import plotly.offline as pyo
import plotly.graph_objs as go
from plotly.offline import iplot
pyo.init_notebook_mode()
import numpy as np
x_neg = np.linspace(-1,0,80)
x_pos = np.linspace(0,1.5, 160)
y_neg = np.linspace(-1,0,80)
y_pos = np.linspace(0,1.5, 160)
x_pos_1, y_pos_1 = np.meshgrid(x_pos, y_pos)
x_neg_2, y_pos_2 = np.meshgrid(x_neg, y_pos)
x_pos_3, y_neg_3 = np.meshgrid(x_pos, y_neg)
x_neg_4, y_neg_4 = np.meshgrid(x_neg, y_neg)
z_1 = np.exp(-n*((x_pos_1 + y_pos_1 - 1)**2))
z_2 = np.exp(-n*(x_neg_2**2 + (y_pos_2 - 1)**2))
z_3 = np.exp(-n*((x_pos_3 - 1)**2 + y_neg_3**2))
z_4 = np.exp(-n*((x_neg_4 + y_neg_4)**2 + 1))
z_1 = (x_pos_1 + y_pos_1 - 1)**2
z_2 = x_neg_2**2 + (y_pos_2 - 1)**2
z_3 = (x_pos_3 - 1)**2 + y_neg_3**2
z_4 = (x_neg_4 + y_neg_4)**2 + 1
min_value=0
max_value=0.5
fig = go.Figure(data=[
go.Surface(x=x_pos_1, y= y_pos_1, z=z_1, cmin=min_value, cmax=max_value),
go.Surface(x=x_neg_2, y= y_pos_2, z=z_2, cmin=min_value, cmax=max_value, showscale=False, opacity=0.9,),
go.Surface(x=x_pos_3, y= y_neg_3, z=z_3, cmin=min_value, cmax=max_value, showscale=False, opacity=0.9),
go.Surface(x=x_neg_4, y= y_neg_4, z=z_4, cmin=min_value, cmax=max_value, showscale=False, opacity=0.9),
])
fig.update_layout(
scene = dict(zaxis = dict(nticks=4, range=[0,0.5]),
xaxis=dict(range=[-0.5,1.5]),
yaxis=dict(range=[-0.5,1.5]),
aspectmode ='cube',
xaxis_title = 'w_0',
yaxis_title = 'w_1',
zaxis_title = 'K(w)'),
width=1200, height=800,
margin=dict(r=20, l=10, b=10, t=10),
)
iplot(fig, filename='3_K_w_degenerate_node')