Camera settings are not preserved with dropdown and sliders (go.Scatter3d)

Hello,

Is there a way to conserve camera settings when moving slider and using dropdown ?

Currently I set uirevision to be true, but this ends up to only preserve the camera settings of the first loaded plot when I run the app.
I can make a video if needed.

Thank you,
Simon

Here is my code :

import os
import pandas as pd

import dash
import dash_table
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output
import datatable as dt
import numpy as np
import plotly.graph_objects as go
from pandas.api.types import is_numeric_dtype

app = dash.Dash(__name__)
server = app.server


# Import data
correctedAll = pd.read_csv("data/exprDat.normalized.log.corrected.SMALL.tsv", sep="\t", index_col=False)
correctedAll.index = correctedAll.iloc[:,0].to_list()
correctedAll = correctedAll.iloc[:,1:]

umap3D = dt.fread("data/umap_coordinates.tsv", sep="\t").to_pandas()
umap3D.columns = ["samples", "D1", "D2", "D3"]
umap3D.index = umap3D["samples"]
umap3D = umap3D.iloc[1:, 1:]
sample_annotation = dt.fread("data/sampleAnnotation.normalized.log.corrected.tsv", sep="\t").to_pandas()
sample_annotation.index = sample_annotation["rownames.dat."]
sample_annotation = sample_annotation.iloc[:,1:]

data = pd.concat([umap3D, sample_annotation, correctedAll], axis=1)
colors = pd.DataFrame(index=correctedAll.index, columns=["Dataset", "Stirparo.lineage", "Author.lineage", "Branches", "finalClusters"])

cols = {"finalClusters":{"TB.early":"#800000", "TB.medium2":"#9A6324", "TB.medium3":"#808000", "TB.medium1":"#469990","EPI":"#e6194B", "Pre.ST":"#000075", "TB.late":"#000000","Pre.EVT":"#f58231","ST":"#ffe119","EVT":"#bfef45","PrE":"#3cb44b","early_TE":"#42d4f4","medium_TE":"#4363d8","late_TE":"#911eb4","ysTE":"#f032e6","TB.apoptosis":"#a9a9a9", "B1_B2":"#fabed4","EightCells":"#ffd8b1","Morula":"#fffac8","Na":"#aaffc3"},
"Dataset":{"Xiang":"red", "Zhou":"blue", "Petropoulos":"green"},
"Stirparo.lineage":{"Na":"#800000","undefined":"#000000","TE":"#808000","ICM":"#e6194B","intermediate":"#bfef45","prE":"#469990","EPI":"#9A6324"},
"Author.lineage":{"Na":"#800000","EPI":"#9A6324","TE":"#808000","PrE":"#469990","ICM":"#e6194B","CTB":"#000075","STB":"#000000","EVT":"#f58231","PSA_EPI":"#ffe119","9.TE.NR2F2+":"#bfef45"},
"Branches":{"Na":"#800000","1.Pre-morula":"#9A6324","2.Morula":"#808000","3.Early blastocyst":"#469990","5.Early trophectoderm":"#e6194B","6.Epiblast":"#000075","4.Inner cell mass":"#000000","7.Primitive endoderm":"#f58231","8.TE.NR2F2-":"#ffe119","9.TE.NR2F2+":"#bfef45"}}

# Web app
app.layout = html.Div([
    html.Label("Dropdown"),
    dcc.Dropdown(
        id = "dropdown",
        options= [{"label":col, "value":col} for col in data.columns],
        value=data.columns[1]
    ),
    dcc.Graph(id="scatter-plot",
    figure={
        "layout":            { "title": "My Dash Graph",
            "height": 700,  # px
        }
    }),
])

@app.callback(
    Output("scatter-plot", "figure"), 
    [Input("dropdown", "value")])

def update_scatter_plot(dropdown):
        steps = []
        fig = go.Figure()
        eds = list(sample_annotation.ED.unique())
        eds.sort()
        ed_used = []
        slider_index = []
        
        # Split data by each ED values for slider
        for ed in eds:
            ed_used.append(ed)
            data_ed = data.loc[data["ED"].isin(ed_used)]

            # Check if dropdown column is numeric or categorical
            if is_numeric_dtype(data[dropdown]):
                fig.add_trace(go.Scatter3d(visible=False, x=data_ed["D1"], y=data_ed["D2"], z=data_ed["D3"], mode="markers", marker=dict(size=12, color=data_ed[dropdown], colorscale="viridis", colorbar=dict(thickness=20))))

                # Keep track for frames
                slider_index.append(slider_index[-1] + 1 if slider_index else 0)
            else:
                # Iterate over all categories
                # Keep track for frames
                step_index = 0
                for category in data_ed[dropdown].unique():
                    data_plot = data_ed.loc[data[dropdown] == category]
                    fig.add_trace(go.Scatter3d(visible=False, x=data_plot["D1"], y=data_plot["D2"], z=data_plot["D3"], name=category, mode="markers", marker=dict(size=12, color=cols[dropdown][category])))
                    step_index = step_index + 1
                slider_index.append(slider_index[-1] + step_index if slider_index else 0)

        fig.data[len(fig.data)-1].visible=True

        previous_i = 0
        for i in slider_index:
            ed = eds.pop(0)
            step = dict(
                method="update",
                args=[{"visible":[False] * len(fig.data)},
                {"title": "Slider switched to ED: " + str(ed)}],
                label= str(ed))

            if is_numeric_dtype(data[dropdown]):
                step["args"][0]["visible"][i] = True
            else:
                for j in range(previous_i, i):
                    step["args"][0]["visible"][j] = True

            steps.append(step)
            previous_i = i

        sliders = [dict(
            active=slider_index[-1],
            currentvalue={"prefix": "ED : "},
            pad={"t": 50},
            steps=steps
        )]  
        
        ranges = [min(data[["D1", "D2", "D3"]].min(axis=1)), max(data[["D1", "D2", "D3"]].max(axis=1))]
        fig.update_layout(
            sliders=sliders,
            scene = {
                'xaxis': {'range': ranges, 'rangemode': 'tozero', 'tickmode': "linear", 'tick0': -5, 'dtick': 1},
                'yaxis': {'range': ranges, 'rangemode': 'tozero', 'tickmode': "linear", 'tick0': -5, 'dtick': 1},
                'zaxis': {'range': ranges, 'rangemode': 'tozero', 'tickmode': "linear", 'tick0': -5, 'dtick': 1},
                'aspectratio': {
                'x': 0,
                'y': 0,
                'z': 0,
                },
                "aspectmode" : "cube",
                "uirevision": True,
                },
                showlegend=False if is_numeric_dtype(data[dropdown]) else True,
        )
        return fig

if __name__ == '__main__':
    app.run_server(debug=True)
else:
    app.run_server(debug=False)