Kerasで作った手書き文字認識をWebアプリにしてDockerコンテナにする
タイトルのとおりです。今更ながらではありますが機械学習に足を踏み入れたWebアプリを作ってみます。
機械学習に足を踏み入れかけの人をターゲットにしています、が投げやりなので不足点は公式ドキュメント等をご参照ください。
今回は機械学習をKerasで簡単に実装し、それを使ったWebアプリの作成を一通り行います。 また動作環境をDockerのコンテナにまとめ、どこでも使えるようなイメージにします。
MNISTとは
機械学習で言うところのHello Worldのようなものです。28×28の手書き文字画像が、どの数字が書かれているか予測する問題です。
学習用に28×28の画像と、答えの数字(0~9)が与えられます。
TF.Kerasで学習する
今回はそんなに細かいことをやるわけではないので、Kerasを使います。
Kerasは(TensorFlowなどに比べ)、ニューラルネットの記述に特化している、評価関数などの初期値がいい感じに定まっているという使いやすい点があります。
さておき細かい構成や実装については、各種書籍や記事にまとめられているので割愛し以下の構成で実装しました。
Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) (None, 28, 28, 1) 0 _________________________________________________________________ conv2d_1 (Conv2D) (None, 26, 26, 32) 320 _________________________________________________________________ conv2d_2 (Conv2D) (None, 24, 24, 64) 18496 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 12, 12, 64) 0 _________________________________________________________________ dropout_1 (Dropout) (None, 12, 12, 64) 0 _________________________________________________________________ flatten_1 (Flatten) (None, 9216) 0 _________________________________________________________________ dense_1 (Dense) (None, 120) 1106040 _________________________________________________________________ dropout_2 (Dropout) (None, 120) 0 _________________________________________________________________ dense_2 (Dense) (None, 10) 1210 ================================================================= Total params: 1,126,066 Trainable params: 1,126,066 Non-trainable params: 0
28×28の画像を正規化し、学習させます。学習したモデルはあとのWebアプリで使用するので保存しておきます。
#%% インポート関連 import tensorflow as tf # tf.enable_eager_execution() print(tf.__version__) print(tf.test.is_built_with_cuda()) from tensorflow.python import keras print(keras.__version__) from tensorflow.python.keras.callbacks import EarlyStopping import numpy as np from IPython.display import display import matplotlib.pyplot as plt from PIL import Image %matplotlib inline np.set_printoptions(threshold=100) #%% データを読みこみ (x_train_src, y_train_src), (x_test_src, y_test_src) = keras.datasets.mnist.load_data() print(x_train_src.shape) print(y_train_src.shape) print(x_test_src.shape) print(y_test_src.shape) # channel last前提で処理 keras.backend.image_data_format() #%% numpy配列に変換 input_shape =(28,28,1) x_train = x_train_src.reshape(x_train_src.shape[0], 28, 28, 1) x_test = x_test_src.reshape(x_test_src.shape[0], 28, 28, 1) # テストデータを正規化 x_train = x_train / 255.0 x_test = x_test / 255.0 # 分類問題なのでone-hot enc y_train = keras.utils.to_categorical(y_train_src, 10) y_test = keras.utils.to_categorical(y_test_src, 10) print(x_train.shape) print(x_test.shape) # 画像を表示、arrは28x28x1の正規化されたもの def convert_image(arr, show=True, title="", w=28, h=28): img = Image.fromarray(arr.reshape(w,h) * 255.0) if show: plt.imshow(img) plt.title(title) return img def convert_images(srcs, length, show=True, cols=5, w=28, h=28): rows = int(length / cols + 1) dst = Image.new('1', (w * cols, h * rows)) for j in range(rows): for i in range(cols): ptr = i + j * cols img = convert_image(srcs[ptr], show=False, w=w, h=h) dst.paste(img, (i * w, j * h)) if show: plt.imshow(dst) return dst plt.subplot(1,2,1) convert_images(x_train, 50,) plt.subplot(1,2,2) convert_images(x_test, 50,) plt.show() #%% モデル構築・学習 def MNISTConvModel(input_shape, predicates_class_n): inputs = keras.layers.Input(shape=input_shape) x = keras.layers.Conv2D(32, kernel_size=(3,3), activation='relu')(inputs) x = keras.layers.Conv2D(64, kernel_size=(3,3), activation='relu')(x) x = keras.layers.MaxPooling2D(pool_size=(2,2))(x) x = keras.layers.Dropout(0.25)(x) x = keras.layers.Flatten()(x) # 2D(12*12*64) -> 1d(9216) x = keras.layers.Dense(120, activation='relu')(x) x = keras.layers.Dropout(0.5)(x) predicates = keras.layers.Dense(predicates_class_n, activation='softmax')(x) return keras.models.Model(inputs=inputs, outputs=predicates) model = MNISTConvModel(input_shape=input_shape, predicates_class_n=10) model.summary() # モデルをコンパイルして実行 batch_size = 128 epochs = 20 model.compile( loss=keras.losses.categorical_crossentropy, optimizer='adadelta', metrics=['accuracy'] ) tensorboard_cb = keras.callbacks.TensorBoard(log_dir="./tflogs/", histogram_freq=1) history = model.fit( x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=2, validation_data=(x_test, y_test), callbacks=[tensorboard_cb], ) #%% 学習結果の確認 plt.subplot(2,1,1) plt.plot(range(epochs), history.history['acc'], label='acc') plt.plot(range(epochs), history.history['val_acc'], label='val_acc') plt.legend(loc='center left', bbox_to_anchor=(1, 0.5)) plt.subplot(2,1,2) plt.plot(range(epochs), history.history['loss'], label='loss') plt.plot(range(epochs), history.history['val_loss'], label='val_loss') plt.legend(loc='center left', bbox_to_anchor=(1, 0.5)) plt.show() #%% 性能 scores = model.evaluate(x_test, y_test, verbose=2) print('loss', scores[0], 'accuracy', scores[1]) #%% モデルの保存 model.save('model.h5')
これで学習させると正答率99.3%でした。グラフを見ても10epochs付近で過学習に陥っているのでEarlyStopping入れても良かったかもしれません。
学習済みモデルを利用したREST APIサーバを作成する
モデルを変換したりKeras.jsを使う方法などがありますが、ユーザにモデルのダウンロードをさせるのは重荷なのでAPIとして提供します。
今回Pythonを使っているのでFlaskというライブラリを使用します。
下のようなコードhttpでjsonが返せるすぐれものです。
@app.route("/") def index(): return make_response(jsonify({"hello": "world"})
今回はkerasのモデルを読み込みpredictを使って予測値を返します。
# GPUは使わない import os os.environ["CUDA_VISIBLE_DEVICES"] = "-1" import time import tensorflow as tf from tensorflow.python import keras from flask import Flask, jsonify, abort, make_response, request, send_from_directory import numpy as np graph = tf.get_default_graph() model = None app = Flask(__name__) # 疎通確認 @app.route("/info") def index(): return make_response(jsonify({ "name": "mnist-cnn server", "time": time.ctime(), })) # 28*28の画像をPOSTで配列にして送ると、0~9の推論結果を返してくれる @app.route("/predict", methods=['POST']) def mnist(): data = request.json if data == None: return abort(400) src = data["src"] if (src == None) | (not isinstance(src, list)): return abort(400) src = np.array(src) # 正規化する src = src.astype('float32') / 255.0 src = src.reshape(-1,28,28,1) # 推論する with graph.as_default(): start = time.time() dst = model.predict(src) elapsed = time.time() - start return make_response(jsonify({ "predict" : dst.tolist(), "elapsed" : elapsed, })) # 静的ファイル公開 @app.route("/", defaults={"path": "index.html"}) @app.route("/<path:path>") def send_file(path): return send_from_directory("dist", path) if __name__ == '__main__': model = keras.models.load_model("./model.h5") app.run(host="0.0.0.0", port=3000, debug=True)
最後の行でhostを0.0.0.0にしないと外部からアクセスできないので注意します。
ついでにこのあと作成する静的ページもホスティングできるようにしています。
ここまでで、POSTすると学習済みモデルのpredictの結果が得られるようになりました。
手書きができるWebページを作成する
最後に手書きされたデータを先程のFlaskの/predict
に投げるWebページを作ります。
手書きにはhtml5のcanvasを使います。また作りやすくするためにvue.jsを読み込んで使います。
@touchmove
みたいな箇所はvue側のmethodsを呼び出してくれます。
v-model
は変数の双方向バインディング、{{ variable_name }}
はViewへの単方向バインディングです。
<!doctype html> <html lang="ja"> <head> <meta charset="UTF-8"> <title>MNIST CNN Demo</title> <link href="style.css" rel="stylesheet"> </head> <body> <div id="container"> <h3>{{ message }}</h3> <div id="canvas_container"> <canvas id="draw_canvas" width="28" height="28" @touchmove="touch_draw" @mousemove="drag_draw" @mouseup="predict" @touchend="predict"></canvas> </div> <input v-model="pen_size" type="range" min="1" max="100" step="1"> <span>Pen Size:{{ pen_size }}</span> <button @click="clear">Clear</button> </div> <script src="vue.min.js"></script> <script type="text/javascript" src="index.js"></script> </body> </html>
次に動作を書きます。.new Vueする際にelで指定した要素に対して適用されます。
dataにはバインドする変数を定義し、methodsに使用する関数を記述します。もっと複雑なロジックや状態遷移がある場合はvuexなども検討してもいいかもしれません。
ポイントはcanvasのドラッグやタッチした際の位置を修正することと、サイズと実際に表示されるサイズが異なるためその変換を行っています。
最後にctx.getImageData
を叩いて得られたデータから、Flaskに送る配列に変換しています。(一緒にRGBからGrayScale画像にしています)
あとはpredictの結果から一番近しい数字を表示して終わりです。
const container = new Vue({ el: '#container', data: { message: '0から9の数字を書いたら識別します!', pen_size: "50", is_debug: false, }, methods: { update_message: function(str) { this.message = str; }, clear: function() { const canvas = document.getElementById('draw_canvas'); const ctx = canvas.getContext('2d'); ctx.fillStyle = 'black'; ctx.fillRect(0, 0, canvas.width, canvas.height); this.update_message('また書いてね!'); }, touch_draw: function(e) { const rect = e.target.getBoundingClientRect(); // 気分でマルチタッチ対応してみる for(const t of e.touches) { const x = t.clientX - rect.left; const y = t.clientY - rect.top; this.draw(x, y); } }, drag_draw: function(e) { if(!e.buttons) return; const rect = e.target.getBoundingClientRect(); const x = e.clientX - rect.left; const y = e.clientY - rect.top; this.draw(x, y); }, draw: function(mx, my) { const canvas = document.getElementById('draw_canvas'); // 表示サイズとcanvasサイズは異なるので変換しておく const x = mx / canvas.clientWidth * canvas.width; const y = my / canvas.clientHeight * canvas.height; if (x < 0 || y < 0 || canvas.width < x || canvas.height < y) return; // 点を書く const ctx = canvas.getContext('2d'); const r = parseFloat(this.pen_size) / 100.0 * (canvas.width / 8); ctx.beginPath(); ctx.fillStyle = 'white'; ctx.arc(x, y, r, 0, Math.PI * 2, true); ctx.fill(); }, predict: function() { const canvas = document.getElementById('draw_canvas'); const ctx = canvas.getContext('2d'); // RGBA32 const img = ctx.getImageData(0,0,28,28).data; const length = img.length / 4; // とりあえず面倒なので加重平均とかはしない const src = []; for(let i = 0 ; i < length ; ++i) { const ptr = i * 4; src.push(Math.floor((img[ptr] + img[ptr + 1] + img[ptr + 2]) / 3.0)); } // flaskで作った推論機に投げる callback = this.update_message; // then内でthis参照させるのがかったるい fetch('/predict', { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify({'src': src }), }).then(function(res) { return res.json(); }).then(function(data) { // predict[1][20], elapsed[sec]が帰ってくるので適当に表示する const predict = data.predict[0]; let index = 0; for(let i = 0 ; i < predict.length ; ++i) { if (predict[index] < predict[i]) { index = i; } } callback(`たぶん${index}だと思う。(${Math.floor(predict[index] * 100)}% ${data.elapsed}[sec])`); }); // 確認用 // this.debug_print(src, null); }, debug_print: function(src, predict) { if (this.is_debug) { let debug = ""; for(let j = 0 ; j < 28 ; ++j) { for(let i = 0 ; i < 28 ; ++i) { debug += ` ${src[j * 28 + i].toString(16)} `.slice(-3); } debug += '\r\n'; } console.log(debug); console.log(predict); } } }, });
※面倒なのでhttpリクエストにfetchを使っていますが、vue公式としてはaxiosを推奨しています。
作成したKeras+FlaskアプリケーションをDockerコンテナにまとめる
システムのポータビリティを考え、コンテナで実行できるようにします
特にPythonのパッケージバージョン管理はいろいろと難があるので
pipのインストールやregistryなどにアップすることを考えてDockerfileからビルドします。
まず現在動作している環境のライブラリをrequirements.txtに出力します
$ pip freeze > requirements.txt
以下のようにパッケージ一覧が得られます。乱雑する場合は最小限動作する仮想環境を作り直してから同作業を行います。
absl-py==0.2.2 astor==0.7.1 bleach==1.5.0 certifi==2018.4.16 click==6.7 Flask==1.0.2 gast==0.2.0 grpcio==1.13.0 h5py==2.8.0 html5lib==0.9999999 itsdangerous==0.24 Jinja2==2.10 Markdown==2.6.11 MarkupSafe==1.0 numpy==1.14.5 protobuf==3.6.0 six==1.11.0 tensorflow==1.8.0 termcolor==1.1.0 Werkzeug==0.14.1 wincertstore==0.2
あとは、Dockerfileで公式のPythonパッケージを引っ張ってきてファイルのコピーとライブラリの追加を行います。
FROM python:3.5 # file copy COPY . /app WORKDIR /app # lib install RUN pip3 install --upgrade -r requirements.txt # run flask server EXPOSE 3000 CMD ["python", "mnist-server.py"]
EXPOSE 3000
はポート開放なので忘れずに入れておきます。
最後にコンテナをビルドします
$ docker build -t <image-name> .
実行したいときはポート開放と合わせて以下のコマンドを実行します
$ docker run -p 3000:3000 -it <image-name>
pythonで実行したときと同様に動けば成功です。
毎回ビルドと実行が面倒なのでdocker-compose.yml
ファイルも用意しておきます
mnist-server: build: ./ ports: - 3000:3000
これによってビルドと起動を$ docker-compose up
でできるようになります。
まとめ
今回は手書き文字認識をベースに、ポータビリティに優れたWebアプリ作成例を示しました。
環境構築の手間がなくなるだけかなり便利なのでぜひ試してください。
今回作成したアプリをgithubに上げておきます。