Hi,
I am basing my work off this post but trying to do 3 plots
The 3 plots are
1.) Annotated Heatmap
2.) ROC Curve
3.) Precision Recall Curve
I can get all 3 on the output but the confusion matrix blurs together with the precision recall curve.
Here is my sample code:
‘’’
from plotly.subplots import make_subplots
from Settings import Settings
import numpy as np
import pandas as pd
import plotly.figure_factory as ff
import plotly.offline as pyo
import plotly.graph_objs as go
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_curve, roc_curve, auc
y_true = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
y_pred = [0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1]
maybe do this part in a function?
confusion_matrix = confusion_matrix(y_true, y_pred)
confusion_matrix = confusion_matrix.astype(int)
x = [‘Predicted No’, ‘Predicted Yes’]
y = [‘Actual No’, ‘Actual Yes’]
z_text = [[str(y) for y in x] for x in confusion_matrix]
#fig = make_subplots(rows=3, cols=1, print_grid=False)
fig = ff.create_annotated_heatmap(x=x,
y=y,
z=confusion_matrix,
colorscale=‘Viridis’,
showscale=False,
annotation_text=z_text)
ROC chart
fpr, tpr, thresholds = roc_curve(y_true, y_pred)
roc_df = pd.DataFrame({‘False Positive Rate’: fpr,
‘True Positive Rate’: tpr},
index=thresholds)
roc_df.index.name = “Thresholds”
roc_df.columns.name = “Rate”
trace1 = go.Scatter(x=fpr,
y=tpr,
name=‘ROC Chart’,
mode=‘lines’,
xaxis=‘x2’,
yaxis=‘y2’
)
precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
trace2 = go.Scatter(x=recall,
y=precision,
mode=‘lines’,
line=dict(width=2, color=‘navy’),
name=‘Precision-Recall curve’
)
fig.add_traces([trace1])
initialize xaxis2 and yaxis2
fig[‘layout’][‘xaxis2’] = {}
fig[‘layout’][‘yaxis2’] = {}
fig[‘layout’][‘xaxis3’] = {}
fig[‘layout’][‘yaxis3’] = {}
Edit layout for subplots
fig.layout.yaxis.update({‘domain’: [0, .33]})
fig.layout.yaxis.update({‘title’: ‘Confusion Matrix’})
fig.layout.xaxis.update({‘anchor’: ‘y’})
The graph’s yaxis MUST BE anchored to the graph’s xaxis
fig.layout.yaxis2.update({‘domain’: [0.34, 0.66]})
fig.layout.yaxis2.update({‘title’: ‘ROC Curve’})
fig.layout.xaxis2.update({‘anchor’: ‘y2’})
fig.add_traces([trace2])
fig.layout.yaxis3.update({‘domain’: [0.67, 1.]})
fig.layout.yaxis3.update({‘title’: ‘Precision Recall Curve’})
fig.layout.xaxis3.update({‘anchor’: ‘y3’})
fig.show()
‘’’
Also once I get the 3rd plot working, how would I add a 4th plot?
Sadly Dash is not an option for me currently.