Figure Factory Subplots - 3 or more charts - blurred together

Hi,

I am basing my work off this post but trying to do 3 plots

The 3 plots are
1.) Annotated Heatmap
2.) ROC Curve
3.) Precision Recall Curve

I can get all 3 on the output but the confusion matrix blurs together with the precision recall curve.

Here is my sample code:

‘’’

from plotly.subplots import make_subplots
from Settings import Settings

import numpy as np
import pandas as pd
import plotly.figure_factory as ff
import plotly.offline as pyo
import plotly.graph_objs as go
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_curve, roc_curve, auc

y_true = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
y_pred = [0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1]

maybe do this part in a function?

confusion_matrix = confusion_matrix(y_true, y_pred)
confusion_matrix = confusion_matrix.astype(int)

x = [‘Predicted No’, ‘Predicted Yes’]

y = [‘Actual No’, ‘Actual Yes’]

z_text = [[str(y) for y in x] for x in confusion_matrix]

#fig = make_subplots(rows=3, cols=1, print_grid=False)

fig = ff.create_annotated_heatmap(x=x,
y=y,
z=confusion_matrix,
colorscale=‘Viridis’,
showscale=False,
annotation_text=z_text)

ROC chart

fpr, tpr, thresholds = roc_curve(y_true, y_pred)

roc_df = pd.DataFrame({‘False Positive Rate’: fpr,
‘True Positive Rate’: tpr},
index=thresholds)

roc_df.index.name = “Thresholds”
roc_df.columns.name = “Rate”

trace1 = go.Scatter(x=fpr,
y=tpr,
name=‘ROC Chart’,
mode=‘lines’,
xaxis=‘x2’,
yaxis=‘y2’
)

precision, recall, thresholds = precision_recall_curve(y_true, y_pred)

trace2 = go.Scatter(x=recall,
y=precision,
mode=‘lines’,
line=dict(width=2, color=‘navy’),
name=‘Precision-Recall curve’
)

fig.add_traces([trace1])

initialize xaxis2 and yaxis2

fig[‘layout’][‘xaxis2’] = {}
fig[‘layout’][‘yaxis2’] = {}
fig[‘layout’][‘xaxis3’] = {}
fig[‘layout’][‘yaxis3’] = {}

Edit layout for subplots

fig.layout.yaxis.update({‘domain’: [0, .33]})
fig.layout.yaxis.update({‘title’: ‘Confusion Matrix’})
fig.layout.xaxis.update({‘anchor’: ‘y’})

The graph’s yaxis MUST BE anchored to the graph’s xaxis

fig.layout.yaxis2.update({‘domain’: [0.34, 0.66]})
fig.layout.yaxis2.update({‘title’: ‘ROC Curve’})
fig.layout.xaxis2.update({‘anchor’: ‘y2’})

fig.add_traces([trace2])

fig.layout.yaxis3.update({‘domain’: [0.67, 1.]})
fig.layout.yaxis3.update({‘title’: ‘Precision Recall Curve’})
fig.layout.xaxis3.update({‘anchor’: ‘y3’})

fig.show()

‘’’

Also once I get the 3rd plot working, how would I add a 4th plot?

Sadly Dash is not an option for me currently.