Network Graphs with Graph Objects with Edge-Specific Coloring

Hello everyone,

I am looking for a way to create a network graph with Scatter graph objects, as seen here: Network Graphs | Python | Plotly. I want to create a neural network graph, and so far I am getting this results:

The problem is, I don’t know how to make each edge a different color based on its weight. My code for the edge trace is as follows:

edge_trace = go.Scatter(
         x=edge_position_list_x, y=edge_position_list_y,
             line_width = 4

My figure code looks like this:

fig = go.Figure(data=[edge_trace, node_trace],
                        title='<br>Neural network structure visualization with Python',
                        margin=dict(b=20, l=5, r=5, t=40),
                            text="Python code: <a href=''></a>",
                            xref="paper", yref="paper",
                            x=0.005, y=-0.002)],
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))

The ‘edge_colors’ variable is nothing more but a list of weights of each edge. The ‘edge_position_list’ variables are also lists, with each pair of x or y coordinates being followed by a ‘None’ value. I’ve read [Network Graph] Set each edge with different color - #13 by kapital1 topic, but it doesn’t explain how to do the same thing, but in Graph Objects. Any help is greatly appreciated!

This is my solution for setting weights to edges, inspired by Julian West.

def get_network_coordinates(df, val_1, val_2, attr, layout_type):
    ### val_1 & val_2, e.g. Taxa_1, Taxa_2 columns of df (str)
    ### attr, e.g. name of correlation column in df (str)
    ### Types of layouts (str)
        ### fruchterman_reingold_layout
        ### circular_layout
        ### random_layout
        ### spring_layout
        ### spectral_layout
        ### minimum_spanning_tree
    edges_ = df[[val_1, val_2, attr]]    
    Gx_ = nx.from_pandas_edgelist(edges_, val_1, val_2, edge_attr=[attr])
    if layout_type == "fruchterman_reingold_layout":
        pos_ = nx.fruchterman_reingold_layout(Gx_)
    if layout_type == "circular_layout":
        pos_ = nx.circular_layout(Gx_)
    if layout_type == "random_layout":
        pos_ = nx.random_layout(Gx_)
    if layout_type == "spring_layout":
        pos_ = nx.spring_layout(Gx_)
    if layout_type == "spectral_layout":
        pos_ = nx.spectral_layout(Gx_)
    if layout_type == "minimum_spanning_tree":
        pos_ = nx.minimum_spanning_tree(Gx_)

    Xnodes = [pos_[n][0] for n in Gx_.nodes()]
    Ynodes = [pos_[n][1] for n in Gx_.nodes()]

    Xedges = []
    Yedges = []
    for e in Gx_.edges():
        # x coordinates of the nodes defining the edge e
        Xedges.extend([pos_[e[0]][0], pos_[e[1]][0], None])
        Yedges.extend([pos_[e[0]][1], pos_[e[1]][1], None])
    return Xnodes, Ynodes, Xedges, Yedges, Gx_, pos_
# Get nodes and network graph coordinates
Xnodes, Ynodes, Xedges, Yedges, G, pos_ = get_network_coordinates(dataframe, "value_1", "value_2", "correlation", "circular_layout")

def assign_colour(correlation):
    # Function to assign correlation colors (neg becomes red, pos becomes green)
    if correlation <= 0:
        return "#ffa09b"  # red
        return "#9eccb7"  # green
def assign_thickness(correlation, benchmark_thickness=2, scaling_factor=1):
    # Function to assign correlation thickness based on absolute magnitide
    return benchmark_thickness * abs(correlation)**scaling_factor

def edge_traces():
    ### Function which creates a list of traces for the edges.
    ### In order to assign weights to edges they need to be added separately via a function like this.    
    # assign colours to edges depending on positive or negative correlation
    # assign edge thickness depending on magnitude of correlation
    edge_colours = []
    edge_width = []
    for key, value in nx.get_edge_attributes(G, 'correlation').items():
    edge_trace_ = []
    for i, e in enumerate(G.edges):
        trace_ = go.Scatter(
            x=[pos_[e[0]][0], pos_[e[1]][0]],
            y=[pos_[e[0]][1], pos_[e[1]][1]],
            line=dict(width=edge_width[i], color=edge_colours[i]))
    return edge_trace_
fig = go.Figure(data=[node_trace] + edge_traces()[0:],

Hope this helps