# 载入需要的库
import pytorch_lightning as pl # 用于构建神经网络模型和训练循环
import torch # 用于构建神经网络模型和训练循环
import segmentation_models_pytorch as smp # 用于提供预训练的图像分割模型
import numpy as np # 用于数据处理
import cv2 # 用于图像处理和可视化
from segmentation_models_pytorch.encoders import get_preprocessing_fn # 用于数据预处理
import matplotlib.pyplot as plt # 用于绘制图表和图像
from typing import Dict, List # 用于类型提示
32 训练 SegVeg 模型
本章介绂了如何配置 SegVeg 神经网络模型,并使用 VegAnn 数据集对模型进行训练。
32.1 载入需要的库
这些库包括 PyTorch Lightning、Segmentation Models PyTorch、OpenCV、Matplotlib 等。他们在这个项目中的功能如下:
- PyTorch Lightning: 用于构建神经网络模型和训练循环。
- Segmentation Models PyTorch: 提供了许多预训练的图像分割模型,如 U-net、DeepLabV3、PSPNet 等。
- OpenCV: 用于图像处理和可视化。
- Matplotlib: 用于绘制图表和图像。
- 其他库: 用于数据处理、评估指标计算等。
32.2 模型初始化
首先,配置一个名为 VegAnnModel
的 PyTorch Lightning 模型,用于训练 U-net 模型。这个模型包含以下几个部分:
__init__
方法:初始化模型,包括选择模型架构、编码器名称、输入通道数、输出类别数等。forward
方法:定义前向传播过程,包括图像预处理、模型推理和输出。shared_step
方法:定义共享的训练/验证/测试步骤,包括计算损失、评估指标等。shared_epoch_end
方法:定义共享的训练/验证/测试 epoch 结束方法,用于计算并输出评估指标。training_step
方法:定义训练步骤,包括调用shared_step
方法并保存输出。on_train_epoch_end
方法:定义训练 epoch 结束方法,用于调用shared_epoch_end
方法。validation_step
方法:定义验证步骤,包括调用shared_step
方法并保存输出。on_validation_epoch_end
方法:定义验证 epoch 结束方法,用于调用shared_epoch_end
方法。test_step
方法:定义测试步骤,包括调用shared_step
方法并保存输出。on_test_epoch_end
方法:定义测试 epoch 结束方法,用于调用shared_epoch_end
方法。configure_optimizers
方法:定义优化器,这里使用 Adam 优化器。
另外,还定义了一个辅助函数:
colorTransform_VegGround
方法:定义一个颜色转换函数,用于将预测的掩膜可视化。
# Initialize the model
class VegAnnModel(pl.LightningModule):
def __init__(self, arch: str, encoder_name: str, in_channels: int, out_classes: int, **kwargs):
super().__init__()
self.model = smp.create_model(
arch,=encoder_name,
encoder_name=in_channels,
in_channels=out_classes,
classes**kwargs,
)
# preprocessing parameteres for image
= smp.encoders.get_preprocessing_params(encoder_name)
params self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1))
self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1))
# for image segmentation dice loss could be the best first choice
self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
self.train_outputs, self.val_outputs, self.test_outputs = [], [], []
def forward(self, image: torch.Tensor):
# normalize image here #todo
= (image - self.mean) / self.std
image = self.model(image)
mask return mask
def shared_step(self, batch: Dict, stage: str):
= batch["image"]
image
# Shape of the image should be (batch_size, num_channels, height, width)
# if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
assert image.ndim == 4
# Check that image dimensions are divisible by 32,
# encoder and decoder connected by `skip connections` and usually encoder have 5 stages of
# downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have
# following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
# and we will get an error trying to concat these features
= image.shape[2:]
h, w assert h % 32 == 0 and w % 32 == 0
= batch["mask"]
mask
# Shape of the mask should be [batch_size, num_classes, height, width]
# for binary segmentation num_classes = 1
assert mask.ndim == 4
# Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
assert mask.max() <= 1.0 and mask.min() >= 0
= self.forward(image)
logits_mask
# Predicted mask contains logits, and loss_fn param `from_logits` is set to True
= self.loss_fn(logits_mask, mask)
loss
# Lets compute metrics for some threshold
# first convert mask values to probabilities, then
# apply thresholding
= logits_mask.sigmoid()
prob_mask = (prob_mask > 0.5).float()
pred_mask
# We will compute IoU metric by two ways
# 1. dataset-wise
# 2. image-wise
# but for now we just compute true positive, false positive, false negative and
# true negative 'pixels' for each image and class
# these values will be aggregated in the end of an epoch
= smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")
tp, fp, fn, tn
return {
"loss": loss,
"tp": tp,
"fp": fp,
"fn": fn,
"tn": tn,
}
def shared_epoch_end(self, outputs: List[Dict], stage: str):
# aggregate step metics
= torch.cat([x["tp"] for x in outputs])
tp = torch.cat([x["fp"] for x in outputs])
fp = torch.cat([x["fn"] for x in outputs])
fn = torch.cat([x["tn"] for x in outputs])
tn
# per image IoU means that we first calculate IoU score for each image
# and then compute mean over these scores
= smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
per_image_iou = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro-imagewise")
per_image_f1 = smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro-imagewise")
per_image_acc # dataset IoU means that we aggregate intersection and union over whole dataset
# and then compute IoU score. The difference between dataset_iou and per_image_iou scores
# in this particular case will not be much, however for dataset
# with "empty" images (images without target class) a large gap could be observed.
# Empty images influence a lot on per_image_iou and much less on dataset_iou.
= smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
dataset_iou = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
dataset_f1 = smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro")
dataset_acc
= {
metrics f"{stage}_per_image_iou": per_image_iou,
f"{stage}_dataset_iou": dataset_iou,
f"{stage}_per_image_f1": per_image_f1,
f"{stage}_dataset_f1": dataset_f1,
f"{stage}_per_image_acc": per_image_acc,
f"{stage}_dataset_acc": dataset_acc,
}
self.log_dict(metrics, prog_bar=True, sync_dist=True, rank_zero_only=True)
def training_step(self, batch: Dict, batch_idx: int):
= self.shared_step(batch, "train")
step_outputs self.train_outputs.append(step_outputs)
return step_outputs
def on_train_epoch_end(self):
self.shared_epoch_end(self.train_outputs, "train")
self.train_outputs = []
def validation_step(self, batch: Dict, batch_idx: int):
= self.shared_step(batch, "valid")
step_outputs self.val_outputs.append(step_outputs)
return step_outputs
def on_validation_epoch_end(self, *args, **kwargs):
self.shared_epoch_end(self.val_outputs, "valid")
self.val_outputs = []
def test_step(self, batch: Dict, batch_idx: int):
= self.shared_step(batch, "test")
step_outputs self.test_outputs.append(step_outputs)
return step_outputs
def on_test_epoch_end(self):
self.shared_epoch_end(self.test_outputs, "test")
self.test_outputs = []
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.0001)
def colorTransform_VegGround(im,X_true,alpha_vert,alpha_g):
= alpha_vert
alpha = [97,65,38]
color # color = [x / 255 for x in color]
=np.copy(im)
imagefor c in range(3):
=np.where(X_true == 0,image[:, :, c] *(1 - alpha) + alpha * color[c] ,image[:, :, c])
image[:, :, c] = alpha_g
alpha = [34,139,34]
color # color = [x / 255 for x in color]
for c in range(3):
=np.where(X_true == 1,image[:, :, c] *(1 - alpha) + alpha * color[c] ,image[:, :, c])
image[:, :, c] return image
现在,我们可以使用 VegAnnModel
类初始化一个 U-net 模型。这个模型使用 ResNet34 作为编码器,输入通道数为 3(RGB 图像),输出类别数为 1(二值分割)。
# Initialize the model
= VegAnnModel("Unet", "resnet34", in_channels=3, out_classes=1) model
接下来,使用 torchinfo
可视化模型的结构。
from torchinfo import summary
# Show detailed model summary using torchinfo
=(1, 3, 512, 512),
summary(model, input_size=["input_size", "output_size", "num_params", "kernel_size"],
col_names=4) depth
======================================================================================================================================================
Layer (type:depth-idx) Input Shape Output Shape Param # Kernel Shape
======================================================================================================================================================
VegAnnModel [1, 3, 512, 512] [1, 1, 512, 512] -- --
├─Unet: 1-1 [1, 3, 512, 512] [1, 1, 512, 512] -- --
│ └─ResNetEncoder: 2-1 [1, 3, 512, 512] [1, 3, 512, 512] -- --
│ │ └─Conv2d: 3-1 [1, 3, 512, 512] [1, 64, 256, 256] 9,408 [7, 7]
│ │ └─BatchNorm2d: 3-2 [1, 64, 256, 256] [1, 64, 256, 256] 128 --
│ │ └─ReLU: 3-3 [1, 64, 256, 256] [1, 64, 256, 256] -- --
│ │ └─MaxPool2d: 3-4 [1, 64, 256, 256] [1, 64, 128, 128] -- 3
│ │ └─Sequential: 3-5 [1, 64, 128, 128] [1, 64, 128, 128] -- --
│ │ │ └─BasicBlock: 4-1 [1, 64, 128, 128] [1, 64, 128, 128] 73,984 --
│ │ │ └─BasicBlock: 4-2 [1, 64, 128, 128] [1, 64, 128, 128] 73,984 --
│ │ │ └─BasicBlock: 4-3 [1, 64, 128, 128] [1, 64, 128, 128] 73,984 --
│ │ └─Sequential: 3-6 [1, 64, 128, 128] [1, 128, 64, 64] -- --
│ │ │ └─BasicBlock: 4-4 [1, 64, 128, 128] [1, 128, 64, 64] 230,144 --
│ │ │ └─BasicBlock: 4-5 [1, 128, 64, 64] [1, 128, 64, 64] 295,424 --
│ │ │ └─BasicBlock: 4-6 [1, 128, 64, 64] [1, 128, 64, 64] 295,424 --
│ │ │ └─BasicBlock: 4-7 [1, 128, 64, 64] [1, 128, 64, 64] 295,424 --
│ │ └─Sequential: 3-7 [1, 128, 64, 64] [1, 256, 32, 32] -- --
│ │ │ └─BasicBlock: 4-8 [1, 128, 64, 64] [1, 256, 32, 32] 919,040 --
│ │ │ └─BasicBlock: 4-9 [1, 256, 32, 32] [1, 256, 32, 32] 1,180,672 --
│ │ │ └─BasicBlock: 4-10 [1, 256, 32, 32] [1, 256, 32, 32] 1,180,672 --
│ │ │ └─BasicBlock: 4-11 [1, 256, 32, 32] [1, 256, 32, 32] 1,180,672 --
│ │ │ └─BasicBlock: 4-12 [1, 256, 32, 32] [1, 256, 32, 32] 1,180,672 --
│ │ │ └─BasicBlock: 4-13 [1, 256, 32, 32] [1, 256, 32, 32] 1,180,672 --
│ │ └─Sequential: 3-8 [1, 256, 32, 32] [1, 512, 16, 16] -- --
│ │ │ └─BasicBlock: 4-14 [1, 256, 32, 32] [1, 512, 16, 16] 3,673,088 --
│ │ │ └─BasicBlock: 4-15 [1, 512, 16, 16] [1, 512, 16, 16] 4,720,640 --
│ │ │ └─BasicBlock: 4-16 [1, 512, 16, 16] [1, 512, 16, 16] 4,720,640 --
│ └─UnetDecoder: 2-2 [1, 3, 512, 512] [1, 16, 512, 512] -- --
│ │ └─Identity: 3-9 [1, 512, 16, 16] [1, 512, 16, 16] -- --
│ │ └─ModuleList: 3-10 -- -- -- --
│ │ │ └─DecoderBlock: 4-17 [1, 512, 16, 16] [1, 256, 32, 32] 2,360,320 --
│ │ │ └─DecoderBlock: 4-18 [1, 256, 32, 32] [1, 128, 64, 64] 590,336 --
│ │ │ └─DecoderBlock: 4-19 [1, 128, 64, 64] [1, 64, 128, 128] 147,712 --
│ │ │ └─DecoderBlock: 4-20 [1, 64, 128, 128] [1, 32, 256, 256] 46,208 --
│ │ │ └─DecoderBlock: 4-21 [1, 32, 256, 256] [1, 16, 512, 512] 6,976 --
│ └─SegmentationHead: 2-3 [1, 16, 512, 512] [1, 1, 512, 512] -- --
│ │ └─Conv2d: 3-11 [1, 16, 512, 512] [1, 1, 512, 512] 145 [3, 3]
│ │ └─Identity: 3-12 [1, 1, 512, 512] [1, 1, 512, 512] -- --
│ │ └─Activation: 3-13 [1, 1, 512, 512] [1, 1, 512, 512] -- --
│ │ │ └─Identity: 4-22 [1, 1, 512, 512] [1, 1, 512, 512] -- --
======================================================================================================================================================
Total params: 24,436,369
Trainable params: 24,436,369
Non-trainable params: 0
Total mult-adds (G): 31.26
======================================================================================================================================================
Input size (MB): 3.15
Forward/backward pass size (MB): 574.62
Params size (MB): 97.75
Estimated Total Size (MB): 675.51
======================================================================================================================================================
32.3 加载数据集
我们需要定义一个自定义的数据集类来处理从 Hugging Face 加载的数据。
from src.segveg import VegAnnDataset
from albumentations import Compose, Resize, Normalize, HorizontalFlip, RandomRotate90, ColorJitter, ToFloat
from torch.utils.data import DataLoader
from datasets import load_dataset
# 简化数据增强流程
= Compose([
transform 512, 512),
Resize(=0.5),
HorizontalFlip(p=0.5),
RandomRotate90(p=0.2, contrast=0.2),
ColorJitter(brightness=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Normalize(mean
])
# Load the VegAnn dataset
= load_dataset("simonMadec/VegAnn", split="train")
ds
# sample small size of data to test
# ds = ds.shuffle(seed=42).select(range(50))
# 直接使用 datasets 内置的 train_test_split
= ds.train_test_split(test_size=0.2, seed=42)
split_ds
# Create custom datasets
= VegAnnDataset(split_ds['train'], transform=transform)
train_dataset = VegAnnDataset(split_ds['test'], transform=transform)
val_dataset
# Create data loaders
= DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
train_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0) val_loader
32.4 定义损失器和优化器
通过 Dice loss 函数和 Adam 优化器进行训练。
# 定义损失器和优化器
from torch import nn, optim
# Using Dice loss and Adam optimizer as specified
= smp.losses.DiceLoss(mode='binary')
criterion = optim.Adam(model.parameters(), lr=0.001)
optimizer = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) scheduler
32.5 定义训练循环
下面,我们定义一个训练循环,用于训练 U-net 模型。这个训练循环包括以下几个部分:
train_model
函数:定义了训练循环,包括模型训练、验证、保存最佳模型等。train_model
函数中的wandb.init
:初始化 W&B 项目,用于记录训练过程和结果。train_model
函数中的wandb.log
:记录训练指标到 W&B。train_model
函数中的wandb.save
:保存最佳模型到 W&B。train_model
函数中的torch.amp.GradScaler
:启用混合精度训练和优化配置。train_model
函数中的torch.backends.cudnn.benchmark
:启用 CuDNN 自动调优。train_model
函数中的torch.set_float32_matmul_precision
:优化矩阵运算。
# 定义训练循环
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=3, device="mps"):
# Move model to device
= torch.device(device)
device = model.to(device)
model
# 启用混合精度训练和优化配置
= torch.amp.GradScaler(enabled=(device.type == "cuda"))
scaler = (device.type == 'cuda')
torch.backends.cudnn.benchmark 'high') # 优化矩阵运算
torch.set_float32_matmul_precision(import wandb # 新增导入
# 初始化W&B
wandb.init( ="veg-segmentation",
project={
config"architecture": "U-Net",
"encoder": "resnet34",
"learning_rate": 0.001,
"batch_size": 32,
"epochs": num_epochs
}
)
= float('inf')
best_val_loss for epoch in range(num_epochs):
model.train()= 0.0
running_loss for batch in train_loader:
# 显式转移数据到设备并添加内存监控
= batch["image"].to(device, non_blocking=True)
images = batch["mask"].to(device, non_blocking=True)
masks
=True) # 更高效的梯度清零
optimizer.zero_grad(set_to_none
= torch.float16 if device.type == "mps" else torch.float32
dtype with torch.amp.autocast(device_type=device.type, dtype=dtype, enabled=(device != "cpu")):
= model(images)
outputs = criterion(outputs, masks)
loss
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# 记录内存使用情况
if device.type == 'cuda':
wandb.log({"gpu_mem_alloc": torch.cuda.memory_allocated() / 1e9,
"gpu_mem_reserved": torch.cuda.memory_reserved() / 1e9
})
+= loss.item() * images.size(0)
running_loss
= running_loss / len(train_loader.dataset)
epoch_loss print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}')
# 记录训练指标到W&B
wandb.log({"train_loss": epoch_loss,
"learning_rate": scheduler.get_last_lr()[0]
})
# Validation
eval()
model.= 0.0
running_val_loss with torch.no_grad():
for batch in val_loader:
= batch["image"].to(device)
images = batch["mask"].to(device)
masks
# 直接使用模型输出,不假设它有'out'键
= model(images)
outputs = criterion(outputs, masks)
loss += loss.item() * images.size(0)
running_val_loss
= running_val_loss / len(val_loader.dataset)
val_loss print(f'Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss:.4f}')
# 记录验证指标
"val_loss": val_loss})
wandb.log({
# Save the best model
# 保存最佳模型到W&B
if val_loss < best_val_loss:
'best_model.pth')
wandb.save(= val_loss
best_val_loss "state_dict": model.state_dict()}, 'best_model.pth')
torch.save({
scheduler.step()
print('Training complete.')
32.6 开始训练
首先,检查设备是否可用,然后运行预先定义好的训练循环。
# check device availability
if torch.mps.is_available:
= torch.device("mps")
device elif torch.cuda.is_available():
= torch.device("cuda")
device else:
= torch.device("cpu")
device
# print device
print(f'Found device: {device}')
Found device: mps
训练模型。
=5, device=device) train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs
训练过程耗时较长,可以在 W&B 项目中查看训练过程和结果。
32.7 保存训练结果
# Save the model - 使用与前面一致的格式保存
"state_dict": model.state_dict()}, 'data/segveg/best_model.pth') torch.save({