Loading custom classes and passing between callbacks

Hi everyone,

I’ve been trying to develop a web app that can load a custom class that we’ve developed for some predictive modeling work. If this is the wrong place, already been answered, or another tag needs to be added please let me know.

The main flow of the app would be to select the model version that you want to explore. The first tab would be to select the model from a list, the second tab would show some metrics about the model that have already been calculated in the class, and then another tab to explore the data within the model itself.

The model is an sklearn model. Within the class is the saved fitted model, the data, and the figures all pregenerated during the fitting process. We want to avoid re-doing all the math and fitting in the dash app for speed.

Below shows the approximate model structure.

class model(ClassifierMixin, BaseEstimator):

    def __init__(self):
        return None
    
    def getdata(self):
        # Get the data sets
        self.X_ = get_x_data
        self.y_ = get_y_data
        return self

    def fit(self):
        # Fit the model.
        # X is features
        # y is target
        self.fitted_model = fit(X,y)
        return self
    
    def fig1(self, columns);
        # Plots the data associated with the column
        self.fig1 = plot_fig_one(self.X_, column)
        return self

The second part would be creating the dash app that can display the different versions of the model. The main aspect here is that I want to get the model loaded in one callback, and then passed to all the others.

import os
import sys
from pickle import load
import plotly.express as px
import matplotlib.pyplot as plt
import dash
from dash import dcc
from dash import html
import dash_bootstrap_components as dbc
from dash import Dash, dcc, html, Input, Output, callback, dash_table

external_stylesheets = [dbc.themes.BOOTSTRAP]
app = dash.Dash(__name__, external_stylesheets=external_stylesheets)

model_select_tab = dbc.Col([
    dcc.Dropdown(
        options = list(os.listdir('models/')), 
        value = f"{model_load_date}_xgb_74", 
        id='model_select_dropdown'
    ),
    dcc.Store(id = 'model'), # Not sure how to get the store to work for this
    dcc.Graph(id = 'acc_fig')
])

@callback(
    Output('model', 'data'), # Not sure how to get the store to work for this
    Output('acc_fig', 'figure'),
    Input('model_select_dropdown', 'value')
)
def update_data_exp_tab(value):
    with open(f"models/{value}", "rb") as f:
        model = load(f)
    
    accuracy_df = model.accuracy_df
    acc_fig = px.bar(
        accuracy_df,
        x="model",
        color="metric",
        y="mean_val",
        error_y="std_val",
        barmode="group",
        labels={
            "model": "Model",
            "metric": "Evaluation Metric",
            "mean_val": "Mean value across K-Folds",
        },
    ).update_yaxes(range=[0, 1.02], gridcolor="Black")
    print('done with fig')
    
    return acc_fig

data_exp_tab = dbc.Col([
    dcc.Dropdown(
        options = list(model.columns), 
        value = list(model.columns)[0], 
        id='data_exp_dropdown'
    ),
    dbc.Row([
        dbc.Col([
            html.Img(id='data_exp_plot')
        ], style={'textAlign': 'center'}),
    ])
])


@callback(
    Output(component_id='data_exp_plot', component_property='src'),
    Input('data_exp_dropdown', 'value'),
    Input('model', 'data'), # Not sure how to pass the unique class in.
)
def update_data_exp_plot(value, model):
    # I could load the model fresh again here, but I want to avoid that overhead.
    fig = model.fig1 # Pyplot figure generated in the training process previously.

    fig.savefig(buf, format="png", bbox_inches='tight')
    fig_data = base64.b64encode(buf.getbuffer()).decode("ascii")
    fig_bar_matplotlib = f'data:image/png;base64,{fig_data}'

    return fig_bar_matplotlib


app.layout = html.Div([
    html.H1('Dash Tabs component demo'),
    dcc.Tabs(
        id="tabs-example-graph", 
        value='tab-1-example-graph', 
        children=[
            dcc.Tab(
                label='Model Select', 
                value='cnt_sum',
                children = model_select_tab),
            dcc.Tab(
                label='Data Exploration', 
                value='data_exp',
                children = data_exp_tab),
        ])
])
    
if __name__ == '__main__':
    app.run(debug=True)

There are a gew approaches you could try but it’s hard to know without seeing all the details.

To store things in the ‘data’ of a dcc.Store component it must be json-serializable; you cannot pass arbitrary objects such as pd.DataFrame into it. Therefore if you using dataframes or some other python object you will have to use their respective serialization / deserialization functions (or make your own). The dcc.Store also has a relatively small limit to how much data you should store in it, typically suggested only up to a few MB.

Another option to look at is using caching or in-memory db. A simple implementation of this could be as follows, assuming you have a small number of models which can be referred to by say a string.

from functools import lru_cache

@lru_cache(maxsize=5) # or your number of models
def load_model(model_name: str):
    return load_the_model(model_name)

This is essentially creating a wrapper around your model call so that if you use the function load_model in a callback - the first time any particular model loads it will take time, but subsequent calls after that would retrieve it from memory and it will be much quicker.

Hopefully that gets you on track, cheers