Deploying Keras Model in web app with real time plots using plotly

Hello, I’m new to the forum. I’m also new to python but I’ve managed to hack my way to some early success. I’ve developed a deep learning model using keras and I’ve successfully deployed it in a flash app. In contains a /predict endpoint that another application can query with data and then get a prediction returned to it. This flask app simply runs as a console application, but in addition to that I also wanted to pop up a web page that shows a real time scatter plot of the input data. I was able to do this using the bokeh library and marrying that into my flask app. Now, when the app launches, a web page of the plot is opened and it is configured to hit a new ‘/data’ endpoint in the flask app periodically to get the latest data and update the 2d scatter plot. This seems to be working reliably.

Now for the reason for my post. I would really like to improve upon this by changing my 2d scatter plot to a 3d scatter plot. Unfortunately, bokeh does not offer a 3d scatter plot out of the box. This brought me to plotly and dash. I know plotly can do a 3d scatter plot. I’m looking for sample code, or advice on how I should go about doing this.

  1. Should I keep the existing structure of a flask app but somehow combine that with plotly to take the place of bokeh? This means plotly would have to create the web page and periodically hit the /data enpoint of my flask app.

  2. Should I abandon flask and change to a dash app? If so, I would then need to deploy a keras model in dash. and setup both the /predict and /data enpoints.

Below is my existing code with flask and bokeh. Does anyone have any suggestions on how I should proceed and even better, can anyone show me some sample code that would be applicable to what I’m trying to do.

Thank you.

import numpy as np
import tensorflow as tf
import keras
from keras.models import model_from_json
from flask import Flask, jsonify, make_response, request

from bokeh.plotting import figure, show
from bokeh.models import AjaxDataSource, CustomJS, Range1d

# Bokeh related code

adapter = CustomJS(code="""
    const result = {x: [], y: []}
    const pts = cb_data.response.points
    for (i=0; i<pts.length; i++) {
    return result

source = AjaxDataSource(data_url='',
                        polling_interval=200, adapter=adapter)

p = figure(plot_height=700, plot_width=1400, x_axis_label='Frequency', y_axis_label='Phase', background_fill_color="lightgrey",
           title="Scatter Plot of TOI")
p.x_range = Range1d(0, 49)
p.y_range = Range1d(-2**15, 2**15)'x', 'y', source=source, color='red', size=10)

# Flask related code

app = Flask(__name__)

def get_model():
    global model
    global g

    json_path = 'RNN_LSTM_128_Drop0p2.json'
    h5_path = "RNN-010-0.983732-0.997119.h5"
    g = tf.Graph()
    with g.as_default():
        # Pull in the model we want to test
        json_file = open(json_path, 'r')
        loaded_model_json =
        model = model_from_json(loaded_model_json)
        # load weights into new model
        print("Loaded model from disk")
        # Compile the loaded model
        print("Compiled Model")
    print(" * Recurrent Neural Network Trained Model Loaded Successfully")

def RNN_scale_input(X):
    N_samp = X.shape[0]
    N_t_samp = X.shape[1]
    n_feat = X.shape[2]

    X_scale = np.zeros(shape=(N_samp, N_t_samp, n_feat), dtype=np.float32)
    for h in range(N_samp):
        # Scale the channel index to be between -1 and 1
        X_scale[h][:, 0] = (X[h][:, 0] - 25) / 25
        # Scale the phase value to be between -1 and 1
        X_scale[h][:, 1] = X[h][:, 1] / 2**15
    return X_scale

# Define function to pre-process phase data
def pre_process_data(data, num_rows, num_cols):
    # Strip the opening and closing brackets from the input data string
    data = data.strip('[]')
    # Convert string into a list of floating point numbers
        data = [float(h) for h in data.split(',')]
    data_np = np.asarray(data).reshape(1, num_rows, num_cols)
    rnn_structure = RNN_scale_input(data_np)
    return rnn_structure[0]

def pre_process_plot_data(data, num_rows, num_cols):
    # Strip the opening and closing brackets from the input data string
    data = data.strip('[]')
    # Convert string into a list of floating point numbers
        list_data = [int(k) for k in data.split(',')]
    plot_data_np = np.asarray(list_data).reshape(num_rows, num_cols)
    return plot_data_np

N_t, n_features = 64, 2
print(" * Loading Neural Network Trained Model...")

# Define Tag of Interest to plot
TOI = "307401320416C2054B6E99D7"

def crossdomain(f):
    def wrapped_function(*args, **kwargs):
        resp = make_response(f(*args, **kwargs))
        h = resp.headers
        h['Access-Control-Allow-Origin'] = '*'
        h['Access-Control-Allow-Methods'] = "GET, OPTIONS, POST"
        h['Access-Control-Max-Age'] = str(21600)
        requested_headers = request.headers.get('Access-Control-Request-Headers')
        if requested_headers:
            h['Access-Control-Allow-Headers'] = requested_headers
        return resp
    return wrapped_function

x = [0]*N_t
y = [0]*N_t

@app.route("/predict", methods=['POST'])
def predict():
    global x
    global y
    message = request.get_json(force=True)
    TagId = message['TagId']
    PhaseData = message['PhaseInput']
    model_input = np.zeros(shape=(1, N_t, n_features), dtype='float32')
    model_input[0] = pre_process_data(PhaseData, N_t, n_features)
    # If TagId == TOI, plot the data
    if TagId == TOI:
        # Update data in plot
        plot_data = pre_process_plot_data(PhaseData, N_t, n_features)
        x1 = list(plot_data[:, 0])
        x = [int(x1[i]) for i in range(N_t)]
        y1 = list(plot_data[:, 1])
        y = [int(y1[i]) for i in range(N_t)]

    # Make prediction with model
    with g.as_default():
        model_output = model.predict(model_input)
    # Take the second element (location at index (0, 1)) as the float prediction
    pred_val = model_output[0, 1]
    prediction = pred_val.tolist()
    # Construct response to send back to client app
    response = {
        'tag': TagId,
        'prediction': prediction
    return jsonify(response)

@app.route('/data', methods=['GET', 'OPTIONS', 'POST'])
def data():
    global x
    global y
    return jsonify(points=list(zip(x, y)))

# show and run