CNN(torch visionの構築済みモデル)による回帰

はじめに

画像による回帰、例えばラーメンの画像から食べログ評価を予測したり、人の顔を点数化したりなど様々な用途が考えられますが、日本語でも英語でも資料が少なく、何度か躓いたので備忘録としてまとめたいと思います。

モデルはtorchvisionのチュートリアルから、既存のクラス分類のモデルから弄るだけで動くようにします。

参考サイト

torchvisonのチュートリアル
https://medium.com/@benjamin.phillips22/simple-regression-with-neural-networks-in-pytorch-313f06910379

クラス分類からの変更点

基本的に画像による回帰は、クラス分類の機構をほとんどそのまま流用できます。
必要な変更点は以下の3点です。

  • Loss
  • 出力層
  • 正解データの付与

Loss

チュートリアルではLossがCrossEntropyになっているので、MSEやSmoothL1などの回帰に用いられるLossに変更します。

# Setup the loss fxn
criterion = nn.MSELoss()

出力層

チュートリアルをベースに弄りたいので、クラス数を1に設定する”だけ”。

# Number of classes in the dataset
num_classes = 1

ただし上記の変更だけだと出力層のノードが1024 -> 1 のようにかなり急激な減少になってしまうため、modelを以下のように変更します。(ResNetにおける一例)

model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, 256),
              nn.LeakyReLU(),
              nn.Linear(256, 32),
              nn.LeakyReLU(),
              nn.Linear(32, 1))

正解データの付与

pytorchのDatasetという便利すぎるモジュールのおかげで、学習データをフォルダごとに分けておけば勝手にクラスのラベルを貼ってくれます。
しかし、今回行いたいのは回帰なのでfloatなりintなりの数字を持たせる必要があるので、Datasetを自作します。

*pytorchの自作データセットは結構情報が転がっているので詳細はそちらに任せます。

class Create_Datasets(Dataset):
    def __init__(self, path, data_transform):
        self.path = path
        self.df = self.create_csv(path)
        self.data_transform = data_transform

    def create_csv(self, path):
        image_path = []
        for file_name in glob(path + '/*.jpeg'):
            basename = os.path.basename(file_name)
            image_path.append(basename)

        df = pd.DataFrame(image_path, columns=['path'])

'''
好みの前処理
'''

        return df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        file = self.df['path'][i]
        score = np.array(self.df['good'][i])
        image = Image.open(os.path.join(self.path, file))
        image = self.data_transform(image)

        return image, score

個人的にハマった落とし穴

上記datasetで学習データを作成した際に、lossのbackwardでエラーが発生。
コードは動くがlossが下がらない等かなり時間を取られてしまいました。

RuntimeError: Found dtype Double but expected Float

解決策としてはlabelのdtypeをtorch.float32にすれば良い。

labels = labels.float().to(device)

終わりに

今回はtorchvisionのモデルをそのまま回帰に流用できるようにしました。
軽く動かしてみた感じでは、ResNetが回帰とは相性が良さそうです。
少しの変更で済む割には全然情報が見つからなかったので、まとめてみました。

おすすめの記事