Hey all, I wrote the following logic to plot a networkx directed graph with plotly:
import plotly.graph_objects as go
import networkx as nx
import numpy as np
def calculate_node_positions(G):
pos = {}
operations = sorted([node for node in G.nodes() if node.startswith('O')],
key=lambda x: (int(x.split('_')[1]), int(x.split('_')[2])))
jobs = sorted(set([node.split('_')[1] for node in operations]))
max_ops = max([int(node.split('_')[2]) for node in operations])
# Calculate positions for operation nodes
for i, job in enumerate(jobs):
job_ops = [node for node in operations if node.split('_')[1] == job]
for j, op in enumerate(job_ops):
pos[op] = (j + 1, -i)
# Calculate positions for S and T nodes
pos['S'] = (0, -len(jobs) / 2 + 0.5)
pos['T'] = (max_ops + 1, -len(jobs) / 2 + 0.5)
return pos
def create_graph_figure(G, pos):
node_size = 30 # in pixels
node_radius = node_size / 2 # in pixels
arrow_length = 15 # in pixels
# Create node trace
node_x, node_y = zip(*pos.values())
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers',
hoverinfo='text',
marker=dict(
showscale=True,
colorscale='YlGnBu',
color=[],
size=node_size,
colorbar=dict(
thickness=15,
title='Node Connections',
xanchor='left',
titleside='right'
),
line_width=2
)
)
# Create figure
fig = go.Figure(data=[node_trace],
layout=go.Layout(
title='NetworkX Directed Graph with Filled and Dotted Edges',
titlefont_size=16,
showlegend=False,
hovermode='closest',
margin=dict(b=20, l=5, r=5, t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
width=800,
height=600
))
# Add a function to recalculate edge positions
fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(size=0), hoverinfo='none'))
def update_edges(xrange, yrange):
filled_edges = []
dotted_edges = []
annotations = []
for edge in G.edges(data=True):
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
# Convert to pixel coordinates
px0 = (x0 - xrange[0]) / (xrange[1] - xrange[0]) * fig.layout.width
py0 = (y0 - yrange[0]) / (yrange[1] - yrange[0]) * fig.layout.height
px1 = (x1 - xrange[0]) / (xrange[1] - xrange[0]) * fig.layout.width
py1 = (y1 - yrange[0]) / (yrange[1] - yrange[0]) * fig.layout.height
dx = px1 - px0
dy = py1 - py0
# Calculate angle and adjust endpoints
angle = np.arctan2(dy, dx)
px1 = px0 + (dx - (node_radius * np.cos(angle)) * 1.1) # Adjusted factor
py1 = py0 + (dy - (node_radius * np.sin(angle)) * 1.1) # Adjusted factor
# Calculate arrow start point
px2 = px0 + (dx - ((arrow_length + node_radius) * np.cos(angle)) * 1.1) # Adjusted factor
py2 = py0 + (dy - ((arrow_length + node_radius) * np.sin(angle)) * 1.1) # Adjusted factor
# Convert back to data coordinates
x1 = xrange[0] + (px1 / fig.layout.width) * (xrange[1] - xrange[0])
y1 = yrange[0] + (py1 / fig.layout.height) * (yrange[1] - yrange[0])
x2 = xrange[0] + (px2 / fig.layout.width) * (xrange[1] - xrange[0])
y2 = yrange[0] + (py2 / fig.layout.height) * (yrange[1] - yrange[0])
if edge[2]['edge_type'] == 'filled':
filled_edges.append((x0, y0, x1, y1))
else:
dotted_edges.append((x0, y0, x1, y1))
annotations.append(
dict(
ax=x2, ay=y2,
axref='x', ayref='y',
x=x1, y=y1,
xref='x', yref='y',
showarrow=True,
arrowhead=2,
arrowwidth=1,
arrowcolor='#888'
)
)
edge_trace_filled = go.Scatter(
x=[x for edge in filled_edges for x in [edge[0], edge[2], None]],
y=[y for edge in filled_edges for y in [edge[1], edge[3], None]],
line=dict(width=1, color='#888'),
hoverinfo='none',
mode='lines'
)
edge_trace_dotted = go.Scatter(
x=[x for edge in dotted_edges for x in [edge[0], edge[2], None]],
y=[y for edge in dotted_edges for y in [edge[1], edge[3], None]],
line=dict(width=1, color='#888', dash='dot'),
hoverinfo='none',
mode='lines'
)
return [edge_trace_filled, edge_trace_dotted], annotations
fig.update_layout(
xaxis=dict(range=[min(node_x) - 0.5, max(node_x) + 0.5]),
yaxis=dict(range=[min(node_y) - 0.5, max(node_y) + 0.5])
)
initial_traces, initial_annotations = update_edges(fig.layout.xaxis.range, fig.layout.yaxis.range)
fig.add_traces(initial_traces)
fig.update_layout(annotations=initial_annotations)
def update_layout(xrange, yrange):
new_traces, new_annotations = update_edges(xrange, yrange)
return {'shapes': [], 'annotations': new_annotations}
fig.layout.on_change(update_layout, 'xaxis.range', 'yaxis.range')
return fig
# Create graph
G = nx.DiGraph()
G.add_edge('S', 'O_1_1', edge_type='filled')
G.add_edge('O_1_1', 'O_1_2', edge_type='dotted')
G.add_edge('O_1_2', 'T', edge_type='filled')
G.add_edge('S', 'O_2_1', edge_type='dotted')
G.add_edge('O_2_1', 'O_2_2', edge_type='filled')
G.add_edge('O_2_2', 'T', edge_type='dotted')
G.add_edge('O_2_2', 'O_1_2', edge_type='dotted')
G.add_edge('O_2_2', 'O_1_1', edge_type='dotted')
# Calculate positions
pos = calculate_node_positions(G)
# Create and show the figure
fig = create_graph_figure(G, pos)
fig.show()
The code orders the nodes based on their names, then to plot directed edges, the code connects the middle of a node to the border of a node and then uses an annotation to put an arrowhead at the end of the edge.
In the initial view, everything looks great but when zooming, the lengths of the edges donβt scale.
Can anyone help me with this?