Consider the following code:
import pandas as pd
import numpy as np
import plotly.express as px
# Create the index for the data frame
x = np.linspace(-1,1, 6)
y = np.linspace(-1,1,6)
n_channel = [1, 2, 3, 4]
xx, yy = np.meshgrid(x, y)
zzz = np.random.randn(len(y)*len(n_channel),len(x))
df = pd.DataFrame(
zzz,
columns = pd.Index(x, name='x (m)'),
index = pd.MultiIndex.from_product([y, n_channel], names=['y (m)', 'n_channel']),
)
print(df)
fig = px.imshow(
df.reset_index('n_channel'),
facet_col = 'n_channel',
)
fig.write_html(
'plot.html',
include_plotlyjs = 'cdn',
)
The data frame looks like this:
x (m) -1.0 -0.6 -0.2 0.2 0.6 1.0
y (m) n_channel
-1.0 1 -0.492584 0.599464 0.097405 -0.177793 -0.027311 1.468527
2 0.202147 0.449809 -2.047460 -1.392223 0.245228 1.220419
3 0.139111 -0.699596 1.754103 -0.141732 -1.494373 -0.003184
4 0.124390 0.245113 -0.031949 1.938560 1.418563 -0.787295
-0.6 1 1.112547 0.307750 -1.206242 -0.739546 0.038905 -0.923485
2 -0.900733 -1.094717 0.770876 -1.973305 2.677651 3.072124
3 -0.279864 -1.341024 2.750811 -1.401604 0.929714 0.658087
4 -1.038905 -1.038625 0.112878 1.112139 -0.799305 -0.934813
-0.2 1 0.332704 1.321129 0.241799 -1.100657 -0.927649 -1.928624
2 -0.576210 0.257960 -0.196699 -0.245751 0.575648 -0.703353
3 -0.549881 -1.208282 0.959120 1.852333 1.452697 -0.562802
4 -0.433256 -0.339644 -1.636592 -1.022501 -0.614497 1.085253
0.2 1 0.378474 -0.829495 -1.313322 -0.654698 -0.644115 2.175938
2 0.567393 -0.340301 1.304942 0.197879 0.309288 -0.126187
3 0.209954 0.161299 -0.362754 -0.328356 -0.106934 -0.238329
4 -0.284447 -0.367920 -0.275830 -0.776649 0.656279 0.056389
0.6 1 1.174153 -1.112658 1.245117 -0.395144 0.471050 0.165074
2 -0.220246 1.063194 0.292873 0.266250 -0.175274 0.225985
3 0.301462 0.737581 0.271691 0.936558 1.007112 1.857389
4 -0.689441 3.369569 0.675700 0.077706 0.152062 -0.533258
1.0 1 0.732183 0.041873 1.156681 0.841262 -0.984433 1.313900
2 0.157533 0.723356 -0.786721 0.150939 0.164049 -0.351816
3 -0.390037 -1.513096 0.255813 -1.365759 0.570145 1.630885
4 0.318037 -1.103191 1.472340 -0.218038 0.990673 -1.565340
and I expect it to produce 4 heatmaps, each of them similar to this one:
but instead I get AttributeError: 'DataFrame' object has no attribute 'dims'
. If instead I do like this
for n in n_channel:
fig = px.imshow(
df.query(f'n_channel=={n}').reset_index('n_channel', drop=True),
)
fig.write_html(
f'plot_{n}.html',
include_plotlyjs = 'cdn',
)
then this produces the 4 plots, but separated and (of course) with the axes not connected.
Is it possible to use the facet_row
argument based on one column, similarly as it can be done e.g. with px.scatter
?