Show & Tell: Plotly subplots with individual legends, all interactions clientside

It would be very useful to have multiple legends with one graph, and this has been mentioned in different places:

Here is an example of my method in action (note in this graph the x-axes are linked):
output

This has been done within dash and plotly.py, and all of the interactions have been done clientside and as such do not require re-querying any databases (whether that be MySLQ or redis-caching), and also do not require sending large data-arrays back to the server.

Here is the code used, it is quite long and may need adapting depending on usage, but the general ideas are there and it is mostly “general”.

code
import dash
from dash import html, dcc, Input, Output, State, MATCH, ALL, no_update
import plotly.graph_objs as go
import numpy as np

deep_merge = """
function batchAssign(patches) {
    function recursiveAssign(input, patch){
        var outputR = Object(input);
        for (var key in patch) {
            if(outputR[key] && typeof patch[key] == "object") {
                outputR[key] = recursiveAssign(outputR[key], patch[key])
            }
            else {
                outputR[key] = patch[key];
            }
        }
        return outputR;
    }

    return Array.prototype.reduce.call(arguments, recursiveAssign, {});
}
"""


# Initial states of the different "groups"
groups_visible_default = [
    dict(
        legendgroup = 'polynomial_group',
        name = 'Polynomial',
        visible = True,
        which_legend = 'legend1'
    ),
    dict(
        legendgroup = 'sinusoidal_group',
        name = 'Sinusoidal',
        visible = 'legendonly',
        which_legend = 'legend1'
    ),
    dict(
        legendgroup = 'exponential_group',
        name = 'Exponential',
        visible = True,
        which_legend = 'legend2'
    ),
]

def make_layout():
    legend_layout = dict(
        height = 70,
        width = 125,
        yaxis=dict(
            showticklabels = False,
            fixedrange = True,
            showgrid = False,
            zeroline = False,
            visible = False,
        ),
        xaxis=dict(
            showticklabels = False,
            fixedrange = True,
            showgrid = False,
            zeroline = False,
            visible = False,
        ),
        margin = dict(l=0, b=0, t=0, r=0),
    )
    traces_legend1 = []
    traces_legend2 = []
    for group_visible in groups_visible_default:
        if group_visible['which_legend']=='legend1':
            traces_legend1.append(go.Scatter(
                x = np.array([np.nan, np.nan]),
                y = np.array([np.nan, np.nan]),
                xaxis = 'x1',
                yaxis = 'y1',
                mode = 'lines',
                line = dict(color='black'),
                showlegend = True,
                legendgroup = group_visible['legendgroup'],
                name = group_visible['name'],
                visible = group_visible['visible']
            ))
        elif group_visible['which_legend']=='legend2':
            traces_legend2.append(go.Scatter(
                x = np.array([np.nan, np.nan]),
                y = np.array([np.nan, np.nan]),
                xaxis = 'x1',
                yaxis = 'y1',
                mode = 'lines',
                line = dict(color='black'),
                showlegend = True,
                legendgroup = group_visible['legendgroup'],
                name = group_visible['name'],
                visible = group_visible['visible']
            ))
    figure_legend1 = dict(
        data = traces_legend1,
        layout = legend_layout
    )
    figure_legend2 = dict(
        data = traces_legend2,
        layout = legend_layout
    )

    layout = html.Div(
        children=[
            dcc.Store(id='page_creation_signal', data=None),
            # "main_graph_figure"
            # contains the full figure, and can be quite large, so it is advantageous NOT to send it over the network back to the server
            dcc.Store(id='main_graph_figure', data=None),
            # "main_graph_legend_status"
            # contains the updated legned status and is merged with "main_graph_figure" to produce the final figure
            dcc.Store(id='main_graph_legend_status', data=None),
            # "main_graph_reduced"
            # is a simplified version of "main_graph_figure", and contains only the essential data needed to make "main_graph_legend_status"
            dcc.Store(id='main_graph_reduced', data=None),

            html.H1(children='Example of multiple legends'),

            html.Div(
                children = [
                    html.Div(
                        children = [
                            dcc.Graph(
                                id = 'main_graph',
                                figure = dict(
                                    data = [],
                                    layout = dict(),
                                ),
                                config = dict(
                                    responsive = True,
                                    displaylogo = False,
                                    modeBarButtonsToRemove = ['zoom2d', 'zoomIn2d', 'zoomOut2d', 'pan2d', 'select2d', 'lasso2d', 'autoScale2d', 'toggleSpikelines', 'hoverCompareCartesian', 'hoverClosestCartesian', 'resetScale2d', 'toimage']
                                )
                            )
                        ],
                        style = {
                          'width': '100%',
                          'height': '100%',
                          'position': 'absolute',
                          'top': '0',
                          'left': '0',
                          'zIndex': '10',
                          # 'backgroundColor': '#EBEBEB'
                        }
                    ),
                    html.Div(
                        children = [
                            dcc.Graph(
                                id = 'legend1',
                                figure = figure_legend1,
                                config = dict(
                                    responsive = True,
                                    displaylogo = False,
                                    modeBarButtonsToRemove = ['zoom2d', 'zoomIn2d', 'zoomOut2d', 'pan2d', 'select2d', 'lasso2d', 'autoScale2d', 'toggleSpikelines', 'hoverCompareCartesian', 'hoverClosestCartesian', 'resetScale2d', 'toimage']
                                )
                            )
                        ],
                        style = {
                          'width': '90px',
                          'height': '70px',
                          'position': 'absolute',
                          'left': '0',
                          'left': '0',
                          'zIndex': '999',
                          'marginLeft': '475px',
                          'marginTop': '110px',
                        }
                    ),
                    html.Div(
                        children = [
                            dcc.Graph(
                                id = 'legend2',
                                figure = figure_legend2,
                                config = dict(
                                    responsive = True,
                                    displaylogo = False,
                                    modeBarButtonsToRemove = ['zoom2d', 'zoomIn2d', 'zoomOut2d', 'pan2d', 'select2d', 'lasso2d', 'autoScale2d', 'toggleSpikelines', 'hoverCompareCartesian', 'hoverClosestCartesian', 'resetScale2d', 'toimage']
                                )
                            )
                        ],
                        style = {
                          'width': '90px',
                          'height': '70px',
                          'position': 'absolute',
                          'left': '0',
                          'left': '0',
                          'zIndex': '999',
                          'marginLeft': '1085px',
                          'marginTop': '110px',
                        }
                    ),
                ],
                style = {
                  'width': '1300px',
                  'height': '300px',
                  'position': 'relative'
                }
            )
        ]
    )

    return layout


def make_traces(groups_visible):
    # Note - this function may take a long time and may involve access to database(s), e.g. MySQL and/or redis-cache
    traces = []

    # Get data from database, cached in redis
    x = np.linspace(0.0, 2.0*np.pi, 50000)

    # Graph 1
    # Make the polynomial group
    legendgroup = 'polynomial_group'
    group_options = next(item for item in groups_visible if item['legendgroup']==legendgroup)
    traces.append(go.Scatter(
        x = x,
        y = x**2/50.0,
        xaxis = 'x1',
        yaxis = 'y1',
        hoverlabel = dict(namelength=-1),
        legendgroup = legendgroup,
        visible = group_options['visible']
    ))
    traces.append(go.Scatter(
        x = x,
        y = x**3/250.0,
        xaxis = 'x1',
        yaxis = 'y1',
        hoverlabel = dict(namelength=-1),
        legendgroup = legendgroup,
        visible = group_options['visible']
    ))

    # Make the sinusoidal group
    legendgroup = 'sinusoidal_group'
    group_options = next(item for item in groups_visible if item['legendgroup'] == legendgroup)
    traces.append(go.Scatter(
        x = x,
        y = np.sin(x),
        xaxis = 'x1',
        yaxis = 'y1',
        hoverlabel = dict(namelength=-1),
        legendgroup = legendgroup,
        visible = group_options['visible']
    ))
    
    # Graph 2
    # Exponential plot
    legendgroup = 'exponential_group'
    group_options = next(item for item in groups_visible if item['legendgroup'] == legendgroup)
    traces.append(go.Scatter(
        x = x,
        y = -np.exp(x),
        xaxis = 'x2',
        yaxis = 'y2',
        hoverlabel = dict(namelength=-1),
        legendgroup = legendgroup,
        visible = group_options['visible']
    ))

    return traces


app = dash.Dash(__name__)

app.layout = make_layout()


# Create main_graph
@app.callback(
    [ Output('main_graph_figure','data'), Output('main_graph_reduced','data') ],
    [ Input('page_creation_signal','data') ],
    prevent_initial_call=True)
def main_graph__figure(
        _dummy0,
    ):

    traces = make_traces(groups_visible_default)
    traces_reduced = []
    for trace in traces:
        traces_reduced.append(go.Scatter(
            legendgroup = trace['legendgroup'],
            visible = trace['visible']
        ))

    layout = dict(
        xaxis = dict(
            title = 'Time (s)',
            domain = [0.00, 0.47],
            hoverformat = '.2f',
            linecolor = 'black', 
            linewidth = 1,
            mirror = 'allticks',
            zerolinecolor = 'black',
            ticks = 'inside',
            matches = 'x2',
        ),
        xaxis2 = dict(
            title = 'Time (s)',
            domain = [0.53, 1.00],
            hoverformat = '.2f',
            linecolor = 'black', 
            linewidth = 1,
            mirror = 'allticks',
            zerolinecolor = 'black',
            ticks = 'inside',
        ),
        yaxis = dict(
            title = 'Variable-set 1 (units)',
            hoverformat = '.2f',
            linecolor = 'black',
            linewidth = 1,
            mirror = 'allticks',
            zerolinecolor = 'black',
            ticks = 'inside',
        ),
        yaxis2 = dict(
            title = 'Variable-set 1 (units)',
            hoverformat = '.2f',
            linecolor = 'black',
            linewidth = 1,
            mirror = 'allticks',
            zerolinecolor = 'black',
            ticks = 'inside',
            anchor = 'x2',
        ),
        hovermode = 'x',
        showlegend = False
    )

    figure = dict(
        data = traces,
        layout = layout
    )
    figure_reduced = dict(
        data = traces_reduced,
        layout = layout
    )

    return figure, figure_reduced


# Set the legend status
@app.callback(
    Output('main_graph_legend_status','data'),
    [
        Input('legend1','restyleData'),      # Input 1
        Input('legend2','restyleData')       # Input 1
    ],
    [
        State('main_graph_reduced','data'),  # State 1
        State('legend1','figure'),           # State 2
        State('legend2','figure'),           # State 2
    ],
    prevent_initial_call=True)
def main_graph__figure(
        _dummy1,         # Input 1
        _dummy2,         # Input 2
        figure_reduced,  # State 1
        legend1_figure,  # State 2
        legend2_figure   # State 3
    ):

    datas = legend1_figure['data'] + legend2_figure['data']

    traces = figure_reduced['data']
    for trace in traces:
        for data in datas:
            legendgroup = data['legendgroup']
            visible = data['visible']
            if trace['legendgroup']==legendgroup:
                trace['visible'] = visible
                trace['legendgroup'] = legendgroup

    figure = dict(
        data = traces,
        layout = dict(showlegend=False)
    )

    return figure


app.clientside_callback(
    deep_merge,
    Output('main_graph', 'figure'),
    [
        Input('main_graph_figure', 'data'),
        Input('main_graph_legend_status','data')
    ]
)


if __name__ == '__main__':
    app.run_server(debug=True)
2 Likes

This is really really nice. It would be great if we could figure out how to encapsulate this logic into a standalone function or library. Perhaps an AIO component?

Thanks for sharing!

1 Like

Hi @chriddyp ! Has this been implemented as a function?

Would love to see this!

1 Like

Hi Sebastian @Seb1,
Welcome to the community. I hope you’re enjoying Dash so far.

I do not believe this example has been developed into an AIO component because I didn’t see anything on the Forum show-and-tell or tips-and-tricks tag.
Doing it the way @pfbuxton shows us above is a good option. I found a post that is not exactly what you’re looking for, but it could be another possible workaround. See here.

HI @adamschroeder,

The workaround you linked to simply adds a large space between legend-groups, by doing this:

legend_tracegroupgap = 180

What this means is that the entire subplot is either on or off. You can’t use the legends to control the appearance of traces. Also if you have many plot you will run into formatting difficulties as you can’t set the gap per subplot, only for the entire figure/layout.

My workaround does allow for individual traces to be turned on/off using the legends.

I have not had time to develop my workaround into an AIO component. But to be honest, this would work a lot better if it were done on the javascript side within plotly.js.

3 Likes