I'm trying to show legend for edges in a multigraph. I would like to keep the labels fixed

def drawing_2D_interactive():
G = nx.MultiDiGraph()
print(“START DRAWING”)
lista_nodi = load_nodes()
print(len(lista_nodi))
print(“-----------------------------”)
print(“Numero archi da rappresentare:”)
lista_archi = load_edges(interactive=True)
print(len(lista_archi))
label_archi = load_edge_label_interactive()

G.add_nodes_from(lista_nodi)
for arco in lista_archi:
    G.add_edge(arco[0], arco[1], label=label_archi.pop(0))

with open("graph.json") as json_file:
    grafo_file = json.load(json_file)

    node_labels = list()
    for node in grafo_file["nodes"]:
        node_labels.append(node['label'])

print('Numero nodi nel grafo')
print(G.number_of_nodes())
print('Numero archi nel grafo')
print(G.number_of_edges())

spiral = nx.spiral_layout(G, resolution=2.5)
spring_pos = nx.spring_layout(G, seed=10, k=1.5, pos=spiral)
kamada_pos = nx.kamada_kawai_layout(G, scale=1.2, pos=spring_pos)
spring_pos = nx.spring_layout(G, seed=110, k=1.5, pos=kamada_pos, scale=10, center=(0, 0))

# we need to seperate the X,Y,Z coordinates for Plotly
x_nodes_2D = [spring_pos[str(i)][0] for i in range(1, 43)]  # x-coordinates of nodes
y_nodes_2D = [spring_pos[str(i)][1] for i in range(1, 43)]  # y-coordinates

edge_list_2D = G.edges()

# we  need to create lists that contain the starting and ending coordinates of each edge.
x_edges_2D = []
y_edges_2D = []

x_text = []
y_text = []

# need to fill these with all of the coordiates
for edge in edge_list_2D:
    # format: [beginning,ending,None]
    x_coords = [spring_pos[edge[0]][0], spring_pos[edge[1]][0], None]
    x_edges_2D += x_coords
    x_text.append((spring_pos[edge[0]][0] + spring_pos[edge[1]][0]) / 2)

    y_coords = [spring_pos[edge[0]][1], spring_pos[edge[1]][1], None]
    y_edges_2D += y_coords
    y_text.append((spring_pos[edge[0]][1] + spring_pos[edge[1]][1]) / 2)

edge_label_trace = go.Scatter(x=x_text,
                              y=y_text,
                              mode='text',
                              hoverinfo='all',
                              text=label_archi,
                              textposition='middle center',
                              marker_size=0.5,
                              )

trace_edges_2D = go.Scatter(x=x_edges_2D,
                            y=y_edges_2D,
                            mode='lines+text',
                            opacity=0.3,
                            line=dict(color='black', width=0.2),

                            marker=dict(size=8,
                                        opacity=0.3,
                                        line=dict(color='black', width=0.1)),
                            text=label_archi,
                            hoverinfo='text',
                            textfont=dict(
                                size=20
                            ),
                            )

trace_nodes_2D = go.Scatter(x=x_nodes_2D,
                            y=y_nodes_2D,
                            mode='markers+text',
                            opacity=1,
                            marker=dict(symbol='circle',
                                        size=18,
                                        opacity=0.5,
                                        color=lista_colori,
                                        line=dict(color='white', width=0.2)),
                            text=node_labels,
                            hoverinfo='text')

axis = dict(showbackground=False,
            showline=False,
            zeroline=False,
            showgrid=False,
            showticklabels=False,
            title='')

edge_list_2D = G.edges()
annotations = list()
for edge in edge_list_2D:
    annotations.append(dict(ax=spring_pos[edge[0]][0], ay=spring_pos[edge[0]][1], axref='x', ayref='y',
                            x=spring_pos[edge[1]][0], y=spring_pos[edge[1]][1], xref='x', yref='y',
                            showarrow=True, arrowhead=3, arrowsize=1.5, ))

# also need to create the layout for our plot
layout = go.Layout(title="Rappresentazione interattiva della saga Hrafnkel",
                   width=1080,
                   height=1025,
                   showlegend=False,
                   annotations=annotations,
                   scene=dict(xaxis=dict(axis),
                              yaxis=dict(axis),
                              zaxis=dict(axis),
                              ),
                   margin=dict(t=100),
                   hovermode='closest')

# Include the traces we want to plot and create a figure
data = [trace_edges_2D, trace_nodes_2D, edge_label_trace]
fig = go.Figure(data=data, layout=layout)
fig.update_layout(yaxis=dict(scaleanchor="x", scaleratio=1.3), plot_bgcolor='rgb(255,255,255)')
fig.update_layout(uniformtext_minsize=3)
fig.update_traces(mode="markers+text")
#fig.show()
py.plot(fig)