px.Scatter Color attribute with Categorical Pandas column

I have just upgrade my anaconda environment, jupyterLabs, plotly, and pandas environments and all of a sudden the color attribute in px.scatter that had been using pandas CategoricalD type column fine has broken and I havne’t been able to figure out why. I can reset the column type to string - and it works fine. The error is a KeyError - in a filtered dataframe the key doesn’t exist. When I use the full dataframe (not filtered), the category column type works fine in the color attribute. The error happens whether or not I filter the dataframe within the px.scatter call or create a new df outside of the call, as well as creating a copy of the df when outside of the px.scatter call. Thoughts?

sample below → df names and columns have been modified for simplicity

Sample px.scatter call:

fig = px.scatter(
    data_frame=df.loc[df['filter_col'] == filter],
    x='Age',
    y='Size',
    color='Division',  # <-- Division is a Categorical data type in the dataframe
)

Error:

File /opt/anaconda3/envs/jupyterlab/lib/python3.10/site-packages/plotly/express/_chart_types.py:66, in scatter(data_frame, x, y, color, symbol, size, hover_name, hover_data, custom_data, text, facet_row, facet_col, facet_col_wrap, facet_row_spacing, facet_col_spacing, error_x, error_x_minus, error_y, error_y_minus, animation_frame, animation_group, category_orders, labels, orientation, color_discrete_sequence, color_discrete_map, color_continuous_scale, range_color, color_continuous_midpoint, symbol_sequence, symbol_map, opacity, size_max, marginal_x, marginal_y, trendline, trendline_options, trendline_color_override, trendline_scope, log_x, log_y, range_x, range_y, render_mode, title, template, width, height)
     12 def scatter(
     13     data_frame=None,
     14     x=None,
   (...)
     60     height=None,
     61 ) -> go.Figure:
     62     """
     63     In a scatter plot, each row of `data_frame` is represented by a symbol
     64     mark in 2D space.
     65     """
---> 66     return make_figure(args=locals(), constructor=go.Scatter)

File /opt/anaconda3/envs/jupyterlab/lib/python3.10/site-packages/plotly/express/_core.py:2003, in make_figure(args, constructor, trace_patch, layout_patch)
   1999 trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
   2000     args, constructor, trace_patch, layout_patch
   2001 )
   2002 grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
-> 2003 groups, orders = get_groups_and_orders(args, grouper)
   2005 col_labels = []
   2006 row_labels = []

File /opt/anaconda3/envs/jupyterlab/lib/python3.10/site-packages/plotly/express/_core.py:1978, in get_groups_and_orders(args, grouper)
   1975                 g.insert(i, "")
   1976     full_sorted_group_names = [tuple(g) for g in full_sorted_group_names]
-> 1978     groups = {
   1979         sf: grouped.get_group(s if len(s) > 1 else s[0])
   1980         for sf, s in zip(full_sorted_group_names, sorted_group_names)
   1981     }
   1982 return groups, orders

File /opt/anaconda3/envs/jupyterlab/lib/python3.10/site-packages/plotly/express/_core.py:1979, in <dictcomp>(.0)
   1975                 g.insert(i, "")
   1976     full_sorted_group_names = [tuple(g) for g in full_sorted_group_names]
   1978     groups = {
-> 1979         sf: grouped.get_group(s if len(s) > 1 else s[0])
   1980         for sf, s in zip(full_sorted_group_names, sorted_group_names)
   1981     }
   1982 return groups, orders

File /opt/anaconda3/envs/jupyterlab/lib/python3.10/site-packages/pandas/core/groupby/groupby.py:747, in BaseGroupBy.get_group(self, name, obj)
    745 inds = self._get_index(name)
    746 if not len(inds):
--> 747     raise KeyError(name)
    749 return obj._take_with_is_copy(inds, axis=self.axis)

KeyError: 'Division1' **<-----This key doesn't exist in the filtered dataframe, but exists in the full dataframe**

Environment to Upgraded:

  • Python = 3.9.7 → 3.10.4
  • JupyterLab = 3.2.9 → 3.4.4
  • Plotly = 5.7 → 5.10
  • Pandas = 1.3.4 → 1.4.4