Add slider to directed weighted graph plot

Hi all,
I want to plot a directed graph with edge weights. On top of that, I want to add a slider to do some thresholding on which edge to show. A simple example: a slider with 2 points, on the first point the slider will show all edges, while on the second point the slide will ONLY show edges with weight > 0.5. This is just a simple example, I want to extend it to have 10 points on the slider (means 10 different thresholds).
The problem Iโ€™m having is that in order to visualize directed edges, I plot a single trace for every single edge to be able to set the traceโ€™s opacity as the corresponding edge weight. I also add some hovering text on the edge. So I have a lot of traces on my plot, 1 for node, X traces for edge (where X is the number of edges) and X traces for edge hovering. How can I configure the sliders to do this thresholding for visibility? My idea is to set an id for each trace, contextualizing which โ€œslider pointโ€ that particular edge belongs to, but this doesnโ€™t seem work. Below is my code. Any help is appreciated! Thanks!

import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from tqdm import tqdm


G = nx.DiGraph()
edges = [(1,2), (2,4), (4,0), (5,1)]
edge_weights = {(1,2): 0.1, (2,4): 0.5, (4,0): 1, (5,1): 0.3}
node_weights = [0.1, 0.2, 0.3, 0.4, 0.5]

G.add_edges_from(edges)

pos = nx.spring_layout(G, k=10)  # For better example looking

# Plot nodes with go.Scatter
node_x = []
node_y = []
for node in G.nodes():
    x, y = pos[node]
    node_x.append(x)
    node_y.append(y)

traces = []
node_trace = go.Scatter(
    x=node_x, y=node_y, visible=True,
    marker=dict(
        showscale=True, color=node_weights,
        colorbar=dict(
            thickness=15,
            title='Node weights',
            xanchor='left',
            titleside='right'
        ),
        line_width=2))

traces.append(node_trace)


#Get edge trace IDs for slider filtering
edge_trace_ids = {}
for w in edge_weights:
    if w <= 0.5:
        edge_trace_ids[w] = 0
    else:
        edge_trace_ids[w] = 1


x0, y0, x1, y1 = [], [], [], []
i=0
for edge,weight in tqdm(zip(edges,edge_weights)):
    # Get the coordinates of each node
    edge_coord_source = pos[edge[0]]
    edge_coord_target = pos[edge[1]]
    x0.append(edge_coord_source[0])
    y0.append(edge_coord_source[1])
    x1.append(edge_coord_target[0])
    y1.append(edge_coord_target[1])

    # Plot edges hover with go.Scatter.
    middle_hover_trace = go.Scatter(
        x=[], y=[], hovertext=[], mode='markers', hoverinfo="text",
        marker={'size': 20, 'color': 'LightSkyBlue'}, opacity=0,
        visible=False,
        uid=edge_trace_ids[weight] #my attempt to assign ID for each trace for slider filtering
    )

    middle_hover_trace['x'] = tuple([(edge_coord_source[0] + edge_coord_target[0]) / 2])
    middle_hover_trace['y'] = tuple([(edge_coord_source[1] + edge_coord_target[1]) / 2])

    hovertext = f"Edge weight: {weight}"
    middle_hover_trace['hovertext'] = tuple([hovertext])

    traces.append(middle_hover_trace)

    # Plot edges with go.Scatter.
    # This has to be done in a loop to be able to set the opacity/color according to weights
    edge_trace = go.Scatter(
        x=tuple([edge_coord_source[0], edge_coord_target[0], None]),
        y=tuple([edge_coord_source[1], edge_coord_target[1], None]),
        mode='lines',
        marker=dict(color="red"),
        line_shape='spline',
        opacity=weight,
        visible=False,
        uid=edge_trace_ids[weight]
    )
    traces.append(edge_trace)
    i+=1


fig = go.Figure(
    data = traces,
    layout=go.Layout(
        title=dict(text="Directed Graph", font=dict(size=14, color='blue')),
        showlegend=False,
        margin=dict(b=20,l=5,r=5,t=40),
        annotations = [
            dict(
                    ax=x0[i], ay=y0[i], axref='x', ayref='y',
                    x=x1[i], y=y1[i], xref='x', yref='y',
                    showarrow=True, arrowhead=1, arrowsize=2,
                    arrowcolor="red" if w == 1 else "green",
                    opacity=w/3, visible=False, name=edge_trace_ids[w]
                )
                for i,w in enumerate(edge_weights)
        ],
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
    )
)



# Create and add slider
sliders_steps = 2
steps = []
for i in range(sliders_steps):
    step = dict(
        method="update",
        args=[{"visible": [False] * sliders_steps},
            {"title": "Slider switched to step: " + str(i)}],  # layout attribute
    )
    for fig_data in fig.data:
        if int(fig_data.uid) == i:
            fig_data.visible = True

    step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [dict(
    active=10,
    currentvalue={"prefix": "Frequency: "},
    pad={"t": 50},
    steps=steps
)]

fig.update_layout(sliders=sliders)