您的位置:首页 >科技 >

UNet分割脊柱

时间:2022-04-11 20:16:28 来源:

随着我们每天收集更多数据,人工智能(AI)将越来越多地应用于医疗领域。人工智能在医疗领域的一个关键应用是诊断。医疗诊断中的人工智能有助于决策、管理、自动化等。

脊柱是肌肉骨骼系统的重要组成部分,支撑着身体及其器官结构,同时在我们的活动性和负荷转移中发挥着重要作用。它还能保护脊髓免受撞击造成的损伤和机械冲击。

在自动化脊柱处理管道中,脊柱标记和分割是两项基本任务。

可靠、准确的脊柱图像处理有望为临床决策支持系统提供帮助,用于脊柱和骨骼健康的诊断、手术规划和基于人群的分析。设计脊柱处理的自动化算法具有挑战性,这主要是因为解剖学和采集协议有相当大的差异,以及公开可用数据的严重短缺。

在这个博客中,我将只关注给定CT扫描数据集中脊柱的分割。标记每一个椎骨和进一步诊断的任务没有包含在这个博客中,可以作为这个任务的延续。

脊柱或脊柱分割是所有脊柱形态学和病理学自动量化应用中的关键步骤。

随着深度学习的到来,对于计算机断层扫描(CT)这样的任务来说,大而多样的数据是一个主要的热门资源。然而,目前还没有一个大规模的公共数据集。

VerSe是一个大型、多探测器、多站点的CT脊柱数据集,由355名患者的374次扫描组成。2019年和2020年都有数据集。在本博客中,我将两个数据集合并为一个数据集,以从更多数据中获益。

这些数据是根据CC BY-SA 4.0许可证提供的,因此完全是开源的。

NIfTI(神经成像信息技术倡议)是神经成像的一种文件格式。NIfTI文件在神经科学甚至神经放射学研究的成像信息学中非常常用。每个NIfTI文件包含多达7维的元数据,并支持多种数据类型。

前三个维度用于定义三个空间维度x、y和z,而第四个维度用于定义时间点t。其余维度(从第五个维度到第七个维度)用于其他用途。然而,第五维仍然可以有一些预定义的用途,例如存储特定于体素的分布参数或保存基于向量的数据。

ITK-SNAP是一个用于在3D医学图像中分割结构的软件应用程序。它是可以安装在不同平台上的开源软件。我用它可以在3D视图中可视化NifTi文件,以及在原始图像上加载和覆盖3D遮罩。我强烈建议将其用于此任务。

计算机断层扫描(CT)是一种x射线成像程序,在该程序中,x射线以快速旋转的速度对准患者。机器收集的信号将存储在计算机中,以生成身体的横截面图像,也称为“切片”。

这些切片被称为断层图像,包含比常规x射线更详细的信息。一系列切片可以数字“叠加”在一起,形成患者的3D图像,从而更容易识别和定位基本结构以及可能的肿瘤或异常。

步骤如下,首先开始下载2019年和2020年的数据集。

然后,将这两个数据集合并到它们的训练、验证和测试文件夹中。下一步是读取CT扫描图像,并将CT扫描图像的每个切片转换为一系列PNG原始图像和遮罩。后来,使用这个Github仓库中的UNet模型,并训练了一个UNet模型。

数据理解:在开始数据处理和训练之前,我想加载几个NIfTI文件,以便更熟悉它们的3D数据结构,能够可视化它们并从图像中提取元数据。

下载完VerSe数据集后,我打开了一个*.nii.gz*文件。通过读取一个文件并查看CT扫描图像的一个特定切片,我能够运行Numpy transpose功能,以轴向、矢状和冠状三种不同视图查看一个切片。

在对原始图像更加熟悉,能够从原始3D图像中提取一个切片后,现在是时候查看同一切片的遮罩文件了。

正如你在下面的图片中所看到的,能够将遮罩切片覆盖在原始图像切片上。我们在这里看到渐变色的原因是,遮罩文件不仅存在每个脊柱的定义区域,而且它们还有不同的标签(用不同的颜色显示),以及每个脊柱的编号或标签。为了更好地理解脊柱标记,你可以参考本页。

数据准备:

数据准备的任务是从原始图像和遮罩文件中的每个3D CT扫描文件生成图像切片。

它首先使用NiBabel库读取“.zip”格式的原始图像和遮罩图像,并将其转换为Numpy数组。然后检查每个3D图像,检查每个图像的视角,并尝试将大部分图像转换为矢状视图。

接下来,我从每个切片生成PNG文件,并将其存储为“L”格式,即灰度值。在这种情况下,我们不需要生成RGB图像。

在这个任务中,使用了UNet体系结构,以便能够在数据集上应用语义分割。为了更好地了解UNet和语义切分,建议查看这个博客。

使用了Pytorch和Pytorchvision来完成这项任务。正如提到的,这个仓库使用PyTorch很好地实现了UNet,一直在使用它的一些代码。

由于正在使用NIfTI文件,并且为了能够在python中读取这些文件,将使用NiBabel库。NiBabel是一个python库,用于读取和写入一些常见的医学和神经成像文件格式,如NIfTI文件。

Dice分数:为了评估我们的模型在语义分割任务中的表现,我们可以使用Dice分数。Dice系数是2*重叠区域(预测遮罩区域和真实遮罩区域之间)除以两幅图像中的像素总数。

训练:首先我定义了UNet类,然后定义了PyTorch数据集类,其中包括读取和预处理图像。预处理任务包括加载PNG文件,将它们全部调整为一个大小(在本例中为250x250),并将它们全部转换为NumPy数组,然后再转换为PyTorch张量。

通过调用dataset类(VerSeDataset),我们可以在我定义的批内准备数据。为了确保原始图像和遮罩图像之间的映射是正确的,我调用next(iter(valid_dataloader))来获取批次中的下一个项目并将其可视化。

后来将模型定义为model=UNet(n_channels=1,n_classes=1)。通道数是1,因为有一个灰度图像而不是RGB,如果你的图像是RGB图像,你可以将n_channels改为3。类的数量是1,因为只有一个类来判断一个像素是否是脊柱的一部分。如果你的问题是多类分割,你可以将类的数量设置为你有多少个类。

后来,训练了模型。对于每个批次,首先计算损失值,通过反向传播更新参数。后来再次检查了所有批次,只计算了验证数据集的损失,并存储了损失值。接下来,对train和validation的损失值进行了可视化观察,并跟踪了模型的性能。

保存模型后,能够抓取其中一张图像并将其传递给经过训练的模型,并收到一张预测的遮罩图像。通过将原始、真实蒙版和预测蒙版的三幅图像并排绘制,能够直观地评估结果。

从上图可以看出,模型在矢状面和轴向视图上都表现得非常好,因为预测的遮罩与真实的遮罩区域非常相似。

完整的代码:

作者:Mazi Boustani

日期:2021年12月24日

目的:使用PyTorch训练UNet模型,使其能够使用VerSe数据集分割脊柱

import numpy as np

import pandas as pd

import os

from os import listdir

from os.path import splitext

import glob

import shutil

import random

from pathlib import Path

from PIL import Image

from tqdm import tqdm

import matplotlib.pyplot as plt

%matplotlib inline

try:

import nibabel as nib

except:

raise ImportError('Install NIBABEL')

import torch

import torch.nn as nn

from torch import Tensor

import torch.nn.functional as F

from torch import optim

import torchvision.transforms as T

from torch.utils.data import DataLoader, random_split

from torch.utils.data import Dataset

# set folder paths for train and validation data

data_folder_path = "/Users/mazi/Projects/other/CT/data"

train_data = data_folder_path + "/verse_19_20_training/"

validation_data = data_folder_path + "/verse_19_20_validation/"

数据理解

# get one image to load

train_data_raw_image = train_data + "/rawdata/sub-verse521/sub-verse521_dir-ax_ct.nii.gz"

one_image = nib.load(train_data_raw_image)

# look at image shape

print(one_image.shape)

# look at image header. To understand header please refer to: https://brainder.org/2012/09/23/the-nifti-file-format/

print(one_image.header)

# look at the raw data

one_image_data = one_image.get_fdata()

print(one_image_data)

# Visualize one image in three different angles

one_image_data_axial = one_image_data

# change the view

one_image_data_sagittal = np.transpose(one_image_data, [2,1,0])

one_image_data_sagittal = np.flip(one_image_data_sagittal, axis=0)

# change the view

one_image_data_coronal = np.transpose(one_image_data, [2,0,1])

one_image_data_coronal = np.flip(one_image_data_coronal, axis=0)

fig, ax = plt.subplots(1, 3, figsize = (60, 60))

ax[0].imshow(one_image_data_axial[:,:,10], cmap ='bone')

ax[0].set_title("Axial view", fontsize=60)

ax[1].imshow(one_image_data_sagittal[:,:,260], cmap ='bone')

ax[1].set_title("Sagittal view", fontsize=60)

ax[2].imshow(one_image_data_coronal[:,:,200], cmap ='bone')

ax[2].set_title("Coronal view", fontsize=60)

plt.show()

# Overlay a mask on top of raw image (one slice of CT-scan)

train_data_mask_image = train_data + "derivatives/sub-verse521/sub-verse521_dir-ax_seg-vert_msk.nii.gz"

train_data_mask_image = nib.load(train_data_mask_image).get_fdata()

plt.figure(figsize=(10,10))

rotated_raw = np.transpose(one_image_data, [2,1,0])

rotated_raw = np.flip(rotated_raw, axis=0)

plt.imshow(rotated_raw[:,:,260], cmap ='bone', interpolation='none')

train_data_mask_image[train_data_mask_image == 0 ] = np.nan

rotated_mask = np.transpose(train_data_mask_image, [2,1,0])

rotated_mask = np.flip(rotated_mask, axis=0)

plt.imshow(rotated_mask[:,:,260], cmap ='cool')

预处理数据

# Set paths to store processed train and validation raw images and masks

processed_train = "./processed_train/"

processed_validation = "./processed_validation/"

processed_train_raw_images = processed_train + "raw_images/"

processed_train_masks = processed_train + "masks/"

processed_validation_raw_images = processed_validation + "raw_images/"

processed_validation_masks = processed_validation + "masks/"

# Read all 2019 and 2020 raw files, both train and validation

raw_train_files = glob.glob(os.path.join(train_data, 'rawdatanii.gz'))

raw_validation_files = glob.glob(os.path.join(validation_data, 'rawdatanii.gz'))

print("Raw images count train: {0}, validation: {1}".format(len(raw_train_files), len(raw_validation_

# Read all 2019 and 2020 raw files, both train and validation

raw_train_files = glob.glob(os.path.join(train_data, 'rawdatanii.gz'))

raw_validation_files = glob.glob(os.path.join(validation_data, 'rawdatanii.gz'))

print("Raw images count train: {0}, validation: {1}".format(len(raw_train_files), len(raw_validation_files)))

# Read all 2019 and 2020 derivatives files, both train and validation

masks_train_files = glob.glob(os.path.join(train_data, 'derivativesnii.gz'))

masks_validation_files = glob.glob(os.path.join(validation_data, 'derivativesnii.gz'))

print("Masks images count train: {0}, validation: {1}".format(len(masks_train_files), len(masks_validation_files)))

def read_file(nii_file):

'''

Read .nii.gz file.

Args:

nii_file (str): a file path.

Return:

3D numpy array of CT image data.

'''

return np.asanyarray(nib.load(nii_file).dataobj)

def save_file(raw_data, label_data, file_name, index, output_raw_file_path, output_label_file_path):

'''

Save file into npz format.

Args:

raw_data (array): 2D numpy array of raw image data.

label_data (array): 2D numpy array of label image data.

file_name (str): file name.

index (int): slice of CT image.

output_raw_file_path (str): Path to all raw files.

output_label_file_path (str): Path to all mask files.

'''

# replace all non-zero pixels to 1

label_data = np.where(label_data > 0, 1, label_data)

unique_values = np.unique(label_data)

# if data has pixel with value of 1 means it is a positive datapoint

if len(unique_values) > 1:

raw_file_name = "{0}{1}_{2}.png".format(output_raw_file_path, file_name, index)

im = Image.fromarray(raw_data)

im = im.convert("L")

im.save(raw_file_name)

label_file_name = "{0}{1}_{2}.png".format(output_label_file_path, file_name, index)

im = Image.fromarray(label_data)

im = im.convert("L")

im.save(label_file_name)

def is_diagonal(matrix):

'''

Check if givem matrix is diagonal or not.

Args:

matrix (np array): numpy array

'''

for i in range(0, 3):

for j in range(0, 3) :

if ((i != j) and (matrix[i][j] != 0)):

return False

return True

def generate_data(raw_file, label_file, file_name, output_raw_file_path, output_label_file_path):

'''

Main function to read each raw and label file and generate series of images

per each slice.

Args:

raw_file (str): path to raw file.

label_file (str): path to label file.

file_name (str): file name.

output_raw_file_path (str): Path to all raw files.

output_label_file_path (str): Path to all mask files.

'''

# If skip every 2 slice. Adjacent slices can be very similar to each other and

# will generate redundant data

skip_slice = 3

continue_it = True

raw_data = read_file(raw_file)

label_data = read_file(label_file)

if "split" in raw_file:

continue_it = False

affine = nib.load(raw_file).affine

if is_diagonal(affine[:3, :3]):

transposed_raw_data = np.transpose(raw_data, [2,1,0])

transposed_raw_data = np.flip(transposed_raw_data)

transposed_label_data = np.transpose(label_data, [2,1,0])

transposed_label_data = np.flip(transposed_label_data)

else:

transposed_raw_data = np.rot90(raw_data)

transposed_raw_data = np.flip(transposed_raw_data)

transposed_label_data = np.rot90(label_data)

transposed_label_data = np.flip(transposed_label_data)

if continue_it:

if transposed_raw_data.shape:

slice_count = transposed_raw_data.shape[-1]

print("File name: ", file_name, " - Slice count: ", slice_count)

# skip some slices

for each_slice in range(1, slice_count, skip_slice):

save_file(transposed_raw_data[:,:,each_slice],

transposed_label_data[:,:,each_slice],

file_name,

each_slice,

output_raw_file_path,

output_label_file_path)

# Loop over raw images and masks and generate 'PNG' images.

print("Processing started.")

for each_raw_file in raw_train_files:

raw_file_name = each_raw_file.split("/")[-1].split("_ct.nii.gz")[0]

for each_mask_file in masks_train_files:

if raw_file_name in each_mask_file.split("/")[-1]:

generate_data(each_raw_file, each_mask_file, raw_file_name, processed_train_raw_images, processed_train_masks)

print("Processing train data done.")

# Loop over raw images and masks and generate 'PNG' images.

for each_raw_file in raw_validation_files:

raw_file_name = each_raw_file.split("/")[-1].split("_ct.nii.gz")[0]

for each_mask_file in masks_validation_files:

if raw_file_name in each_mask_file.split("/")[-1]:

generate_data(each_raw_file, each_mask_file, raw_file_name, processed_validation_raw_images, processed_validation_masks)

print("Processing validation data done.")

训练

# Define model parameters

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# image size to convert to

IMAGE_HEIGHT = 250

IMAGE_WIDTH = 250

LEARNING_RATE = 1e-4

BATCH_SIZE = 10

EPOCHS = 10

NUM_WORKERS = 8

# Set the device

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# UNet model parts

# Source code: https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py

class DoubleConv(nn.Module):

"""(convolution => [BN] => ReLU) * 2"""

def __init__(self, in_channels, out_channels, mid_channels=None):

super().__init__()

if not mid_channels:

mid_channels = out_channels

self.double_conv = nn.Sequential(

nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),

nn.BatchNorm2d(mid_channels),

nn.ReLU(inplace=True),

nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),

nn.BatchNorm2d(out_channels),

nn.ReLU(inplace=True)

def forward(self, x):

return self.double_conv(x)

class Down(nn.Module):

"""Downscaling with maxpool then double conv"""

def __init__(self, in_channels, out_channels):

super().__init__()

self.maxpool_conv = nn.Sequential(

nn.MaxPool2d(2),

DoubleConv(in_channels, out_channels)

def forward(self, x):

return self.maxpool_conv(x)

class Up(nn.Module):

"""Upscaling then double conv"""

def __init__(self, in_channels, out_channels, bilinear=True):

super().__init__()

# if bilinear, use the normal convolutions to reduce the number of channels

if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels)

def forward(self, x1, x2):

x1 = self.up(x1)

# input is CHW

diffY = x2.size()[2] - x1.size()[2]

diffX = x2.size()[3] - x1.size()[3]

x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,

diffY // 2, diffY - diffY // 2])

# if you have padding issues, see

# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a

# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

x = torch.cat([x2, x1], dim=1)

return self.conv(x)

class OutConv(nn.Module):

def __init__(self, in_channels, out_channels):

super(OutConv, self).__init__()

self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

def forward(self, x):

return self.conv(x)

# Defining UNet architecture

# Source code: https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py

class UNet(nn.Module):

def __init__(self, n_channels, n_classes, bilinear=True):

super(UNet, self).__init__(

self.n_channels = n_channels

self.n_classes = n_classes

self.bilinear = bilinear

self.inc = DoubleConv(n_channels, 64)

self.down1 = Down(64, 128)

self.down2 = Down(128, 256)

self.down3 = Down(256, 512)

factor = 2 if bilinear else 1

self.down4 = Down(512, 1024 // factor)

self.up1 = Up(1024, 512 // factor, bilinear)

self.up2 = Up(512, 256 // factor, bilinear)

self.up3 = Up(256, 128 // factor, bilinear)

self.up4 = Up(128, 64, bilinear)

self.outc = OutConv(64, n_classes)

def forward(self, x):

x1 = self.inc(x)

x2 = self.down1(x1)

x3 = self.down2(x2)

x4 = self.down3(x3)

x5 = self.down4(x4)

x = self.up1(x5, x4)

x = self.up2(x, x3)

x = self.up3(x, x2)

x = self.up4(x, x1)

logits = self.outc(x)

return logits

# Define PyTorch dataset class

# This class will access the images and masks, preprocess them for training and validation

class VerSeDataset(Dataset):

def __init__(self, raw_images_path, masks_path, images_name):

self.raw_images_path = raw_images_path

self.masks_path = masks_path

self.images_name = images_name

def __len__(self):

return len(self.images_name)

def __getitem__(self, index):

# get image and mask for a given index

img_path = os.path.join(self.raw_images_path, self.images_name[index])

mask_path = os.path.join(self.masks_path, self.images_name[index])

# read the image and mask

image = Image.open(img_path)

mask = Image.open(mask_path)

# resize image and change the shape to (1, image_width, image_height)

w, h = image.size

image = image.resize((w, h), resample=Image.BICUBIC)

image = T.Resize(size=(IMAGE_WIDTH, IMAGE_HEIGHT))(image)

image_ndarray = np.asarray(image)

image_ndarray = image_ndarray.reshape(1, image_ndarray.shape[0], image_ndarray.shape[1])

# resize the mask. Mask shape is (image_width, image_height)

mask = mask.resize((w, h), resample=Image.NEAREST)

mask = T.Resize(size=(IMAGE_WIDTH, IMAGE_HEIGHT))(mask)

mask_ndarray = np.asarray(mask)

return {

'image': torch.as_tensor(image_ndarray.copy()).float().contiguous(),

'mask': torch.as_tensor(mask_ndarray.copy()).float().contiguous(

# Get path for all images and masks

train_images_paths = os.listdir(processed_train_raw_images)

train_masks_paths = os.listdir(processed_train_masks)

validation_images_paths = os.listdir(processed_validation_raw_images)

validation_masks_paths = os.listdir(processed_validation_masks)

# Load both images and masks data

train_data = VerSeDataset(processed_train_raw_images, processed_train_masks, train_images_paths)

valid_data = VerSeDataset(processed_validation_raw_images, processed_validation_masks, validation_images_paths)

# Create PyTorch DataLoader

train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

valid_dataloader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False)

# Looking at one image and mask from one batch just to check them visually

next_image = next(iter(valid_dataloader))

fig, ax = plt.subplots(1, 2, figsize = (60, 60))

ax[0].imshow(next_image['image'][0][0,:,:], cmap ='bone')

ax[0].set_title("Raw image", fontsize=60)

ax[1].imshow(next_image['mask'][0][:,:], cmap ='bone')

ax[1].set_title("Mask image", fontsize=60)

plt.show()

# Defining Dice loss class

# Source code: https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch

class DiceLoss(nn.Module):

def __init__(self, weight=None, size_average=True):

super(DiceLoss, self).__init__()

def forward(self, inputs, targets, smooth=1):

inputs = torch.sigmoid(inputs)

# flatten label and prediction tensors

inputs = inputs.view(-1)

targets = targets.view(-1)

intersection = (inputs * targets).sum()

dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)

bce = F.binary_cross_entropy_with_logits(inputs, targets)

pred = torch.sigmoid(inputs)

loss = bce * 0.5 + dice * (1 - 0.5)

# subtract 1 to calculate loss from dice value

return 1 - dice

# Define model as UNet

model = UNet(n_channels=1, n_classes=1)

model.to(device=device)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Train and validate

train_loss = []

val_loss = []

for epoch in range(EPOCHS):

model.train()

train_running_loss = 0.0

counter = 0

with tqdm(total=len(train_data), desc=f'Epoch {epoch + 1}/{EPOCHS}', unit='img') as pbar:

for batch in train_dataloader:

counter+=1

image = batch['image'].to(DEVICE)

mask = batch['mask'].to(DEVICE)

optimizer.zero_grad()

outputs = model(image)

outputs = outputs.squeeze(1)

loss = DiceLoss()(outputs, mask)

train_running_loss += loss.item()

loss.backward()

optimizer.step()

pbar.update(image.shape[0])

pbar.set_postfix(**{'loss (batch)': loss.item()})

train_loss.append(train_running_loss/counter)

model.eval()

valid_running_loss = 0.0

counter = 0

with torch.no_grad():

for i, data in enumerate(valid_dataloader):

counter += 1

image = data['image'].to(DEVICE)

mask = data['mask'].to(DEVICE)

outputs = model(image)

outputs = outputs.squeeze(1)

loss = DiceLoss()(outputs, mask)

valid_running_loss += loss.item()

val_loss.append(valid_running_loss)

Epoch 1/10: 100%|██████████| 4790/4790 [4:00:34<00:00, 3.01s/img, loss (batch)=0.385]

Epoch 2/10: 100%|██████████| 4790/4790 [4:00:02<00:00, 3.01s/img, loss (batch)=0.268]

Epoch 3/10: 100%|██████████| 4790/4790 [3:57:30<00:00, 2.98s/img, loss (batch)=0.152]

Epoch 4/10: 100%|██████████| 4790/4790 [3:57:05<00:00, 2.97s/img, loss (batch)=0.105]

Epoch 5/10: 100%|██████████| 4790/4790 [4:08:29<00:00, 3.11s/img, loss (batch)=0.103]

Epoch 6/10: 100%|██████████| 4790/4790 [4:04:12<00:00, 3.06s/img, loss (batch)=0.0874]

Epoch 7/10: 100%|██████████| 4790/4790 [4:02:00<00:00, 3.03s/img, loss (batch)=0.0759]

Epoch 8/10: 100%|██████████| 4790/4790 [3:58:32<00:00, 2.99s/img, loss (batch)=0.0655]

Epoch 9/10: 100%|██████████| 4790/4790 [4:00:47<00:00, 3.02s/img, loss (batch)=0.0644]

Epoch 10/10: 100%|██████████| 4790/4790 [4:08:54<00:00, 3.12s/img, loss (batch)=0.0604]

# Plot train vs validation loss

plt.figure(figsize=(10, 7))

plt.plot(train_loss, color="orange", label='train loss')

plt.plot(val_loss, color="red", label='validation loss')

plt.xlabel("Epochs")

plt.ylabel("Loss")

plt.legend()

plt.show()

# Save the trained model

torch.save({

'epoch': EPOCHS,

'model_state_dict': model.state_dict(),

'optimizer_state_dict': optimizer.state_dict(),}, "./unet_model.pth")

# Visually look at one prediction

next_image = next(iter(valid_dataloader))

# do predict

outputs = model(next_image['image'].float())

outputs = outputs.detach().cpu()

loss = DiceLoss()(outputs, next_image['mask'])

print("Dice Score: ", 1- loss.item())

outputs[outputs<=0.0] = 0

outputs[outputs>0.0] = 1.0

# plot all three images

fig, ax = plt.subplots(1, 3, figsize = (60, 60))

ax[0].imshow(next_image['image'][0][0,:,:], cmap ='bone')

ax[0].set_title("Raw Image", fontsize=60)

ax[1].imshow(next_image['mask'][0][:,:], cmap ='bone')

ax[1].set_title("True Mask", fontsize=60)

ax[2].imshow(outputs[0,0,:,:], cmap ='bone')

ax[2].set_title("Predicted Mask", fontsize=60)

plt.show()

未来的工作:这个任务也可以用3D UNet完成,这可能是学习脊柱结构的更好方法。

因为我们对每个椎骨的每个遮罩区域都有标签,所以我们可以进一步进行多类遮罩分割。此外,当图像视图为矢状视图时,模型性能最好,因此,将所有切片转换为矢状视图可能会得到更好的结果。

感谢阅读!

原文标题:UNet分割脊柱


郑重声明:文章仅代表原作者观点,不代表本站立场;如有侵权、违规,可直接反馈本站,我们将会作修改或删除处理。
猜你喜欢