Overiding button action in animated scatter plot

Hi, I am building a Jupyter notebook to allow people to explore and learn the behaviour of waves and signals. One of the tasks I am trying to solve is to animate a signal propagating through space. I have so far been able to get most of the code running, but I am trying to perfect the functionality.

Ideally, I would like people who use the notebook to be able to play, pause and reset the animation. I have created buttons that control all three options, and on their own, they work well. I run into a problem when the reset button gets pressed before the animation is finished. if this is the case, the animation continues plotting the traces it has left to finish the animation.

Is there a way to override the animation with a press of a button? Essentially, making a reset button that pauses the animation and then resetting the data?

This is my code:

import numpy as np
import plotly.graph_objs as go

# parameters that define the wave
k = 0.1  # wavenumber, 1/m
a = 1  # amplitude, m
phi = 0  # phase shift, between 0 and 2*pi

# Fixed parameters for this solution
g = 9.81  # acceleration due to gravity, m/s^2
x = np.linspace(0, 100, 101)  # spatial coordinate, m
time = np.linspace(0, 3, 31)  # time coordinate, s

# The dispersion relation
w = np.sqrt(g * k)  # angular frequency, 1/second

# Color spectrum from blue to yellow
colors = [[i/len(time), "rgb({0:.0f}, {1:.0f}, {2:.0f})".format(255*i/len(time), 255*i/len(time), 255 - 255*i/len(time))] for i in range(len(time))]

# create a trace for each time step
traces = [go.Scatter( name='time: ' + str(i) + 's',
                      line=dict(color=colors[i][1]))
          for i in range(len(time))]

# create frames for each time step 
frames = [dict(data=[dict(type='scatter',
                          x=x,
                          y=a * np.cos(k * x - w * time[ii] + phi))],
               traces=[ii],
               ) for ii in range(len(time))]

# create the animation
layout = go.Layout(
    xaxis=dict(title='x/m'),
    yaxis=dict(title='Amplitude/m'),
    updatemenus=[dict(
        type='buttons',
        showactive=False,
        buttons=[dict(
            label='Play',
            method='animate',
            args=[None, {'frame': {'duration': 1, 'redraw': True},
                         'fromcurrent': False,
                         'transition': {'duration': 1},
                         'mode': 'immediate',
                         }]
        ),
            dict(
                label='Pause',
                method='animate',
                args=[[None], {'frame': {'duration': 3, 'redraw': False},
                               'mode': 'immediate', 'transition': {'duration': 30}}]
            ),
            dict(
                label='Reset',
                method='update',
                args=[{'x': 0,
                       'y':0}

                      ]
            )]

    )]

)

fig = go.Figure(data=traces, frames=frames, layout=layout)
fig.show()