- はじめに -
NIPS 2016のSiamese Neural Networks for One-shot Image Recognitionを参考に、画像の距離学習を行う。
Siamese Networkは、各クラスの画像量にバラつきがあり、一部クラスが数枚しかない学習データでも上手く学習させられるネットワークである。
「特徴量同士の距離が近い画像」を探す事で、分類や検索といった問題を解くために利用できる。
やりたい事としては以下のような感じ。
本記事は、以下の記事で利用したcnn_finetuneライブラリを用いて、pnasnet5largeをpretrainとした、画像の距離学習のtrain, testを行うメモである。
また、最後にfood-101データセットを用いた結果の例を示す。
- train -
距離関数とSiamese Networkを定義し、trainする。
最も簡単な構成で、pre trainから実現する。
距離関数の設計
距離関数には、古典的なContrastiveLossを利用する。
より一般的には、この距離関数の設定がtest時の精度に大きく影響するため、対象となるデータやモデルの大きさに応じて、変更すると良い。
import torch class ContrastiveLoss(torch.nn.Module): def __init__(self, margin=1.0): super(ContrastiveLoss, self).__init__() self.margin = margin def forward(self, x0, x1, y): diff = x0 - x1 dist_sq = torch.sum(torch.pow(diff, 2), 1) dist = torch.sqrt(dist_sq) mdist = self.margin - dist dist = torch.clamp(mdist, min=0.0) loss = y * dist_sq + (1 - y) * torch.pow(dist, 2) loss = torch.sum(loss) / 2.0 / x0.size()[0] return loss
モデルの定義
前述した通り、cnn_finetuneライブラリを用いてpnasnet5largeのimagenet pretrainモデルを導入し、そのネットワークの中間層出力を用いてSiamese Networkを構築する。
上図のInput, hidden layerをpnasnetの特徴量抽出部分に変更し、後段の層も深くしてみる。
from cnn_finetune import make_model import torch.nn as nn resize = (256, 256) # 入力画像サイズ class Identity(nn.Module): def __init__(self): super(Identity, self).__init__() def forward(self, x): return x def make_pnas(): model = make_model('pnasnet5large', pretrained=True, input_size=resize) model.module._classifier = Identity() return model class SiameseNetwork(nn.Module): def __init__(self): super(SiameseNetwork, self).__init__() self.cnn = nn.Sequential( make_pnas(), nn.Linear(4320,500), nn.ReLU(inplace=True), nn.Linear(500, 10), nn.Linear(10, 2)) def forward(self, input1, input2): output1 = self.cnn(input1) output2 = self.cnn(input2) return output1, output2
SiameseNetworkは別のネットワークから2つの出力を行う形となる。
データ読み込み部の設計
学習のためのデータセットとしては、前回記事同様、以下のようなheader付きのhogehoge.csvなる「学習画像の名前(ImageName)」と「学習画像に対するラベル(ImageLabel)」がある想定。
ImageName,ImageLabel 0001.jpg,1 0002.jpg,3 0003.jpg,0 0004.jpg,3
画像の変形によるアップサンプリングを行いながら、等確率で「同じラベルの画像」「別ラベルの画像」を返す実装を以下に示す。
transformには、pretrainで利用されているImageNet画像の平均、分散値を利用する。
import torch import os import pandas as pd import pickle import torchvision.transforms as transforms from torch.utils.data import Dataset from PIL import Image resize = (256, 256) # 入力画像サイズ trans= [transforms.Resize(resize), transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3), transforms.RandomHorizontalFlip(), transforms.RandomAffine(0.3, shear=0.3), transforms.RandomRotation(degrees=30), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] class MyDataSet(Dataset): def __init__(self, root_dir): self.transform = transforms.Compose(trans) self.train_df = pd.read_csv('hogehoge.csv') self.root_dir = './train' self.images = list(self.train_df.ImageName.unique()) self.labels = list(self.train_df.ImageLabel.unique()) def __len__(self): return len(self.images) def image_open(self,t): image = Image.open( os.path.join(self.root_dir, t) ) return image.convert('RGB') def __getitem__(self, idx): # labelに対して画像を選択 source_label = self.labels[idx] source_image_name = self.train_df.query('ImageLabel=="{}"'.format(source_label)).sample(1)['Image'].iloc[0] # labelに対して同じラベル、違うラベルをそれぞれ50%で返す if random.randint(0,100)<50: target_image_name = self.train_df.query('ImageLabel=="{}"'.format(source_label)).sample(1)['Image'].iloc[0] label = 1 else: target_image_name = self.train_df.query('ImageLabel!="{}"'.format(source_label)).sample(1)['Image'].iloc[0] label = 0 # 画像ロード source_image = self.image_open(source_image_name) target_image = self.image_open(target_image_name) return self.transform(target_image), self.transform(image), label kwargs = {'num_workers': 1, 'pin_memory': True} train_set = MyDataSet() train_loader = torch.utils.data.DataLoader( train_set, batch_size=32, shuffle=True, **kwargs)
実際は同じ画像を学習しないだとか、少ないラベルをなるべくサンプリングするだとか、現在のモデルの精度によって当たらないラベルを多めにサンプリングするなどすると良い。
optimizer
前回記事あまり考えずSGDを選択する。criterionには前述したContrastiveLossを利用する。
import torch.nn as nn import torch.optim as optim criterion = ContrastiveLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
train run
学習を回す。
from torch.autograd import Variable def train(epoch): model.train() for batch_idx, (x0, x1, labels) in enumerate(train_loader): labels = labels.float() # x0, x1, labels = x0.cuda(), x1.cuda(), labels.cuda() x0, x1, labels = Variable(x0), Variable(x1), Variable(labels) output1, output2 = model(x0, x1) loss = criterion(output1, output2, labels) optimizer.zero_grad() loss.backward() optimizer.step() torch.save(model.state_dict(), '../model/model-epoch-%s.pth' % epoch) model = SiameseNetwork() # model.cuda() for epoch in range(1, 100): train(epoch)
- test -
データが少なければ全てのデータを特徴量にして、全てのtrain, testデータに対して総当たりで距離を計算すれば良い。
高次元ベクトル検索ライブラリを利用して検索する。
もし保存したtrainのモデルを利用したい場合は、前回記事を参照。
データ読み込み部の設計
test用にデータを読み込むだけのクラスを定義する。
import os import pandas as pd import torchvision.transforms as transforms from torch.utils.data import Dataset from PIL import Image tran = [transforms.Resize(resize), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] class MyDataSet(Dataset): def __init__(self, root_path, csv, pg=False): self.df = pd.read_csv(csv) self.root_path = root_path self.images = list(self.df.ImageName.unique()) self.transform = transforms.Compose(tran) def __len__(self): return len(self.images) def __getitem__(self, idx): image_name = self.images[idx] image = Image.open( os.path.join(self.root_path, image_name) ) return self.transform(image), image_name
test_set = MyDataSet('./test/', './test.csv') test_loader = torch.utils.data.DataLoader(test_set,batch_size=1, shuffle=False) train_set = MyDataSet('./train/', './train.csv') train_loader = torch.utils.data.DataLoader(train_set,batch_size=1, shuffle=False)
trainとtestどちらも特徴量化して、testの1画像をqueryとして、train内から近いものを探す形で、分類問題や検索問題を解く。
そのため、trainは全て特徴量にしておく。
test run
画像を特徴量にする。
もしデータが大量の場合は、DBに保存したり、後述する高次元ベクトル検索ライブラリを利用する。
import numpy as np import torch.nn as nn model = model.eval() aap2d = nn.AdaptiveAvgPool2d(output_size=1) # aap2d.cuda() # - test images - test_output = [] test_output_name = [] for batch_idx, (x,name) in enumerate(test_loader): # x = x.cuda() output = model.cnn[0].module.features(x) output = aap2d(output).squeeze() test_output.append(np.array(output.cpu().tolist())) test_output_name.append(name) # - train images - train_output = [] train_output_name = [] for batch_idx, (x,name) in enumerate(train_loader): # x = x.cuda() output = model.cnn[0].module.features(x) output = aap2d(output).squeeze() train_output.append(np.array(output.cpu().tolist())) train_output_name.append(name)
全てのデータから探索する
全てのデータ同士のユークリッド距離を計算してやるやつのサンプル。
kaggleなど全データが少ない時に使える。
for tx,tname in zip(test_output, test_output_name): dists = [] for y in train_output: dists.append(np.linalg.norm(tx-y)) # 小さい順に j = sorted(list(zip(dists, train_output_name)), key=lambda x: x[0]) print(tname[0], j[:10]) break
検索ライブラリに突っ込む
全部探索してたら時間がすごいかかるので、高次元ベクトル検索ライブラリを利用する方法がある。
検索では私は大体何も考えずにnmslibに突っ込む(導入が最も簡単)。
# pip install nmslib # https://github.com/nmslib/nmslib/tree/master/python_bindings import nmslib index = nmslib.init(method='hnsw', space='cosinesimil') index.addDataPointBatch(train_output) index.createIndex({'post': 2}, print_progress=True) ids, distances = index.knnQuery(test_output[0], k=19) print(ids) print(distances)
検索ライブラリの比較:qiita.com
その他、Yahoo! JAPAN社もNGTなるライブラリ出してるので要検討。
GitHub - yahoojapan/NGT: Neighborhood Graph and Tree for Indexing High-dimensional Data
food-101による例
一応例としてfood-101データセットを学習して、queryに対して近い画像を出したものを示す。
記事上部のドーナッツは上手くいった例で実際はこんな感じになる。
左がquery、近い順に左から画像が並んでいる。ドーナッツに関してはドーナッツが収集できており、似た画像が拾ってきていると言える。
2行目のアイスクリームをqueryにした場合、2番目の画像についたラベルはアップルパイであった。視覚的には似てるけど。
3行目はスパゲッティで検索しているが、実際に出てくる画像は「オムレツ、パエリア、リゾット、パスタ、ハンバーガー(イカリング)」となっている。気持ちはわかるけど。
まあ大体3行目のような感じで間違えるので、後処理なんかで色々やってやると良さそう。
food-101 datasetは画像の明暗が激しかったり、人がインスタに載せるような角度だったり、以下のような画像が混ざっていて厳しい所もあるので前処理も重要そうだという知見が得られた。上記は、左から「アイスクリーム」「アイスクリーム」「ハンバーガー」「ピザ」「ブレッドプディング」である。わからん。
パスタやステーキ、餃子、寿司、ティラミス、マカロン、…といった他物体や人が映る可能性の低いものもあるので、タスクに応じて、カテゴリを絞ったり細分化されている所をマージするなどすると良さそう。
- おわりに -
Pytorchのpretrainモデルを利用したSiamese Networkを構築した。
verification modelの拡張としてtripret lossを利用したり、partなモデルに拡張してより細かな物体同士の距離を用いたりできるので、いつか記事として書く。GitHubにも上げる。
特に他意はない。