SelectedData on 3D Scatterplot

I’m trying to return selectedData when clicking on points of Scatter plot but it just worked with 2D Scatterplot. Below is my code:

import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import dash
import dash_html_components as html
import dash_core_components as dcc
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc
from dash.exceptions import PreventUpdate

UMAPdf = pd.read_csv('https://raw.githubusercontent.com/hoatranobita/Test-data/main/UMAPdf.csv')
UMAPdf['clusters'] = UMAPdf['clusters'].astype(str)

TSNEdf = pd.read_csv('https://raw.githubusercontent.com/hoatranobita/Test-data/main/TSNEdf.csv')
TSNEdf['clusters'] = TSNEdf['clusters'].astype(str)

app = dash.Dash(__name__,external_stylesheets=[dbc.themes.LUX],suppress_callback_exceptions=True)

app.layout = html.Div([
            dbc.Row([
                dbc.Col([
                    dbc.RadioItems(
                        options=[
                            {"label": "UMAP", "value": 'UMAP'},
                            {"label": "tSNE", "value": 'tSNE'}],
                        value='UMAP',
                        id="radioitems-input_1",
                        inline=True
                    ),
                ],width={'size':2,"offset":0,'order':1},style={'text-align':'center'}), #,style={'padding-top' : 10}
                dbc.Col([
                    dbc.RadioItems(
                        options=[
                            {"label": "2D", "value": '2D'},
                            {"label": "3D", "value": '3D'}],
                        value='2D',
                        id="radioitems-input_2",
                        inline=True
                    ),
                ],width={'size':2,"offset":0,'order':1},style={'text-align':'center'}), #,style={'padding-top' : 10}             
                dbc.Col([
                    dbc.Row([
                        dbc.Col([
                            html.H6('X_AXIS',style={'padding-top' : 10,'padding-right' : 2})
                        ],width={'size':4,"offset":0,'order':1}),
                        dbc.Col([
                            dcc.Dropdown(id="x_axis",
                                         options=[],
                                         value=[],
                                         multi=False,
                                         disabled=False,
                                         clearable=False,
                                         searchable=True)
                        ],width={'size':8,"offset":0,'order':1},style={'text-align':'center'})
                    ])
                ],width={'size':2,"offset":0,'order':1},style={'text-align':'center'}),            
                dbc.Col([
                    dbc.Row([
                        dbc.Col([
                            html.H6('Y_AXIS',style={'padding-top' : 10,'padding-right' : 2})
                        ],width={'size':4,"offset":0,'order':1}),
                        dbc.Col([
                            dcc.Dropdown(id="y_axis",
                                         options=[],
                                         value=[],
                                         multi=False,
                                         disabled=False,
                                         clearable=False,
                                         searchable=True)
                        ],width={'size':8,"offset":0,'order':1},style={'text-align':'center'})
                    ])
                ],width={'size':2,"offset":0,'order':1},style={'text-align':'center'}),            
                dbc.Col([
                    dbc.Row([
                        dbc.Col([
                            html.H6('Z_AXIS',style={'padding-top' : 10,'padding-right' : 2})
                        ],width={'size':4,"offset":0,'order':1}),
                        dbc.Col([
                            dcc.Dropdown(id="z_axis",
                                         options=[],
                                         value=[],
                                         multi=False,
                                         disabled=False,
                                         clearable=False,
                                         searchable=True)
                        ],width={'size':8,"offset":0,'order':1},style={'text-align':'center'})
                    ])
                ],width={'size':2,"offset":0,'order':1},style={'text-align':'center'}),                           
            ], className='p-2 align-items-stretch'),
            
            dbc.Row([      
                dbc.Col([
                    dbc.Card([
                        dbc.CardBody([
                             dbc.Row([
                                dbc.Col([                                    
                                    dbc.Button("SVG", size="sm",className="me-1",id='btn_1',color="secondary"),
                                    dcc.Download(id='download_1'),
                                    dbc.Button("HTML", size="sm",className="me-1",id='btn_2',color="secondary"),
                                    dcc.Download(id='download_2'),
                                    dbc.Button("csv", size="sm",className="me-1",id='btn_3',color="secondary"),
                                    dcc.Download(id='download_3')                               
                                ],width={'size':12,'offset':0,'order':1},style={'text-align':'right'}),
                            ]),
                            
                            dbc.Row([
                                dbc.Col([
                                    html.Div(id='chart_title'),
                                ],width={'size':12,'offset':0,'order':1},style={'text-align':'center'}),
                            ]),
                            dbc.Row([
                                dbc.Col([
                                     dcc.Loading(children=[dcc.Graph(id='scatter_chart',figure={},style={'height': '450px'},selectedData={'points': [{'hovertext':'X24_CTGACACAATGC'}]})],color='#119DFF',type='dot'),
                                ],width={'size':12,'offset':0,'order':1}),
                            ]),                        
                        ])
                    ], className='h-100 text-left')
                ], xs=6),
                dbc.Col([
                    dbc.Row([
                        dbc.Col([
                            dbc.Card([
                                dbc.CardBody([
                                    dbc.Row([
                                        dbc.Col([
                                            dbc.Button("SVG", size="sm",className="me-1",id='btn_4',color="secondary"),
                                            dcc.Download(id='download_4'),
                                            dbc.Button("HTML", size="sm",className="me-1",id='btn_5',color="secondary"),
                                            dcc.Download(id='download_5')
                                        ],width={'size':12,'offset':0,'order':1},style={'text-align':'right'}),
                                    ]),
                                    dbc.Row([
                                        dbc.Col([
                                            html.Span(f"Sample Violin Plot for Samples",style={'text-align':'center'}),
                                            dcc.Loading(children=[html.Div(id='violin_chart')],color='#119DFF',type='dot')
                                        ],width={'size':12,'offset':0,'order':1},style={'text-align':'center'}),
                                    ]),
                                ])
                            ], className='h-100 text-left')
                        ])
                    ]),
                ])
            ])
])


@app.callback([Output('x_axis', 'options'),
               Output('x_axis', 'value'),
               Output('y_axis', 'options'),
               Output('y_axis', 'value'),
               Output('z_axis', 'options'),
               Output('z_axis', 'value')],
             [Input('radioitems-input_1','value')])

def update_axis(radioitems_input_1):
    if radioitems_input_1 == 'UMAP':
        value = [{'label': 'UMAP1', 'value': 'UMAP1'},
                 {'label': 'UMAP2', 'value': 'UMAP2'},
                 {'label': 'UMAP3', 'value': 'UMAP3'}]
        x='UMAP1'
        y='UMAP2'
        z='UMAP3'
        return value, x,value,y, value,z  
    elif radioitems_input_1 == 'tSNE':
        value = [{'label': 'TSNE1', 'value': 'TSNE1'},
                 {'label': 'TSNE2', 'value': 'TSNE2'},
                 {'label': 'TSNE3', 'value': 'TSNE3'}]
        x='TSNE1'
        y='TSNE2'
        z='TSNE3'
        return value, x,value,y, value,z   

@app.callback([Output('scatter_chart', 'figure'),
               Output('chart_title', 'children')],
             [Input('radioitems-input_1','value'),
              Input('radioitems-input_2','value'),
              Input('x_axis','value'),
              Input('y_axis','value'),
              Input('z_axis','value')
              ])

def update_scatter_chart(radioitems_input_1,radioitems_input_2,x_axis,y_axis,z_axis):
    if radioitems_input_1 == 'tSNE' and radioitems_input_2 == '2D':
        figTSNE = px.scatter(TSNEdf, x=TSNEdf[x_axis], y=TSNEdf[y_axis], color='clusters', 
                      labels='Id',
                      hover_name = 'Id')
        figTSNE.update_layout(template='plotly_white',
                              margin=dict(l=0,r=0,t=0,b=0),
                              clickmode='event+select',
                              #dragmode='select',
                              hovermode='closest')
        return figTSNE, html.Span(f'2D TSNE',style={'text-align':'center'})
    
    elif radioitems_input_1 == 'tSNE' and radioitems_input_2 == '3D':
        figTSNE3D = px.scatter_3d(TSNEdf, x=TSNEdf[x_axis], y=TSNEdf[y_axis], z=TSNEdf[z_axis], 
                         color='clusters', 
                         hover_name = 'Id')
        figTSNE3D.update_layout(template='plotly_white', 
                                legend_traceorder="normal",
                                margin=dict(l=0,r=0,t=0,b=0),
                                clickmode='event+select',
                                #dragmode='select',
                                hovermode='closest')
        figTSNE3D.update_traces(marker_size = 2)       
        return figTSNE3D, html.Span(f'3D TSNE',style={'text-align':'center'})        

    elif radioitems_input_1 == 'UMAP' and radioitems_input_2 == '2D':
        figUMAP = px.scatter(UMAPdf, x=UMAPdf[x_axis], y=UMAPdf[y_axis], color='clusters', 
                      labels='Id',
                      hover_name = 'Id')
        figUMAP.update_layout(template='plotly_white',
                              margin=dict(l=0,r=0,t=0,b=0),
                              clickmode='event+select',
                              #dragmode='select',
                              hovermode='closest')    
        return figUMAP, html.Span(f'2D UMAP',style={'text-align':'center'}) 
    
    elif radioitems_input_1 == 'UMAP' and radioitems_input_2 == '3D':
        figUMAP3D = px.scatter_3d(UMAPdf, x=UMAPdf[x_axis], y=UMAPdf[y_axis], z=UMAPdf[z_axis], 
                         color='clusters', 
                         hover_name = 'Id')
        figUMAP3D.update_layout(template='plotly_white', 
                                legend_traceorder="normal",
                                margin=dict(l=0,r=0,t=0,b=0),
                                clickmode='event+select',
                                #dragmode='select',
                                hovermode='closest')
        figUMAP3D.update_traces(marker_size = 2)      
        return figUMAP3D, html.Span(f'3D UMAP',style={'text-align':'center'})

@app.callback(Output('violin_chart', 'children'),
             [Input('scatter_chart','selectedData')])

def update_violin_chart(selectedData):
    if selectedData != []:
        id_name = [p['hovertext'] for p in selectedData['points']]

    if selectedData == []:
        id_name ='X24_CTGACACAATGC'
                   
    return html.Span(f'ID Name {id_name}')


if __name__ == "__main__":
    app.run_server(debug=False,port=1119)

I’m curious is there anyway to return point when clicking on 3D Scatterplot.

Thank you.

HI @hoatran , which information do you need to extract from the point? In the past, I used the clickdata property of the dcc.Graph() to extract the coordinates, for example. I am not sure which information is included in the clickdata, though.

1 Like

Hi @AIMPED ,

I need to return points name, in my case it is Id from dataframe.

@jinnyzor: Sorry if it bothering you but do you have experience on it? So my code is working with 2D Scatterplot but not 3D Scatterplot.

Hello @hoatran,

@AIMPED, is correct. It looks like select is not an option for these 3d charts, here is an alteration though:

@app.callback(Output('violin_chart', 'children'),
              [Input('scatter_chart', 'selectedData')],
              [Input('scatter_chart', 'clickData')])
def update_violin_chart(selectedData, clickData):
    if selectedData != []:
        id_name = [p['hovertext'] for p in selectedData['points']]

    if clickData:
        id_name = [p['hovertext'] for p in clickData['points']]

    if selectedData == [] and clickData == []:
        id_name = 'X24_CTGACACAATGC'

    return html.Span(f'ID Name {id_name}')

And yes, I have encountered similar issues with pie charts. :wink:

1 Like

@jinnyzor: Sorry for late response. Problem here is that I want to return multiple points so that clickdata maybe not the best way.

You can stack clickData with a dcc.Store. I don’t know if there is a way to select multiple at a single go.