Problem shared xaxis with subplot order and strange labels

Hello,

I am using plotly with dash. I want to plot a scatter plot and a bar chart related to it so I am using a shared xaxis. I am using a callback to select which feature to plot. However, I get a strange behavior: for some features I have xlabels which are not related to it.

Here is an example:


I got xlabels which are not for this feature

However, when I stop sharing the xaxis, the problem does not appear:

Here is the code for the figure:

def plot_feature_stats(df, name_feature):
    """ Plot feature statistics
    """
    # select right feature
    df = df[df['feature'] == name_feature]
    # create figure
    fig = make_subplots(
        rows=2, cols=1, #specs=[[{}], [{}]],
        shared_xaxes=True, shared_yaxes=False,
        row_heights=[0.7, 0.3]
    )
    fig.append_trace(go.Bar(
        x=df['category'], y=df['n_elements'],
        name=''.join(['Number of elements for each category']),
    ), row=2, col=1)
    fig.append_trace(go.Scatter(
        x=df['category'], y=df['mean_prediction'],
        mode='lines+markers',
        name='Mean prediction',
    ), row=1, col=1)
    fig.append_trace(go.Scatter(
        x=df['category'], y=df['mean_target'],
        mode='lines+markers',
        name='Mean target',
    ), row=1, col=1)
    fig.append_trace(go.Scatter(
        x=df['category'], y=df['mean_prediction_model1'],
        mode='lines+markers',
        name='Mean prediction Model1',
    ), row=1, col=1)
    fig.append_trace(go.Scatter(
        x=df['category'], y=df['mean_prediction_model2'],
        mode='lines+markers',
        name='Mean predictions Model2',
    ), row=1, col=1)
    fig.update_layout(
        title=''.join(['Predictions for different values of the feature ', name_feature]),
        #width=1600,
        height=800,
    )
    fig.update_xaxes(type='category', row=1, col=1)
    fig.update_xaxes(type='category', row=2, col=1)
    return fig