Multiple callback, error in arguements

My code as follows, I have run into an error
TypeError: callback() takes from 3 to 5 positional arguments but 7 were given
pointing to the State app.callback. I have tried a couple of ways to convert the input, output, states in list with square brackets but nothing seems to work. Funny thing is that this code works on Pycharm but on Juypter notebook/Colab, that error comes up.

Create Data:

import pandas as pd
df = pd.DataFrame(
    {
        "Wallclock": pd.date_range(
            "22-dec-2020 00:01:36", freq="5min", periods=2000
        ),
        "tcd": np.linspace(3434, 3505, 2000) *np.random.uniform(.9,1.1, 2000),
        "humidity": np.linspace(63, 96, 2000),
    }
).pipe(lambda d: d.assign(Capsule_ID=(d.index // (len(d)//16))+2100015))

Visualisation Code:

import plotly.express as px  
import plotly.graph_objects as go
import numpy as np
import openpyxl
import dash 
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output, State
from dash.exceptions import PreventUpdate
from jupyter_dash import JupyterDash

app = JupyterDash(__name__)

capsuleID = df["Capsule_ID"].unique()
# print(capsuleID)
capsuleID_names = sorted(list(capsuleID))
# print(capsuleID_names)

capsuleID_names_1 = [{"label": k, "value": k} for k in sorted(capsuleID)]
capsuleID_names_2 = [{"label": "(Select All)", "value": "All"}]
capsuleID_names_all = capsuleID_names_1 + capsuleID_names_2

def slider_fig(df):
    return px.scatter(
                df.groupby("Wallclock", as_index=False).size(), x="Wallclock", y="size"
            ).update_layout(
                xaxis={"rangeslider": {"visible": True}, "title":None},
                height=125,
                yaxis={"tickmode": "array", "tickvals": [], "title": None},
                margin={"l": 0, "r": 0, "t": 0, "b": 0},
            )

app.layout = html.Div(
    [
        html.H1("Relative Humidity vs TCD", style={"text-align": "center"}),
        dcc.Dropdown(
            id="capsule_select",
            options=capsuleID_names_all,
            optionHeight=25,
            multi=True,
            searchable=True,
            placeholder="Please select...",
            clearable=True,
            value=["All"],
            style={"width": "100%"},
        ),
        dcc.Graph(
            id="slider",
            figure=slider_fig(df),
        ),
        html.Div(
            [dcc.Graph(id="the_graph"),
            ]),
    ])

# -----------------------------------------------------------

@app.callback(
    Output('the_graph', 'figure'),
    Output('capsule_select', 'value'),
    Output('slider', 'figure'),
    Input('capsule_select', 'value'),
    Input('slider', 'relayoutData'),
    State('slider', 'figure')
)

def update_graph(capsule_chosen, slider, sfig):
    dropdown_values = capsule_chosen
    if "All" in capsule_chosen:
        dropdown_values = capsuleID_names
        dff = df
    else:
        dff = df[
            df["Capsule_ID"].isin(capsule_chosen)
        ]  # filter all rows where capsule ID is the capsule ID selected

    if slider and "xaxis.range" in slider.keys():
        dff = dff.loc[dff["Wallclock"].between(*slider["xaxis.range"])]
    else:
        # update slider based on selected capsules
        sfig = slider_fig(dff)

    scatterplot = px.scatter(
        data_frame=dff,
        x="tcd",
        y="humidity",
        hover_name="Wallclock",
    )

    scatterplot.update_traces(textposition="top center")

    return scatterplot, dropdown_values, sfig

# ------------------------------------------------------------------------------
if __name__ == "__main__":
    #     app.run_server(debug=True)
    app.run_server(mode="inline")