Store restyled figure state to use as output in callback

Continuing the discussion from Dynamically update stacked bar annotations based on legend click:

I have added a drill-down dimension of Product Id. The breakdown of the sales for each category can be accessed by clicking on the appropriate trace. This also triggers a back button which can be clicked to go back to the original graph.

Now, let’s suppose I isolate a category using the legend and then do a drill down. When I click on the back button again, is it possible to view the isolated bars that I had viewed before the drill down ?

Isolated graph:

After Drilldown:

After back button click - goes back to original graph:

I have created the full example as below:

from dash import Dash, dcc, html, Input, Output, State, Patch, no_update, ctx
import dash_bootstrap_components as dbc
import plotly.express as px
import pandas as pd
import numpy as np

# Define the range of months
months = pd.date_range(start='2023-01-01', end='2023-12-31', freq='M')

# Define categories and productIds
categories = ['Electronics', 'Clothing', 'Books', 'Groceries', 'Furniture']
prod_ids = ['A','B','C','D','E']

# Generate random sales amounts
np.random.seed(0)
sales_amounts = np.random.randint(100, 1000, size=(len(months), len(categories), len(prod_ids)))

data = []
for i, month in enumerate(months):
    for j, category in enumerate(categories):
        for k, prodId in enumerate(prod_ids):
            if not (str(month) == '2023-03-31 00:00:00' and category == 'Books'):
                data.append([month, category, prodId, sales_amounts[i][j][k]])

df = pd.DataFrame(data, columns=['Month', 'Category', 'ProdId','Sale Amount'])

def create_sales_fig(df):
    df_grouped = df.groupby(['Month', 'Category'])['Sale Amount'].sum().reset_index()
    totals = df_grouped.groupby('Month')['Sale Amount'].sum().reset_index()

    sales_fig = px.bar(df_grouped, x='Month', y='Sale Amount', color='Category',
                 title='Total Sales by Month and Category',
                 labels={'Sale Amount': 'Total Sales', 'Month': 'Month'},
                 text='Sale Amount')

    sales_fig.update_layout(plot_bgcolor='white')
    sales_fig.update_traces(textposition='inside', marker=dict(line=dict(color='black', width=1)))

    for i, month_total in totals.iterrows():
        sales_fig.add_annotation(x=month_total['Month'], y=month_total['Sale Amount'],
                           text=f'{month_total["Sale Amount"]:.0f}', showarrow=True, textangle=90)
    return sales_fig

    
sales_graph = dcc.Graph(figure=create_sales_fig(df), id='sales-graph')
sales_card = dbc.Card([
                dbc.Button('🡠', id='back-button-2', outline=True, size="sm",
                            className='mt-2 ml-2 col-1', style={'display': 'none'}),
                dbc.Row(
                    sales_graph, justify='center'
                )
                ], className='mt-3')    
    
    
app = Dash(__name__)

app.layout = html.Div([
    html.H1("Sales Dashboard"),
    html.Div(sales_card)
])


@app.callback(
    Output('sales-graph', 'figure'),
    Output('back-button-2','style'),
    Input('sales-graph', 'restyleData'),
    Input('sales-graph', 'clickData'),
    Input('back-button-2','n_clicks'),
    State('sales-graph', 'figure'),
    prevent_initial_call=True
)
def update_sales_graph(style, click_data, nclicks, figure):
    
    global df, df_grouped
    
    trigger_id = ctx.triggered[0]["prop_id"]
    
    if trigger_id == 'sales-graph.restyleData':
        
        newFig = Patch()

        totals = {}

        for i in range(len(figure['layout']['annotations'])):
            newFig['layout']['annotations'][i]['y'] = 0
            newFig['layout']['annotations'][i]['text'] = f'{0:.0f}'
            newFig['layout']['annotations'][i]['showarrow'] = False

        for d in figure['data']:

            if 'visible' not in d:
                index = d['x']
                value = d['y']
                for i in range(len(index)):
                    try:
                        totals[index[i]] += value[i]
                    except:
                        totals[index[i]] = value[i]

            elif d['visible'] == True:
                index = d['x']
                value = d['y']
                for i in range(len(index)):
                    try:
                        totals[index[i]] += value[i]
                    except:
                        totals[index[i]] = value[i]
        if totals:
            for i, (month, annot) in enumerate(totals.items()):
                newFig['layout']['annotations'][i]['y'] = annot
                newFig['layout']['annotations'][i]['x'] = month
                newFig['layout']['annotations'][i]['showarrow'] = True
                newFig['layout']['annotations'][i]['text'] = str(annot)
        newFig['data'] = figure['data']
        return newFig, {'display':'none'}
    
    elif trigger_id == 'sales-graph.clickData':
        curve_number = click_data['points'][0]['curveNumber']
        month_name = click_data['points'][0]['label']
        category = figure['data'][curve_number]['name']
        prod_id_total_sales = df[(df['Month'] == month_name) & (df['Category'] == category)]
        prod_id_total_sales = prod_id_total_sales.groupby('ProdId')['Sale Amount'].sum().reset_index()
        prod_id_total_sales.rename(columns = {'ProdId':'Product ID'}, inplace = True) 
        prod_id_total_sales.sort_values('Sale Amount', ascending = False, inplace = True)

        product_id_units_fig = px.bar(prod_id_total_sales, x = 'Product ID',y = 'Sale Amount',text = prod_id_total_sales['Sale Amount'])
        product_id_units_fig.update_layout(plot_bgcolor='white',
                                                 font=dict(color='black'),
                                                 title = f'{category} Sales Breakdown for {month_name}',
                                                 title_x = 0.5,)        
        product_id_units_fig.update_traces(marker=dict(line=dict(color='black', width=2)))
        return product_id_units_fig, {'display':'inline'}
    else:        
        return create_sales_fig(df), {'display':'none'}

if __name__ == '__main__':
    app.run_server(port=5555, debug=False)

I think you should be able to do this by adding a dcc.Store to your layout, with its ‘data’ added as both an Output and a State of your callback.

Then on drilldown, send the (Input) (‘sales-graph’, ‘figure’) to the Store, and on back button click, restore the (‘sales-graph’, ‘figure’) from the data held in the Store.

@davidharris Thank you for your reply, I would really appreciate it if you could provide a working example for this.

These changes to your code seem (at first sight) to work (unchanged code lines not shown):

app.layout = html.Div([
    html.H1("Sales Dashboard"),
    html.Div(sales_card),
    dcc.Store(id='figure-store')
])


@app.callback(
    Output('sales-graph', 'figure'),
    Output('back-button-2','style'),
    Output('figure-store','data'),
    Input('sales-graph', 'restyleData'),
    Input('sales-graph', 'clickData'),
    Input('back-button-2','n_clicks'),
    State('sales-graph', 'figure'),
    State('figure-store','data'),
    prevent_initial_call=True
)
def update_sales_graph(style, click_data, nclicks, figure, stored_figure):
    ...
    if trigger_id == 'sales-graph.restyleData':
        ...        
        return newFig, {'display':'none'}, no_update
    
    elif trigger_id == 'sales-graph.clickData':
        ...     
        return product_id_units_fig, {'display':'inline'}, figure
    else:        
        return stored_figure, {'display':'none'}, no_update


1 Like

@davidharris Thanks a lot! This works as intended!