4.6. 演習:小麦穂検出#

穂数や開花数、結実数などを数えることは、フェノタイピングにおいて欠かせない作業である。しかし、実験圃場全体を対象としたフェノタイピングには高いコストがかかる。そのため、現在では深層学習による画像認識を活用したフェノタイピングが一般的となっている。本節では、小麦の穂を検出し、計数を行う例を示す。

4.6.1. 演習準備#

4.6.1.1. ライブラリ#

本節で利用するライブラリを読み込みます。NumPy、Pnadas、Matplotlib、Pillow(PIL)などのライブライは、モデルの性能や推論結果などの可視化に利用します。scikit-learn(sklearn)、PyTorch(torch)、torchvision、torchmetrics は機械学習関連のライブラリであり、モデルの構築、検証や推論などに利用します。

import os
import numpy as np
import pandas as pd
import PIL.Image
import matplotlib.pyplot as plt

import torch
import torchvision
import torchmetrics

print(f'torch v{torch.__version__}; torchvision v{torchvision.__version__}')
torch v2.5.1+cu121; torchvision v0.20.1+cu121

ライブラリの読み込み時に ImportErrorModuleNotFoundError が発生した場合は、該当するライブラリをインストールしてください。ライブラリのバージョンを揃える必要はありませんが、PyTorch(torch)および torchvision が上記のバージョンと異なる時、実行中に警告メッセージが現れたり、同じ結果にならなかったりする可能性があります。

4.6.1.2. データセット#

本節では、Global WHEAT Dataset[1] を使用します。このデータセットには、さまざまな圃場で撮影された画像が含まれており、それらは 1024×1024 ピクセルの正方形に切り取られています。世界各地で栽培されている小麦の品種を幅広くカバーしており、成熟度も異なる多様な画像が含まれているため、実践的なデータセットとなっています。各画像には、小麦の穂が存在する位置のバウンディングボックスの座標情報が CSV 形式で提供されています。

../_images/gwhd_dataset.jpg

Fig. 4.20 Global WHEAT Dataset に含まれる画像のサンプル。#

オリジナルのデータセットは非常に多くの画像を含んでいるため、本演習では、限られた時間内で訓練やテストを実施できるように、訓練データ 200 枚、検証データ 100 枚、テストデータ 100 枚に調整したサブセットを使用します。また、このサブセットのアノテーションデータは、PyTorch が利用しやすいように CSV から COCO フォーマットに変更してあります。このサブセットは、Jupyter Notebook 上で次のコマンドを実行することでダウンロードできます。

!wget https://dl.biopapyrus.jp/data/gwhd.zip
!unzip gwhd.zip

4.6.1.3. 画像前処理#

物体検出のタスクでは、画像とともに、検出対象の物体を囲むバウンディングボックスの座標とラベル(ここでは小麦の「穂」)をモデルに与え、学習させる必要があります。本節では、画像とバウンディングボックスの座標、およびそのラベルを適切に対応づけるための前処理コードを作成します。この前処理により、画像と COCO フォーマットのアノテーションが対応づけられ、PyTorch に入力できる形式となるため、学習をスムーズに行うことが可能になります。

なお、PyTorch では、アノテーションのない画像を処理する際にエラーが発生するため、小麦の穂が含まれていない画像に対しては特別な処理を施します。具体的には、ダミーのバウンディングボックス([0, 0, 1, 1])を設定し、そのラベルを背景クラス(0)として扱うことで、エラーを回避します。この対応により、対象物が存在しない画像も問題なく処理できるようになります。

class CocoDataset(torchvision.datasets.CocoDetection):
    def __init__(self, root, annFile):
        super(CocoDataset, self).__init__(root, annFile)
    
    def __getitem__(self, idx):
        img, target = super(CocoDataset, self).__getitem__(idx)
        
        boxes = []
        labels = []
        for obj in target:
            bbox = obj['bbox']
            bbox = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]
            boxes.append(bbox)
            labels.append(obj['category_id'])
        if len(boxes) == 0:
            boxes = [[0, 0, 1, 1]]
            labels = [0]
        
        img = torchvision.transforms.functional.to_tensor(img)
        target = {
            'boxes': torch.as_tensor(boxes, dtype=torch.float32),
            'labels': torch.as_tensor(labels, dtype=torch.int64),
        }

        return img, target
    

ここでは、画像分類と同様に、畳み込みニューラルネットワーク(CNN)に入力する画像のサイズを、指定されたサイズに変更する必要があります。一般的な物体検出モデルでは、長方形の画像を入力として扱うことが多いですが、今回のデータセットは正方形の画像で構成されています。そのため、独自のコードを用いて適切なサイズ調整を行うことで、より高い性能が期待できます。しかし、この作業は煩雑であるため、本節では CocoDetection クラスに実装されているデフォルトの機能を利用するにとどめます。

また、モデルを訓練する際には、画像の拡大縮小や平行移動、回転などのデータ拡張を行う必要があります。これに伴い、バウンディングボックスの座標も同様に再計算しなければなりません。しかし、これらの処理を追加するとコードが複雑になり、全体の流れがわかりにくくなるため、本節ではデータ拡張の処理は省略します。

4.6.1.4. 計算デバイス#

計算を行うデバイスを設定します。PyTorch が GPU を認識できる場合は GPU を利用し、認識できない場合は CPU を使用するように設定します。

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

4.6.2. モデル構築#

本節では、物体検出アーキテクチャとしてよく知られている Faster R-CNN を使用します。torchvision.models.detection で提供されているアーキテクチャは、車や人など 90 種類の一般的なオブジェクトを対象としています。これに対して、本節では、小麦の穂という 1 種類のオブジェクトのみを検出を目的としています。そのため、torchvision.models.detection から読み込んだアーキテクチャの出力層のユニット数を、検出対象の種類数に合わせる必要があります。

物体検出アーキテクチャでは、背景を一つのクラスとして扱うため、出力数を修正するとき、検出対象数に 1 を加えた値で修正します。例えば、穂検出の場合は、出力層の数を 2 とします。この修正は、アーキテクチャを呼び出すたびに行う必要があり、手間がかかります。そこで、一連の処理を関数として定義してから利用します。

def fasterrcnn(num_classes, weights=None):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    num_classes = num_classes + 1  # class + background
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
    if weights is not None:
        model.load_state_dict(torch.load(weights))
    return model

model = fasterrcnn(num_classes=1)
model.to(device)
Hide code cell output
FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): FrozenBatchNorm2d(256, eps=0.0)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(inplace=True)
        )
      )
      (layer2): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=0.0)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=0.0)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): FrozenBatchNorm2d(512, eps=0.0)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=0.0)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=0.0)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=0.0)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=0.0)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=0.0)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=0.0)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=0.0)
          (relu): ReLU(inplace=True)
        )
      )
      (layer3): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): FrozenBatchNorm2d(1024, eps=0.0)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (4): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (5): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
      )
      (layer4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(512, eps=0.0)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(512, eps=0.0)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(2048, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): FrozenBatchNorm2d(2048, eps=0.0)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(512, eps=0.0)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(512, eps=0.0)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(2048, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(512, eps=0.0)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(512, eps=0.0)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(2048, eps=0.0)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (fpn): FeaturePyramidNetwork(
      (inner_blocks): ModuleList(
        (0): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (layer_blocks): ModuleList(
        (0-3): 4 x Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (extra_blocks): LastLevelMaxPool()
    )
  )
  (rpn): RegionProposalNetwork(
    (anchor_generator): AnchorGenerator()
    (head): RPNHead(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
        )
      )
      (cls_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))
      (bbox_pred): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (roi_heads): RoIHeads(
    (box_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7), sampling_ratio=2)
    (box_head): TwoMLPHead(
      (fc6): Linear(in_features=12544, out_features=1024, bias=True)
      (fc7): Linear(in_features=1024, out_features=1024, bias=True)
    )
    (box_predictor): FastRCNNPredictor(
      (cls_score): Linear(in_features=1024, out_features=2, bias=True)
      (bbox_pred): Linear(in_features=1024, out_features=8, bias=True)
    )
  )
)

4.6.3. モデル訓練#

モデルが学習データを効率よく学習できるようにするため、学習アルゴリズム(optimizer)、学習率(lr)、および学習率を調整するスケジューラ(lr_scheduler)を設定します。なお、画像分類では損失関数も合わせて定義していますが、物体検出では分類誤差を計算する損失関数とバウンディングボックスの座標の誤差を計算する損失関数の二種類を定義する必要があります。これらの関数はすでにモデルの中で定義されているため、ここであらためて定義する必要はありません。

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.01)

次に、訓練データと検証データを読み込み、モデルが入力できる形式に整えます。

train_loader = torch.utils.data.DataLoader(
                    CocoDataset('gwhd', 'gwhd/train.json'),
                    batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

valid_loader = torch.utils.data.DataLoader(
                    CocoDataset('gwhd', 'gwhd/valid.json'),
                    batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))
Hide code cell output
loading annotations into memory...
Done (t=0.02s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!

準備が整ったら、訓練を開始します。訓練プロセスでは、訓練と検証を交互に繰り返します。訓練では、訓練データを使ってモデルのパラメータを更新し、その際の損失(誤差)を記録します。検証では、検証データを使ってモデルの予測性能(mAP)を計算し、その結果を記録します。

num_epochs = 10
metric_dict = []

for epoch in range(num_epochs):
    # training phase
    model.train()
    epoch_loss_dict = {}

    for images, targets in train_loader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        batch_loss_dict = model(images, targets)
        batch_tol_loss = 0
        for loss_type, loss_val in batch_loss_dict.items():
            batch_tol_loss += loss_val
            if loss_type in epoch_loss_dict:
                epoch_loss_dict[f'train_{loss_type}'] += loss_val.item()
            else:
                epoch_loss_dict[f'train_{loss_type}'] = loss_val.item()
                
        # update weights
        optimizer.zero_grad()
        batch_tol_loss.backward()
        optimizer.step()
    lr_scheduler.step()


    # validation phase
    model.eval()
    metric = torchmetrics.detection.mean_ap.MeanAveragePrecision()
    with torch.no_grad():
        for images, targets in valid_loader:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            metric.update(model(images), targets)


    # record training loss
    epoch_loss_dict['train_loss_total'] = sum(epoch_loss_dict.values())
    metric_dict.append({k: v / len(train_loader) for k, v in epoch_loss_dict.items()})
    for k, v in metric.compute().items():
        if k != 'classes':
            metric_dict[-1][k] = v.item()
    metric_dict[-1]['epoch'] = epoch + 1

    print(metric_dict[-1])
Hide code cell output
{'train_loss_classifier': 0.007636570334434509, 'train_loss_box_reg': 0.014173473119735719, 'train_loss_objectness': 0.002357144504785538, 'train_loss_rpn_box_reg': 0.0017590907216072083, 'train_loss_total': 0.025926278680562975, 'map': 0.09538678079843521, 'map_50': 0.26854586601257324, 'map_75': 0.034807998687028885, 'map_small': 0.0, 'map_medium': 0.19908052682876587, 'map_large': 0.16562525928020477, 'mar_1': 0.0043641431257128716, 'mar_10': 0.03734720125794411, 'mar_100': 0.16331760585308075, 'mar_small': 0.0, 'mar_medium': 0.33003395795822144, 'mar_large': 0.3345758318901062, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 1}
{'train_loss_classifier': 0.006604284644126892, 'train_loss_box_reg': 0.01260679006576538, 'train_loss_objectness': 0.0025429564714431765, 'train_loss_rpn_box_reg': 0.001986919343471527, 'train_loss_total': 0.023740950524806976, 'map': 0.14333924651145935, 'map_50': 0.35522961616516113, 'map_75': 0.08604297041893005, 'map_small': 0.0024294499307870865, 'map_medium': 0.28656691312789917, 'map_large': 0.31338533759117126, 'mar_1': 0.005146901123225689, 'mar_10': 0.05058974772691727, 'mar_100': 0.1964186131954193, 'mar_small': 0.012500000186264515, 'mar_medium': 0.38346827030181885, 'mar_large': 0.4668380320072174, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 2}
{'train_loss_classifier': 0.006086929440498352, 'train_loss_box_reg': 0.011673849821090699, 'train_loss_objectness': 0.0021947240829467774, 'train_loss_rpn_box_reg': 0.0019063720107078553, 'train_loss_total': 0.021861875355243685, 'map': 0.16921494901180267, 'map_50': 0.39122071862220764, 'map_75': 0.11149682104587555, 'map_small': 0.007756911683827639, 'map_medium': 0.3375093638896942, 'map_large': 0.37576717138290405, 'mar_1': 0.006111945025622845, 'mar_10': 0.05509328842163086, 'mar_100': 0.21585889160633087, 'mar_small': 0.02410714328289032, 'mar_medium': 0.4261164665222168, 'mar_large': 0.48894602060317993, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 3}
{'train_loss_classifier': 0.006054455637931824, 'train_loss_box_reg': 0.011459314823150634, 'train_loss_objectness': 0.0013985633850097656, 'train_loss_rpn_box_reg': 0.0017762410640716553, 'train_loss_total': 0.02068857491016388, 'map': 0.17258816957473755, 'map_50': 0.3973223865032196, 'map_75': 0.11937520653009415, 'map_small': 0.00966841820627451, 'map_medium': 0.3388715386390686, 'map_large': 0.40536069869995117, 'mar_1': 0.006519407965242863, 'mar_10': 0.057377226650714874, 'mar_100': 0.2192794382572174, 'mar_small': 0.02857142873108387, 'mar_medium': 0.4262470602989197, 'mar_large': 0.5278920531272888, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 4}
{'train_loss_classifier': 0.004160727560520172, 'train_loss_box_reg': 0.008839853405952454, 'train_loss_objectness': 0.0006757990270853043, 'train_loss_rpn_box_reg': 0.0008076342195272446, 'train_loss_total': 0.014484014213085174, 'map': 0.18541501462459564, 'map_50': 0.41510728001594543, 'map_75': 0.13624610006809235, 'map_small': 0.01184830255806446, 'map_medium': 0.37072476744651794, 'map_large': 0.39347076416015625, 'mar_1': 0.006369289942085743, 'mar_10': 0.057420115917921066, 'mar_100': 0.22870469093322754, 'mar_small': 0.03660714253783226, 'mar_medium': 0.46147820353507996, 'mar_large': 0.4672236442565918, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 5}
{'train_loss_classifier': 0.005996693372726441, 'train_loss_box_reg': 0.011288018226623535, 'train_loss_objectness': 0.0009498541057109832, 'train_loss_rpn_box_reg': 0.001507614850997925, 'train_loss_total': 0.019742180556058884, 'map': 0.18947400152683258, 'map_50': 0.415609210729599, 'map_75': 0.14540135860443115, 'map_small': 0.01535609271377325, 'map_medium': 0.3792743980884552, 'map_large': 0.40887030959129333, 'mar_1': 0.006540853530168533, 'mar_10': 0.05859961360692978, 'mar_100': 0.23332618176937103, 'mar_small': 0.03928571566939354, 'mar_medium': 0.46787673234939575, 'mar_large': 0.49074551463127136, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 6}
{'train_loss_classifier': 0.005074151754379272, 'train_loss_box_reg': 0.009185189008712768, 'train_loss_objectness': 0.0009762082993984223, 'train_loss_rpn_box_reg': 0.0014452695846557616, 'train_loss_total': 0.016680818647146226, 'map': 0.19019567966461182, 'map_50': 0.4149739444255829, 'map_75': 0.14536364376544952, 'map_small': 0.015378776006400585, 'map_medium': 0.38042521476745605, 'map_large': 0.4131540060043335, 'mar_1': 0.006573021877557039, 'mar_10': 0.0589856319129467, 'mar_100': 0.23399099707603455, 'mar_small': 0.03839285671710968, 'mar_medium': 0.46711936593055725, 'mar_large': 0.5025706887245178, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 7}
{'train_loss_classifier': 0.005067222118377686, 'train_loss_box_reg': 0.009892233610153199, 'train_loss_objectness': 0.0008133697509765625, 'train_loss_rpn_box_reg': 0.0012730754911899566, 'train_loss_total': 0.017045900970697404, 'map': 0.19052116572856903, 'map_50': 0.4151192009449005, 'map_75': 0.14504939317703247, 'map_small': 0.01841896027326584, 'map_medium': 0.3792853057384491, 'map_large': 0.41893377900123596, 'mar_1': 0.0066695259883999825, 'mar_10': 0.05901779979467392, 'mar_100': 0.2342483401298523, 'mar_small': 0.03928571566939354, 'mar_medium': 0.466100811958313, 'mar_large': 0.5104113221168518, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 8}
{'train_loss_classifier': 0.004147751331329346, 'train_loss_box_reg': 0.009677772521972655, 'train_loss_objectness': 0.000982421189546585, 'train_loss_rpn_box_reg': 0.0015559527277946472, 'train_loss_total': 0.016363897770643236, 'map': 0.19065935909748077, 'map_50': 0.4154583811759949, 'map_75': 0.14555609226226807, 'map_small': 0.015842221677303314, 'map_medium': 0.379228800535202, 'map_large': 0.42176172137260437, 'mar_1': 0.0066695259883999825, 'mar_10': 0.05923225358128548, 'mar_100': 0.23425905406475067, 'mar_small': 0.03750000149011612, 'mar_medium': 0.46542179584503174, 'mar_large': 0.5141388177871704, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 9}
{'train_loss_classifier': 0.005217180252075195, 'train_loss_box_reg': 0.010225982666015624, 'train_loss_objectness': 0.0009424474090337753, 'train_loss_rpn_box_reg': 0.0012347762286663054, 'train_loss_total': 0.0176203865557909, 'map': 0.19093823432922363, 'map_50': 0.415356308221817, 'map_75': 0.14672257006168365, 'map_small': 0.01612910069525242, 'map_medium': 0.37957215309143066, 'map_large': 0.4250395596027374, 'mar_1': 0.006648080423474312, 'mar_10': 0.05936092510819435, 'mar_100': 0.23462362587451935, 'mar_small': 0.03839285671710968, 'mar_medium': 0.4656568169593811, 'mar_large': 0.5170950889587402, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 10}

訓練データに対する損失と検証データに対する予測性能(mAP)を可視化し、訓練過程を評価します。

Hide code cell source
metric_dict = pd.DataFrame(metric_dict)

fig, ax = plt.subplots(1, 2)
ax[0].plot(metric_dict['epoch'], metric_dict['train_loss_total'], label='total')
ax[0].plot(metric_dict['epoch'], metric_dict['train_loss_classifier'], label='classification')
ax[0].plot(metric_dict['epoch'], metric_dict['train_loss_box_reg'], label='bbox')
ax[0].set_xlabel('epoch')
ax[0].set_ylabel('loss')
ax[0].set_title('Train')
ax[0].legend()
ax[1].plot(metric_dict['epoch'], metric_dict['map_50'])
ax[1].set_ylim(0, 1)
ax[1].set_xlabel('epoch')
ax[1].set_ylabel('mAP (50%)')
ax[1].set_title('Validation')
plt.tight_layout()
fig.show()
../_images/238259f352015c6d474c6fa1b0736bf5d3995e01b8c73a7d26684043cd76354a.png

可視化の結果から、エポック数が増えるにつれて訓練データに対する損失が継続的に減少していることが確認できます。10 エポック目においても訓練損失が減少し続ける傾向がまだ見られます。一方、検証データに対する検出性能(mAP 50%)は、5 エポックを過ぎたあたりでほぼ収束しているようです。ただし、値が と低く、十分とはいえません。このため、訓練エポック数をさらに増やして損失や検証性能の推移を詳しく観察するか、必要に応じてデータを増やして再訓練することが考えられます。ただし、本節では、時間の制約があるため、訓練はここで終了します。

次に、この手順を SSD や YOLO など、他の深層ニューラルネットワークのアーキテクチャに適用し、それぞれの検証性能を比較します。この比較により、データセットに最も適したアーキテクチャを選定します。ただし、本節では時間の関係で他のアーキテクチャを構築せず、上で構築した Faster R-CNN を最適なアーキテクチャとして扱い、次のステップに進みます。

次のステップでは、訓練サブセットと検証サブセットを統合し、最適と判断したアーキテクチャを最初から訓練します。

モデル選択のために行われた訓練と検証の結果から、数エポックの訓練だけでも十分に高い予測性能を獲得できたことがわかったので、ここでは訓練サブセットと検証サブセットを統合したデータに対して 5 エポックだけ訓練させます。

# model
model = fasterrcnn(num_classes=1)
model.to(device)

# training parameters
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

# training data
train_loader = torch.utils.data.DataLoader(
                            CocoDataset('gwhd', 'gwhd/trainvalid.json'),
                            batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

# training
num_epochs = 5
metric_dict = []
for epoch in range(num_epochs):
    model.train()
    epoch_loss_dict = {}
    for images, targets in train_loader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        batch_loss_dict = model(images, targets)
        batch_tol_loss = 0
        for loss_type, loss_val in batch_loss_dict.items():
            batch_tol_loss += loss_val
            if loss_type in epoch_loss_dict:
                epoch_loss_dict[f'train_{loss_type}'] += loss_val.item()
            else:
                epoch_loss_dict[f'train_{loss_type}'] = loss_val.item()
        optimizer.zero_grad()
        batch_tol_loss.backward()
        optimizer.step()
    lr_scheduler.step()

    # record training loss
    epoch_loss_dict['train_loss_total'] = sum(epoch_loss_dict.values())
    metric_dict.append({k: v / len(train_loader) for k, v in epoch_loss_dict.items()})
    metric_dict[-1]['epoch'] = epoch + 1
    print(metric_dict[-1])
Hide code cell output
loading annotations into memory...
Done (t=0.03s)
creating index...
index created!
{'train_loss_classifier': 0.002623879313468933, 'train_loss_box_reg': 0.0062036569913228356, 'train_loss_objectness': 0.0007894185682137807, 'train_loss_rpn_box_reg': 0.0009187627832094828, 'train_loss_total': 0.010535717656215032, 'epoch': 1}
{'train_loss_classifier': 0.0031152212619781493, 'train_loss_box_reg': 0.006234282255172729, 'train_loss_objectness': 0.0003873119751612345, 'train_loss_rpn_box_reg': 0.0009469439586003622, 'train_loss_total': 0.010683759450912475, 'epoch': 2}
{'train_loss_classifier': 0.004101479053497314, 'train_loss_box_reg': 0.007303280830383301, 'train_loss_objectness': 0.0008268887797991435, 'train_loss_rpn_box_reg': 0.0012875079115231831, 'train_loss_total': 0.013519156575202942, 'epoch': 3}
{'train_loss_classifier': 0.0019704099496205647, 'train_loss_box_reg': 0.004701109727223715, 'train_loss_objectness': 0.0002920643985271454, 'train_loss_rpn_box_reg': 0.0028775211175282797, 'train_loss_total': 0.009841105192899704, 'epoch': 4}
{'train_loss_classifier': 0.0029249056180318195, 'train_loss_box_reg': 0.006188910007476806, 'train_loss_objectness': 0.00038986667990684507, 'train_loss_rpn_box_reg': 0.0009572247664133708, 'train_loss_total': 0.010460907071828842, 'epoch': 5}

訓練が完了したら、訓練済みモデルの重みをファイルに保存します。

model.to('cpu')
torch.save(model.state_dict(), 'gwhd.pth')

4.6.4. モデル評価#

最適なモデルが得られたら、次にテストデータを用いて詳細な評価を行います。ここでは、物体検出で一般的に用いられる評価指標である mAP を計算し、さらに各画像に対して予測された穂の数と実際の穂の数も記録します。なお、予測スコアが 0.5 を超える領域を予測領域として、閾値を設定しています。

test_loader = torch.utils.data.DataLoader(
                    CocoDataset('gwhd', 'gwhd/test.json'),
                    batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

model = fasterrcnn(num_classes=1, weights='gwhd.pth')
model.to(device)
model.eval()

n_gt = []
n_predicted = []
metric = torchmetrics.detection.mean_ap.MeanAveragePrecision()
with torch.no_grad():
    for images, targets in test_loader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        outputs = model(images)

        for target, output in zip(targets, outputs):
            n_gt.append(len(target['labels']))
            _n = 0
            for score in output['scores']:
                if score > 0.5:
                    _n += 1
            n_predicted.append(_n)
        
        metric.update(outputs, targets)

metrics = {}
for k, v in metric.compute().items():
    metrics[k] = v.cpu().detach().numpy().tolist()
Hide code cell output
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
metrics
{'map': 0.2296687662601471,
 'map_50': 0.44304972887039185,
 'map_75': 0.210895836353302,
 'map_small': 0.08700288087129593,
 'map_medium': 0.4606533348560333,
 'map_large': 0.4755059480667114,
 'mar_1': 0.00795561820268631,
 'mar_10': 0.07247403264045715,
 'mar_100': 0.2655571401119232,
 'mar_small': 0.10000000149011612,
 'mar_medium': 0.5302915573120117,
 'mar_large': 0.5479434728622437,
 'map_per_class': -1.0,
 'mar_100_per_class': -1.0,
 'classes': [0, 1]}

ここで出力される指標について、map から始まる指標は mAP を表し、mar から始まる指標は mean average recall(平均再現率)を示します。mar はすべてのクラスに対する再現率を計算し、それらの平均を求めたものです。mAR 1 は、各画像に対してモデルが検出した物体のうち、最も信頼度の高い物体だけを利用して計算した再現率を表します。同様に、mAR 10 および mAR 100 は、モデルが検出した物体のうち、信頼度の高い 10 および 100 物体を利用して計算した再現率を表しています。

次に、予測した穂の数と実際の穂の数を散布図で可視化して評価します。

fig, ax = plt.subplots()
ax.scatter(n_gt, n_predicted, alpha=0.5)

mn = min(min(n_gt), min(n_predicted))
mx = max(max(n_gt), max(n_predicted))
ax.plot([mn, mx], [mn, mx], '--', label='y = x')

ax.set_xlabel('number of spikes (groundtruth)')
ax.set_ylabel('number of predicted spikes')
ax.set_aspect('equal')
ax.legend()
fig.show()
../_images/a99a700b9760be5591298abd456a6826bac500874aa0167cb84eee4c40d46adc.png

可視化の結果、予測数と実際の数が完全に一致する画像はほとんどありませんでしたが、ほとんどの画像においてほぼ正確に穂を検出できたことが確認できました。

4.6.5. 推論#

推論時にも、訓練時と同じように torchvision モジュールからアーキテクチャを呼び出し、出力層のクラス数を設定します。その後、load_state_dict メソッドを使って、訓練済みの重みファイルをモデルにロードします。これらの操作はすべて fasterrcnn 関数で定義されているので、その関数を利用します。

model = fasterrcnn(num_classes=1, weights='gwhd.pth')
model.to(device)
model.eval()
Hide code cell output
FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): FrozenBatchNorm2d(256, eps=0.0)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(inplace=True)
        )
      )
      (layer2): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=0.0)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=0.0)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): FrozenBatchNorm2d(512, eps=0.0)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=0.0)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=0.0)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=0.0)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=0.0)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=0.0)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=0.0)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=0.0)
          (relu): ReLU(inplace=True)
        )
      )
      (layer3): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): FrozenBatchNorm2d(1024, eps=0.0)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (4): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (5): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
      )
      (layer4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(512, eps=0.0)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(512, eps=0.0)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(2048, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): FrozenBatchNorm2d(2048, eps=0.0)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(512, eps=0.0)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(512, eps=0.0)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(2048, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(512, eps=0.0)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(512, eps=0.0)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(2048, eps=0.0)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (fpn): FeaturePyramidNetwork(
      (inner_blocks): ModuleList(
        (0): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (layer_blocks): ModuleList(
        (0-3): 4 x Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (extra_blocks): LastLevelMaxPool()
    )
  )
  (rpn): RegionProposalNetwork(
    (anchor_generator): AnchorGenerator()
    (head): RPNHead(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
        )
      )
      (cls_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))
      (bbox_pred): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (roi_heads): RoIHeads(
    (box_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7), sampling_ratio=2)
    (box_head): TwoMLPHead(
      (fc6): Linear(in_features=12544, out_features=1024, bias=True)
      (fc7): Linear(in_features=1024, out_features=1024, bias=True)
    )
    (box_predictor): FastRCNNPredictor(
      (cls_score): Linear(in_features=1024, out_features=2, bias=True)
      (bbox_pred): Linear(in_features=1024, out_features=8, bias=True)
    )
  )
)

このモデルを利用して推論を行います。まず、1 枚の画像を指定し、PIL モジュールを用いて画像を開き、テンソル形式に変換した後、モデルに入力します。モデルは予測結果としてバウンディングボックスの座標(bboxes)、分類ラベル(labels)、および信頼スコア(scores)を出力します。ただし、信頼スコアが 0.5 未満のバウンディングボックスは採用せず、信頼スコアが高い結果のみを選択して利用します。

threshold = 0.5
image_path = 'gwhd/images/fda86ae9a.jpg'

image = PIL.Image.open(image_path).convert('RGB')
input_tensor = torchvision.transforms.functional.to_tensor(image).unsqueeze(0).to(device)
    
with torch.no_grad():
    predictions = model(input_tensor)[0]
    
bboxes = predictions['boxes'][predictions['scores'] > threshold]
labels = predictions['labels'][predictions['scores'] > threshold]
scores = predictions['scores'][predictions['scores'] > threshold]

検出されたオブジェクトのバウンディングボックスを入力画像に描画します。その後、PIL および matplotlib ライブラリを使用して、画像とその検出結果を可視化します。

draw = PIL.ImageDraw.Draw(image)
for bbox, label, score in zip(bboxes, labels, scores):
    x1, y1, x2, y2 = bbox
    draw.rectangle(((x1, y1), (x2, y2)), outline="blue", width=3)
    draw.text((x1, y1 - 10), f"{label.item()} ({score:.2f})", fill="blue")
    
fig = plt.figure()
ax = fig.add_subplot()
ax.imshow(image)
ax.axis('off')
fig.show()
../_images/a3d21a955ec93854ad6709d2bb309a44d0b0ec2bd54c75309cb27d75f0f688ec.png

次にこのモデルにもう一枚の画像を入力し、その推論結果を見てみましょう。

threshold = 0.5
image_path = 'gwhd/images/cb34f7509.jpg'

image = PIL.Image.open(image_path).convert('RGB')
input_tensor = torchvision.transforms.functional.to_tensor(image).unsqueeze(0).to(device)
    
with torch.no_grad():
    predictions = model(input_tensor)[0]
    
bboxes = predictions['boxes'][predictions['scores'] > threshold]
labels = predictions['labels'][predictions['scores'] > threshold]
scores = predictions['scores'][predictions['scores'] > threshold]

draw = PIL.ImageDraw.Draw(image)
for bbox, label, score in zip(bboxes, labels, scores):
    x1, y1, x2, y2 = bbox
    draw.rectangle(((x1, y1), (x2, y2)), outline="blue", width=3)
    draw.text((x1, y1 - 10), f"{label.item()} ({score:.2f})", fill="blue")
    
fig = plt.figure()
ax = fig.add_subplot()
ax.imshow(image)
ax.axis('off')
fig.show()
../_images/807395021a782a3e5cf6ad84180065091cedc8df4e08d9d7e93185c8556a6de4.png
threshold = 0.5
image_path = 'gwhd/images/d3b3b5628.jpg'

image = PIL.Image.open(image_path).convert('RGB')
input_tensor = torchvision.transforms.functional.to_tensor(image).unsqueeze(0).to(device)
    
with torch.no_grad():
    predictions = model(input_tensor)[0]
    
bboxes = predictions['boxes'][predictions['scores'] > threshold]
labels = predictions['labels'][predictions['scores'] > threshold]
scores = predictions['scores'][predictions['scores'] > threshold]

draw = PIL.ImageDraw.Draw(image)
for bbox, label, score in zip(bboxes, labels, scores):
    x1, y1, x2, y2 = bbox
    draw.rectangle(((x1, y1), (x2, y2)), outline="blue", width=3)
    draw.text((x1, y1 - 10), f"{label.item()} ({score:.2f})", fill="blue")
    
fig = plt.figure()
ax = fig.add_subplot()
ax.imshow(image)
ax.axis('off')
fig.show()
../_images/a49a25e70b5aa9b1c893336ad2417ddefb3adecd2edc98086c786b5f1a24fdde.png