- はじめに -
PyTorchのRust bindingsであるtch-rsを使って、画像認識APIを実装する時のメモ。
今回は非同期ランタイムのtokioと同じプロジェクト配下で開発されているaxumを利用する。
- axumによるHTTPサーバ構築 -
RustでHTTPサーバを立てるライブラリはいくつかある。現状日本語ではCyberAgent社の以下のブログが詳しい。
私自身にあまり選定ノウハウが無いので、今回はtokioから出ているaxumを利用する。
axumを利用してHTTPサーバを構築するにあたっては、repository内のexampleディレクトリに複数の実装サンプルが配置されている他、Tokioのreleaseにも簡単なkickstartが存在するので、そちらを見ながら開発を進めた。
hello world
Cargo.tomlを作成する。
[package] name = "rust-machine-learning-api-example" version = "0.1.0" authors = ["vaaaaanquish <6syun9@gmail.com>"] edition = "2018" [dependencies] axum = "0.2.2" tokio = { version = "1.0", features = ["full"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0"
localhostにpostリクエストを投げる事でjsonをやり取りするサンプルを書く。
use axum::{handler::post, Router, Json}; use serde::{Serialize, Deserialize}; use serde_json::{json, Value}; use std::net::SocketAddr; #[tokio::main] async fn main() { let app = Router::new().route("/", post(proc)); let addr = SocketAddr::from(([0, 0, 0, 0], 3000)); println!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } #[derive(Deserialize)] struct RequestJson { message: String, } #[derive(Serialize)] struct ResponseJson { message: String, } async fn proc(Json(payload): Json<RequestJson>) -> Json<Value> { Json(json!({ "message": payload.message + " world!" })) }
responseはimpl IntoResponseで実装されたものを返す事ができる。ドキュメント内のbuiliding responses節にString、HTML、Json、StatusCodeなどを返す実装イメージが掲載されているので参考にすると良い。routeやMiddlewareを付与する場合も同様に参照すると良い。
cargo runして、以下のhello文字列を投げると「hello world!」になって帰ってくる。
curl -X POST -H "Content-Type: application/json" -d '{"message":"hello"}' http://localhost:3000
base64による画像の受信
一旦無難にbase64で画像をやり取りする事を考える。Cargo.tomlに以下を追記する。
base64 = "0.13" image = "0.23"
先程のスクリプトのpayload.messageを読んでいた箇所をbase64へデコードし、画像として保存するよう変更してみる。
extern crate base64; extern crate image; ... let img_buffer = base64::decode(&payload.message).unwrap(); let img = image::load_from_memory(img_buffer.as_slice()).unwrap(); img.save("output.png").unwrap();
clientサイドとして、rustのロゴを取得してbase64エンコードした文字列を投げるPythonスクリプトを書いてみる。
import base64 import json import requests # require: pip install requests sample_image_response = requests.get('http://rust-lang.org/logos/rust-logo-128x128-blk.png') img = base64.b64encode(sample_image_response.content).decode('utf-8') res = requests.post('http://127.0.0.1:3000', data=json.dumps({'message': img}), headers={'content-type': 'application/json'})
output.pngとしてRustのロゴ画像がcargo runしているディレクトリにできれば良い。適宜組み替える。
ExtensionLayerによるstate管理
機械学習APIなのでMLモデルを一回読み込んでグローバルに扱いたい。axumではExtensionLayerという機能を用いて、stateを実装できる。
https://docs.rs/axum/0.2.3/axum/#sharing-state-with-handlers
ここでは試しにHashSetをstateとしてみる。
AddExtensionLayerを使って、先程のAPIを同名の画像は保存しないように改修してみる。
use axum::{handler::post, Router, Json, AddExtensionLayer, extract::Extension}; use serde::{Serialize, Deserialize}; use serde_json::{json, Value}; use std::net::SocketAddr; use std::sync::Arc; use tokio::sync::Mutex; use std::collections::HashSet; extern crate base64; extern crate image; struct DataState { set: Mutex<HashSet<String>> } #[tokio::main] async fn main() { let set = Mutex::new(HashSet::new()); let state = Arc::new(DataState { set }); let app = Router::new() .route("/", post(proc)) .layer(AddExtensionLayer::new(state)); let addr = SocketAddr::from(([0, 0, 0, 0], 3000)); println!("listening on {}", addr); axum::Server::bind(&addr) .serve(app.into_make_service()) .await .unwrap(); } #[derive(Deserialize)] struct RequestJson { name: String, img: String, } #[derive(Serialize)] struct ResponseJson { result: String, } async fn proc(Json(payload): Json<RequestJson>, Extension(state): Extension<Arc<DataState>>) -> Json<Value> { let img_buffer = base64::decode(&payload.img).unwrap(); let mut set = state.set.lock().await; let result; if set.contains(&payload.name) { result = "skip by duplicated"; } else { let img = image::load_from_memory(&img_buffer.as_slice()).unwrap(); img.save(&payload.name).unwrap(); set.insert(payload.name); result = "saved output image"; } Json(json!({ "result": result })) }
先程のPythonスクリプトにname keyを付与してpostしていく。
res = requests.post('http://127.0.0.1:3000', data=json.dumps({'img': img, 'name': name}), headers={'content-type': 'application/json'}) print(res.text)
名前が重複したItemの場合は保存処理が走らず「skip by duplicated」なる文字列が返ってくる。名前がまだHashSet内にない場合はlocalディレクトリに画像が保存され、「saved output image」なる文字列が返ってくるようになった。
インメモリなので一度サーバを落とすと消えてしまうが、機械学習モデルをインメモリに保持する用途であれば十分だろう。
- tch-rsによる推論 -
PyTorchのRust bindingsでpretrain済みのモデルを流用して、推論を行うサンプルを過去に公開している。
こちらを流用して、推論を行うstateを作成しAddExtensionLayerに流す実装を行う。
tch-rsをCargo.tomlに追加する
tch = "0.5.0"
Arc
... use tch::nn::ModuleT; use tch::vision::{resnet, imagenet}; extern crate tch; struct DnnModel { net: Mutex<Box<dyn ModuleT>> } #[tokio::main] async fn main() { let weights = std::path::Path::new("/resnet18.ot"); let mut vs = tch::nn::VarStore::new(tch::Device::Cpu); let net:Mutex<Box<(dyn ModuleT + 'static)>> = Mutex::new(Box::new(resnet::resnet18(&vs.root(), imagenet::CLASS_COUNT))); let _ = vs.load(weights); let state = Arc::new(DnnModel { net }); ...
RustのFutureは難しい部分がいくつかあり、私も把握しきれていないが、大まかな外枠は以下を見る事ですぐ把握できる。
zenn.dev
tech.uzabase.com
推論部分は一度画像を保存して読み込む形を取る。
... let net = state.net.lock().await; let img_buffer = base64::decode(&payload.img).unwrap(); let img = image::load_from_memory(&img_buffer.as_slice()).unwrap(); let _ = img.save("/tmp.jpeg"); let img_tensor = imagenet::load_image_and_resize224("/tmp.jpeg").unwrap(); let output = net .forward_t(&img_tensor.unsqueeze(0), false) .softmax(-1, tch::Kind::Float); let mut result = Vec::new(); for (probability, class) in imagenet::top(&output, 5).iter() { result.push(format!("{:50} {:5.2}%", class, 100.0 * probability)); } ...
ローカルに画像を保存せずメモリバッファ経由で実装する方法としてload_image_and_resize224_from_memoryが実装されているが、まだreleaseには至っていないようだ。もう少しでインメモリ上で推論が完結しそうである。
github.com
以下のRustロゴ画像を投げてみる
レスポンスは以下のようになった。
{ "result": [ "buckle 26.54%", "wall clock 5.34%", "digital watch 5.32%", "analog clock 4.14%", "digital clock 3.71%" ] }
Rustのロゴはバックルか時計からしい。まあ概ね良さそう。
同様の方法を利用して、PythonのPyTorchで学習したモデルをRust bindings上で再現し推論を行うAPIを作成できるだろう。今回はこの辺でおわる。
- おわりに -
手探りの部分もあったが何とかできた。
コードは雑多だが以下に公開している。コメントはよしなにください。