Decision Tree plot plot_tree

image

Is there anyway plot such map with Plotly? this is Decision Tree plot_tree
Thank you

@SaadKhan
Plotly can plot trees, and any other graph structure, if you provide the node positions and the list of edges. Such data are provided by graph layout algorithms. scikit- learn plots a decision tree with matplotlib, calling the function plot_tree, and uses graphviz to get the layout.
Decision trees have Buchheim layout. The journal article, by Buchheim, that presents this algorithm can be downloaded here:
https://www.researchgate.net/publication/30508504_Improving_Walker’s_Algorithm_to_Run_in_Linear_Time. Unfortunately, neither Networkx, nor igraph provide this layout.
When I want to plot such a tree I use Julia, not Python, because the Julia package NetworkLayout.jl implements the Buchheim layout.

So it’s not possible with Python?

@SaadKhan
graphviz involved in plotting scikit-learn decision trees is not Python.
The only solution I see now is to implement yourself the Buchheim algorithm in Python, and to plot your decision tree with Plotly, based on the tree position, returned by your code.
You can find Plotly examples of networks (in particular trees), googling, “plotly, networks”.

Thank you for your replies, I tried, but it’s making my program worse, I thank you for your time :slight_smile:

@SaadKhan
A layout algorithm returns an array coords of shape (n,2), that records the coordinates of the tree nodes.
If E is the list of edges, represented as tuples (i,j), with i, j, pointing outh the end nodes of an edge, then you need these three functions to plot the tree:

def get_plotly_data(E, coords):
    # E is the list of tuples representing the graph edges
    # coords is the list of node coordinates 
    N = len(coords)
    Xnodes = [coords[k][0] for k in range(N)] # x-coordinates of nodes
    Ynodes = [coords[k][1] for k in range(N)] # y-coordnates of nodes

    Xedges = []
    Yedges = []
    for e in E:
        Xedges.extend([coords[e[0]][0], coords[e[1]][0], None])
        Yedges.extend([coords[e[0]][1], coords[e[1]][1], None])
        
    return Xnodes, Ynodes, Xedges, Yedges 

def get_node_trace(x, y, labels, marker_size=5, marker_color='#6959CD', 
                   line_color='rgb(50,50,50)', line_width=0.5):
    return go.Scatter(
                x=x,
                y=y,
                mode='markers',
                marker=dict(
                            size=marker_size, 
                            color=marker_color,
                            line=dict(color=line_color, width=line_width)
                             ),
            text=labels,
            hoverinfo='text'
               )

def get_edge_trace(x, y, line_color='rgb(210,210,210)', line_width=1):
    return go.Scatter(
                x=x,
                y=y,
                mode='lines',
                line_color=line_color,
                line_width=line_width,
                hoverinfo='none'
               )

If the root is not placed at the origin of axes, i.e. at the point of coords (0,0), but at some point (x0,y0), then you should translate it at origin, and all node positions are mapped to new coords= coords--np.array([x0, y-0]).. To ensure that the rectangular boxes placed at the new node positions do not overlap, you can map a scalling transformation to the last coords, i.e. multiply each x coord with a factor, a>0, and preserve the y coord.

1 Like

How can the E and coords be extracted from plot_tree(model), Thank you

my code is:
model = DecisionTreeClassifier(some parameters)
model.fit(X_train, y_train)
plot_tree(model);
it creates that figure shown in the picture

Print

help(plot_tree)

to see its arguments and what it returns.

returns annotation
[Text(0.5, 0.9166666666666666, ‘X[2] <= 206.5\ngini = 0.633\nsamples = 233\nvalue = [105, 45, 83]’),
Text(0.2727272727272727, 0.75, ‘X[0] <= 42.35\ngini = 0.418\nsamples = 147\nvalue = [104, 42, 1]’),
Text(0.18181818181818182, 0.5833333333333334, ‘gini = 0.0\nsamples = 97\nvalue = [97, 0, 0]’),
Text(0.36363636363636365, 0.5833333333333334, ‘X[4] <= 0.5\ngini = 0.274\nsamples = 50\nvalue = [7, 42, 1]’),
Text(0.18181818181818182, 0.4166666666666667, ‘X[1] <= 16.0\ngini = 0.245\nsamples = 7\nvalue = [6, 0, 1]’),
Text(0.09090909090909091, 0.25, ‘gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]’),
Text(0.2727272727272727, 0.25, ‘gini = 0.0\nsamples = 6\nvalue = [6, 0, 0]’),
Text(0.5454545454545454, 0.4166666666666667, ‘X[3] <= 4075.0\ngini = 0.045\nsamples = 43\nvalue = [1, 42, 0]’),
Text(0.45454545454545453, 0.25, ‘gini = 0.0\nsamples = 38\nvalue = [0, 38, 0]’),
Text(0.6363636363636364, 0.25, ‘X[2] <= 196.5\ngini = 0.32\nsamples = 5\nvalue = [1, 4, 0]’),
Text(0.5454545454545454, 0.08333333333333333, ‘gini = 0.0\nsamples = 1\nvalue = [1, 0, 0]’),
Text(0.7272727272727273, 0.08333333333333333, ‘gini = 0.0\nsamples = 4\nvalue = [0, 4, 0]’),
Text(0.7272727272727273, 0.75, ‘X[1] <= 17.65\ngini = 0.09\nsamples = 86\nvalue = [1, 3, 82]’),
Text(0.6363636363636364, 0.5833333333333334, ‘gini = 0.0\nsamples = 82\nvalue = [0, 0, 82]’),
Text(0.8181818181818182, 0.5833333333333334, ‘X[4] <= 0.5\ngini = 0.375\nsamples = 4\nvalue = [1, 3, 0]’),
Text(0.7272727272727273, 0.4166666666666667, ‘gini = 0.0\nsamples = 1\nvalue = [1, 0, 0]’),
Text(0.9090909090909091, 0.4166666666666667, ‘gini = 0.0\nsamples = 3\nvalue = [0, 3, 0]’)]

Just now I’m out, and cannot check if I’m right or not. I think that the first two numbers in Text are the coordinates of the node i, where i appears as index in X. The rest of text is the usual text displayed in the corresponding box.

and What should be E? Thank you

@SaadKhan

I plotted the points of coordinates given as the first two values in each Text, and indeed they are the coordinates of the tree nodes. Moreover assigning to each pair of coordinates its index in the list you pasted here, I was able to deduce that they are listed in preorder, i.e. the tree was traversed by the preorder method. This information could help to get the list of edges, if you have some knowledge on tree traversal.
You need a function that extracts from the array of coordinates, the list of edges. This plot could help, if you assign to each point the index of its coordonates in the array coords:

import plotly.graph_objects as go
import numpy as np

coords= np.array([[0.5, 0.9166666666666666],
[0.2727272727272727, 0.75],
[0.18181818181818182, 0.5833333333333334],
[0.36363636363636365, 0.5833333333333334],
[0.18181818181818182, 0.4166666666666667],
[0.09090909090909091, 0.25],
[0.27272727270272727, 0.25],
[0.5454545454545454, 0.4166666666666667],
[0.45454545454545453, 0.25],
[0.6363636363636364, 0.25],
[0.5454545454545454, 0.08333333333333333],
[0.7272727272727273, 0.08333333333333333],
[0.7272727272727273, 0.75],
[0.6363636363636364, 0.5833333333333334],
[0.8181818181818182, 0.5833333333333334],
[0.7272727272727273, 0.4166666666666667],
[0.9090909090909091, 0.4166666666666667]])
Xn, Yn= coords.T
fig=go.Figure(go.Scatter(x=Xn, y=Yn, mode="markers", marker_size=15))
fig.update_layout(width=750, height=550)
1 Like

The Brilliant Idea, fantastic, i extracted the coordinates and the text from plot_tree(model) by:
coords = # empty list
texts = # empty list
for item in plot_tree(model):
coords.append(list(item.get_position()))
texts.append(item.get_text())

Only to convert the circles of the scatter into rectangles and put the text inside, would be absolutely great now

@empet I just want to add the line as shown in the first original pictures, as they are required, rest is done.
I used the functions you provided as getting this, tried for very long now

coords =
texts =
for item in plot_tree(model):
coords.append(list(item.get_position()))
texts.append(item.get_text());
Xn, Yn = coords.T

fig = go.Figure()
fig.add_trace(go.Scatter(x=Xnodes, y=Ynodes, mode=“markers+text”, marker_size=15, text=texts, textposition=‘top center’))

#++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
fig.add_trace(go.Scatter(x=Xedges,
y=Yedges,
mode=‘lines’,
line_color=‘blue’,
line_width=5,
hoverinfo=‘none’
))
fig.update_layout(width=750, height=550)

Any way to solve this?

@SaadKhan

E = [(0,1), (1,2), (1,3), (2,4), (4,5), (4,6), (7,8), (7,9), (9,10), (9,11),
     (0,12), (12, 13), (13,7), (12,14), (14,15), (14, 16)]
labels= np.arange(17)
Xnodes, Ynodes, Xedges, Yedges = get_plotly_data(E, coords)
nodes= get_node_trace(Xnodes, Ynodes, labels, marker_size=18, marker_color='#c0c0c0')
edges= get_edge_trace(Xedges, Yedges)
fig= go.Figure([edges, nodes])
fig.update_layout(title_text="Decision tree",
            title_x=0.5,
            font_size=12,
            showlegend=False,
            width=800,
            height=600,
            xaxis_visible=False,
            yaxis_visible=False, 
            template='none',     
            hovermode='closest',
            paper_bgcolor='#eeeeee')

In get_node_trace() I replaced mode='markers', by mode='markers+text', and commented out
hoverinfo='text', in order to ensure that node indices, i.e. indices of their coordinates in the array
coords, are displayed on the plot…
Good luck, and have a nice weekend!!!

2 Likes

Absolutely briliant, Thank you,
It is just me who do not understand how to get this E list, though you explained it very well and thank you for your time

1 Like