I am working on a Plotly code that displays a 2D random walk. I have also added one more subplot that shows the number of times the walker was in each quadrant. However, I am facing some plot issues that I would like to solve.
Firstly, is there a way to group the legends of each subplot instead of showing legends for two subplots? By grouping the legends, I mean when I click on one legend (such as the first quadrant), it will make the corresponding traces on both subplots disappear. If I click on the legend of first quadrant
, then the random walk in the first subplot and the corresponding count in the second subplot will disappear for that quadrant.
Secondly, I have written the x and y axis titles for both subplots, but only one is appearing. Additionally, the x and y-axis titles of the second plot are appearing on the first subplot. Can you suggest a solution to fix these issues?
I am attaching the code and the plot that I am getting.
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
stepsize = 0.5
num_steps = 250
# Generating 2D random walk data
pos = np.array([0, 0])
path = []
for _ in range(num_steps):
pos = pos + np.random.normal(0, stepsize, 2)
path.append(list(pos))
path = np.array(path)
x = path[:, 0]
y = path[:, 1]
# Determining quadrant for each step
quadrants = np.zeros(num_steps)
quadrants[(x >= 0) & (y >= 0)] = 1 # First quadrant
quadrants[(x < 0) & (y >= 0)] = 2 # Second quadrant
quadrants[(x < 0) & (y < 0)] = 3 # Third quadrant
quadrants[(x >= 0) & (y < 0)] = 4 # Fourth quadrant
# Counting occurrences in each quadrant
quadrant_counts = [np.sum(quadrants == q) for q in range(1, 5)]
# Counting occurrences at (0,0)
origin_count = np.sum((x == 0) & (y == 0))
colors = ['red', 'blue', 'green', 'purple']
fig1 = go.Figure()
for q in range(1, 5):
x_q = x[quadrants == q]
y_q = y[quadrants == q]
fig1.add_trace(go.Scatter(
x=x_q,
y=y_q,
name=f'Quadrant {q}',
mode='markers+lines',
marker=dict(
color=colors[q-1],
size=10
),
line=dict(color='black', width=0.5),
))
# Adding colored rectangles for each quadrant
fig1.update_layout(
shapes=[
# First quadrant
dict(
type="rect",
xref="x",
yref="y",
x0=0,
y0=0,
x1=max(x)+2,
y1=max(y)+2,
fillcolor = 'red',
opacity = 0.2,
line=dict(width=0),
layer="below"
),
# Second quadrant
dict(
type="rect",
xref="x",
yref="y",
x0=min(x)-2,
y0=0,
x1=0,
y1=max(y)+2,
fillcolor="blue",
opacity = 0.2,
line=dict(width=0),
layer="below"
),
# Third quadrant
dict(
type="rect",
xref="x",
yref="y",
x0=min(x)-2,
y0=min(y)-2,
x1=0,
y1=0,
fillcolor="green",
opacity = 0.2,
line=dict(width=0),
layer="below"
),
# Fourth quadrant
dict(
type="rect",
xref="x",
yref="y",
x0=0,
y0=min(y)-2,
x1=max(x)+2,
y1=0,
fillcolor="purple",
opacity = 0.2,
line=dict(width=0),
layer="below"
)
]
)
fig1.add_shape(
type='line',
x0=0, y0=min(y)-2,
x1=0, y1=max(y)+2,
line=dict(color='black', width=2)
)
fig1.add_shape(
type='line',
x0=min(x)-2, y0=0,
x1=max(x)+2, y1=0,
line=dict(color='black', width=2)
)
fig1.add_trace(go.Scatter(
x=[x[0], x[-1]],
y=[y[0], y[-1]],
mode='markers+text',
name='Start and End Point',
text=['Starting Point', 'End Point'],
textposition='bottom center',
marker=dict(color='black', size=15),
textfont=dict(color='black', size=10)
))
fig1.update_layout(
height=600,
width=1000,
xaxis_title = 'x',
yaxis_title = 'y',
showlegend=True,
legend=dict( orientation="h",
yanchor="top",
y=1.1,
xanchor="right",
x=1),
)
# Creating the bar plot for the counts
fig2 = go.Figure()
for q in range(1, 5):
fig2.add_trace(go.Bar(
x=[f'Quadrant {q}'],
y=[quadrant_counts[q - 1]],
marker_color=colors[q - 1],
name=f'Quadrant {q}'
))
# Adding bar plot for (0,0) position count
fig2.add_trace(go.Bar(
x=['(0,0)'],
y=[origin_count],
marker_color='gray',
name='(0,0) Position'
))
fig2.update_layout(
xaxis_title='Position',
yaxis_title='Count',
height=600,
width=1000,
showlegend=True,
legend=dict( orientation="h",
yanchor="top",
y=1.1,
xanchor="right",
x=1)
)
# Creating subplots
fig = make_subplots(rows=2, cols=1, subplot_titles=['2D Random Walk', 'Position Counts'],
vertical_spacing=0.15)
# Adding traces from fig1
for trace in fig1.data:
fig.add_trace(trace, row=1, col=1)
# Adding shapes from fig1
for shape in fig1.layout.shapes:
fig.add_shape(shape)
# Adding layout attributes from figures
fig.update_layout(fig1.layout)
fig.update_layout(fig2.layout)
# Adding bar plot for the counts in the second subplot
for trace in fig2.data:
fig.add_trace(trace, row=2, col=1)
fig.update_layout(height=1200, width=1000, showlegend=True)
fig.show()