Question about matching callback and patch method in multiple filter on ag-grid

(I modified detail of question and add entire code)

I want to apply multiple filter to data table in my dash app.

It works like these :

  1. Choose a table
  2. Choose a column to add
  3. Add this column.
  4. Choose wanted values (categorical column) or wanted range value (numeric or date column)
  5. Apply filter

And you can go back to 2 to add another column filter.

It works with same type of filter. For example, if you want to filtered data just bewteen categorical data using ‘filter_cat’, filters work sequentially. (Firstly filtered data → second filter apply → secondly filtered data)

However, if I want to apply different type of filter (filter for categorical value and filter for numerical value and datetime value), it does not work sequentially.

It means

  1. Firstly filtered data by categorical column filter
  2. Apply numeric column filter, but it does not apply filter on firstly filtered data, but on original data.

Moreover, firstly applied filter does not work anymore after adding second filter.

from dash import Dash,html,dcc,Input,Output,State,Patch,MATCH,ALL,ctx
from dash.exceptions import PreventUpdate
import dash_ag_grid as dag
import pandas as pd
import plotly.express as px
import dash_bootstrap_components as dbc

app = Dash(__name__)
# Sample data for demonstration
data_table1 = pd.DataFrame({
    'Category1': ['A', 'B', 'C', 'A', 'B'],
    'Category2': ['X', 'Y', 'X', 'Y', 'Z'],
    'Numeric1': [10, 15, 8, 12, 6],
    'Numeric2': [100, 200, 150, 50, 300],
    'Date1': pd.to_datetime(['2023-09-01', '2023-09-02', '2023-09-03', '2023-09-04', '2023-09-05']),
    'Date2': pd.to_datetime(['2023-09-01 08:00', '2023-09-02 10:00', '2023-09-03 12:00', '2023-09-04 14:00', '2023-09-05 16:00'])
})


data_table2 = pd.DataFrame({
    'Category3': ['A', 'B', 'C', 'A', 'B'],
    'Category4': ['X', 'Y', 'X', 'Y', 'Z'],
    'Numeric3': [10, 15, 8, 12, 6],
    'Numeric4': [100, 200, 150, 50, 300],
    'Date3': pd.to_datetime(['2023-09-01', '2023-09-02', '2023-09-03', '2023-09-04', '2023-09-05']),
    'Date4': pd.to_datetime(['2023-09-10 08:00', '2023-09-12 10:00', '2023-09-13 12:00', '2023-09-14 14:00', '2023-09-15 16:00'])
})



rowClassRules = {
    # apply green to 2008
"rounded": True,
}
rowStyle={
"border-radius": "10px"
}
defaultColDef = {
    "resizable": True,
    "sortable": True, 
    "filter": True,
    "initialWidth": 200,
    "wrapHeaderText": True,
    "autoHeaderHeight": True,
    "headerClass": 'center-header', "cellStyle": {'textAlign': 'center'}
}


table_configs = {
    "table1": {
        "df": data_table1,
        "columns": data_table1.columns,
    },
    "table2": {
        "df": data_table2,
        "columns": data_table2.columns,
    },
}
def get_selected_dataframe(selected_table):
    if selected_table == "table1":
        return data_table1
    elif selected_table == "table2":
        return data_table2
    else:
        return pd.DataFrame()
    


list_table = ['table1','table2']

dropdown_table = dcc.Dropdown(
            options=[{'label': i, 'value': i} for i in list_table],
            value = 'table1',
            id="filter_table",
            # clearable=False,
            style={"marginBottom": 10},
            multi=False
        )

dropdown_var_filter = dcc.Dropdown(
                        id='filter_variable_to_show',
                        options=[],
                        persistence=True,
                        multi=True,
                        placeholder='Select a table...',)

second_filter = dcc.Dropdown(
                        id='second_filter',
                        options=[],
                        value=[],
                        multi=False,
                        persistence=True,
                        placeholder='Select a columns...',)



table_output = html.Div(id='table_output')

@app.callback(
    Output('update-rowdata-grid', 'rowData'),
    Input('apply_filter_btn','n_clicks'),
    State({'type': 'filter_cat',"table":ALL ,'index': ALL}, 'value'),
    State({'type': 'filter_cat',"table":ALL ,'index': ALL}, 'id'),
    State({'type': 'filter_num','table':ALL, 'index': ALL}, 'value'),
    State({'type': 'filter_num',"table":ALL ,'index': ALL}, 'id'),
    State({'type': 'filter_date','table':ALL, 'index': ALL}, 'start_date'),
    State({'type': 'filter_date','table':ALL, 'index': ALL}, 'end_date'),
    State({'type': 'filter_date','table':ALL, 'index': ALL}, 'id'),
    State('filter_table', 'value'),
    State('second_filter','value'),
    prevent_initial_call=True
    
)
def apply_filter(n_clicks,cat,cat_id,num,num_id,start_date,end_date,date_id,selected_table,selected_columns):
    df = get_selected_dataframe(selected_table)
    dff = df.copy()
    column_type = df[selected_columns].dtype
    if n_clicks > 0 :
        print(n_clicks)
        if column_type == 'object' and cat[0]:
            # Without for, we cannot assign proper unique value to each column.
            # For example, cat could have a [['X'],['A']]. Here, 'X' is from column 1 and 'A' is from column 2
            # To link each unique value to proper column, I should use cat_id, containing information about column
            # And we should iterate it using for loop. dff is updated for each column.
            print('cat_filter')
            for idx,value in enumerate(cat_id):
                dff = dff[dff[value['index']].isin(cat[idx])]
                
        if column_type in ['int64', 'float64'] and num[0]:
            # Same as cat. But it is composed of two element (min & max value). We have to extract both
            print('num_filter')
            for idx,value in enumerate(num_id):
                dff = dff[(dff[value['index']] >= num[idx][0]) & (dff[value['index']] <= num[idx][1])]
                
        if column_type == 'datetime64[ns]' and start_date and end_date:
            # Same as cat and num.
            print('date_filter')
            for idx,value in enumerate(date_id):
                dff = dff[(dff[value['index']] >= start_date[idx]) & (dff[value['index']] <= end_date[idx])]
    return dff.to_dict('records')

@app.callback(
        Output('second_filter', 'options',allow_duplicate=True),
        Input({"type": "filter_column", "index": ALL},'value'),
        Input({"type": "filter_column", "index": ALL},'id'),
        Input('filter_table','value'),
        prevent_initial_call='initial_duplicate'
    )
def update_filter(value,col_id,selected_table):
    df = get_selected_dataframe(selected_table)
    if value :
        return [{"label": col, "value": col} for col in df.drop(columns = list(value),axis=1).columns]
    else :
        return [{"label": col, "value": col} for col in df.columns]



@app.callback(
    Output('filter_container','children',allow_duplicate=True),
    Input('add_filter_btn','n_clicks'),
    State("second_filter", "value"),
    State('filter_table','value'),
    prevent_initial_call='initial_duplicate'
)
def add_filter(n_clicks,selected_columns,selected_table):
    patched_children = Patch()
    df = get_selected_dataframe(selected_table)
    columns = df.columns
    
    if n_clicks != None and selected_columns:
        column_type = df[selected_columns].dtype
        if column_type == 'object':
            unique_values = df[selected_columns].unique()
            new_filter = html.Div([
                    html.H4('Used filter'),
                    dcc.Dropdown(
                        id={"type": "filter_column", "index": n_clicks},
                        value=selected_columns,
                        options=[{"label": col, "value": col} for col in columns],
                        disabled=True,
                    ),
                    dcc.Dropdown(
                    id={"type": "filter_cat", "table": selected_table, "index": selected_columns},
                    options=[{"label": str(val), "value": val} for val in unique_values],
                    placeholder="Select a value",
                    multi=True,
            ),
                    dbc.Row(dbc.Button("X", id={"type": "remove_btn", "table": selected_table, "index": selected_columns}, color="primary"))
            ])
        elif column_type in ['int64', 'float64']:
        # For Integer & Float type, create slider filter
            min_val = df[selected_columns].min()
            max_val = df[selected_columns].max()
            new_filter = html.Div([
                html.H4('Used filter'),
                    dcc.Dropdown(
                        id={"type": "filter_column", "index": n_clicks},
                        value=selected_columns,
                        options=[{"label": col, "value": col} for col in columns],
                        disabled=True,
                    ),
                    dcc.RangeSlider(
                    id={"type": "filter_num", "table": selected_table, "index": selected_columns},
                    min = min_val,
                    max = max_val,
            ),
                    dbc.Row(dbc.Button("X", id={"type": "remove_btn", "table": selected_table, "index": selected_columns}, color="primary"))
            ])
        elif column_type == 'datetime64[ns]':
        # For Integer & Float type, create slider filter
            min_date = df[selected_columns].min()
            max_date = df[selected_columns].max()
            new_filter = html.Div([
                html.H4('Used filter'),
                    dcc.Dropdown(
                        id={"type": "filter_column", "index": n_clicks},
                        value=selected_columns,
                        options=[{"label": col, "value": col} for col in columns],
                        disabled=True,
                    ),
                    dcc.DatePickerRange(
                        id={"type": "filter_date", "table": selected_table, "index": selected_columns},
                        min_date_allowed=min_date,
                        max_date_allowed=max_date,
                        display_format='DD/MM/YYYY',
                        clearable=True,
                        start_date=min_date,
                        end_date=max_date
                        

            ),
                    dbc.Row(dbc.Button("X", id={"type": "remove_btn", "table": selected_table, "index": selected_columns}, color="primary"))
            ])    
        patched_children.append(new_filter)
        return patched_children
    
    return patched_children

# @app.callback(
#     Output('filter_container','children',allow_duplicate=True),
#     Input({"type": "remove_btn", "table": ALL, "index": ALL},'n_clicks'),
#     prevent_initial_call=True
# )
# def remove_param_filter(n_clicks):
#     if n_clicks :
#         return None

@app.callback(
            Output('second_filter', 'options',allow_duplicate=True),
            Output('filter_container','children',allow_duplicate=True),
            Output('update-rowdata-grid', 'rowData',allow_duplicate=True),
             Input('clear-button','n_clicks'),
             State('filter_table', 'value'),
             prevent_initial_call=True)
def reset_filters(n_clicks, selected_table):
    if n_clicks:
        df = get_selected_dataframe(selected_table)
        return [{"label": col, "value": col} for col in df.columns],None,df.to_dict('records')
    else:
        raise PreventUpdate




@app.callback(
    Output('filter_variable_to_show','options'),
    Input('filter_table', 'value'),
)
def filter_col(selected_table):
    df = get_selected_dataframe(selected_table)
    return [{"label": col, "value": col} for col in df.columns]


@app.callback(
        Output('second_filter', 'options'),
        Input('filter_variable_to_show','value'),
        Input('filter_table','value')
    )
def update_filter(value,selected_table):
    df = get_selected_dataframe(selected_table)
    if value :
        return value
    else :
        return [{"label": col, "value": col} for col in df.columns]
    
@app.callback(
        Output('filter_container','children',allow_duplicate=True),
        Output('table_output', 'children'),
        Input('filter_table', 'value'),
        Input('filter_variable_to_show','value'),
        prevent_initial_call='initial_duplicate'
    )
def update_table(value,selected_columns):
    config = table_configs.get(value)
    if config:
        df = config["df"]
        if selected_columns:
            df = df[selected_columns]
        table = dag.AgGrid(
            id = "update-rowdata-grid",
            rowData=df.to_dict('records'),
            defaultColDef=defaultColDef,
            columnDefs=[{'field':i} for i in df.columns],
            columnSize="autoSize",
            dashGridOptions={"pagination": True},
            className="ag-theme-alpine",
            rowClassRules=rowClassRules,
            rowStyle=rowStyle,
        )
        return None,table
filter_container=html.Div(id="filter_container", children=[])


filter=dbc.Card(
    [
        dbc.CardHeader(html.H3("Filter")),
        dbc.CardBody(
            [
                dbc.Row(
                    children=[second_filter,
                            filter_container,
                            html.Hr()],
                    style={"height": "80%"},  # Adjust the height as per your requirement
                )
            ]
        ),
        dbc.CardBody(
            [
                dbc.Row(
                    children=[
                        # Apply button to create filter. After that, I want to create new filter section.
                        dbc.Col(dbc.Button("Add", id="add_filter_btn", color="primary"),width=6,),
                        dbc.Col(dbc.Button("Apply", id="apply_filter_btn", color="primary"),width=6,),
                        # Clear button to remove all filters
                        dbc.Col(dbc.Button("Clear", id="clear-button", color="danger"),width=6,),
                        html.Hr(),
                        
                    ],
                    style={"height": "20%"},  # Adjust the height as per your requirement
                )
            ]
        ),])

app.layout = html.Div(children = [
    html.H3("Table selection"),
    dropdown_table,
    html.Hr(),
    html.H3("Variable To SHOW"),
    dropdown_var_filter,
    html.Hr(),
    filter,
    html.H3("Output Table"),
    table_output])


# graph = dcc.Graph(id="my-graph",figure={})






if __name__ == "__main__":
    app.run_server(debug=True)

The code is quite long because it is entire test app, but you can focus on just two callback function : ‘add_filter’ and ‘apply_filter’

What I want to do is overlapping multiple filter bewteen different type of filter. So if I apply filter A => firstly filtered data. And after I apply filter B => secondly filtered data from filtered data.

How can I do that ?