Get n_clicks from trace in plot?

I would like to use traces as a sort of button in my plot for a few reasons:

  • I need to be able to put the “button” in an exact location on the plot.
  • I need to be able to show/hide it easily using other callbacks.

I started to setup a basic example but I’m running into two main issues

  • I have no way to get n_clicks from a trace object. From what I understand n_clicks only applies to dash_html_components objects, so I can only get n_clicks from the Div object for the plot, which is not really what I want.
  • I also need to get the number of clicks specific to each trace.

I can sort of get what trace is clicked on in this example, but the clickData state lingers so even when I click somewhere else in the plot, it will count it as clicking the last trace that was clicked.

My example code is below. Is this a case where I probably need to build my own component? Thanks!

import json
from textwrap import dedent as d
import dash
import dash_core_components as dcc
import dash_html_components as html
import plotly.graph_objs as go
from dash.dependencies import Input, Output, State

external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']

app = dash.Dash(__name__, external_stylesheets=external_stylesheets)

styles = {
    'pre': {
        'border': 'thin lightgrey solid',
        'overflowX': 'scroll'
    }
}

triangle_trace = go.Scatter(
    x=[0, 1, 2, 0],
    y=[0, 2, 0, 0],
    fill='toself',
    mode='lines',
    name='triangle'
)

rectangle_trace = go.Scatter(
    x=[3, 3, 5, 5, 3],
    y=[0.5, 1.5, 1.5, 0.5, 0.5],
    fill='toself',
    mode='lines',
    name='rectangle'
)

shape_lo = go.Layout(
    clickmode='event',
)

fig = go.Figure(data=[triangle_trace, rectangle_trace], layout=shape_lo)

app.layout = html.Div([
    html.Div([
        dcc.Graph(
            id='basic-interactions',
            figure=fig
        )
    ], id='graph-div'),
    html.Div(className='row', children=[
        html.Div([
            dcc.Markdown(d("""
            **Hover Data**

            Mouse over values in the graph.
        """)),
            html.Pre(id='hover-data', style=styles['pre'])
        ], className='three columns'),

        html.Div([
            dcc.Markdown(d("""
            **Click Data**

            Click on points in the graph.
        """)),
            html.Pre(id='click-data', style=styles['pre']),
        ], className='three columns'),

        html.Div([
            dcc.Markdown(d("""
            **Selection Data**

            Choose the lasso or rectangle tool in the graph's menu
            bar and then select points in the graph.

            Note that if `layout.clickmode = 'event+select'`, selection data also 
            accumulates (or un-accumulates) selected data if you hold down the shift
            button while clicking.
        """)),
            html.Pre(id='selected-data', style=styles['pre']),
        ], className='three columns'),

        html.Div([
            dcc.Markdown(d("""
            **Zoom and Relayout Data**

            Click and drag on the graph to zoom or click on the zoom
            buttons in the graph's menu bar.
            Clicking on legend items will also fire
            this event.
        """)),
            html.Pre(id='relayout-data', style=styles['pre']),
        ], className='three columns')
    ]),
    html.Div(className='row', children=[
        html.Div([
            dcc.Markdown(d("""
             **CUSTOM CLICK DATA HERE**
         """)),
            html.Pre(id='custom-data', style=styles['pre']),
        ], className='twelve columns')
    ])
])


@app.callback(
    Output('custom-data', 'children'),
    [Input('basic-interactions', 'clickData'),
     Input('graph-div', 'n_clicks')])
def custom_click(clickData, n_clicks):
    shape_index = {
        0: 'the triangle',
        1: 'the rectangle'
    }
    try:
        shape = shape_index[clickData['points'][0]['curveNumber']]
    except KeyError:
        shape = 'nothing'

    return 'You clicked on {}\ntotal click count: {}'.format(shape, n_clicks)


@app.callback(
    Output('hover-data', 'children'),
    [Input('basic-interactions', 'hoverData')])
def display_hover_data(hoverData):
    return json.dumps(hoverData, indent=2)


@app.callback(
    Output('click-data', 'children'),
    [Input('basic-interactions', 'clickData')])
def display_click_data(clickData):
    return json.dumps(clickData, indent=2)


@app.callback(
    Output('selected-data', 'children'),
    [Input('basic-interactions', 'selectedData')])
def display_selected_data(selectedData):
    return json.dumps(selectedData, indent=2)


@app.callback(
    Output('relayout-data', 'children'),
    [Input('basic-interactions', 'relayoutData')])
def display_relayout_data(relayoutData):
    return json.dumps(relayoutData, indent=2)


if __name__ == '__main__':
    app.run_server(debug=True)

After toying with this (updated code below), I realized that using State would allow me to store my click count on each trace the way I want. I almost have what I want, the only caveat is that I have to create a “background” trace that covers the entire area of the plot (set to opacity=0). If I don’t, a click on the plot but not on the shape trace will still count towards the last trace clicked or hovered on.

This seems a little hacky but may be the only way to go. Any ideas on if there’s a way to improve this?

import json
from textwrap import dedent as d
import dash
import dash_core_components as dcc
import dash_html_components as html
import plotly.graph_objs as go
from dash.dependencies import Input, Output, State

external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']

app = dash.Dash(__name__, external_stylesheets=external_stylesheets)

styles = {
    'pre': {
        'border': 'thin lightgrey solid',
        'overflowX': 'scroll'
    }
}

triangle_trace = go.Scatter(
    x=[0, 1, 2, 0],
    y=[0, 2, 0, 0],
    fill='toself',
    mode='lines',
    name='triangle'
)

rectangle_trace = go.Scatter(
    x=[3, 3, 5, 5, 3],
    y=[0.5, 1.5, 1.5, 0.5, 0.5],
    fill='toself',
    mode='lines',
    name='rectangle'
)

bg_trace = go.Scatter(
    x=[0, 5, 5, 0, 0],
    y=[0, 0, 2, 2, 0],
    fill='toself',
    mode='none',
    name='background',
    opacity=0
)

shape_lo = go.Layout(
    clickmode='event',
)

fig = go.Figure(data=[bg_trace, triangle_trace, rectangle_trace], layout=shape_lo)

app.layout = html.Div([
    html.Div([
        dcc.Graph(
            id='basic-interactions',
            figure=fig
        )
    ], id='graph-div'),
    html.Div(className='row', children=[
        html.Div([
            dcc.Markdown(d("""
            **Hover Data**

            Mouse over values in the graph.
        """)),
            html.Pre(id='hover-data', style=styles['pre'])
        ], className='four columns'),

        html.Div([
            dcc.Markdown(d("""
            **Click Data**

            Click on points in the graph.
        """)),
            html.Pre(id='click-data', style=styles['pre']),
        ], className='four columns'),

        html.Div([
            dcc.Markdown(d("""
            **CLICK TRACKING**
        """)),
            html.Pre(id='clicktrack-data', style=styles['pre'], children='{}'),
        ], className='four columns')
    ])
])


@app.callback(
    Output('clicktrack-data', 'children'),
    [Input('graph-div', 'n_clicks'),
     Input('basic-interactions', 'clickData')],
    [State('clicktrack-data', 'children')])
def disp_dct(n_clicks, clickData, dct_json):
    shape_index = {
        1: 'the triangle',
        2: 'the rectangle'
    }
    try:
        shape = shape_index[clickData['points'][0]['curveNumber']]
    except (KeyError, TypeError):
        shape = 'nothing'
    dct = json.loads(dct_json)
    if shape in dct:
        count = dct[shape]
        dct[shape] = count + 1
    else:
        dct[shape] = 1

    return json.dumps(dct, indent=2)


@app.callback(
    Output('hover-data', 'children'),
    [Input('basic-interactions', 'hoverData')])
def display_hover_data(hoverData):
    return json.dumps(hoverData, indent=2)


@app.callback(
    Output('click-data', 'children'),
    [Input('basic-interactions', 'clickData')])
def display_click_data(clickData):
    return json.dumps(clickData, indent=2)


if __name__ == '__main__':
    app.run_server(debug=True)