Stimulator

機械学習とか好きな技術話とかエンジニア的な話とかを書く

PyTorchで学習済みモデルを元に自前画像をtrainしてtestするまで

- はじめに -

最初のステップとなる「学習済みのDeep Learningモデルをpre-train modelとして自分が用意した画像に対して学習」する時のメモ。

多分これが一番簡単だと思います。

 

- 準備 -

バージョンはtorch (0.4.1)、torchvision (0.2.1)の話をする。

pip install torch
pip install torchvision

学習済みモデルはpytorchの画像向けパッケージとなるtorchvisionでもサポートされている。
torchvisionで扱えるモデルは以下(2018/09/15 時点)

  • AlexNet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet
  • Inception v3

参考:torchvision.models — PyTorch master documentation

 
最近はすごいスピードで他の高精度モデルや、仕組みの違う学習済みモデルが出てきてるので、pytorchのpretrainモデルを使う場合のサポートpackageを使うと良さそう。
以下のどちらでも良い。
GitHub - creafz/pytorch-cnn-finetune: Fine-tune pretrained Convolutional Neural Networks with PyTorch
GitHub - Cadene/pretrained-models.pytorch: Pretrained ConvNets for pytorch: NASNet, ResNeXt, ResNet, InceptionV4, InceptionResnetV2, Xception, DPN, etc.

pip install cnn_finetune
pip install pretrainedmodels

上記のtorchvisionに加えて以下が簡易に扱えるようになる(2018/09/15 時点)

  • ResNeXt
  • NASNet-A Large
  • NASNet-A Mobile
  • Inception-ResNet v2
  • Dual Path Networks
  • Inception v4
  • Xception
  • Squeeze-and-Excitation Networks
  • PNASNet-5-Large
  • PolyNet

モデルは「どのモデルがどんな感じの精度なん?」というのは以下READMEにimagenetでの精度比較表が載ってるので参考に。
https://github.com/Cadene/pretrained-models.pytorch#evaluation-on-imagenet

それぞれのモデルへのリンクも以下に存在する。
https://github.com/Cadene/pretrained-models.pytorch#documentation


 

- pretrainモデルで簡易に学習する -

cnn_finetuneの方がちょっとばかり楽できるので今回はcnn_finetuneベースで薦める。

分類するクラス数とモデルの入力となる画像サイズ、pretrained=Trueを指定して実行すると、学習済みデータがダウンロードされて読み込まれる。

from cnn_finetune import make_model
import torch

# cnn_futureを使う場合
model = make_model('pnasnet5large', num_classes=2, pretrained=True, input_size=(384, 384))
# pretrainedmodelsを使う場合
# model = pretrainedmodels.__dict__[model_name](num_classes=10, pretrained='imagenet')

# 'cuda' or 'cpu'
device = torch.device('cuda')
model = model.to(device)

  
学習のためのデータセットとしては、header付きのhogehoge.csvなる「学習画像の名前(ImageName)」と「学習画像に対するラベル(ImageLabel)」がある想定。こんなん。

ImageName,ImageLabel
0001.jpg,1
0002.jpg,1
0003.jpg,0
0004.jpg,1

また「学習画像は/hogehoge/train/なる配下に全て入っている」想定。

学習にはDataset、DataLoaderというクラスを利用する必要がある。
今回は雑に2クラスのデータセットを想定して書く。

# must: pip install pillow, pandas
from PIL import Image
from torch.utils.data import Dataset
import pandas as pd
import os
import torchvision.transforms as transforms


class MyDataSet(Dataset):
    def __init__(self, csv_path, root_dir):
        self.train_df = pd.read_csv(csv_path)
        self.root_dir = root_dir
        self.images = os.listdir(self.root_dir)
        self.transform = transforms.Compose([transforms.ToTensor()])
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # 画像読み込み
        image_name = self.images[idx]
        image = Image.open( os.path.join(self.root_dir, image_name) )
        image = image.convert('RGB') # PyTorch 0.4以降
        # label (0 or 1)
        label = self.train_df.query('ImageName=="'+image_name+'"')['ImageLabel'].iloc[0]
        return self.transform(image), int(label)

train_set = MyDataSet('hogehoge_train.csv', '/hogehoge/train/')
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)

シンプルに画像とラベルを返す__getitem__とデータの大きさを返す__len__を実装するだけです。

バッチサイズは画像の大きさに合わせて調整すると良いです。
デカすぎると皆大好き「RuntimeError: cuda runtime error: out of memory」になります。

 
Optimizerを選ぶ。学習済みモデル使うならSGDでええんちゃうんと思ってるけどベストプラクティスは謎。

import torch.nn as nn
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

 
これで後は学習回すだけ。

import datetime

def train(epoch):
    total_loss = 0
    total_size = 0
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        total_loss += loss.item()
        total_size += data.size(0)
        loss.backward()
        optimizer.step()
        if batch_idx % 1000 == 0:
            now = datetime.datetime.now()
            print('[{}] Train Epoch: {} [{}/{} ({:.0f}%)]\tAverage loss: {:.6f}'.format(
                now,
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), total_loss / total_size))

for epoch in range(1, 10 + 1):
    train(epoch)

以下examplesを参考にした。
https://github.com/creafz/pytorch-cnn-finetune/blob/master/examples/cifar10.py


 

- modelを保存する -

state_dictはモデルの構造だけ保存。
普通にsaveするとGPU等device関連情報も一緒に保存するため、別環境で動かす時面倒らしい。

torch.save(model.state_dict(), 'cnn_dict.model')
torch.save(model, 'cnn.model')

参考:https://pytorch.org/docs/master/notes/serialization.html


 

- predictする -

モデルを読み込む。
modelをそのままsaveした場合はloadで簡易に読み込めるが、state_dictした場合は以下のように。

import torch
from cnn_finetune import make_model

# モデル定義
model = make_model('pnasnet5large', num_classes=2, input_size=(384, 384))
# パラメータの読み込み
param = torch.load('cnn_dict.model')
model.load_state_dict(param)
# 評価モードにする
model = model.eval()

 
テスト時もDataLoaderが必要になるが、今回は上記train時に作成したMyDataSetクラスをそのまま使う。

test_set = MyDataSet('hogehoge_test.csv', '/hogehoge/test/')
test_loader = torch.utils.data.DataLoader(test_set, batch_size=32)

 
torch.no_gradとした上でmodelにデータを入力するだけ。
皆大好きclassification_reportを出す。

# must: pip install scikit-learn
from sklearn.metrics import classification_report

pred = []
Y = []
for i, (x,y) in enumerate(test_loader):
    with torch.no_grad():
        output = model(x)
    pred += [int(l.argmax()) for l in output]
    Y += [int(l) for l in y]

print(classification_report(Y, pred))

出力はSoftmax使えばクラス数分 [クラス1の予測値, クラス2の予測値, ...] となってるので、argmax取ってやれば予測クラスを出すことができる。


 

- おわりに -

最近インターン生にオススメされてPyTorch触り始めて「ええやん」ってなってるので書いた。

ちょっと複雑なモデル書く時の話や torch.distributed 使う話も気が向いたら書くと思うけど、TensorFlow資産(tensorbordとか)にも簡単に繋げられるし、分散時もバックエンド周りを意識しながら書きやすいので結構良い感じする。


 
追記:2018/09/16

以下の部分でtransforms内のメソッドを利用して入力正規化とかaugumentationも出来るのですが、今回省いています。

self.transform = transforms.Compose([transforms.ToTensor()])

そしたら「正規化はpretrainに合わせてやった方がいいのでは?」みたいな話がTwitterで発生しました。

結論としては、多分やったほうが良いみたいになったのですが、確証が今の所ないのでtransforms.Normalizeとかtransforms内の色々試して比較すべきみたいな感じです。

もしpretrainに合わせて正規化をしたい場合は以下にmean, stdが載っているので使うと良いと思います。
pretrained-models.pytorch/pretrainedmodels/models at master · Cadene/pretrained-models.pytorch · GitHub