Hi!
I’m working on a project whose main goal is showing data from a CSV on a choropleth map. In the next example, the map is showing the cities of 4 countries from MiddleAmerica and the colours are related to the country which each city belongs to. As you can see, the app is similar to the example US Opioid Epidemic.
The problem is that the geometry is very very very complex and when I change the column to represent on the map, the app takes more than 8 seconds to refresh the view and plot the new data. I tried to simplify the geometry with mapshaper but it’s still the same.
I’m running the app in a container with Gunicorn and I tried to optimize it but the app performance didn’t get better. I don’t know if there’s something else which could help in this situation or if it’s an issue of the library.
I add the code below. Thanks for the help.
import os
import pathlib
import re
import plotly.graph_objs as graph_objs
import numpy as np
import dash
import dash_core_components as dcc
import dash_html_components as html
import pandas as pd
from dash.dependencies import Input, Output, State
import cufflinks as cf
from sklearn.preprocessing import LabelEncoder
import plotly.express as px
import plotly.graph_objects as go
MAPBOX_TOKEN = 'my_box_token'
mapbox_style = "mapbox://styles/plotlymapbox/cjvprkf3t1kns1cqjxuxmwixz"
DEFAULT_COLORSCALE = [
"#f2fffb",
"#bbffeb",
"#98ffe0",
"#79ffd6",
"#6df0c8",
"#69e7c0",
"#59dab2",
"#45d0a5",
"#31c194",
"#2bb489",
"#25a27b",
"#1e906d",
"#188463",
"#157658",
"#11684d",
"#10523e",
]
map_type = [{"label":"Basico", "value":"basic"},
{"label":"Calles", "value":"streets"},
{"label":"Outdoors", "value":"outdoors"},
{"label":"Claro", "value":"light"},
{"label":"Oscuro", "value":"dark"},
{"label":"Satélite", "value":"satellite"},
{"label":"Satélite street", "value":"satellite-streets"},
{"label":"Open Street Map", "value":"open-street-map"}
]
app = dash.Dash(
__name__,
meta_tags=[
{"name": "viewport", "content": "width=device-width, initial-scale=1.0"}
]
)
server = app.server
original_df = pd.read_csv("./data/csv/200923_PT1_A07_P04_BBDD_agregada.csv")
import geojson
with open("./data/geojson/0.005_reduced_data.json") as f:
geojson_data = geojson.load(f)
non_selected_cols = ['ID', 'ID_1', 'ID_12', 'CDGO_MUN', 'MUN', 'geometry']
select= []
for var in list(original_df.columns):
if (var not in non_selected_cols):
select.append({'label':var, 'value':var})
categorical_columns = ['PAIS', 'DPTO', 'MUN']
categorical_to_numeric = {'DPTO':'CDGO_DPTO', 'MUN':'CDGO_MUN'}
def get_centroids_from_polygons():
from geojson_utils import centroid
lat_list = []
lon_list = []
hover_list = []
for feat in geojson_data['features']:
hover_text = "<b>Pais: "+feat['properties']['PAIS'] + "</b> <br> Dpto: " + feat['properties']['DPTO'] + "<br> Mun: " + feat['properties']['MUN']
hover_list.append(hover_text)
if (feat['geometry']['type'] == 'MultiPolygon'):
multipolygons = feat['geometry']['coordinates']
lat_sum = 0
lon_sum = 0
for pol in multipolygons:
format_pol = {'coordinates':pol ,'type':'Polygon'}
multicentroids = centroid(format_pol)
lon_sum += multicentroids['coordinates'][0]
lat_sum += multicentroids['coordinates'][1]
lat_list.append(lat_sum/len(multipolygons))
lon_list.append(lon_sum/len(multipolygons))
else:
centroids = centroid(feat['geometry'])
lon_list.append(centroids['coordinates'][0])
lat_list.append(centroids['coordinates'][1])
centroids_df = pd.DataFrame()
centroids_df['Latitude'] = lat_list
centroids_df['Longitude'] = lon_list
centroids_df['Hover'] = hover_list
return centroids_df
def preprocess_data():
preprocessed_df = pd.DataFrame()
lb_make = LabelEncoder()
for column in original_df:
column_values = list(original_df[column].values)
if (column in categorical_columns):
encoding_values = lb_make.fit_transform(column_values)
preprocessed_df[column] = encoding_values
else:
preprocessed_values = []
for value in column_values:
if (np.isnan(float(str(value).replace(',','.')))):
preprocessed_values.append("NAN")
else:
preprocessed_v = float(str(value).replace(',','.'))
preprocessed_values.append(preprocessed_v)
preprocessed_df[column] = preprocessed_values
return preprocessed_df
def find_intervals(preprocessed_df):
bin_dict = {}
interval_dict = {}
numeric_columns = [column for column in list(original_df.columns) if column not in categorical_columns]
numeric_df = preprocessed_df.filter(numeric_columns)
for column in list(numeric_df.columns):
values = list(numeric_df[column].values)
values_without_nan = [value for value in values if value != 'NAN']
n_unique_values = len(np.unique(values_without_nan))
n_intervals = len(DEFAULT_COLORSCALE) if n_unique_values >= len(DEFAULT_COLORSCALE) else n_unique_values
first = int(min(values_without_nan))
second = max(values_without_nan) / n_intervals
interval_column = []
bins = []
for i in range(1, n_intervals+1):
interval_column.append((first, int(second*i) + 1))
bins.append(str(first)+"-"+str(int(second*i) + 1))
first = int(second*i) + 1
interval_dict[column] = interval_column
bin_dict[column] = bins
return {'bins':bin_dict, 'intervals':interval_dict}
def plot_map(selected_var, map_type, preprocessed_df, interval_dict, bin_dict):
colors = []
bins = []
values_to_plot = {}
if (selected_var in categorical_to_numeric):
selected_var = categorical_to_numeric[selected_var]
if (selected_var == 'PAIS'):
country_names = list(set(list(original_df['PAIS'].values)))
countries = list(preprocessed_df['PAIS'].values)
unique_countries = np.unique(countries)
sorted_countries = sorted(set(list(original_df['PAIS'].values)), key=list(original_df['PAIS'].values).index)
country_colors = {}
for i in range(0, len(unique_countries)):
country_colors[unique_countries[i]] = DEFAULT_COLORSCALE[i*len(unique_countries)]
bins.append(sorted_countries[i])
id_values = list(preprocessed_df['ID'].values)
for i in range(0, len(id_values)):
values_to_plot[id_values[i]] = countries[i]
colors.append(country_colors[countries[i]])
else:
id_values = list(preprocessed_df['ID'].values)
column_values = list(preprocessed_df[selected_var].values)
for i in range(0, len(id_values)):
if (column_values[i] != 'NAN'):
values_to_plot[id_values[i]] = column_values[i]
interval_list = interval_dict[selected_var]
bins = bin_dict[selected_var]
for value in list(values_to_plot.values()):
for interval in interval_list:
if (value >= interval[0] and value <= interval[1]):
index = interval_list.index(interval)
colors.append(DEFAULT_COLORSCALE[index])
data = [
dict(
lat=centroids_df["Latitude"],
lon=centroids_df["Longitude"],
text=centroids_df["Hover"],
type="scattermapbox",
hoverinfo="text",
marker=dict(size=5, opacity=0),
hoverlabel={'bgcolor':colors, 'font':{'size': 13, 'color':'black'}}
)
]
annotations = [
dict(
showarrow=False,
align="right",
text="<b></b>",
font=dict(color=DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]),
bgcolor="#1f2630",
x=0.95,
y=0.95,
)
]
cm = dict(zip(bins, DEFAULT_COLORSCALE))
for i, bin in enumerate(reversed(bins)):
color = cm[bin]
annotations.append(
dict(
arrowcolor=color,
text=bin,
x=0.95,
y=0.85 - (i / 20),
ax=-60,
ay=0,
arrowwidth=5,
arrowhead=0,
bgcolor="#1f2630",
font=dict(color=DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]),
)
)
sources=[{"type": "FeatureCollection", 'features': [feat]} \
for feat in geojson_data['features'] if float(feat['properties']['ID']) in values_to_plot]
layout = dict(
mapbox=dict(
style=map_type,
#style=mapbox_style,
uirevision=True,
accesstoken=MAPBOX_TOKEN,
layers=[dict(sourcetype = 'geojson',
source =sources[k],
below="water",
type = 'fill',
color = colors[k],
opacity=0.8
) for k in range(len(sources))],
center=dict(
lat=14.3333300,
lon=-91.3833300
),
zoom=4,
),
shapes=[
{
"type": "rect",
"xref": "paper",
"yref": "paper",
"x0": 0,
"y0": 0,
"x1": 1,
"y1": 1,
"line": {"width": 1, "color": "#B0BEC5"},
}
],
hovermode="closest",
margin=dict(r=0, l=0, t=0, b=0),
annotations=annotations,
)
figure = {"data": data, "layout": layout}
return figure
def plot_pie_chart(chart_dropdown, preprocessed_df):
countries = list(original_df['PAIS'].values)
var_values = list(preprocessed_df[chart_dropdown].values)
adds = {}
for i in range(0, len(var_values)):
if (countries[i] not in adds):
adds[countries[i]] = var_values[i]
else:
adds[countries[i]] += var_values[i]
sorted_adds = {k: v for k, v in sorted(adds.items(), key=lambda item: item[1])}
colours = {}
colour_index = 0
for key in sorted_adds:
colours[key] = DEFAULT_COLORSCALE[colour_index]
colour_index += 4
total_df = pd.DataFrame()
total_df['pais'] = list(sorted_adds.keys())
total_df['total_pop'] = list(sorted_adds.values())
fig = px.pie(total_df, values='total_pop', names='pais', color='pais', color_discrete_map=colours, title=chart_dropdown)
fig_layout = fig["layout"]
fig_data = fig["data"]
fig_data[0]["marker"]["line"]["width"] = 0
fig_data[0]["textposition"] = "outside"
fig_layout["paper_bgcolor"] = "#1f2630"
fig_layout["plot_bgcolor"] = "#1f2630" # Background graphic: black
fig_layout["font"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)] # Text colour: green
fig_layout["title"]["font"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)] # Font colour: green
fig_layout["xaxis"]["tickfont"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]
fig_layout["yaxis"]["visible"] = True
fig_layout["yaxis"]["tickfont"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]
fig_layout["xaxis"]["gridcolor"] = "#5b5b5b"
fig_layout["yaxis"]["gridcolor"] = "#5b5b5b"
fig_layout["margin"]["t"] = 75
fig_layout["margin"]["r"] = 50
fig_layout["margin"]["b"] = 100
fig_layout["margin"]["l"] = 50
fig.update_layout(showlegend=False)
return fig
def plot_sunburst_chart(chart_dropdown, preprocessed_df):
countries = list(original_df['PAIS'].values)
unique_countries = list(set(countries))
var_values = list(preprocessed_df[chart_dropdown].values)
unique_levels = list(set(var_values))
unique_levels.sort()
new_df = pd.DataFrame()
country_list = []
risk_list = []
total_list = []
total_adds = {}
for country in unique_countries:
country_df = original_df[original_df['PAIS'] == country]
country_values = list(country_df[chart_dropdown].values)
counts = {i:country_values.count(i) for i in country_values}
for level in unique_levels:
country_list.append(country)
risk_list.append(level)
total_list.append(counts[level] if level in counts else 0)
total_adds[country] = sum(total_list)
new_df['pais'] = country_list
new_df['nivel'] = risk_list
new_df['total'] = total_list
colours = {}
colour_index = 0
for key in sorted_adds:
colours[key] = DEFAULT_COLORSCALE[colour_index]
colour_index += 4
fig = px.sunburst(new_df, path=['pais', 'nivel'], values='total', color='pais',
title=chart_dropdown, color_discrete_map=colours,
custom_data=['pais'])
fig.update_traces(
go.Sunburst(hovertemplate= ' pais=%{customdata[0]} <br> municipios=%{value:,.0f}<br>'))
fig_layout = fig["layout"]
fig_data = fig["data"]
fig_data[0]["marker"]["line"]["width"] = 0
fig_layout["paper_bgcolor"] = "#1f2630"
fig_layout["plot_bgcolor"] = "#1f2630" # Background graphic: black
fig_layout["font"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]
fig_layout["title"]["font"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]
fig_layout["margin"]["t"] = 75
fig_layout["margin"]["r"] = 50
fig_layout["margin"]["b"] = 100
fig_layout["margin"]["l"] = 50
fig.update_layout(uniformtext=dict(minsize=10, mode='hide'))
return fig
def plot_bar_chart(chart_dropdown, preprocessed_df):
countries = list(original_df['PAIS'].values)
unique_countries = list(set(countries))
adds = {}
for country in unique_countries:
country_df = original_df[original_df['PAIS'] == country]
country_values = list(country_df[chart_dropdown])
float_country_values = []
for value in country_values:
float_country_values.append(value)
adds[country] = sum(float_country_values)
sorted_adds = {k: v for k, v in sorted(adds.items(), key=lambda item: item[1])}
colours = {}
colour_index = 0
for key in sorted_adds:
colours[key] = DEFAULT_COLORSCALE[colour_index]
colour_index += 4
import plotly.graph_objects as go
fig = go.Figure(data=[go.Bar(
x=list(sorted_adds.keys()),
y=list(sorted_adds.values()),
marker_color=list(colours.values()) # marker color can be a single color value or an iterable
)])
fig_layout = fig["layout"]
fig_data = fig["data"]
fig_data[0]["marker"]["opacity"] = 1
fig_data[0]["marker"]["line"]["width"] = 0
fig_data[0]["textposition"] = "outside"
fig_layout["paper_bgcolor"] = "#1f2630"
fig_layout["plot_bgcolor"] = "#1f2630"
fig_layout["font"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]
fig_layout["title"]["font"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]
fig_layout["xaxis"]["tickfont"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]
fig_layout["yaxis"]["visible"] = True
fig_layout["yaxis"]["tickfont"]["color"] = DEFAULT_COLORSCALE[int(len(DEFAULT_COLORSCALE)/2)]
fig_layout["xaxis"]["gridcolor"] = "#5b5b5b"
fig_layout["yaxis"]["gridcolor"] = "#5b5b5b"
fig_layout["margin"]["t"] = 75
fig_layout["margin"]["r"] = 50
fig_layout["margin"]["b"] = 100
fig_layout["margin"]["l"] = 50
return fig
centroids_df = get_centroids_from_polygons()
centroids_df.to_csv("./data/csv/centroids_df.csv")
preprocessed_df = preprocess_data()
preprocessed_df.to_csv("./data/csv/preprocessed_df.csv")
# Find intervals for numeric data
intervals_result = find_intervals(preprocessed_df)
app.layout = html.Div(
id="root",
children=[
html.Div(
id="header",
children=[
html.Img(id="logo", src=app.get_asset_url("dash-logo.png")),
html.H4(children="Datos de Centroamérica"),
html.P(
id="description",
children="Descripción???.",
),
],
),
html.Div(
id="app-container",
children=[
html.Div(
id="left-column",
children=[
html.P(id="select-map-type", children="Seleccione un tipo de mapa:"),
dcc.Dropdown(
options=map_type,
id="select-map-chartdown",
value="satellite",
),
html.Div(
id="heatmap-container",
children=[
dcc.Graph(
id="county-choropleth",
),
],
),
],
),
html.Div(
id="graph-container",
children=[
html.P(id="chart-selector", children="Seleccione una variable:"),
dcc.Dropdown(
options=select,
id="chart-dropdown",
value="PAIS",
),
dcc.Graph(
id="selected-data",
figure=dict(
data=[dict(x=0, y=0)],
layout=dict(
paper_bgcolor="#F4F4F8",
plot_bgcolor="#F4F4F8",
autofill=True,
margin=dict(t=75, r=50, b=100, l=50),
),
),
),
],
),
],
),
],
)
@app.callback(
Output("county-choropleth", "figure"),
[Input("county-choropleth", "clickData"),
Input("chart-dropdown", "value"),
Input("select-map-chartdown", "value")],
)
def update_map(figure, selected_var, map_type):
return plot_map(selected_var, map_type, preprocessed_df,
intervals_result['intervals'], intervals_result['bins'])
@app.callback(
Output("selected-data", "figure"),
[
Input("chart-dropdown", "value")
],
)
def display_selected_data(chart_dropdown):
if (chart_dropdown == 'DPTO'):
chart_dropdown = 'CDGO_DPTO'
elif (chart_dropdown == 'PAIS'):
chart_dropdown = 'Pop_mean'
if ('pop' in chart_dropdown.lower()):
return plot_pie_chart(chart_dropdown, preprocessed_df)
elif ('riesgo' in chart_dropdown.lower()):
return plot_sunburst_chart(chart_dropdown, preprocessed_df)
else:
return plot_bar_chart(chart_dropdown, preprocessed_df)
if __name__ == "__main__":
app.run_server(debug=True, port=8020)