Decision Tree plot plot_tree

image

Is there anyway plot such map with Plotly? this is Decision Tree plot_tree
Thank you

1 Like

@SaadKhan
Plotly can plot trees, and any other graph structure, if you provide the node positions and the list of edges. Such data are provided by graph layout algorithms. scikit- learn plots a decision tree with matplotlib, calling the function plot_tree, and uses graphviz to get the layout.
Decision trees have Buchheim layout. The journal article, by Buchheim, that presents this algorithm can be downloaded here:
https://www.researchgate.net/publication/30508504_Improving_Walker’s_Algorithm_to_Run_in_Linear_Time. Unfortunately, neither Networkx, nor igraph provide this layout.
When I want to plot such a tree I use Julia, not Python, because the Julia package NetworkLayout.jl implements the Buchheim layout.

So it’s not possible with Python?

@SaadKhan
graphviz involved in plotting scikit-learn decision trees is not Python.
The only solution I see now is to implement yourself the Buchheim algorithm in Python, and to plot your decision tree with Plotly, based on the tree position, returned by your code.
You can find Plotly examples of networks (in particular trees), googling, “plotly, networks”.

Thank you for your replies, I tried, but it’s making my program worse, I thank you for your time :slight_smile:

@SaadKhan
A layout algorithm returns an array coords of shape (n,2), that records the coordinates of the tree nodes.
If E is the list of edges, represented as tuples (i,j), with i, j, pointing outh the end nodes of an edge, then you need these three functions to plot the tree:

def get_plotly_data(E, coords):
    # E is the list of tuples representing the graph edges
    # coords is the list of node coordinates 
    N = len(coords)
    Xnodes = [coords[k][0] for k in range(N)] # x-coordinates of nodes
    Ynodes = [coords[k][1] for k in range(N)] # y-coordnates of nodes

    Xedges = []
    Yedges = []
    for e in E:
        Xedges.extend([coords[e[0]][0], coords[e[1]][0], None])
        Yedges.extend([coords[e[0]][1], coords[e[1]][1], None])
        
    return Xnodes, Ynodes, Xedges, Yedges 

def get_node_trace(x, y, labels, marker_size=5, marker_color='#6959CD', 
                   line_color='rgb(50,50,50)', line_width=0.5):
    return go.Scatter(
                x=x,
                y=y,
                mode='markers',
                marker=dict(
                            size=marker_size, 
                            color=marker_color,
                            line=dict(color=line_color, width=line_width)
                             ),
            text=labels,
            hoverinfo='text'
               )

def get_edge_trace(x, y, line_color='rgb(210,210,210)', line_width=1):
    return go.Scatter(
                x=x,
                y=y,
                mode='lines',
                line_color=line_color,
                line_width=line_width,
                hoverinfo='none'
               )

If the root is not placed at the origin of axes, i.e. at the point of coords (0,0), but at some point (x0,y0), then you should translate it at origin, and all node positions are mapped to new coords= coords--np.array([x0, y-0]).. To ensure that the rectangular boxes placed at the new node positions do not overlap, you can map a scalling transformation to the last coords, i.e. multiply each x coord with a factor, a>0, and preserve the y coord.

3 Likes

How can the E and coords be extracted from plot_tree(model), Thank you

my code is:
model = DecisionTreeClassifier(some parameters)
model.fit(X_train, y_train)
plot_tree(model);
it creates that figure shown in the picture

Print

help(plot_tree)

to see its arguments and what it returns.

returns annotation
[Text(0.5, 0.9166666666666666, ‘X[2] <= 206.5\ngini = 0.633\nsamples = 233\nvalue = [105, 45, 83]’),
Text(0.2727272727272727, 0.75, ‘X[0] <= 42.35\ngini = 0.418\nsamples = 147\nvalue = [104, 42, 1]’),
Text(0.18181818181818182, 0.5833333333333334, ‘gini = 0.0\nsamples = 97\nvalue = [97, 0, 0]’),
Text(0.36363636363636365, 0.5833333333333334, ‘X[4] <= 0.5\ngini = 0.274\nsamples = 50\nvalue = [7, 42, 1]’),
Text(0.18181818181818182, 0.4166666666666667, ‘X[1] <= 16.0\ngini = 0.245\nsamples = 7\nvalue = [6, 0, 1]’),
Text(0.09090909090909091, 0.25, ‘gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]’),
Text(0.2727272727272727, 0.25, ‘gini = 0.0\nsamples = 6\nvalue = [6, 0, 0]’),
Text(0.5454545454545454, 0.4166666666666667, ‘X[3] <= 4075.0\ngini = 0.045\nsamples = 43\nvalue = [1, 42, 0]’),
Text(0.45454545454545453, 0.25, ‘gini = 0.0\nsamples = 38\nvalue = [0, 38, 0]’),
Text(0.6363636363636364, 0.25, ‘X[2] <= 196.5\ngini = 0.32\nsamples = 5\nvalue = [1, 4, 0]’),
Text(0.5454545454545454, 0.08333333333333333, ‘gini = 0.0\nsamples = 1\nvalue = [1, 0, 0]’),
Text(0.7272727272727273, 0.08333333333333333, ‘gini = 0.0\nsamples = 4\nvalue = [0, 4, 0]’),
Text(0.7272727272727273, 0.75, ‘X[1] <= 17.65\ngini = 0.09\nsamples = 86\nvalue = [1, 3, 82]’),
Text(0.6363636363636364, 0.5833333333333334, ‘gini = 0.0\nsamples = 82\nvalue = [0, 0, 82]’),
Text(0.8181818181818182, 0.5833333333333334, ‘X[4] <= 0.5\ngini = 0.375\nsamples = 4\nvalue = [1, 3, 0]’),
Text(0.7272727272727273, 0.4166666666666667, ‘gini = 0.0\nsamples = 1\nvalue = [1, 0, 0]’),
Text(0.9090909090909091, 0.4166666666666667, ‘gini = 0.0\nsamples = 3\nvalue = [0, 3, 0]’)]

Just now I’m out, and cannot check if I’m right or not. I think that the first two numbers in Text are the coordinates of the node i, where i appears as index in X. The rest of text is the usual text displayed in the corresponding box.

and What should be E? Thank you

@SaadKhan

I plotted the points of coordinates given as the first two values in each Text, and indeed they are the coordinates of the tree nodes. Moreover assigning to each pair of coordinates its index in the list you pasted here, I was able to deduce that they are listed in preorder, i.e. the tree was traversed by the preorder method. This information could help to get the list of edges, if you have some knowledge on tree traversal.
You need a function that extracts from the array of coordinates, the list of edges. This plot could help, if you assign to each point the index of its coordonates in the array coords:

import plotly.graph_objects as go
import numpy as np

coords= np.array([[0.5, 0.9166666666666666],
[0.2727272727272727, 0.75],
[0.18181818181818182, 0.5833333333333334],
[0.36363636363636365, 0.5833333333333334],
[0.18181818181818182, 0.4166666666666667],
[0.09090909090909091, 0.25],
[0.27272727270272727, 0.25],
[0.5454545454545454, 0.4166666666666667],
[0.45454545454545453, 0.25],
[0.6363636363636364, 0.25],
[0.5454545454545454, 0.08333333333333333],
[0.7272727272727273, 0.08333333333333333],
[0.7272727272727273, 0.75],
[0.6363636363636364, 0.5833333333333334],
[0.8181818181818182, 0.5833333333333334],
[0.7272727272727273, 0.4166666666666667],
[0.9090909090909091, 0.4166666666666667]])
Xn, Yn= coords.T
fig=go.Figure(go.Scatter(x=Xn, y=Yn, mode="markers", marker_size=15))
fig.update_layout(width=750, height=550)
2 Likes

The Brilliant Idea, fantastic, i extracted the coordinates and the text from plot_tree(model) by:
coords = # empty list
texts = # empty list
for item in plot_tree(model):
coords.append(list(item.get_position()))
texts.append(item.get_text())

Only to convert the circles of the scatter into rectangles and put the text inside, would be absolutely great now

@empet I just want to add the line as shown in the first original pictures, as they are required, rest is done.
I used the functions you provided as getting this, tried for very long now

coords =
texts =
for item in plot_tree(model):
coords.append(list(item.get_position()))
texts.append(item.get_text());
Xn, Yn = coords.T

fig = go.Figure()
fig.add_trace(go.Scatter(x=Xnodes, y=Ynodes, mode=“markers+text”, marker_size=15, text=texts, textposition=‘top center’))

#++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
fig.add_trace(go.Scatter(x=Xedges,
y=Yedges,
mode=‘lines’,
line_color=‘blue’,
line_width=5,
hoverinfo=‘none’
))
fig.update_layout(width=750, height=550)

Any way to solve this?

@SaadKhan

E = [(0,1), (1,2), (1,3), (2,4), (4,5), (4,6), (7,8), (7,9), (9,10), (9,11),
     (0,12), (12, 13), (13,7), (12,14), (14,15), (14, 16)]
labels= np.arange(17)
Xnodes, Ynodes, Xedges, Yedges = get_plotly_data(E, coords)
nodes= get_node_trace(Xnodes, Ynodes, labels, marker_size=18, marker_color='#c0c0c0')
edges= get_edge_trace(Xedges, Yedges)
fig= go.Figure([edges, nodes])
fig.update_layout(title_text="Decision tree",
            title_x=0.5,
            font_size=12,
            showlegend=False,
            width=800,
            height=600,
            xaxis_visible=False,
            yaxis_visible=False, 
            template='none',     
            hovermode='closest',
            paper_bgcolor='#eeeeee')

In get_node_trace() I replaced mode='markers', by mode='markers+text', and commented out
hoverinfo='text', in order to ensure that node indices, i.e. indices of their coordinates in the array
coords, are displayed on the plot…
Good luck, and have a nice weekend!!!

3 Likes

Absolutely briliant, Thank you,
It is just me who do not understand how to get this E list, though you explained it very well and thank you for your time

1 Like

Hi there!

Thanks for both for you very helpful discussion.

To complete the objective, here is a hack to retrieve edges from the tree:


def extract_edges(clf):
    edges = []
    for i, (x, y) in enumerate(
        zip(clf.tree_.children_left, clf.tree_.children_right)
    ):

        if x != -1:
            edges.append((i, x))
        if y != -1:
            edges.append((i, y))

    return edges

I am glad someone resurfaced this item.

I have now had the pleasure of playing with Plotly Tree Plots and also Dash and mermaid

TreePlots are a lot less complicated than they look.
Documentation is here

Created from the example code

mermaid is what got me excited. Docs here
I can feel an article coming on.

The following code created the flowchart below. it is boiler plate code but I will illuminate on it in a later article. All I would say is look at the other dash modules, you may be surprised what you find.

from dash_extensions.enrich import DashProxy
from dash_extensions import Mermaid

app = DashProxy()

app.layout = Mermaid(chart="""
flowchart TD
    Start --> SampleA
    Start --> SampleB
    Start --> SampleC
    Start --> SampleD
    Start --> SampleE
    Start --> SampleF
    Start --> SampleG
    SampleA & SampleB --> Experiment1
    SampleB & SampleC --> Experiment2
    Sample & SampleD --> Experiment3
    SampleD & SampleE --> Experiment4
    SampleE & SampleF --> Experiment5
    SampleF & SampleG --> Experiment6
    Experiment1 --> Result1
    Experiment2 --> Result2
    Experiment3 --> Result3
    Experiment4 --> Result4
    Experiment5 --> Result5
    Experiment6 --> Result6
    Result1 & Result2 --> Experiment7
    Result3 & Result4 --> Experiment8
    Result5 & Result6 --> Experiment9
    Experiment7 --> Result7
    Experiment8 --> Result8
    Experiment9 --> Result9
    Result7 & Result8 --> Experiment10
    Result8 & Result9 --> Experiment11
    Experiment10 --> Result10
    Experiment11 --> Result11
    Result10 & Result11 --> Experiment12
    Experiment12 --> Result13
    Result13 --> FinalOutcome
""")

if __name__ == "__main__":
    app.run_server()

Hopefully, someone looking for info finds this helpful.

1 Like