Customized Selection Under Color-Grouping Highlights Most Recently Selected Group Only

Hi!

I’m trying to create an interactive scatterplot-matrix using Plotly Express and Dash.
One of the components I want to customize is the box-selection - in particular, I want to be able to use group-semantics to, for example, highlight the union of the two most recent selections.

I’ve actually been able to implement this, but for some reason, only points that belong the the most recently selected groups are highlighted, even though in the figure’s data, they are listed under each group’s ‘selectedpoints’.

Am I doing something wrong, or is this behavior intended?

Here’s a video of the behavior: https://youtu.be/TXlg4yL6bTw

And my code (I’ve highlighted the relevant parts with a bunch of #-symbols):

Code
import numpy as np

import plotly.express as px

import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output

oldPoints = []


def Splom_app(df, axes_options, color_by_options):
    """
    This function returns a dash-app that shown an interactive SPLOM with the following functionality:
    A fully linked, brushable SPLOM.
    A dropdown-menu to select the SPLOM's basis.
    A dropdown-menu to select the selection semantic (e.g. union, complement).
    A slider to adjust the alpha/opacity of the points.

    Attributes:
        df (pd.DataFrame): The pandas-Dataframe that holds all of the relevant data.
        axes_options (List<List<String>>): The bases that the user can choose from.
                                           The Strings represent the corresponding column-names of the dataframe.
                                           By default, the first list is chosen the basis.
        color_by_options (List<String>): The attributes that the user can choose to color the points by.
                                         The Strings represent the corresponding column-names of the dataframe.
                                         By default, the first attribute is chosen.
    """
    # initialize the dashboard
    app = dash.Dash()

    # the entries of the options of the dropdown-component of dash have to be formatted very specifically
    # this allows the user to just give a list of lists for the axes-options
    # the values of the dropdown have to be hashable, so another list of the actual df-columns has to be initialized
    dropdown_axes_options = []
    axes_list = []
    i = 0
    for option in axes_options:
        option_line = []
        for attribute in option:
            option_line.append({'label': attribute, 'values': df[attribute]})
        axes_list.append(option_line)
        dropdown_axes_options.append({'label': ', '.join(option), 'value': i})
        i = i + 1

    dropdown_color_options = []
    for attribute in color_by_options:
        dropdown_color_options.append({'label': attribute, 'value': attribute})

    dropdown_semantic_options = []
    for option in [
            'replace', 'complement', 'union', 'intersection', 'set-minus'
    ]:
        dropdown_semantic_options.append({'label': option, 'value': option})

    # use the first value of the options by default
    def_axes = dropdown_axes_options[0].get('value')
    def_color = dropdown_color_options[0].get('value')

#  # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
    # create the SPLOM using plotly express
    fig = px.scatter_matrix(df,
                            dimensions=axes_options[0],
                            color=def_color,
                            opacity=0.6)
#  # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

    # this prevents changes resetting the layout/selection
    fig.layout.selectionrevision = True
    fig.layout.uirevision = True

    # define the app-layout
    app.layout = html.Div([
        dcc.Dropdown(id='axes',
                     options=dropdown_axes_options,
                     multi=False,
                     value=def_axes),
        dcc.Dropdown(id='color_by',
                     options=dropdown_color_options,
                     multi=False,
                     value=def_color),
        dcc.Slider(id='alpha', min=0.1, max=1.0, value=0.6, step=0.1),
        dcc.Dropdown(id='selection_semantic',
                     options=dropdown_semantic_options,
                     multi=False,
                     value=dropdown_semantic_options[0].get('value')),
        dcc.Graph(id='SPLOM')
    ])

    # oldPoints, the previous selection, is a global variable
    global oldPoints
    oldPoints = []
    for group in fig.data:
        oldPoints.append(list(range(len(group.dimensions[0].values))))

    # all points should be selected by default
    for i, group in enumerate(fig.data):
        group.update(selectedpoints=oldPoints[i],
                     selected=dict(marker=dict(opacity=0.6)),
                     unselected=dict(marker=dict(opacity=0.15, color='grey')))

    # define the callback that triggers on updates
    @app.callback(Output(component_id='SPLOM', component_property='figure'), [
        Input(component_id='axes', component_property='value'),
        Input(component_id='alpha', component_property='value'),
        Input('SPLOM', 'selectedData'),
        Input('color_by', 'value'),
        Input('selection_semantic', 'value')
    ])
    def update(axes_selected, alpha, selectedData, color_by,
               selection_semantic):
        ### get the component's id that triggered the update
        # get the name of the component that triggered the callback,
        # from https://dash.plotly.com/advanced-callbacks
        component = dash.callback_context.triggered[0]['prop_id'].split('.')[0]

        # update the corresponding part of the visualization
        if component == 'axes':
            fig.layout.uirevision = False
            fig.layout.selectionrevision = True
            axes = axes_list[axes_selected]
            fig.update_traces(dimensions=axes, overwrite=True)

        elif component == 'alpha':
            fig.layout.uirevision = True
            fig.layout.selectionrevision = True
            for group in fig.data:
                group.update(selected=dict(marker=dict(opacity=alpha)),
                             unselected=dict(marker=dict(opacity=alpha / 3.0)))

#  # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
        elif component == 'SPLOM':
            fig.layout.uirevision = True
            fig.layout.selectionrevision = True

            # oldPoints is a global variable
            global oldPoints
            global group_dict

            ### compute the selected points under the given selection semantic
            # calculate each group (given by its curve-number) individually
            selectedPoints = []
            for group in fig.data:
                selectedPoints.append([])
            for point in selectedData['points']:
                selectedPoints[point['curveNumber']].append(
                    point['pointNumber'])

            for i, group in enumerate(fig.data):
                # selection semantic is complement of the current selection
                if selection_semantic == 'complement':
                    all_points = list(range(len(group.dimensions[0].values)))
                    selectedPoints[i] = [
                        item for item in all_points
                        if item not in selectedPoints[i]
                    ]

                # selection semantic is union of the current selection and the previous one
                elif selection_semantic == 'union':
                    selectedPoints[i] = selectedPoints[i] + oldPoints[i]

                # selection semantic is intersection of the current selection and the previous one
                elif selection_semantic == 'intersection':
                    selectedPoints[i] = [
                        item for item in selectedPoints[i]
                        if item in oldPoints[i]
                    ]

                # selection semantic is the previous selection without the current selection
                elif selection_semantic == 'set-minus':
                    selectedPoints[i] = [
                        item for item in oldPoints[i]
                        if item not in selectedPoints[i]
                    ]
                    # if after a set-minus operation there are no points selected, select all
                    length = 0
                    for points in selectedPoints:
                        length = length + len(points)
                    if length == 0:
                        for i in range(len(selectedPoints)):
                            selectedPoints = list(
                                range(len(fig.data[i].dimensions[0].values)))

            # update the selections of each group
            for i, group in enumerate(fig.data):
                factor = group['name']
                group.update(selectedpoints=selectedPoints[i], overwrite=True)

            # save the current selection for future reference
            oldPoints = selectedPoints
#  # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

        return fig

    return app

Thanks!