Remove Trendline from Marginal Distribution Figures

I’m looking for a “clean” way to remove the trendline from the marginal-distribution subplot created using plotly-express. I know it’s a bit unclear, so please look at the following example:
Generating some fake data:

np.random.seed(42)
data = pd.DataFrame(np.random.randint(0, 100, (100, 4)), columns=["feature1", "feature2", "feature3", "feature4"])
data["label"] = np.random.choice(list("ABC"), 100)
data["is_outlier"] = np.random.choice([True, False], 100)

Creating a scatter plot with both marginal and trendline options:

fig = px.scatter(
    data, x="feature1", y="feature2",
    color="label", symbol="is_outlier", symbol_map={True: "x", False: "circle"},
    log_x=False, marginal_x="box",
    log_y=False, marginal_y="box",
    trendline="ols", trendline_scope="overall", trendline_color_override='black',
    trendline_options=dict(log_x=False, log_y=False),
)

This yields a figure with a trendline in all 3 panels:

I looked into the fig.data struct and found that the trendlines are the last 3 objects in it, and the last 2 are the lines appearing in the top & right panels. Removing those objects from the structs will result in removing the lines from those panels. Seen here:

fig2 = copy.deepcopy(fig)
fig2.data = fig2.data[:-2]

This creates a new issue, because it also removes trendline from the legend, which is not a behavior I’m happy with. So I need to first update the 3rd-to-last object (main panel’s trendline) to have showlegend=True attribute:

fig3 = copy.deepcopy(fig)
fig3.data[-3].showlegend = True
fig3.data = fig3.data[:-2]

This finally gives me the figure I wanted:

So I do have a solution, but it requires “manhandling” the fig object.
Is there a better, cleaner way of achieving the same final figure?

###############
Full code:

import copy

import numpy as np
import pandas as pd
import plotly.io as pio
import plotly.express as px

pio.renderers.default = "browser"

np.random.seed(42)
data = pd.DataFrame(np.random.randint(0, 100, (100, 4)), columns=["feature1", "feature2", "feature3", "feature4"])
data["label"] = np.random.choice(list("ABC"), 100)
data["is_outlier"] = np.random.choice([True, False], 100)

fig = px.scatter(
    data, x="feature1", y="feature2",
    color="label", symbol="is_outlier", symbol_map={True: "x", False: "circle"},
    log_x=False, marginal_x="box",
    log_y=False, marginal_y="box",
    trendline="ols", trendline_scope="overall", trendline_color_override='black',
    trendline_options=dict(log_x=False, log_y=False),
)
fig.show()

fig2 = copy.deepcopy(fig)
fig2.data = fig2.data[:-2]
fig2.show()

fig3 = copy.deepcopy(fig)
fig3.data[-3].showlegend = True
fig3.data = fig3.data[:-2]
fig3.show()

Hi @JonNir I actually think this might be a bug. You could open an issue on github.

If you change to trendline_scope="trace" the behavior is different.

If you can live with your workaround I personally would just forget about it. That said, you could just not use the trend line functionality of px.express, calculate it with a third party package and then add it as a new trace to the scatter plot.

Thanks @AIMPED,
Here’s the issue I opened on GitHub.
Also, here’s a “cleaner” workaround someone suggested on SO.

1 Like