Callback function on Sankey diagram

Hi Plotly community!

I have just started with Plotly for python and I am already stuck :confused:
When I use a callback function (e.g. on_click) how can I distinguish between a node and a edge object?

I am using the following approach:

def update_color(trace, points, state):
# Compute new color array
new_color = np.array(fig.data[0].node.color)
new_color[points.point_inds[0]] = “red”
with fig.batch_update():
fig.data[0].node.color = new_color

fig.data[0].on_click(update_color)

The problem with this approach is that the id of links and nodes overlap. Therefore, when I click on any link between nodes the id of the link is used to choose and color the node with that id. Clearly, if there are more links than nodes, than such click can also lead to an out-of-index error.
I would really appreciate your help here.

Hi [xanderwallace],
to avoid the overlapping index issue (between nodes and edges/transition) you can try the following hack that worked for me.

  1. count the number of nodes you have in total and add a similar amount of None value in source and target list of link dict. Thus, now the index count for transitions will begin with len(num_of_nodes).

Example code: pastebin link: https://paste.ubuntu.com/p/zD5HjpZsnq/

import plotly.graph_objects as go
import functools
import ipywidgets as widgets
from ipywidgets import Output

def click_detector(trace, points, state, node_list):
# Compute new color array
plot_output_print.clear_output()
with plot_output_print:
click_in = points.point_inds[-1]
if(click_in < len(node_list)):
print(“Node clicked”, click_in, node_list[click_in])
else:
print(“Edge clicked”, click_in)

plot_output_print = widgets.Output(layout={‘border’: ‘1px solid black’})

fig = go.FigureWidget()
#order the nodes based on the index you want to appear on the click
node_label_list = [‘W_NR_0’,‘W_Low_1’,‘M_NR_2’,‘M_L_3’,‘A_NR_4’,‘A_L_5’,‘Ref_NR_6’,‘Ref_Low_7’, ‘Ref2_NR_8’, ‘Ref2_L_9’]

sankey_trace = go.Sankey(arrangement = “fixed”, name = “san”,
node = dict(
pad = 15,
thickness = 20,
line = dict(color = “red”, width = 0.5),
label = node_label_list,
color = [‘rgb(99,99,99)’,
‘rgb(0,109,44)’,
‘rgb(253,141,60)’,
‘rgb(179,0,0)’,
‘rgb(99,99,99)’,
‘rgb(0,109,44)’,
‘rgb(253,141,60)’,
‘rgb(179,0,0)’, “rgb(200,0,0)”, “rgb(255,0,255)”],
x= [0.1, 0.1, 0.3, 0.3, 0.5, 0.5, 0.7, 0.7, 0.8, 0.8],
y = [0.7, 0.1, 0.7, 0.1, 0.7, 0.1, 0.7, 0.1, 0.7, 0.1]),
link = dict(
source = [None] * len(node_label_list) + [0, 2, 4, 0, 3, 4, 1, 0, 4],
target = [None] * len(node_label_list) + [2, 4, 7, 3, 4, 7, 2, 4, 6],
value = [None] * len(node_label_list) + [1, 1, 1, 1, 1, 1, 1, 1, 1],
label = [“None_T”] * len(node_label_list) + [“same”, “same”, “same”, “diff”, “diff”, “diff”, “l”, “k”, “m”],
color = [“rgba(255,255,255,0)”] * len(node_label_list) + [“rgba(255,255,0,0.6)”, “rgba(0,0,200,1)”, “rgba(155,0,0,1)”, “rgba(0,100,120,1)”, “rgba(90,0,0,1)”, “rgba(90,0,0,0.4)”, “rgba(100, 155,160,0.4)”, “rgba(50,170,0,0.4)”, “rgba(100,30,100,1)”]
))

fig.add_trace(sankey_trace)
fig.update_layout(title_text=“Basic Sankey Diagram”, font_size=10)
fig.update_layout(height=600, width=1200)

fig.data[0].on_click(functools.partial(click_detector, node_list = node_label_list))
plot_output_print.clear_output()
display(fig, plot_output_print)