✊🏿 Black Lives Matter. Please consider donating to Black Girls Code today.
⚡️ Concerned about the grid? Kyle Baranko teaches how to predicting peak loads using XGBoost. Register for the August webinar!

Dash - Change trace values upon change in legend selection

Hello, everyone.
I’m currently working on a dash interface, with standard functionality:

  • select parameters from a bunch of dropdowns and hit a button.

  • the callback function reacting to the button press:

    1. uses an imported class to fetch the numbers from a database,
    2. then a second class to apply some logic to get the recomputed values,
    3. appends the individual traces and then calculates the mean trace and adds that (as a hidden line), and then
    4. plot the graph shown below.

I would like to implement the behaviour where by deselecting a trace from the legend causes the average trace to update using the remaining traces. Also, reselecting it should trigger an update of the Mean trace.

I have read about the

clickmode='event+select'

but that seems to trigger when the trace is selected inside the plot, whilst I would need the react to clicks of traces in the legend in order to show or hide them. But I may have missed something about that.

Here is a simplified version of the code structure:

import data
import logic

import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output, State, MATCH, ALL

import plotly.graph_objects as go
import pandas as pd

data_importer = data.DataImporter()
data_logic = logic.Logic()


app = dash.Dash(__name__)
server = app.server


def build_layout():
    return html.Div(children=[
        html.H1('Great Tool'),
        html.Div(id='parameter-section', children=[
            # various drop downs
            html.Button('Calculate', id='calculate-button')])
        html.div(id='plot-output')])


@app.callback(
    Output('plot-output', 'children'),
    Input('calculate-button', 'n_clicks'),
    # bunch of States for the parameters
)
def create_plot(n_clicks):  # plus the various parameters passed
    parameters = {}   # parameters placed in a dictionary

    working_dataset = data_importer.process_request(**parameters))

    # getting the transformed data plus the x-axis values for the plot
    analysis_result, time_values = data_logic.analyse_data(working_dataset)

    # condition check to detect errors in calculation
    if isinstance(analysis_result, pd.DataFrame):
        fig = go.Figure()
        # add all the normal traces
        for i in analysis_result.iterrows():
            fig.add_trace(
                go.Scatter(
                    y=i[1],
                    x=time_values,
                    name=i[0],
                    mode='lines',
                    hoverinfo='name+y'
                )
            )

        # Create the trace of Mean values
        fig.add_trace(
            go.Scatter(
                y=analysis_result.mean(axis=0),
                x=time_values,
                name='Mean',
                mode='lines',
                line={'width': 0}
                showlegend=False,
                hovertemplate='<b>%{y}</b>'
                hoverinfo='name+y'
            )
        )

        fig.update_layout(hovermode='x unified')
        
        return dcc.Graph(figure=fig)


app.layout = build_layout


if __name__ == '__main__':
    app.run_server(debug=True, dev_tools_ui=True, port=80)

Thanks in advance for your help.