SHAP-like bee swarm plots

Hello,

I am trying to approximately reproduce the bee swarm plot produced by the SHAP library in Plotly. This is how it looks like:

This is my code:

import pandas as pd
import plotly.express as px

df = pd.read_csv('Shap_FI.csv')

values = df.iloc[:,2:].abs().mean(axis=0).sort_values().index
df_plot = pd.melt(df, id_vars=['transaction_id', 'predictions'], value_vars=values, var_name='Feature', value_name='SHAP')

fig = px.strip(df_plot, x='SHAP', y='Feature', color='predictions', stripmode='overlay', height=1024)
fig.update_layout(xaxis=dict(showgrid=True),
              yaxis=dict(showgrid=True)
)
fig.show()

Which produces the following:

I do not care about background colour or heatmap legend, etc.

What I would like to achieve is to have row heights and dot densities proportional to frequency.

It looks like it is already doing it slightly but it’s not very visible. px.strip is based on a boxplot with visible points and invisible box lines + jitter on points so you could try making the boxes closer together and increasing the amount of jitter.

Could you try adding the following to your code and see if it’s clearer?

fig = (
    fig
    # Make it so there is no gap between the supporting boxes
    .update_layout(boxgap=0)
    # Increase the jitter so it reaches the sides of the boxes
    .update_traces(jitter=1)
)
1 Like

Thanks for pointing me in the right direction Renaud!

Indeed, boxgap and jitter are the way to go. I still need to play a bit with the values but I am almost there.

import pandas as pd
import plotly.express as px

df = pd.read_csv('Shap_FI.csv')

#values = df.iloc[:,2:].columns
values = df.iloc[:,2:].abs().mean(axis=0).sort_values().index
df_plot = pd.melt(df, id_vars=['transaction_id', 'predictions'], value_vars=values, var_name='Feature', value_name='SHAP')

fig = px.strip(df_plot, x='SHAP', y='Feature', color='predictions', stripmode='overlay', height=4000, width=1000)
fig.update_layout(xaxis=dict(showgrid=True, gridcolor='WhiteSmoke', zerolinecolor='Gainsboro'),
              yaxis=dict(showgrid=True, gridcolor='WhiteSmoke', zerolinecolor='Gainsboro')
)
fig.update_layout(plot_bgcolor='white')

fig = (
    fig
    # Make it so there is no gap between the supporting boxes
    .update_layout(boxgap=0)
    # Increase the jitter so it reaches the sides of the boxes
    .update_traces(jitter=1)
)

fig.write_html('plotly_beeswarm_test.html')
fig.show()

1 Like