Updating / extending px facet plots

I’m pretty new to Dash/plotly and have trouble updating the data of a facet / trellis plot without redrawing the whole thing.

Imagine I have datapoints as 4-tuples consisting of an entry that defines the facet, an entry that defines the trace color and two lists for the x and y values for that combination of facet and color. For illustration, a data set with 4 points might look like this:
d1 = (fac1, col1, [1,2,3,4,5], [10,3,6,7,8])
d2 = (fac1, col2, [1,2,3,4,5], [3,7,2,6,4])
d3 = (fac2, col1, [1,2,3,4,5], [45,23,75,21,64])
d4 = (fac2, col2, [1,2,3,4,5], [99,16,47,13,7])

The following is a working example that creates data like this and updates each trace with a new data point every second. The intial plot is created using px.line(df, x=‘x’, y=‘y’, color=‘color’, facet_row=‘facet’,) after putting the data in DataFrame df. How the data is created / converted isn’t really the point, look at the picture below and you should get the idea.

import random
import pandas as pd
import dash
import plotly.express as px
import dash_core_components as dcc
import dash_html_components as html
from dash.exceptions import PreventUpdate
from dash.dependencies import Input, Output, State

app = dash.Dash(__name__)
app.layout = html.Div([
	dcc.Graph(id='facet-graph', style={'height':'90vh'}),
	dcc.Interval(id='facet-interval', interval=1*1000),
	dcc.Store(id='facet-data'),
])

DATA_FACETS = ['fac1','fac2','fac3']
DATA_COLORS = ['col1','col2','col3','col4','col5']

def get_data(old_data):
	# just some example data
	new_data = []
	if not old_data:
		# generate 10 new datapoints for each combination of facet and color
		# each facet gets a different data range for illustration
		for (i,fac) in enumerate(DATA_FACETS):
			for col in DATA_COLORS:
				xdata = list(range(10))
				ydata = [random.randint(10**i,10**(i+1)) for _ in range(10)]
				new_data.append((fac, col, xdata, ydata))
	else:
		# extend the existing data by a single point
		old_idx = 0
		for (i,fac) in enumerate(DATA_FACETS):
			for col in DATA_COLORS:
				# throw away the oldest data point and append a new one
				old_data_entry = old_data[old_idx]
				old_xdata = old_data_entry[2]
				new_xdata = old_xdata[1:] + [old_xdata[-1]+1]
				old_ydata = old_data_entry[3]
				new_ydata = old_ydata[1:] + [random.randint(10**i,10**(i+1))]
				new_data.append((fac, col, new_xdata, new_ydata))
				old_idx += 1
	return new_data

@app.callback(Output('facet-data', 'data'),
					Input('facet-interval', 'n_intervals'),
					State('facet-data', 'data'),)
def update_data(n_intervals, old_data):
	new_data = get_data(old_data)
	return new_data

@app.callback(Output('facet-graph', 'figure'),
					Input('facet-data', 'data'),
					State('facet-graph', 'figure'),)
def update_figure(data, old_fig):
	if data == None:
		raise PreventUpdate
	
	if old_fig == None or len(old_fig['data']) == 0:
		# draw new figure from data
		return draw_new_figure(data)
	else:
		# update existing figure with new data
		return update_old_figure(data, old_fig)

def draw_new_figure(data):
	# pack the data in a data frame
	df_data = []
	for data_entry in data:
		fac, col, xdata, ydata = data_entry
		df_data.extend([(fac, col, x, y) for (x,y) in zip(xdata, ydata)])
	df = pd.DataFrame(df_data, columns=['facet', 'color', 'x', 'y'])
	
	# plot it!
	fig = px.line(df, x='x', y='y', color='color', facet_row='facet',)
	fig.update_yaxes(matches=None) # make y-axes independent of each other
	return fig

def update_old_figure(data, old_fig):
	# draw new figure because updating doesn't work :(
	return draw_new_figure(data)
	
	# This attempt doesn't work because appearantly the facet call in
	# draw_new_figure does not preserve trace order (I guess).
	for (i, data_entry) in enumerate(data):
		old_fig['data'][i]['x'] = data_entry[2]
		old_fig['data'][i]['y'] = data_entry[3]
	return old_fig

if __name__ == '__main__':
	app.run_server(debug=True)

This results in a graph like this (note how each facet has a different y-data dimension):

As you can see, in the update_figure callback I call different functions depending on whether I create the figure for the first time (-> draw_new_figure) or whether I could update the existing figure (-> update_old_figure). However, the update_old_figure function currently just calls draw_new_figure, because I couldn’t get the updating of the old figure to work.

update_old_figure includes a failed attempt of how I tried to update the data by setting the corresponding x and y values in the figure object itself. However, calling px.line() with the color and facet_row options appearantly didn’t preserve the original data order, because the output of the attempt looked like this:

Note how some of the traces belonging in the lower facet (y range 1-1000) ended up in the other facets and vice versa.

There are two main reasons why I want to update the data instead of redrawing the whole graph drawing px.line() repeatedly:

  • calling px.line() again doesn’t preserve zoom or trace selection
  • calling px.line() repeatedly is a lot slower than just updating the data

So, am I missing some obvious way to update the trace data of a faceted graph?

I imagine I could iterate over the traces of the figure object and select my corresponding data points by looking at the trace’s “name” and “yaxis” properties, but that feels overelaborate.

I found a wokraround using make_subplots which is also a bit troublesome because I have to take care of legendgroups, hiding legends for all except one subplot and matching trace color myself, but it does what I want to do - is there a way to get similar behaviour with a facet plot?
Code for workaround:

import random
import dash
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import dash_core_components as dcc
import dash_html_components as html
from dash.exceptions import PreventUpdate
from dash.dependencies import Input, Output, State

app = dash.Dash(__name__)
app.layout = html.Div([
	dcc.Graph(id='facet-graph', style={'height':'90vh'}),
	dcc.Interval(id='facet-interval', interval=1*1000),
	dcc.Store(id='facet-data'),
])

DATA_FACETS = ['fac1','fac2','fac3']
DATA_COLORS = ['col1','col2','col3','col4','col5']

def get_data(old_data):
	# just some example data
	new_data = []
	if not old_data:
		# generate 10 new datapoints for each combination of facet and color
		# each facet gets a different data range for illustration
		for (i,fac) in enumerate(DATA_FACETS):
			for col in DATA_COLORS:
				xdata = list(range(10))
				ydata = [random.randint(10**i,10**(i+1)) for _ in range(10)]
				new_data.append((fac, col, xdata, ydata))
	else:
		# extend the existing data by a single point
		old_idx = 0
		for (i,fac) in enumerate(DATA_FACETS):
			for col in DATA_COLORS:
				# throw away the oldest data point and append a new one
				old_data_entry = old_data[old_idx]
				old_xdata = old_data_entry[2]
				new_xdata = old_xdata[1:] + [old_xdata[-1]+1]
				old_ydata = old_data_entry[3]
				new_ydata = old_ydata[1:] + [random.randint(10**i,10**(i+1))]
				new_data.append((fac, col, new_xdata, new_ydata))
				old_idx += 1
	return new_data

@app.callback(Output('facet-data', 'data'),
					Input('facet-interval', 'n_intervals'),
					State('facet-data', 'data'),)
def update_data(n_intervals, old_data):
	new_data = get_data(old_data)
	return new_data

@app.callback(Output('facet-graph', 'figure'),
					Input('facet-data', 'data'),
					State('facet-graph', 'figure'),)
def update_figure(data, old_fig):
	if data == None:
		raise PreventUpdate
	
	if old_fig == None or len(old_fig['data']) == 0:
		# draw new figure from data
		return draw_new_figure(data)
	else:
		# update existing figure with new data
		return update_figure(data, old_fig)

def draw_new_figure(data):
	fig = make_subplots(rows=3, cols=1, shared_xaxes=True)
	for (fac, col, xdata, ydata) in data:
		fac_number = int(fac[-1])
		col_number = int(col[-1])
		color = px.colors.qualitative.Plotly[col_number]
		trace_args = dict(
			x = xdata, y = ydata, mode = 'lines',
			name = col, legendgroup = col, showlegend = fac_number==1,
			line = dict(color=color),
		)
		fig.add_trace(go.Scatter(**trace_args), row=fac_number, col=1)
	
	return fig

def update_figure(data, old_fig):
	# Now it works because trace order is the same as data order.
	for (i, data_entry) in enumerate(data):
		old_fig['data'][i]['x'] = data_entry[2]
		old_fig['data'][i]['y'] = data_entry[3]
	return old_fig

if __name__ == '__main__':
	app.run_server(debug=True)

BTW, I’m aware I could also target the ‘extendData’ attribute of the graph as Output in the workaround, but this works well enough for this example.