On button event in Python class

I am trying to write a short Gaussian process regression example using the plotly Graph object library. I have a class which is designed to construct an interactive plot with selection buttons for training data in your GPR model. However, then on_click method seems to not result in a change to my plot and I am confused as to why. There is also no output from the process so it is challenging to debug. I have left the code below, it should work completely once run.

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process import kernels


import numpy as np
import plotly.graph_objects as go
import plotly.express as px

import matplotlib.pyplot as plt

import copy

def noisy_sine(x: np.ndarray):
    """
    Produce a noisy sine function.

    Parameters
    ----------
    x : np.ndarray
            x data at which to compute the
            function.
    """
    return np.sin(x) + 0.1 * np.random.normal(size=x.shape)

x = np.linspace(0, 10, 100)
y = noisy_sine(x)

class GPTrainWidget:
    """
    Class for an interactive GPR training
    widget.
    
    Attributes
    ----------
    train_data : np.ndarray
            Train data list that is updated
            on click events.
    unselected_colour : str
            Colour of points when not selected.
    selected_colour : str
            Colour of points when selected.
    """
    train_data: list
    train_indices: list
    unselected_colour = 'White'
    selected_colour = "#e28413"
    
    def __init__(
        self, 
        kernel: kernels.Kernel, 
        training_ds: np.ndarray,
        train_points: int = 10,
        
    ):
        """
        Constructor for the widget.
        
        Parameter
        ---------
        kernel : kernels.Kernel
                Kernel to use in the GPR construction.
        training_ds : np.ndarray (n_points, 2)
                Raw data of the function.
        train_points : int
                Number of train points to display on the plot.
        """        
        self.kernel = kernel
        self.train_points = train_points
        self.training_ds = training_ds
        
        self.model = GaussianProcessRegressor(
            kernel=kernel, random_state=0
        )
        
        # Build the bubbles
        selected_bubbles = np.linspace(
            0, 
            len(self.training_ds), 
            self.train_points,
            endpoint=False,
            dtype=int
        )
        self.train_bubbles = np.take(
            self.training_ds, selected_bubbles, axis=0
        )
        
        self.bubble_scatter = go.Scatter(
            x=self.train_bubbles[:, 0], 
            y=self.train_bubbles[:, 1],
            mode="markers", 
            marker=dict(
                size=10, line_width=2
            )
        )
        self.bubble_colors = [self.unselected_colour] * self.train_points
        self._update_train_bubbles()
        
    def _update_train_bubbles(self):
        """
        Update the bubble scatter plot.
        """
        self.bubble_scatter.marker.line.width = 1.5
        self.bubble_scatter.marker.line.color = "Black"
        self.bubble_scatter.marker.opacity = 1
        self.bubble_scatter.marker.color = self.bubble_colors
    
    def _on_click_callback(self, trace, points, selector):
        """
        On click callback for the widget.
        
        When the user clicks on the empty training data,
        a new GP is fit and displayed with updated priors.
        
        Parameters
        ----------
        All parameters of this class are internally handled
        by plotly and therefore not documented here.
        """
        # Select the important parameters
        colours = list(new_data.marker.color)
        sizes = list(new_data.marker.size)
        
        # Check point selection
        index = points.point_inds
        if index in self.train_indices:
            other_index = self.train_indices.index(index)
            self.train_indices.remove(index)
            self.bubble_colors[index[0]] = self.unselected_colour
            self.train_data.pop(other_index)
        else:
            self.train_indices.append(index)
            self.train_data.append([points.xs[0], points.ys[0]])
            self.bubble_colors[index[0]] = self.selected_colour        
        
        # Update model fit
        self.fit_gpr()        
                
        # Update plot fig
        self._redraw_figure()
        
    def _redraw_figure(self):
        """
        Redraw the figure.
        """
        self._update_train_bubbles()  # update bubble colours
        prior_plots = self.create_fit_figure()
        prior_plots.append(self.bubble_scatter)

        fig = go.FigureWidget(prior_plots)
        bubble_scatter = fig.data[-1]
        bubble_scatter.on_click(self._on_click_callback)

        fig.update_layout(showlegend=False)
        fig.layout.hovermode = 'closest'
        return fig
    
    def create_fit_figure(self):
        """
        Draw new priors and plot them on the figure.
        
        Returns
        -------
        prior_figure
        """
        n_samples=10
        x = np.linspace(0, 10, 100).reshape(-1, 1)

        y_mean, y_std = self.model.predict(x, return_std=True)
        y_samples = self.model.sample_y(x, n_samples)
        
        prior_plots = []
        for prior_draw in y_samples.T:
            prior_plots.append(
                go.Scatter(
                    x=x.reshape(-1), 
                    y=prior_draw, 
                    mode="lines",
                    hoverinfo='skip'
                )
            )
            
        prior_plots.append(
            go.Scatter(
                x=x.reshape(-1), 
                y=y_mean, 
                mode="lines", 
                line=dict(color="black", width=4),
                hoverinfo='skip'
            )
        )
            
        return prior_plots

    def fit_gpr(self):
        """
        Fit a new GPR model with the updated data.
        """
        x_train = np.array(self.train_data)[:, 0].reshape(-1, 1)
        y_train = np.array(self.train_data)[:, 1]
        
        # Train the model in place.
        self.model.fit(x_train, y_train)
    
    def __call__(self):
        """
        Run the app.
        """
        fig = self._redraw_figure()
        return fig

kernel = kernels.RBF(
            length_scale=1.0, length_scale_bounds=(1e-1, 10.0)
        )
train_ds = np.stack((x, y), axis=1)

gpr_example = GPTrainWidget(
    kernel=kernel, training_ds=train_ds, train_points=10
)
fig = gpr_example()
fig.show()

The idea is that when you click on the white bubbles, a GPR model is trained and the updated mean function re-plotted. However, when the buttons are clicked, nothing takes place. Any help would be much appreciated.