Spacing/Outline between pixels in px.imshow

Hello!

I’m currently trying to create a heatmap using Plotly in order to be implemented into Dash and decided to convert my data into a pivot table to be represented as a heatmap using px.imshow.
Reason I decided to choose px.imshow instead of go.heatmap is because it has the option to make the pixel/cell sizes consistent no matter how the figure margins or number of data represented, which seems to be something go.heatmap isn’t capable of.
The issue with not using go.heatmap is that I seem to be abandoning the convenient xgap and ygap options to make it easier to differentiate between similar looking pixels.


This is my current code and output using px.imshow:

import plotly.express as px

h = len(count_pivot.index.tolist()) * 30
w = len(count_pivot.columns.tolist()) * 30

fig = px.imshow(count_pivot, x=count_pivot.columns, y=count_pivot.index)
fig.update_xaxes(tickmode='linear', showgrid=False)
fig.update_yaxes(tickmode='linear', showgrid=False)
fig.update_layout({
    'plot_bgcolor': 'rgba(255, 255, 255, 1)',
    'paper_bgcolor': 'rgba(255, 255, 255, 1)',
}, height=h, width=w)
fig.show()


What I want using xgap and ygap in go.heatmap:


Is there a way to incorporate this same functionality into px.imshow either as gaps/paddings between pixels or as simple white outlines around each pixel?

Thank you!

Never mind! I simply just needed to run a couple for loops to utilize fig.add_shape in order to create the rectangles necessary to create the grid.

I found the hypervalent’s idea and elaborated it. Here’s the code that visualizes hm pandas.DataFrame.

cell_size = 35
row_title_width = 200
width = cell_size*len(hm.columns) + row_title_width
height = cell_size*len(hm.index)

fig = px.imshow(hm, width=width, height=height, color_continuous_scale='RdYlBu_r')

for i in range(len(hm.columns)):
    fig.add_shape(type="line", x0=0.5 + i, y0=-0.5, x1=0.5 + i, y1=len(hm.index) - 0.5, line=dict(color="white", width=2))

for i in range(len(hm.index)):
    fig.add_shape(type="line", x0=-0.5, y0=0.5 + i, x1=len(hm.columns) - 0.5, y1=0.5 + i, line=dict(color="white", width=2))

fig.show()

you can use the xgap and ygap parameters in figure.update_traces