Continuing the discussion from Dynamically update stacked bar annotations based on legend click:
I have added a drill-down dimension of Product Id. The breakdown of the sales for each category can be accessed by clicking on the appropriate trace. This also triggers a back button which can be clicked to go back to the original graph.
Now, let’s suppose I isolate a category using the legend and then do a drill down. When I click on the back button again, is it possible to view the isolated bars that I had viewed before the drill down ?
Isolated graph:
After Drilldown:
After back button click - goes back to original graph:
I have created the full example as below:
from dash import Dash, dcc, html, Input, Output, State, Patch, no_update, ctx
import dash_bootstrap_components as dbc
import plotly.express as px
import pandas as pd
import numpy as np
# Define the range of months
months = pd.date_range(start='2023-01-01', end='2023-12-31', freq='M')
# Define categories and productIds
categories = ['Electronics', 'Clothing', 'Books', 'Groceries', 'Furniture']
prod_ids = ['A','B','C','D','E']
# Generate random sales amounts
np.random.seed(0)
sales_amounts = np.random.randint(100, 1000, size=(len(months), len(categories), len(prod_ids)))
data = []
for i, month in enumerate(months):
for j, category in enumerate(categories):
for k, prodId in enumerate(prod_ids):
if not (str(month) == '2023-03-31 00:00:00' and category == 'Books'):
data.append([month, category, prodId, sales_amounts[i][j][k]])
df = pd.DataFrame(data, columns=['Month', 'Category', 'ProdId','Sale Amount'])
def create_sales_fig(df):
df_grouped = df.groupby(['Month', 'Category'])['Sale Amount'].sum().reset_index()
totals = df_grouped.groupby('Month')['Sale Amount'].sum().reset_index()
sales_fig = px.bar(df_grouped, x='Month', y='Sale Amount', color='Category',
title='Total Sales by Month and Category',
labels={'Sale Amount': 'Total Sales', 'Month': 'Month'},
text='Sale Amount')
sales_fig.update_layout(plot_bgcolor='white')
sales_fig.update_traces(textposition='inside', marker=dict(line=dict(color='black', width=1)))
for i, month_total in totals.iterrows():
sales_fig.add_annotation(x=month_total['Month'], y=month_total['Sale Amount'],
text=f'{month_total["Sale Amount"]:.0f}', showarrow=True, textangle=90)
return sales_fig
sales_graph = dcc.Graph(figure=create_sales_fig(df), id='sales-graph')
sales_card = dbc.Card([
dbc.Button('🡠', id='back-button-2', outline=True, size="sm",
className='mt-2 ml-2 col-1', style={'display': 'none'}),
dbc.Row(
sales_graph, justify='center'
)
], className='mt-3')
app = Dash(__name__)
app.layout = html.Div([
html.H1("Sales Dashboard"),
html.Div(sales_card)
])
@app.callback(
Output('sales-graph', 'figure'),
Output('back-button-2','style'),
Input('sales-graph', 'restyleData'),
Input('sales-graph', 'clickData'),
Input('back-button-2','n_clicks'),
State('sales-graph', 'figure'),
prevent_initial_call=True
)
def update_sales_graph(style, click_data, nclicks, figure):
global df, df_grouped
trigger_id = ctx.triggered[0]["prop_id"]
if trigger_id == 'sales-graph.restyleData':
newFig = Patch()
totals = {}
for i in range(len(figure['layout']['annotations'])):
newFig['layout']['annotations'][i]['y'] = 0
newFig['layout']['annotations'][i]['text'] = f'{0:.0f}'
newFig['layout']['annotations'][i]['showarrow'] = False
for d in figure['data']:
if 'visible' not in d:
index = d['x']
value = d['y']
for i in range(len(index)):
try:
totals[index[i]] += value[i]
except:
totals[index[i]] = value[i]
elif d['visible'] == True:
index = d['x']
value = d['y']
for i in range(len(index)):
try:
totals[index[i]] += value[i]
except:
totals[index[i]] = value[i]
if totals:
for i, (month, annot) in enumerate(totals.items()):
newFig['layout']['annotations'][i]['y'] = annot
newFig['layout']['annotations'][i]['x'] = month
newFig['layout']['annotations'][i]['showarrow'] = True
newFig['layout']['annotations'][i]['text'] = str(annot)
newFig['data'] = figure['data']
return newFig, {'display':'none'}
elif trigger_id == 'sales-graph.clickData':
curve_number = click_data['points'][0]['curveNumber']
month_name = click_data['points'][0]['label']
category = figure['data'][curve_number]['name']
prod_id_total_sales = df[(df['Month'] == month_name) & (df['Category'] == category)]
prod_id_total_sales = prod_id_total_sales.groupby('ProdId')['Sale Amount'].sum().reset_index()
prod_id_total_sales.rename(columns = {'ProdId':'Product ID'}, inplace = True)
prod_id_total_sales.sort_values('Sale Amount', ascending = False, inplace = True)
product_id_units_fig = px.bar(prod_id_total_sales, x = 'Product ID',y = 'Sale Amount',text = prod_id_total_sales['Sale Amount'])
product_id_units_fig.update_layout(plot_bgcolor='white',
font=dict(color='black'),
title = f'{category} Sales Breakdown for {month_name}',
title_x = 0.5,)
product_id_units_fig.update_traces(marker=dict(line=dict(color='black', width=2)))
return product_id_units_fig, {'display':'inline'}
else:
return create_sales_fig(df), {'display':'none'}
if __name__ == '__main__':
app.run_server(port=5555, debug=False)