Add slider to directed weighted graph plot

Hi all,
I want to plot a directed graph with edge weights. On top of that, I want to add a slider to do some thresholding on which edge to show. A simple example: a slider with 2 points, on the first point the slider will show all edges, while on the second point the slide will ONLY show edges with weight > 0.5. This is just a simple example, I want to extend it to have 10 points on the slider (means 10 different thresholds).
The problem I’m having is that in order to visualize directed edges, I plot a single trace for every single edge to be able to set the trace’s opacity as the corresponding edge weight. I also add some hovering text on the edge. So I have a lot of traces on my plot, 1 for node, X traces for edge (where X is the number of edges) and X traces for edge hovering. How can I configure the sliders to do this thresholding for visibility? My idea is to set an id for each trace, contextualizing which “slider point” that particular edge belongs to, but this doesn’t seem work. Below is my code. Any help is appreciated! Thanks!

import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from tqdm import tqdm


G = nx.DiGraph()
edges = [(1,2), (2,4), (4,0), (5,1)]
edge_weights = {(1,2): 0.1, (2,4): 0.5, (4,0): 1, (5,1): 0.3}
node_weights = [0.1, 0.2, 0.3, 0.4, 0.5]

G.add_edges_from(edges)

pos = nx.spring_layout(G, k=10)  # For better example looking

# Plot nodes with go.Scatter
node_x = []
node_y = []
for node in G.nodes():
    x, y = pos[node]
    node_x.append(x)
    node_y.append(y)

traces = []
node_trace = go.Scatter(
    x=node_x, y=node_y, visible=True,
    marker=dict(
        showscale=True, color=node_weights,
        colorbar=dict(
            thickness=15,
            title='Node weights',
            xanchor='left',
            titleside='right'
        ),
        line_width=2))

traces.append(node_trace)


#Get edge trace IDs for slider filtering
edge_trace_ids = {}
for w in edge_weights:
    if w <= 0.5:
        edge_trace_ids[w] = 0
    else:
        edge_trace_ids[w] = 1


x0, y0, x1, y1 = [], [], [], []
i=0
for edge,weight in tqdm(zip(edges,edge_weights)):
    # Get the coordinates of each node
    edge_coord_source = pos[edge[0]]
    edge_coord_target = pos[edge[1]]
    x0.append(edge_coord_source[0])
    y0.append(edge_coord_source[1])
    x1.append(edge_coord_target[0])
    y1.append(edge_coord_target[1])

    # Plot edges hover with go.Scatter.
    middle_hover_trace = go.Scatter(
        x=[], y=[], hovertext=[], mode='markers', hoverinfo="text",
        marker={'size': 20, 'color': 'LightSkyBlue'}, opacity=0,
        visible=False,
        uid=edge_trace_ids[weight] #my attempt to assign ID for each trace for slider filtering
    )

    middle_hover_trace['x'] = tuple([(edge_coord_source[0] + edge_coord_target[0]) / 2])
    middle_hover_trace['y'] = tuple([(edge_coord_source[1] + edge_coord_target[1]) / 2])

    hovertext = f"Edge weight: {weight}"
    middle_hover_trace['hovertext'] = tuple([hovertext])

    traces.append(middle_hover_trace)

    # Plot edges with go.Scatter.
    # This has to be done in a loop to be able to set the opacity/color according to weights
    edge_trace = go.Scatter(
        x=tuple([edge_coord_source[0], edge_coord_target[0], None]),
        y=tuple([edge_coord_source[1], edge_coord_target[1], None]),
        mode='lines',
        marker=dict(color="red"),
        line_shape='spline',
        opacity=weight,
        visible=False,
        uid=edge_trace_ids[weight]
    )
    traces.append(edge_trace)
    i+=1


fig = go.Figure(
    data = traces,
    layout=go.Layout(
        title=dict(text="Directed Graph", font=dict(size=14, color='blue')),
        showlegend=False,
        margin=dict(b=20,l=5,r=5,t=40),
        annotations = [
            dict(
                    ax=x0[i], ay=y0[i], axref='x', ayref='y',
                    x=x1[i], y=y1[i], xref='x', yref='y',
                    showarrow=True, arrowhead=1, arrowsize=2,
                    arrowcolor="red" if w == 1 else "green",
                    opacity=w/3, visible=False, name=edge_trace_ids[w]
                )
                for i,w in enumerate(edge_weights)
        ],
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
    )
)



# Create and add slider
sliders_steps = 2
steps = []
for i in range(sliders_steps):
    step = dict(
        method="update",
        args=[{"visible": [False] * sliders_steps},
            {"title": "Slider switched to step: " + str(i)}],  # layout attribute
    )
    for fig_data in fig.data:
        if int(fig_data.uid) == i:
            fig_data.visible = True

    step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [dict(
    active=10,
    currentvalue={"prefix": "Frequency: "},
    pad={"t": 50},
    steps=steps
)]

fig.update_layout(sliders=sliders)