Improving performance on 3D surface plot with multiple traces

I am using Plotly to display a 3D surface plot for one in a sequence of images. I want to display some extra information, in the form of vertical planes in specific positions, on this plot. I also want to change the image being displayed as the 3D surface (based on a slider). This is being done in a Jupyter notebook.

I have all of this working, but the update of the surface is quite slow when the slider changes. I have determined that it is much slower when the vertical planes are being displayed. To some extent I understand this, as the scene is more complicated to draw. However, it seems that the amount of slowdown is excessive. Rough measurement shows that it takes more than 2x the time to update the scene when the planes are visible.

I am providing a simplified version of my code that reproduces this issue. Is there a better way to do what Iā€™m trying to do ā€“ such that it will update faster? I did try to use the frames property to do the animation, but FigureWidget does not support this, and I need to be able to get event notifications on the Python side, etc. I also had some problems with the animation not working correctly (blanking out an associated 2D plot when a second 2D plot was being animated).

Code:

import numpy as np
import cv2
import ipywidgets as ipw
import plotly.express as px
import plotly.graph_objects as go

from skimage import io
from IPython.display import clear_output
from plotly.subplots import make_subplots

def add_3d_shapes(fig, initially_visible, z1=0, z2=1024, rownum=1, colnum=1):
    for i in range(8):
        bx, by, bw, bh = [50 + 15*i, 50 + 15*i, 100, 100]
        x1 = bx
        x2 = bx + bw
        y1 = by
        y2 = by + bh
        single_color=[[0.0, 'red'], [1.0, 'red']]

        def add_plane(x, y, z):
            plane=dict(type='surface', x=x, y=y, z=z, colorscale=single_color, showscale=False, 
                       visible=initially_visible, name='plane')
            fig.add_trace(plane, row=rownum, col=colnum)

        x,z=np.meshgrid(np.linspace(x1, x2, 2), np.linspace(z1, z2, 2))
        y = y1*np.ones(x.shape)
        add_plane(x, y, z)
        y = y2*np.ones(x.shape)
        add_plane(x, y, z)

        y,z=np.meshgrid(np.linspace(y1, y2, 2), np.linspace(z1, z2, 2))
        x = x1*np.ones(y.shape)
        add_plane(x, y, z)
        x = x2*np.ones(y.shape)
        add_plane(x, y, z)

def on_img_num_change(imglist, titlelist, change):
    imgnum = int(change['new'])
    with fig.batch_update():
        fig.layout.annotations[0].update(text=titlelist[imgnum])
        fig.update_traces(selector=dict(name="imgdata"), z=imglist[imgnum])

def on_add3d_checkbox_change(change):
    with fig.batch_update():
        fig.update_traces(selector=dict(name="plane"), visible=change.new)

def draw_figure(draw3d_shapes):
    global fig
    
    specs = [[{'l': .01, 'r': .01, 'b': .01, 't': .01, 'is_3d': True}]]
    fig = make_subplots(rows=1, cols=1, horizontal_spacing=0.0, specs=specs,
                        vertical_spacing=0.0, subplot_titles=[titlelist[0]])

    img = imglist[0]

    trace = {}
    trace['type'] = 'surface'
    trace['z'] = img.astype(np.uint32)
    trace['colorscale'] = 'viridis'
    trace['name'] = 'imgdata'
    fig.add_trace(trace, row=1, col=1)

    fig.update_layout(width=350, height=205, margin=dict(l=5, r=25, t=40, b=5))
    fig.for_each_annotation(lambda a: a.update(font=dict(size=10)))

    fig.update_traces(showscale=False)
    if img.shape[0] != img.shape[1]:
        fig.update_scenes(aspectmode='manual',
                          aspectratio=go.layout.scene.Aspectratio(
                              x=1, y=img.shape[0]/img.shape[1], z=imgmax / (img.shape[0]*img.shape[1]))
                          )

    fig.update_scenes(xaxis_visible=False)
    fig.update_scenes(yaxis_visible=False)
    fig.update_scenes(zaxis_title="Intensity", zaxis_range=[0, imgmax])
    fig.update_scenes(xaxis_autorange="reversed")

    add_3d_shapes(fig, draw3d_shapes)

    fig = go.FigureWidget(fig)

    num_slider = ipw.widgets.IntSlider(
        value=0,
        min=0,
        max=len(imglist)-1,
        step=1,
        description='Image #',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d'
    )
    num_slider.observe(lambda change,imglist=imglist,titlelist=titlelist: on_img_num_change(imglist, titlelist, change), names='value')

    fig_hbox = ipw.HBox([fig], layout=ipw.Layout(display='flex', flex_flow='row',
                                                 justify_content='center'))
    img_num_control_hbox = ipw.HBox([num_slider], layout=ipw.Layout(display='flex', flex_flow='row',
                                                                    justify_content='center'))
    fig_vbox = ipw.VBox([fig_hbox, img_num_control_hbox])
    
    display(fig_vbox)

add3d_checkbox = ipw.widgets.Checkbox(False, description="Add 3D shapes", indent=False,
                                      layout=ipw.Layout(width='140px'))
add3d_checkbox.observe(on_add3d_checkbox_change, 'value')
display(add3d_checkbox)

output_cell = ipw.Output(layout={'border': '0px solid black'})
display(output_cell)

imgmax = 4000

imglist = io.imread("https://github.com/scikit-image/skimage-tutorials/raw/main/images/cells.tif")
titlelist = [f"image{n}" for n in range(len(imglist))]

with output_cell:
    draw_figure(add3d_checkbox.value)