How to group legends of two subplots in plotly.subplots?

I am working on a Plotly code that displays a 2D random walk. I have also added one more subplot that shows the number of times the walker was in each quadrant. However, I am facing some plot issues that I would like to solve.

Firstly, is there a way to group the legends of each subplot instead of showing legends for two subplots? By grouping the legends, I mean when I click on one legend (such as the first quadrant), it will make the corresponding traces on both subplots disappear. If I click on the legend of first quadrant, then the random walk in the first subplot and the corresponding count in the second subplot will disappear for that quadrant.

Secondly, I have written the x and y axis titles for both subplots, but only one is appearing. Additionally, the x and y-axis titles of the second plot are appearing on the first subplot. Can you suggest a solution to fix these issues?

I am attaching the code and the plot that I am getting.

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

stepsize = 0.5
num_steps = 250

# Generating 2D random walk data
pos = np.array([0, 0])
path = []
for _ in range(num_steps):
    pos = pos + np.random.normal(0, stepsize, 2)
    path.append(list(pos))
path = np.array(path)
x = path[:, 0]
y = path[:, 1]

# Determining quadrant for each step
quadrants = np.zeros(num_steps)
quadrants[(x >= 0) & (y >= 0)] = 1  # First quadrant
quadrants[(x < 0) & (y >= 0)] = 2   # Second quadrant
quadrants[(x < 0) & (y < 0)] = 3    # Third quadrant
quadrants[(x >= 0) & (y < 0)] = 4   # Fourth quadrant


# Counting occurrences in each quadrant
quadrant_counts = [np.sum(quadrants == q) for q in range(1, 5)]

# Counting occurrences at (0,0)
origin_count = np.sum((x == 0) & (y == 0))

colors = ['red', 'blue', 'green', 'purple']

fig1 = go.Figure()
for q in range(1, 5):
    x_q = x[quadrants == q]
    y_q = y[quadrants == q]
    fig1.add_trace(go.Scatter(
        x=x_q,
        y=y_q,
        name=f'Quadrant {q}',
        mode='markers+lines',
        marker=dict(
            color=colors[q-1],
            size=10
        ),
        line=dict(color='black', width=0.5),
    ))

# Adding colored rectangles for each quadrant
fig1.update_layout(
    shapes=[
        # First quadrant
        dict(
            type="rect",
            xref="x",
            yref="y",
            x0=0,
            y0=0,
            x1=max(x)+2,
            y1=max(y)+2,
            fillcolor = 'red',
            opacity = 0.2,
            line=dict(width=0),
            layer="below"
        ),
        # Second quadrant
        dict(
            type="rect",
            xref="x",
            yref="y",
            x0=min(x)-2,
            y0=0,
            x1=0,
            y1=max(y)+2,
            fillcolor="blue",
            opacity = 0.2,
            line=dict(width=0),
            layer="below"
        ),
        # Third quadrant
        dict(
            type="rect",
            xref="x",
            yref="y",
            x0=min(x)-2,
            y0=min(y)-2,
            x1=0,
            y1=0,
            fillcolor="green",
            opacity = 0.2,
            line=dict(width=0),
            layer="below"
        ),
        # Fourth quadrant
        dict(
            type="rect",
            xref="x",
            yref="y",
            x0=0,
            y0=min(y)-2,
            x1=max(x)+2,
            y1=0,
            fillcolor="purple",
            opacity = 0.2,
            line=dict(width=0),
            layer="below"
        )
    ]
)

fig1.add_shape(
    type='line',
    x0=0, y0=min(y)-2,
    x1=0, y1=max(y)+2,
    line=dict(color='black', width=2)
)
fig1.add_shape(
    type='line',
    x0=min(x)-2, y0=0,
    x1=max(x)+2, y1=0,
    line=dict(color='black', width=2)
)

fig1.add_trace(go.Scatter(
    x=[x[0], x[-1]],
    y=[y[0], y[-1]],
    mode='markers+text',
    name='Start and End Point',
    text=['Starting Point', 'End Point'],
    textposition='bottom center',
    marker=dict(color='black', size=15),
    textfont=dict(color='black', size=10)
))

fig1.update_layout(
    height=600,
    width=1000,
    xaxis_title = 'x',
    yaxis_title = 'y',
    showlegend=True,
    legend=dict( orientation="h",
                                yanchor="top",
                                y=1.1,
                                xanchor="right",
                                x=1),
)




# Creating the bar plot for the counts
fig2 = go.Figure()
for q in range(1, 5):
    fig2.add_trace(go.Bar(
        x=[f'Quadrant {q}'],
        y=[quadrant_counts[q - 1]],
        marker_color=colors[q - 1],
        name=f'Quadrant {q}'
    ))

# Adding bar plot for (0,0) position count
fig2.add_trace(go.Bar(
    x=['(0,0)'],
    y=[origin_count],
    marker_color='gray',
    name='(0,0) Position'
))

fig2.update_layout(
    xaxis_title='Position',
    yaxis_title='Count',
    height=600,
    width=1000,
    showlegend=True,
    legend=dict( orientation="h",
                                yanchor="top",
                                y=1.1,
                                xanchor="right",
                                x=1)
    
)

# Creating subplots
fig = make_subplots(rows=2, cols=1, subplot_titles=['2D Random Walk', 'Position Counts'], 
                   vertical_spacing=0.15)

# Adding traces from fig1
for trace in fig1.data:
    fig.add_trace(trace, row=1, col=1)

# Adding shapes from fig1
for shape in fig1.layout.shapes:
    fig.add_shape(shape)

# Adding layout attributes from figures
fig.update_layout(fig1.layout)
fig.update_layout(fig2.layout)

# Adding bar plot for the counts in the second subplot
for trace in fig2.data:
    fig.add_trace(trace, row=2, col=1)

fig.update_layout(height=1200, width=1000, showlegend=True)
                  

fig.show()

Yes, you can group legends like this by using legendgroup and then showing the legend for one subplot only.
Also if you use update_xaxes and update_yaxes to set properties of axes (rather than fig.update_layout) then you can specify row and col of the subplot.

Here’s an extract of some code that does this sort of thing - not a self-contained complete example, but hopefully makes some sense

fig = make_subplots(rows=2, cols=3, subplot_titles = valcols[:6])

for row in range(1,3):      # Rows are numbered from 1, not from 0
    for col in range(1,4):  # Columns are numbered from 1, not from 0
        valindex = (row-1)*3 + col - 1  # The index of the value column
        val = valcols[valindex]         # The name of the value column
        
        # Set the axis ranges
        fig.update_xaxes(range=[df_abm['date'].min(), df_abm['date'].max()], row=row, col=col)
        fig.update_yaxes(range=[df_abm[val].min(), df_abm[val].max()], row=row, col=col)
        
        # Add lines
        for i, (station, dfs) in enumerate(df_abm.groupby('station')):
            fig.add_trace(go.Scatter(x=dfs['date'], y=dfs[val], 
                                    name=station, legendgroup=station,
                                    mode='lines', hoverinfo='x+y+name', 
                                    line={'color':plotly.colors.DEFAULT_PLOTLY_COLORS[i % 10]},
                                    showlegend=(row == 1 and col == 1)),
                         row=row, col=col)
        
fig.show()
1 Like

Cool, it worked!
Attaching the updated version of the code and figure.

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

stepsize = 0.5
num_steps = 250

# Generating 2D random walk data
pos = np.cumsum(np.random.normal(0, stepsize, (num_steps, 2)), axis=0)
x,y = pos[:,0], pos[:,1]

# Determining quadrant for each step
quadrants = np.zeros(num_steps)
quadrants[(x >= 0) & (y >= 0)] = 1  # First quadrant
quadrants[(x < 0) & (y >= 0)] = 2   # Second quadrant
quadrants[(x < 0) & (y < 0)] = 3    # Third quadrant
quadrants[(x >= 0) & (y < 0)] = 4   # Fourth quadrant

# Counting occurrences in each quadrant
quadrant_counts = [np.sum(quadrants == q) for q in range(1, 5)]
# Counting occurrences at (0,0)
origin_count = np.sum((x == 0) & (y == 0))

##color for each quadrant 
colors = ['red', 'blue', 'green', 'purple']

##showing the random walk, in each quadrant 
fig1 = go.Figure()
for q in range(1, 5):
    x_q = x[quadrants == q]
    y_q = y[quadrants == q]
    fig1.add_trace(go.Scatter(
        x=x_q,
        y=y_q,
        name=f'Quadrant {q}',
        legendgroup= f'Quadrant {q}',
        mode='markers+lines',
        marker=dict(
            color=colors[q-1],
            size=10
        ),
        line=dict(color='black', width=1),
    ))

## text annotations indicating the starting and end point
fig1.add_trace(go.Scatter(
    x=[x[0], x[-1]],
    y=[y[0], y[-1]],
    mode='markers+text',
    name='Start and End Point',
    text=['Starting Point', 'End Point'],
    textposition='bottom center',
    marker=dict(color='black', size=15),
    textfont=dict(color='black', size=10)
))

## for colored background for each quadrant
quadrant_shapes = []
for q in range(4):
    if q == 0:
        x0, y0, x1, y1 = 0, 0, max(x) + 2, max(y) + 2
    elif q == 1:
        x0, y0, x1, y1 = min(x) - 2, 0, 0, max(y) + 2
    elif q == 2:
        x0, y0, x1, y1 = min(x) - 2, min(y) - 2, 0, 0
    else:
        x0, y0, x1, y1 = 0, min(y) - 2, max(x) + 2, 0

    quadrant_shapes.append(
        dict(
            type="rect",
            xref="x",
            yref="y",
            x0=x0,
            y0=y0,
            x1=x1,
            y1=y1,
            fillcolor=colors[q],
            opacity=0.2,
            line=dict(width=0),
            layer="below"
        )
    )

fig1.update_layout(shapes=quadrant_shapes)

for axis, limits in zip(['x', 'y'], [(min(x) - 2, max(x) + 2), (min(y) - 2, max(y) + 2)]):
    fig1.add_shape(
        type='line',
        x0=limits[0] if axis == 'x' else 0,
        y0=limits[0] if axis == 'y' else 0,
        x1=limits[1] if axis == 'x' else 0,
        y1=limits[1] if axis == 'y' else 0,
        line=dict(color='black', width=2)
    )

# Creating the bar plot for the counts
fig2 = go.Figure()
for q in range(1, 5):
    fig2.add_trace(go.Bar(
        x=[f'Quadrant {q}'],
        y=[quadrant_counts[q - 1]],
        marker_color=colors[q - 1],
        name=f'Quadrant {q}',
        legendgroup= f'Quadrant {q}',
    ))
    
# Creating subplots
fig = make_subplots(rows=2, cols=1, subplot_titles=['2D Random Walk', 'Position Counts'], 
                   vertical_spacing=0.1)

# Adding traces from fig1
for trace in fig1.data:
    fig.add_trace(trace, row=1, col=1)

# Adding shapes from fig1
for shape in fig1.layout.shapes:
    fig.add_shape(shape)

# Adding layout attributes from figures
fig.update_layout(fig1.layout)
fig.update_layout(fig2.layout)

# Adding bar plot for the counts in the second subplot
for trace in fig2.data:
    fig.add_trace(trace, row=2, col=1)
    
fig.update_layout(height=1200, width=1000, xaxis_title='x',yaxis_title='y',xaxis2_title='Quadrant',yaxis2_title='Count',
                  showlegend=True, legend=dict( orientation="h",yanchor="top",y=1.1,xanchor="right",x=1))
                
fig.show()

1 Like