Sankey diagram - source and target can be list of int only?

Hi there,

Referring to the Sankey diagram doc (Sankey traces in Python),

  • the source can be a Type: list, numpy array, or Pandas series of numbers, strings, or datetimes.
  • and the target can be Type: list, numpy array, or Pandas series of numbers, strings, or datetimes.

However, while the below minimal example works,

import plotly.graph_objects as go
fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15, thickness = 20, line = dict(color = "black", width = 0.5),
      label = ["Strat1", "Strat1"], color = "blue"
    ),
    link = dict(
      source = [0, 0], # indices correspond to labels, eg A1, A2, A1, B1, ...
      target = [1, 2],
      value = [8, 4]
  ))])
fig.update_layout(title_text="Basic Sankey Diagram", font_size=10)
fig.show()

The one below does not:

fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15, thickness = 20, line = dict(color = "black", width = 0.5),
      label = ["Strat1", "Strat1"], color = "blue"
    ),
    link = dict(
      source = ["A", "A"], # indices correspond to labels, eg A1, A2, A1, B1, ...
      target = ["B", "C"],
      value = [8, 4]
  ))])
fig.update_layout(title_text="Basic Sankey Diagram", font_size=10)
fig.show()

Am I doing something wrong or is there a typo in the doc ?

Hello!

I took a look at the source code and source, target and value are all value type data_array and take numeric values.
(plotly.js/attributes.js at a4ce2b71f083cdbcd6dcf04931496ee13d0872e8 · plotly/plotly.js · GitHub)

 source: {
            valType: 'data_array',
            dflt: [],
            description: 'An integer number `[0..nodes.length - 1]` that represents the source node.'
        },
target: {
            valType: 'data_array',
            dflt: [],
            description: 'An integer number `[0..nodes.length - 1]` that represents the target node.'
        },
value: {
            valType: 'data_array',
            dflt: [],
            description: 'A numeric value representing the flow volume value.'

If you’re wanting to use text data, here’s an example of someone transforming text data to use in the sankey chart.

import pandas as pd
import numpy as np
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot

node_label = ["A1", "A2", "B1", "B2","B3", "C1", "C2"]
node_dict = {y:x for x, y in enumerate(node_label)}
print(node_dict)

source = ['A1','A1','A1','A2','A2','A2','B1','B2','B2','B3','B3']
target = ['B1','B2','B3','B1','B2','B3','C1','C1','C2','C1','C2']
values = [ 10, 5, 15, 5, 20, 45, 15, 20, 5, 30, 30 ]

source_node = [node_dict[x] for x in source]
target_node = [node_dict[x] for x in target]
# [0, 0, 0, 1, 1, 1, 2, 3, 3, 4, 4]
# [2, 3, 4, 2, 3, 4, 5, 5, 6, 5, 6]
import plotly.graph_objects as go # Import the graphical object

fig = go.Figure(
    data=[go.Sankey( # The plot we are interest
        # This part is for the node information
        node = dict(
            label = node_label
        ),
        # This part is for the link information
        link = dict(
            source = source_node,
            target = target_node,
            value = values
        ))])

# And shows the plot
fig.show()

Best,
Eliza

1 Like

HI @eliza
Thanks for your reply.

In the meantime, I reached the same conclusion and wrote a draft of function with a quick and very dirty docstring to generalise the process. It takes a dataframe as input, plus the column name of the data to count, and generate the lists of sources, nodes, labels. It also generates the array of colors, should we want to highlight some nodes.

for instance

The function is:


def gen_lists_for_sankey(df_src, col_with_unique_ids, return_int_of_critical_val=True, list_of_critical_val=["---"]):
    """
    df_src: a dataframe whose columns are categories, excepted one, containing the unique ids of the data to count
    col_with_unique_ids: a string; it is the column name in which the unique values identifier are.
    For instance, df_src could contain a column 'color', a column "seat", a column 'horsepower', and a
    column containing the unique identifier, such as 'VIN' (vehicle identification number).
    return_int_of_critical_val can be set on True or False; if True, the func returns the integer of the
    nodes corresponding to the "critical values". This generates a list of colors usable to color the nodes
    the list_of_critical_value contains the list of value we want to highlight on the sankey
    
    Returns the list of sources, targets, values, nodes(akalabels) and the dict containing the
    list of unique values and their INT value. These lists can be used to plot a sankey diagram.
    Optionaly, it cas also return a dict of list of int, each list corresponding to the critical val.
    THe critical val for which we want the int value can be passed as a list. By default, it returns
    the list of int corresponding to a '---' in any column.
    The list of int corresponding to the critical values (aka suspicious values or missing value or default
    records) are useful to apply color on specific nodes/link
    Example of returned dict:
    {"---": [11, 12],
     "None": [8],
     etc: []}
    """
    memDict = {}
    # column_with_unique_ids = "Fund Technical ID"
    # df_source = sample.copy(deep=True)
    new_idx = df_src.columns.to_list()
    new_idx.remove(col_with_unique_ids)
    list_of_sources = []
    list_of_targets = []
    list_of_values = []
    list_of_nodes = []

    max_int = 0
    for col, ncol in zip(new_idx, new_idx[1:]):
        df = df_src.groupby([col, ncol])[col_with_unique_ids].count().reset_index()
        # list_of_relevant_columns = df.columns.to_list()
        # list_of_relevant_columns.remove(column_with_unique_ids)
        #for i in list_of_relevant_columns:
        for i in [col, ncol]:
            if i in memDict.keys():
                pass
            else:
                #memDict[i] = {"uniqueValuesAsString":list(df[i].unique())}
                memDict[i] = {n[1]:n[0]+max_int for n in enumerate(list(df[i].unique()))}
                print(memDict)
                print("_______________")
                max_int = max(memDict[i].values())+1

        list_of_sources.extend(df[col].replace(memDict[col]))
        list_of_targets.extend(df[ncol].replace(memDict[ncol]))
        list_of_values.extend(df[col_with_unique_ids])
        print(df)
        print("=====================================")
    for k in memDict.keys():
        list_of_nodes.extend(memDict[k].keys())

    if return_int_of_critical_val:
        dCritNodes = {}
        for critical_val in list_of_critical_val:
            listOfCriticalNodes = []
            for i in memDict.keys():
                if critical_val in memDict[i].keys():
                    listOfCriticalNodes.append(memDict[i][critical_val])
            dCritNodes [critical_val] = listOfCriticalNodes
        return list_of_sources, list_of_targets, list_of_values, list_of_nodes, memDict, dCritNodes 
    else:
        return list_of_sources, list_of_targets, list_of_values, list_of_nodes, memDict

How to use it:

src, trgt, val, lbl, mD, cD = gen_lists_for_sankey(sample[["index", "sex", "smoker", "day", "time"]], "index", return_int_of_critical_val=True, list_of_critical_val=["No"])

or

sample = px.data.tips().reset_index()  # reset index just to make sure that the df passed as input has the expected format, which is the one I was working with when I did need this function
src, trgt, val, lbl, mD = gen_lists_for_sankey(sample[["index", "sex", "smoker", "day", "time"]], "index")

)

I know it’s ugly and not PEP8, but id does the job and it’s a draft.

Then building a sankey with the highlighted node is straightforward:

x_domain = [0.1, 0.9]
color_nodes = [
    "blue" if k in ["No", "Sat"] else
    "lightgreen" if k not in ["Male", "Yes"] else
    "indianred" for v in mD.values() for k in v.keys()
]
fig6 = go.Figure(data=[go.Sankey(
    node=dict(
        pad=15, thickness=20, line=dict(color="black", width=0.5),
        color=color_nodes,
        label=lbl,
    ),
    link=dict(
        source=src,
        target=trgt,
        value=val,
        color=color,
    ),
    domain=dict(x=x_domain, y=[0.2, 0.9])
)])
fig6.update_layout(title_text="Basic Sankey Diagram", font_size=10)

#And to add the name of the categories as annotations:
colnames = ["Cat#1", "Cat#2", "Cat#3", "Cat#4"]
for xcoord, colname in enumerate(colnames):
    fig6.add_annotation(
        # x=xcoord / (len(colnames) - 1),
        x=x_domain_boundaries[0] + x_domain_width * xcoord / (len(colnames) - 1),
        y=0.1,
        xref="paper",
        yref="paper",
        text=colname,
        showarrow=False,
        xanchor="center",
        # font=dict(
        #     family="Courier New, monospace",
        #     size=16,
        #     color="tomato"
        # ),
        align="center",
    )
fig6.show()

Which gives

WIth a bit of work more, the function generating the diagram could be added into the body function. Variables names could be cleaned, and we would get Sankey Diagram by writing just one row.

It would be nice if the plotly.js could have a feature to disable the automatic positioning of the text along the bar; it’s always on the right, excepted for the last bar, it move on the left. On some of my charts, there is collision with the previous bar

Note: I will very likely clean above code by the end of next week and update this post after. I was very reluctant to copy paste it here, given the mess :smiley:

2 Likes