Parallel processing with pattern matching

Hello,

I am trying to use background=True alongside pattern matching to be able to update N plots in parallel. The issue I’m running into is it only seems to want to use 1 thread per callback, instead of per “matched” callback. For example, in the MRE below, I have a function that generates a plot in 5s, 2s, and instantly. If you click on one of the longer running ones, then the short running one, it will kill the first task and just proceed with the latest called function.

For reference, my end goal here is the have N plots that update independently based on input on other plots. So parallel pattern-matched cross-filtering!

import os
import time
import numpy as np
import dash_bootstrap_components as dbc
from dash import Dash, CeleryManager, Input, Output, html, callback, MATCH, ctx, dcc
from celery import Celery

def create_random_graph(n_pts):
    return {
        'data': [{'x': np.arange(n_pts) / 10, 'y': np.arange(n_pts)/ 10, 'type': 'scatter', 'mode': 'markers', 'selected': {'marker': {'color': 'red'}}}],
        'layout': {'title': 'Graph', 'clickmode': 'event+select', 'dragmode': 'select', 'newselection_mode': 'gradual', 'xaxis': {'range': [0, 1]}, 'yaxis': {'range': [0, 1]}},
    }

celery_app = Celery(__name__, broker=os.environ['REDIS_URL'], backend=os.environ['REDIS_URL'])
celery_app.conf.update(
    # Breaks because "kill is not implemented"
    # worker_pool='threads',
    
    # This doesn't throw error, but isn't parallel either
    worker_pool='prefork',
    task_serializer='json',
    accept_content=['json'],
    result_serializer='json',
    worker_concurrency=4,
)
background_callback_manager = CeleryManager(celery_app)

app = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP], background_callback_manager=background_callback_manager)

app.layout = html.Div(
    [   
        dbc.Row([
            dbc.Col(dcc.Graph(id={'type': 'dynamic-graph', 'index': 1}, figure=create_random_graph(1)), width=4),
            dbc.Col(dcc.Graph(id={'type': 'dynamic-graph', 'index': 2}, figure=create_random_graph(1)), width=4),
            dbc.Col(dcc.Graph(id={'type': 'dynamic-graph', 'index': 3}, figure=create_random_graph(1)), width=4),
        ]),
        dbc.Row([
            dbc.Col(html.Button(id={'type': "button", 'index': 1}, children="Update graph 1")),
            dbc.Col(html.Button(id={'type': "button", 'index': 2}, children="Update graph 2")),
            dbc.Col(html.Button(id={'type': "button", 'index': 3}, children="Update graph 3")),
        ]),
    ]
)

@callback(
    output=Output({'type': 'dynamic-graph', 'index': MATCH}, "figure"),
    inputs=Input({'type': 'button', 'index': MATCH}, "n_clicks"),
    prevent_initial_call=True,
    background=True,
)
def update_graph1(n_clicks):
    graph_index = ctx.triggered_id['index']

    if graph_index == 1:
        time.sleep(5.0)
    elif graph_index == 2:
        time.sleep(2.0)

    if n_clicks is None:
        n_clicks = 1
    return create_random_graph(n_clicks + 1)

if __name__ == "__main__":
    app.run(debug=True)

Thanks for any and all help!

Isaac