Change variable used for coloring using the new partial property updates

Hi,

I am trying to use the recently introduced partial property updates to change the variable used for coloring a scatter plot created using plotly.express. I am wondering if there is an easy way to do this without having to manually create all the new traces. Here is a minimal example illustrating what I am trying to achieve:

from dash import Dash, html, dcc, callback, Output, Input, callback_context, Patch
import plotly.express as px
import pandas as pd

df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/gapminder_unfiltered.csv')

app = Dash(__name__)

app.layout = html.Div([
    html.H1(children='Title of Dash App', style={'textAlign':'center'}),
    dcc.Dropdown(df.country.unique(), ['Canada'], id='dropdown-selection', multi=True),
    dcc.Dropdown(['country', 'continent'], 'country', id='color-selection'),
    dcc.Graph(id='graph-content')
])



@callback(
    Output('graph-content', 'figure'),
    Input('dropdown-selection', 'value'),
    Input('color-selection', 'value')
)
def update_graph(countries, color):
    if callback_context.triggered[0]["prop_id"] == 'color-selection.value':
        patched_figure = Patch()
        # how to use a different variable for the color without manually creating all new traces etc?
        # return patched_figure
    dff = df[df.country.isin(countries)]
    return px.line(dff, x='year', y='pop', line_group='country', color=color)

if __name__ == '__main__':
    app.run_server(debug=True)

I am looking forward to your answers!

Hi @christoph.b !
I was thinking about your needs., it is possible but I think not easy
You want to change the color of the traces to display them either by country or group them by continent.
If you want to use a patch that mean you have to:

  • define the color for each trace either by country or grouped by continent
  • if grouped by continent, you have to def the group for each trace
  • update the legend accordingly

Another difficulty, once the figure is created, the df is no more accessible in the figure object, meaning if you displayed the color by countries for example, then you want to patch to change the color by continent, you will not able to know what is the continent of each trace.
You can address this by using 'customdata' to keep the info of country and continent by trace (actually by data point) in the figure object.

Actually you can get the country and the continent for each trace with the hoover text, but you need to process it to extract those info

But I took the challenge :muscle:, I’m going to share you my app if that can help you :grin:
I tried to comment as much as possible to make it clearer.

from dash import Dash, dcc, html, Input, Output, Patch, callback, ctx, State
import dash_bootstrap_components as dbc
import plotly.express as px

app = Dash(__name__)

df = px.data.gapminder()

app.layout = html.Div([
    dcc.Dropdown(id='dropdown-countries', options=df.country.unique(), value=['Canada', 'United States', 'France'],
                 multi=True),
    dbc.Switch(id="switch-group-continent", label="Group by continent"),
    dcc.Graph(id='graph-content')
])


@callback(
    Output('graph-content', 'figure'),
    Input('dropdown-countries', 'value'),
    Input("switch-group-continent", "value"),
    State('graph-content', 'figure'),
)
def update_graph(countries, group_ON, fig):
    if ctx.triggered_id == 'dropdown-countries':
        return px.line(
            df[df.country.isin(countries)],
            x='year', y='pop',
            color='continent' if group_ON else 'country',
            line_group='country' if group_ON else None,  # no need to group by 'country' if color = 'country'
            custom_data=['country', 'continent']
        )

    patched_figure = Patch()

    # patch the legend title
    patched_figure['layout']['legend']['title']['text'] = 'Continent' if group_ON else 'Country'

    # list of colors to loop through, here get the template colorway
    colors = fig['layout']['template']['layout']['colorway']

    # create a set by continents or countries, to remove duplicates (note: countries should be already unique)
    groups = {trace['customdata'][0][group_ON] for trace in fig['data']}  # see NOTE below
    # maps colors to the groups (by continent or country)
    # note: i % len(colors): to loop again through colors list if more groups than colors
    color_map = {group: colors[i % len(colors)] for i, group in enumerate(groups)}

    legend_items = set()
        # patch the trace attributes depending on its group
    for i, trace in enumerate(fig['data']):
        group = trace['customdata'][0][group_ON]  # see NOTE below
        patched_figure['data'][i]['legendgroup'] = group
        patched_figure['data'][i]['name'] = group
        patched_figure['data'][i]['line']['color'] = color_map[group]
        # hide trace in legend if its group is already in it
        patched_figure['data'][i]['showlegend'] = group not in legend_items
        legend_items.add(group)  # since legend_items is a set, group will be added only if not already in it

    return patched_figure


if __name__ == "__main__":
    app.run_server(debug=True)

# NOTE:
# Same as:
# if group_ON:
#     groups = {trace['customdata'][0][1] for trace in fig['data']}
# else:
#     groups = {trace['customdata'][0][0] for trace in fig['data']}

# To make the syntax shorter, I used a little construction trick,
# As trace['customdata'][0][0] is the trace's country and trace['customdata'][0][1] is the continent
# If group_ON = True, we want the continent and then trace['customdata'][0][group_ON] will select the continent
# If group_ON = False, we want the country and then trace['customdata'][0][group_ON] will select the country

2 Likes

Hi @Skiks

thank you so much for your effort!
I just tried your code and it looks like it does exactly what I want.
I am now going through it to understand it in detail and I’ll let you know if I managed to apply it to my problem :slightly_smiling_face:

1 Like

Hi @Skiks and @adamschroeder I am having problem with background updates, have they perhaps not been implemented yet?

@callback(Output("market_risk_plot", "figure"), Input("color-scheme-toggle", "n_clicks"), State("theme-store", "data"))
def my_callback(n_clicks, data):
    # Defining a new random color
    patched_figure = Patch()
    print(data)
    if data["colorScheme"]=="dark":
        # Creating a Patch object
        print(patched_figure["layout"])
        patched_figure["layout"]["plot_bgcolor"] = 'rgba(0,0,0,0)'
        patched_figure["layout"]["graph_line"] = 'rgba(0,0,0,0)'

        return patched_figure
    elif data["colorScheme"]=="light":
        patched_figure["layout"]["plot_bgcolor"] = "#082255"
        patched_figure["layout"]["graph_line"] = "#007ACE"

        return patched_figure

I should probably add that other properties do work like, and that it is a plotly express created plot:
patched_figure["layout"]["title"]["font"]["color"] = "082255"

Hi @snowde !
I think the problem is “graph_line”, the layout doesn’t have this property.
I don’t know what kind of graph you are using, but here is an example with lines graphs

Here is the code:

import plotly.graph_objects as go
import numpy as np
from dash import Dash, html, dcc, Input, Output, Patch, State
import dash_bootstrap_components as dbc
from random import sample, random

app = Dash(__name__)

t = np.linspace(0, 2 * np.pi, 100)

fig = go.Figure()

n_trace = 5
for i in range(n_trace):
    fig.add_scatter(x=t, y=np.sin(t + (i / n_trace) * np.pi), name=f'sin(t+{i}/{n_trace}*pi)')

app.layout = html.Div(
    [
        dbc.Button("New background color", id="btn-bg-color"),
        dbc.Button("New plot line color", id="btn-line-color"),
        dcc.Graph(id='graph-content', figure=fig),
    ]
)


@app.callback(
    Output("graph-content", "figure"),
    Input("btn-bg-color", "n_clicks"),
    prevent_initial_call=True,
)
def change_bg_color(n):
    patched_fig = Patch()
    patched_fig['layout']['plot_bgcolor'] = f'rgba({",".join(map(str, sample(range(255), 3)))}, {random() / 2})'
    return patched_fig


@app.callback(
    Output("graph-content", "figure", allow_duplicate=True),
    Input("btn-line-color", "n_clicks"),
    State("graph-content", "figure"),
    prevent_initial_call=True,
)
def change_line_color(n, fig):
    patched_fig = Patch()
    for trace in range(len(fig['data'])):
        patched_fig['data'][trace]['line']['color'] = f'rgb({",".join(map(str, sample(range(255), 3)))})'
    return patched_fig


if __name__ == "__main__":
    app.run_server(debug=True)

3 Likes