4.4. 演習:内視鏡画像ポリープ検出#

医療分野では、内視鏡画像や X 線画像を用いた病変部位の検出に関する研究が活発に進められています。最近では、これらの技術が内視鏡システムに組み込まれ、医師の診断を補助する製品として実用化されています。また、このような内視鏡画像からポリープを検出するシステムも、PyTorch のようなライブラリを活用することで、個人や小規模なチームでも開発することが可能です。本節では、内視鏡画像からポリープを検出するモデルを構築する方法について学びます。

4.4.1. 演習準備#

4.4.1.1. ライブラリ#

本節で利用するライブラリを読み込みます。NumPy、Pnadas、Matplotlib、Pillow(PIL)などのライブライは、モデルの性能や推論結果などの可視化に利用します。scikit-learn(sklearn)、PyTorch(torch)、torchvision、torchmetrics は機械学習関連のライブラリであり、モデルの構築、検証や推論などに利用します。

import os
import numpy as np
import pandas as pd
import PIL.Image
import matplotlib.pyplot as plt

import torch
import torchvision
import torchmetrics

print(f'torch v{torch.__version__}; torchvision v{torchvision.__version__}')
torch v2.5.1+cu121; torchvision v0.20.1+cu121

ライブラリの読み込み時に ImportErrorModuleNotFoundError が発生した場合は、該当するライブラリをインストールしてください。ライブラリのバージョンを揃える必要はありませんが、PyTorch(torch)および torchvision が上記のバージョンと異なる時、実行中に警告メッセージが現れたり、同じ結果にならなかったりする可能性があります。

4.4.1.2. データセット#

本節では、Kvasir データセット[1]を使用します。このデータセットは内視鏡画像を集めた医療用データセットで、Simula Research Laboratory にて公開されています。Kvasir データセットは研究および教育目的に限り利用可能で、それ以外の用途での使用は許可されていません[2]。データセットを扱う際は、利用規約を必ず遵守してください。

オリジナルの Kvasir データセットでは、セグメンテーション用の画像にはポリープを含むもののみが収録されています。しかし、実際の医療現場では、ポリープが存在しない健常者の画像も含まれることが一般的です。そこで、本演習では、オリジナルの Kvasir セグメンテーション用データから一部を抽出し、そこに Kvasir 分類用データから健常者の内視鏡画像を追加して、新しいデータセットを作成しました。

新しいデータセットは、訓練データ 120 枚、検証データ 30 枚、テストデータ 30 枚で構成されています。訓練データは 100 枚の画像がポリープを持つ画像であり、残りの 20 枚は健全な画像です。また、検証データとテストデータはそれぞれ 20 枚のポリープ画像と 10 枚の健全画像が含まれています。

../_images/kvasir_detection_dataset.jpg

Fig. 4.18 Kvasir データセットに含まれる各カテゴリのサンプル画像。#

Jupyter Notebook 上では、次のコマンドを実行することでダウンロードできます。

!wget https://dl.biopapyrus.jp/data/kvasirdet.zip
!unzip kvasirdet.zip

4.4.1.3. 画像前処理#

物体検出の問題では、画像と一緒に、検出対象の物体を囲むバウンディングボックスの座標とラベルをモデルに与え、学習させる必要があります。バウンディングボックスの座標とラベルは、一般的に COCO フォーマット(.json)や Pascal VOC フォーマット(.xml)などで保存されます。しかし、これらのフォーマットのままでは PyTorch で直接扱えないため、PyTorch が利用できる形式に変換する必要があります。

さらに、PyTorch では、画像にアノテーションがない場合にエラーが発生するため、ポリープが含まれない健常者の画像については特別な処理を行っています。具体的には、ダミーのバウンディングボックス([0, 0, 1, 1])を設定し、そのラベルを背景クラス(0)として扱うようにしています。この対応により、健常者の画像もエラーなく処理できるようになります。

class CocoDataset(torchvision.datasets.CocoDetection):
    def __init__(self, root, annFile):
        super(CocoDataset, self).__init__(root, annFile)
    
    def __getitem__(self, idx):
        img, target = super(CocoDataset, self).__getitem__(idx)
        
        boxes = []
        labels = []
        for obj in target:
            bbox = obj['bbox']
            bbox = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]
            boxes.append(bbox)
            labels.append(obj['category_id'])
        if len(boxes) == 0:
            boxes = [[0, 0, 1, 1]]
            labels = [0]
        
        img = torchvision.transforms.functional.to_tensor(img)
        target = {
            'boxes': torch.as_tensor(boxes, dtype=torch.float32),
            'labels': torch.as_tensor(labels, dtype=torch.int64),
        }

        return img, target
    

なお、画像分類と同様に、畳み込みニューラルネットワークでは入力する画像のサイズを指定されたサイズに変更する必要があります。ただし、このサイズ変更は CocoDetection 内で行われるため、ここであらためて設定する必要はありません。

また、モデルを訓練する際には、画像の拡大縮小や平行移動、回転などのデータ拡張を行い、それに伴ってバウンディングボックスの座標も同じように再計算する必要があります。しかし、これらの処理を加えるとコードが複雑化し、全体の流れがわかりにくくなるため、本節ではデータ拡張の処理は省略しています。

4.4.1.4. 計算デバイス#

計算を行うデバイスを設定します。PyTorch が GPU を認識できる場合は GPU を利用し、認識できない場合は CPU を使用するように設定します。

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

4.4.2. モデル構築#

本節では、物体検出アーキテクチャとしてよく知られている Faster R-CNN を使用します。torchvision.models.detection で提供されているアーキテクチャは、車や人など 90 種類の一般的なオブジェクトを対象としています。これに対して、本節では、ポリープという 1 種類のオブジェクトのみを検出を目的としています。そのため、torchvision.models.detection から読み込んだアーキテクチャの出力層のユニット数を、検出対象の種類数に合わせる必要があります。

物体検出アーキテクチャでは、背景を一つのクラスとして扱うため、出力数を修正するとき、検出対象数に 1 を加えた値で修正します。例えば、ポリープ検出の場合は、出力層の数を 2 とします。この修正は、アーキテクチャを呼び出すたびに行う必要があり、手間がかかります。そこで、一連の処理を関数として定義してから利用します。

def fasterrcnn(num_classes, weights=None):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    num_classes = num_classes + 1  # class + background
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
    if weights is not None:
        model.load_state_dict(torch.load(weights))
    return model

model = fasterrcnn(num_classes=1)
model.to(device)
Hide code cell output
FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): FrozenBatchNorm2d(256, eps=0.0)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(inplace=True)
        )
      )
      (layer2): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=0.0)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=0.0)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): FrozenBatchNorm2d(512, eps=0.0)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=0.0)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=0.0)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=0.0)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=0.0)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=0.0)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=0.0)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=0.0)
          (relu): ReLU(inplace=True)
        )
      )
      (layer3): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): FrozenBatchNorm2d(1024, eps=0.0)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (4): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (5): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
      )
      (layer4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(512, eps=0.0)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(512, eps=0.0)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(2048, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): FrozenBatchNorm2d(2048, eps=0.0)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(512, eps=0.0)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(512, eps=0.0)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(2048, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(512, eps=0.0)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(512, eps=0.0)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(2048, eps=0.0)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (fpn): FeaturePyramidNetwork(
      (inner_blocks): ModuleList(
        (0): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (layer_blocks): ModuleList(
        (0-3): 4 x Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (extra_blocks): LastLevelMaxPool()
    )
  )
  (rpn): RegionProposalNetwork(
    (anchor_generator): AnchorGenerator()
    (head): RPNHead(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
        )
      )
      (cls_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))
      (bbox_pred): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (roi_heads): RoIHeads(
    (box_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7), sampling_ratio=2)
    (box_head): TwoMLPHead(
      (fc6): Linear(in_features=12544, out_features=1024, bias=True)
      (fc7): Linear(in_features=1024, out_features=1024, bias=True)
    )
    (box_predictor): FastRCNNPredictor(
      (cls_score): Linear(in_features=1024, out_features=2, bias=True)
      (bbox_pred): Linear(in_features=1024, out_features=8, bias=True)
    )
  )
)

4.4.3. モデル訓練#

モデルが学習データを効率よく学習できるようにするため、学習アルゴリズム(optimizer)、学習率(lr)、および学習率を調整するスケジューラ(lr_scheduler)を設定します。なお、画像分類では損失関数も合わせて定義していますが、物体検出では分類誤差を計算する損失関数とバウンディングボックスの座標の誤差を計算する損失関数の二種類を定義する必要があります。これらの関数はすでにモデルの中で定義されているため、ここであらためて定義する必要はありません。

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

次に、訓練データと検証データを読み込み、モデルが入力できる形式に整えます。

train_loader = torch.utils.data.DataLoader(
                    CocoDataset('kvasirdet/train', 'kvasirdet/train/bbox.json'),
                    batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
valid_loader = torch.utils.data.DataLoader(
                    CocoDataset('kvasirdet/valid', 'kvasirdet/valid/bbox.json'),
                    batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))
Hide code cell output
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!

準備が整ったら、訓練を開始します。訓練プロセスでは、訓練と検証を交互に繰り返します。訓練では、訓練データを使ってモデルのパラメータを更新し、その際の損失(誤差)を記録します。検証では、検証データを使ってモデルの予測性能(mAP)を計算し、その結果を記録します。

num_epochs = 10
metric_dict = []

for epoch in range(num_epochs):
    # training phase
    model.train()
    epoch_loss_dict = {}

    for images, targets in train_loader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        batch_loss_dict = model(images, targets)
        batch_tol_loss = 0
        for loss_type, loss_val in batch_loss_dict.items():
            batch_tol_loss += loss_val
            if loss_type in epoch_loss_dict:
                epoch_loss_dict[f'train_{loss_type}'] += loss_val.item()
            else:
                epoch_loss_dict[f'train_{loss_type}'] = loss_val.item()
                
        # update weights
        optimizer.zero_grad()
        batch_tol_loss.backward()
        optimizer.step()
    lr_scheduler.step()


    # validation phase
    model.eval()
    metric = torchmetrics.detection.mean_ap.MeanAveragePrecision()
    with torch.no_grad():
        for images, targets in valid_loader:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            metric.update(model(images), targets)


    # record training loss
    epoch_loss_dict['train_loss_total'] = sum(epoch_loss_dict.values())
    metric_dict.append({k: v / len(train_loader) for k, v in epoch_loss_dict.items()})
    for k, v in metric.compute().items():
        if k != 'classes':
            metric_dict[-1][k] = v.item()
    metric_dict[-1]['epoch'] = epoch + 1

    print(metric_dict[-1])
Hide code cell output
{'train_loss_classifier': 0.0038107074797153473, 'train_loss_box_reg': 0.0034361253182093304, 'train_loss_objectness': 0.00041085205351312954, 'train_loss_rpn_box_reg': 0.00190495103597641, 'train_loss_total': 0.009562635887414217, 'map': 0.10639167577028275, 'map_50': 0.24018487334251404, 'map_75': 0.10837788134813309, 'map_small': 0.0, 'map_medium': -1.0, 'map_large': 0.2127833515405655, 'mar_1': 0.13500000536441803, 'mar_10': 0.26249998807907104, 'mar_100': 0.26249998807907104, 'mar_small': 0.0, 'mar_medium': -1.0, 'mar_large': 0.5249999761581421, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 1}
{'train_loss_classifier': 0.003062793364127477, 'train_loss_box_reg': 0.003503084679444631, 'train_loss_objectness': 0.0004165112661818663, 'train_loss_rpn_box_reg': 0.0003838776610791683, 'train_loss_total': 0.007366266970833143, 'map': 0.129903644323349, 'map_50': 0.2786444127559662, 'map_75': 0.09775569289922714, 'map_small': 0.0, 'map_medium': -1.0, 'map_large': 0.2598143517971039, 'mar_1': 0.11999999731779099, 'mar_10': 0.25999999046325684, 'mar_100': 0.2775000035762787, 'mar_small': 0.0, 'mar_medium': -1.0, 'mar_large': 0.5550000071525574, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 2}
{'train_loss_classifier': 0.002542121708393097, 'train_loss_box_reg': 0.003187131881713867, 'train_loss_objectness': 0.00017864097220202286, 'train_loss_rpn_box_reg': 0.00021028382082780202, 'train_loss_total': 0.006118178383136789, 'map': 0.14222273230552673, 'map_50': 0.3193133771419525, 'map_75': 0.07043728232383728, 'map_small': 0.0, 'map_medium': -1.0, 'map_large': 0.2845240831375122, 'mar_1': 0.1550000011920929, 'mar_10': 0.26499998569488525, 'mar_100': 0.26499998569488525, 'mar_small': 0.0, 'mar_medium': -1.0, 'mar_large': 0.5299999713897705, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 3}
{'train_loss_classifier': 0.0016972508281469345, 'train_loss_box_reg': 0.0020544350147247313, 'train_loss_objectness': 0.00025559401450057825, 'train_loss_rpn_box_reg': 0.0004213443025946617, 'train_loss_total': 0.004428624159966906, 'map': 0.19554124772548676, 'map_50': 0.39081814885139465, 'map_75': 0.19486503303050995, 'map_small': 0.0, 'map_medium': -1.0, 'map_large': 0.3910824954509735, 'mar_1': 0.23999999463558197, 'mar_10': 0.2849999964237213, 'mar_100': 0.2849999964237213, 'mar_small': 0.0, 'mar_medium': -1.0, 'mar_large': 0.5699999928474426, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 4}
{'train_loss_classifier': 0.0009562353913982709, 'train_loss_box_reg': 0.0011748326321442922, 'train_loss_objectness': 0.00035701123997569086, 'train_loss_rpn_box_reg': 0.00015609899225334326, 'train_loss_total': 0.002644178255771597, 'map': 0.19693368673324585, 'map_50': 0.3860412836074829, 'map_75': 0.17650237679481506, 'map_small': 0.0, 'map_medium': -1.0, 'map_large': 0.3938673734664917, 'mar_1': 0.26249998807907104, 'mar_10': 0.2775000035762787, 'mar_100': 0.2775000035762787, 'mar_small': 0.0, 'mar_medium': -1.0, 'mar_large': 0.5550000071525574, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 5}
{'train_loss_classifier': 0.0007708690439661344, 'train_loss_box_reg': 0.0008639856552084287, 'train_loss_objectness': 5.642389490579565e-05, 'train_loss_rpn_box_reg': 0.00028497669845819475, 'train_loss_total': 0.0019762552925385534, 'map': 0.19178402423858643, 'map_50': 0.377999484539032, 'map_75': 0.18770338594913483, 'map_small': 0.0, 'map_medium': -1.0, 'map_large': 0.38374125957489014, 'mar_1': 0.23499999940395355, 'mar_10': 0.2750000059604645, 'mar_100': 0.2750000059604645, 'mar_small': 0.0, 'mar_medium': -1.0, 'mar_large': 0.550000011920929, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 6}
{'train_loss_classifier': 0.0005954490974545479, 'train_loss_box_reg': 0.0007406117394566536, 'train_loss_objectness': 0.00020855683833360673, 'train_loss_rpn_box_reg': 0.0003224297116200129, 'train_loss_total': 0.0018670473868648211, 'map': 0.19659025967121124, 'map_50': 0.37948426604270935, 'map_75': 0.20666946470737457, 'map_small': 0.0, 'map_medium': -1.0, 'map_large': 0.3931805193424225, 'mar_1': 0.2224999964237213, 'mar_10': 0.29249998927116394, 'mar_100': 0.29249998927116394, 'mar_small': 0.0, 'mar_medium': -1.0, 'mar_large': 0.5849999785423279, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 7}
{'train_loss_classifier': 0.0004533529902497927, 'train_loss_box_reg': 0.0006452140708764394, 'train_loss_objectness': 1.1965065884093443e-05, 'train_loss_rpn_box_reg': 7.20159305880467e-05, 'train_loss_total': 0.0011825480575983723, 'map': 0.19579236209392548, 'map_50': 0.38376930356025696, 'map_75': 0.20205460488796234, 'map_small': 0.0, 'map_medium': -1.0, 'map_large': 0.39195314049720764, 'mar_1': 0.23250000178813934, 'mar_10': 0.2800000011920929, 'mar_100': 0.2800000011920929, 'mar_small': 0.0, 'mar_medium': -1.0, 'mar_large': 0.5600000023841858, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 8}
{'train_loss_classifier': 0.0005828270067771275, 'train_loss_box_reg': 0.0007636663814385732, 'train_loss_objectness': 4.412510897964239e-05, 'train_loss_rpn_box_reg': 0.0003515062853693962, 'train_loss_total': 0.0017421247825647394, 'map': 0.1982128918170929, 'map_50': 0.3865389823913574, 'map_75': 0.2007303684949875, 'map_small': 0.0, 'map_medium': -1.0, 'map_large': 0.3966187536716461, 'mar_1': 0.24250000715255737, 'mar_10': 0.2824999988079071, 'mar_100': 0.2824999988079071, 'mar_small': 0.0, 'mar_medium': -1.0, 'mar_large': 0.5649999976158142, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 9}
{'train_loss_classifier': 0.0005450807511806488, 'train_loss_box_reg': 0.0007525745779275895, 'train_loss_objectness': 0.00015938753883043926, 'train_loss_rpn_box_reg': 0.00011849292398740848, 'train_loss_total': 0.001575535791926086, 'map': 0.20113347470760345, 'map_50': 0.3854667842388153, 'map_75': 0.2113003432750702, 'map_small': 0.0, 'map_medium': -1.0, 'map_large': 0.4024599492549896, 'mar_1': 0.24250000715255737, 'mar_10': 0.28999999165534973, 'mar_100': 0.28999999165534973, 'mar_small': 0.0, 'mar_medium': -1.0, 'mar_large': 0.5799999833106995, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 10}

訓練データに対する損失と検証データに対する予測性能(mAP)を可視化し、訓練過程を評価します。

Hide code cell source
metric_dict = pd.DataFrame(metric_dict)

fig, ax = plt.subplots(1, 2)
ax[0].plot(metric_dict['epoch'], metric_dict['train_loss_total'], label='total')
ax[0].plot(metric_dict['epoch'], metric_dict['train_loss_classifier'], label='classification')
ax[0].plot(metric_dict['epoch'], metric_dict['train_loss_box_reg'], label='bbox')
ax[0].set_xlabel('epoch')
ax[0].set_ylabel('loss')
ax[0].set_title('Train')
ax[0].legend()
ax[1].plot(metric_dict['epoch'], metric_dict['map_50'])
ax[1].set_ylim(0, 1)
ax[1].set_xlabel('epoch')
ax[1].set_ylabel('mAP (50%)')
ax[1].set_title('Validation')
plt.tight_layout()
fig.show()
../_images/49b0f80de89ff632648ed52b02c7269bdd153bac5408bbae05734a6f568a54f3.png

可視化の結果から、エポック数が増えるにつれて訓練データに対する損失が継続的に減少していることが確認できます。10 エポック目においても訓練損失が減少し続ける傾向がまだ見られます。一方、検証データに対する検出性能(mAP 50%)は、5 エポックを過ぎたあたりでほぼ収束しているようです。ただし、値が と低く、十分とはいえません。このため、訓練エポック数をさらに増やして損失や検証性能の推移を詳しく観察するか、必要に応じてデータを増やして再訓練することが考えられます。ただし、本節では、時間の制約があるため、訓練はここで終了します。

次に、この手順を SSD や YOLO など、他の深層ニューラルネットワークのアーキテクチャに適用し、それぞれの検証性能を比較します。この比較により、データセットに最も適したアーキテクチャを選定します。ただし、本節では時間の関係で他のアーキテクチャを構築せず、上で構築した Faster R-CNN を最適なアーキテクチャとして扱い、次のステップに進みます。

次のステップでは、訓練サブセットと検証サブセットを統合し、最適と判断したアーキテクチャを最初から訓練します。その準備として、まず訓練サブセットと検証サブセットを結合します。

!mkdir kvasirdet/trainvalid/images
!cp kvasirdet/train/images/* kvasirdet/trainvalid/images
!cp kvasirdet/valid/images/* kvasirdet/trainvalid/images

モデル選択のために行われた訓練と検証の結果から、数エポックの訓練だけでも十分に高い予測性能を獲得できたことがわかったので、ここでは訓練サブセットと検証サブセットを統合したデータに対して 5 エポックだけ訓練させます。

# model
model = fasterrcnn(num_classes=1)
model.to(device)

# training parameters
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

# training data
train_loader = torch.utils.data.DataLoader(
                            CocoDataset('kvasirdet/trainvalid', 'kvasirdet/trainvalid/bbox.json'),
                            batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

# training
num_epochs = 5
metric_dict = []
for epoch in range(num_epochs):
    model.train()
    epoch_loss_dict = {}
    for images, targets in train_loader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        batch_loss_dict = model(images, targets)
        batch_tol_loss = 0
        for loss_type, loss_val in batch_loss_dict.items():
            batch_tol_loss += loss_val
            if loss_type in epoch_loss_dict:
                epoch_loss_dict[f'train_{loss_type}'] += loss_val.item()
            else:
                epoch_loss_dict[f'train_{loss_type}'] = loss_val.item()
        optimizer.zero_grad()
        batch_tol_loss.backward()
        optimizer.step()
    lr_scheduler.step()

    # record training loss
    epoch_loss_dict['train_loss_total'] = sum(epoch_loss_dict.values())
    metric_dict.append({k: v / len(train_loader) for k, v in epoch_loss_dict.items()})
    metric_dict[-1]['epoch'] = epoch + 1
    print(metric_dict[-1])
Hide code cell output
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
{'train_loss_classifier': 0.0007069695642904231, 'train_loss_box_reg': 0.0009978428286941427, 'train_loss_objectness': 9.190286264607781e-05, 'train_loss_rpn_box_reg': 0.0014187528898841457, 'train_loss_total': 0.0032154681455147894, 'epoch': 1}
{'train_loss_classifier': 0.0007008846947237065, 'train_loss_box_reg': 0.001675179522288473, 'train_loss_objectness': 2.304443018817294e-06, 'train_loss_rpn_box_reg': 1.5938838355635343e-05, 'train_loss_total': 0.0023943074983866323, 'epoch': 2}
{'train_loss_classifier': 0.00011188924116523642, 'train_loss_box_reg': 0.0, 'train_loss_objectness': 7.444352377206087e-05, 'train_loss_rpn_box_reg': 0.007048018668827258, 'train_loss_total': 0.007234351433764555, 'epoch': 3}
{'train_loss_classifier': 0.0003254322188073083, 'train_loss_box_reg': 0.0002108464556697168, 'train_loss_objectness': 2.2420080648244995e-05, 'train_loss_rpn_box_reg': 0.0002923292401981981, 'train_loss_total': 0.0008510279953234681, 'epoch': 4}
{'train_loss_classifier': 0.0003888232045267758, 'train_loss_box_reg': 0.0006029748504883365, 'train_loss_objectness': 8.524450669555288e-05, 'train_loss_rpn_box_reg': 0.00022346629320006621, 'train_loss_total': 0.0013005088549107313, 'epoch': 5}

訓練が完了したら、訓練済みモデルの重みをファイルに保存します。

model.to('cpu')
torch.save(model.state_dict(), 'kvasirdet.pth')

4.4.4. モデル評価#

最適なモデルが得られたら、次にテストデータを用いてモデルを詳細に評価します。

test_loader = torch.utils.data.DataLoader(
                    CocoDataset('kvasirdet/test', 'kvasirdet/test/bbox.json'),
                    batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

model = fasterrcnn(num_classes=1, weights='kvasirdet.pth')
model.to(device)
model.eval()

metric = torchmetrics.detection.mean_ap.MeanAveragePrecision()
with torch.no_grad():
    for images, targets in test_loader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        metric.update(model(images), targets)

metrics = {}
for k, v in metric.compute().items():
    metrics[k] = v.cpu().detach().numpy().tolist()
Hide code cell output
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
metrics
{'map': 0.31070682406425476,
 'map_50': 0.43754759430885315,
 'map_75': 0.35496529936790466,
 'map_small': 0.0,
 'map_medium': -1.0,
 'map_large': 0.6214136481285095,
 'mar_1': 0.3149999976158142,
 'mar_10': 0.35749998688697815,
 'mar_100': 0.35749998688697815,
 'mar_small': 0.0,
 'mar_medium': -1.0,
 'mar_large': 0.7149999737739563,
 'map_per_class': -1.0,
 'mar_100_per_class': -1.0,
 'classes': [0, 1]}

ここで出力される指標について、map から始まる指標は mAP を表し、mar から始まる指標は mean average recall(平均再現率)を示します。mar はすべてのクラスに対する再現率を計算し、それらの平均を求めたものです。mAR 1 は、各画像に対してモデルが検出した物体のうち、最も信頼度の高い物体だけを利用して計算した再現率を表します。同様に、mAR 10 および mAR 100 は、モデルが検出した物体のうち、信頼度の高い 10 および 100 物体を利用して計算した再現率を表しています。

4.4.5. 推論#

推論時にも、訓練時と同じように torchvision モジュールからアーキテクチャを呼び出し、出力層のクラス数を設定します。その後、load_state_dict メソッドを使って、訓練済みの重みファイルをモデルにロードします。これらの操作はすべて fasterrcnn 関数で定義されているので、その関数を利用します。

model = fasterrcnn(num_classes=1, weights='kvasirdet.pth')
model.to(device)
model.eval()
Hide code cell output
FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): FrozenBatchNorm2d(256, eps=0.0)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(inplace=True)
        )
      )
      (layer2): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=0.0)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=0.0)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): FrozenBatchNorm2d(512, eps=0.0)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=0.0)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=0.0)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=0.0)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=0.0)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(128, eps=0.0)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(128, eps=0.0)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(512, eps=0.0)
          (relu): ReLU(inplace=True)
        )
      )
      (layer3): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): FrozenBatchNorm2d(1024, eps=0.0)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (4): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (5): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(256, eps=0.0)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(256, eps=0.0)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(1024, eps=0.0)
          (relu): ReLU(inplace=True)
        )
      )
      (layer4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(512, eps=0.0)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(512, eps=0.0)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(2048, eps=0.0)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): FrozenBatchNorm2d(2048, eps=0.0)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(512, eps=0.0)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(512, eps=0.0)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(2048, eps=0.0)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(512, eps=0.0)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(512, eps=0.0)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(2048, eps=0.0)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (fpn): FeaturePyramidNetwork(
      (inner_blocks): ModuleList(
        (0): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (layer_blocks): ModuleList(
        (0-3): 4 x Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (extra_blocks): LastLevelMaxPool()
    )
  )
  (rpn): RegionProposalNetwork(
    (anchor_generator): AnchorGenerator()
    (head): RPNHead(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
        )
      )
      (cls_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))
      (bbox_pred): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (roi_heads): RoIHeads(
    (box_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7), sampling_ratio=2)
    (box_head): TwoMLPHead(
      (fc6): Linear(in_features=12544, out_features=1024, bias=True)
      (fc7): Linear(in_features=1024, out_features=1024, bias=True)
    )
    (box_predictor): FastRCNNPredictor(
      (cls_score): Linear(in_features=1024, out_features=2, bias=True)
      (bbox_pred): Linear(in_features=1024, out_features=8, bias=True)
    )
  )
)

このモデルを利用して推論を行います。まず、1 枚の画像を指定し、PIL モジュールを用いて画像を開き、テンソル形式に変換した後、モデルに入力します。モデルは予測結果としてバウンディングボックスの座標(bboxes)、分類ラベル(labels)、および信頼スコア(scores)を出力します。ただし、信頼スコアが 0.5 未満のバウンディングボックスは採用せず、信頼スコアが高い結果のみを選択して利用します。

threshold = 0.5
image_path = 'kvasirdet/test/images/cju2lz8vqktne0993fuym6drw.jpg'

image = PIL.Image.open(image_path).convert('RGB')
input_tensor = torchvision.transforms.functional.to_tensor(image).unsqueeze(0).to(device)
    
with torch.no_grad():
    predictions = model(input_tensor)[0]
    
bboxes = predictions['boxes'][predictions['scores'] > threshold]
labels = predictions['labels'][predictions['scores'] > threshold]
scores = predictions['scores'][predictions['scores'] > threshold]

検出されたオブジェクトのバウンディングボックスを入力画像に描画します。その後、PIL および matplotlib ライブラリを使用して、画像とその検出結果を可視化します。

draw = PIL.ImageDraw.Draw(image)
for bbox, label, score in zip(bboxes, labels, scores):
    x1, y1, x2, y2 = bbox
    draw.rectangle(((x1, y1), (x2, y2)), outline="blue", width=3)
    draw.text((x1, y1 - 10), f"{label.item()} ({score:.2f})", fill="blue")
    
fig = plt.figure()
ax = fig.add_subplot()
ax.imshow(image)
ax.axis('off')
fig.show()
../_images/7b571d7553a2fea0b248ab9e503337af9860e46d647136a26fb3541e92cee5f0.png

次にこのモデルにもう一枚の画像を入力し、その推論結果を見てみましょう。

threshold = 0.5
image_path = 'kvasirdet/test/images/cju2hqt33lmra0988fr5ijv8j.jpg'

image = PIL.Image.open(image_path).convert('RGB')
input_tensor = torchvision.transforms.functional.to_tensor(image).unsqueeze(0).to(device)
    
with torch.no_grad():
    predictions = model(input_tensor)[0]
    
bboxes = predictions['boxes'][predictions['scores'] > threshold]
labels = predictions['labels'][predictions['scores'] > threshold]
scores = predictions['scores'][predictions['scores'] > threshold]

draw = PIL.ImageDraw.Draw(image)
for bbox, label, score in zip(bboxes, labels, scores):
    x1, y1, x2, y2 = bbox
    draw.rectangle(((x1, y1), (x2, y2)), outline="blue", width=3)
    draw.text((x1, y1 - 10), f"{label.item()} ({score:.2f})", fill="blue")
    
fig = plt.figure()
ax = fig.add_subplot()
ax.imshow(image)
ax.axis('off')
fig.show()
../_images/a7d6a4521b98c9b1bfe0fed426cf05ee588672ed5f8673d211edcbe14457967f.png