How to draw a box to delete points in 2D scatter plot

I created a 2D scatter plot and loaded an jpg image for each point using plotly dash. I can view different images when the mouse hovers over different points. However, I would like to add more interactivity to the plot. Specifically, I want to remove a cluster of points from the plot by drawing a box enclosing them and also their associated jpg images. Does anybody know how to do that?

Many thanks.

Hi @tianjuntommy,

You can make use of the Box select and Lasso select features at the top right corner of your plotly graph, then retrieve the selectedData property of your figure to filter your dataframe. See the full documentation for selectedData here.

Also, see below a basic working examples using the Iris dataset. Not sure it’s the most elegant way, but does the job.

from dash import Dash, dcc, html, Input, Output
import plotly.express as px

app = Dash(__name__)

df = px.data.iris()

fig = px.scatter(
        df, x="sepal_width", y="sepal_length", 
        color="species", size='petal_length', 
        hover_data=['petal_width'])

app.layout = html.Div([
    html.H4('Interactive scatter plot with Iris dataset'),
    dcc.Graph(id="scatter-plot", figure = fig),
    html.Button('Reset plot', id='reset-button', n_clicks=0)
])

@app.callback(Output('scatter-plot', 'figure'),
              Output('reset-button', 'n_clicks'),
              Input('reset-button', 'n_clicks'),
              Input('scatter-plot', 'selectedData'),
             prevent_initial_call = True)
def display_selected_data(n_clicks, selectedData):
    if n_clicks > 0:
        fig = px.scatter(
                    df, x="sepal_width", y="sepal_length", 
                    color="species", size='petal_length', 
                    hover_data=['petal_width'])
        return [fig, 0]
    else: 
        x = [x["x"] for x in selectedData["points"]]
        y = [x["y"] for x in selectedData["points"]]
        rows_to_remove = []
        for i in range(len(x)):
            rows_to_remove.append((df[(df['sepal_width'] == x[i]) & (df['sepal_length'] == y[i])].index).tolist())
        flat_list = [item for sublist in rows_to_remove for item in sublist]
        filtered_df = df.drop(flat_list, axis = 0)
        
        fig = px.scatter(
                        filtered_df,
                        x="sepal_width", y="sepal_length", 
                        color="species", size='petal_length', 
                        hover_data=['petal_width'])
        return [fig, 0]

app.run_server(debug=True)

Hi @jhupiterz ,

Thank you for the example code. I followed your instruction, and wrote a similar code for my case, but I still have some problems:

  1. In the function display_selected_data, I can remove selected data points from my list MJD and DM. However, I would like to remove more data points. I adjusted the code ‘MJD_filtered = list(set(MJD) - set(x))’ to ‘MJD = list(set(MJD) - set(x))’, but got an error ‘cannot access local variable ‘MJD’ where it is not associated with a value’. I am wondering if you know how to solve this.

  2. I also added a function display_hover so that when I hover the mouse over a data point its associated image file can pop up. However, that means every time I delete some data points, I also need to update the list of paths to the image files. I am wondering if you know how to keep the three lists, i.e. MJD, DM and cand_plots, consistent throughout the process of deleting data points.

Please see below my code:

import plotly.express as px
from dash import Dash, dcc, html, Input, Output, no_update, callback
import plotly.graph_objects as go
import base64
import json

fig = go.Figure(data=[go.Scatter(x=MJD, y=DM, mode="markers",)])
fig.update_traces(hoverinfo="none", hovertemplate=None)
fig.update_layout(xaxis=dict(title='MJD'),yaxis=dict(title='DM'))
app = Dash(__name__)
app.layout = html.Div([dcc.Graph(id="graph-basic-2", figure=fig, clear_on_unhover=True),
                       dcc.Tooltip(id="graph-tooltip"),
                       dcc.Store(id= 'data-selection'),
                       html.Button('Reset plot', id='reset-button', n_clicks=0)
                      ])
@app.callback(Output('graph-basic-2', 'figure'),
              Output('reset-button', 'n_clicks'),
              Input('reset-button', 'n_clicks'),
              Input('graph-basic-2', 'selectedData'),
             prevent_initial_call = True
             )
def display_selected_data(n_clicks, selectedData):
    if n_clicks > 0:
        fig = go.Figure(data=[go.Scatter(x=MJD, y=DM, mode="markers",)])
        return [fig, 0]
    else: 
        x = [x["x"] for x in selectedData["points"]]
        y = [x["y"] for x in selectedData["points"]]
        MJD_filtered = list(set(MJD) - set(x))
        DM_filtered = list(set(DM) - set(y))
        fig = go.Figure(data=[go.Scatter(x=MJD_filtered, y=DM_filtered, mode="markers",)])
        return [fig, 0]
@callback(
    Output("graph-tooltip", "show"),    
    Output("graph-tooltip", "bbox"),
    Output("graph-tooltip", "children"),
    Input("graph-basic-2", "hoverData"),
)
def display_hover(hoverData):
    if hoverData is None:
        return False, no_update, no_update
    pt = hoverData["points"][0]
    bbox = pt["bbox"]
    num = pt["pointNumber"]
    with open(cand_plots[num], "rb") as image_file:
        img_data = base64.b64encode(image_file.read())
        img_data = img_data.decode()
        img_data = "{}{}".format("data:image/jpg;base64, ", img_data)
    children = [
        html.Div([
            html.Img(src=img_data, style={"width": "100%"}),
        ], style={'width': '200px', 'white-space': 'normal'})
    ]

    return True, bbox, children
app.run(debug=True)