Dash Shared State with huge Python Objects: Interacting with Machine Learning Models

Use Case:

I have a big Machine Learning model (1-2GB) that I want to interact with using Dash. I use Dash dropdowns and sliders to create novel data points that are fed to the ML model to make predictions. Hence, I can nicely use Dash to interactively poke my ML model and really get a feeling how different changes in features yield different predictions.

Moreover, the model should be re-trained regularly when new real-time data comes in (~every 10 minutes).

The Problem:

What is the right way to keep my big model alive in a Dash framework?

Currently, I load the model as a global variable on my Dash app startup. This is fine for a first proto-type, but probably not a good idea in the long run. Especially when I re-train the model, the global state of the entire Dash app changes and might have undesired side effects.

To allow for shared state between callbacks Dash has several suggestions:

  • Use a hidden or store div → This is really not useful here as passing around 1-2GB from front to backend doesn’t work

  • Use file or Redis cache → Almost as unpractical as the previous above suggestion, I cannot de-serialize my ML model every time a Dash component is triggered, constructing a 1-2 GB Python object (including a complex mixture of sklearn and PyTorch pipelines) from a serialized JSON/Pickle takes 30 seconds or more. Gone is the nice interaction I can have with the model.

What is the best practice to handle a big Python object such as an ML model in Dash that needs to be shared between callbacks? How to handle occasional state changes, such as re-training, of this big shared object?

Thanks!

I have used the previous approach in the past,

  • Create a task for the training process. Schedule this task to run when new data arrives, e.g. every 10 minutes (how long does your training take). The output if this task should be a servable model
  • Create a task for serving the model. I used tensorflow serving for tensorflow models, and a small custom Flask server for other models. This task should monitor the model directory so that new models are loaded as they arrive
  • In your Dash app, query the process which serves the model for predictions. It will always have the latest model loaded, so you get predictions more-or-less instantly

Thanks, I thought about having the model served separately, too.

Yet, I suppose the easier way to have the model in the DashApp directly is probably impossible?

It depends on the complexity of the model. I think in most cases the prediction step is fast (seconds), and you could thus do it in a callback without any issues. However, in your case the models are huge, so you need to preload them. The safe bet is a separate serving process, but preloading into a redis cache might be sufficient.

1 Like