HI @eliza
Thanks for your reply.
In the meantime, I reached the same conclusion and wrote a draft of function with a quick and very dirty docstring to generalise the process. It takes a dataframe as input, plus the column name of the data to count, and generate the lists of sources, nodes, labels. It also generates the array of colors, should we want to highlight some nodes.
for instance
The function is:
def gen_lists_for_sankey(df_src, col_with_unique_ids, return_int_of_critical_val=True, list_of_critical_val=["---"]):
"""
df_src: a dataframe whose columns are categories, excepted one, containing the unique ids of the data to count
col_with_unique_ids: a string; it is the column name in which the unique values identifier are.
For instance, df_src could contain a column 'color', a column "seat", a column 'horsepower', and a
column containing the unique identifier, such as 'VIN' (vehicle identification number).
return_int_of_critical_val can be set on True or False; if True, the func returns the integer of the
nodes corresponding to the "critical values". This generates a list of colors usable to color the nodes
the list_of_critical_value contains the list of value we want to highlight on the sankey
Returns the list of sources, targets, values, nodes(akalabels) and the dict containing the
list of unique values and their INT value. These lists can be used to plot a sankey diagram.
Optionaly, it cas also return a dict of list of int, each list corresponding to the critical val.
THe critical val for which we want the int value can be passed as a list. By default, it returns
the list of int corresponding to a '---' in any column.
The list of int corresponding to the critical values (aka suspicious values or missing value or default
records) are useful to apply color on specific nodes/link
Example of returned dict:
{"---": [11, 12],
"None": [8],
etc: []}
"""
memDict = {}
# column_with_unique_ids = "Fund Technical ID"
# df_source = sample.copy(deep=True)
new_idx = df_src.columns.to_list()
new_idx.remove(col_with_unique_ids)
list_of_sources = []
list_of_targets = []
list_of_values = []
list_of_nodes = []
max_int = 0
for col, ncol in zip(new_idx, new_idx[1:]):
df = df_src.groupby([col, ncol])[col_with_unique_ids].count().reset_index()
# list_of_relevant_columns = df.columns.to_list()
# list_of_relevant_columns.remove(column_with_unique_ids)
#for i in list_of_relevant_columns:
for i in [col, ncol]:
if i in memDict.keys():
pass
else:
#memDict[i] = {"uniqueValuesAsString":list(df[i].unique())}
memDict[i] = {n[1]:n[0]+max_int for n in enumerate(list(df[i].unique()))}
print(memDict)
print("_______________")
max_int = max(memDict[i].values())+1
list_of_sources.extend(df[col].replace(memDict[col]))
list_of_targets.extend(df[ncol].replace(memDict[ncol]))
list_of_values.extend(df[col_with_unique_ids])
print(df)
print("=====================================")
for k in memDict.keys():
list_of_nodes.extend(memDict[k].keys())
if return_int_of_critical_val:
dCritNodes = {}
for critical_val in list_of_critical_val:
listOfCriticalNodes = []
for i in memDict.keys():
if critical_val in memDict[i].keys():
listOfCriticalNodes.append(memDict[i][critical_val])
dCritNodes [critical_val] = listOfCriticalNodes
return list_of_sources, list_of_targets, list_of_values, list_of_nodes, memDict, dCritNodes
else:
return list_of_sources, list_of_targets, list_of_values, list_of_nodes, memDict
How to use it:
src, trgt, val, lbl, mD, cD = gen_lists_for_sankey(sample[["index", "sex", "smoker", "day", "time"]], "index", return_int_of_critical_val=True, list_of_critical_val=["No"])
or
sample = px.data.tips().reset_index() # reset index just to make sure that the df passed as input has the expected format, which is the one I was working with when I did need this function
src, trgt, val, lbl, mD = gen_lists_for_sankey(sample[["index", "sex", "smoker", "day", "time"]], "index")
)
I know it’s ugly and not PEP8, but id does the job and it’s a draft.
Then building a sankey with the highlighted node is straightforward:
x_domain = [0.1, 0.9]
color_nodes = [
"blue" if k in ["No", "Sat"] else
"lightgreen" if k not in ["Male", "Yes"] else
"indianred" for v in mD.values() for k in v.keys()
]
fig6 = go.Figure(data=[go.Sankey(
node=dict(
pad=15, thickness=20, line=dict(color="black", width=0.5),
color=color_nodes,
label=lbl,
),
link=dict(
source=src,
target=trgt,
value=val,
color=color,
),
domain=dict(x=x_domain, y=[0.2, 0.9])
)])
fig6.update_layout(title_text="Basic Sankey Diagram", font_size=10)
#And to add the name of the categories as annotations:
colnames = ["Cat#1", "Cat#2", "Cat#3", "Cat#4"]
for xcoord, colname in enumerate(colnames):
fig6.add_annotation(
# x=xcoord / (len(colnames) - 1),
x=x_domain_boundaries[0] + x_domain_width * xcoord / (len(colnames) - 1),
y=0.1,
xref="paper",
yref="paper",
text=colname,
showarrow=False,
xanchor="center",
# font=dict(
# family="Courier New, monospace",
# size=16,
# color="tomato"
# ),
align="center",
)
fig6.show()
Which gives
WIth a bit of work more, the function generating the diagram could be added into the body function. Variables names could be cleaned, and we would get Sankey Diagram by writing just one row.
It would be nice if the plotly.js could have a feature to disable the automatic positioning of the text along the bar; it’s always on the right, excepted for the last bar, it move on the left. On some of my charts, there is collision with the previous bar
Note: I will very likely clean above code by the end of next week and update this post after. I was very reluctant to copy paste it here, given the mess