Heatmap is slow for large data arrays

My problem is the following. I am trying to use heatmap to plot 2D data, but I typically have an array size of 2000x1000.
The following takes approx 30 seconds to load in my Jupyter notebook, which is too slow for my use:

from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True)
import numpy as np
trace = dict(type='heatmap', z=np.random.normal(size=2000*1000).reshape(1000,2000).astype(np.float64))
data=[trace]
iplot(data)

I have tried adding hoverinfo= 'skip' to disable the hover info but that did not help.
In comparison, using imshow in matplotlib is instantaneous.
(I don’t know if the fact that i’m trying to do this in a Jupyter notebook makes any difference, or if i’m using the offline mode, but I cannot do without either of these)

My guess is that plotly is drawing a small square/rectangle for each pixel and this is the bottleneck, while imshow is using some shortcut via image compression. However, this is actually a feature of Heatmap that I need, since not all my pixels necessarily have the same size (and imshow assumes the data is regularly spaced. In comparison, matplotlib’s pcolormesh is also very slow).

My question is then the following: is it possible to have a way to display a low resolution version of my data (say every 10 data points), and then when the user zooms in, we eventually recover the full resolution by re-plotting higher resolution data for every zoom. Or in other words is it possible to keep the resolution of the image displayed constant, independent of the level of zoom (up to a maximum, obviously), say always 256x256?
I would need some sort of event listener that grabs the current x and y range of the displayed area and resample my original data array accordingly.
I can see that this may be possible from JS, but it seems to be Python is more of a challenge.

I thought about having a slider from ipywidgets and using that as a zoom but this would only allow me to zoom around a fixed point on the image, and I would like to be able to zoom on any area, using the classic plotly rectangle zoom tool.

Any help would be very appreciated.
Thanks!

1 Like

Hi @nvaytet,

Welcome to the forums! There are two high-level things that can be slowing down a plot like this.

  1. Serialization can be slow, which causes a long delay before the plot is initially displayed. This is the process of turning that 2 million element array into a JSON form that can be processed by Plotly.js. We have a more efficient serialization protocol implemented for displaying FigureWidget figures in the Jupyter notebook, but this is only implemented for 1D arrays at the moment unfortunately.
  2. Interaction can be slow, which causes the pan/zoom/hover to be sluggish after the figure is loaded. This usually happens for non-WebGL accelerated trace types that draw individual SVG elements for each marker.

Using the WebGL accelerated heatmap trace (https://plot.ly/python/heatmap-webgl/#create-a-heatmapgl-from-an-image) should help somewhat if you’re running into the second problem above, but it won’t help with the first.

Regarding the dynamic re-sampling idea, I’d recommend studying the Datashader case study (https://plot.ly/python/change-callbacks-datashader/). This approach sends the data from Python to JavaScript as a base64 encoded png image, which is much more efficient. It also installs a callback on pan/zoom events to re-aggregate the dataset using Datashader and then update the image.

Hope that helps give you some ideas!
-Jon

1 Like

Hi @jmmease,

Thanks so much for your reply.

The problem I was having was both with the initial generation time and interaction.
I had tried to use heatmapgl but that didn’t help with generation time and also did not seem to help much with the interaction.

I had a look at the Datashader case study you suggested, and in there I found exactly what I was looking for: a way to install a callback from pan/zoom to a function.
Based on this, and after a few more tricks, I finally managed to obtain what I wanted, even without the need to use PNG images and the Datashader module, just pure plotly.
I can now make a large image and plot a low-resolution of it until I zoom all the way down to the original data.
I post my solution below, in case anyone else finds it useful.

# In[ ]:


from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True)
import plotly.graph_objs as go
import numpy as np


# In[ ]:


# Define image maximum resolution to be displayed
res = 128
# Define full resolution image
N = 2000
M = 1000
# xx and yy are the edges of the pixels in the full image
xx = np.arange(N+1, dtype=np.float64)
yy = np.arange(M+1, dtype=np.float64)
x, y = np.meshgrid(xx, yy)
b = N/40.0
c = M/2.0
r = np.sqrt(((x[:-1]-c)/b)**2 + ((y[:-1]-c)/b)**2)
# zz is the heatmap values
zz = np.sin(r)
# xc and yc are the pixels centers
xc = 0.5 * (xx[1:] + xx[:-1])
yc = 0.5 * (yy[1:] + yy[:-1])
# Store the limits of the full resolution array for the colorbar
cmin = np.amin(zz)
cmax = np.amax(zz)


# In[ ]:


def resample_image(x_range, y_range):
    # Find indices of xx and yy that are shown in current range
    x_in_range = np.where(np.logical_and(xx >= x_range[0], xx <= x_range[1]))
    y_in_range = np.where(np.logical_and(yy >= y_range[0], yy <= y_range[1]))

    # xmin, xmax... here are array indices, not float coordinates
    xmin = x_in_range[0][0]
    xmax = x_in_range[0][-1]
    ymin = y_in_range[0][0]
    ymax = y_in_range[0][-1]
    # here we perform a trick so that the edges of the displayed image is not greyed out
    # if the zoom area slices a pixel in half, only the pixel inside the view area will be shown
    # and the outer edge between that last pixel edge and the edge of the view frame area will
    # be empty. So we extend the selected area with an additional pixel, if the selected area
    # is inside the global limits of the full resolution array
    xmin -= int(xmin > 0)
    xmax += int(xmax < len(xx)-1)
    ymin -= int(ymin > 0)
    ymax += int(ymax < len(yy)-1)

    # Local coordinate arrays
    xxx = xx[xmin:xmax+1]
    yyy = yy[ymin:ymax+1]

    # Count the number of pixels in the current view
    nx_view = xmax-xmin
    ny_view = ymax-ymin

    # Define x and y edges for histogramming
    # If the number of pixels in the view area is larger than the max allowed resolution
    # we create some custom pixels
    if nx_view > res:
        xe = np.linspace(xxx[0],xxx[-1],res)
    else:
        xe = xxx
    if ny_view > res:
        ye = np.linspace(yyy[0],yyy[-1],res)
    else:
        ye = yyy

    # Optimize if no re-sampling is required
    if (nx_view < self.resolution) and (ny_view < self.resolution):
        z1 = self.z[ymin:ymax,xmin:xmax]
    else:
        xg, yg = np.meshgrid(xc[xmin:xmax], yc[ymin:ymax])
        xv = np.ravel(xg)
        yv = np.ravel(yg)
        zv = np.ravel(zz[ymin:ymax,xmin:xmax])
        # Histogram the data to make a low-resolution image
        # Using weights in the second histogram allows us to then do z1/z0 to obtain the
        # averaged data inside the coarse pixels
        z0, yedges1, xedges1 = np.histogram2d(yv, xv, bins=(ye,xe))
        z1, yedges1, xedges1 = np.histogram2d(yv, xv, bins=(ye,xe), weights=zv)
        z1 /= z0

    # Here we perform another trick. If we plot simply the local arrays in plotly, the reset axes
    # or home functionality will be lost because plotly will now think that the data that eixsts
    # is only the small window shown after a zoom. So we add a one-pixel padding area to the local
    # z array. The size of that padding extends from the edges of the initial full resolution array
    # (e.g. x=0, y=0) up to the edge of the view area. These large (and probably elongated) pixels
    # add very little data and will not show in the view area but allow plotly to recover the full
    # axes limits if we double-click on the plot
    if xmin > 0:
        xe = np.concatenate([xx[0:1], xe])
    if xmax < len(xx)-1:
        xe = np.concatenate([xe, xx[-1:]])
    if ymin > 0:
        ye = np.concatenate([yy[0:1], ye])
    if ymax < len(yy)-1:
        ye = np.concatenate([ye, yy[-1:]])
    imin = int(xmin>0)
    imax = int(xmax<(len(xx)-1))
    jmin = int(ymin>0)
    jmax = int(ymax<(len(yy)-1))

    # the local z array
    zzz = np.zeros([len(ye)-1, len(xe)-1])
    zzz[jmin:len(ye)-jmax-1,imin:len(xe)-imax-1] = z1
    return xe, ye, zzz


# In[ ]:


# Make an initial low-resolution sampling of the image for plotting
x_init, y_init, z_init = resample_image([xx[0], xx[-1]], [yy[0], yy[-1]])
trace = dict(type='heatmap', x=x_init, y=y_init, z=z_init)
data=[trace]
f = go.FigureWidget(data=data)


# In[ ]:


# The function bound to the on_change callback
def update_image(layout, x_range, y_range):
    x_upd, y_upd, z_upd = resample_image(x_range, y_range)
    # Using f.update allows us here to update all x, y and z at the same time
    # We also apply the global colorbar limits to avoid the autoscaling of the colorbar as we zoom in
    # and out
    f.update({'data': [{'type':'heatmap', 'x':x_upd, 'y':y_upd, 'z':z_upd, 'zmin':cmin, 'zmax':cmax}]})


# In[ ]:


# Add a callback to update the view area
f.layout.on_change(update_image, 'xaxis.range', 'yaxis.range')


# In[ ]:


# Plot the Figure
f
1 Like

Hi @nvaytet,

Wow, that is really cool. Thanks for sharing!

-Jon

Hi @jmmease,

A new update about improving performance of heatmap plots (by the way, I can’t tell you enough what an amazing job you guys are doing with plotly, you can basically do anything with it!)

Learning from the “Zoom on static images” example (https://plot.ly/python/images/#zoom-on-static-images), I am now displaying the heatmap as a background image, thus removing the need for the cumbersome code in my previous post. I am using the PIL library to convert the numpy array to an image.

import numpy as np

N = 1000
M = 500
xx = np.arange(N, dtype=np.float64)
yy = np.arange(M, dtype=np.float64)
x, y = np.meshgrid(xx, yy)
b = N/20.0
c = M/2.0
r = np.sqrt(((x-c)/b)**2 + ((y-c)/b)**2)
a = np.sin(r)

# Limits
xmin = xx[0]
xmax = xx[-1]
ymin = yy[0]
ymax = yy[-1]
amin = np.amin(a)
amax = np.amax(a)

from PIL import Image
from matplotlib import cm
from matplotlib.colors import Normalize

# Some normalization from matplotlib
cNorm = Normalize(vmin=amin, vmax=amax)
scalarMap  = cm.ScalarMappable(norm=cNorm, cmap='viridis' )
seg_colors = scalarMap.to_rgba(a) 
img = Image.fromarray(np.uint8(seg_colors*255))

# Now the plotly code
import plotly.graph_objects as go

# Create figure
fig = go.Figure()

# Constants
img_width = 900
img_height = 600

# Add invisible scatter trace.
# This trace is added to help the autoresize logic work.
# We also add a color to the scatter points so we can have a colorbar next to our image
fig.add_trace(
    go.Scatter(
        x=[xmin, xmax],
        y=[ymin, ymax],
        mode="markers",
        marker={"color":[np.amin(a), np.amax(a)],
                "colorscale":'Viridis',
                "showscale":True,
                "colorbar":{"title":"Counts",
                            "titleside": "right"},
                "opacity": 0
               }
    )
)

# Add image
fig.update_layout(
    images=[go.layout.Image(
        x=xmin,
        sizex=xmax-xmin,
        y=ymax,
        sizey=ymax-ymin,
        xref="x",
        yref="y",
        opacity=1.0,
        layer="below",
        sizing="stretch",
        source=img)]
)

# Configure other layout
fig.update_layout(
        xaxis=dict(showgrid=False, zeroline=False, range=[xmin, xmax]),
        yaxis=dict(showgrid=False, zeroline=False, range=[ymin, ymax]),
    width=img_width,
    height=img_height,
)

fig.show()

Hope someone finds this useful.

3 Likes