How to use Keras inside a web server
I got a lot of questions about this topic so I decided to make a tutorial.
Problem
The problem is that Tensorflow Sessions cannot be shared across processes. So you cannot have a web server that has a reference to your model and call methods on it. Also, Keras will not tell you what’s going on, because it doesn’t know. In fact, Tensorflow will just block without error.
Solution
Your model should only be used by a single worker. Easy enough?
Let’s first create a class that will handle our model. It should be thread-safe, because multiple workers will be accessing it.
from keras.applications import VGG16
from multiprocessing import Lock
import numpy as np
class KerasModel():
def __init__(self):
self.mutex = Lock()
self.model = None
def initialize(self):
"""Initialize our model"""
self.model = VGG16()
# Dummy compile
self.model.compile('sgd', 'mse')
def predict(self, arr):
"""This method uses VGG16 to predict an ImageNet class.
arr: Numpy array, the input image (should be of shape (224,224,3))
returns : A distribution over all 1000 ImageNet's classes.
"""
if arr.shape != (224, 224, 3):
raise ValueError('The image provided is not right.')
with self.mutex:
# With the mutex, we can now predict!
return self.model.predict_on_batch(arr[np.newaxis, ...])[0]
Manager
I’ll use a multiprocessing.Manager
to interact with our model. A manager already spawns inside another process and his purpose is do this kind of stuff.
Basically, a Manager is a simple server that provides methods to access its components. In our case, our model is our component.
To let the Manager use our model, we’ll need to register it.
from multiprocessing.managers import BaseManager
# Dummy class
class KerasManager(BaseManager):
pass
KerasManager.register('KerasModel', KerasModel)
That’s it!
Web server
We can now build our web server! I’ll use Flask, because it’s super easy to use, but choose the framework of your choice.
Let’s initialize our manager :
from keras_model import KerasManager
manager = KerasManager()
manager.start() # Important to start our server!
keras_model = manager.KerasModel()
keras_model.initialize() # Important to initialize our network!
It is now ready to be used inside the web server.
import cv2
import numpy as np
from flask import Flask, request
app = Flask(__name__)
@app.route('/hello', methods=['POST'])
def hello():
img = cv2.imdecode(np.fromstring(request.files['file'].read(), np.uint8), -1)
img = cv2.resize(img, (224, 224))
# With the Manager, we call our model!
pred = keras_model.predict(img)
# Get the associated ImageNet class
return imagenet_classes[np.argmax(pred)]
if __name__ == '__main__':
app.run()
Where imagenet_classes
is the dict
that we can find here.
Our web server has only one endpoint hello, which is a POST method that accepts a file with a key 'file'
.
The file is then read using OpenCV. We then feed the network and send back the response.
We can now try it with your favorite request manager. (Like Postman)
Conclusion
As we can see, it’s fairly easy to use a Keras model inside a Web server!
The code can be found on my repo: https://github.com/Dref360/tuto_keras_web