Plot doesn't fully render when I update the color parameter

I used plotly/dash to create a website called www.kpgraphs.com that creates plots from college basketball data. I’m very proud of how it turned out except for one minor issue. When changing the “color by” option from conference to team and back to conference, the legend overlaps with the plot and covers some of the data. You can see this behavior by going to the page and trying it yourself. Is there some way to force the figure to be redrawn so this does not happen? Here is the portion of the code that creates the first figure on the page:

@app.callback(
    [Output('fig1', 'figure')],
    [Input('stat-column', 'value'),
    Input('number-teams', 'value'),
    Input('conf', 'value'),
    Input('color-by', 'value')]
)

def update_figure_1(stat_column_name, number_teams, conf, color_by):
    global dff
    if stat_column_name in ['AdjD', 'OppD']:
        dff = df.loc[df['Conf'].isin(conf)].sort_values(by=stat_column_name, ascending=True).head(number_teams).reset_index().copy()
    else: dff = df.loc[df['Conf'].isin(conf)].sort_values(by=stat_column_name, ascending=False).head(number_teams).reset_index().copy()

    if color_by == 'Team':
        fig1 = px.bar(data_frame=dff,
                    x='Team',
                    y=stat_column_name,
                    color='Team',
                    color_discrete_map=COLOR_DICT,
                    category_orders={'Team': list(dff['Team'])}
                    )
    else:
        fig1 = px.bar(data_frame=dff,
                    x='Team',
                    y=stat_column_name,
                    color='Conf',
                    color_discrete_map=COLORS,
                    category_orders={'Team': list(dff['Team'])}
                    )
    fig1.update_traces(hovertemplate='%{x}: %{y}')
    fig1.update_traces(marker=dict(line=dict(
                                            width=2,
                                            color='DarkSlateGrey')))
    fig1.update_yaxes(range=[min(dff[stat_column_name]) - abs((min(dff[stat_column_name])*.2)), max(5, max(dff[stat_column_name])*1.15)])

    fig1.update_layout(transition_duration=500)
    fig1.update_yaxes(title=stat_column_name)
    fig1.update_xaxes(title='')
    return [fig1]