I’m looking for a “clean” way to remove the trendline from the marginal-distribution subplot created using plotly-express. I know it’s a bit unclear, so please look at the following example:
Generating some fake data:
np.random.seed(42)
data = pd.DataFrame(np.random.randint(0, 100, (100, 4)), columns=["feature1", "feature2", "feature3", "feature4"])
data["label"] = np.random.choice(list("ABC"), 100)
data["is_outlier"] = np.random.choice([True, False], 100)
Creating a scatter plot with both marginal
and trendline
options:
fig = px.scatter(
data, x="feature1", y="feature2",
color="label", symbol="is_outlier", symbol_map={True: "x", False: "circle"},
log_x=False, marginal_x="box",
log_y=False, marginal_y="box",
trendline="ols", trendline_scope="overall", trendline_color_override='black',
trendline_options=dict(log_x=False, log_y=False),
)
This yields a figure with a trendline in all 3 panels:
I looked into the fig.data
struct and found that the trendlines are the last 3 objects in it, and the last 2 are the lines appearing in the top & right panels. Removing those objects from the structs will result in removing the lines from those panels. Seen here:
fig2 = copy.deepcopy(fig)
fig2.data = fig2.data[:-2]
This creates a new issue, because it also removes trendline
from the legend, which is not a behavior I’m happy with. So I need to first update the 3rd-to-last object (main panel’s trendline) to have showlegend=True
attribute:
fig3 = copy.deepcopy(fig)
fig3.data[-3].showlegend = True
fig3.data = fig3.data[:-2]
This finally gives me the figure I wanted:
So I do have a solution, but it requires “manhandling” the fig
object.
Is there a better, cleaner way of achieving the same final figure?
###############
Full code:
import copy
import numpy as np
import pandas as pd
import plotly.io as pio
import plotly.express as px
pio.renderers.default = "browser"
np.random.seed(42)
data = pd.DataFrame(np.random.randint(0, 100, (100, 4)), columns=["feature1", "feature2", "feature3", "feature4"])
data["label"] = np.random.choice(list("ABC"), 100)
data["is_outlier"] = np.random.choice([True, False], 100)
fig = px.scatter(
data, x="feature1", y="feature2",
color="label", symbol="is_outlier", symbol_map={True: "x", False: "circle"},
log_x=False, marginal_x="box",
log_y=False, marginal_y="box",
trendline="ols", trendline_scope="overall", trendline_color_override='black',
trendline_options=dict(log_x=False, log_y=False),
)
fig.show()
fig2 = copy.deepcopy(fig)
fig2.data = fig2.data[:-2]
fig2.show()
fig3 = copy.deepcopy(fig)
fig3.data[-3].showlegend = True
fig3.data = fig3.data[:-2]
fig3.show()