Stimulator

機械学習とか好きな技術話とかエンジニア的な話とかを書く

mambaやripのinstallが何故早いのか調べたメモ

- はじめに -

最近、PythonのパッケージインストーラーであるpipをRustで書き直したripというツールが公開された。

github.com

ripのREADME.mdには、flaskを題材に依存解決とインストールが1秒で終わるようなgifが貼られている。

この速さは一体どこから来ているのか調べた。

 

- 宣伝 -

来週開催の技術書典15で「エムスリーテックブック5」が出ます。

私の内容は「自作Python Package Manager入門」で、CLIツールの作り方から始まって40ページでPyPIの仕様やその背景となっている要素を把握しながら、lock、install、run、build、uploadといったサブコマンドを実装してPackage Manager開発者になろうという内容です。Python開発経験2年くらいあれば作れると思います。

本誌ではまさかの「いろんな言語のパッケージマネージャ比べてみた」というパッケージマネージャネタ被りが起きており、こちらでは同僚がCargoやnpm、poetry、go mod、pnpm、yarnといったツールの内部実装を比較しています。本記事と同様にPackage Manager完全理解者の道を歩む事ができる本になっています。

techbookfest.org

# オンライン開催

  • 会期:2023/11/11 (土) 〜2023/11/26(日)
  • 会場:技術書典オンラインマーケット

# オフライン開催

12日は家庭の事情で午後からにはなりますが多分会場に居ますのでよろしくお願いします

- ripの成り立ち -

ripは、prefix-devというOrganization配下にある。
このprefixという会社は、実はAnacondaやその周辺ツールと関係があるため、まずその成り立ちから書いておく。

Anaconda

Anacondaは、知っての通りAnacondaリポジトリや周辺のcondaエコシステムに影響を及ぼす偉大な企業である。
あまり特別な情報はないが、会社のHistoryを見ていて2012年に出来た企業だと知って「思ったより若い」って思った。
もっとこうずっとあるイメージだった。

mamba-org

2022年、condaのエコシステムに大きな影響を与えた、「Mambaプロジェクト」というものがある。
フランスにjupyterやConda-forgeの開発者が集まるQuantStackという会社があり、そこに所属していた開発者@wuoulfが主軸になったOSSプロジェクトである。QuantStackの他の開発者も多くMambaプロジェクトに参画している。
BloombergやNumFOCUSなどから出資を受けているプロジェクトでもある。

Anacondaリポジトリやconda-forgeなどのcondaエコシステムへのアクセスは、CLIツール「conda」が長く利用されてきた。condaは大部分がPythonで書かれたツールであるのに対し、Mambaプロジェクトではcondaの互換性を保った形で多くをC++で再実装したツール「mamba」を開発、提供している。mambaはcondaに比べ、依存解決やパッケージのbuild方法の変更、並列化を行う事でも高速化されており、CLIツールとして表面だけ見てもより手軽にcondaエコシステムにアクセスできるようになっている。

github.com

そもそもMambaプロジェクトは、先に挙げたような便利な代替CLIツールの再実装やそれらの高速化とは別に、以下のような問題を解決するために立ち上がったプロジェクトであり、CLIツールはその結果の1つにあたる*1

故にMambaプロジェクトでは、CLIツールのconda代替「mamba」以外に、conda packageホスティングサーバの「Quetz」mamba内で高速にPackage buildするためのconda-build代替「boa」などをそれぞれ並行して開発している。
非常に大きいプロジェクトである。

中でもCLIツールmambaの実装が、condaエコシステムユーザにとって大きな恩恵をもたらしているという話である。
内部で利用される依存解決アルゴリズム等を含むコア実装「libmamba」も、QuantStackらにより開発された後、NumPy、SciPy、Jupyter、Matplotlib、scikit-learnなどPythonやAI/ML関連のライブラリのコア開発者らが所属するコンサルティング企業Quansightによって、condaに移植される事になる。これにより、condaはv23.10.0から大幅に高速化した。
www.anaconda.com

prefix.dev

先述のconda-forgeのコア開発者でもありmambaの発案者でもある@wuoulfは、QuantStackでのmambaプロジェクト後に「prefix.dev」という会社をドイツで立ち上げる。パッケージマネージャー開発を主軸とする会社であり、毎年開催されるパッケージマネージャーの国際カンファレンス「PackageCon」の主催企業でもある。

prefix.devは、conda packageの思想を拡張した、cross-platformmulti-languageに対応した高速かつ軽量なpackage managerである「pixi」を推して開発している。

世間一般ではcondaに関する言及の殆どがPythonやRに関するものであるため「condaエコシステムはPythonのためのもの」のような誤解があるが、conda packageの仕組みはaptやrpm等と変わらない。そのため、実際はどんな言語、どんなツールでもパッケージとして配布出来るし、installの仕組みを作る事ができる。prefix.devでは、conda package形式の高速なインデックスのホスティングhuggingface.coのような実行環境まで開発しており、本気で様々な言語のPackage Managementをpixiで飲み込んでやろうと企む気持ちが伺える。これらのバックエンドは、先に示したMambaプロジェクトで開発されたツールや関連の技術が使われている。
prefix.dev

当然、prefix.devはmambaやlibmamba、conda関連のツールの開発も牽引し続けている。
その他にも、prefix.devはパッケージマネージャーに関わるツールをRustで開発している。
著名な依存解決アルゴリズムbacktrackingのRust実装「resolvelib-rs」、libsolv(CDCL)のRust実装「resolvo」、condaエコシステムとのAPIのやり取りをRustでwrapした「Rattler」、そして今回の本題の1つ「rip」など、精力的に新しいツールを開発している。

mambaが全体的にC++で書かれており、そこから得られた知見を活かして、多くをRustで書き直そう、というスタイルのようだ。イカしてる。本当に頑張って欲しい。

- condaがinstallで行うこと -

話はロジック面に入る。先に示した通り、conda packageはPython以外でも扱えるようなフォーマットであり、condaエコシステムないしconda packageの構造は、PythonPyPIのものとはかなり違う*2。中身はさておき、Package Managerが気にするべきは、そのメタデータ取得方法になる。

Pythonパッケージの場合は多くの場合METADATAファイルを読みパッケージ情報を取得するが、conda packageの場合はrepodata.json(またはcurrent_repodata.json)である。repodata.jsonを含むファイルの構成は以下のようになっている。

https://docs.conda.io/projects/conda-build/en/stable/concepts/generating-index.html より

この構成のリポジトリに対して、condaが行うinstall作業は以下のようになる。

  1. 関連するrepodata.jsonファイルを全てダウンロードしてメモリ上に乗せる*3
    • repodata.json.bz2のような圧縮形式にもアクセスできるので必要あれば展開する
    • repodata.jsonの中にdependsがあり依存関係情報が入っている
  2. すべての依存関係の中から環境で使用される可能性のあるパッケージを絞り込む
    • 例: cudaが必要な場合はそうでない物を削る
  3. SATとしてSATソルバを再帰的に実行する
    • SATソルバを呼び出す前に優先度決定、pruningを毎回実行する
      • 例: 満たす必要のなくなった依存関係を先んじて削ってしまう
      • 例: 可能な限り最新バージョンを使う
      • 例: pruningとして削除された関連ファイルが多いパッケージを優先する
      • パッケージ内の優先度をパッケージ開発者が決めれたり*4もする
  4. SATソルバの結果得られたパッケージをinstallする
    • install_scriptを走らせる
    • 必要に応じてhard link/soft linkが使い分けられる

SATソルバは、元々PicoSATだったものが、CryptoMiniSatになり、現在はlibmamba(libsolv, CDCLの拡張実装)となっている。
この場合、速度において重要なのは、SATソルバとCaching戦略と依存解決を行う優先度の3つになる。この辺りはcondaの長い歴史の中で磨かれてきたものがあり、優先度ロジックは以下の「Running the Solver」にまとまっている。1つ1つはあまり難しいものではないので以下記事を参照して欲しい。
www.anaconda.com

mambaは、condaの積み重ねてきたテクニックを踏まえつつ、大きく再実装を実施することで速度面を改善している。

- mambaでの速度改善 -

mambaの大掛かりなC++採用において、依存解決ロジック周辺の変更は速度に大きく影響している。
condaは、PicoSATのPython wrapperであるpycosat、CryptoMiniSat(msoos/cryptominisat)のPython wrapperであるpycryptosatを依存解決に利用していた。先のcondaのinstallのロジックの説明の通り、condaの実装はSATソルバと他ロジックを複数回行き来しており、依存解決はC++、優先度決定やCachingの戦略はPythonという形になっている。Pythonとのやり取りがボトルネックになるため、mambaではrepodata.jsonを取得後、前章②の段階からC++に情報を渡す実装となっている。これにより「複数回のSATソルバ呼び出し」が無くなった。また、C++を直接扱えた事で、libsolv、libarchive、libcurlなどの強力なC++資産をそのまま扱う実装にもなっている。「優先度決定等のためのPythonC++間のオブジェクトの行き来」も大幅に無くなり、これらが速度に影響を与えている*5

依存解決のロジックは、openSUSEなどで利用、開発されている信頼と実績のあるlibsolvを叩く形に変わった。
アルゴリズムは変わったが、その際のベンチマーク自体には大きな変化は現れていないようなので、上記のC++再実装によるリファクタ効果が大きかったように感じられる。

全体感としては、開発者の@wuoulfによる記事やPackageConの動画があるため、そちらを参照するとよい。
wolfv.medium.com
www.youtube.com


また、repodata.json.zstといったストリーム可能な圧縮形式を用いてパッケージ情報取得を高速化している。ファイルのダウンロードにおいても、dnfなどでも使われるlibrepoをガッツリC++で書き直した「powerloader」も開発している。これにより、ファイル取得の並列化やrestart可能な分割ダウンロード、zchunkなどをサポートしている。powerloaderを利用して、repodata.jsonが更新されても更新されたbitのみ取得する方法も実装されており、ファイルダウンロードの側面からも速度が改善されている*6

 

- ripに応用されたこと -

ここまでで得られたテクニックを利用し、Rustでまるっと書き直したのがripにあたる。
Rustで書くことで、asyncで高速に依存解決とinstallができるようにな実装になっている。依存解決はlibsolv(CDCL)のRust実装「resolvo」をprefix.devが自前で作り利用している。これによって、PythonC++な部分が無くなり*7簡素で高速な実装になっている。前述したpixiにもresolvoが使われており、issue上ではpubgrub-rs作者との交流もあり、PubGrub*8導入などさらなる進化の余地を考えられているといった所だろう。

ロジック面のポイントとして、resolvoに入っているIncremental solving*9という考え方が高速化に繋がっている*10*11
Pythonパッケージにおいては正確なメタデータPyPI APIから返ってこないため、METADATAファイルに何らかの形でアクセスする必要がある。これがいかんせん高コストである。なので優先度付きキューにメタデータ取得処理を詰め込んで、Solverは非同期的にPackageの情報を取得、追加しながら探索を行う。故にIncremental。また、その際の探索パッケージの優先度決定はPubGrubのdecision makingに似た考え方を採用しており、最新バージョンを優先しながらもなるべく依存パッケージが少なくなる方向性に向かう*12。この実装により、Solverがメタデータ取得等で殆ど止まる事なく依存解決を行う事が出来ている。またファイルの分割ダウンロード等も実装されており、resolvelib-rsやpowerloader等を作成した経験、pubgrub-rsへの貢献を経て作成された依存解決ライブラリである事が伺える。

注意点として、Pythonパッケージのフォーマットはsdist/bdistの2種類があるが、ripはsdistに対応していない。
condaエコシステムと違い、PyPIの依存解決が高コストな理由うちの1つにsdistフォーマットがある。近年ではbdist(wheel)が十分広まりつつあるので、比較的新しいバージョンを指定すれば正常かつ高速に動作するかもしれない。一方でsdist未対応につき、installできないパッケージやバージョンが存在するという欠点にも繋がっている。また、ベンチマークでsdistが依存関係に入るものと比較する場合、それを実際使わずともsdist対応パッケージマネージャーは勝てない要素が多くなるので、公正なベンチマークが求められる。この辺りが今後asyncと絡んだ時にどうなるかポイントで、実際ripがsdistに対応した時の速度がどのようになるか未知数だと思われる。

- おわりに -

今回、mambaやripについて調べた。
元々conda周り知ってないとなと思って調べてあったが、エムスリー エンジニアリングフェローの@SassaHeroが気になっていたので「技術書典の熱も余ってるし書くか〜」と思い書いておいた。


コードは読んでいるが、condaをユーザとして使い倒しているわけではないので、もし間違いがあればこのブログのリンクと一緒に参考リンクをXで呟いておいて欲しい。


Rustのパッケージマネージャーといえば一時期ryeが話題になったりもしたが、ryeはRustとはいえ中身は殆ど既存のPython資産を叩いているもので、実際既存ツールと同じ問題に行き着いている様子が伺えていた。もっとRust寄りのパッケージマネージャーだとhuakがあるが、ずっとWIP状態である*13。しかしながらripの精力的な開発を見ていると、JSやCLIツール群がそうであるように、PythonのPackage Manager周辺にもRustが増えそうな感じがしてくる。依存解決などの難しい部分をprefix.devがRust化しているというのが明らかに大きい。Linterには最近よく使われる所でRuffがあるし、env環境マネージャーもyenのように挑戦者が居る。必要最低限のツールが出揃いつつある中で、シレッとprefix.devがこのまま全てRustなPython Package Managerを出してしまう気もする。
今後もPyPA等が公に導入することはなかなか無いだろうけど、こういったconda系統を経由してサードパーティとしてRustが増えていく感じがあるのは、Pythonの多様なユースケースの結果と捉えると面白い。


なお、こういったPyPI周辺のお話を日本語で知りたい場合は「PyPI APIメタデータ取得はどうなっているのか」「sdist/bdistとはどういう歴史で生まれた何なのか」「どうPythonパッケージをインストールするのが正解か」「これからこの問題はどのようになっていくのか」を技術書典の本の方に書いたので参照されたい。

techbookfest.org


是非どうぞ

*1:Mambaプロジェクト創設者らのブログより

*2:conda installのドキュメントもしくはconda開発者の2015年スライドが全体感を掴むのに良い

*3:fetching

*4:track-features

*5:実装は読んだが私が実際にベンチマークした訳ではないので実態は不明

*6:libmamba vs classic — conda-libmamba-solverが詳しい

*7:正確にはPythonのパッケージ情報を扱うpackagingだけvendoringされているが

*8:PythonのPackage Managerを深く知るためのリンク集を参照して欲しい

*9:#349

*10:公式のブログが後日出るらしい https://prefix.dev/blog/introducing_rip#step-2-make-the-solver-lazy

*11:こちらも詳しい https://github.com/pypa/pip/issues/7406#issuecomment-583989243 https://github.com/pubgrub-rs/pubgrub/issues/138

*12:solver/decision*.rs

*13:PyPAメンバーの@uranusjrがシレッと公開していたmoltも今見たら更新されていなかった…

Google Cloud Champion Innovator になりました

お知らせ

Google Cloud Champion Innovator になりました。

https://developers.google.com/profile/u/108992975007665801883

Cloud AI/ML領域です。

昨年度、ありがたい事にGoogle Cloud OnAirやGoogle Cloud Innovators Hive at Next ’22に登壇させて頂き、その中での活動で推薦頂きました。

1e100.4watcher365.dev
cloud.google.com

それら以外にもMLOpsイベントの開催など、広く活動していた事が良かったみたいです。

国内だとK_Ryuichirouさんとも一緒です。

検索してみたところpolar3130さんの記事しか見当たらなかったので、そもそもの認知度を上げるために記事にしておきました。

polar3130.hatenablog.com

真摯に技術と向き合い続けていきたいなと改めて思う次第です。
 

LLMの登場で、AI/ML業界も一変していく事になりそうです。

形は様々ですが、人、組織、技術、業界全ての側面で推進に貢献出来ればと思っています。

イベント等、是非皆さんとご一緒出来ればと思っていますので、今後とも何卒よろしくお願いします。

河合俊典 (@vaaaaanquish)

最適輸送本イベントに寄せて学ぶ

はじめに

Forkwell Libraryという書籍の著者が登壇するイベントにて、最適輸送の理論とアルゴリズム (機械学習プロフェッショナルシリーズ) の佐藤さん(@joisino_)と話す時間を頂いた。
forkwell.connpass.com

スライド
動画

その時に事前に学んだメモの公開と、当日のイベントの肌感を残す。

最適輸送の理論とアルゴリズム

MLPシリーズの書籍

最適輸送の理論的な背景から応用まで書かれている。
私個人としては、幾何や統計、測度についてお気持ちレイヤーまで分かる、機械学習コンピュータサイエンスなら少しわかる、くらいの私でもちゃんと読めるように、深く入り込まず難しい所に例示を出して優しく書いてくれている上、応用事例まで付いてくるありがたい本。


サポートページとしてGitHubリポジトリも用意されている。
github.com
書籍内にあるアルゴリズム、最適輸送での画像操作事例など、jupyter notebookで各サンプルが動かせるようになっている。
一通り動かしたが、数式で分からない所は大体この例示で解決する気がする。ありがたい。


ざっくりこんな感じになっている。

1章 各種定義
2章 最適化問題としての定式化
    輸送計画とリサイクル業者のイメージ、双対問題
    ボールによる物理的解釈
    最適輸送問題の疎性
    組合せ最適化、線形計画法、最小費用流問題(minimum flow cost problem)
3章 エントロピー正則化、シンクホーンアルゴリズム
    エントロピー正則化つき問題は強凸、最適解が一意に定まる、微分可能
    シンクホーンアルゴリズムは「ソフトなC変換」であり「行列スケーリング」
    行列計算なのでGPU+畳み込みで高速化できる
    シンクホーンは大域収束性、計算量
4章 GANと最適輸送
5章 スライス法、1次になおして貪欲法
    カーネル的、木による階層クラスタリング
6章 KL、JSダイバージェンス、MMD、ワッサースタイン距離
7,8章 不均衡最適輸送→ワッサースタイン重心
9章 グロモフワッサースタインで2つの異なる分布を比較

ML屋視点でしんどいのは2章の確率測度が関わる証明と3章の計算量の証明辺り。
イベント内で佐藤さんも「証明は全て読み込む必要はないので是非読み進めてもらって」と仰っていたので、そこを抜ければ4章以降は線形代数とMLでよく使われる知識があればスッと読み進める事が出来る。

7章で「不均衡最適輸送って応用があるのねふむふむ…」くらいに思ってたら、恐ろしくスムーズな流れで9章で「あれ、最適輸送めっちゃ便利じゃん…」ってなる。

ML屋がツールとしての最適輸送を学ぶ上で最良だと思う(宣伝)。

事前学習

何に使われているか。

「2つの分布の距離として使う、比較する」「2つの分布の中間状態を捉える」というツールと捉える事ができる。


Wasserstein GANでMLから着目を浴びたとされているが、WGANの話は書籍を読むと流れが把握できる。


つまり色々使える。
KLダイバージェンスの強い版、教師なしアラインメントな損失関数と捉える事もできる。
直近スケーリング則より良い学習データをサンプルした方が良いという動きもあり、分布を利用したSamplerにもなり得る。
複数のDNNモデルのタスクを複合的に扱ったり、埋め込み空間としての活用も進んでいきそう。

 

何が嬉しくて使われているのか

端的にMLにとって嬉しい点をまとめる

  • 複数の問題を並列に扱える
  • データが少ない、教師なしでも扱える
  • コストが微分可能
  • 分布の全体の状態、内部の状態を上手く扱える
  • 2つの分布に重なりが無くても近似できる
    • 数理最適化では計算量が大きくなる
    • 最適輸送におけるシンクホーンアルゴリズムやスライス法が良い
      • これらが行列計算なのでGPU上で扱える

事前、並行して読むと良いもの

大まかにまとめてくれているもの

  • 最適輸送の解き方, Ryuma Sato
    • 最適輸送の解き方
    • 今回のイベントスピーカーでもある佐藤さんのスライド
    • それぞれの手法のメリデメが端的かつ分かりやすくまとめられている

イベント当日のQ&A

いくつか良かったQ&Aのメモ、順不同

  • KLダイバージェンスと比較していたが、f-divergence、Bregman divergence、integral probability metricsと比較すると
    • 最適輸送はIPMの一種、クロスエントロピーはf-divergenceの一種
    • 最適輸送はカスタマイズ性と一般性が良いよ
    • 6章読むといいぞ!
  • 競プロで使える?
    • 最小費用流問題なので知っておいて損はないよ
    • ICPCの後ろの方とかでは出てくるよ!
  • 最適輸送で分かってないトピックはある?
    • 連続分布をサンプリングしたら点の集合が得られるが、その時に精度良く近似になっているかとか
    • 分布の最悪ケースとか
      • 最悪ケースとは?
      • いわゆる意地悪な分布に対してどこまで出来るのか、定式化
  • 時系列では扱える?
    • 時系列でDTWっぽいことをやる、も研究されてるよ
    • 物理とか生物では、細胞の動きのモデリング等で大昔から使われている
  • 広義の拡散モデルと言えたりする?
    • 拡散モデルとの繋がりは指摘されている
    • フォッカー・プランク方程式 (Fokker–Planck equation)とかで検索すると良いよ
  • 最適輸送という訳が「輸送コスト最適化」のようなイメージを植え付けるのでは?「分布の比較」としての名前を付けるとしたら?
    • 難しい、そもそも英語がOptimal Transportなので
    • 古い論文ではOptimal Transportionだった、MLブームにつれてOptimal Transportに
  • 理論をプログラムに落とすの難しい、という人が多そうだけどどう思う?
    • シンクホーンなど理論の割にはコードが簡単になるのが人気の理由の1つではないか
    • 最適輸送に関しては簡単というのが印象
  • 基礎、応用で読むべき本はある?
    • 「最適輸送の理論とアルゴリズム」をまず読んでみて欲しい
    • Computational Optimal Transport
      • ちょっと広い範囲、応用も含まれている
      • 公式pdfがarXivに公開されてるので是非
    • Optimal transport, old and new
      • 数学者が書いた1000ページある本
      • 数学的な細かい定義、どういう条件が揃ったら使えるか等

おわりに

Forkwellの人がお手上げになり、最適輸送本の回はモデレータ出来ないかもという事で1週間程前にお話を頂いて、最適輸送をツールとして何となく知っているくらいから一気に学習したので、学部のゼミの気分だった。

難しい。

イベントでは、かなり優しく最適輸送の事例を紹介して頂いて、ありがたかった。

ありがて〜
  

Rustでグラフをplotするライブラリのまとめ

- はじめに -

Rustでグラフを描画したいと思った時に調べたクレートとその実装、機能のまとめた時のメモ。

現状はplottersを使っておけば間違いなさそうだが、目的によっては機能で選択する場合もありそう。


 

- 前提知識 -

グラフの描画までの機能としては、matplotlibのようにaxisやviewを構造体として持っているライブラリもあれば、受け取った配列をそのままgnuplotスクリプトに変換するライブラリもある。
詳細は後述するが、当然この構造に依存してインターフェースが変わったり、出来ること出来ないことがある。


plotを想定したグラフデータの出力方法は大きく3つに分かれる。
SVG等を通して画像ファイルとして出力する方法、jsやwasmやhtmlテンプレートエンジンを利用してHTMLベースで出力する方法、テキストベース(アスキーアート)として表示する方法である。

また、Jupyter NotebookのRust Kernelとして現状開発が継続しているものにevcxrというライブラリがあり、こちらに出力する事が出来るかも差別化の点に入る。
github.com


OpenCV等の画像処理系ライブラリを用いてもグラフの描画はもちろん行えるが、今回はグラフ描画を軸としたライブラリの調査であり対象とはしない。

 

- グラフ描画クレートざっくりまとめ -

2021/09/21時点での大まかな実装とライブラリをまとめる

plotters

A rust drawing library for high quality data plotting for both WASM and native, statically and realtimely 🦀 📈🚀
latest commit: 2021/09/17, star: 1.5K
github.com

以下参考に成り得る文献

plotly

Plotly for Rust
latest commit: 2021/07/15, star: 467
github.com

plotlib

Data plotting library for Rust
latest commit: 2021/02/01, star: 335
github.com

  • 非常にmatplotlibを意識したであろう実装になっている
  • 開発は滞り気味
  • Vecやndarrayに対応
  • 自前でaxisやviewの構造体を持っている
  • textでの描画、svgクレートを使った画像での描画に対応
  • matplotlibに似た思想のAPIを持つ
    • matplotlibにおけるfigure、axesがview, plotに当たる

以下参考に成り得る文献

poloto

A simple 2D plotting library that outputs graphs to SVG that can be styled using CSS.
latest commit: 2021/09/17, star: 28
github.com

以下参考に成り得る文献

rustplotlib

A pure Rust visualization library inspired by D3.js
latest commit: 2021/07/13, star: 1116
github.com

  • D3.jsをまるっとrustで書き直している
  • 描画がかなり綺麗な印象
  • 最後はsvgクレートでSVGに書き出している (.to_svg)
  • 発想としてはかなり壮大なプロジェクトだが、更新はしばらく止まっていそう
  • multiviewなどに未対応だが今後開発されるかは実装を見る限り微妙そう
    • 対応するplot形式を沢山作る方針っぽい

RustGnuplot

A Rust library for drawing plots, powered by Gnuplot
latest commit: 2021/09/01, star: 324
github.com

以下参考に成り得る文献

preexplorer

Externalize easily the plotting process from Rust to gnuplot.
latest commit: 2021/09/06, star: 4
github.com

vega_lite_4.rs

rust api for vega-lite v4
latest commit: 2021/01/22, star: 7
github.com

  • Pythonで言う所のAltair
  • Vega Lite(vega-lite.js)にJsonAPIがあるのでそれを叩くための実装を用意したもの
  • nalgebraやndarray、rulinalg等の主要な行列ライブラリに対応している
  • showtaを作っている人と同じ
    • https://github.com/procyon-rs/showata
      • HTMLを生成するためのツール
      • jupyter notebook上に描画する事を目的としている
      • tableと画像をHTMLに変換するためのツール
    • showtaを経由してevcxrで表示できる
  • version4に対応したもので、vega_lite_3.rsも存在する

dataplotlib

Scientific plotting library for Rust
latest commit: 2017/10/14, star: 57
github.com

chord_rs

Rust crate for creating beautiful interactive Chord Diagrams.
latest commit: 2021/01/07, star: 22
github.com

  • Chord Diagramsを描画するためだけのクレート
  • Chord PROなるAPIを叩くclientであり、描画機構については分からない

- アスキーアート系のクレート -

plotlib等でも対応しているが、CLIなどで扱えるようにtext形式でplotするクレートがいくつかある。以下簡単に。

plotという強気の命名がここにある

- 記事外で参考になりそうな記事 -

- おわりに -

まとめた。

Rustでデータ分析する所までやるユーザはあまりいなさそうで、定常分析や監視に使うならHTMLレンダリングは筋が良さそう。
一方plottersが一番活発に開発されているので、何を選択しましょうかという感じ。

 

axumとtch-rsでRustの画像認識APIを作る

- はじめに -

PyTorchのRust bindingsであるtch-rsを使って、画像認識APIを実装する時のメモ。

今回は非同期ランタイムのtokioと同じプロジェクト配下で開発されているaxumを利用する。


 

- axumによるHTTPサーバ構築 -


RustでHTTPサーバを立てるライブラリはいくつかある。現状日本語ではCyberAgent社の以下のブログが詳しい。

developers.cyberagent.co.jp

私自身にあまり選定ノウハウが無いので、今回はtokioから出ているaxumを利用する。

axumを利用してHTTPサーバを構築するにあたっては、repository内のexampleディレクトリに複数の実装サンプルが配置されている他、Tokioのreleaseにも簡単なkickstartが存在するので、そちらを見ながら開発を進めた。

github.com

 

hello world

Cargo.tomlを作成する。

[package]
name = "rust-machine-learning-api-example"
version = "0.1.0"
authors = ["vaaaaanquish <6syun9@gmail.com>"]
edition = "2018"

[dependencies]
axum = "0.2.2"
tokio = { version = "1.0", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

localhostにpostリクエストを投げる事でjsonをやり取りするサンプルを書く。

use axum::{handler::post, Router, Json};
use serde::{Serialize, Deserialize};
use serde_json::{json, Value};
use std::net::SocketAddr;

#[tokio::main]
async fn main() {
    let app = Router::new().route("/", post(proc));
    let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
    println!("listening on {}", addr);
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await
        .unwrap();
}

#[derive(Deserialize)]
struct RequestJson {
    message: String,
}

#[derive(Serialize)]
struct ResponseJson {
    message: String,
}

async fn proc(Json(payload): Json<RequestJson>) -> Json<Value> {
    Json(json!({ "message": payload.message + " world!" }))
}

responseはimpl IntoResponseで実装されたものを返す事ができる。ドキュメント内のbuiliding responses節にString、HTML、Json、StatusCodeなどを返す実装イメージが掲載されているので参考にすると良い。routeやMiddlewareを付与する場合も同様に参照すると良い。

cargo runして、以下のhello文字列を投げると「hello world!」になって帰ってくる。

curl -X POST -H "Content-Type: application/json" -d '{"message":"hello"}' http://localhost:3000

 

base64による画像の受信

一旦無難にbase64で画像をやり取りする事を考える。Cargo.tomlに以下を追記する。

base64 = "0.13"
image = "0.23"

先程のスクリプトのpayload.messageを読んでいた箇所をbase64へデコードし、画像として保存するよう変更してみる。

extern crate base64;
extern crate image;

...

    let img_buffer = base64::decode(&payload.message).unwrap();
    let img = image::load_from_memory(img_buffer.as_slice()).unwrap();
    img.save("output.png").unwrap();

clientサイドとして、rustのロゴを取得してbase64エンコードした文字列を投げるPythonスクリプトを書いてみる。

import base64
import json
import requests      # require: pip install requests

sample_image_response = requests.get('http://rust-lang.org/logos/rust-logo-128x128-blk.png')
img = base64.b64encode(sample_image_response.content).decode('utf-8')
res = requests.post('http://127.0.0.1:3000', data=json.dumps({'message': img}), headers={'content-type': 'application/json'})

output.pngとしてRustのロゴ画像がcargo runしているディレクトリにできれば良い。適宜組み替える。

f:id:vaaaaaanquish:20210907135228p:plain
output.png (rust-lang.org/logos/より)

 

ExtensionLayerによるstate管理

機械学習APIなのでMLモデルを一回読み込んでグローバルに扱いたい。axumではExtensionLayerという機能を用いて、stateを実装できる。

https://docs.rs/axum/0.2.3/axum/#sharing-state-with-handlers

ここでは試しにHashSetをstateとしてみる。
AddExtensionLayerを使って、先程のAPIを同名の画像は保存しないように改修してみる。

use axum::{handler::post, Router, Json, AddExtensionLayer, extract::Extension};
use serde::{Serialize, Deserialize};
use serde_json::{json, Value};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
use std::collections::HashSet;

extern crate base64;
extern crate image;

struct DataState {
    set: Mutex<HashSet<String>>
}

#[tokio::main]
async fn main() {
    let set = Mutex::new(HashSet::new());
    let state = Arc::new(DataState { set });

    let app = Router::new()
        .route("/", post(proc))
        .layer(AddExtensionLayer::new(state));

    let addr = SocketAddr::from(([0, 0, 0, 0], 3000));
    println!("listening on {}", addr);
    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await
        .unwrap();
}

#[derive(Deserialize)]
struct RequestJson {
    name: String,
    img: String,
}

#[derive(Serialize)]
struct ResponseJson {
    result: String,
}

async fn proc(Json(payload): Json<RequestJson>, Extension(state): Extension<Arc<DataState>>) -> Json<Value> {
    let img_buffer = base64::decode(&payload.img).unwrap();
    let mut set = state.set.lock().await;

    let result;
    if set.contains(&payload.name) {
        result = "skip by duplicated";
    } else {
        let img = image::load_from_memory(&img_buffer.as_slice()).unwrap();
        img.save(&payload.name).unwrap();
        set.insert(payload.name);
        result = "saved output image";
    }
    Json(json!({ "result": result }))
}

先程のPythonスクリプトにname keyを付与してpostしていく。

res = requests.post('http://127.0.0.1:3000', data=json.dumps({'img': img, 'name': name}), headers={'content-type': 'application/json'})
print(res.text)

名前が重複したItemの場合は保存処理が走らず「skip by duplicated」なる文字列が返ってくる。名前がまだHashSet内にない場合はlocalディレクトリに画像が保存され、「saved output image」なる文字列が返ってくるようになった。

インメモリなので一度サーバを落とすと消えてしまうが、機械学習モデルをインメモリに保持する用途であれば十分だろう。

 

- tch-rsによる推論 -

PyTorchのRust bindingsでpretrain済みのモデルを流用して、推論を行うサンプルを過去に公開している。

github.com

こちらを流用して、推論を行うstateを作成しAddExtensionLayerに流す実装を行う。

tch-rsをCargo.tomlに追加する

tch = "0.5.0"

Arc>で囲むようにモデルのstructを定義する

...
use tch::nn::ModuleT;
use tch::vision::{resnet, imagenet};

extern crate tch;

struct DnnModel {
    net: Mutex<Box<dyn ModuleT>>
}

#[tokio::main]
async fn main() {
    let weights = std::path::Path::new("/resnet18.ot"); 
    let mut vs = tch::nn::VarStore::new(tch::Device::Cpu);
    let net:Mutex<Box<(dyn ModuleT + 'static)>> = Mutex::new(Box::new(resnet::resnet18(&vs.root(), imagenet::CLASS_COUNT)));
    let _ = vs.load(weights);
    let state = Arc::new(DnnModel { net });
...

RustのFutureは難しい部分がいくつかあり、私も把握しきれていないが、大まかな外枠は以下を見る事ですぐ把握できる。
zenn.dev
tech.uzabase.com

 
推論部分は一度画像を保存して読み込む形を取る。

...
    let net = state.net.lock().await;
    let img_buffer = base64::decode(&payload.img).unwrap();
    let img = image::load_from_memory(&img_buffer.as_slice()).unwrap();

    let _ = img.save("/tmp.jpeg");
    let img_tensor = imagenet::load_image_and_resize224("/tmp.jpeg").unwrap();
    let output = net
        .forward_t(&img_tensor.unsqueeze(0), false)
        .softmax(-1, tch::Kind::Float);

    let mut result = Vec::new();
    for (probability, class) in imagenet::top(&output, 5).iter() {
        result.push(format!("{:50} {:5.2}%", class, 100.0 * probability));
    }
...

ローカルに画像を保存せずメモリバッファ経由で実装する方法としてload_image_and_resize224_from_memoryが実装されているが、まだreleaseには至っていないようだ。もう少しでインメモリ上で推論が完結しそうである。
github.com


以下のRustロゴ画像を投げてみる

f:id:vaaaaaanquish:20210907135228p:plain
rust logo (rust-lang.org/logos/より)

レスポンスは以下のようになった。

 {
  "result": [
    "buckle 26.54%",
    "wall clock 5.34%",
    "digital watch  5.32%",
    "analog clock 4.14%",
    "digital clock 3.71%"
  ]
}

Rustのロゴはバックルか時計からしい。まあ概ね良さそう。

同様の方法を利用して、PythonのPyTorchで学習したモデルをRust bindings上で再現し推論を行うAPIを作成できるだろう。今回はこの辺でおわる。

 

- おわりに -

手探りの部分もあったが何とかできた。

コードは雑多だが以下に公開している。コメントはよしなにください。

github.com


 

Rustでlabel propagationを実装した

- はじめに -

教師あり学習アルゴリズムの1種であるlabel propagationをRustで実装し、クレートとして公開した。

github.com

本記事は、label propationの実装と検証を行った際のメモである。

 

- label propagationとは -

label propagationは、transductive learningの枠組みの1つでもあり、グラフ構造を利用した機械学習アルゴリズムである。

 
ラベルがあるデータ、ラベルのないデータ、それらを繋ぐエッジがある状態で、ラベルのないデータに付くラベルを推定する事が解きたいタスクとなる。
最もシンプルな実タスクとして例示すると「文書データ等で一部のデータにはラベルがあるが一部欠損している所を推定したい」「ユーザとアイテム、それらを繋ぐPV等のエッジがあり、アイテムにのみラベルがある状態でユーザにもラベル付けを行いたい」といった状況が想定できる。

近年ではCVPR 2019でEmbeddingによる距離をノードとして画像ラベルを推定して利用する手法*1が採択されるなどしており、汎用的なアルゴリズムの1つである。近いワードとしては、tag recommendationなどがあり、PageRankアルゴリズムを利用した手法*2やCollaborative filteringを拡張する手法*3が提案されている他、Content baseな方法もまた考えられる。

実際エムスリーではtag propagationを利用したtag伝搬を用いてユーザのタグ付けを行い、様々な配信のセグメント分けや分析に利用している*4。ハイパーパラメータが少なく、グラフ生成部及び内部の行列計算手前までをオンライン化する事ができ、汎用性が高く安定した結果を得られる所が良いところである。

 
label propagationの問題設計は、 (x_{n}, y_{n})をラベル付きデータにY_{N}=y_{1},...y{n}のC個のラベルが付与されていた時、そこから観測できないuのデータに紐付いたY_{U}=y_{n+1},...y_{n+u}を推定する事にある。データ間の重みwは、古典的にユークリッド距離dとハイパーパラメータ\alphaを用いて簡素に以下のように表現される。

 w_{ij} = exp \bigg( - \frac{d_{ij}^{2}}{\alpha^{2}} \bigg) = exp  \bigg( - \frac{ \sum_{d=1}^{D} ( x_{i}^{d} - x_{j}^{d} )^{2} }{\alpha^{2}} \bigg)

これは最も簡素な例で、距離に関しても時に離散的な距離であったりDNNのEmbeddingから得られる距離であったりする。wを作るためには、(n+u)\dot(n+u)の確率遷移行列を作ってやればよい。

行列の最適化のためのアプローチは、いくつか方法があるが、概ね以下が詳しい。

ベースは、グラフ上で隣接するノードは同じラベルを持つ可能性が高い、という所に基づいて設計した目的関数を最小化することでweight行列を最適化する。「隣接ノードが同じラベル」の閾値をパラメータや推論によってコントロールする拡張が主である。

 
label propagationは、Pythonではsklearn内にも実装されており、簡易に呼び出す事ができる。
sklearn.semi_supervised.LabelPropagation — scikit-learn 0.24.2 documentation

 
よりグラフィカルな解説は以下が参考になる。オススメ。


 

- Rustによる実装 -

先に示した通り、確率遷移行列を作って最小化できれば良いので、行列演算を行う事になる。

今回はndarrayを利用して実装している。rust/ndarrayのドキュメント内にnumpyからの移行のススメがあるので、基本的にはここを参照すると良い。

docs.rs

   
numpyにはadvanced-indexingという機能がある。

Indexing — NumPy v1.21 Manual
こういうやつ

x = np.array([0, 1])
y = np.array([[0, 0], [0, 0], [0, 0]])

y[x] = 1

# array([[1, 1], 
#             [1, 1], 
#             [0, 0]])


rustのndarrayでは、現状実装されていないのでslice_mutで指定インデックスごとにスライスを作ってfillterによる代入を行う必要がある。

for i in x {
    y.slice_mut(s![*i, ..]).fill(1);
}

 
機械学習で行列を扱う時は大体スパースな事が多く、実装としてsparse matrixを使う事が多い。現状ndarrayにはsparse matrixに類似するものは実装されていなさそう。同じく行列演算を趣旨としたnalgebraにはnalgebra_sparse::csr::CsrMatrixがあるが、こちらはdot積などが実装されていない。
なのでArrayBaseで押し切る実装になってしまった。メモリに優しくない。linfaなど、一部ライブラリで独自にsparse matrixを実装しているものもあるが、クレート依存が激しい。

以下のクレートを試してみてはという助言を貰ったので検証中ではある。
github.com
この辺何か良い方法があるんだろうか。知っている人居れば教えて欲しい。

上記以外はdot積と行列変換が扱えれば良いのでndarrayで十分実装できる。

 

- 検証 -

irisデータセットを利用して、一部のラベルを欠損、各データのユークリッド距離をエッジと考えて、label propagationにより欠損ラベルを推論する。

公開したlabel-propagation-rsには、label propagationの派生アルゴリズムとして、LGCとCAMLPを実装しており、検証にはCAMLPを利用した。

Rustにおけるsklearnのような立ち位置になるライブラリであるsmartcoreよりirisデータセットを読み込んで行列を作る。閾値としてユークリッド距離の逆数が0.5以下になっている場合はエッジを繋がないものとする。

...
    let iris = iris::load_dataset();

    let node = (0..iris.num_samples).collect::<Array<usize, _>>();
    let mut label = Array::from_shape_vec(iris.num_samples, iris.target.iter().map(|x| *x as usize).collect())?;
    let mut graph = Array::<f32, _>::zeros((iris.num_samples, iris.num_samples));

    let data = Array::from_shape_vec((iris.num_samples, iris.num_features), iris.data)?;
    for i in 0..iris.num_samples {
        for j in 0..iris.num_samples {
            if i != j {
                let weight = 1. / (*&data.slice(s![i, ..]).sq_l2_dist(&data.slice(s![j, ..]))? + 1.);  // reciprocal
                if weight > 0.5 {
                    graph[[i, j]] = weight;
                }
            }
        }
    }
...

ざっくり10個ターゲットを選んで、ノードに付与されたラベルを0にする。irisのラベルは0,1,2のどれかなので「ランダムにあるラベルが0になってしまった」という状況になる。

...
    let target_num = 10;
    let mut rng = thread_rng();
    let target = (0..iris.num_samples).choose_multiple(&mut rng, target_num).iter().map(|x| *x).collect::<Array<usize, _>>();
    for i in &target {
        label[*i] = 0;
    }
...

モデルを学習させて、上記で0にしたターゲットのlabelを推定する。

...
    let mut model = CAMLP::new(graph).iter(100).beta(0.1);
    model.fit(&node, &label)?;
    let result = model.predict_proba(&target);

    for (i, x) in target.iter().enumerate() {
        println!("node: {:?}, label: {:?}, result: {:?}", *x, iris.target[*x], result.slice(s![i, ..]).argmax()?);
    }
...

結果は以下のようになった。

node: 0, label: 0.0, result: 0
node: 14, label: 0.0, result: 0
node: 67, label: 1.0, result: 0
node: 118, label: 2.0, result: 2
node: 43, label: 0.0, result: 0
node: 144, label: 2.0, result: 2
node: 91, label: 1.0, result: 1
node: 137, label: 2.0, result: 2
node: 49, label: 0.0, result: 0
node: 62, label: 1.0, result: 1

node 67のみ、真のラベルが1に対して推論ラベルが0となってしまっているが、それ以外は正解している。良い感じ。実際どういったデータで各metricでどの程度の精度が出るかはこれから検証していく。

上記の検証コードはexample内にある。

label-propagation-rs/examples at main · vaaaaanquish/label-propagation-rs · GitHub

 

- おわりに -

label propationの実装と検証を行い、クレートとして公開した。

まずは動く所までという感じ。

できればどこかでPythonとの比較をやりたい。

 

*1: A. Iscen, G. Tolias, Y. Avrithis, O. Chum. "Label Propagation for Deep Semi-supervised Learning", CVPR 2019 https://openaccess.thecvf.com/content_CVPR_2019/papers/Iscen_Label_Propagation_for_Deep_Semi-Supervised_Learning_CVPR_2019_paper.pdf, github: https://github.com/ahmetius/LP-DeepSSL

*2:Heung-Nam Kim and Abdulmotaleb El Saddik. 2011. Personalized PageRank vectors for tag recommendations: inside FolkRank. In Proceedings of the fifth ACM conference on Recommender systems (RecSys '11). Association for Computing Machinery, New York, NY, USA, 45–52. DOI:https://doi.org/10.1145/2043932.2043945

*3:Kim, Heung-Nam, et al. "Collaborative filtering based on collaborative tagging for enhancing the quality of recommendation." Electronic Commerce Research and Applications 9.1 (2010): 73-83. https://www.sciencedirect.com/science/article/pii/S1567422309000544

*4:エムスリーにおけるグラフ構造を用いたユーザ興味のタグ付け - Speaker Deck

Pure Rustな近似最近傍探索ライブラリhoraを用いた画像検索を実装する

f:id:vaaaaaanquish:20210810063410p:plain

- はじめに -

本記事は、近似最近傍探索(ANN: Approximate Nearest Neighbor)による画像検索をRustを用いて実装した際のメモである。

画像からの特徴量抽出にTensorFlow Rust bindings、ANNのインデックス管理にRustライブラリであるhoraを利用した。

RustとANNの現状および、実装について触れる。

 

 

- RustとANN -

Rustの機械学習関連クレート、事例をまとめたリポジトリがある。

github.com

この中でも、ANN関連のクレートは充実している。利用する場合は以下のようなクレートが候補になる。

* Enet4/faiss-rs
* lerouxrgd/ngt-rs
* rust-cv/hnsw
* hora-search/hora
* InstantDomain/instant-distance
* granne/granne
* qdrant/qdrant

Pythonでもしばしば利用されるfacebook researchのfaiss、Yahoo!のNGTのrust bindingsは強く候補に上がる。C++からGPUが触れる点から利用だけならfaissが活用しやすいだろう。

 
他にPure Rustで機能が充実しているクレートにhoraがある。
github.com

horaには、PythonJavascriptJavaのbindingsがあるだけでなく、Pure Rustである事でWebAssembly化などもサポートしている。
また、インデキシングアルゴリズムとして多く利用されているHNSWIndex以外にグラフベースのSatellite System Graph*1、直積量子化を行うProduct Quantization Inverted File*2が実装されており、開発が継続されている数少ないクレートである。
一部SIMDによる高速化が図られている(https://github.com/rust-lang/packed_simdによるもの)。

(horaの由来は「小さな恋の歌」とREADMEに書いてあるが、どういう経路で知られたのかよくわからない)

今回は、画像検索のwasm化を目指し、horaを利用する。
画像検索がwasm化する事で、API経由で行われていた画像検索の一部がエッジデバイス上で処理できる可能性などの幅が出る事を期待する。
例えば、ネット環境を扱えないや工場やサーバセンター、病院であったり、個人情報の観点でスマフォやカメラの外に出せない画像をその場で類似画像検索にかける事ができる可能性である。

 
画像特徴を抽出する部分でもwasm化を目指すため、wasmの利用実績が多いTensorFlowを利用する。

TensorFlowにはRust bindingsが存在する。
github.com

今回はこちらを利用してモデルを作成し、wasm化する。
他にもDNNのライブラリはいくつかあるが、開発が活発でないか、PyTorchのRust bindingsは現在中間層の出力を受け取る方法がないなど、機能的に難しい場合が多かった。

(実験時に作成したPyTorchのRust bindingsでpretrain modelのpredictを実行するdockerなども公開している https://github.com/vaaaaanquish/tch-rs-pretrain-example-docker

 

- pretrainモデルによる特徴量化 -

TensorFlow 2.xでのRustとPythonの相互運用に関する以下の記事を参考にした。

TensorFlow 2.xでのRustとPython

import tensorflow as tf
from keras.models import Model
from tensorflow.python.framework.convert_to_constants import \
    convert_variables_to_constants_v2

# pretrainモデルの読み込み
model = tf.keras.applications.resnet50.ResNet50(weights='imagenet')

# 中間層の出力を得るモデルにする
embedding_model = Model(inputs=model.layers[0].input, outputs=model.layers[-2].output)

# tf.functionに変換しpbファイルとしてgraphを保存できる状態にする
resnet = tf.TensorSpec(embedding_model.input_shape, tf.float32, name="resnet")
concrete_function = tf.function(lambda x: embedding_model(x)).get_concrete_function(resnet)
frozen_model = convert_variables_to_constants_v2(concrete_function)

# fileをdumpする
tf.io.write_graph(frozen_model.graph, '/app/model', "model.pb", as_text=False)

Rustのbindingsから読み込み、画像ファイルを特徴量に変換する。

// モデルファイルを読み込み、セッションを作る
let mut graph = Graph::new();
let mut proto = Vec::new();
File::open("model/model.pb")?.read_to_end(&mut proto)?;
graph.import_graph_def(&proto, &ImportGraphDefOptions::new())?;
let session = Session::new(&SessionOptions::new(), &graph)?;

// 入力画像を読み込み、リサイズしてTensorに変換する
let img = ImageReader::open("./img/example.jpeg")?.decode()?;
let resized_img = img.resize_exact(224 as u32, 224 as u32, FilterType::Lanczos3);
let img_vec: Vec<f32> = resized_img.to_rgb8().to_vec().iter().map(|x| *x as f32).collect();
let x = Tensor::new(&[1, 224, 224, 3]).with_values(&img_vec)?;

// DNNに入力する
let mut args = SessionRunArgs::new();
args.add_feed(&graph.operation_by_name_required("resnet")?, 0, &x);
let output = args.request_fetch(&graph.operation_by_name_required("Identity")?, 0);
session.run(&mut args)?;

// check result
let output_tensor: Tensor<f32> = args.fetch(output)?;
let output_array: Vec<f32> = output_tensor.iter().map(|x| x.clone()).collect();
println!("{:?}", output_array);

出力として、特徴量vectorが得られる。

 

- 画像特徴のインデックスと検索 -

horaを利用して画像検索を行う。

// init index
let mut index = hora::index::hnsw_idx::HNSWIndex::<f32, usize>::new(2048, &hora::index::hnsw_params::HNSWParams::<f32>::default(),);

// 特定ディレクトリの画像ファイルをインデックス
let paths = fs::read_dir("img")?;
let mut file_map = HashMap::new();
for (i, path) in paths.into_iter().enumerate() {
    let file_path = path?.path();
    let path_str = file_path.to_str();
    if path_str.is_some() {
        file_map.insert(i, path_str.unwrap().to_string().clone());  // ファイル一覧を作成
        let emb_vec = emb.convert_from_img(path_str.unwrap())?;     // 画像特徴を得るメソッド
        index.add(emb_vec.as_slice(), i)?;                          // インデックス
    }
}
index.build(hora::core::metrics::Metric::Euclidean).unwrap();

// 画像をqueryとして検索
let query_image = &file_map[&100]
let emb_vec_target = emb.convert_from_img(&query_image.to_string())?;  // 画像特徴を得るメソッド
let result = index.search(emb_vec_target.as_slice(), 10);              // 特徴量をqueryとし検索
println!("neighbor images by query: {:?}", query_image);
for r in result {
    println!("{:?}", &file_map[&r]);
}

これらのコードは以下に公開している。

また、上記にはfood-101データセットを用いたインデキシングのサンプルが配置してあるため、今回はそちらを利用して検索の動作確認を行った。

www.tensorflow.org

 

- 検索結果 -

query画像をランダムに選択してTop5の画像を目視でチェックする。

f:id:vaaaaaanquish:20210810060405p:plain
餃子queryとTop5
f:id:vaaaaaanquish:20210810061939p:plain
ラーメンqueryとTop5

餃子は1つだけ間違えて寿司を引いてきているが概ね良さそう。

カテゴリを利用した精度測定などが考えられるが今回はここまで。

- おわりに -

Rustによる画像検索を実装し、動作を確認できた。

エッジデバイスやスマフォ上での画像検索が出来るようになってくると、インデックスファイルを小さくしても精度が保てるモデルの研究が出てきたりするかもなと妄想することができた。

コードは以下に公開した。
github.com

wasm化した上での画像検索は出来てはいるので次はそちらを書く。