Is there a way to recreate Shap Values summary plot in plotly?


I’m working on recreating the summary plot from the SHAP library using Plotly. I have two datasets:

  • A SHAP value dataset containing the SHAP values for each data point in my original dataset.
  • The original dataset, which includes the one-hot encoded values of the features. This means it contains values of 0 or 1.

I aim to create a bee swarm plot of the SHAP values and assign different colors to each point depending on whether the corresponding point in the one-hot encoded variable is 0 or 1.

Here’s the code I have, which produces the described plot but for only one variable:

figures = []

for column in shap_values.columns:
    fig = px.strip(merged_df, x=merged_df[column+'_shap'], color=merged_df[column+'_train'], orientation='h', stripmode='overlay')

        title=f'Bee swarm plot de la valeur de Shapley pour {column}',
        xaxis_title='Valeur de Shapley (impact sur la sortie du modèle)',

Is there a way to combine all these plots into one comprehensive plot?

Here’s a sample of the data:

shap_values = pd.DataFrame(
    {"A" : [-0.065704,-0.096510,0.062368,0.062368,0.063093], 
     'B' : [-0.168249,-0.173284,-0.168756,-0.168756,-0.169378]})

train  = pd.DataFrame(
    {"A" : [0,1,1,0,0], 
     'B' : [1,1,0,0,1]})

merged_df = shap_values.join(train, lsuffix='_shap', rsuffix='_train’)

The solution :

shap_values = pd.DataFrame({
"A": [np.random([0,1]) for i in range 1000],
"B": [-0.168249, -0.173284, -0.168756, -0.168756, 

train = pd.DataFrame({
"A": [0, 1, 1, 0, 0],
"B": [1, 1, 0, 0, 1]

# Joining SHAP values and one-hot encoded features
merged_df = shap_values.join(train, lsuffix='_shap', rsuffix='_train')

# Melt the merged DataFrame to long format
melted_df = merged_df.melt(value_vars=[col for col in 
merged_df.columns if '_shap' in col],
                       value_name='SHAP Value')
melted_df['Feature'] = melted_df['Feature'].str.replace('_shap', '', regex=False)

# Directly assign the 'One-hot Value' using a vectorized approach
# This avoids using apply() which caused the indexing issue
for feature in train.columns:
    feature_shap = feature + '_shap'
    feature_train = feature + '_train'
    melted_df.loc[melted_df['Feature'] == feature, 'One-hot Value'] = merged_df[feature_train].values

# Generate the plot again
fig = px.strip(melted_df, x='SHAP Value', y='Feature', 
               color='One-hot Value',
               orientation='h', stripmode='overlay', 
               title='Bee Swarm Plot of SHAP Values by Feature')

fig.update_layout(xaxis_title='SHAP Value (Impact on Model Output)',