4.5. 演習:血液細胞計数#

顕微鏡画像の解析に深層学習を活用することで、作業の効率化が期待されています。例えば、基礎研究において、顕微鏡を使って細胞を種類ごとに分類し、その数を調べる作業が行われています。本節では、このような物体検出と個数カウントの実例として、顕微鏡で撮影した血液細胞の画像から赤血球、白血球、血小板を検出し、それぞれの個数を数える方法を学びます。

4.5.1. 演習準備#

4.5.1.1. ライブラリ#

本節で利用するライブラリを読み込みます。NumPy、Pnadas、Matplotlib、Pillow(PIL)などのライブライは、モデルの性能や推論結果などの可視化に利用します。scikit-learn(sklearn)、PyTorch(torch)、torchvision、torchmetrics は機械学習関連のライブラリであり、モデルの構築、検証や推論などに利用します。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import PIL.Image

import torch
import torchvision
import torchmetrics

print(f'torch v{torch.__version__}; torchvision v{torchvision.__version__}')
torch v2.5.1+cu121; torchvision v0.20.1+cu121

ライブラリの読み込み時に、ImportErrorModuleNotFoundError が発生した場合、必要なライブラリがインストールされていない可能性があります。該当ライブラリをインストールしてください。

4.5.1.2. データセット#

本節では、BCCD (blood cell count and detection) データセットを利用します。このデータセットは顕微鏡を利用して撮影された血液細胞の画像で、主に赤血球(red blood cell; RBC)、白血球(white blood cell; WBC)、血小板(platelet)を含みます。BCCD データセットは MIT ライセンスのもとで GitHub BCCD_Dataset で公開されており、著作権表示を行うことで自由に利用することができます。

../_images/bccd_dataset.jpg

Fig. 4.19 BCCD データセットに含まれる顕微鏡画像のサンプル。#

本節では、オリジナルの BCCD データセットを再整理して作成したデータセットを利用します。本節で利用するデータセットは訓練、検証、そしてテストの 3 つのサブセットからなり、それぞれのサブセットにおいて各カテゴリに 50 枚、20 枚、20 枚の画像が含まれています。

Jupyter Notebook では、次のコマンドを実行することで、データをダウンロードできます。

!wget https://medDL.biopapyrus.jp/datasets/bccd.zip
!unzip bccd.zip

4.5.1.3. 前処理#

物体検出の問題では、画像と一緒に、検出対象の物体を囲むバウンディングボックスの座標とラベルをモデルに与え、学習させる必要があります。バウンディングボックスの座標とラベルは、一般的に COCO フォーマット(.json)や Pascal VOC フォーマット(.xml)などで保存されます。しかし、これらのフォーマットのままでは PyTorch で直接扱えないため、PyTorch が利用できる形式に変換する必要があります。

class CocoDataset(torchvision.datasets.CocoDetection):
    def __init__(self, root, annFile):
        super(CocoDataset, self).__init__(root, annFile)
    
    def __getitem__(self, idx):
        img, target = super(CocoDataset, self).__getitem__(idx)
        
        boxes = []
        labels = []
        for obj in target:
            bbox = obj['bbox']
            bbox = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]
            boxes.append(bbox)
            labels.append(obj['category_id'])
        
        img = torchvision.transforms.functional.to_tensor(img)
        target = {
            'boxes': torch.as_tensor(boxes, dtype=torch.float32),
            'labels': torch.as_tensor(labels, dtype=torch.int64),
        }

        return img, target
    

なお、画像分類と同様に、畳み込みニューラルネットワークでは入力する画像のサイズを指定されたサイズに変更する必要があります。ただし、このサイズ変更は CocoDetection 内で行われるため、ここであらためて設定する必要はありません。

また、モデルを訓練する際には、画像の拡大縮小や平行移動、回転などのデータ拡張を行い、それに伴ってバウンディングボックスの座標も同じように再計算する必要があります。しかし、これらの処理を加えるとコードが複雑化し、全体の流れがわかりにくくなるため、本節ではデータ拡張の処理は省略しています。

4.5.1.4. 計算デバイス#

PyTorch が GPU を認識できる場合は GPU を利用し、認識できない場合は CPU を利用するように計算デバイスを設定します。

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

4.5.2. モデル構築#

SSD(Single Shot MultiBox Detector)は、高性能な物体検出アーキテクチャであり、torchvision.models モジュールを通じて利用できます。しかし、このアーキテクチャは COCO データセット向けに設計されており、車や人などの 90 種類の一般的なオブジェクトを検出するための構造になっています。そのため、そのままでは血液細胞(赤血球、白血球、血小板)の検出には対応していません。

血液細胞の検出を行うには、このアーキテクチャの分類部分の出力数を、検出対象となる 3 カテゴリ(赤血球、白血球、血小板)に対応させる必要があります。この修正はモデルを呼び出すたびに同じ作業を繰り返す必要があり、手間がかかります。そこで、指定したクラス数に応じてアーキテクチャを生成し、必要に応じて修正を加えられるよう、一連の処理を関数として定義します。

def ssd(num_classes, weights=None):
    model = torchvision.models.detection.ssd300_vgg16(weights='DEFAULT')
    
    in_channels = torchvision.models.detection._utils.retrieve_out_channels(model.backbone, (300, 300))
    num_anchors = model.anchor_generator.num_anchors_per_location()
    model.head.classification_head = torchvision.models.detection.ssd.SSDClassificationHead(
        in_channels=in_channels,
        num_anchors=num_anchors,
        num_classes=num_classes + 1,
    )
    model.transform.min_size = (300,)
    model.transform.max_size = 3000
    if weights is not None:
        model.load_state_dict(torch.load(weights))
    return model

model = ssd(num_classes=3)
model.to(device)
Hide code cell output
SSD(
  (backbone): SSDFeatureExtractorVGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inplace=True)
      (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
      (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (18): ReLU(inplace=True)
      (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (20): ReLU(inplace=True)
      (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (22): ReLU(inplace=True)
    )
    (extra): ModuleList(
      (0): Sequential(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): ReLU(inplace=True)
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): ReLU(inplace=True)
        (5): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
        (7): Sequential(
          (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
          (1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))
          (2): ReLU(inplace=True)
          (3): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (4): ReLU(inplace=True)
        )
      )
      (1): Sequential(
        (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (3): ReLU(inplace=True)
      )
      (2): Sequential(
        (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (3): ReLU(inplace=True)
      )
      (3-4): 2 x Sequential(
        (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
        (3): ReLU(inplace=True)
      )
    )
  )
  (anchor_generator): DefaultBoxGenerator(aspect_ratios=[[2], [2, 3], [2, 3], [2, 3], [2], [2]], clip=True, scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], steps=[8, 16, 32, 64, 100, 300])
  (head): SSDHead(
    (classification_head): SSDClassificationHead(
      (module_list): ModuleList(
        (0): Conv2d(512, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Conv2d(1024, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): Conv2d(512, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): Conv2d(256, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4-5): 2 x Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (regression_head): SSDRegressionHead(
      (module_list): ModuleList(
        (0): Conv2d(512, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Conv2d(1024, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): Conv2d(512, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): Conv2d(256, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4-5): 2 x Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
  )
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.48235, 0.45882, 0.40784], std=[0.00392156862745098, 0.00392156862745098, 0.00392156862745098])
      Resize(min_size=(300,), max_size=3000, mode='bilinear')
  )
)

4.5.3. モデル訓練#

モデルが学習データを効率よく学習できるようにするため、学習アルゴリズム(optimizer)、学習率(lr)、および学習率を調整するスケジューラ(lr_scheduler)を設定します。

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

次に、訓練データと検証データを読み込み、モデルが入力できる形式に整えます。

train_loader = torch.utils.data.DataLoader(
                    CocoDataset('bccd', 'bccd/train.bbox.json'),
                    batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
valid_loader = torch.utils.data.DataLoader(
                    CocoDataset('bccd', 'bccd/valid.bbox.json'),
                    batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))
Hide code cell output
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!

準備が整ったら、訓練を開始します。訓練プロセスでは、訓練と検証を交互に繰り返します。訓練では、訓練データを使ってモデルのパラメータを更新し、その際の損失(誤差)を記録します。検証では、検証データを使ってモデルの予測性能(mAP)を計算し、その結果を記録します。

num_epochs = 10
metric_dict = []

for epoch in range(num_epochs):
    # training phase
    model.train()
    epoch_loss_dict = {}

    for images, targets in train_loader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        batch_loss_dict = model(images, targets)
        batch_tol_loss = 0
        for loss_type, loss_val in batch_loss_dict.items():
            batch_tol_loss += loss_val
            if loss_type in epoch_loss_dict:
                epoch_loss_dict[f'train_{loss_type}'] += loss_val.item()
            else:
                epoch_loss_dict[f'train_{loss_type}'] = loss_val.item()
                
        # update weights
        optimizer.zero_grad()
        batch_tol_loss.backward()
        optimizer.step()
    lr_scheduler.step()

    # validation phase
    model.eval()
    metric = torchmetrics.detection.mean_ap.MeanAveragePrecision()
    with torch.no_grad():
        for images, targets in valid_loader:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            metric.update(model(images), targets)


    # record training loss
    epoch_loss_dict['train_loss_total'] = sum(epoch_loss_dict.values())
    metric_dict.append({k: v / len(train_loader) for k, v in epoch_loss_dict.items()})
    for k, v in metric.compute().items():
        if k != 'classes':
            metric_dict[-1][k] = v.item()
    metric_dict[-1]['epoch'] = epoch + 1

    print(metric_dict[-1])
Hide code cell output
{'train_bbox_regression': 0.07170875255878155, 'train_classification': 0.34223009989811826, 'train_loss_total': 0.41393885245689976, 'map': 0.08750460296869278, 'map_50': 0.19256165623664856, 'map_75': 0.06499755382537842, 'map_small': 0.0, 'map_medium': 0.07740246504545212, 'map_large': 0.17151948809623718, 'mar_1': 0.03004535101354122, 'mar_10': 0.13114511966705322, 'mar_100': 0.23691609501838684, 'mar_small': 0.0, 'mar_medium': 0.2142857164144516, 'mar_large': 0.37119048833847046, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 1}
{'train_bbox_regression': 0.0620613327393165, 'train_classification': 0.2838489275712233, 'train_loss_total': 0.34591026031053984, 'map': 0.16306035220623016, 'map_50': 0.3797067701816559, 'map_75': 0.11639369279146194, 'map_small': 0.0, 'map_medium': 0.17574787139892578, 'map_large': 0.2550510764122009, 'mar_1': 0.07227890938520432, 'mar_10': 0.20452381670475006, 'mar_100': 0.2934126853942871, 'mar_small': 0.0, 'mar_medium': 0.2380952388048172, 'mar_large': 0.4660119116306305, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 2}
{'train_bbox_regression': 0.047172688520871676, 'train_classification': 0.24223883335406965, 'train_loss_total': 0.2894115218749413, 'map': 0.3410320281982422, 'map_50': 0.5738868117332458, 'map_75': 0.3590058386325836, 'map_small': 0.0, 'map_medium': 0.19936926662921906, 'map_large': 0.5243248343467712, 'mar_1': 0.22447845339775085, 'mar_10': 0.3275056779384613, 'mar_100': 0.3992403745651245, 'mar_small': 0.0, 'mar_medium': 0.24920634925365448, 'mar_large': 0.6192262172698975, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 3}
{'train_bbox_regression': 0.05474379429450402, 'train_classification': 0.23787819422208345, 'train_loss_total': 0.2926219885165875, 'map': 0.358513742685318, 'map_50': 0.5643796324729919, 'map_75': 0.39468538761138916, 'map_small': 0.0, 'map_medium': 0.19884093105793, 'map_large': 0.5573099851608276, 'mar_1': 0.24163265526294708, 'mar_10': 0.35766440629959106, 'mar_100': 0.4185487627983093, 'mar_small': 0.0, 'mar_medium': 0.25952380895614624, 'mar_large': 0.645297646522522, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 4}
{'train_bbox_regression': 0.048000803360572226, 'train_classification': 0.23422791407658503, 'train_loss_total': 0.2822287174371573, 'map': 0.3628511428833008, 'map_50': 0.5761359930038452, 'map_75': 0.40079978108406067, 'map_small': 0.0, 'map_medium': 0.20737437903881073, 'map_large': 0.5610173344612122, 'mar_1': 0.24163265526294708, 'mar_10': 0.353730171918869, 'mar_100': 0.4279954731464386, 'mar_small': 0.0, 'mar_medium': 0.2750396728515625, 'mar_large': 0.6470833420753479, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 5}
{'train_bbox_regression': 0.03632826759265019, 'train_classification': 0.2112285173856295, 'train_loss_total': 0.2475567849782797, 'map': 0.36678045988082886, 'map_50': 0.5795672535896301, 'map_75': 0.3980797231197357, 'map_small': 0.0, 'map_medium': 0.20977972447872162, 'map_large': 0.5677263736724854, 'mar_1': 0.24341270327568054, 'mar_10': 0.35766440629959106, 'mar_100': 0.4323832094669342, 'mar_small': 0.0, 'mar_medium': 0.28059524297714233, 'mar_large': 0.6525595188140869, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 6}
{'train_bbox_regression': 0.03765265070475065, 'train_classification': 0.2195291519165039, 'train_loss_total': 0.25718180262125456, 'map': 0.3674083352088928, 'map_50': 0.5775684714317322, 'map_75': 0.39822250604629517, 'map_small': 0.0, 'map_medium': 0.21054765582084656, 'map_large': 0.5676042437553406, 'mar_1': 0.24485260248184204, 'mar_10': 0.3596712052822113, 'mar_100': 0.4328027069568634, 'mar_small': 0.0, 'mar_medium': 0.27861112356185913, 'mar_large': 0.6532738208770752, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 7}
{'train_bbox_regression': 0.022788462730554435, 'train_classification': 0.20725158544687125, 'train_loss_total': 0.23004004817742568, 'map': 0.36692485213279724, 'map_50': 0.577145516872406, 'map_75': 0.3978608250617981, 'map_small': 0.0, 'map_medium': 0.20821280777454376, 'map_large': 0.567468523979187, 'mar_1': 0.24485260248184204, 'mar_10': 0.3564966022968292, 'mar_100': 0.43799546360969543, 'mar_small': 0.0, 'mar_medium': 0.2859523892402649, 'mar_large': 0.6532738208770752, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 8}
{'train_bbox_regression': 0.03986583306239201, 'train_classification': 0.22025251388549805, 'train_loss_total': 0.26011834694789004, 'map': 0.36657676100730896, 'map_50': 0.5764498114585876, 'map_75': 0.3978538513183594, 'map_small': 0.0, 'map_medium': 0.2066395878791809, 'map_large': 0.5680896043777466, 'mar_1': 0.24507936835289001, 'mar_10': 0.35638323426246643, 'mar_100': 0.43765532970428467, 'mar_small': 0.0, 'mar_medium': 0.28357142210006714, 'mar_large': 0.6541666388511658, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 9}
{'train_bbox_regression': 0.03415831923484802, 'train_classification': 0.22326860061058632, 'train_loss_total': 0.2574269198454343, 'map': 0.36664795875549316, 'map_50': 0.5764819979667664, 'map_75': 0.3978741765022278, 'map_small': 0.0, 'map_medium': 0.20670509338378906, 'map_large': 0.5682852864265442, 'mar_1': 0.24507936835289001, 'mar_10': 0.35638323426246643, 'mar_100': 0.43776869773864746, 'mar_small': 0.0, 'mar_medium': 0.28357142210006714, 'mar_large': 0.6544643044471741, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'epoch': 10}

訓練後、訓練中の損失および検証性能の変化を可視化し、訓練が正しく行われたかどうかを確認します。

Hide code cell source
metric_dict = pd.DataFrame(metric_dict)

fig, ax = plt.subplots(1, 2)
ax[0].plot(metric_dict['epoch'], metric_dict['train_loss_total'], label='total')
ax[0].plot(metric_dict['epoch'], metric_dict['train_classification'], label='classification')
ax[0].plot(metric_dict['epoch'], metric_dict['train_bbox_regression'], label='bbox')
ax[0].set_xlabel('epoch')
ax[0].set_ylabel('loss')
ax[0].set_title('Train')
ax[0].legend()
ax[1].plot(metric_dict['epoch'], metric_dict['map_50'])
ax[1].set_ylim(0, 1)
ax[1].set_xlabel('epoch')
ax[1].set_ylabel('mAP (50%)')
ax[1].set_title('Validation')
plt.tight_layout()
fig.show()
../_images/ce04c26c2921281b92692e44fe13d5a30f15be3877e6b19e3e89866d2d9303fa.png

訓練データに対する損失はエポック数の増加に伴って徐々に減少しており、一部で上下するものの、後半では収束しているように見えます。また、検証データに対する検出性能(mAP 50%)は、4 エポック目以降安定して 0.577 前後で推移しています。この結果から、さらなる性能改善を目指す場合には、データセットを増やしたりする必要があるかもしれません。ただし、本節では時間の関係上、ここで訓練を終了します。

次に、この手順を Faster R-CNN や YOLO など、他の深層ニューラルネットワークアーキテクチャに適用し、それぞれの検証性能を比較して、このデータセットに最適なアーキテクチャを選定するプロセスを考えます。ただし、本節では時間の都合上、他のアーキテクチャを構築せず、既に構築した SSD を最適なアーキテクチャと仮定して次のステップに進みます。

次のステップでは、訓練サブセットと検証サブセットを統合したデータセット(trainvalid.bbox.json)を用いて、最適と判断したアーキテクチャを最初から訓練します。

# model
model = ssd(num_classes=3)
model.to(device)

# training parameters
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

# training data
train_loader = torch.utils.data.DataLoader(
                    CocoDataset('bccd', 'bccd/trainvalid.bbox.json'),
                    batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

# training
num_epochs = 5
metric_dict = []
for epoch in range(num_epochs):
    model.train()
    epoch_loss_dict = {}
    for images, targets in train_loader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        batch_loss_dict = model(images, targets)
        batch_tol_loss = 0
        for loss_type, loss_val in batch_loss_dict.items():
            batch_tol_loss += loss_val
            if loss_type in epoch_loss_dict:
                epoch_loss_dict[f'train_{loss_type}'] += loss_val.item()
            else:
                epoch_loss_dict[f'train_{loss_type}'] = loss_val.item()
                
        # update weights
        optimizer.zero_grad()
        batch_tol_loss.backward()
        optimizer.step()
    lr_scheduler.step()

    # record training loss
    epoch_loss_dict['train_loss_total'] = sum(epoch_loss_dict.values())
    metric_dict.append({k: v / len(train_loader) for k, v in epoch_loss_dict.items()})
    metric_dict[-1]['epoch'] = epoch + 1

    print(metric_dict[-1])
Hide code cell output
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
{'train_bbox_regression': 0.05104579859309726, 'train_classification': 0.24258189731174046, 'train_loss_total': 0.2936276959048377, 'epoch': 1}
{'train_bbox_regression': 0.036310940980911255, 'train_classification': 0.1893100208706326, 'train_loss_total': 0.22562096185154384, 'epoch': 2}
{'train_bbox_regression': 0.04107703765233358, 'train_classification': 0.14839735296037462, 'train_loss_total': 0.1894743906127082, 'epoch': 3}
{'train_bbox_regression': 0.02335899571577708, 'train_classification': 0.1372472577624851, 'train_loss_total': 0.16060625347826216, 'epoch': 4}
{'train_bbox_regression': 0.0300109154648251, 'train_classification': 0.14002909925248888, 'train_loss_total': 0.17004001471731398, 'epoch': 5}

訓練が完了したら、訓練済みモデルの重みをファイルに保存します。

model.to('cpu')
torch.save(model.state_dict(), 'bccd.pth')

4.5.4. モデル評価#

最適なモデルが得られたら、次にテストデータを用いてモデルを詳細に評価します。

test_loader = torch.utils.data.DataLoader(
                    CocoDataset('bccd', 'bccd/test.bbox.json'),
                    batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

model = ssd(num_classes=3, weights='bccd.pth')
model.to(device)
model.eval()

n_cells = [[], [], []]
n_pred_cells = [[], [], []]

metric = torchmetrics.detection.mean_ap.MeanAveragePrecision()
with torch.no_grad():
    for images, targets in test_loader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        outputs = model(images)

        for target, output in zip(targets, outputs):
            n_cells_ = [0, 0, 0]
            n_pred_cells_ = [0, 0, 0]
            for label in target['labels']:
                n_cells_[label.item() - 1] += 1
            for label, score in zip(output['labels'], output['scores']):
                if score > 0.5:
                    n_pred_cells_[label.item() - 1] += 1
            for i in range(3):
                n_cells[i].append(n_cells_[i])
                n_pred_cells[i].append(n_pred_cells_[i])
        metric.update(outputs, targets)

n_cells = pd.DataFrame(n_cells, index=['WBC', 'RBC', 'Platelets']).T
n_pred_cells = pd.DataFrame(n_pred_cells, index=['WBC', 'RBC', 'Platelets']).T

metrics = {}
for k, v in metric.compute().items():
    metrics[k] = v.cpu().detach().numpy().tolist()
Hide code cell output
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
metrics
{'map': 0.3857402503490448,
 'map_50': 0.6347993016242981,
 'map_75': 0.4144221544265747,
 'map_small': -1.0,
 'map_medium': 0.40806663036346436,
 'map_large': 0.5765005946159363,
 'mar_1': 0.23079825937747955,
 'mar_10': 0.5017107129096985,
 'mar_100': 0.6220964789390564,
 'mar_small': -1.0,
 'mar_medium': 0.5318182110786438,
 'mar_large': 0.6575040221214294,
 'map_per_class': -1.0,
 'mar_100_per_class': -1.0,
 'classes': [1, 2, 3]}

テストデータに対する検出性能を確認したところ、mAP(50%)が 0.635 程度であり、まだ改善の余地があるかもしれません。時間の関係上、モデルのチューニングをここで終了します。次に、モデルが予測した細胞の数と、実際にアノテーションされた細胞の数を比較するために、画像上で結果を可視化します。

fig, ax = plt.subplots()
for i in range(3):
    ax.scatter(n_cells.iloc[:, i], n_pred_cells.iloc[:, i], label=n_cells.columns[i], alpha=0.5)
ax.set_xlabel('Number of Cells')
ax.set_ylabel('Number of Predicted Cells')
ax.set_aspect('equal')
ax.legend()
fig.show()
../_images/03d7d61c5f771973234bf919be81f028e9357d68e3c525d2a8a412e99d4717ed.png

可視化の結果、白血球と赤血球の予測数はアノテーションされた数と似た数となることを確認できます。しかし、血小板についてはほとんど予測できていないことがわかります。これをより正確に評価するには、RMSE(平方平均二乗誤差)などの指標を用いる必要がありますが、ここで省略します。

血小板の予測が困難だった理由の一つは、SSD が持つ特徴と考えられます。SSD は大きな物体を高速かつ高精度で検出するのに優れている一方で、血小板のような小さな物体の検出は得意ではありません。また、今回使用したデータセットでは、血小板のアノテーション数が非常に少なかったことも影響していると考えられます。その結果、SSD は血小板のような小さな物体の特徴を十分に学習できなかった可能性があります。このように、大きな物体と小さな物体が混在する画像で、特に小さな物体のアノテーション数が著しく少ない場合、SSD は必ずしも最適な選択ではないと考えられます。

4.5.5. 推論#

推論時にも、訓練時と同じように torchvision モジュールからアーキテクチャを呼び出し、出力層のクラス数を設定します。その後、load_state_dict メソッドを使って、訓練済みの重みファイルをモデルにロードします。これらの操作はすべて ssd 関数で定義されているので、その関数を利用します。

model = ssd(num_classes=3, weights='bccd.pth')
model.to(device)
model.eval()
Hide code cell output
SSD(
  (backbone): SSDFeatureExtractorVGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inplace=True)
      (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
      (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (18): ReLU(inplace=True)
      (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (20): ReLU(inplace=True)
      (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (22): ReLU(inplace=True)
    )
    (extra): ModuleList(
      (0): Sequential(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): ReLU(inplace=True)
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): ReLU(inplace=True)
        (5): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
        (7): Sequential(
          (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
          (1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))
          (2): ReLU(inplace=True)
          (3): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
          (4): ReLU(inplace=True)
        )
      )
      (1): Sequential(
        (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (3): ReLU(inplace=True)
      )
      (2): Sequential(
        (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (3): ReLU(inplace=True)
      )
      (3-4): 2 x Sequential(
        (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
        (3): ReLU(inplace=True)
      )
    )
  )
  (anchor_generator): DefaultBoxGenerator(aspect_ratios=[[2], [2, 3], [2, 3], [2, 3], [2], [2]], clip=True, scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], steps=[8, 16, 32, 64, 100, 300])
  (head): SSDHead(
    (classification_head): SSDClassificationHead(
      (module_list): ModuleList(
        (0): Conv2d(512, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Conv2d(1024, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): Conv2d(512, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): Conv2d(256, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4-5): 2 x Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (regression_head): SSDRegressionHead(
      (module_list): ModuleList(
        (0): Conv2d(512, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Conv2d(1024, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): Conv2d(512, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): Conv2d(256, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4-5): 2 x Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
  )
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.48235, 0.45882, 0.40784], std=[0.00392156862745098, 0.00392156862745098, 0.00392156862745098])
      Resize(min_size=(300,), max_size=3000, mode='bilinear')
  )
)

このモデルを利用して推論を行います。まず、1 枚の画像を指定し、PIL モジュールを用いて画像を開き、テンソル形式に変換した後、モデルに入力します。モデルは予測結果としてバウンディングボックスの座標(bboxes)、分類ラベル(labels)、および信頼スコア(scores)を出力します。ただし、信頼スコアが 0.5 未満のバウンディングボックスは採用せず、信頼スコアが高い結果のみを選択して利用します。

def predict(image_path, threshold=0.5):
    image = PIL.Image.open(image_path).convert('RGB')
    input_tensor = torchvision.transforms.functional.to_tensor(image).unsqueeze(0).to(device)
        
    with torch.no_grad():
        predictions = model(input_tensor)[0]
        
    bboxes = predictions['boxes'][predictions['scores'] > threshold]
    labels = predictions['labels'][predictions['scores'] > threshold]
    scores = predictions['scores'][predictions['scores'] > threshold]
    bboxes = [_.cpu().detach().numpy().tolist() for _ in bboxes]
    labels = [int(_.cpu().detach().item()) for _ in labels]
    scores = [float(_.cpu().detach().item()) for _ in scores]
    
    return bboxes, labels, scores

bboxes, labels, scores = predict('bccd/images/BloodImage_00033.jpg')

for bbox, label, score in zip(bboxes, labels, scores):
    print([bbox, label, score])
[[132.56802368164062, 107.75836181640625, 289.7909240722656, 249.7099151611328], 1, 0.8317584991455078]
[[84.72344970703125, 221.9060821533203, 194.07650756835938, 329.6766052246094], 2, 0.8022224307060242]
[[408.1998291015625, 211.26560974121094, 519.9180908203125, 315.0968933105469], 2, 0.8018974661827087]
[[38.22060775756836, 21.28352165222168, 155.07687377929688, 123.92396545410156], 2, 0.774657130241394]
[[292.3138427734375, 245.04052734375, 393.79595947265625, 356.3401794433594], 2, 0.7442697882652283]
[[192.3541717529297, 293.60784912109375, 305.8379211425781, 404.2905578613281], 2, 0.6082244515419006]
[[201.6402587890625, 4.537786960601807, 299.6002502441406, 90.0793685913086], 2, 0.5860577821731567]
[[548.2664184570312, 34.37702178955078, 638.9638061523438, 135.6253662109375], 2, 0.5618970990180969]
[[569.5496215820312, 330.07330322265625, 639.9425048828125, 433.29376220703125], 2, 0.5243276357650757]
[[338.2408752441406, 76.50111389160156, 460.2703552246094, 189.35848999023438], 2, 0.516722559928894]
[[307.004150390625, 4.214175701141357, 420.1590270996094, 90.2580337524414], 2, 0.5131307244300842]

検出されたオブジェクトのバウンディングボックスを入力画像に描画します。その後、PIL および matplotlib ライブラリを使用して、画像とその検出結果を可視化します。

def viz(image_path):
    bboxes, labels, scores = predict(image_path)
    meta_info = [
        {}, # background
        {'class': 'WBC', 'col': '#0094cd'},
        {'class': 'RBC', 'col': '#d51317'},
        {'class': 'Platelets', 'col': '#007b3d'}
    ]
    
    image = PIL.Image.open(image_path).convert('RGB')
    draw = PIL.ImageDraw.Draw(image)
    for bbox, label, score in zip(bboxes, labels, scores):
        x1, y1, x2, y2 = bbox
        draw.rectangle(((x1, y1), (x2, y2)), outline=meta_info[label]['col'], width=3)
        draw.text((x1, y1 - 10), f'{meta_info[label]["class"]} ({score:.2f})', fill=meta_info[label]['col'])
        
    fig = plt.figure()
    ax = fig.add_subplot()
    ax.imshow(image)
    ax.axis('off')
    fig.show()
viz('bccd/images/BloodImage_00033.jpg')
../_images/601f400959b25cab8be95bb2ba215324a6601f1fe34637b5cccf95229873ce4f.png