Is there anyway plot such map with Plotly? this is Decision Tree plot_tree
Thank you
@SaadKhan
Plotly can plot trees, and any other graph structure, if you provide the node positions and the list of edges. Such data are provided by graph layout algorithms. scikit- learn
plots a decision tree with matplotlib, calling the function plot_tree
, and uses graphviz
to get the layout.
Decision trees have Buchheim layout. The journal article, by Buchheim, that presents this algorithm can be downloaded here:
https://www.researchgate.net/publication/30508504_Improving_Walker’s_Algorithm_to_Run_in_Linear_Time. Unfortunately, neither Networkx
, nor igraph
provide this layout.
When I want to plot such a tree I use Julia, not Python, because the Julia package NetworkLayout.jl implements the Buchheim layout.
So it’s not possible with Python?
@SaadKhan
graphviz
involved in plotting scikit-learn decision trees is not Python.
The only solution I see now is to implement yourself the Buchheim algorithm in Python, and to plot your decision tree with Plotly, based on the tree position, returned by your code.
You can find Plotly examples of networks (in particular trees), googling, “plotly, networks”.
Thank you for your replies, I tried, but it’s making my program worse, I thank you for your time
@SaadKhan
A layout algorithm returns an array coords
of shape (n,2), that records the coordinates of the tree nodes.
If E is the list of edges, represented as tuples (i,j), with i, j, pointing outh the end nodes of an edge, then you need these three functions to plot the tree:
def get_plotly_data(E, coords):
# E is the list of tuples representing the graph edges
# coords is the list of node coordinates
N = len(coords)
Xnodes = [coords[k][0] for k in range(N)] # x-coordinates of nodes
Ynodes = [coords[k][1] for k in range(N)] # y-coordnates of nodes
Xedges = []
Yedges = []
for e in E:
Xedges.extend([coords[e[0]][0], coords[e[1]][0], None])
Yedges.extend([coords[e[0]][1], coords[e[1]][1], None])
return Xnodes, Ynodes, Xedges, Yedges
def get_node_trace(x, y, labels, marker_size=5, marker_color='#6959CD',
line_color='rgb(50,50,50)', line_width=0.5):
return go.Scatter(
x=x,
y=y,
mode='markers',
marker=dict(
size=marker_size,
color=marker_color,
line=dict(color=line_color, width=line_width)
),
text=labels,
hoverinfo='text'
)
def get_edge_trace(x, y, line_color='rgb(210,210,210)', line_width=1):
return go.Scatter(
x=x,
y=y,
mode='lines',
line_color=line_color,
line_width=line_width,
hoverinfo='none'
)
If the root is not placed at the origin of axes, i.e. at the point of coords (0,0), but at some point (x0,y0), then you should translate it at origin, and all node positions are mapped to new coords= coords--np.array([x0, y-0]).
. To ensure that the rectangular boxes placed at the new node positions do not overlap, you can map a scalling transformation to the last coords, i.e. multiply each x coord with a factor, a>0, and preserve the y coord.
How can the E and coords be extracted from plot_tree(model), Thank you
my code is:
model = DecisionTreeClassifier(some parameters)
model.fit(X_train, y_train)
plot_tree(model);
it creates that figure shown in the picture
help(plot_tree)
to see its arguments and what it returns.
returns annotation
[Text(0.5, 0.9166666666666666, ‘X[2] <= 206.5\ngini = 0.633\nsamples = 233\nvalue = [105, 45, 83]’),
Text(0.2727272727272727, 0.75, ‘X[0] <= 42.35\ngini = 0.418\nsamples = 147\nvalue = [104, 42, 1]’),
Text(0.18181818181818182, 0.5833333333333334, ‘gini = 0.0\nsamples = 97\nvalue = [97, 0, 0]’),
Text(0.36363636363636365, 0.5833333333333334, ‘X[4] <= 0.5\ngini = 0.274\nsamples = 50\nvalue = [7, 42, 1]’),
Text(0.18181818181818182, 0.4166666666666667, ‘X[1] <= 16.0\ngini = 0.245\nsamples = 7\nvalue = [6, 0, 1]’),
Text(0.09090909090909091, 0.25, ‘gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]’),
Text(0.2727272727272727, 0.25, ‘gini = 0.0\nsamples = 6\nvalue = [6, 0, 0]’),
Text(0.5454545454545454, 0.4166666666666667, ‘X[3] <= 4075.0\ngini = 0.045\nsamples = 43\nvalue = [1, 42, 0]’),
Text(0.45454545454545453, 0.25, ‘gini = 0.0\nsamples = 38\nvalue = [0, 38, 0]’),
Text(0.6363636363636364, 0.25, ‘X[2] <= 196.5\ngini = 0.32\nsamples = 5\nvalue = [1, 4, 0]’),
Text(0.5454545454545454, 0.08333333333333333, ‘gini = 0.0\nsamples = 1\nvalue = [1, 0, 0]’),
Text(0.7272727272727273, 0.08333333333333333, ‘gini = 0.0\nsamples = 4\nvalue = [0, 4, 0]’),
Text(0.7272727272727273, 0.75, ‘X[1] <= 17.65\ngini = 0.09\nsamples = 86\nvalue = [1, 3, 82]’),
Text(0.6363636363636364, 0.5833333333333334, ‘gini = 0.0\nsamples = 82\nvalue = [0, 0, 82]’),
Text(0.8181818181818182, 0.5833333333333334, ‘X[4] <= 0.5\ngini = 0.375\nsamples = 4\nvalue = [1, 3, 0]’),
Text(0.7272727272727273, 0.4166666666666667, ‘gini = 0.0\nsamples = 1\nvalue = [1, 0, 0]’),
Text(0.9090909090909091, 0.4166666666666667, ‘gini = 0.0\nsamples = 3\nvalue = [0, 3, 0]’)]
Just now I’m out, and cannot check if I’m right or not. I think that the first two numbers in Text are the coordinates of the node i, where i appears as index in X. The rest of text is the usual text displayed in the corresponding box.
and What should be E? Thank you
I plotted the points of coordinates given as the first two values in each Text, and indeed they are the coordinates of the tree nodes. Moreover assigning to each pair of coordinates its index in the list you pasted here, I was able to deduce that they are listed in preorder, i.e. the tree was traversed by the preorder method. This information could help to get the list of edges, if you have some knowledge on tree traversal.
You need a function that extracts from the array of coordinates, the list of edges. This plot could help, if you assign to each point the index of its coordonates in the array coords
:
import plotly.graph_objects as go
import numpy as np
coords= np.array([[0.5, 0.9166666666666666],
[0.2727272727272727, 0.75],
[0.18181818181818182, 0.5833333333333334],
[0.36363636363636365, 0.5833333333333334],
[0.18181818181818182, 0.4166666666666667],
[0.09090909090909091, 0.25],
[0.27272727270272727, 0.25],
[0.5454545454545454, 0.4166666666666667],
[0.45454545454545453, 0.25],
[0.6363636363636364, 0.25],
[0.5454545454545454, 0.08333333333333333],
[0.7272727272727273, 0.08333333333333333],
[0.7272727272727273, 0.75],
[0.6363636363636364, 0.5833333333333334],
[0.8181818181818182, 0.5833333333333334],
[0.7272727272727273, 0.4166666666666667],
[0.9090909090909091, 0.4166666666666667]])
Xn, Yn= coords.T
fig=go.Figure(go.Scatter(x=Xn, y=Yn, mode="markers", marker_size=15))
fig.update_layout(width=750, height=550)
The Brilliant Idea, fantastic, i extracted the coordinates and the text from plot_tree(model) by:
coords = # empty list
texts = # empty list
for item in plot_tree(model):
coords.append(list(item.get_position()))
texts.append(item.get_text())
Only to convert the circles of the scatter into rectangles and put the text inside, would be absolutely great now
@empet I just want to add the line as shown in the first original pictures, as they are required, rest is done.
I used the functions you provided as getting this, tried for very long now
coords =
texts =
for item in plot_tree(model):
coords.append(list(item.get_position()))
texts.append(item.get_text());
Xn, Yn = coords.T
fig = go.Figure()
fig.add_trace(go.Scatter(x=Xnodes, y=Ynodes, mode=“markers+text”, marker_size=15, text=texts, textposition=‘top center’))
#++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
fig.add_trace(go.Scatter(x=Xedges,
y=Yedges,
mode=‘lines’,
line_color=‘blue’,
line_width=5,
hoverinfo=‘none’
))
fig.update_layout(width=750, height=550)
Any way to solve this?
E = [(0,1), (1,2), (1,3), (2,4), (4,5), (4,6), (7,8), (7,9), (9,10), (9,11),
(0,12), (12, 13), (13,7), (12,14), (14,15), (14, 16)]
labels= np.arange(17)
Xnodes, Ynodes, Xedges, Yedges = get_plotly_data(E, coords)
nodes= get_node_trace(Xnodes, Ynodes, labels, marker_size=18, marker_color='#c0c0c0')
edges= get_edge_trace(Xedges, Yedges)
fig= go.Figure([edges, nodes])
fig.update_layout(title_text="Decision tree",
title_x=0.5,
font_size=12,
showlegend=False,
width=800,
height=600,
xaxis_visible=False,
yaxis_visible=False,
template='none',
hovermode='closest',
paper_bgcolor='#eeeeee')
In get_node_trace()
I replaced mode='markers'
, by mode='markers+text'
, and commented out
hoverinfo='text'
, in order to ensure that node indices, i.e. indices of their coordinates in the array
coords
, are displayed on the plot…
Good luck, and have a nice weekend!!!
Absolutely briliant, Thank you,
It is just me who do not understand how to get this E list, though you explained it very well and thank you for your time
Hi there!
Thanks for both for you very helpful discussion.
To complete the objective, here is a hack to retrieve edges from the tree:
def extract_edges(clf):
edges = []
for i, (x, y) in enumerate(
zip(clf.tree_.children_left, clf.tree_.children_right)
):
if x != -1:
edges.append((i, x))
if y != -1:
edges.append((i, y))
return edges
I am glad someone resurfaced this item.
I have now had the pleasure of playing with Plotly Tree Plots and also Dash and mermaid
TreePlots are a lot less complicated than they look.
Documentation is here
Created from the example code
mermaid is what got me excited. Docs here
I can feel an article coming on.
The following code created the flowchart below. it is boiler plate code but I will illuminate on it in a later article. All I would say is look at the other dash modules, you may be surprised what you find.
from dash_extensions.enrich import DashProxy
from dash_extensions import Mermaid
app = DashProxy()
app.layout = Mermaid(chart="""
flowchart TD
Start --> SampleA
Start --> SampleB
Start --> SampleC
Start --> SampleD
Start --> SampleE
Start --> SampleF
Start --> SampleG
SampleA & SampleB --> Experiment1
SampleB & SampleC --> Experiment2
Sample & SampleD --> Experiment3
SampleD & SampleE --> Experiment4
SampleE & SampleF --> Experiment5
SampleF & SampleG --> Experiment6
Experiment1 --> Result1
Experiment2 --> Result2
Experiment3 --> Result3
Experiment4 --> Result4
Experiment5 --> Result5
Experiment6 --> Result6
Result1 & Result2 --> Experiment7
Result3 & Result4 --> Experiment8
Result5 & Result6 --> Experiment9
Experiment7 --> Result7
Experiment8 --> Result8
Experiment9 --> Result9
Result7 & Result8 --> Experiment10
Result8 & Result9 --> Experiment11
Experiment10 --> Result10
Experiment11 --> Result11
Result10 & Result11 --> Experiment12
Experiment12 --> Result13
Result13 --> FinalOutcome
""")
if __name__ == "__main__":
app.run_server()
Hopefully, someone looking for info finds this helpful.