3.6. 蛾の種同定#
種同定は、多様性調査において欠かせない作業の一つです。人工知能を活用した種同定システムを使用することで、市民が参加する多様性調査ネットワークに役立てることができ、全国規模の生態調査がより簡単に行えるようになります。本演習では、インターネットで集めた 50 種類の蛾の画像を使って、蛾の種同定モデルを構築する方法を学びます。
3.6.1. 演習準備#
3.6.1.1. ライブラリ#
本節で利用するライブラリを読み込みます。ライブラリの読み込み時に ImportError や ModuleNotFoundError が発生した場合は、該当するライブラリをインストールしてください。pytorch_grad_cam を読み込むときに ModuleNotFoundError が発生した場合は、grad-cam パッケージをインストールしてください。
# image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import PIL.Image
# machine learning
import sklearn.metrics
import torch
import torchvision
# grad-CAM visualization
import cv2
import pytorch_grad_cam
print(f'torch v{torch.__version__}; torchvision v{torchvision.__version__}')
torch v2.8.0+cu126; torchvision v0.23.0+cu126
ライブラリのインポート時に、ImportError や ModuleNotFoundError が発生した場合、必要なライブラリがインストールされていない可能性があります。該当するライブラリをインストールすることで、エラーなく実行できるようになります。
3.6.1.2. データセット#
本演習では、Kaggle において CC0 ライセンスのもとで配布された MOTHS IMAGE DATASET-CLASSIFICATION を利用します。このデータセットにはインターネットで集められた 50 種類の蛾の画像が 224 × 224 の正方形に変更した画像が用意されて、訓練サブセット、検証サブセットおよびテストサブセットに分けられています(Fig. 3.14)。訓練サブセットには種によって枚数が異なるが 100 枚から 170 枚前後の画像が含まれています。検証サブセットよびテストサブセットについては、それぞれの種には 5 枚の画像が含まれています。

Fig. 3.14 MOTHS IMAGE DATASET CLASSIFICATION に含まれる 50 種の蛾の画像例。#
Kaggle で配布されているオリジナルのデータセットについて、画像を保存しているフォルダ名にスペースが含まれています。スペースが含まれると操作ミスが起こりやすいため、本演習ではスペースをすべてアンダーバー「_」に置換したものを利用します。
修正済みのデータセットは、Jupyter Notebook 上では、次のコマンドを実行することでダウンロードできます。
!wget https://dl.biopapyrus.jp/data/moths.zip
!unzip moths.zip
3.6.1.3. 画像前処理#
物体分類用のニューラルネットワークアーキテクチャの多くは、正方形の画像を入力として受け取るように設計されています。例えば、本節で使用する ResNet-18 では、224×224 ピクセルの正方形画像を入力として想定しています[1]。また、PyTorch ではすべてのデータをテンソル形式で扱う必要があります。そのため、畳み込みニューラルネットワークに画像を入力する前に、画像サイズを適切に調整し、テンソル型に変換するなどの前処理を行う必要があります。
これらの前処理は次のように定義します。なお、本節で使用する画像データは、すでにネットワークの入力サイズに合わせて調整されているため、画像サイズを変更する必要はありません。
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
画像データは通常、0 から 255 の範囲の整数値で表現されていますが、前処理の段階でこれを正規化します。正規化により、画像データの値は平均約 0.5、標準偏差約 0.23 の範囲に変換され、モデルの学習を効率的に進めることができます。なお、正規化の際に平均を 0.50、分散を 0.23 のような切りの良い数値にしない理由は、これから利用する torchvision.models が提供する訓練済みモデルが、特定の数値(例えば平均 0.485、標準偏差 0.229)で訓練されているためです。そのため、この訓練済みモデルに合わせて正規化を行います。
なお、独自のアーキテクチャをゼロから訓練する際に、これらの正規化パラメータを訓練データの特徴に応じて決めたりする必要があります。
3.6.1.4. 計算デバイス#
計算デバイスを設定します。PyTorch が GPU を認識できる場合は GPU を利用し、認識できない場合は CPU を利用します。
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
3.6.2. モデル構築#
本節では、物体分類を行うアーキテクチャとして ResNet18 を利用します。ResNet-18 は、深層学習のアーキテクチャの中でも比較的層が浅く、パラメーターの数も少ないため、訓練時間が速いという特徴があります。ResNet18 のアーキテクチャは、torchvision.models モジュールで提供されています。
なお、torchvision.models モジュールで提供されている ResNet-18 は、飛行機や車、人などの 1000 種類の一般的な物体を分類することを前提に設計されています。これに対して、本節では、画像を 50 種類の蛾のな目に分類するため、ResNet-18 の出力層のクラス数を 50 に修正する必要があります。この修正作業はモデルを呼び出すたびに行う必要があり、手間がかかります。そこで、一連の処理を関数化してから利用します。
def resnet18(num_classes, weights=None):
model = torchvision.models.resnet18(weights='DEFAULT')
in_features = model.fc.in_features
model.fc = torch.nn.Linear(in_features, num_classes)
if weights is not None:
model.load_state_dict(torch.load(weights))
return model
model = resnet18(50)
3.6.3. モデル訓練#
モデルが学習データを効率よく学習できるようにするため、損失関数(criterion
)、学習アルゴリズム(optimizer
)、学習率(lr
)、および学習率を調整するスケジューラ(lr_scheduler
)を設定します。
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
次に、訓練データと検証データを読み込み、モデルが入力できる形式に整えます。
train_dataset = torchvision.datasets.ImageFolder('moths/train', transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_dataset = torchvision.datasets.ImageFolder('moths/valid', transform=transform)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, shuffle=False)
準備が整ったら、訓練を開始します。訓練プロセスでは、訓練と検証を交互に繰り返します。訓練では、訓練データを使ってモデルのパラメータを更新し、その際の損失(誤差)を記録します。検証では、検証データを使ってモデルの予測性能(正解率)を計算し、その結果を記録します。このサイクルを繰り返すことで、モデルの精度を少しずつ向上させていきます。
model.to(device)
num_epochs = 10
metric_dict = []
for epoch in range(num_epochs):
# training phase
model.train()
running_loss = 0.0
n_correct_train = 0
n_train_samples = 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
_, predicted_labels = torch.max(outputs.data, 1)
n_correct_train += torch.sum(predicted_labels == labels).item()
n_train_samples += labels.size(0)
running_loss += loss.item() / len(train_loader)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# validation phase
model.eval()
n_correct_valid = 0
n_valid_samples = 0
with torch.no_grad():
for images, labels in valid_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted_labels = torch.max(outputs.data, 1)
n_correct_valid += torch.sum(predicted_labels == labels).item()
n_valid_samples += labels.size(0)
metric_dict.append({
'epoch': epoch + 1,
'train_loss': running_loss,
'train_acc': n_correct_train / n_train_samples,
'valid_acc': n_correct_valid / n_valid_samples
})
print(metric_dict[-1])
Show code cell output
{'epoch': 1, 'train_loss': 0.982824449936549, 'train_acc': 0.7283864680495614, 'valid_acc': 0.88}
{'epoch': 2, 'train_loss': 0.4401725761095681, 'train_acc': 0.8701099819017124, 'valid_acc': 0.868}
{'epoch': 3, 'train_loss': 0.3206855146255757, 'train_acc': 0.9053320339690937, 'valid_acc': 0.916}
{'epoch': 4, 'train_loss': 0.1097999247618847, 'train_acc': 0.9685368230544341, 'valid_acc': 0.964}
{'epoch': 5, 'train_loss': 0.05271018845960498, 'train_acc': 0.9859390226924684, 'valid_acc': 0.968}
{'epoch': 6, 'train_loss': 0.034956350437261995, 'train_acc': 0.9908116385911179, 'valid_acc': 0.968}
{'epoch': 7, 'train_loss': 0.027692981462718706, 'train_acc': 0.994431296115829, 'valid_acc': 0.968}
{'epoch': 8, 'train_loss': 0.025436467875002144, 'train_acc': 0.9955450368926633, 'valid_acc': 0.968}
{'epoch': 9, 'train_loss': 0.02546513455712962, 'train_acc': 0.9948489489071418, 'valid_acc': 0.968}
{'epoch': 10, 'train_loss': 0.02320980512537064, 'train_acc': 0.995962689683976, 'valid_acc': 0.964}
訓練後、学習時の損失および検証性能の変化を可視化し、訓練が正しく行われたかどうかを確認します。
metric_dict = pd.DataFrame(metric_dict)
fig, ax = plt.subplots(1, 2)
ax[0].plot(metric_dict['epoch'], metric_dict['train_loss'])
ax[0].set_xlabel('epoch')
ax[0].set_ylabel('loss')
ax[0].set_title('Train')
ax[1].plot(metric_dict['epoch'], metric_dict['valid_acc'])
ax[1].set_ylim(0, 1)
ax[1].set_xlabel('epoch')
ax[1].set_ylabel('accuracy')
ax[1].set_title('Validation')
plt.tight_layout()
fig.show()

可視化の結果から、エポックが増えると訓練データに対する損失が減少し、7 エポック前後で収束し始めたように見えます。また、検証データに対する分類性能(正解率)は、4 エポック目からすでにほぼ最大値に達したように見えます。
次に、上の手順を様々な深層ニューラルネットワークのアーキテクチャ(DenseNet や Inception など)に対して行います。それぞれのアーキテクチャの検証性能を見比べ、このデータセットにとって最適なアーキテクチャを決めます。ただし、本節ではモデル(アーキテクチャ)選択を行わずに、ResNet18 を最適なアーキテクチャとして採用し、次のステップに進みます。
次のステップでは、訓練サブセットと検証サブセットを統合し、最適なアーキテクチャを最初から訓練します。
!mkdir moths/trainvalid
!cp -r moths/train/* moths/trainvalid
!cp -r moths/valid/* moths/trainvalid
最適なモデルを選択する段階で、数エポックの訓練だけでも十分に高い予測性能を獲得できたことがわかったので、ここでは訓練サブセットと検証サブセットを統合したデータに対して 5 エポックだけ訓練させます。
# model
model = resnet18(50)
model.to(device)
# training params
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00003)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
# training data
train_dataset = torchvision.datasets.ImageFolder('moths/trainvalid', transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
# training
num_epochs = 5
metric_dict = []
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
n_correct_train = 0
n_train_samples = 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
_, predicted_labels = torch.max(outputs.data, 1)
n_correct_train += torch.sum(predicted_labels == labels).item()
n_train_samples += labels.size(0)
running_loss += loss.item() / len(train_loader)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
metric_dict.append({
'epoch': epoch + 1,
'train_loss': running_loss,
'train_acc': n_correct_train / n_train_samples,
})
print(metric_dict[-1])
Show code cell output
{'epoch': 1, 'train_loss': 1.6524445450273897, 'train_acc': 0.7120947127673887}
{'epoch': 2, 'train_loss': 0.458397384088046, 'train_acc': 0.9245257634871519}
{'epoch': 3, 'train_loss': 0.241675884266614, 'train_acc': 0.9593703753531548}
{'epoch': 4, 'train_loss': 0.1488520490105756, 'train_acc': 0.979281582133728}
{'epoch': 5, 'train_loss': 0.13098409236859399, 'train_acc': 0.9815686802098749}
訓練が完了したら、訓練済みモデルの重みをファイルに保存します。
model = model.to('cpu')
torch.save(model.state_dict(), 'moths.pth')
3.6.4. モデル評価#
最適なモデルが得られたら、次にテストデータを用いてモデルを詳細に評価します。例えば、正解率の他に、適合率や再現率や F1 スコアなどを計算し、モデルを総合的に評価します。まず、テストデータをモデルに入力し、その予測結果を取得します。
test_dataset = torchvision.datasets.ImageFolder('moths/test', transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8, shuffle=False)
model = resnet18(50, 'moths.pth')
model.to(device)
model.eval()
pred_labels = []
true_labels = []
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, _labels = torch.max(outputs.data, 1)
#print(_labels)
pred_labels.extend(_labels.cpu().detach().numpy().tolist())
true_labels.extend(labels.cpu().detach().numpy().tolist())
pred_labels = [test_dataset.classes[_] for _ in pred_labels]
true_labels = [test_dataset.classes[_] for _ in true_labels]
次に、予測結果とラベルを比較し、混同行列を作成します。これにより、間違いやすいカテゴリを特定することができます。
cm = sklearn.metrics.confusion_matrix(true_labels, pred_labels)
print(cm)
[[5 0 0 ... 0 0 0]
[0 5 0 ... 0 0 0]
[0 0 5 ... 0 0 0]
...
[0 0 0 ... 5 0 0]
[0 0 0 ... 0 5 0]
[0 0 0 ... 0 0 5]]
それぞれのクラスに対する適合率、再現率、F1 スコアなどは、scikit-learn ライブラリを利用して計算します。
pd.DataFrame(sklearn.metrics.classification_report(true_labels, pred_labels, output_dict=True))
ARCIGERA_FLOWER_MOTH | ATLAS_MOTH | BANDED_TIGER_MOTH | BIRD_CHERRY_ERMINE_MOTH | BLACK_RUSTIC_MOTH | BLAIRS_MOCHA | BLOTCHED_EMERALD_MOTH | BLUE_BORDERED_CARPET_MOTH | CINNABAR_MOTH | CLEARWING_MOTH | ... | SIXSPOT_BURNET_MOTH | SQUARE_SPOT_RUSTIC_MOTH | SUSSEX_EMERALD_MOTH | SWALLOW_TAILED_MOTH | VESTAL_MOTH | WHITE_LINED_SPHINX_MOTH | WHITE_SPOTTED_SABLE_MOTH | accuracy | macro avg | weighted avg | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
precision | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.000000 | ... | 1.0 | 0.833333 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 0.988 | 0.990000 | 0.990000 |
recall | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 0.800000 | ... | 1.0 | 1.000000 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 0.988 | 0.988000 | 0.988000 |
f1-score | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 0.888889 | ... | 1.0 | 0.909091 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 0.988 | 0.987879 | 0.987879 |
support | 5.0 | 5.0 | 5.0 | 5.0 | 5.0 | 5.0 | 5.0 | 5.0 | 5.0 | 5.000000 | ... | 5.0 | 5.000000 | 5.0 | 5.0 | 5.0 | 5.0 | 5.0 | 0.988 | 250.000000 | 250.000000 |
4 rows × 53 columns
3.6.5. 推論#
推論を行う際には、訓練や評価時と同様に、torchvision.models モジュールから ResNet18 のアーキテクチャを読み込み、出力層のクラス数を設定します。その後、load_state_dict
メソッドを使用して訓練済みの重みファイルをモデルにロードします。これらの処理はすでに関数化(resnet18
)されているため、その関数を利用して簡単に実行できます。
labels = [_ for _ in test_dataset.classes]
model = resnet18(50, 'moths.pth')
model.to(device)
model.eval()
Show code cell output
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=50, bias=True)
)
このモデルを使用して推論を行います。まず、ARCIGERA_FLOWER_MOTH に分類されている画像を 1 枚指定し、訓練時と同じ前処理を行います。その後、前処理した画像をモデルに入力し、モデルから予測結果が出力されます。
image_path = 'moths/test/ARCIGERA_FLOWER_MOTH/1.jpg'
image = PIL.Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
score = model(input_tensor)[0]
print(labels)
print(score)
['ARCIGERA_FLOWER_MOTH', 'ATLAS_MOTH', 'BANDED_TIGER_MOTH', 'BIRD_CHERRY_ERMINE_MOTH', 'BLACK_RUSTIC_MOTH', 'BLAIRS_MOCHA', 'BLOTCHED_EMERALD_MOTH', 'BLUE_BORDERED_CARPET_MOTH', 'CINNABAR_MOTH', 'CLEARWING_MOTH', 'COMET_MOTH', 'DEATHS_HEAD_HAWK_MOTH', 'ELEPHANT_HAWK_MOTH', 'EMPEROR_GUM_MOTH', 'EMPEROR_MOTH', 'EYED_HAWK_MOTH', 'FIERY_CLEARWING_MOTH', 'GARDEN_TIGER_MOTH', 'HERCULES_MOTH', 'HORNET_MOTH', 'HUMMING_BIRD_HAWK_MOTH', 'IO_MOTH', 'JULY_BELLE_MOTH', 'KENTISH_GLORY_MOTH', 'LACE_BORDER_MOTH', 'LEOPARD_MOTH', 'LUNA_MOTH', 'MADAGASCAN_SUNSET_MOTH', 'MAGPIE_MOTH', 'MUSLIN_MOTH', 'OLEANDER_HAWK_MOTH', 'OWL_MOTH', 'PALPITA_VITREALIS_MOTH', 'PEACH_BLOSSOM_MOTH', 'PLUME_MOTH', 'POLYPHEMUS_MOTH', 'PURPLE_BORDERED_GOLD_MOTH', 'RED_NECKED_FOOTMAN_MOTH', 'REGAL_MOTH', 'ROSY_MAPLE_MOTH', 'ROSY_UNDERWING_MOTH', 'RUSTY_DOT_PEARL_MOTH', 'SCHORCHED_WING_MOTH', 'SIXSPOT_BURNET_MOTH', 'SQUARE_SPOT_RUSTIC_MOTH', 'SUSSEX_EMERALD_MOTH', 'SWALLOW_TAILED_MOTH', 'VESTAL_MOTH', 'WHITE_LINED_SPHINX_MOTH', 'WHITE_SPOTTED_SABLE_MOTH']
tensor([ 8.8410, -0.5224, -1.4381, -2.1876, -1.8390, -1.7094, -2.9583, 1.0024,
-0.5546, 0.6770, -2.7842, -0.0259, 0.6690, -1.7822, -0.2781, 2.2137,
-0.0312, -0.1761, 1.3376, -3.7415, 0.7613, 0.0265, -0.1928, 3.0125,
-2.5570, -1.0619, -0.5637, -1.9761, -1.0519, -0.9716, -0.2027, -2.7283,
-2.6698, -0.6537, -2.3282, 0.4160, 0.1297, -2.8068, -2.2458, -1.0727,
-0.7748, -3.1409, -0.1399, -2.0539, -0.9778, -2.0506, -2.6850, -2.0104,
-1.7093, -0.6895], device='cuda:0')
このように、モデルの出力値は実数値として得られます。最も高い数値を持つクラスが予測結果(予測ラベル)となります。しかし、このままでは出力値が理解しにくいため、予測値を合計が 1.0 になるように修正します。これにより、出力値が確率として解釈できるようになります。
prob = torch.softmax(score, axis=0).cpu().detach().numpy()
print(prob)
[9.9151975e-01 8.5074062e-05 3.4052464e-05 1.6092994e-05 2.2805256e-05
2.5960553e-05 7.4459272e-06 3.9087143e-04 8.2378312e-05 2.8229150e-04
8.8620454e-06 1.3978175e-04 2.8004477e-04 2.4136934e-05 1.0862217e-04
1.3124426e-03 1.3904068e-04 1.2028446e-04 5.4651452e-04 3.4023617e-06
3.0713723e-04 1.4730156e-04 1.1829478e-04 2.9172748e-03 1.1121752e-05
4.9604521e-05 8.1634578e-05 1.9881765e-05 5.0099530e-05 5.4289871e-05
1.1712776e-04 9.3715025e-06 9.9354284e-06 7.4609721e-05 1.3981594e-05
2.1745544e-04 1.6330411e-04 8.6633554e-06 1.5182204e-05 4.9070393e-05
6.6098342e-05 6.2033241e-06 1.2472016e-04 1.8394012e-05 5.3953245e-05
1.8456109e-05 9.7856300e-06 1.9211771e-05 2.5963500e-05 7.1987175e-05]
ソフトマックス関数を使用して出力を変換することにより、確率のような値が得られます。この変換によって、各クラスに対する予測確率が計算されます。必要に応じて、Pandas などのライブラリを利用して、予測結果を見やすい形式に整えます。
output = pd.DataFrame({
'class': labels,
'probability': prob
})
output
class | probability | |
---|---|---|
0 | ARCIGERA_FLOWER_MOTH | 0.991520 |
1 | ATLAS_MOTH | 0.000085 |
2 | BANDED_TIGER_MOTH | 0.000034 |
3 | BIRD_CHERRY_ERMINE_MOTH | 0.000016 |
4 | BLACK_RUSTIC_MOTH | 0.000023 |
5 | BLAIRS_MOCHA | 0.000026 |
6 | BLOTCHED_EMERALD_MOTH | 0.000007 |
7 | BLUE_BORDERED_CARPET_MOTH | 0.000391 |
8 | CINNABAR_MOTH | 0.000082 |
9 | CLEARWING_MOTH | 0.000282 |
10 | COMET_MOTH | 0.000009 |
11 | DEATHS_HEAD_HAWK_MOTH | 0.000140 |
12 | ELEPHANT_HAWK_MOTH | 0.000280 |
13 | EMPEROR_GUM_MOTH | 0.000024 |
14 | EMPEROR_MOTH | 0.000109 |
15 | EYED_HAWK_MOTH | 0.001312 |
16 | FIERY_CLEARWING_MOTH | 0.000139 |
17 | GARDEN_TIGER_MOTH | 0.000120 |
18 | HERCULES_MOTH | 0.000547 |
19 | HORNET_MOTH | 0.000003 |
20 | HUMMING_BIRD_HAWK_MOTH | 0.000307 |
21 | IO_MOTH | 0.000147 |
22 | JULY_BELLE_MOTH | 0.000118 |
23 | KENTISH_GLORY_MOTH | 0.002917 |
24 | LACE_BORDER_MOTH | 0.000011 |
25 | LEOPARD_MOTH | 0.000050 |
26 | LUNA_MOTH | 0.000082 |
27 | MADAGASCAN_SUNSET_MOTH | 0.000020 |
28 | MAGPIE_MOTH | 0.000050 |
29 | MUSLIN_MOTH | 0.000054 |
30 | OLEANDER_HAWK_MOTH | 0.000117 |
31 | OWL_MOTH | 0.000009 |
32 | PALPITA_VITREALIS_MOTH | 0.000010 |
33 | PEACH_BLOSSOM_MOTH | 0.000075 |
34 | PLUME_MOTH | 0.000014 |
35 | POLYPHEMUS_MOTH | 0.000217 |
36 | PURPLE_BORDERED_GOLD_MOTH | 0.000163 |
37 | RED_NECKED_FOOTMAN_MOTH | 0.000009 |
38 | REGAL_MOTH | 0.000015 |
39 | ROSY_MAPLE_MOTH | 0.000049 |
40 | ROSY_UNDERWING_MOTH | 0.000066 |
41 | RUSTY_DOT_PEARL_MOTH | 0.000006 |
42 | SCHORCHED_WING_MOTH | 0.000125 |
43 | SIXSPOT_BURNET_MOTH | 0.000018 |
44 | SQUARE_SPOT_RUSTIC_MOTH | 0.000054 |
45 | SUSSEX_EMERALD_MOTH | 0.000018 |
46 | SWALLOW_TAILED_MOTH | 0.000010 |
47 | VESTAL_MOTH | 0.000019 |
48 | WHITE_LINED_SPHINX_MOTH | 0.000026 |
49 | WHITE_SPOTTED_SABLE_MOTH | 0.000072 |
もう一例を見てみましょう。LUNA_MOTH に分類された画像をモデルに入力し、推論を行います。
image_path = 'moths/test/LUNA_MOTH/5.jpg'
image = PIL.Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
score = model(input_tensor)[0]
pd.DataFrame({
'class': labels,
'probability': torch.softmax(score, axis=0).cpu().detach().numpy()
})
class | probability | |
---|---|---|
0 | ARCIGERA_FLOWER_MOTH | 0.000048 |
1 | ATLAS_MOTH | 0.000100 |
2 | BANDED_TIGER_MOTH | 0.000015 |
3 | BIRD_CHERRY_ERMINE_MOTH | 0.000018 |
4 | BLACK_RUSTIC_MOTH | 0.000010 |
5 | BLAIRS_MOCHA | 0.000006 |
6 | BLOTCHED_EMERALD_MOTH | 0.000054 |
7 | BLUE_BORDERED_CARPET_MOTH | 0.000049 |
8 | CINNABAR_MOTH | 0.000038 |
9 | CLEARWING_MOTH | 0.000063 |
10 | COMET_MOTH | 0.001164 |
11 | DEATHS_HEAD_HAWK_MOTH | 0.000023 |
12 | ELEPHANT_HAWK_MOTH | 0.000089 |
13 | EMPEROR_GUM_MOTH | 0.000098 |
14 | EMPEROR_MOTH | 0.000009 |
15 | EYED_HAWK_MOTH | 0.000011 |
16 | FIERY_CLEARWING_MOTH | 0.000012 |
17 | GARDEN_TIGER_MOTH | 0.000049 |
18 | HERCULES_MOTH | 0.000053 |
19 | HORNET_MOTH | 0.000011 |
20 | HUMMING_BIRD_HAWK_MOTH | 0.000020 |
21 | IO_MOTH | 0.000024 |
22 | JULY_BELLE_MOTH | 0.000035 |
23 | KENTISH_GLORY_MOTH | 0.000023 |
24 | LACE_BORDER_MOTH | 0.000023 |
25 | LEOPARD_MOTH | 0.000034 |
26 | LUNA_MOTH | 0.997339 |
27 | MADAGASCAN_SUNSET_MOTH | 0.000049 |
28 | MAGPIE_MOTH | 0.000016 |
29 | MUSLIN_MOTH | 0.000017 |
30 | OLEANDER_HAWK_MOTH | 0.000040 |
31 | OWL_MOTH | 0.000023 |
32 | PALPITA_VITREALIS_MOTH | 0.000007 |
33 | PEACH_BLOSSOM_MOTH | 0.000035 |
34 | PLUME_MOTH | 0.000046 |
35 | POLYPHEMUS_MOTH | 0.000030 |
36 | PURPLE_BORDERED_GOLD_MOTH | 0.000023 |
37 | RED_NECKED_FOOTMAN_MOTH | 0.000052 |
38 | REGAL_MOTH | 0.000013 |
39 | ROSY_MAPLE_MOTH | 0.000061 |
40 | ROSY_UNDERWING_MOTH | 0.000007 |
41 | RUSTY_DOT_PEARL_MOTH | 0.000002 |
42 | SCHORCHED_WING_MOTH | 0.000009 |
43 | SIXSPOT_BURNET_MOTH | 0.000012 |
44 | SQUARE_SPOT_RUSTIC_MOTH | 0.000006 |
45 | SUSSEX_EMERALD_MOTH | 0.000037 |
46 | SWALLOW_TAILED_MOTH | 0.000076 |
47 | VESTAL_MOTH | 0.000014 |
48 | WHITE_LINED_SPHINX_MOTH | 0.000004 |
49 | WHITE_SPOTTED_SABLE_MOTH | 0.000003 |
3.6.6. 分類根拠の可視化#
畳み込みニューラルネットワーク(CNN)を利用した分類において、畳み込み層で抽出された特徴マップが分類に寄与しています。そのため、最後の畳み込み層で得られた特徴マップと、それぞれの特徴量に対する重みを用いて可視化を行うことで、モデルの分類根拠を明らかにすることができます。
ここでは、いくつかの可視化アルゴリズムを用いて、モデルの判断根拠を可視化します。可視化には Python の grad-cam パッケージが必要です。必要に応じてインストールし、grad-cam のチュートリアルを参考にして、Grad-CAM を計算し、可視化するための関数を定義します。
def viz(image_path):
# load models
labels = [_ for _ in test_dataset.classes]
model = resnet18(50, 'moths.pth')
model.to(device)
model.eval()
# load image
image = PIL.Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device)
rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
rgb_img = np.float32(np.array(image)) / 255
# Grad-CAM
with pytorch_grad_cam.GradCAM(model=model, target_layers=[model.layer4]) as cam:
cam.batch_size = 32
grayscale_cam = cam(input_tensor=input_tensor, targets=None,aug_smooth=True, eigen_smooth=True)
grayscale_cam = grayscale_cam[0, :]
cam_image = pytorch_grad_cam.utils.image.show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
prob = torch.softmax(cam.outputs[0], axis=0).cpu().detach().numpy()
gb_model = pytorch_grad_cam.GuidedBackpropReLUModel(model=model, device=device)
gb = gb_model(input_tensor, target_category=None)
cam_mask = np.stack([grayscale_cam, grayscale_cam, grayscale_cam], axis=-1)
cam_gb = pytorch_grad_cam.utils.image.deprocess_image(cam_mask * gb)
gb = pytorch_grad_cam.utils.image.deprocess_image(gb)
# plot
fig, ax = plt.subplots(2, 2)
ax[0, 0].imshow(rgb_img)
ax[0, 0].axis('off')
ax[0, 0].set_title('Original Image', fontsize=16)
ax[0, 1].imshow(cam_image)
ax[0, 1].axis('off')
ax[0, 1].set_title('Grad-CAM', fontsize=16)
ax[1, 0].imshow(gb)
ax[1, 0].axis('off')
ax[1, 0].set_title('Guided Backpropagation', fontsize=16)
ax[1, 1].imshow(cam_gb)
ax[1, 1].axis('off')
ax[1, 1].set_title('Guided Grad-CAM', fontsize=16)
fig.show()
次に、いくつかの画像をこの関数に代入し、モデルの予測結果とその判断根拠を可視化します。これにより、モデルがどの部分に注目して分類を行ったのかが視覚的に示されます。
viz('moths/test/ARCIGERA_FLOWER_MOTH/1.jpg')

viz('moths/test/LUNA_MOTH/5.jpg')
