Announcing Dash Bio 1.0.0 🎉 : a one-stop-shop for bioinformatics and drug development visualizations.

Slow dash app with complex geometry

Hi!

I’m working on a project whose main goal is showing data from a CSV on a choropleth map. In the next example, the map is showing the cities of 4 countries from MiddleAmerica and the colours are related to the country which each city belongs to. As you can see, the app is similar to the example US Opioid Epidemic.

The problem is that the geometry is very very very complex and when I change the column to represent on the map, the app takes more than 8 seconds to refresh the view and plot the new data. I tried to simplify the geometry with mapshaper but it’s still the same.

I’m running the app in a container with Gunicorn and I tried to optimize it but the app performance didn’t get better. I don’t know if there’s something else which could help in this situation or if it’s an issue of the library.

I add the code below. Thanks for the help.

import os
import pathlib
import re
import plotly.graph_objs as graph_objs
import numpy as np

import dash
import dash_core_components as dcc
import dash_html_components as html
import pandas as pd
from dash.dependencies import Input, Output, State
import cufflinks as cf
from sklearn.preprocessing import LabelEncoder
import plotly.express as px
import plotly.graph_objects as go

MAPBOX_TOKEN = 'my_box_token'
mapbox_style = "mapbox://styles/plotlymapbox/cjvprkf3t1kns1cqjxuxmwixz"
DEFAULT_COLORSCALE = [
    "#f2fffb",
    "#bbffeb",
    "#98ffe0",
    "#79ffd6",
    "#6df0c8",
    "#69e7c0",
    "#59dab2",
    "#45d0a5",
    "#31c194",
    "#2bb489",
    "#25a27b",
    "#1e906d",
    "#188463",
    "#157658",
    "#11684d",
    "#10523e",
]

map_type = [{"label":"Basico", "value":"basic"}, 
            {"label":"Calles", "value":"streets"}, 
            {"label":"Outdoors", "value":"outdoors"}, 
            {"label":"Claro", "value":"light"}, 
            {"label":"Oscuro", "value":"dark"}, 
            {"label":"Satélite", "value":"satellite"},
            {"label":"Satélite street", "value":"satellite-streets"},
            {"label":"Open Street Map", "value":"open-street-map"}
            ]

app = dash.Dash(
    __name__,
    meta_tags=[
        {"name": "viewport", "content": "width=device-width, initial-scale=1.0"}
    ]
)
server = app.server 

original_df = pd.read_csv("./data/csv/200923_PT1_A07_P04_BBDD_agregada.csv")
      
import geojson
with open("./data/geojson/0.005_reduced_data.json") as f:
    geojson_data = geojson.load(f)
    
non_selected_cols = ['ID', 'ID_1', 'ID_12',  'CDGO_MUN', 'MUN', 'geometry']
select= []
for var in list(original_df.columns):
    if (var not in non_selected_cols):
        select.append({'label':var, 'value':var})

categorical_columns = ['PAIS', 'DPTO', 'MUN']
categorical_to_numeric = {'DPTO':'CDGO_DPTO', 'MUN':'CDGO_MUN'}
   
def get_centroids_from_polygons():
    from geojson_utils import centroid 
    lat_list = []
    lon_list = []
    hover_list = []
    for feat in geojson_data['features']:
        hover_text = "<b>Pais: "+feat['properties']['PAIS'] + "</b> <br> Dpto: " + feat['properties']['DPTO'] + "<br> Mun: " + feat['properties']['MUN']
        hover_list.append(hover_text)
        if (feat['geometry']['type'] == 'MultiPolygon'):
            multipolygons = feat['geometry']['coordinates']
            lat_sum = 0
            lon_sum = 0
            for pol in multipolygons:
                format_pol = {'coordinates':pol ,'type':'Polygon'}
                multicentroids = centroid(format_pol)
                lon_sum += multicentroids['coordinates'][0]
                lat_sum += multicentroids['coordinates'][1]
                
            lat_list.append(lat_sum/len(multipolygons))
            lon_list.append(lon_sum/len(multipolygons))

        else:
            centroids = centroid(feat['geometry'])
            lon_list.append(centroids['coordinates'][0])
            lat_list.append(centroids['coordinates'][1])
    
    centroids_df = pd.DataFrame()
    centroids_df['Latitude'] = lat_list
    centroids_df['Longitude'] = lon_list
    centroids_df['Hover'] = hover_list
    
    return centroids_df

def preprocess_data():
    preprocessed_df = pd.DataFrame()
    lb_make = LabelEncoder()
    for column in original_df:
        column_values = list(original_df[column].values)
        if (column in categorical_columns):
            encoding_values = lb_make.fit_transform(column_values)
            preprocessed_df[column] = encoding_values
        else:
            preprocessed_values = []
            for value in column_values:
                if (np.isnan(float(str(value).replace(',','.')))):
                    preprocessed_values.append("NAN")
                else:
                    preprocessed_v = float(str(value).replace(',','.'))
                    preprocessed_values.append(preprocessed_v)
                    
            preprocessed_df[column] = preprocessed_values
    
    return preprocessed_df

def find_intervals(preprocessed_df):
    bin_dict = {}
    interval_dict = {}
    numeric_columns = [column for column in list(original_df.columns) if column not in categorical_columns]
    numeric_df = preprocessed_df.filter(numeric_columns)
    
    for column in list(numeric_df.columns):
        values = list(numeric_df[column].values)
        values_without_nan = [value for value in values if value != 'NAN']
        
        n_unique_values = len(np.unique(values_without_nan))  
        n_intervals = len(DEFAULT_COLORSCALE) if n_unique_values >= len(DEFAULT_COLORSCALE) else n_unique_values
        first = int(min(values_without_nan))
        second = max(values_without_nan) / n_intervals
        interval_column = []
        bins = []
        for i in range(1, n_intervals+1):
            interval_column.append((first, int(second*i) + 1))
            bins.append(str(first)+"-"+str(int(second*i) + 1))
            first = int(second*i) + 1
        
        interval_dict[column] = interval_column
        bin_dict[column] = bins
    
    return {'bins':bin_dict, 'intervals':interval_dict}

def plot_map(selected_var, map_type, preprocessed_df, interval_dict, bin_dict):
    colors = []
    bins = []
    values_to_plot = {}
    if (selected_var in categorical_to_numeric):
        selected_var = categorical_to_numeric[selected_var]
        
    if (selected_var == 'PAIS'):
        country_names = list(set(list(original_df['PAIS'].values)))
        countries = list(preprocessed_df['PAIS'].values)
        unique_countries = np.unique(countries)
        sorted_countries = sorted(set(list(original_df['PAIS'].values)), key=list(original_df['PAIS'].values).index)
        country_colors = {}
        for i in range(0, len(unique_countries)):
            country_colors[unique_countries[i]] = DEFAULT_COLORSCALE[i*len(unique_countries)]
            bins.append(sorted_countries[i])

        id_values = list(preprocessed_df['ID'].values)
        for i in range(0, len(id_values)):
            values_to_plot[id_values[i]] = countries[i]
            colors.append(country_colors[countries[i]])
    else:
        id_values = list(preprocessed_df['ID'].values)
        column_values = list(preprocessed_df[selected_var].values)
        for i in range(0, len(id_values)):
            if (column_values[i] != 'NAN'):
                values_to_plot[id_values[i]] = column_values[i]
        
        interval_list = interval_dict[selected_var]
        bins = bin_dict[selected_var]
        for value in list(values_to_plot.values()):
            for interval in interval_list:
                if (value >= interval[0] and value <= interval[1]):
                    index = interval_list.index(interval)
                    colors.append(DEFAULT_COLORSCALE[index])

    data = [
        dict(
            lat=centroids_df["Latitude"],
            lon=centroids_df["Longitude"],
            text=centroids_df["Hover"],
            type="scattermapbox",
            hoverinfo="text",
            marker=dict(size=5, opacity=0),
            hoverlabel={'bgcolor':colors, 'font':{'size': 13, 'color':'black'}}
        )
    ]
    annotations = [
        dict(
            showarrow=False,
            align="right",
            text="<b></b>",
            font=dict(color=DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]),
            bgcolor="#1f2630",
            x=0.95,
            y=0.95,
        )
    ]
    cm = dict(zip(bins, DEFAULT_COLORSCALE))
    
    for i, bin in enumerate(reversed(bins)):
        color = cm[bin]
        annotations.append(
            dict(
                arrowcolor=color,
                text=bin,
                x=0.95,
                y=0.85 - (i / 20),
                ax=-60,
                ay=0,
                arrowwidth=5,
                arrowhead=0,
                bgcolor="#1f2630",
                font=dict(color=DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]),
            )
        )
    
    sources=[{"type": "FeatureCollection", 'features': [feat]} \
          for feat in geojson_data['features'] if float(feat['properties']['ID']) in values_to_plot]
        
    layout = dict(
        mapbox=dict(
            style=map_type,
            #style=mapbox_style,
            uirevision=True,
            accesstoken=MAPBOX_TOKEN,
            layers=[dict(sourcetype = 'geojson',
                          source =sources[k],
                          below="water",
                          type = 'fill',
                          color = colors[k],
                          opacity=0.8
                        ) for k in range(len(sources))],
            center=dict(
                lat=14.3333300,
                lon=-91.3833300
            ),
            zoom=4,
        ),
        shapes=[
            {
                "type": "rect",
                "xref": "paper",
                "yref": "paper",
                "x0": 0,
                "y0": 0,
                "x1": 1,
                "y1": 1,
                "line": {"width": 1, "color": "#B0BEC5"},
            }
        ],
        hovermode="closest",
        margin=dict(r=0, l=0, t=0, b=0),
        annotations=annotations,
    )
    
    figure = {"data": data, "layout": layout}
    
    return figure

def plot_pie_chart(chart_dropdown, preprocessed_df):
    countries = list(original_df['PAIS'].values)
    var_values = list(preprocessed_df[chart_dropdown].values)
    adds = {}
    for i in range(0, len(var_values)):
        if (countries[i] not in adds):
            adds[countries[i]] = var_values[i]
        else:
            adds[countries[i]] += var_values[i]
    
    sorted_adds = {k: v for k, v in sorted(adds.items(), key=lambda item: item[1])}
    colours = {}
    colour_index = 0
    for key in sorted_adds:
        colours[key] = DEFAULT_COLORSCALE[colour_index]
        colour_index += 4
            
    total_df = pd.DataFrame()
    total_df['pais'] = list(sorted_adds.keys())
    total_df['total_pop'] = list(sorted_adds.values())
    
    fig = px.pie(total_df, values='total_pop', names='pais', color='pais', color_discrete_map=colours, title=chart_dropdown)
    fig_layout = fig["layout"]
    fig_data = fig["data"]
    fig_data[0]["marker"]["line"]["width"] = 0
    fig_data[0]["textposition"] = "outside"
    fig_layout["paper_bgcolor"] = "#1f2630"
    fig_layout["plot_bgcolor"] = "#1f2630"   # Background graphic: black
    fig_layout["font"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]  # Text colour: green
    fig_layout["title"]["font"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]  # Font colour: green
    fig_layout["xaxis"]["tickfont"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]
    fig_layout["yaxis"]["visible"] = True
    fig_layout["yaxis"]["tickfont"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]
    fig_layout["xaxis"]["gridcolor"] = "#5b5b5b" 
    fig_layout["yaxis"]["gridcolor"] = "#5b5b5b"
    fig_layout["margin"]["t"] = 75
    fig_layout["margin"]["r"] = 50
    fig_layout["margin"]["b"] = 100
    fig_layout["margin"]["l"] = 50

    fig.update_layout(showlegend=False)
    return fig

def plot_sunburst_chart(chart_dropdown, preprocessed_df):
    countries = list(original_df['PAIS'].values)
    unique_countries = list(set(countries))
    
    var_values = list(preprocessed_df[chart_dropdown].values)
    unique_levels = list(set(var_values))
    unique_levels.sort()

    new_df = pd.DataFrame()
    country_list = []
    risk_list = []
    total_list = []
    total_adds = {}
    for country in unique_countries:
        country_df = original_df[original_df['PAIS'] == country]
        country_values = list(country_df[chart_dropdown].values)
        counts = {i:country_values.count(i) for i in country_values}
        for level in unique_levels:
            country_list.append(country)
            risk_list.append(level)
            total_list.append(counts[level] if level in counts else 0)
            
        total_adds[country] = sum(total_list)
    
    new_df['pais'] = country_list
    new_df['nivel'] = risk_list
    new_df['total'] = total_list

    colours = {}
    colour_index = 0
    for key in sorted_adds:
        colours[key] = DEFAULT_COLORSCALE[colour_index]
        colour_index += 4
    
    fig = px.sunburst(new_df, path=['pais', 'nivel'], values='total', color='pais',
                      title=chart_dropdown, color_discrete_map=colours,
                      custom_data=['pais'])
    fig.update_traces(
        go.Sunburst(hovertemplate= ' pais=%{customdata[0]} <br> municipios=%{value:,.0f}<br>'))
    
    fig_layout = fig["layout"]
    fig_data = fig["data"]
    fig_data[0]["marker"]["line"]["width"] = 0
    fig_layout["paper_bgcolor"] = "#1f2630"
    fig_layout["plot_bgcolor"] = "#1f2630"   # Background graphic: black
    fig_layout["font"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]
    fig_layout["title"]["font"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]
    fig_layout["margin"]["t"] = 75
    fig_layout["margin"]["r"] = 50
    fig_layout["margin"]["b"] = 100
    fig_layout["margin"]["l"] = 50
    
    fig.update_layout(uniformtext=dict(minsize=10, mode='hide'))
    return fig

def plot_bar_chart(chart_dropdown, preprocessed_df):
    countries = list(original_df['PAIS'].values)
    unique_countries = list(set(countries))

    adds = {}
    for country in unique_countries:
        country_df = original_df[original_df['PAIS'] == country]
        country_values = list(country_df[chart_dropdown])
        float_country_values = []
        for value in country_values:
            float_country_values.append(value)
                
        adds[country] = sum(float_country_values)
    
    sorted_adds = {k: v for k, v in sorted(adds.items(), key=lambda item: item[1])}
    colours = {}
    colour_index = 0
    for key in sorted_adds:
        colours[key] = DEFAULT_COLORSCALE[colour_index]
        colour_index += 4
    
    import plotly.graph_objects as go
    fig = go.Figure(data=[go.Bar(
        x=list(sorted_adds.keys()),
        y=list(sorted_adds.values()),
        marker_color=list(colours.values()) # marker color can be a single color value or an iterable
    )])
    
    fig_layout = fig["layout"]
    fig_data = fig["data"]
    fig_data[0]["marker"]["opacity"] = 1
    fig_data[0]["marker"]["line"]["width"] = 0
    fig_data[0]["textposition"] = "outside"
    fig_layout["paper_bgcolor"] = "#1f2630"
    fig_layout["plot_bgcolor"] = "#1f2630"  
    fig_layout["font"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]
    fig_layout["title"]["font"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]
    fig_layout["xaxis"]["tickfont"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]
    fig_layout["yaxis"]["visible"] = True
    fig_layout["yaxis"]["tickfont"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]
    fig_layout["xaxis"]["gridcolor"] = "#5b5b5b"
    fig_layout["yaxis"]["gridcolor"] = "#5b5b5b"
    fig_layout["margin"]["t"] = 75
    fig_layout["margin"]["r"] = 50
    fig_layout["margin"]["b"] = 100
    fig_layout["margin"]["l"] = 50

    return fig            
    
centroids_df = get_centroids_from_polygons()
centroids_df.to_csv("./data/csv/centroids_df.csv")
preprocessed_df = preprocess_data()
preprocessed_df.to_csv("./data/csv/preprocessed_df.csv")

# Find intervals for numeric data
intervals_result = find_intervals(preprocessed_df)

app.layout = html.Div(
    id="root",
    children=[
        html.Div(
            id="header",
            children=[
                html.Img(id="logo", src=app.get_asset_url("dash-logo.png")),
                html.H4(children="Datos de Centroamérica"),
                html.P(
                    id="description",
                    children="Descripción???.",
                ),
            ],
        ),
        html.Div(
            id="app-container",
            children=[
                html.Div(
                    id="left-column",
                    children=[
                        html.P(id="select-map-type", children="Seleccione un tipo de mapa:"),
                        dcc.Dropdown(
                            options=map_type,
                            id="select-map-chartdown",
                            value="satellite",
                        ),
                        html.Div(
                            id="heatmap-container",
                            children=[
                                dcc.Graph(
                                    id="county-choropleth",
                                ),
                            ],
                        ),
                        
                    ],
                ),
                html.Div(
                    id="graph-container",
                    children=[
                        html.P(id="chart-selector", children="Seleccione una variable:"),
                        dcc.Dropdown(
                            options=select,
                            id="chart-dropdown",
                            value="PAIS",
                        ),
                        dcc.Graph(
                            id="selected-data",
                            figure=dict(
                                data=[dict(x=0, y=0)],
                                layout=dict(
                                    paper_bgcolor="#F4F4F8",
                                    plot_bgcolor="#F4F4F8",
                                    autofill=True,
                                    margin=dict(t=75, r=50, b=100, l=50),
                                ),
                            ),
                        ),
                    ],
                ),
            ],
        ),
    ],
)

@app.callback(
    Output("county-choropleth", "figure"),
    [Input("county-choropleth", "clickData"),
     Input("chart-dropdown", "value"),
     Input("select-map-chartdown", "value")],
)
def update_map(figure, selected_var, map_type):
    return plot_map(selected_var, map_type, preprocessed_df, 
        intervals_result['intervals'], intervals_result['bins'])

@app.callback(
    Output("selected-data", "figure"),
    [
        Input("chart-dropdown", "value")
    ],
)
def display_selected_data(chart_dropdown):
    if (chart_dropdown == 'DPTO'):
        chart_dropdown = 'CDGO_DPTO'
    elif (chart_dropdown == 'PAIS'):
        chart_dropdown = 'Pop_mean'
        
    if ('pop' in chart_dropdown.lower()):
        return plot_pie_chart(chart_dropdown, preprocessed_df)
    
    elif ('riesgo' in chart_dropdown.lower()):
        return plot_sunburst_chart(chart_dropdown, preprocessed_df)
    else:
        return plot_bar_chart(chart_dropdown, preprocessed_df)

if __name__ == "__main__":
    app.run_server(debug=True, port=8020)

Hi,

On the contrary, how can we be sure that the issue does not come from your code?

If you really do not want to share your code, try to isolate the data processing in a function and run a code profiler to gain a better understanding of what is going on.

Good luck!

Possible performance optimization steps could be,

  • Geometry simplification (reduce number of points per polygon)
  • Property reduction (remove unused props from GeoJSON)
  • Serving the GeoJSON as a static asset
  • Compression (e.g. via geobuf)

I am not feeling like doing a full code review, but here are a few tips :man_teacher:

Disclaimer: all optimization work should come after profiling your code. I personally like to use line_profiler.

  1. Cleanup your code.

    • you have a bunch of unused imports : os, pathlib, cf, re, State (& three plotly.graph_objects). This will not speed up your dashboard per se, but still it would be good to clean those. Furthermode, you should keep your imports at the top of your file unless you really need to do otherwise (geojson, geojson_utils, plotly.graph_objects). You also have many unused variables.
    • regroup stuff that should go together (figure layout updates for example)
  2. Know your tools. Take a closer look at pd.read_csv. There are a bunch of options that would simplify your code throughout the entire script and even speed it up:

    • using decimal="," you may cut a large portion of your preprocessing function (eg. : stuff like float(str(value).replace(',', '.')))

    • you should take a look at pandas’ builtin categorical type instead of using LabelEncoder. You may specify dtype at loading time using the dtype argument

    • You are not using your data structures efficiently. In plot_map you use for loops to iterate over pandas series which is extremely inefficient. I am trying to keep this review as shallow as possible but it seems like you should take a look at pandas.cut for all intervals related things.

    • I am pretty sure that you can cut the code size of plot_pie_chart by half with df.sort_values (and the computing time even more). You are (for) looping over a dataframe to sort the values to recreate a sorted dataframe.

    • use .unique() instead of set(), use nan related methods instead of checking one value at a time.

Anyhow, I am not saying that fixing all of this will speedup your callback down to <1 s for sure (though it might), but it would be good nonetheless.

Good luck

Wow!! Thanks so much for your suggestions. I will take them into account to optimize my code.

Finally I managed to get to a solution changing to Plotly Express, which uses WebGL and it goes faster.

Hi! Thanks for your reply. I did the first two suggestions, but I’m curious about the third one.

  • How can I do it? Is enought if I put GeoJSON file in the assets folder?
  • How can this help to the performance of my app? If you could give some briefly explanations, I’d be grateful.

You create the GeoJSON file prior to launching the app (as a pre processing step) and place it in the assets folder. Then you pass the url to the file instead of the GeoJSON data itself.

When you provide the GeoJSON data as a static asset (rather in-memory data) it tends to speed up the data transfer time. In particular for subsequent request as the data can be cached by the browser. Hence the improvement in performance is in load time, not in terms of the map performance itself (which is what is improved by moving to WebGL rendering).