Black Lives Matter. Please consider donating to Black Girls Code today.
Dash HoloViews is now available! Check out the docs.

Yet another question about saving/sharing state in Dash

I have an application in which I’d like the user to be able to parametrize and run a simulation by doing the following:

1 - input some parameters in boxes/sliders/figures
2 - press run so that a simulation would execute
3 - visualize the output
4- be able to pause the simulation and step through it

This of course requires saving some state. I’ve done my best to read through the different community posts and docs on Dash about use of global variables and cache to achieve sharing of data between callbacks and persistent data.

I think I have something that works, but it feels a bit “hacky” and it sometimes exhibits odd behavior. I’ve included a simple example below which illustrates how ‘stepping’ through would look like. In a nutshell, I’m using the “iteration” number to cache the state of the simulation, and recursively calling with the previous iteration number (which should be cached, thus avoiding the recursion to go deeper than one level) to recover the state, before updating and therefore providing a new state which becomes cached.

Note that while in this example the state of the sim is easy enough to be “json dumped” into a Div, the real application involves an object with more complicated members/states (numpy vectors, matrices, RandomState, linked lists, BallTrees, etc.) and I’d like to avoid serializing/deserializing that object

I have the following questions if any kind hearts out there would like to help answer:

  1. Is this the best way to implement what I’m trying to achieve (have state in Dash that is
    more complex than what could be saved in a Div through json), or am I over complicating things?

  2. If you run the example, hit enter on the “Step” button and hold it down. The “number of times clicked” and “number of calls” track, but there are times when the caching seems to fail and the recursion restarts from n=0. Any hints as to what would be causing that? (easy to catch when it happens if you set std_dev != 1)

  3. What if I need to pass parameters to sim.update() (like std_dev in this example). Right now
    I’m relying on the cached function as being an internal function to the update function itself.

  4. Are there any links that discuss the perfomance of disk vs. Redis caching (memory / speed / etc.) as I’m trying to figure out whether it’s worth introducing Redis as a dependency

Thanks,


import dash
from dash.dependencies import Input, Output, State
import dash_core_components as dcc
import dash_html_components as html

from flask_caching import Cache

import uuid
import numpy as np

import logging
logging.basicConfig(format='%(levelname)s|%(asctime)s|%(name)s|\t %(message)s',level=logging.WARNING)

log = logging.getLogger('test')
log.setLevel(logging.DEBUG)


class sim(object):
    '''
    Sample class to emulate what a sim class might look like
    with an initialization, some state, and an update function
    '''
    def __init__(self, seed = 0):
        self.n = 0
        self.y = 0
        self.rngState = np.random.RandomState(seed)
        self.nHist = [self.n]
        self.yHist = [self.y]

    def update(self, std_dev = 1):
        self.n += 1
        self.y += self.rngState.randn() * std_dev

        self.nHist.append(self.n)
        self.yHist.append(self.y)


app = dash.Dash()
app.css.config.serve_locally = True
app.scripts.config.serve_locally = True


cache = Cache(app.server, config={
    # 'CACHE_TYPE': 'redis', #Is this better performance wise?

    'CACHE_TYPE': 'filesystem',
    'CACHE_DIR': 'cache-directory',
    
    'CACHE_THRESHOLD': 10
})


#Yes, not kosher, but just for debugging :)
global nRuns
nRuns = 0


#the function itself is not cached
#as we want to allow for std_dev to be changed
def c_updateSimObj(session_id, n, std_dev):

    @cache.memoize(timeout=None)
    def _c_simObj(session_id, n):

        #Looks like using a global variable inside of
        #cached function does weird things ... leaving it out for now
        global nRuns
        nRuns += 1

        log.debug('-' * 10 + "s = {}, n = {}".format(session_id, n) + '-'*10)


        if n == 0:
            log.debug("n == 0: initialize and return object")
            s0 = sim()
            return s0

        #recursively ask for previous time
        #hopefully the previous one will still be cached and this 
        #won't require all n-1 calls to be made!!!!
        log.debug("n == %i: recursive calling with n-1=%i"%(n, n-1))
        st = _c_simObj(session_id, n-1)

        log.debug("Recovered simObj with n = %i"%st.n)

        #step the sim. Note use of parameter that is not part of the 
        #caching function
        st.update(std_dev)

        log.debug("Updated simObj to n = %i - Total Executions : %i"%(st.n, nRuns))

        #what are the implications of letting the cache grow?
        #will the framework handle it gracefully, or should we 
        #clear the cache to avoid things growing? something like the
        #following seems to work (unclear what performance hit is?)
        cache.delete_memoized(_c_simObj, session_id, n-1)

        #Also experimented with clearing periodically rather than at every call:
        # if n%10 == 0: for m in range(n-10,n): cache.delete_memoized(c_simObj, session_id, m)

        return st

    #Call _c_simObj without std_dev which might change
    return _c_simObj(session_id, n)

#Layout in function to ensure that session id is unique
#everytime
def serve_layout():
    session_id = str(uuid.uuid4())

    return html.Div([
        html.H1(session_id, id='session-id'),
        html.Div(style={'width':300}, children = [
                 html.P('Standard Deviation'),
                 dcc.Slider(id='std_dev',value=1,min=0,max=5,step=0.1, marks={i: '{}'.format(i) for i in range(6)}),
                 html.P('')]),
        html.Button('Step', id='button', style={'height':50, 'width':100}),
        dcc.Graph(id='graph'),
        html.Div(id='textOutput')
    ])

app.layout = serve_layout


@app.callback(Output('textOutput', 'children'),
              [Input('button', 'n_clicks'),
               Input('session-id', 'children')])
def display_value(value, session_id):
    global nRuns

    return dcc.Markdown('''
Button has been clicked {} times

c_SimObj has been run {} times
        '''.format(value, nRuns))


#This is not Dash Kosher, but
#using as a way to compare to cached implementation
global globalSimObj
globalSimObj = sim()


@app.callback(Output('graph', 'figure'),
              [Input('button', 'n_clicks'),
               Input('session-id', 'children')],
              [State('std_dev','value')])
def display_graph(n_iter, session_id, std_dev):
   
    if n_iter == None: n_iter = 0
    if std_dev == None: std_dev = 1


    ####### The not OK way #######
    global globalSimObj
    #this is just an easy way to catch up to the number of
    #clicks in case session is restarted
    while(globalSimObj.n < n_iter): 
        globalSimObj.update(std_dev = std_dev)

    ####### The cached attemp #######

    #Calling the cached function which will return the 
    #sim object. Either by creating one, or getting
    #a cached copy and updating it. 
    cachedSimObj = c_updateSimObj(session_id, n_iter, std_dev) 

    ####### Plotting to visually check that both methods match ##### 
    data = [dict(type='scatter', name='global',
                 x = globalSimObj.nHist, y = globalSimObj.yHist), 
            dict(type='scatter', name='cached',
                #Offsetting y trace to be able to easily compare
                x = cachedSimObj.nHist, y = np.array(cachedSimObj.yHist) + .1)]
    layout = dict(type='scatter', mode='markers+lines')
    fig = dict(data=data, layout=layout)

    return fig


if __name__ == '__main__':
    app.run_server(debug=True)

After some digging, I eventually stumbled on https://flask-sessionstore.readthedocs.io/en/latest/

It seems to be doing what I’d like (although with filesystem things are slow and sometimes stale data is served?)

I’m surprised this isn’t discussed as one of the ways to share data amongst callbacks or to store state data for sessions.

I’m new to flask/dash, but it’d be great for someone with more in depth understanding of the two Frameworks to discuss the use of (server or client side) flask.session

Here is a simpler implementation (I think) which relies on Flask sessions:

Couple of caveats:

  • Looks like latest Dash broke app layouts being set to a function, so this code might not work in some versions
  • Code works as expected when deployed on pythonanywhere, but when run locally, updating the stored object seems to not properly propagate to the session dictionary.

Since this is using Flask.sessionstore, it means the data is all on the serverside, hence not requiring any data to be sent from/to the client (unlike using hidden divs), the other benefit is that the data does not need to be json serializable!

Hopefully this helps someone out there, but I’d still love to hear from one of the Dash experts on what their opinion on using this as an alternative to caching or hidden divs as suggested in the tutorials.

Cheers,


import dash
from dash.dependencies import Input, Output, State
import dash_core_components as dcc
import dash_html_components as html

import time
import uuid
import numpy as np

import logging
logging.basicConfig(format='%(levelname)s|%(asctime)s|%(name)s|\t %(message)s',level=logging.WARNING)

log = logging.getLogger('test')
log.setLevel(logging.DEBUG)


#Using Flas Session to keep track of global data
from flask import session as f_session
from flask_sessionstore import Session


class sim(object):
    '''
    Sample class to emulate what a sim class might look like
    with an initialization, some state, and an update function
    '''
    def __init__(self, seed = 0, session_id = None):
        self.n = 0
        self.y = 0
        self.rngState = np.random.RandomState(seed)
        self.nHist = [self.n]
        self.yHist = [self.y]

        self.session_id = session_id
        self.initTime = time.ctime()

    def update(self, std_dev = 1):
        self.n += 1
        self.y += self.rngState.randn() * std_dev

        self.nHist.append(self.n)
        self.yHist.append(self.y)


app = dash.Dash()
app.css.config.serve_locally = True
app.scripts.config.serve_locally = True

#Set up server side session
SESSION_TYPE='filesystem'
SESSION_FILE_DIR = 'cache-directory'
SESSION_FILE_THRESHOLD = 100
app.server.config.from_object(__name__)
app.server.secret_key = 'thisisonlyasecretefromtheuser!'.encode('utf8')
Session(app.server)


#Layout in function to ensure that session id is unique
#everytime
def serve_layout():
    session_id = str(uuid.uuid4())

    return html.Div([
        html.H1(session_id, id='session-id'),
        html.Div(style={'width':300}, children = [
                 html.P('Standard Deviation'),
                 dcc.Slider(id='std_dev',value=1,min=0,max=5,step=0.1, marks={i: '{}'.format(i) for i in range(6)}),
                 html.P('')]),
        html.Button('Step', id='button', style={'height':50, 'width':100}),
        dcc.Graph(id='graph'),
        html.Div(id='textOutput')
    ])

app.layout = serve_layout


@app.callback(Output('textOutput', 'children'),
              [Input('button', 'n_clicks'),
               Input('session-id', 'children')])
def display_value(value, session_id):
    global globalSimObj

    g_init  = '[not yet init]' if globalSimObj == None else globalSimObj.initTime
    g_id    = '[not yet init]' if globalSimObj == None else str(globalSimObj.session_id)[0:5]

    skey = 'simObj_' +session_id[0:5]
    s_simObj = f_session.get(skey, None)
    s_init   = '[not yet init]' if s_simObj == None else s_simObj.initTime 
    s_id     = '[not yet init]' if s_simObj == None else str(s_simObj.session_id)[0:5]

    return dcc.Markdown('''
Button has been clicked {} times

globalSimObj [{}] was initialized at {}

sessioSimObj [{}] was initialized at {}
        '''.format(value, 
                   g_id, g_init, s_id, s_init))


#This is not Dash Kosher, but
#using as a way to compare to session implementation
global globalSimObj
globalSimObj = None

@app.callback(Output('graph', 'figure'),
              [Input('button', 'n_clicks'),
               Input('session-id', 'children')],
              [State('std_dev','value')])
def display_graph(n_iter, session_id, std_dev):
   
    if n_iter == None: n_iter = 0
    if std_dev == None: std_dev = 1


    ####### The not OK way #######
    global globalSimObj
    if globalSimObj == None: 
        globalSimObj = sim(session_id = session_id)


    #retrieve simObj from session, using session id as part of the key
    skey = 'simObj_' + session_id[0:5]
    if not skey in f_session:
        f_session[skey] = sim(session_id=session_id)

    s_simObj = f_session[skey]

    log.debug('g_simObj.n = %i' % globalSimObj.n)
    log.debug('s_simObj.n = %i' % s_simObj.n)


    s_simObj.update(std_dev = std_dev)
    globalSimObj.update(std_dev = std_dev)


    ####### Plotting to visually check that both methods match ##### 
    data = [dict(type='scatter', name='global',
                 x = globalSimObj.nHist, y = globalSimObj.yHist), 
            dict(type='scatter', name='cached',
                #Offsetting y trace to be able to easily compare
                x = s_simObj.nHist, y = np.array(s_simObj.yHist) + .1)]
    layout = dict(type='scatter', mode='markers+lines')
    fig = dict(data=data, layout=layout)

    return fig

if __name__ == '__main__':
    app.run_server(debug=True)

Not sure whether this would help you out as much as you want, but you might want to check out the new dcc.Storage component, that has three different modes of saving data which might offer better compromises of speed/capacity.

Thanks @Mike3
I’ve managed to use it in on of my projects. Another one relies on data that is difficult to JSON serialize, so I’ll see if it’s worthwhile to use it.