Dash Graph 'show label' button requiring two clicks to show all labels

Created a risk quadrant using Dash and Plotly, but ended up facing an issue where the Button created in plotly to show and hide labels, requires 2 mouse-clicks to display all the product labels.

PFA Code:

from dash import Dash, dcc, html, Input, Output
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import ssl

#fixing self signed/signed TLS cert checks
ssl._create_default_https_context = ssl._create_unverified_context

external_stylesheets = [
    'https://fonts.googleapis.com/css?family=Open+Sans:200,400,600|&display=swap',
]

#Scoreboard data api
URL = ''
fam_url= ''

df=pd.read_json(URL)
df=pd.DataFrame(df)

df=df[['prod', 'scorecard_full', 'risk_score', 'BU']]

#read the skipper api data
df1=pd.read_json(fam_url)
df1=pd.DataFrame(df1)
df1=df1[['full_prod','prod_family','business_unit']]

#merge both datasets
new_data=pd.merge(df, df1.drop_duplicates(subset='full_prod'), left_on=['prod'], right_on=['full_prod'], how='outer')
#drop extra columns
new_data=new_data.drop(columns=['full_prod', 'business_unit'])

#if empty values
new_data=new_data.replace('', np.nan, regex=True)
new_data.fillna('Misc', inplace=True)

labels_visible = 1          

#clean up column names 
risk_data = new_data.copy()
risk_data.rename(columns={'prod': 'Product Name', 'scorecard_full': 'Progress Score', 'risk_score': 'Risk Score', 'BU': 'Business Unit', 'prod_family': 'Product Family'},
inplace = True) 
#print(risk_data.tail())
filtered_data = risk_data.copy()

#create color palette
colorscales = px.colors.named_colorscales()

#create filter variables
available_bu=new_data['BU'].unique()
available_family=new_data['prod_family'].unique()

app = Dash(__name__, external_stylesheets=external_stylesheets)

app.css.config.serve_locally = True
app.title = 'Risk Quadrant'

   
#create Filter
def createFilter(select_bu, select_family):
    #print('Inside createFilter : ',select_bu, ', ',select_family)
    global filtered_data
    global risk_data
    if (select_bu == ['Product'] or select_bu==None or select_bu==[]):
        if select_family == ['Product'] or select_family==None or select_family==[]:
            filtered_data = risk_data.copy()
        
        else:
            filtered_data = risk_data.loc[risk_data['Product Family'].isin(select_family)]
    else:
        if select_family == ['Product'] or select_family==None or select_family==[]:
            filtered_data = risk_data.loc[risk_data['Business Unit'].isin(select_bu)]
        
        else:
            filtered_data_tmp = risk_data.loc[risk_data['Product Family'].isin(select_family)]
            filtered_data = filtered_data_tmp.loc[risk_data['Business Unit'].isin(select_bu)]
    return filtered_data


#HTML Front-end
app.layout = html.Div([
    
    html.Div([
        #html.Link(rel='stylesheet',href='/assets/main.css'),
        html.A(
            href="",
            children=[
            html.Img(id='logo', src='')]
        ),
        
        html.Div('Risk Quadrant', id='header')
    ], id='top_bar'),
    
    
    html.Div(
    [   html.P('Filters', id='sidebar_header'),
        dcc.Dropdown(id="select_bu",
                 options=[{'label': i,'value': i} for i in available_bu],
                 multi=True,
                 value=['Product'],
                 searchable=True,
                 clearable=True,
                 placeholder="Business Unit"
                 ),
        dcc.Dropdown(id="select_family",
                 options=[{'label': i,'value': i} for i in available_family],
                 multi=True,
                 value=['Product'],
                 searchable=True,
                 clearable=True,
                 placeholder="Product Family"
                 ),

    ], id='navbar'),

    html.Div(
    [
        dcc.Graph(id="scatter-plot", animate=True),
    ], id="graph")
    
])


@app.callback(
    Output("scatter-plot", "figure"),
    [Input(component_id='select_bu', component_property='value'),
    Input(component_id='select_family', component_property='value')])

#@app.server.route('/assets/main.css')

#Function to plot the data
def update_bar_chart(select_bu, select_family):

    global filtered_data
    filtered_data = createFilter(select_bu, select_family)
   
    #annotations = [dict(text=filtered_data['Product Name'],visible=False) ]
    #graph plotting
    fig = px.scatter(
        data_frame=filtered_data, x='Risk Score', y='Progress Score',
        range_x=(100,0), range_y=(0,100),
        #hover_name='Product Name',
        text='Product Name',
        template="simple_white",
        hover_data=['Product Name', 'Product Family', "Risk Score", "Progress Score"],
        #color='Risk Score',
        #color_continuous_scale=['green','orange','red']
    )

    fig.update_traces(textposition='top center')    #position for labels

    
    fig.layout.update(                              #buttons for labels
    updatemenus=[
        dict(
            type="buttons",
            direction="right",
            x=1.3,
            y=0.95,
            showactive=True,
            active=0,
            buttons=list([
                dict(
                    label="Show Labels",
                    method='restyle',
                    args=[{'text':[filtered_data['Product Name']]}],
                ),
                 dict(
                    label="Hide Labels",
                    method='restyle',
                    args=[{'text':None}],
                ) 
            ])
        )
    ])
    
    #integrate font-style
    fig.update_layout(
        font_family="Open Sans",
        legend_title_font_color="green",
            
        width=1000,
        height=800,
        autosize=False,
        
    )

    fig.update_layout(paper_bgcolor="#F5F5F5")

    #Add annotation on each quadrant
    fig.add_annotation(x=75, y=95,
                text="High Risk, High Progress",
                showarrow=False,
                font=dict(color='Orange',
                size=18)
                )

    fig.add_annotation(x=25, y=95,
                text="Low Risk, High Progress",
                showarrow=False,
                font=dict(color='Green',
                size=18))

    fig.add_annotation(x=75, y=5,
                text="High Risk, Low Progress",
                showarrow=False,
                font=dict(color='Red',
                size=18))
        
    fig.add_annotation(x=25, y=5,
                text="Low Risk, Low Progress",
                showarrow=False,
                font=dict(color='Orange',
                size=18))


    fig.add_annotation(x=5, y=95,
                text="Q1",
                showarrow=False,
                font=dict(color='Grey',
                size=14))

    fig.add_annotation(x=5, y=5,
                text="Q2",
                showarrow=False,
                font=dict(color='Grey',
                size=14))

    fig.add_annotation(x=95, y=5,
                text="Q3",
                showarrow=False,
                font=dict(color='Grey',
                size=14))

    fig.add_annotation(x=95, y=95,
                text="Q4",
                showarrow=False,
                font=dict(color='Grey',
                size=14))

    #add labels to x-axis
    fig.update_xaxes(title_text='Risk Score')
    fig.update_xaxes(title_font=dict(size=18, family='Open Sans', color='Dark Blue'),  dtick=10)
    #make gridlines and margins visible
    fig.update_xaxes(showline=True, linewidth=2, linecolor='Light Grey', mirror=True)
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='Light Grey')
        
    #add labels to y-axis    
    fig.update_yaxes(title_text='Progress Score')
    fig.update_yaxes(title_font=dict(size=18, family='Open Sans', color='Dark Blue'),  dtick=10)
    #make gridlines and margins visible
    fig.update_yaxes(showline=True, linewidth=2, linecolor='Light Grey', mirror=True)
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='Light Grey')

    #create grid center lines
    fig.add_hline(y=50, line_color="Orange", line_width=3)
    fig.add_vline(x=50, line_color="Orange", line_width=3)

    return fig

app.run_server( debug=True, port=8082)