go.Figure Scatter, tonexty not in plot

This one was interesting. Two major things I changed:

  • make sure, that the x- values are sorted. That was the main thing. If you zoom into the center of a chart created with your initial function you will see, that the lines are zig zagged. If plotting a go.Scatter() with mode=markers the order of the points does not matter, if mode=lines it does.
  • order the traces for the use of tonexty
  • change some trace names aiming for a clear legend
  • changes to be able to use numpy arrays
  • changes in axis titles and figure title

Here the changed function:

from sklearn.linear_model import LinearRegression
import plotly.graph_objects as go
from scipy import stats
import numpy as np

def scatter_bands(x, y, conf):
    # make sure, that the x- values are sorted so that the lineplot does not get zig zagged
    list_of_tuples = [(a,b) for a,b in zip(x, y)]
    list_of_sorted_tuples = sorted(list_of_tuples, key=lambda x: x[0])
    x = np.asarray([a[0] for a in list_of_sorted_tuples])
    y = np.asarray([b[1] for b in list_of_sorted_tuples])

    # model y_hat with sklearn
    model = LinearRegression()
    model.fit(x.reshape(-1, 1), y)
    y_hat = model.predict(x.reshape(-1, 1))

    n = x.size
    m = 1
    dof = n - m
    t = stats.t.ppf(conf, dof)

    resid = y - y_hat
    se = np.sqrt(np.sum(resid**2) / dof)
    ci = t * se * np.sqrt(1/n + (x - np.mean(x))**2 / np.sum((x - np.mean(x))**2))
    pi = t * se * np.sqrt(1 + 1/n + (x - np.mean(x))**2 / np.sum((x - np.mean(x))**2))

    # def scatterplot plus confidence and prediction bands
    ff = go.Figure()
    
    ff.add_trace(
        go.Scatter(
            x=x, 
            y=y_hat,
            mode="lines", 
            name="Expected", 
            line={"color": "red", "width": 1.5},
        )
    )
    
    ff.add_trace(
        go.Scatter(
            x=x, 
            y=y_hat-ci, 
            mode="lines", 
            name="Upper CI Band Limit",
            line={"color": "red", "width": 0.5},
            showlegend=False
        )
    )
    
    ff.add_trace(
        go.Scatter(
            x=x, 
            y=y_hat+ci, 
            mode="lines", 
            name=f"{round(conf*100)}% confidence interval   ",
            line={"color": "red", "width": 0.5}, 
            fill="tonexty",
            fillcolor="rgba(255,0,0,0.5)",
            showlegend=True,
        )
    )    
        
    ff.add_trace(
        go.Scatter(
            x=x, 
            y=y, 
            mode="markers", 
            line={"color": "rgba(4, 217, 255, 1)"}, 
            name="Scores"
        )
    )
    


    # prediction interval
    ff.add_trace(go.Scatter(x=x, y=y_hat+pi, mode="lines", name=f"{round(conf*100)}% prediction interval", line={"color": "orange", "width": 1, "dash": "dash"}))
    ff.add_trace(go.Scatter(x=x, y=y_hat-pi, mode="lines", name="Lower PI Band Limit", line={"color": "orange", "width": 1, "dash": "dash"}, showlegend=False))
    
    ff.update_layout({"plot_bgcolor": "rgba(0, 0, 0, 1.0)", "paper_bgcolor": "rgba(226, 220, 216, 1.0)", "font": {"family":"Lato", "color": "#121212"},
                     "title": {"text": f"{1} by {2}", "x": 0.5, "font": {"size": 28}},
                     "legend": {"title": {"text": " <b>Variables</b>"}, "x": 0.01, "y": 0.97, "bgcolor": "#eafbff", "traceorder": "normal"},
                     "xaxis": {"title": {"text": 1}, "range": [0.0, x.max()+(x.max()//20)],
                               "gridcolor": "#121212", "gridwidth": 10.0, "zeroline": False},
                     "yaxis": {"title": {"text": 2}, "range": [0.0, y.max()+(y.max()//20)],
                               "gridcolor": "#121212", "gridwidth": 10.0, "zeroline": False},
                     })

    ff.update_layout(height=700)
    return ff

use:

np.random.seed(42)
x = np.random.randint(0,1000,1000)
y = np.random.randint(0,1000,1000)

fig = scatter_bands(x,y,0.95)

fig.show()

creates:


mrep tonexty