読者です 読者をやめる 読者になる 読者になる

Stimulator

機械学習とか好きな技術話とかエンジニア的な話とかを書く

Active Object Localization with Deep Reinforcementを作っている

- はじめに -


この記事はDeep Learning Advent Calendar 2015の24日目の記事です.


Deep Q-Network(DQN)がNIPSで発表されてから*1はや2年.

DQNは, 深層強化学習として一分野を確立し, 機械学習分野自体の活発さ, Deep Learningの話題性も相まって, 怒涛の勢いで新しい研究成果が発表されています.*2

この記事では, DQNを物体認識タスクに応用したActive Object Localization with Deep Reinforcement Learning[ PDF ]について解説, 実装を行っていこうと思います.*3

(12/24現在まだ思うように実装が出来てませんすいません)

 

- 背景 -

そもそも物体認識タスクとは, 画像から特定の物体を検出するタスクの事を指します.

例えば, 私が昨年kazoo04 Advent Calendarで行ったものがそれですね.
この時の記事では, 画像の中からkazoo04という特定の物体(人物)を検出しています.

この記事で私は, kazoo04かどうか分類する学習器にかける前に, 様々な大きさの窓をスライドし, その枠内を入力としています(sliding window, Exhaustive Search).
この手法は, かなり以前から物体認識タスクで多く用いられていましたが, 元の画像サイズや窓を移動させる幅によって, かなりの計算時間がかかってしまう問題がありました.

またConvolutional Neural Networks(CNN)という画像認識に強いDeep Learningの手法が流行し, 学習器の性能が格段に向上しました.
過去のブログの記事にもしていますが, それまでSHIFTやHOGのような特徴量抽出を挟む事で実現していた認識処理を学習器1つで行えるようになりました.

こうのように機械学習による画像認識の精度が高まって行く中で, Exhaustive Searchで計算時間を使っていてはリアルタイム認識なんかは無理だよねという流れが出てきました. また, sliding windowのスケールの違いによって入力も違うため, 誤認識が発生するという問題もありました.

そこで, Exhaustive Searchのスケールによる誤認識を減らすための画像処理手法や計算量を減らす手法*4, 物体検出に対して効率的な手法*5が出てきたり, CNN以外でもsliding windowの欠点を補うようなRandom Forest的手法*6が出てきたりしました.


そして, 次の大きな成果としてR-CNNというモデルが現れました. R-CNN*7はGirshickらが2014年に提案した手法で, 先に物体が入る窓を推定*8し, その窓(window, bounding box)を入力としています.

CNN部分では, その入力を分類する学習に加えて, そのbox自体を矩形回帰するように学習させる事で, 物体の場所検出とその分類を同時に行う事ができるようになります.
分類と回帰を同じ学習器でも行うという点でも, CNN, ニューラルネットワークの強みを活かした手法です.

R-CNNは当初, ネットワークの大きさ等から認識に時間がかかるという問題がありましたが, Fast R-CNN*9のような改良手法が提案され, 今回用いてるPascal VOC(http://host.robots.ox.ac.uk/pascal/VOC/)というデータセットにおいてもかなりの結果を出しています.

近年では上記のようなCNNで抽出した特徴量をRNNにつなげることで, センテンス表現の学習を行う手法(Image Captioning*10 )等も発表されています.


今回のDQNを用いた手法は, 今までの趣向とは少し異なっており, トップダウンな探索によって物体の位置を検出するアルゴリズムになっています.
f:id:vaaaaaanquish:20151225005241p:plain

最初に画像の大きな領域を入力とし, 動的にその入力領域を変化させていきます. 動的な探索において強化学習(Q-learning)な技術が使われています.
また, マルコフ決定過程(MDP)に基づいた動的探索ステップを複数回行う事によって, その回数だけ複数の物体を検出する事も可能にしています.

 

- Q-learningな部分 -

一般的なQ学習の要領を用います. 以下に行動と報酬を示します.

  • 行動

行動は8つのActionと終端条件(trigger)に分かれています.
Actionは, 入力範囲の上下左右の移動と拡大縮小です.
f:id:vaaaaaanquish:20151225000420p:plain

また, 全てのActionは以下の2式によって制御することができます.

{
\alpha_{w}=\alpha * (x_{2}-x_{1})
}
{
\alpha_{h}=\alpha * (y_{2}-y_{1})
}

{\alpha}は幅を制御するパラメータで, 論文中では{\alpha=0.2}として固定の値を用いています.
{\alpha}が大きければ探索が雑になり, 小さければ時間がかかるという事が感覚的にもわかると思います.

また, Triggerは1つの探索終了を示します.
inhibition-of-return(IoR)*11を参考に, 探索が終了した時点で, box内に十字のマークを挿入します. これは, 次の探索で同じ領域がゴールになることを防ぐためです.
f:id:vaaaaaanquish:20151225000813p:plain
報酬の定義を基に複数回繰り返す事で, 複数の物体を認識する事を可能にしています.

  • 報酬

報酬関数にはIoU(Intersection-over-Union)を用います. IoUは, boxである{b}(box)に対して, 目的となる領域{g}(ground truth box)がどれだけ含まれているかとなります.

{IoU(b,g)=area(b \cap g)/area(b \cup g)}

IoUを用いて, 状態{s}において行動{a}を行って状態{s^{'}}に遷移する時の報酬関数{R}は, 以下のように定義されます.

{R_{a}(s,s^{'})=sign(IoU(b^{'},g)-IoU(b,g)) }

ある状態のIoUから次の状態へのIoUの差ですね.
またこの値は正負がbinaryで制御されます.

{\\
R_{\omega}(s,s^{'}) = \cases{
\eta & if IoU(b,g)\(\leq \tau\) \cr
\\-\eta & otherwise \cr
}
}

しきい値{\tau}を超えていれば{+\eta}, 無ければ{\\-\eta}の報酬という形です.
論文中では{\eta=3.0, \tau=0.6}に設定されています. {\eta, \tau}は経験則によるものが大きく, {\tau}はデカすぎるとなかなか達成できないので0.6という感じみたいです.

  • 学習手法

学習では, パラメータを初期化した後, 目標となる{g}に対して率直に+-で進むよう行動を選択していきます.複数目的があった場合は, 内1つがランダムに選択されます.

しかし, すべて貪欲に動いていれは汎化性能が上がらないため, 全てのtraining画像に対して学習を終えた後, ε-greedy法を用いた学習を行い探索します.

論文中では, ε-greedyなTrainingを15epoch分回しますが, 最初の5epochで{ε}を1~0.1に線形に下がるよう設定しているようです.

また, boxのスタート地点は4隅から, 全体の75%のサイズで始めます.

 

- CNNな部分 -

CNNのネットワーク構成は, 以下のようになっています.
f:id:vaaaaaanquish:20151225001924p:plain

実際にQ-learningを適応しているのは後ろ3層のみです.
これについては論文中でも言及されており, 前層のpre-trainingによって学習収束速度が向上すること, 全体を学習するにはさらに大きなデータセットが必要と考えられる事などから, 今後の研究課題であるとされています.
ちなみにpre-training層は分類器としてVOCデータセット学習したCNNの特徴抽出部分を用いています.

入力は224*224に正規化された画像のベクトル.
出力はActionとTriggerを含む9ユニットです.
出力で強化学習におけるQtableを再現するイメージです.

NNの部分は誤差逆伝搬法(back propagation)による最適化, Dropoutによる正則化を用いています.

また, 過去10Actionをbinary形式で保存したaction historyと呼ばれるユニットをQ-Network以前の層に挿入しています.
これにより短期的に良い行動を学習する事が可能となり, 精度にして3%前後の向上が見られるようです.


Deep Q-Networkの学習機構については, 日本語であれば次の記事が分かりやすいかと思います.
DQNの生い立ち + Deep Q-NetworkをChainerで書いた - Qiita

 

- 実装 -

こんな感じなので少しまって

*1:Playing Atari with Deep Reinforcement Learning - http://arxiv.org/pdf/1312.5602.pdf

*2:自分もここ1ヶ月くらいで本腰入れて調査した程度なので詳しくはないです

*3:この記事で用いてる画像は論文中から引用したものです

*4:http://www.kyb.mpg.de/fileadmin/user_upload/files/publications/pdfs/pdf5070.pdf

*5:http://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Gonzalez-Garcia_An_Active_Search_2015_CVPR_paper.pdf

*6:http://www.habe-lab.org/habe/pdf/2011SSII_SIHF.pdf

*7:http://arxiv.org/abs/1311.2524

*8:bjectness. どちらかというとコンピュータビジョンな技術が多い. Selective Search(最初のR-CNN), BING等様々な手法がある

*9:http://arxiv.org/abs/1504.08083

*10:Deep Visual-Semantic Alignments for Generating Image Descriptions, http://cs.stanford.edu/people/karpathy/deepimagesent/

*11:http://www.cnbc.cmu.edu/~tai/readings/tom/itti_attention.pdf