Stimulator

機械学習とか好きな技術話とかエンジニア的な話とかを書く

axumとtch-rsでRustの画像認識APIを作る

- はじめに -

PyTorchのRust bindingsであるtch-rsを使って、画像認識APIを実装する時のメモ。

今回は非同期ランタイムのtokioと同じプロジェクト配下で開発されているaxumを利用する。


 

- axumによるHTTPサーバ構築 -


RustでHTTPサーバを立てるライブラリはいくつかある。現状日本語ではCyberAgent社の以下のブログが詳しい。

developers.cyberagent.co.jp

私自身にあまり選定ノウハウが無いので、今回はtokioから出ているaxumを利用する。

axumを利用してHTTPサーバを構築するにあたっては、repository内のexampleディレクトリに複数の実装サンプルが配置されている他、Tokioのreleaseにも簡単なkickstartが存在するので、そちらを見ながら開発を進めた。

github.com

 

hello world

Cargo.tomlを作成する。

[package]
name = "rust-machine-learning-api-example"
version = "0.1.0"
authors = ["vaaaaanquish <6syun9@gmail.com>"]
edition = "2018"

[dependencies]
axum = "0.2.2"
tokio = { version = "1.0", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

localhostにpostリクエストを投げる事でjsonをやり取りするサンプルを書く。

use axum::{handler::post, Router, Json};
use serde::{Serialize, Deserialize};
use serde_json::{json, Value};
use std::net::SocketAddr;

#[tokio::main]
async fn main() {
    let app = Router::new().route("/", post(proc));
    let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
    println!("listening on {}", addr);
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await
        .unwrap();
}

#[derive(Deserialize)]
struct RequestJson {
    message: String,
}

#[derive(Serialize)]
struct ResponseJson {
    message: String,
}

async fn proc(Json(payload): Json<RequestJson>) -> Json<Value> {
    Json(json!({ "message": payload.message + " world!" }))
}

responseはimpl IntoResponseで実装されたものを返す事ができる。ドキュメント内のbuiliding responses節にString、HTML、Json、StatusCodeなどを返す実装イメージが掲載されているので参考にすると良い。routeやMiddlewareを付与する場合も同様に参照すると良い。

cargo runして、以下のhello文字列を投げると「hello world!」になって帰ってくる。

curl -X POST -H "Content-Type: application/json" -d '{"message":"hello"}' http://localhost:3000

 

base64による画像の受信

一旦無難にbase64で画像をやり取りする事を考える。Cargo.tomlに以下を追記する。

base64 = "0.13"
image = "0.23"

先程のスクリプトのpayload.messageを読んでいた箇所をbase64へデコードし、画像として保存するよう変更してみる。

extern crate base64;
extern crate image;

...

    let img_buffer = base64::decode(&payload.message).unwrap();
    let img = image::load_from_memory(img_buffer.as_slice()).unwrap();
    img.save("output.png").unwrap();

clientサイドとして、rustのロゴを取得してbase64エンコードした文字列を投げるPythonスクリプトを書いてみる。

import base64
import json
import requests      # require: pip install requests

sample_image_response = requests.get('http://rust-lang.org/logos/rust-logo-128x128-blk.png')
img = base64.b64encode(sample_image_response.content).decode('utf-8')
res = requests.post('http://127.0.0.1:3000', data=json.dumps({'message': img}), headers={'content-type': 'application/json'})

output.pngとしてRustのロゴ画像がcargo runしているディレクトリにできれば良い。適宜組み替える。

f:id:vaaaaaanquish:20210907135228p:plain
output.png (rust-lang.org/logos/より)

 

ExtensionLayerによるstate管理

機械学習APIなのでMLモデルを一回読み込んでグローバルに扱いたい。axumではExtensionLayerという機能を用いて、stateを実装できる。

https://docs.rs/axum/0.2.3/axum/#sharing-state-with-handlers

ここでは試しにHashSetをstateとしてみる。
AddExtensionLayerを使って、先程のAPIを同名の画像は保存しないように改修してみる。

use axum::{handler::post, Router, Json, AddExtensionLayer, extract::Extension};
use serde::{Serialize, Deserialize};
use serde_json::{json, Value};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
use std::collections::HashSet;

extern crate base64;
extern crate image;

struct DataState {
    set: Mutex<HashSet<String>>
}

#[tokio::main]
async fn main() {
    let set = Mutex::new(HashSet::new());
    let state = Arc::new(DataState { set });

    let app = Router::new()
        .route("/", post(proc))
        .layer(AddExtensionLayer::new(state));

    let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
    println!("listening on {}", addr);
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await
        .unwrap();
}

#[derive(Deserialize)]
struct RequestJson {
    name: String,
    img: String,
}

#[derive(Serialize)]
struct ResponseJson {
    result: String,
}

async fn proc(Json(payload): Json<RequestJson>, Extension(state): Extension<Arc<DataState>>) -> Json<Value> {
    let img_buffer = base64::decode(&payload.img).unwrap();
    let mut set = state.set.lock().await;

    let result;
    if set.contains(&payload.name) {
        result = "skip by duplicated";
    } else {
        let img = image::load_from_memory(&img_buffer.as_slice()).unwrap();
        img.save(&payload.name).unwrap();
        set.insert(payload.name);
        result = "saved output image";
    }
    Json(json!({ "result": result }))
}

先程のPythonスクリプトにname keyを付与してpostしていく。

res = requests.post('http://127.0.0.1:3000', data=json.dumps({'img': img, 'name': name}), headers={'content-type': 'application/json'})
print(res.text)

名前が重複したItemの場合は保存処理が走らず「skip by duplicated」なる文字列が返ってくる。名前がまだHashSet内にない場合はlocalディレクトリに画像が保存され、「saved output image」なる文字列が返ってくるようになった。

インメモリなので一度サーバを落とすと消えてしまうが、機械学習モデルをインメモリに保持する用途であれば十分だろう。

 

- tch-rsによる推論 -

PyTorchのRust bindingsでpretrain済みのモデルを流用して、推論を行うサンプルを過去に公開している。

github.com

こちらを流用して、推論を行うstateを作成しAddExtensionLayerに流す実装を行う。

tch-rsをCargo.tomlに追加する

tch = "0.5.0"

Arc>で囲むようにモデルのstructを定義する

...
use tch::nn::ModuleT;
use tch::vision::{resnet, imagenet};

extern crate tch;

struct DnnModel {
    net: Mutex<Box<dyn ModuleT>>
}

#[tokio::main]
async fn main() {
    let weights = std::path::Path::new("/resnet18.ot"); 
    let mut vs = tch::nn::VarStore::new(tch::Device::Cpu);
    let net:Mutex<Box<(dyn ModuleT + 'static)>> = Mutex::new(Box::new(resnet::resnet18(&vs.root(), imagenet::CLASS_COUNT)));
    let _ = vs.load(weights);
    let state = Arc::new(DnnModel { net });
...

RustのFutureは難しい部分がいくつかあり、私も把握しきれていないが、大まかな外枠は以下を見る事ですぐ把握できる。
zenn.dev
tech.uzabase.com

 
推論部分は一度画像を保存して読み込む形を取る。

...
    let net = state.net.lock().await;
    let img_buffer = base64::decode(&payload.img).unwrap();
    let img = image::load_from_memory(&img_buffer.as_slice()).unwrap();

    let _ = img.save("/tmp.jpeg");
    let img_tensor = imagenet::load_image_and_resize224("/tmp.jpeg").unwrap();
    let output = net
        .forward_t(&img_tensor.unsqueeze(0), false)
        .softmax(-1, tch::Kind::Float);

    let mut result = Vec::new();
    for (probability, class) in imagenet::top(&output, 5).iter() {
        result.push(format!("{:50} {:5.2}%", class, 100.0 * probability));
    }
...

ローカルに画像を保存せずメモリバッファ経由で実装する方法としてload_image_and_resize224_from_memoryが実装されているが、まだreleaseには至っていないようだ。もう少しでインメモリ上で推論が完結しそうである。
github.com


以下のRustロゴ画像を投げてみる

f:id:vaaaaaanquish:20210907135228p:plain
rust logo (rust-lang.org/logos/より)

レスポンスは以下のようになった。

 {
  "result": [
    "buckle 26.54%",
    "wall clock 5.34%",
    "digital watch  5.32%",
    "analog clock 4.14%",
    "digital clock 3.71%"
  ]
}

Rustのロゴはバックルか時計からしい。まあ概ね良さそう。

同様の方法を利用して、PythonのPyTorchで学習したモデルをRust bindings上で再現し推論を行うAPIを作成できるだろう。今回はこの辺でおわる。

 

- おわりに -

手探りの部分もあったが何とかできた。

コードは雑多だが以下に公開している。コメントはよしなにください。

github.com