✊🏿 Black Lives Matter. Please consider donating to Black Girls Code today.
⚡️ Concerned about the grid? Kyle Baranko teaches how to predicting peak loads using XGBoost. Register for the August webinar!

How to make Plotly animation with dcc.graph and slider smoother?

Hi, I created a scatter mapbox animation with dcc.Interval, dcc.Graph and dcc.Slider, instead of using the plotly animation(the use of animation_frame and animation_group). This is because I want to get back the current value of slider at every interval of the animation and anytime the user pauses the animation. I could not find a way to retrieve the current value of slider from the plotly animation.

The user will need to click the Play button to enable the dcc.Interval and the slider changes its value at intervals. When the value of slider changes, the graph will be updated accordingly.

In the following codes, it put a lot of works rendering every frame because the ‘update_figure’ callback does a lot of works and return huge data to the dcc.Graph everytime when the value of slider changes.

Is there any way to load all data into the graph at the start and just switch frames when the value of slider changes?

import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output, State
from dash.exceptions import PreventUpdate

import numpy as np
import pandas as pd
import plotly.graph_objs as go

app = dash.Dash(__name__)


def set_slider_calendar(dataframe):
    calendar = []
    for i in range(0, dataframe.shape[0]):
        value = dataframe[i] if i % 7 == 0 else ''
        calendar.append(value)
    return calendar


access_token = 'mapbox_token'
data_url = 'https://shahinrostami.com/datasets/time-series-19-covid-combined.csv'
data = pd.read_csv(data_url)
missing_states = pd.isnull(data['Province/State'])
data.loc[missing_states, 'Province/State'] = data.loc[missing_states, 'Country/Region']
data['Active'] = data['Confirmed'] - data['Recovered'] - data['Deaths']
for i in range(0, data.shape[0]):
    data.loc[i, 'Size'] = data.loc[i, 'Active'] if data.loc[i, 'Active'] < 50 else 50
data = data.dropna()
df_date = data['Date'].unique()

app.layout = html.Div([
    dcc.Interval(
        id='interval',
        interval=2000,
        n_intervals=0,
        max_intervals=df_date.shape[0] - 1,
        disabled=True
    ),
    dcc.Graph(id='my_graph'),
    dcc.Slider(
        id='slider',
        min=0,
        max=df_date.shape[0] - 1,
        value=0,
        marks={str(i): str(des) for i, des in zip(range(0, df_date.shape[0]), set_slider_calendar(df_date))},
        dots=True
    ),
    html.Div(id='label', style={'margin-top': 20}),
    html.Button('Play', id='my_btn'),
])


@app.callback([Output('interval', 'disabled'), Output('my_btn', 'children')],
              [Input('my_btn', 'n_clicks')],
              [State('interval', 'disabled')])
def display_value(click, value):
    print('click', value)
    if click:
        new_value = not value
        btn_name = 'Play' if new_value else 'Pause'
        return new_value, btn_name
    else:
        raise PreventUpdate


@app.callback(Output('label', 'children'),
              [Input('slider', 'value')])
def display_value(value):
    return f'Selected Calendar: {df_date[value]} '


@app.callback(Output('slider', 'value'),
              [Input('interval', 'n_intervals')])
def update_slider(num):
    return num


@app.callback(
    Output('my_graph', 'figure'),
    [Input('slider', 'value')])
def update_figure(selected_year):
    filtered_df = data[data['Date'] == df_date[selected_year]]
    traces = []
    for i in filtered_df['Province/State'].unique():
        df_by_continent = filtered_df[filtered_df['Province/State'] == i]
        traces.append(
            go.Scattermapbox(
                lat=df_by_continent['Lat'],
                lon=df_by_continent['Long'],
                mode='markers',
                marker=go.scattermapbox.Marker(
                    size=df_by_continent['Size']
                ),
                text=df_by_continent['Province/State'],

            )
        )

    return {
        'data': traces,
        'layout': dict(
            autosize=True,
            hovermode='closest',
            mapbox=dict(
                accesstoken=access_token,
                bearing=0,
                center=dict(
                    lat=38.92,
                    lon=-77.07
                ),
                pitch=0,
                zoom=1
            ),
            transition={
                'duration': 500,
                'easing': 'cubic-in-out'
            }
        )
    }


if __name__ == '__main__':
    app.run_server(debug=True)