Hello, I’m new to the forum. I’m also new to python but I’ve managed to hack my way to some early success. I’ve developed a deep learning model using keras and I’ve successfully deployed it in a flash app. In contains a /predict endpoint that another application can query with data and then get a prediction returned to it. This flask app simply runs as a console application, but in addition to that I also wanted to pop up a web page that shows a real time scatter plot of the input data. I was able to do this using the bokeh library and marrying that into my flask app. Now, when the app launches, a web page of the plot is opened and it is configured to hit a new ‘/data’ endpoint in the flask app periodically to get the latest data and update the 2d scatter plot. This seems to be working reliably.
Now for the reason for my post. I would really like to improve upon this by changing my 2d scatter plot to a 3d scatter plot. Unfortunately, bokeh does not offer a 3d scatter plot out of the box. This brought me to plotly and dash. I know plotly can do a 3d scatter plot. I’m looking for sample code, or advice on how I should go about doing this.
-
Should I keep the existing structure of a flask app but somehow combine that with plotly to take the place of bokeh? This means plotly would have to create the web page and periodically hit the /data enpoint of my flask app.
-
Should I abandon flask and change to a dash app? If so, I would then need to deploy a keras model in dash. and setup both the /predict and /data enpoints.
Below is my existing code with flask and bokeh. Does anyone have any suggestions on how I should proceed and even better, can anyone show me some sample code that would be applicable to what I’m trying to do.
Thank you.
import numpy as np
import tensorflow as tf
import keras
from keras.models import model_from_json
from flask import Flask, jsonify, make_response, request
from bokeh.plotting import figure, show
from bokeh.models import AjaxDataSource, CustomJS, Range1d
# Bokeh related code
adapter = CustomJS(code="""
const result = {x: [], y: []}
const pts = cb_data.response.points
for (i=0; i<pts.length; i++) {
result.x.push(pts[i][0])
result.y.push(pts[i][1])
}
return result
""")
source = AjaxDataSource(data_url='http://10.61.226.215:443/data',
polling_interval=200, adapter=adapter)
p = figure(plot_height=700, plot_width=1400, x_axis_label='Frequency', y_axis_label='Phase', background_fill_color="lightgrey",
title="Scatter Plot of TOI")
p.x_range = Range1d(0, 49)
p.y_range = Range1d(-2**15, 2**15)
p.circle('x', 'y', source=source, color='red', size=10)
# Flask related code
app = Flask(__name__)
def get_model():
global model
global g
json_path = 'RNN_LSTM_128_Drop0p2.json'
h5_path = "RNN-010-0.983732-0.997119.h5"
g = tf.Graph()
with g.as_default():
# Pull in the model we want to test
json_file = open(json_path, 'r')
loaded_model_json = json_file.read()
json_file.close()
model = model_from_json(loaded_model_json)
# load weights into new model
model.load_weights(h5_path)
print("Loaded model from disk")
# Compile the loaded model
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])
print("Compiled Model")
print(" * Recurrent Neural Network Trained Model Loaded Successfully")
def RNN_scale_input(X):
N_samp = X.shape[0]
N_t_samp = X.shape[1]
n_feat = X.shape[2]
X_scale = np.zeros(shape=(N_samp, N_t_samp, n_feat), dtype=np.float32)
for h in range(N_samp):
# Scale the channel index to be between -1 and 1
X_scale[h][:, 0] = (X[h][:, 0] - 25) / 25
# Scale the phase value to be between -1 and 1
X_scale[h][:, 1] = X[h][:, 1] / 2**15
return X_scale
# Define function to pre-process phase data
def pre_process_data(data, num_rows, num_cols):
# Strip the opening and closing brackets from the input data string
data = data.strip('[]')
# Convert string into a list of floating point numbers
try:
data = [float(h) for h in data.split(',')]
except:
print("***ERROR***")
print(data)
data_np = np.asarray(data).reshape(1, num_rows, num_cols)
rnn_structure = RNN_scale_input(data_np)
return rnn_structure[0]
def pre_process_plot_data(data, num_rows, num_cols):
# Strip the opening and closing brackets from the input data string
data = data.strip('[]')
# Convert string into a list of floating point numbers
try:
list_data = [int(k) for k in data.split(',')]
except:
print("***ERROR***")
print(data)
plot_data_np = np.asarray(list_data).reshape(num_rows, num_cols)
return plot_data_np
N_t, n_features = 64, 2
print(" * Loading Neural Network Trained Model...")
get_model()
# Define Tag of Interest to plot
TOI = "307401320416C2054B6E99D7"
def crossdomain(f):
def wrapped_function(*args, **kwargs):
resp = make_response(f(*args, **kwargs))
h = resp.headers
h['Access-Control-Allow-Origin'] = '*'
h['Access-Control-Allow-Methods'] = "GET, OPTIONS, POST"
h['Access-Control-Max-Age'] = str(21600)
requested_headers = request.headers.get('Access-Control-Request-Headers')
if requested_headers:
h['Access-Control-Allow-Headers'] = requested_headers
return resp
return wrapped_function
x = [0]*N_t
y = [0]*N_t
@app.route("/predict", methods=['POST'])
def predict():
global x
global y
message = request.get_json(force=True)
TagId = message['TagId']
PhaseData = message['PhaseInput']
model_input = np.zeros(shape=(1, N_t, n_features), dtype='float32')
model_input[0] = pre_process_data(PhaseData, N_t, n_features)
# If TagId == TOI, plot the data
if TagId == TOI:
# Update data in plot
plot_data = pre_process_plot_data(PhaseData, N_t, n_features)
x1 = list(plot_data[:, 0])
x = [int(x1[i]) for i in range(N_t)]
y1 = list(plot_data[:, 1])
y = [int(y1[i]) for i in range(N_t)]
# Make prediction with model
with g.as_default():
model_output = model.predict(model_input)
# Take the second element (location at index (0, 1)) as the float prediction
pred_val = model_output[0, 1]
prediction = pred_val.tolist()
# Construct response to send back to client app
response = {
'tag': TagId,
'prediction': prediction
}
return jsonify(response)
@app.route('/data', methods=['GET', 'OPTIONS', 'POST'])
@crossdomain
def data():
global x
global y
return jsonify(points=list(zip(x, y)))
# show and run
show(p)
# app.run(port=443)