✊🏿 Black Lives Matter. Please consider donating to Black Girls Code today.
🧬 Learn how to build RNA-Seq data apps with Python & Dash. Register for the May 20 Webinar!

Associating subplots legends with each subplot and formatting subplot titles

Hello,

I want to make a graph similar to the one in the photo, but I can’t find a way to attach each legend to its subplot, so is there any way to do that?

Also, If I want to use subplot titles, is there any way to format the titles in terms of font size and colors?!

Thanks in advance,

2 Likes

Legends per subplots are not possible unfortunately but you can emulate them with annotations, see the example below. As for subplot titles, they are annotations themselves (you can notice this by doing “print(fig)” to inspect the structure of the figure) so you can tune them with fig.update_annotations.

from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(
    rows=1, cols=2,
    subplot_titles=("Plot 1", "Plot 2"))

fig.add_trace(go.Scatter(x=[1, 2, 3], y=[4, 5, 6]),
              row=1, col=1)

fig.add_trace(go.Scatter(x=[20, 30, 40], y=[50, 60, 70]),
              row=1, col=2)


fig.update_layout(title_text="Multiple Subplots with Titles",
                  showlegend=False
                 )
fig.update_annotations(dict(font_size=8))
for col in [1, 2]:
    fig.add_annotation(dict(x=col / 2 - 0.4, y=0.8, xref="paper", yref="paper", 
                            text='trace %d' %col, showarrow=False))
fig.show()

1 Like

Thanks @Emmanuelle , that’s really helpful, although it is hard to adjust the annotations to be at the same position in all the subplots.
I hope Plotly will give us a solution for “Legends per subplots” soon or at least enable x and y refs for each subplot instead of full paper.
Thanks again.

You can use "x1" etc. as xref but it will correspond to data coordinates in x1.

This is why it’s useful to use subplots as references because using data points as references won’t guarantee consistent positioning .

Hi @Emmanuelle, your example has been really helpful. I have a related question on subplots. statsmodels.graphics.api.abline_plot accepts an axis handle for plotting. How can I get axis handle to (i,j)th subplot and provide like below:

fig, ax= plt.subplots()
# need to replace the above ax with handle of (i,j)th subplot
abline_plot(model_results=line_fit, ax=ax)

An example on abline_plot can be found here.

Hi @Emmanuelle, may be this would be easy to answer. How can I adjust size of whole subplot?
I have already played with column_widths and row_heights but they modify each cell keeping the whole subplot size same.

def make_subplots(
    rows=1,
    cols=1,
    shared_xaxes=False,
    shared_yaxes=False,
    start_cell="top-left",
    print_grid=False,
    horizontal_spacing=None,
    vertical_spacing=None,
    subplot_titles=None,
    column_widths=None,
    row_heights=None,
    specs=None,
    insets=None,
    column_titles=None,
    row_titles=None,
    x_title=None,
    y_title=None,
    **kwargs
):


Edit:
Never mind, discovered fig.update_layout(height=1000)

Hi all,

I do not have the perfect solution but if you are facing this issue in Python you can utilize the fields legendgroup and legend_tracegroupgap to fake having individual legends per subplot. See my explanation here: [Plotly] How to make individual legends in subplot | Kaggle

You can try the code below to create the figure above

import pandas as pd
import plotly.express as px

df = px.data.gapminder().query("continent=='Americas'")

from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=3, cols=1)

fig.append_trace(go.Scatter(
    x=df.query("country == 'Canada'")['year'],
    y=df.query("country == 'Canada'")['lifeExp'],
    name = 'Canada',
    legendgroup = '1'
), row=1, col=1)
fig.append_trace(go.Scatter(
    x=df.query("country == 'United States'")['year'],
    y=df.query("country == 'United States'")['lifeExp'],
    name = 'United States',
    legendgroup = '1'
), row=1, col=1)

fig.append_trace(go.Scatter(
    x=df.query("country == 'Mexico'")['year'],
    y=df.query("country == 'Mexico'")['lifeExp'],
    name = 'Mexico',
    legendgroup = '2'
), row=2, col=1)
fig.append_trace(go.Scatter(
    x=df.query("country == 'Colombia'")['year'],
    y=df.query("country == 'Colombia'")['lifeExp'],
    name = 'Colombia',
    legendgroup = '2'
), row=2, col=1)
fig.append_trace(go.Scatter(
    x=df.query("country == 'Brazil'")['year'],
    y=df.query("country == 'Brazil'")['lifeExp'],
    name = 'Brazil',
    legendgroup = '2'
), row=2, col=1)

fig.append_trace(go.Scatter(
    x=df.query("country == 'Argentina'")['year'],
    y=df.query("country == 'Argentina'")['lifeExp'],
    name = 'Argentina',
    legendgroup = '3'
), row=3, col=1)
fig.append_trace(go.Scatter(
    x=df.query("country == 'Chile'")['year'],
    y=df.query("country == 'Chile'")['lifeExp'],
    name = 'Chile',
    legendgroup = '3'
), row=3, col=1)

fig.update_layout(
    height=800, 
    width=800, 
    title_text="Life Expectancy in the Americas", 
    xaxis3_title = 'Year',
    yaxis1_title = 'Age',
    yaxis2_title = 'Age',
    yaxis3_title = 'Age',
    legend_tracegroupgap = 180,
    yaxis1_range=[50, 90],
    yaxis2_range=[50, 90],
    yaxis3_range=[50, 90]
)
fig.show()