Update subplot scene in for loop

I am trying to make a function to generate subplots with surface plots and scatters overlaid. I have nearly completed it, however I cannot think of a way to update the scene argument in update_layout, I am using a workaround that limits the function to a defined amount of `update_layout’ definitions based on the amount of plots I pass in.

However I wish to have accept a undefined amount. The issue comes from plotly incrementing the arugment names by 1 for each plot on the subplot. Is the a way to reproduce this - searching for answers seem to that using the exec command to do this is a bad idea.

def surface_col(
    df: list,
    x_string: list,
    y_string: list,
    z_string: list,
    plot_title: list,
):
    """_summary_

    Args:
        df (list): _description_
        x_string (list): _description_
        y_string (list): _description_
        z_string (list): _description_
        plot_title (list): _description_

    Returns:
        _type_: _description_
    """

    if len(df) == 1:
        plot = go.Figure()
    elif len(df) > 1:
        scenes = [dict(zip(["type"], ["surface"])) for x in range(0, len(df))]
        cbar_offset = 0.1
        cbar_loc = []
        x = len(df)
        while x > 0:
            cbar_loc.append((len(df) / (x * len(df))) - cbar_offset)
            x = x - 1

        plot = make_subplots(
            rows=1,
            cols=len(df),
            subplot_titles=(plot_title),
            specs=[scenes],
            shared_xaxes=True,
            shared_yaxes=True,
        )

    for plots in range(0, len(df)):
        print(plots)
        x = df[plots][x_string[plots]]
        y = df[plots][y_string[plots]]
        z = df[plots][z_string[plots]]

        x_i = np.linspace(start=float(min(x)), stop=float(max(x)), num=int(50))
        y_i = np.linspace(start=float(min(y)), stop=float(max(y)), num=int(50))

        X, Y = np.meshgrid(x_i, y_i)

        z_out = griddata((x, y), z, (X, Y), method="linear")

        x_out = x_i
        y_out = y_i

        plot.add_trace(
            go.Surface(
                z=z_out,
                x=x_out,
                y=y_out,
                contours={"x": {"show": True}, "z": {"show": True}},
                opacity=0.98,
                colorscale="plotly3",
                hovertemplate=x_string[plots]
                + ": %{x:.3f}"
                + "<br>"
                + y_string[plots]
                + ": %{y:.3f}</br>"
                + z_string[plots]
                + ": %{z:.3f}",
                colorbar=dict(
                    title=z_string[plots],
                    titleside="right",
                    titlefont=dict(size=12, family="Arial, sans"),
                    x=cbar_loc[plots],
                ),
            ),
            row=1,
            col=plots + 1,
        )

        plot.add_trace(
            go.Scatter3d(
                z=z,
                x=x,
                y=y,
                name=plot_title[plots],
                hovertemplate=x_string[plots]
                + ": %{x:.5f}"
                + "<br>"
                + y_string[plots]
                + ": %{y:.5f}</br>"
                + z_string[plots]
                + ": %{z:.5f}",
                mode="markers",
            ),
            row=1,
            col=plots + 1,
        )

        camera = dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=1.25, y=-1.25, z=1.25),
        )

    if len(df) == 1:
        plot.update_layout(
            template=template_white,
            scene=dict(
                scene_camera=camera,
                xaxis_title=x_string[0],
                yaxis_title=y_string[0],
                zaxis_title=z_string[0],
            ),
        )

    elif len(df) == 2:
        plot.update_layout(
            template=template_white,
            scene=dict(
                camera=camera,
                xaxis_title=x_string[0],
                yaxis_title=y_string[0],
                zaxis_title=z_string[0],
            ),
            scene2=dict(
                camera=camera,
                xaxis_title=x_string[1],
                yaxis_title=y_string[1],
                zaxis_title=z_string[1],
            ),
        )

    return plot