Directed Graph VIsualization Zooming Issues

Hey all, I wrote the following logic to plot a networkx directed graph with plotly:

import plotly.graph_objects as go
import networkx as nx
import numpy as np


def calculate_node_positions(G):
    pos = {}
    operations = sorted([node for node in G.nodes() if node.startswith('O')],
                        key=lambda x: (int(x.split('_')[1]), int(x.split('_')[2])))
    jobs = sorted(set([node.split('_')[1] for node in operations]))
    max_ops = max([int(node.split('_')[2]) for node in operations])

    # Calculate positions for operation nodes
    for i, job in enumerate(jobs):
        job_ops = [node for node in operations if node.split('_')[1] == job]
        for j, op in enumerate(job_ops):
            pos[op] = (j + 1, -i)

    # Calculate positions for S and T nodes
    pos['S'] = (0, -len(jobs) / 2 + 0.5)
    pos['T'] = (max_ops + 1, -len(jobs) / 2 + 0.5)

    return pos


def create_graph_figure(G, pos):
    node_size = 30  # in pixels
    node_radius = node_size / 2  # in pixels
    arrow_length = 15  # in pixels

    # Create node trace
    node_x, node_y = zip(*pos.values())
    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            colorscale='YlGnBu',
            color=[],
            size=node_size,
            colorbar=dict(
                thickness=15,
                title='Node Connections',
                xanchor='left',
                titleside='right'
            ),
            line_width=2
        )
    )

    # Create figure
    fig = go.Figure(data=[node_trace],
                    layout=go.Layout(
                        title='NetworkX Directed Graph with Filled and Dotted Edges',
                        titlefont_size=16,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=20, l=5, r=5, t=40),
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        width=800,
                        height=600
                    ))

    # Add a function to recalculate edge positions
    fig.add_trace(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(size=0), hoverinfo='none'))

    def update_edges(xrange, yrange):
        filled_edges = []
        dotted_edges = []
        annotations = []

        for edge in G.edges(data=True):
            x0, y0 = pos[edge[0]]
            x1, y1 = pos[edge[1]]

            # Convert to pixel coordinates
            px0 = (x0 - xrange[0]) / (xrange[1] - xrange[0]) * fig.layout.width
            py0 = (y0 - yrange[0]) / (yrange[1] - yrange[0]) * fig.layout.height
            px1 = (x1 - xrange[0]) / (xrange[1] - xrange[0]) * fig.layout.width
            py1 = (y1 - yrange[0]) / (yrange[1] - yrange[0]) * fig.layout.height

            dx = px1 - px0
            dy = py1 - py0

            # Calculate angle and adjust endpoints
            angle = np.arctan2(dy, dx)
            px1 = px0 + (dx - (node_radius * np.cos(angle)) * 1.1)  # Adjusted factor
            py1 = py0 + (dy - (node_radius * np.sin(angle)) * 1.1)  # Adjusted factor

            # Calculate arrow start point
            px2 = px0 + (dx - ((arrow_length + node_radius) * np.cos(angle)) * 1.1)  # Adjusted factor
            py2 = py0 + (dy - ((arrow_length + node_radius) * np.sin(angle)) * 1.1)  # Adjusted factor

            # Convert back to data coordinates
            x1 = xrange[0] + (px1 / fig.layout.width) * (xrange[1] - xrange[0])
            y1 = yrange[0] + (py1 / fig.layout.height) * (yrange[1] - yrange[0])
            x2 = xrange[0] + (px2 / fig.layout.width) * (xrange[1] - xrange[0])
            y2 = yrange[0] + (py2 / fig.layout.height) * (yrange[1] - yrange[0])

            if edge[2]['edge_type'] == 'filled':
                filled_edges.append((x0, y0, x1, y1))
            else:
                dotted_edges.append((x0, y0, x1, y1))

            annotations.append(
                dict(
                    ax=x2, ay=y2,
                    axref='x', ayref='y',
                    x=x1, y=y1,
                    xref='x', yref='y',
                    showarrow=True,
                    arrowhead=2,
                    arrowwidth=1,
                    arrowcolor='#888'
                )
            )

        edge_trace_filled = go.Scatter(
            x=[x for edge in filled_edges for x in [edge[0], edge[2], None]],
            y=[y for edge in filled_edges for y in [edge[1], edge[3], None]],
            line=dict(width=1, color='#888'),
            hoverinfo='none',
            mode='lines'
        )

        edge_trace_dotted = go.Scatter(
            x=[x for edge in dotted_edges for x in [edge[0], edge[2], None]],
            y=[y for edge in dotted_edges for y in [edge[1], edge[3], None]],
            line=dict(width=1, color='#888', dash='dot'),
            hoverinfo='none',
            mode='lines'
        )

        return [edge_trace_filled, edge_trace_dotted], annotations

    fig.update_layout(
        xaxis=dict(range=[min(node_x) - 0.5, max(node_x) + 0.5]),
        yaxis=dict(range=[min(node_y) - 0.5, max(node_y) + 0.5])
    )

    initial_traces, initial_annotations = update_edges(fig.layout.xaxis.range, fig.layout.yaxis.range)
    fig.add_traces(initial_traces)
    fig.update_layout(annotations=initial_annotations)

    def update_layout(xrange, yrange):
        new_traces, new_annotations = update_edges(xrange, yrange)
        return {'shapes': [], 'annotations': new_annotations}

    fig.layout.on_change(update_layout, 'xaxis.range', 'yaxis.range')

    return fig


# Create graph
G = nx.DiGraph()
G.add_edge('S', 'O_1_1', edge_type='filled')
G.add_edge('O_1_1', 'O_1_2', edge_type='dotted')
G.add_edge('O_1_2', 'T', edge_type='filled')
G.add_edge('S', 'O_2_1', edge_type='dotted')
G.add_edge('O_2_1', 'O_2_2', edge_type='filled')
G.add_edge('O_2_2', 'T', edge_type='dotted')
G.add_edge('O_2_2', 'O_1_2', edge_type='dotted')
G.add_edge('O_2_2', 'O_1_1', edge_type='dotted')

# Calculate positions
pos = calculate_node_positions(G)

# Create and show the figure
fig = create_graph_figure(G, pos)
fig.show()

The code orders the nodes based on their names, then to plot directed edges, the code connects the middle of a node to the border of a node and then uses an annotation to put an arrowhead at the end of the edge.
In the initial view, everything looks great but when zooming, the lengths of the edges don’t scale.

Can anyone help me with this?

Hey @BeezyP welcome to the forums.

The example shown in the docs does not show this behavior. I did not check the differences between your approach and the one shown in the docs. Maybe it helps.