はじめに

毎年恒例、ハロウィーンの時期になるとうまい棒が600本届き、ハロウィーンパーティをします。 そして、今年はうまい棒40周年だそうです!
折角なので、大量のうまい棒をスマホで撮影し、画像認識をさせて、何味か当てるという「うまい棒パッケージ判定器」をつくってみました。

「うまい棒パッケージ判定器」は今回 pytorch を使用しました。

開発環境

  • pytorch 1.3.0
  • python3系
  • GPU(Geforce 980M)
  • Docker

データ

うまい棒の画像が300枚ほどです。すべてiPhone 8 Plusのカメラで撮りました。

味は以下の通りです。

  • チーズ
  • チキンカレー
  • コンポタージュ
  • 海老マヨ
  • 牛タン
  • 明太
  • 納豆
  • サラダ
  • たこ焼き
  • 照り焼き
  • とんかつ

全11種類となります。本当はラスク味もあったのですが、あまりにも大人気だったため、撮影する前に売り切れてしまいました(笑)

使用するネットワークモデル

VGG16を使います。

今回は torch_vision に入ってる VGG モデルを転移学習します。

コードで書くと以下のようになります。

import torch.nn as nn
from torchvision import models

net = models.vgg16(pretrained=True)
net.classifier[6] = nn.Linear(in_features=4096, out_features=11)

実装方法

データセットの実装

データのディレクトリは data/train/[各種味]/****.jpeg となっています。

# class MyDataset(data.Dataset)の実際にデータを作る部分です
img_path = self.file_list[index]

# 画像の事前処理
img = Image.open(img_path)
# self.transform は torchvision の transforms.Compose です
# データオーギュメントも transforms.Compose の中でやっています
# phase は学習か検証の2つを持っていて、検証ではデータオーギュメントは不要なので分岐しています
img_transformed = self.transform(img, self.phase)

pattern = './data/train/.+/'
# data/train/[各種味]/****.jpeg のディレクトリ構造から味のラベルを取る
label = re.match(pattern, img_path).group().replace('./data/train/', '').replace('/', '')

# ラベルを数値に変換します
if label == "cheeze":
    label = 0
elif label == "chikencurry":
    label = 1
elif label == "compota":
...

return img_transformed, label

データローダーの実装

pytorch の utils をそのまま使用しています。

# transform は標準化もしています
train_dataset = MyDataset(
    file_list=train_list, transform=ImageTransform(size, mean, std), phase='train')
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

ネットワークモデルとフォワードの実装

VGG16の転移学習なのでフォワードの定義などは不要です。

net = models.vgg16(pretrained=True)
net.classifier[6] = nn.Linear(in_features=4096, out_features=11)
# out_features は味の数に応じて変えてください

損失関数の実装

クロスエントロピー誤差を使用します。

nn.CrossEntropyLoss()

学習と API サーバースクリプト作成

150epoch で回しました。 GPU を使用しているため、10分もかからずに終わると思います。 VGG16はそこまで重いネットワークモデルはないので CPU でも十分にできると思います。

今回は学習に使う環境を Docker で作っています。

FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04

ENV DEBIAN_FRONTEND=noninteractive

RUN apt-get update
RUN apt-get -y upgrade
RUN apt-get -y install python3
RUN apt-get -y install python3-pip
RUN apt-get -y install nano wget curl
RUN apt-get -y install python3-setuptools python3-wheel

#RUN python -m pip install -U pip
RUN pip3 install --upgrade pip

RUN pip3 install matplotlib pillow opencv-python

RUN apt-get update && apt-get install -y libopencv-dev
RUN pip3 install torch torchvision
RUN pip3 install jupyter
RUN pip3 install pandas
RUN pip3 install tqdm

ENV PYTHONIOENCODING utf-8

WORKDIR /workspace
COPY data data
COPY utils utils
COPY train.py .

出来上がったモデルを折角なので API サーバーにして、アプリなどで遊べるように組んでみました。 以下、サーバーのスクリプトになります。こちらは pytorch のオフィシャルサイトに書いてあるのを真似すると簡単に作れます。

app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False

# モデル読み込み
model = models.vgg16(pretrained=False)
model.classifier[6] = nn.Linear(in_features=4096, out_features=11)
net_weights = torch.load('<保存したモデルの path>',
                         map_location={'cuda:0': 'cpu'})
model.load_state_dict(net_weights, strict=False)
model.eval()

# 実際に予測するメソッドです
def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model(tensor)
    num, label = outputs.max(1)
    return num, label, label_converter(label)

# id から日本語に変換するものです
def label_converter(int):
    labels = ["チーズ", "チキンカレー", "コンポタージュ", "海老マヨ", "牛タン", "明太", "納豆", "サラダ", "たこ焼き", "照り焼き", "とんかつ"]
    return labels[int]

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_num, class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'umaibou_id': class_id.item(), 'umaibou_label': class_name, 'umaibou_acc': class_num.item()})

if __name__ == '__main__':
    app.run()

サーバーを建てて、 curl で叩くと以下のようになります。

$ curl -X POST -F file=@<うまい棒の画像データ> http://localhost:5000/predict
{"umaibou_acc":3.3103041648864746,"umaibou_id":0,"umaibou_label":"チーズ"}

最終的に9割ほどの精度はでたので、まぁ及第点かなというところでした。

おわりに

うまい棒40周年おめでとう!

(株)やおきんさんありがとう!

うまい棒は安く美味しいので食べ過ぎには注意しましょう!!