After this answer in SO I coded this function
def add_grouped_legend(fig, data_frame, x, graph_dimensions):
"""Create a grouped legend based on the example here https://stackoverflow.com/a/69829305/8849755
- fig: The figure in which to add such grouped legend.
- data_frame: The data frame from which to create the legend, in principle it should be the same that was plotted in `fig`.
- graph_dimensions: A dictionary with the arguments such as `color`, `symbol`, `line_dash` passed to plotly.express functions you want to group, with the names of the columns in the data_frame."""
param_list = [{'px': {dimension: dimension_value}, 'lg': {'legendgrouptitle_text': dimension_value}} for dimension, dimension_value in graph_dimensions.items()]
legend_traces = []
for param in param_list:
this_dimension_trace = px.line(
data_frame,
x = x,
y = [float('NaN') for i in range(len(data_frame))],
**param["px"],
).update_traces(
**param["lg"],
legendgroup = str(param["px"]),
)
if 'color' not in param['px']:
this_dimension_trace.update_traces(
marker = {'color': '#000000'},
line = {'color': '#000000'},
)
legend_traces.append(this_dimension_trace)
for t in legend_traces:
fig.add_traces(t.data)
which does what I want. Full example below:
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
def add_grouped_legend(fig, data_frame, x, graph_dimensions):
"""Create a grouped legend based on the example here https://stackoverflow.com/a/69829305/8849755
- fig: The figure in which to add such grouped legend.
- data_frame: The data frame from which to create the legend, in principle it should be the same that was plotted in `fig`.
- graph_dimensions: A dictionary with the arguments such as `color`, `symbol`, `line_dash` passed to plotly.express functions you want to group, with the names of the columns in the data_frame."""
param_list = [{'px': {dimension: dimension_value}, 'lg': {'legendgrouptitle_text': dimension_value}} for dimension, dimension_value in graph_dimensions.items()]
legend_traces = []
for param in param_list:
this_dimension_trace = px.line(
data_frame,
x = x,
y = [float('NaN') for i in range(len(data_frame))],
**param["px"],
).update_traces(
**param["lg"],
legendgroup = str(param["px"]),
)
if 'color' not in param['px']:
this_dimension_trace.update_traces(
marker = {'color': '#000000'},
line = {'color': '#000000'},
)
legend_traces.append(this_dimension_trace)
for t in legend_traces:
fig.add_traces(t.data)
# Generate sample data set ---
SIZE = 10
df = pd.DataFrame(
{
"x values": np.tile(np.linspace(0, SIZE - 1, SIZE), SIZE),
"y values": np.sort(np.random.uniform(1, 1000, SIZE ** 2)),
"Device": np.concatenate(
[np.full(SIZE, np.random.choice([52, 36, 34], 1)) for _ in range(SIZE)]
),
"Contact type": np.concatenate(
[np.full(SIZE, np.random.choice(["dot", "ring"], 1)) for _ in range(SIZE)]
),
"Device specs": np.concatenate(
[
np.full(SIZE, np.random.choice(["laptop", "tablet", "console"], 1))
for _ in range(SIZE)
]
),
}
)
df.loc[df["x values"].eq(SIZE - 1), "y values"] = np.nan
# Do a regular plot ---
fig = px.line(
data_frame = df,
x = "x values",
y = "y values",
color = "Device",
symbol = "Contact type",
line_dash = 'Device specs',
)
# Add the grouped legend ---
add_grouped_legend(fig, df, 'x values', dict(color = "Device", symbol = "Contact type", line_dash = 'Device specs',))
fig.show()
and the result