How to use plotly.express.imshow facet_row argument?

Consider the following code:

import pandas as pd
import numpy as np
import plotly.express as px

# Create the index for the data frame
x = np.linspace(-1,1, 6)
y = np.linspace(-1,1,6)
n_channel = [1, 2, 3, 4]

xx, yy = np.meshgrid(x, y)

zzz = np.random.randn(len(y)*len(n_channel),len(x))

df = pd.DataFrame(
	zzz,
	columns = pd.Index(x, name='x (m)'),
	index = pd.MultiIndex.from_product([y, n_channel], names=['y (m)', 'n_channel']),
)

print(df)

fig = px.imshow(
	df.reset_index('n_channel'),
	facet_col = 'n_channel',
)
fig.write_html(
	'plot.html',
	include_plotlyjs = 'cdn',
)

The data frame looks like this:

x (m)                -1.0      -0.6      -0.2       0.2       0.6       1.0
y (m) n_channel                                                            
-1.0  1         -0.492584  0.599464  0.097405 -0.177793 -0.027311  1.468527
      2          0.202147  0.449809 -2.047460 -1.392223  0.245228  1.220419
      3          0.139111 -0.699596  1.754103 -0.141732 -1.494373 -0.003184
      4          0.124390  0.245113 -0.031949  1.938560  1.418563 -0.787295
-0.6  1          1.112547  0.307750 -1.206242 -0.739546  0.038905 -0.923485
      2         -0.900733 -1.094717  0.770876 -1.973305  2.677651  3.072124
      3         -0.279864 -1.341024  2.750811 -1.401604  0.929714  0.658087
      4         -1.038905 -1.038625  0.112878  1.112139 -0.799305 -0.934813
-0.2  1          0.332704  1.321129  0.241799 -1.100657 -0.927649 -1.928624
      2         -0.576210  0.257960 -0.196699 -0.245751  0.575648 -0.703353
      3         -0.549881 -1.208282  0.959120  1.852333  1.452697 -0.562802
      4         -0.433256 -0.339644 -1.636592 -1.022501 -0.614497  1.085253
 0.2  1          0.378474 -0.829495 -1.313322 -0.654698 -0.644115  2.175938
      2          0.567393 -0.340301  1.304942  0.197879  0.309288 -0.126187
      3          0.209954  0.161299 -0.362754 -0.328356 -0.106934 -0.238329
      4         -0.284447 -0.367920 -0.275830 -0.776649  0.656279  0.056389
 0.6  1          1.174153 -1.112658  1.245117 -0.395144  0.471050  0.165074
      2         -0.220246  1.063194  0.292873  0.266250 -0.175274  0.225985
      3          0.301462  0.737581  0.271691  0.936558  1.007112  1.857389
      4         -0.689441  3.369569  0.675700  0.077706  0.152062 -0.533258
 1.0  1          0.732183  0.041873  1.156681  0.841262 -0.984433  1.313900
      2          0.157533  0.723356 -0.786721  0.150939  0.164049 -0.351816
      3         -0.390037 -1.513096  0.255813 -1.365759  0.570145  1.630885
      4          0.318037 -1.103191  1.472340 -0.218038  0.990673 -1.565340

and I expect it to produce 4 heatmaps, each of them similar to this one:

enter image description here

but instead I get AttributeError: 'DataFrame' object has no attribute 'dims'. If instead I do like this

for n in n_channel:
	fig = px.imshow(
		df.query(f'n_channel=={n}').reset_index('n_channel', drop=True),
	)
	fig.write_html(
		f'plot_{n}.html',
		include_plotlyjs = 'cdn',
	)

then this produces the 4 plots, but separated and (of course) with the axes not connected.

Is it possible to use the facet_row argument based on one column, similarly as it can be done e.g. with px.scatter?