5.2. 演習:レントゲン画像歯検出#
歯や歯茎の状態を確認するためにレントゲンを使用することは、歯科検査において欠かせません。もし、レントゲン画像から歯の領域を自動で検出し、その状態、たとえば虫歯の有無を判定できるようになれば、歯科医師の負担を軽減し、診断をより迅速かつ正確に行うことが可能です。このような自動化を実現するには、画像内の歯の領域を正確に分離するセグメンテーション技術が必要です。本節では、レントゲン画像を用いた歯のセグメンテーションの方法を学び、歯科検査を効率化する支援方法を考えていきます。
5.2.1. 演習準備#
5.2.1.1. ライブラリ#
本節で必要なライブラリを読み込みます。os、random、NumPy、Pandas、Matplotlib、Pillow(PIL)は訓練過程の可視化や推論結果の表示に利用します。scikit-image(skimage)はマスクから輪郭線を計算する際に使用します。さらに、torch、torchvision、torchmetrics はインスタンスセグメンテーションモデルの訓練、検証、推論に利用します。
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import PIL
import skimage
import torch
import torchvision
import torchvision.transforms.v2
import torchmetrics
print(f'torch v{torch.__version__}; torchvision v{torchvision.__version__}')
torch v2.5.1+cu121; torchvision v0.20.1+cu121
ライブラリの読み込み時に ImportError や ModuleNotFoundError が発生した場合は、該当するライブラリをインストールしてください。ライブラリのバージョンを揃える必要はありませんが、PyTorch(torch)および torchvision が上記のバージョンと異なる時、実行中に警告メッセージが現れたり、同じ結果にならなかったりする可能性があります。
5.2.1.2. データセット#
本章では、Kaggle にて CC0 ライセンスのもとで公開されている Teeth Segmentation on dental X-ray images を使用します。このデータセットは、歯のレントゲン画像からなるデータセットであり、歯を領域をポリゴン囲んだアノテーションが含まれています。アノテーションはポリゴンの座標を記載した数値データと画像として保存されたマスクの両方が用意されています。マスク画像は RGB カラー画像で、画像内の色が各歯の番号に対応しています。例えば、13 番目の歯は色が (R, G, B) = (1, 1, 1)、12 番目の歯は (2, 2, 2) のように、32 番目の歯までそれぞれ異なる色で対応付けられています[1]。歯の番号と色の対応関係は、obj_class_to_machine_color.json
ファイルに保存されています。歯の番号を区別して取り扱う際は、この対応表を使用してデータを取得する必要があります。

Fig. 5.3 Teeth Segmentation on dental X-ray images データセットのサンプル画像とマスク画像。#
オリジナルのデータセットはやや大きいため、本節では、オリジナルのデータセットから 80 枚の画像を抽出して、そのうち 60 枚を訓練データ、10 枚を検証データ、10 枚を検証データとして整理したものを利用します。Jupyter Notebook では、以下のコマンドを実行することで、データセットをダウンロードできます。
!wget https://dl.biopapyrus.jp/data/teethsegm.zip
!unzip teethsegm.zip
5.2.1.3. 前処理#
本節では、歯の番号を区別せずにインスタンスセグメンテーションを行います。そのために、前述のデータセットをインスタンスセグメンテーションの学習に利用できる形式に変換する処理(TeethDataset
)を定義します。
Teeth Segmentation on dental X-ray images データセットに含まれるマスク画像は RGB 画像で、例えば画素値が (1, 1, 1) の部分が 13 番目の歯、(2, 2, 2) の部分が 12 番目の歯に対応しています。TeethDataset
では、この RGB 画像を基に、まず画素値が (1, 1, 1) の部分を判定して 1 枚のバイナリマスクを作成し、次に画素値が (2, 2, 2) の部分について同様にバイナリマスクを作成します。この操作を繰り返し、32 本の歯に対応するマスクを生成します((mask == labels[:, None, None]).to(dtype=torch.uint8)
)。さらに、画像内に該当する歯が存在しない場合、その歯に対応するマスクの画素値はすべて 0 になります。このような無効なマスクを削除する処理(masks[has_tooth]
)も実装しています。最後に、生成したマスクを PyTorch で扱える形式に変換する処理を行います。
class TeethDataset(torch.utils.data.Dataset):
def __init__(self, root):
self.root = root
self.images, self.masks = self.__load_datasets(self.root)
self.transforms = torchvision.transforms.v2.Compose([
torchvision.transforms.v2.ToDtype(torch.float, scale=True),
torchvision.transforms.v2.ToPureTensor()
])
def __getitem__(self, idx):
image = torchvision.io.read_image(self.images[idx])
mask = torchvision.io.read_image(self.masks[idx])
# create labels, masks, bboxes for training from the original mask
labels = torch.tensor([_ for _ in range(1, 33)])
masks = (mask == labels[:, None, None]).to(dtype=torch.uint8)
has_tooth = [_.sum() > 0 for _ in masks]
labels = labels[has_tooth]
masks = masks[has_tooth]
boxes = torchvision.ops.boxes.masks_to_boxes(masks)
# convert teeth number to 1 (ignore the teeth number)
labels = torch.ones((len(labels), ), dtype=torch.int64)
# format image and annotation for training
image = torchvision.tv_tensors.Image(image)
target = {
'boxes': torchvision.tv_tensors.BoundingBoxes(boxes, format='XYXY', canvas_size=torchvision.transforms.v2.functional.get_size(image)),
'masks': torchvision.tv_tensors.Mask(masks),
'labels': labels,
}
image, target = self.transforms(image, target)
return image, target
def __len__(self):
return len(self.images)
def __load_datasets(self, root):
images = []
masks = []
for image_fpath in os.listdir(os.path.join(root, 'images')):
image_fpath = os.path.join(root, 'images', image_fpath)
if os.path.splitext(image_fpath)[1] == '.jpg':
image_fname = os.path.basename(image_fpath)
mask_fpath = os.path.join(root, 'masks', os.path.splitext(image_fname)[0] + '.png')
if os.path.exists(mask_fpath):
images.append(image_fpath)
masks.append(mask_fpath)
return images, masks
通常、モデルの訓練では、データ拡張として画像の拡大縮小、平行移動、回転などを適用することが一般的です。しかし、歯のレントゲン画像の場合、過度なデータ拡張を行うと、本来の画像情報から逸脱し、モデルの学習に悪影響を及ぼす可能性があります。そのため、適切なデータ拡張手法を慎重に選択することが重要です。
5.2.1.4. 計算デバイス#
PyTorch が GPU を認識できる場合は GPU を利用し、認識できない場合は CPU を利用するように計算デバイスを設定します。
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
5.2.2. モデル構築#
Mask R-CNN は高性能なインスタンスセグメンテーションを行うアーキテクチャで、torchvision.models モジュールに含まれています。しかし、torchvision.models で提供されている Mask R-CNN は、COCO データセット向けに設計されており、車や人などの 90 種類の一般的なオブジェクトを対象としています。そのため、そのままでは歯のインスタンスセグメンテーションに適用することができません。
歯のセグメンテーションに対応させるためには、Mask R-CNN の分類モジュールの出力層のユニット数を修正する必要があります。この修正はモデルを呼び出すたびに行う必要があり、手間がかかります。そこで、指定したクラス数に応じてアーキテクチャを生成し、必要に応じて修正を加えられるように、一連の処理を関数として定義します。なお、インスタンスセグメンテーションは物体検出と同様に背景を 1 つのクラスとして扱うため、出力層の数を修正する際には、検出対象のクラス数に 1 を加えた数にする必要があります。
def maskrcnn(num_classes, weights=None):
num_classes = num_classes + 1
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights='DEFAULT')
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(in_features_mask, 256, num_classes)
if weights is not None:
model.load_state_dict(torch.load(weights))
return model
model = maskrcnn(num_classes=1)
model.to(device)
Show code cell output
MaskRCNN(
(transform): GeneralizedRCNNTransform(
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Resize(min_size=(800,), max_size=1333, mode='bilinear')
)
(backbone): BackboneWithFPN(
(body): IntermediateLayerGetter(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): FrozenBatchNorm2d(64, eps=0.0)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): Bottleneck(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(64, eps=0.0)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(64, eps=0.0)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(256, eps=0.0)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): FrozenBatchNorm2d(256, eps=0.0)
)
)
(1): Bottleneck(
(conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(64, eps=0.0)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(64, eps=0.0)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(256, eps=0.0)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(64, eps=0.0)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(64, eps=0.0)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(256, eps=0.0)
(relu): ReLU(inplace=True)
)
)
(layer2): Sequential(
(0): Bottleneck(
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(128, eps=0.0)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(128, eps=0.0)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(512, eps=0.0)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): FrozenBatchNorm2d(512, eps=0.0)
)
)
(1): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(128, eps=0.0)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(128, eps=0.0)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(512, eps=0.0)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(128, eps=0.0)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(128, eps=0.0)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(512, eps=0.0)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(128, eps=0.0)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(128, eps=0.0)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(512, eps=0.0)
(relu): ReLU(inplace=True)
)
)
(layer3): Sequential(
(0): Bottleneck(
(conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(256, eps=0.0)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(256, eps=0.0)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(1024, eps=0.0)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): FrozenBatchNorm2d(1024, eps=0.0)
)
)
(1): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(256, eps=0.0)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(256, eps=0.0)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(1024, eps=0.0)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(256, eps=0.0)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(256, eps=0.0)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(1024, eps=0.0)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(256, eps=0.0)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(256, eps=0.0)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(1024, eps=0.0)
(relu): ReLU(inplace=True)
)
(4): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(256, eps=0.0)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(256, eps=0.0)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(1024, eps=0.0)
(relu): ReLU(inplace=True)
)
(5): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(256, eps=0.0)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(256, eps=0.0)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(1024, eps=0.0)
(relu): ReLU(inplace=True)
)
)
(layer4): Sequential(
(0): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(512, eps=0.0)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(512, eps=0.0)
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(2048, eps=0.0)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): FrozenBatchNorm2d(2048, eps=0.0)
)
)
(1): Bottleneck(
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(512, eps=0.0)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(512, eps=0.0)
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(2048, eps=0.0)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(512, eps=0.0)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(512, eps=0.0)
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(2048, eps=0.0)
(relu): ReLU(inplace=True)
)
)
)
(fpn): FeaturePyramidNetwork(
(inner_blocks): ModuleList(
(0): Conv2dNormActivation(
(0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
(1): Conv2dNormActivation(
(0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
)
(2): Conv2dNormActivation(
(0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
)
(3): Conv2dNormActivation(
(0): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
(layer_blocks): ModuleList(
(0-3): 4 x Conv2dNormActivation(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(extra_blocks): LastLevelMaxPool()
)
)
(rpn): RegionProposalNetwork(
(anchor_generator): AnchorGenerator()
(head): RPNHead(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
)
)
(cls_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))
(bbox_pred): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1))
)
)
(roi_heads): RoIHeads(
(box_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7), sampling_ratio=2)
(box_head): TwoMLPHead(
(fc6): Linear(in_features=12544, out_features=1024, bias=True)
(fc7): Linear(in_features=1024, out_features=1024, bias=True)
)
(box_predictor): FastRCNNPredictor(
(cls_score): Linear(in_features=1024, out_features=2, bias=True)
(bbox_pred): Linear(in_features=1024, out_features=8, bias=True)
)
(mask_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(14, 14), sampling_ratio=2)
(mask_head): MaskRCNNHeads(
(0): Conv2dNormActivation(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
)
(2): Conv2dNormActivation(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
)
(3): Conv2dNormActivation(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
)
)
(mask_predictor): MaskRCNNPredictor(
(conv5_mask): ConvTranspose2d(256, 256, kernel_size=(2, 2), stride=(2, 2))
(relu): ReLU(inplace=True)
(mask_fcn_logits): Conv2d(256, 2, kernel_size=(1, 1), stride=(1, 1))
)
)
)
5.2.3. モデル訓練#
モデルが学習データを効率的に学習できるように、学習アルゴリズム(optimizer
)、学習率(lr
)、および学習率を調整するスケジューラ(lr_scheduler
)を設定します。なお、インスタンスセグメンテーションでは、すべてのピクセルに対して予測結果(分類結果)とラベルの誤差を計算する損失関数を定義する必要がありますが、この損失関数はすでにモデル内で定義されているため、ここで新たに定義する必要はありません。
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
次に、訓練サブセットおよび検証サブセットを読み込み、モデルに入力できる形式に整えます。
train_loader = torch.utils.data.DataLoader(
TeethDataset('teethsegm/train'),
batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
valid_loader = torch.utils.data.DataLoader(
TeethDataset('teethsegm/valid'),
batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))
準備が整ったら、訓練を開始します。訓練プロセスでは、訓練と検証を交互に繰り返します。訓練では、訓練データを使ってモデルのパラメータを更新し、その際の損失(誤差)を記録します。検証では、検証データを使ってモデルの予測性能(mAP や IoU)を計算し、その結果を記録します。
num_epochs = 10
metric_dict = []
for epoch in range(num_epochs):
# training phase
model.train()
epoch_loss_dict = {}
for images, targets in train_loader:
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
batch_loss_dict = model(images, targets)
batch_tol_loss = 0
for loss_type, loss_val in batch_loss_dict.items():
batch_tol_loss += loss_val
if loss_type in epoch_loss_dict:
epoch_loss_dict[f'train_{loss_type}'] += loss_val.item()
else:
epoch_loss_dict[f'train_{loss_type}'] = loss_val.item()
# update weights
optimizer.zero_grad()
batch_tol_loss.backward()
optimizer.step()
lr_scheduler.step()
# validation phase
model.eval()
metric = torchmetrics.detection.mean_ap.MeanAveragePrecision()
iou = 0
n_targets = 0
with torch.no_grad():
for images, targets in valid_loader:
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
pred_targets = model(images)
metric.update(pred_targets, targets)
for i in range(len(targets)):
pred_mask = pred_targets[i]['masks'].squeeze(1).any(dim=0)
true_mask = targets[i]['masks'].any(dim=0)
iou += torchmetrics.functional.jaccard_index(pred_mask.unsqueeze(0), true_mask.unsqueeze(0), num_classes=1, task='binary')
n_targets += 1
# record training loss
epoch_loss_dict['train_total_loss'] = sum(epoch_loss_dict.values())
metric_dict.append({k: v / len(train_loader) for k, v in epoch_loss_dict.items()})
for k, v in metric.compute().items():
if k != 'classes':
metric_dict[-1][k] = v.item()
metric_dict[-1]['avg_iou'] = iou.item() / n_targets
metric_dict[-1]['epoch'] = epoch + 1
print(metric_dict[-1])
Show code cell output
{'train_loss_classifier': 0.011489534378051757, 'train_loss_box_reg': 0.03191184997558594, 'train_loss_mask': 0.021380738417307536, 'train_loss_objectness': 0.0016914131740729014, 'train_loss_rpn_box_reg': 0.0030600666999816895, 'train_total_loss': 0.06953360264499982, 'map': 0.5444761514663696, 'map_50': 0.9737210273742676, 'map_75': 0.5243497490882874, 'map_small': -1.0, 'map_medium': 0.37486690282821655, 'map_large': 0.5558559894561768, 'mar_1': 0.024163568392395973, 'mar_10': 0.25315985083580017, 'mar_100': 0.6152416467666626, 'mar_small': -1.0, 'mar_medium': 0.47058823704719543, 'mar_large': 0.625, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'avg_iou': 0.4445939540863037, 'epoch': 1}
{'train_loss_classifier': 0.010966811577479045, 'train_loss_box_reg': 0.024358453353246053, 'train_loss_mask': 0.016817092895507812, 'train_loss_objectness': 0.000496666847417752, 'train_loss_rpn_box_reg': 0.002907271683216095, 'train_total_loss': 0.05554629635686676, 'map': 0.5674804449081421, 'map_50': 0.9768766164779663, 'map_75': 0.6191827058792114, 'map_small': -1.0, 'map_medium': 0.4134511649608612, 'map_large': 0.5766996145248413, 'mar_1': 0.02490706369280815, 'mar_10': 0.2587360739707947, 'mar_100': 0.6368029713630676, 'mar_small': -1.0, 'mar_medium': 0.5411764979362488, 'mar_large': 0.6432539820671082, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'avg_iou': 0.5404828071594239, 'epoch': 2}
{'train_loss_classifier': 0.009296949704488118, 'train_loss_box_reg': 0.022947633266448976, 'train_loss_mask': 0.016431063413619995, 'train_loss_objectness': 0.00032265024880568185, 'train_loss_rpn_box_reg': 0.002596474935611089, 'train_total_loss': 0.05159477156897386, 'map': 0.5910195112228394, 'map_50': 0.988142728805542, 'map_75': 0.6513012051582336, 'map_small': -1.0, 'map_medium': 0.45653799176216125, 'map_large': 0.598625659942627, 'mar_1': 0.022676579654216766, 'mar_10': 0.2643122673034668, 'mar_100': 0.6605948209762573, 'mar_small': -1.0, 'mar_medium': 0.5588235259056091, 'mar_large': 0.6674603223800659, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'avg_iou': 0.6339510440826416, 'epoch': 3}
{'train_loss_classifier': 0.009638185302416483, 'train_loss_box_reg': 0.018773547808329263, 'train_loss_mask': 0.014078843593597411, 'train_loss_objectness': 0.0002855499275028706, 'train_loss_rpn_box_reg': 0.0017421250542004904, 'train_total_loss': 0.04451825168604652, 'map': 0.6261765360832214, 'map_50': 0.9887192845344543, 'map_75': 0.7262885570526123, 'map_small': -1.0, 'map_medium': 0.5095721483230591, 'map_large': 0.6344967484474182, 'mar_1': 0.02565055713057518, 'mar_10': 0.2773234248161316, 'mar_100': 0.6881040930747986, 'mar_small': -1.0, 'mar_medium': 0.5647059082984924, 'mar_large': 0.6964285969734192, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'avg_iou': 0.6112461566925049, 'epoch': 4}
{'train_loss_classifier': 0.008092211186885833, 'train_loss_box_reg': 0.017672230799992878, 'train_loss_mask': 0.014983957012494406, 'train_loss_objectness': 0.0001613290049135685, 'train_loss_rpn_box_reg': 0.002534709870815277, 'train_total_loss': 0.043444437875101966, 'map': 0.6290183663368225, 'map_50': 0.9883098006248474, 'map_75': 0.7401853203773499, 'map_small': -1.0, 'map_medium': 0.4878535866737366, 'map_large': 0.6373701691627502, 'mar_1': 0.02490706369280815, 'mar_10': 0.28141263127326965, 'mar_100': 0.6929367780685425, 'mar_small': -1.0, 'mar_medium': 0.5764706134796143, 'mar_large': 0.7007936239242554, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'avg_iou': 0.6161362171173096, 'epoch': 5}
{'train_loss_classifier': 0.006792613367239634, 'train_loss_box_reg': 0.01617384652296702, 'train_loss_mask': 0.012900521357854208, 'train_loss_objectness': 8.660277817398309e-05, 'train_loss_rpn_box_reg': 0.001977264260252317, 'train_total_loss': 0.037930848286487164, 'map': 0.6377891898155212, 'map_50': 0.9888526201248169, 'map_75': 0.7507622838020325, 'map_small': -1.0, 'map_medium': 0.48388487100601196, 'map_large': 0.6465470194816589, 'mar_1': 0.026394052430987358, 'mar_10': 0.2855018675327301, 'mar_100': 0.6959107518196106, 'mar_small': -1.0, 'mar_medium': 0.5352941155433655, 'mar_large': 0.7067460417747498, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'avg_iou': 0.6109446048736572, 'epoch': 6}
{'train_loss_classifier': 0.009918662905693054, 'train_loss_box_reg': 0.021767306327819824, 'train_loss_mask': 0.01655222276846568, 'train_loss_objectness': 0.00015166027781864008, 'train_loss_rpn_box_reg': 0.0019014857709407807, 'train_total_loss': 0.050291338050737974, 'map': 0.633118748664856, 'map_50': 0.9888139367103577, 'map_75': 0.7573145031929016, 'map_small': -1.0, 'map_medium': 0.46705272793769836, 'map_large': 0.6426718831062317, 'mar_1': 0.026394052430987358, 'mar_10': 0.28438660502433777, 'mar_100': 0.6933085322380066, 'mar_small': -1.0, 'mar_medium': 0.5529412031173706, 'mar_large': 0.7027778029441833, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'avg_iou': 0.6191488265991211, 'epoch': 7}
{'train_loss_classifier': 0.006997281809647878, 'train_loss_box_reg': 0.018330061435699464, 'train_loss_mask': 0.014295052488644917, 'train_loss_objectness': 8.228952065110206e-05, 'train_loss_rpn_box_reg': 0.0016385612388451895, 'train_total_loss': 0.04134324649348855, 'map': 0.6294244527816772, 'map_50': 0.9886727929115295, 'map_75': 0.7531455755233765, 'map_small': -1.0, 'map_medium': 0.47531214356422424, 'map_large': 0.6381085515022278, 'mar_1': 0.027137545868754387, 'mar_10': 0.282527893781662, 'mar_100': 0.6918215751647949, 'mar_small': -1.0, 'mar_medium': 0.5647059082984924, 'mar_large': 0.7003968358039856, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'avg_iou': 0.6212560176849365, 'epoch': 8}
{'train_loss_classifier': 0.009105870127677917, 'train_loss_box_reg': 0.019151902198791503, 'train_loss_mask': 0.014871043960253398, 'train_loss_objectness': 0.00025804553491373856, 'train_loss_rpn_box_reg': 0.0020078005890051525, 'train_total_loss': 0.04539466241064171, 'map': 0.6337465047836304, 'map_50': 0.9886727929115295, 'map_75': 0.7454661726951599, 'map_small': -1.0, 'map_medium': 0.47531214356422424, 'map_large': 0.6423608064651489, 'mar_1': 0.027137545868754387, 'mar_10': 0.28364312648773193, 'mar_100': 0.6940520405769348, 'mar_small': -1.0, 'mar_medium': 0.5647059082984924, 'mar_large': 0.7027778029441833, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'avg_iou': 0.6193825244903565, 'epoch': 9}
{'train_loss_classifier': 0.007122074564297994, 'train_loss_box_reg': 0.016642950971921287, 'train_loss_mask': 0.014895657698313395, 'train_loss_objectness': 0.0002251185787220796, 'train_loss_rpn_box_reg': 0.0017238457997639973, 'train_total_loss': 0.04060964761301875, 'map': 0.6320429444313049, 'map_50': 0.9886727929115295, 'map_75': 0.7427568435668945, 'map_small': -1.0, 'map_medium': 0.47531214356422424, 'map_large': 0.6406862735748291, 'mar_1': 0.027137545868754387, 'mar_10': 0.28438660502433777, 'mar_100': 0.6933085322380066, 'mar_small': -1.0, 'mar_medium': 0.5647059082984924, 'mar_large': 0.7019841074943542, 'map_per_class': -1.0, 'mar_100_per_class': -1.0, 'avg_iou': 0.619456672668457, 'epoch': 10}
訓練データに対する損失と検証データに対する予測性能(mAP)を可視化し、訓練過程を評価します。

可視化の結果から、エポック数が増えるにつれて訓練データに対する損失が継続的に減少し、5 エポック目から収束し始める傾向がみられました。一方で、検証データに対する検証性能(mAP および IoU)は最初の数エポックですでに高い値に達しいることがわかりました。訓練ではこれで十分と考えらえるので次のステップに進みます。
次に、この手順を他の深層ニューラルネットワークのアーキテクチャ(例えば U-Net など)に適用し、それぞれの検証性能を比較します。この比較により、データセットに最も適したアーキテクチャを選定します。ただし、本節はこの比較を行わずに、Mask R-CNN を最適なアーキテクチャとして扱い、次のステップに進みます。
次のステップでは、訓練サブセットと検証サブセットを統合し、最適と判断したアーキテクチャを最初から訓練します。その準備として、まず訓練サブセットと検証サブセットを結合します。
!mkdir -p teethsegm/trainvalid/images
!mkdir -p teethsegm/trainvalid/masks
!cp teethsegm/train/images/* teethsegm/trainvalid/images
!cp teethsegm/valid/images/* teethsegm/trainvalid/images
!cp teethsegm/train/masks/* teethsegm/trainvalid/masks
!cp teethsegm/valid/masks/* teethsegm/trainvalid/masks
次に、モデルの構築を行います。先ほど可視化した検証性能の推移グラフを確認した結果、数エポックの訓練で十分に高い予測性能を達成できることがわかりました。そこで、ここでは訓練サブセットと検証サブセットを統合したデータを用いて、5 エポックのみ訓練を行います。
# model
model = maskrcnn(num_classes=1)
model.to(device)
# training parameters
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
# training data
train_loader = torch.utils.data.DataLoader(
TeethDataset('teethsegm/trainvalid'),
batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
# training
num_epochs = 5
metric_dict = []
for epoch in range(num_epochs):
model.train()
epoch_loss_dict = {}
for images, targets in train_loader:
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
batch_loss_dict = model(images, targets)
batch_tol_loss = 0
for loss_type, loss_val in batch_loss_dict.items():
batch_tol_loss += loss_val
if loss_type in epoch_loss_dict:
epoch_loss_dict[f'train_{loss_type}'] += loss_val.item()
else:
epoch_loss_dict[f'train_{loss_type}'] = loss_val.item()
optimizer.zero_grad()
batch_tol_loss.backward()
optimizer.step()
lr_scheduler.step()
# record training loss
epoch_loss_dict['train_total_loss'] = sum(epoch_loss_dict.values())
metric_dict.append({k: v / len(train_loader) for k, v in epoch_loss_dict.items()})
metric_dict[-1]['epoch'] = epoch + 1
print(metric_dict[-1])
Show code cell output
{'train_loss_classifier': 0.0283333592944675, 'train_loss_box_reg': 0.046519974867502846, 'train_loss_mask': 0.025429550144407485, 'train_loss_objectness': 0.0019212358941634495, 'train_loss_rpn_box_reg': 0.005258884280920029, 'train_total_loss': 0.10746300448146132, 'epoch': 1}
{'train_loss_classifier': 0.024011474516656663, 'train_loss_box_reg': 0.04263102345996433, 'train_loss_mask': 0.021300279431872897, 'train_loss_objectness': 0.0028616411404477227, 'train_loss_rpn_box_reg': 0.0029303785413503647, 'train_total_loss': 0.09373479709029198, 'epoch': 2}
{'train_loss_classifier': 0.01876573430167304, 'train_loss_box_reg': 0.0398339761628045, 'train_loss_mask': 0.019809636804792616, 'train_loss_objectness': 0.0023387237969372007, 'train_loss_rpn_box_reg': 0.004181647466288673, 'train_total_loss': 0.08492971853249603, 'epoch': 3}
{'train_loss_classifier': 0.020585997237099543, 'train_loss_box_reg': 0.04305836227205065, 'train_loss_mask': 0.021457746624946594, 'train_loss_objectness': 0.003746054238743252, 'train_loss_rpn_box_reg': 0.003297751769423485, 'train_total_loss': 0.09214591214226352, 'epoch': 4}
{'train_loss_classifier': 0.01937776803970337, 'train_loss_box_reg': 0.04037020603815714, 'train_loss_mask': 0.019126496381229825, 'train_loss_objectness': 0.00240055699315336, 'train_loss_rpn_box_reg': 0.003422745813926061, 'train_total_loss': 0.08469777326616976, 'epoch': 5}
訓練が完了したら、訓練済みモデルの重みをファイルに保存します。
model.to('cpu')
torch.save(model.state_dict(), 'teethsegm.pth')
5.2.4. モデル評価#
最適なモデルが得られたら、次にテストデータを用いてモデルを詳細に評価します。
test_loader = torch.utils.data.DataLoader(
TeethDataset('teethsegm/test'),
batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
model = maskrcnn(num_classes=1, weights='teethsegm.pth')
model.to(device)
model.eval()
metric = torchmetrics.detection.mean_ap.MeanAveragePrecision()
iou = 0
n_targets = 0
with torch.no_grad():
for images, targets in test_loader:
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
pred_targets = model(images)
metric.update(pred_targets, targets)
for i in range(len(targets)):
pred_mask = pred_targets[i]['masks'].squeeze(1).any(dim=0)
true_mask = targets[i]['masks'].any(dim=0)
iou += torchmetrics.functional.jaccard_index(pred_mask.unsqueeze(0), true_mask.unsqueeze(0), num_classes=1, task='binary')
n_targets += 1
metrics = [{k: v.cpu().detach().numpy().tolist()} for k, v in metric.compute().items()]
metrics.append({'avg_iou': iou.item() / n_targets})
metrics
[{'map': 0.40518149733543396},
{'map_50': 0.8205856084823608},
{'map_75': 0.32684585452079773},
{'map_small': -1.0},
{'map_medium': 0.12812453508377075},
{'map_large': 0.4226655960083008},
{'mar_1': 0.020769231021404266},
{'mar_10': 0.2161538451910019},
{'mar_100': 0.5426923036575317},
{'mar_small': -1.0},
{'mar_medium': 0.4000000059604645},
{'mar_large': 0.5502024292945862},
{'map_per_class': -1.0},
{'mar_100_per_class': -1.0},
{'classes': 1},
{'avg_iou': 0.40285353660583495}]
5.2.5. 推論#
推論時にも、訓練時と同じように torchvision モジュールからアーキテクチャを呼び出し、出力層のクラス数を設定します。その後、load_state_dict
メソッドを使って、訓練済みの重みファイルをモデルにロードします。これらの操作はすべて maskrcnn
関数で定義されているので、その関数を利用します。
model = maskrcnn(num_classes=1, weights='teethsegm.pth')
model.to(device)
model.eval()
Show code cell output
MaskRCNN(
(transform): GeneralizedRCNNTransform(
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Resize(min_size=(800,), max_size=1333, mode='bilinear')
)
(backbone): BackboneWithFPN(
(body): IntermediateLayerGetter(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): FrozenBatchNorm2d(64, eps=0.0)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): Bottleneck(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(64, eps=0.0)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(64, eps=0.0)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(256, eps=0.0)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): FrozenBatchNorm2d(256, eps=0.0)
)
)
(1): Bottleneck(
(conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(64, eps=0.0)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(64, eps=0.0)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(256, eps=0.0)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(64, eps=0.0)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(64, eps=0.0)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(256, eps=0.0)
(relu): ReLU(inplace=True)
)
)
(layer2): Sequential(
(0): Bottleneck(
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(128, eps=0.0)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(128, eps=0.0)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(512, eps=0.0)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): FrozenBatchNorm2d(512, eps=0.0)
)
)
(1): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(128, eps=0.0)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(128, eps=0.0)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(512, eps=0.0)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(128, eps=0.0)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(128, eps=0.0)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(512, eps=0.0)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(128, eps=0.0)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(128, eps=0.0)
(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(512, eps=0.0)
(relu): ReLU(inplace=True)
)
)
(layer3): Sequential(
(0): Bottleneck(
(conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(256, eps=0.0)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(256, eps=0.0)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(1024, eps=0.0)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): FrozenBatchNorm2d(1024, eps=0.0)
)
)
(1): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(256, eps=0.0)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(256, eps=0.0)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(1024, eps=0.0)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(256, eps=0.0)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(256, eps=0.0)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(1024, eps=0.0)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(256, eps=0.0)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(256, eps=0.0)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(1024, eps=0.0)
(relu): ReLU(inplace=True)
)
(4): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(256, eps=0.0)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(256, eps=0.0)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(1024, eps=0.0)
(relu): ReLU(inplace=True)
)
(5): Bottleneck(
(conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(256, eps=0.0)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(256, eps=0.0)
(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(1024, eps=0.0)
(relu): ReLU(inplace=True)
)
)
(layer4): Sequential(
(0): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(512, eps=0.0)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(512, eps=0.0)
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(2048, eps=0.0)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): FrozenBatchNorm2d(2048, eps=0.0)
)
)
(1): Bottleneck(
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(512, eps=0.0)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(512, eps=0.0)
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(2048, eps=0.0)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d(512, eps=0.0)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d(512, eps=0.0)
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d(2048, eps=0.0)
(relu): ReLU(inplace=True)
)
)
)
(fpn): FeaturePyramidNetwork(
(inner_blocks): ModuleList(
(0): Conv2dNormActivation(
(0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
(1): Conv2dNormActivation(
(0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
)
(2): Conv2dNormActivation(
(0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
)
(3): Conv2dNormActivation(
(0): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
(layer_blocks): ModuleList(
(0-3): 4 x Conv2dNormActivation(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(extra_blocks): LastLevelMaxPool()
)
)
(rpn): RegionProposalNetwork(
(anchor_generator): AnchorGenerator()
(head): RPNHead(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
)
)
(cls_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))
(bbox_pred): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1))
)
)
(roi_heads): RoIHeads(
(box_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7), sampling_ratio=2)
(box_head): TwoMLPHead(
(fc6): Linear(in_features=12544, out_features=1024, bias=True)
(fc7): Linear(in_features=1024, out_features=1024, bias=True)
)
(box_predictor): FastRCNNPredictor(
(cls_score): Linear(in_features=1024, out_features=2, bias=True)
(bbox_pred): Linear(in_features=1024, out_features=8, bias=True)
)
(mask_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(14, 14), sampling_ratio=2)
(mask_head): MaskRCNNHeads(
(0): Conv2dNormActivation(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
)
(2): Conv2dNormActivation(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
)
(3): Conv2dNormActivation(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
)
)
(mask_predictor): MaskRCNNPredictor(
(conv5_mask): ConvTranspose2d(256, 256, kernel_size=(2, 2), stride=(2, 2))
(relu): ReLU(inplace=True)
(mask_fcn_logits): Conv2d(256, 2, kernel_size=(1, 1), stride=(1, 1))
)
)
)
このモデルを利用して推論を行います。まず、1 枚の画像を指定し、PIL モジュールを用いて画像を開き、テンソル形式に変換した後、モデルに入力します。モデルは予測結果としてバウンディングボックスの座標(bboxes
)、マスク(masks
)、分類ラベル(labels
)、および信頼スコア(scores
)を出力します。ただし、信頼スコアが 0.5 未満のバウンディングボックスは採用せず、信頼スコアが高い結果のみを選択して利用します。
threshold = 0.5
image_path = 'teethsegm/test/images/13.jpg'
image = PIL.Image.open(image_path).convert('RGB')
input_tensor = torchvision.transforms.v2.functional.to_tensor(image).unsqueeze(0).to(device)
with torch.no_grad():
predictions = model(input_tensor)[0]
bboxes = predictions['boxes'][predictions['scores'] > threshold]
masks = predictions['masks'][predictions['scores'] > threshold]
labels = predictions['labels'][predictions['scores'] > threshold]
scores = predictions['scores'][predictions['scores'] > threshold]
検出されたオブジェクトのバウンディングボックスを入力画像に描画します。その後、PIL および matplotlib ライブラリを使用して、画像とその検出結果を可視化します。
image = PIL.Image.open(image_path).convert('RGBA')
overlay = PIL.Image.new('RGBA', image.size, (255, 255, 255, 0))
overlay_draw = PIL.ImageDraw.Draw(overlay)
for bbox, mask, label, score in zip(bboxes, masks, labels, scores):
col = tuple([random.randint(0, 255) for _ in range(3)]) + (128,)
# bbox
x1, y1, x2, y2 = bbox
draw = PIL.ImageDraw.Draw(image)
draw.rectangle(((x1, y1), (x2, y2)), outline=col[:3], width=3) # BBox is not transparent
draw.text((x1, y1 - 10), f'{label.item()} ({score:.2f})', fill=col[:3])
# mask
mask = mask.squeeze(0).cpu().numpy()
contours = skimage.measure.find_contours(mask, 0.5)
for contour in contours:
contour = np.flip(contour, axis=1).astype(int)
polygon = [tuple(point) for point in contour]
overlay_draw.polygon(polygon, fill=col)
blended_image = PIL.Image.alpha_composite(image, overlay)
fig = plt.figure()
ax = fig.add_subplot()
ax.imshow(blended_image)
ax.axis('off')
plt.show()
