3d surface plot not showing with Customize Colorscale

Hello,

I’m trying to construct the 3d surface plot with heatmap that can toggle through customize button.

However, when I tried to construct the customize colorscale, 3d surface plot is not expressed.
(The heatmap is correctly showing though)

I’ve attached the code snippet below, it would be really great if this can resolved.

def create_plot_3d():

    # fig_combined = go.Figure()
    fig = go.Figure()
     
    
    # 3D plot (error)
    traces_3d_per_barrier = len(energy_maker)
    traces = []; colorScaleList, colorScaleList_2 = [], []

    # for eB_, eaBarrier in enumerate(energy_barrierList):
    for eaBarrier in energy_barrierList:
        
        combined_x, combined_y, combined_z = [], [], []
        # fig_3d_temp = create_3d_plot(eaBarrier)
        
        for sliderE in energy_maker:
            x_data = data[eaBarrier +  f" Unscaled Action Exact {sliderE}"]
            z_data = data[eaBarrier +  f" Error {sliderE}"]
            
            combined_x.extend(x_data)
            combined_y.extend(np.repeat(sliderE, len(x_data)))
            combined_z.extend(z_data)
        
            # fig.add_trace(go.Scatter3d(
            #     x=x_data,
            #     y=np.repeat(sliderE, len(x_data)),
            #     z=z_data,
            #     mode='lines',
            #     line=dict(color='green', width=4),
            #     showlegend=False,
            #     opacity=0.4,
            #     hovertemplate='Tunneling Action: %{x}<br>Scaled Energy: %{y}<br>Error: %{z}<extra></extra>'
            # ))

        # scatter_trace = go.Scatter3d(
        #             x=combined_x,
        #             y=combined_y,
        #             z=combined_z,
        #             mode='markers',
        #             # line=dict(color='green', width=4),
        #             showlegend=False,
        #             opacity=0.4,
        #             marker=dict(color='green', size=1.2),
        #             hovertemplate='Tunneling Action: %{x}<br>Scaled Energy: %{y}<br>Error: %{z}<extra></extra>'
        #         )

        reshape_x = energy_maker
        reshape_y = data[eaBarrier +  " Unscaled Action Exact 0.01"]
        reshape_z = np.asarray(combined_z).reshape((len(reshape_x), len(reshape_y)))

        # Color range settings
        sort_z = np.sort(reshape_z, axis=None)
        normalize_sortZ = (sort_z - np.min(sort_z))/ (np.max(sort_z) - np.min(sort_z))
        where_sortZ = np.where(sort_z > 0)[0]; where2_sortZ = np.where((sort_z >= -2) & (sort_z <= 1))[0]
        
        # position_value = np.linspace(-2,1,num = len(viridis_list)) # Array of linear space to output the color code
        
        colorscale2 = []; new_viridis_list = [] # Colorscale, color code list
        if (np.max(sort_z) <= 1) and (np.min(sort_z) < -2):  # Maximum value <= 1
            print("(,1)")
            pos_index_max = np.where(-2 > sort_z)[0][-1]
            
            for min_max in np.arange(pos_index_max + 1, sort_z.shape[0]):
                (r,g,b,_) = cmap_color(sort_z[min_max])
                new_viridis_list.append(f"rgb({int(r * 255)}, {int(g * 255)}, {int(b * 255)})" )
            
            # for _llr in np.arange(pos_index_max + 1): # value < -2
            colorscale2.append([.0, new_viridis_list[0]])
                
            for l_r, llr_2 in enumerate(np.arange(pos_index_max + 1, sort_z.shape[0])): # value >= -2
                colorscale2.append([normalize_sortZ[llr_2], new_viridis_list[l_r]])

        elif (np.max(sort_z) > 1) and (np.min(sort_z) >= -2): # Minimum Value > -2
            print("[-2, )")
            pos_index_max = np.where(1 < sort_z)[0][0]
            pos_index_min2 = np.where(position_value >= np.min(sort_z))[0][0]

            for min_max in np.arange(pos_index_max - 1):
                (r,g,b,_) = cmap_color(sort_z[min_max])
                new_viridis_list.append(f"rgb({int(r * 255)}, {int(g * 255)}, {int(b * 255)})" )
            
            for _llr in np.arange(pos_index_max - 1):
                colorscale2.append([normalize_sortZ[_llr], viridis_list[_llr]])
            colorscale2.append([1.0, viridis_list[-1]])
            # for llr_2 in np.arange(pos_index_max - 1, sort_z.shape[0]):
            #     colorscale2.append([normalize_sortZ[llr_2], viridis_list[-1]])
        
        elif (np.max(sort_z) > 1) and (np.min(sort_z) < -2): # sort_z contains the [-2,1]
            print("Outside")
            pos_index_min = np.where(-2 > sort_z)[0][-1]
            pos_index_max = np.where(1 < sort_z)[0][0]
            # new_linspace = normalize_sortZ[pos_index_min + 1:pos_index_max]
            new_viridis_list = []
            for min_max in np.arange(pos_index_min, pos_index_max):
                (r,g,b,_) = cmap_color(sort_z[min_max])
                new_viridis_list.append(f"rgb({int(r * 255)}, {int(g * 255)}, {int(b * 255)})" )

            # for low_ in np.arange(pos_index_min):
            #     colorscale2.append([normalize_sortZ[low_], new_viridis_list[0]])
            colorscale2.append([.0, new_viridis_list[0]])
            for m_, middle_ in enumerate(np.arange(pos_index_min, pos_index_max)):
                # print(test_list[m_], sort_z[middle_],(normalize_sortZ[middle_], new_viridis_list[m_]))
                colorscale2.append([normalize_sortZ[middle_], new_viridis_list[m_]])
            # for high_ in np.arange(pos_index_max, sort_z.shape[0]):
            #     colorscale2.append([normalize_sortZ[high_], new_viridis_list[-1]])
            colorscale2.append([1.0, new_viridis_list[-1]])
        
        
        if len(where_sortZ)==0:
            colorscale = [(.0, "blue"),(1.0, "blue")]
            
        elif len(where_sortZ) != 0:
            index_0 = where_sortZ[0]
            colorscale = [(.0,"darkblue"),(normalize_sortZ[index_0-1], "blue"),(normalize_sortZ[index_0], "red"), (1.0,"darkred")]
            
        # print(colorscale)
        colorScaleList.append(colorscale); colorScaleList_2.append(colorscale2)
        # print(colorscale2)
        fig.add_trace(go.Surface(
        x=reshape_x,
        y=reshape_y,
        z=reshape_z,
        colorscale=colorscale2,
        visible=False
        ))
        # traces.extend([scatter_trace, surface_plot])
        
        # fig.add_trace()
        # for trace in fig_3d_temp.data:
            # fig_combined.add_trace(trace)

        # fig_combined.add_trace(trace)
        
        # Set the traces for the current barrier to True
        # visible_list = [False] * traces_3d_per_barrier * 4
        # start_index = eB_ * traces_3d_per_barrier
        # end_index = (eB_ + 1) * traces_3d_per_barrier
        # visible_list[start_index:end_index] = [True] * traces_3d_per_barrier

        # button_true_false.append(visible_list)
        # fig.update_traces(surface_plot)

    fig.update_scenes(
        xaxis_title=dict(text="Tunneling Action (θ₀)", font=dict(size=15)),
        yaxis_title=dict(text="Scaled Energy (ξ)", font=dict(size=15)),
        zaxis_title=dict(text="Tunneling Probability (error) (log | (P̃ - Pex) / Pex |)", font=dict(size=15)),
        xaxis=dict(tickmode='linear', tick0=0, dtick=1), #Major ticks
        yaxis=dict(type='linear'),
        zaxis=dict(type='linear'),)
    
    fig.update_layout(
        width=1100, height=800,
        template = 'plotly_white',
        updatemenus=[
        dict( # Toggle between the energy barrier
        buttons=
            list([
                dict(
                    args=[{"visible": [True if i == j else False for j in range(4)]},
                         {"Title":energy_barrierList[i]}],
                    label = energy_barrierList[i],
                    method="update",
                    ) for i in range(4)
            ]),
        direction="down",
        type="buttons",
        pad={"r": 5, "t": 5},
        showactive=True,
        x=1.4,
        xanchor="right",
        y=.0,
        yanchor="top",
        ),

        dict( # Surface / scatter3d
        buttons=
            list([
                dict(
                        args=["type", "surface"],
                        label = "3D Surface",
                        method="restyle",
                    ),
                 dict(
                        args=["type", "heatmap"],
                        label = "Heat Map",
                        method="restyle",
                    ),
            ]),
        direction="down",
        type="buttons",
        pad={"r": 5, "t": 5},
        showactive=True,
        x=1.2,
        xanchor="right",
        y=.0,
        yanchor="top",),

        dict(
            buttons=list([
                dict(
                    args=[{"contours.showlines": False, "type": "contour"}],
                    label="Hide lines",
                    method="restyle"
                ),
                dict(
                    args=[{"contours.showlines": True, "type": "contour"}],
                    label="Show lines",
                    method="restyle"
                ),
            ]),
            type = "buttons",
            direction="down",
            pad={"r": 5, "t": 5},
            showactive=True,
            x=0.9,
            xanchor="right",
            y=-.2,
            yanchor="top"
        ),

        dict( # Colorscale
            buttons=
            list([
                dict(
                    args=[{
                        "colorscale":colorScaleList,
                        "coloraxis_colorbar.tickvals":[-1,0,1],
                        "coloraxis_colorbar.ticktext":["Err < 0", "0", "Err > 0"],
                        "coloraxis_colorbar.title":"Error",
                          }],
                    label = "Binary",
                    method="update",),
                # dict(
                #     args=["colorscale", ['Viridis'] * 4],
                #     label = "Viridis",
                #     method="restyle",),
                dict(
                  args=[{
                        "colorscale":colorScaleList_2,
                        "coloraxis_colorbar.tickvals":[-2,-1,0,1],
                        "coloraxis_colorbar.ticktext":["Err. < -2", "-1",  "0", "Err > 1"],
                        "coloraxis_colorbar.title":"Error",
                  }],
                    label = "Viridis(-2 < Err < 1)",
                    method="update",),
            ]),
        direction="down",
        type="buttons",
        pad={"r": 5, "t": 5},
        showactive=True,
        x=1.7,
        xanchor="right",
        y=.0,
        yanchor="top",
        ),

            
        ],
        scene=dict(
        camera=dict(projection=dict(type="orthographic"))),
        margin=dict(l=10, r=10, b=10, t=10),
        modebar_add=[
        "v1hovermode","toggleSpikelines",
        ],
        xaxis_title="Scaled Energy (ξ)",
        yaxis_title="Tunneling Action (θ₀)",
    )


    # for mm in fig.layout.updatemenus:
    #     for button in mm.buttons:
    #         print(f"Button: {button.label}, Args: {button.args}")
    
    return fig