Announcing Dash Bio 1.0.0 🎉 : a one-stop-shop for bioinformatics and drug development visualizations.

Show & Tell: Plotly subplots with individual legends, all interactions clientside

It would be very useful to have multiple legends with one graph, and this has been mentioned in different places:

Here is an example of my method in action (note in this graph the x-axes are linked):
output

This has been done within dash and plotly.py, and all of the interactions have been done clientside and as such do not require re-querying any databases (whether that be MySLQ or redis-caching), and also do not require sending large data-arrays back to the server.

Here is the code used, it is quite long and may need adapting depending on usage, but the general ideas are there and it is mostly “general”.

code
import dash
from dash import html, dcc, Input, Output, State, MATCH, ALL, no_update
import plotly.graph_objs as go
import numpy as np

deep_merge = """
function batchAssign(patches) {
    function recursiveAssign(input, patch){
        var outputR = Object(input);
        for (var key in patch) {
            if(outputR[key] && typeof patch[key] == "object") {
                outputR[key] = recursiveAssign(outputR[key], patch[key])
            }
            else {
                outputR[key] = patch[key];
            }
        }
        return outputR;
    }

    return Array.prototype.reduce.call(arguments, recursiveAssign, {});
}
"""


# Initial states of the different "groups"
groups_visible_default = [
    dict(
        legendgroup = 'polynomial_group',
        name = 'Polynomial',
        visible = True,
        which_legend = 'legend1'
    ),
    dict(
        legendgroup = 'sinusoidal_group',
        name = 'Sinusoidal',
        visible = 'legendonly',
        which_legend = 'legend1'
    ),
    dict(
        legendgroup = 'exponential_group',
        name = 'Exponential',
        visible = True,
        which_legend = 'legend2'
    ),
]

def make_layout():
    legend_layout = dict(
        height = 70,
        width = 125,
        yaxis=dict(
            showticklabels = False,
            fixedrange = True,
            showgrid = False,
            zeroline = False,
            visible = False,
        ),
        xaxis=dict(
            showticklabels = False,
            fixedrange = True,
            showgrid = False,
            zeroline = False,
            visible = False,
        ),
        margin = dict(l=0, b=0, t=0, r=0),
    )
    traces_legend1 = []
    traces_legend2 = []
    for group_visible in groups_visible_default:
        if group_visible['which_legend']=='legend1':
            traces_legend1.append(go.Scatter(
                x = np.array([np.nan, np.nan]),
                y = np.array([np.nan, np.nan]),
                xaxis = 'x1',
                yaxis = 'y1',
                mode = 'lines',
                line = dict(color='black'),
                showlegend = True,
                legendgroup = group_visible['legendgroup'],
                name = group_visible['name'],
                visible = group_visible['visible']
            ))
        elif group_visible['which_legend']=='legend2':
            traces_legend2.append(go.Scatter(
                x = np.array([np.nan, np.nan]),
                y = np.array([np.nan, np.nan]),
                xaxis = 'x1',
                yaxis = 'y1',
                mode = 'lines',
                line = dict(color='black'),
                showlegend = True,
                legendgroup = group_visible['legendgroup'],
                name = group_visible['name'],
                visible = group_visible['visible']
            ))
    figure_legend1 = dict(
        data = traces_legend1,
        layout = legend_layout
    )
    figure_legend2 = dict(
        data = traces_legend2,
        layout = legend_layout
    )

    layout = html.Div(
        children=[
            dcc.Store(id='page_creation_signal', data=None),
            # "main_graph_figure"
            # contains the full figure, and can be quite large, so it is advantageous NOT to send it over the network back to the server
            dcc.Store(id='main_graph_figure', data=None),
            # "main_graph_legend_status"
            # contains the updated legned status and is merged with "main_graph_figure" to produce the final figure
            dcc.Store(id='main_graph_legend_status', data=None),
            # "main_graph_reduced"
            # is a simplified version of "main_graph_figure", and contains only the essential data needed to make "main_graph_legend_status"
            dcc.Store(id='main_graph_reduced', data=None),

            html.H1(children='Example of multiple legends'),

            html.Div(
                children = [
                    html.Div(
                        children = [
                            dcc.Graph(
                                id = 'main_graph',
                                figure = dict(
                                    data = [],
                                    layout = dict(),
                                ),
                                config = dict(
                                    responsive = True,
                                    displaylogo = False,
                                    modeBarButtonsToRemove = ['zoom2d', 'zoomIn2d', 'zoomOut2d', 'pan2d', 'select2d', 'lasso2d', 'autoScale2d', 'toggleSpikelines', 'hoverCompareCartesian', 'hoverClosestCartesian', 'resetScale2d', 'toimage']
                                )
                            )
                        ],
                        style = {
                          'width': '100%',
                          'height': '100%',
                          'position': 'absolute',
                          'top': '0',
                          'left': '0',
                          'zIndex': '10',
                          # 'backgroundColor': '#EBEBEB'
                        }
                    ),
                    html.Div(
                        children = [
                            dcc.Graph(
                                id = 'legend1',
                                figure = figure_legend1,
                                config = dict(
                                    responsive = True,
                                    displaylogo = False,
                                    modeBarButtonsToRemove = ['zoom2d', 'zoomIn2d', 'zoomOut2d', 'pan2d', 'select2d', 'lasso2d', 'autoScale2d', 'toggleSpikelines', 'hoverCompareCartesian', 'hoverClosestCartesian', 'resetScale2d', 'toimage']
                                )
                            )
                        ],
                        style = {
                          'width': '90px',
                          'height': '70px',
                          'position': 'absolute',
                          'left': '0',
                          'left': '0',
                          'zIndex': '999',
                          'marginLeft': '475px',
                          'marginTop': '110px',
                        }
                    ),
                    html.Div(
                        children = [
                            dcc.Graph(
                                id = 'legend2',
                                figure = figure_legend2,
                                config = dict(
                                    responsive = True,
                                    displaylogo = False,
                                    modeBarButtonsToRemove = ['zoom2d', 'zoomIn2d', 'zoomOut2d', 'pan2d', 'select2d', 'lasso2d', 'autoScale2d', 'toggleSpikelines', 'hoverCompareCartesian', 'hoverClosestCartesian', 'resetScale2d', 'toimage']
                                )
                            )
                        ],
                        style = {
                          'width': '90px',
                          'height': '70px',
                          'position': 'absolute',
                          'left': '0',
                          'left': '0',
                          'zIndex': '999',
                          'marginLeft': '1085px',
                          'marginTop': '110px',
                        }
                    ),
                ],
                style = {
                  'width': '1300px',
                  'height': '300px',
                  'position': 'relative'
                }
            )
        ]
    )

    return layout


def make_traces(groups_visible):
    # Note - this function may take a long time and may involve access to database(s), e.g. MySQL and/or redis-cache
    traces = []

    # Get data from database, cached in redis
    x = np.linspace(0.0, 2.0*np.pi, 50000)

    # Graph 1
    # Make the polynomial group
    legendgroup = 'polynomial_group'
    group_options = next(item for item in groups_visible if item['legendgroup']==legendgroup)
    traces.append(go.Scatter(
        x = x,
        y = x**2/50.0,
        xaxis = 'x1',
        yaxis = 'y1',
        hoverlabel = dict(namelength=-1),
        legendgroup = legendgroup,
        visible = group_options['visible']
    ))
    traces.append(go.Scatter(
        x = x,
        y = x**3/250.0,
        xaxis = 'x1',
        yaxis = 'y1',
        hoverlabel = dict(namelength=-1),
        legendgroup = legendgroup,
        visible = group_options['visible']
    ))

    # Make the sinusoidal group
    legendgroup = 'sinusoidal_group'
    group_options = next(item for item in groups_visible if item['legendgroup'] == legendgroup)
    traces.append(go.Scatter(
        x = x,
        y = np.sin(x),
        xaxis = 'x1',
        yaxis = 'y1',
        hoverlabel = dict(namelength=-1),
        legendgroup = legendgroup,
        visible = group_options['visible']
    ))
    
    # Graph 2
    # Exponential plot
    legendgroup = 'exponential_group'
    group_options = next(item for item in groups_visible if item['legendgroup'] == legendgroup)
    traces.append(go.Scatter(
        x = x,
        y = -np.exp(x),
        xaxis = 'x2',
        yaxis = 'y2',
        hoverlabel = dict(namelength=-1),
        legendgroup = legendgroup,
        visible = group_options['visible']
    ))

    return traces


app = dash.Dash(__name__)

app.layout = make_layout()


# Create main_graph
@app.callback(
    [ Output('main_graph_figure','data'), Output('main_graph_reduced','data') ],
    [ Input('page_creation_signal','data') ],
    prevent_initial_call=True)
def main_graph__figure(
        _dummy0,
    ):

    traces = make_traces(groups_visible_default)
    traces_reduced = []
    for trace in traces:
        traces_reduced.append(go.Scatter(
            legendgroup = trace['legendgroup'],
            visible = trace['visible']
        ))

    layout = dict(
        xaxis = dict(
            title = 'Time (s)',
            domain = [0.00, 0.47],
            hoverformat = '.2f',
            linecolor = 'black', 
            linewidth = 1,
            mirror = 'allticks',
            zerolinecolor = 'black',
            ticks = 'inside',
            matches = 'x2',
        ),
        xaxis2 = dict(
            title = 'Time (s)',
            domain = [0.53, 1.00],
            hoverformat = '.2f',
            linecolor = 'black', 
            linewidth = 1,
            mirror = 'allticks',
            zerolinecolor = 'black',
            ticks = 'inside',
        ),
        yaxis = dict(
            title = 'Variable-set 1 (units)',
            hoverformat = '.2f',
            linecolor = 'black',
            linewidth = 1,
            mirror = 'allticks',
            zerolinecolor = 'black',
            ticks = 'inside',
        ),
        yaxis2 = dict(
            title = 'Variable-set 1 (units)',
            hoverformat = '.2f',
            linecolor = 'black',
            linewidth = 1,
            mirror = 'allticks',
            zerolinecolor = 'black',
            ticks = 'inside',
            anchor = 'x2',
        ),
        hovermode = 'x',
        showlegend = False
    )

    figure = dict(
        data = traces,
        layout = layout
    )
    figure_reduced = dict(
        data = traces_reduced,
        layout = layout
    )

    return figure, figure_reduced


# Set the legend status
@app.callback(
    Output('main_graph_legend_status','data'),
    [
        Input('legend1','restyleData'),      # Input 1
        Input('legend2','restyleData')       # Input 1
    ],
    [
        State('main_graph_reduced','data'),  # State 1
        State('legend1','figure'),           # State 2
        State('legend2','figure'),           # State 2
    ],
    prevent_initial_call=True)
def main_graph__figure(
        _dummy1,         # Input 1
        _dummy2,         # Input 2
        figure_reduced,  # State 1
        legend1_figure,  # State 2
        legend2_figure   # State 3
    ):

    datas = legend1_figure['data'] + legend2_figure['data']

    traces = figure_reduced['data']
    for trace in traces:
        for data in datas:
            legendgroup = data['legendgroup']
            visible = data['visible']
            if trace['legendgroup']==legendgroup:
                trace['visible'] = visible
                trace['legendgroup'] = legendgroup

    figure = dict(
        data = traces,
        layout = dict(showlegend=False)
    )

    return figure


app.clientside_callback(
    deep_merge,
    Output('main_graph', 'figure'),
    [
        Input('main_graph_figure', 'data'),
        Input('main_graph_legend_status','data')
    ]
)


if __name__ == '__main__':
    app.run_server(debug=True)
2 Likes

This is really really nice. It would be great if we could figure out how to encapsulate this logic into a standalone function or library. Perhaps an AIO component?

Thanks for sharing!