32  训练 SegVeg 模型

本章介绂了如何配置 SegVeg 神经网络模型,并使用 VegAnn 数据集对模型进行训练。

32.1 载入需要的库

这些库包括 PyTorch Lightning、Segmentation Models PyTorch、OpenCV、Matplotlib 等。他们在这个项目中的功能如下:

  • PyTorch Lightning: 用于构建神经网络模型和训练循环。
  • Segmentation Models PyTorch: 提供了许多预训练的图像分割模型,如 U-net、DeepLabV3、PSPNet 等。
  • OpenCV: 用于图像处理和可视化。
  • Matplotlib: 用于绘制图表和图像。
  • 其他库: 用于数据处理、评估指标计算等。
# 载入需要的库
import pytorch_lightning as pl # 用于构建神经网络模型和训练循环
import torch # 用于构建神经网络模型和训练循环
import segmentation_models_pytorch as smp # 用于提供预训练的图像分割模型
import numpy as np # 用于数据处理
import cv2 # 用于图像处理和可视化
from segmentation_models_pytorch.encoders import get_preprocessing_fn # 用于数据预处理
import matplotlib.pyplot as plt # 用于绘制图表和图像
from typing import Dict, List # 用于类型提示

32.2 模型初始化

首先,配置一个名为 VegAnnModel 的 PyTorch Lightning 模型,用于训练 U-net 模型。这个模型包含以下几个部分:

  • __init__ 方法:初始化模型,包括选择模型架构、编码器名称、输入通道数、输出类别数等。
  • forward 方法:定义前向传播过程,包括图像预处理、模型推理和输出。
  • shared_step 方法:定义共享的训练/验证/测试步骤,包括计算损失、评估指标等。
  • shared_epoch_end 方法:定义共享的训练/验证/测试 epoch 结束方法,用于计算并输出评估指标。
  • training_step 方法:定义训练步骤,包括调用 shared_step 方法并保存输出。
  • on_train_epoch_end 方法:定义训练 epoch 结束方法,用于调用 shared_epoch_end 方法。
  • validation_step 方法:定义验证步骤,包括调用 shared_step 方法并保存输出。
  • on_validation_epoch_end 方法:定义验证 epoch 结束方法,用于调用 shared_epoch_end 方法。
  • test_step 方法:定义测试步骤,包括调用 shared_step 方法并保存输出。
  • on_test_epoch_end 方法:定义测试 epoch 结束方法,用于调用 shared_epoch_end 方法。
  • configure_optimizers 方法:定义优化器,这里使用 Adam 优化器。

另外,还定义了一个辅助函数:

  • colorTransform_VegGround 方法:定义一个颜色转换函数,用于将预测的掩膜可视化。
# Initialize the model
class VegAnnModel(pl.LightningModule):
    def __init__(self, arch: str, encoder_name: str, in_channels: int, out_classes: int, **kwargs):
        super().__init__()
        self.model = smp.create_model(
            arch,
            encoder_name=encoder_name,
            in_channels=in_channels,
            classes=out_classes,
            **kwargs,
        )

        # preprocessing parameteres for image
        params = smp.encoders.get_preprocessing_params(encoder_name)
        self.register_buffer("std", torch.tensor(params["std"]).view(1, 3, 1, 1))
        self.register_buffer("mean", torch.tensor(params["mean"]).view(1, 3, 1, 1))

        # for image segmentation dice loss could be the best first choice
        self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
        self.train_outputs, self.val_outputs, self.test_outputs = [], [], []

    def forward(self, image: torch.Tensor):
        # normalize image here #todo
        image = (image - self.mean) / self.std
        mask = self.model(image)
        return mask

    def shared_step(self, batch: Dict, stage: str):
        image = batch["image"]

        # Shape of the image should be (batch_size, num_channels, height, width)
        # if you work with grayscale images, expand channels dim to have [batch_size, 1, height, width]
        assert image.ndim == 4

        # Check that image dimensions are divisible by 32,
        # encoder and decoder connected by `skip connections` and usually encoder have 5 stages of
        # downsampling by factor 2 (2 ^ 5 = 32); e.g. if we have image with shape 65x65 we will have
        # following shapes of features in encoder and decoder: 84, 42, 21, 10, 5 -> 5, 10, 20, 40, 80
        # and we will get an error trying to concat these features
        h, w = image.shape[2:]
        assert h % 32 == 0 and w % 32 == 0

        mask = batch["mask"]

        # Shape of the mask should be [batch_size, num_classes, height, width]
        # for binary segmentation num_classes = 1
        assert mask.ndim == 4

        # Check that mask values in between 0 and 1, NOT 0 and 255 for binary segmentation
        assert mask.max() <= 1.0 and mask.min() >= 0

        logits_mask = self.forward(image)

        # Predicted mask contains logits, and loss_fn param `from_logits` is set to True
        loss = self.loss_fn(logits_mask, mask)

        # Lets compute metrics for some threshold
        # first convert mask values to probabilities, then
        # apply thresholding
        prob_mask = logits_mask.sigmoid()
        pred_mask = (prob_mask > 0.5).float()

        # We will compute IoU metric by two ways
        #   1. dataset-wise
        #   2. image-wise
        # but for now we just compute true positive, false positive, false negative and
        # true negative 'pixels' for each image and class
        # these values will be aggregated in the end of an epoch
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode="binary")

        return {
            "loss": loss,
            "tp": tp,
            "fp": fp,
            "fn": fn,
            "tn": tn,
        }

    def shared_epoch_end(self, outputs: List[Dict], stage: str):
        # aggregate step metics
        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])

        # per image IoU means that we first calculate IoU score for each image
        # and then compute mean over these scores
        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
        per_image_f1 = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro-imagewise")
        per_image_acc = smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro-imagewise")
        # dataset IoU means that we aggregate intersection and union over whole dataset
        # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
        # in this particular case will not be much, however for dataset
        # with "empty" images (images without target class) a large gap could be observed.
        # Empty images influence a lot on per_image_iou and much less on dataset_iou.
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
        dataset_f1 = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
        dataset_acc = smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro")

        metrics = {
            f"{stage}_per_image_iou": per_image_iou,
            f"{stage}_dataset_iou": dataset_iou,
            f"{stage}_per_image_f1": per_image_f1,
            f"{stage}_dataset_f1": dataset_f1,
            f"{stage}_per_image_acc": per_image_acc,
            f"{stage}_dataset_acc": dataset_acc,
        }

        self.log_dict(metrics, prog_bar=True, sync_dist=True, rank_zero_only=True)

    def training_step(self, batch: Dict, batch_idx: int):
        step_outputs = self.shared_step(batch, "train")
        self.train_outputs.append(step_outputs)
        return step_outputs

    def on_train_epoch_end(self):
        self.shared_epoch_end(self.train_outputs, "train")
        self.train_outputs = []

    def validation_step(self, batch: Dict, batch_idx: int):
        step_outputs = self.shared_step(batch, "valid")
        self.val_outputs.append(step_outputs)
        return step_outputs

    def on_validation_epoch_end(self, *args, **kwargs):
        self.shared_epoch_end(self.val_outputs, "valid")
        self.val_outputs = []

    def test_step(self, batch: Dict, batch_idx: int):
        step_outputs = self.shared_step(batch, "test")
        self.test_outputs.append(step_outputs)
        return step_outputs

    def on_test_epoch_end(self):
        self.shared_epoch_end(self.test_outputs, "test")
        self.test_outputs = []

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)


def colorTransform_VegGround(im,X_true,alpha_vert,alpha_g):
    alpha = alpha_vert
    color = [97,65,38]
    # color = [x / 255 for x in color]
    image=np.copy(im)
    for c in range(3):
        image[:, :, c] =np.where(X_true == 0,image[:, :, c] *(1 - alpha) + alpha * color[c] ,image[:, :, c])
    alpha = alpha_g
    color = [34,139,34]
#    color = [x / 255 for x in color]
    for c in range(3):
        image[:, :, c] =np.where(X_true == 1,image[:, :, c] *(1 - alpha) + alpha * color[c] ,image[:, :, c])
    return image 

现在,我们可以使用 VegAnnModel 类初始化一个 U-net 模型。这个模型使用 ResNet34 作为编码器,输入通道数为 3(RGB 图像),输出类别数为 1(二值分割)。

# Initialize the model
model = VegAnnModel("Unet", "resnet34", in_channels=3, out_classes=1)

接下来,使用 torchinfo 可视化模型的结构。

from torchinfo import summary

# Show detailed model summary using torchinfo
summary(model, input_size=(1, 3, 512, 512), 
    col_names=["input_size", "output_size", "num_params", "kernel_size"],
    depth=4)
======================================================================================================================================================
Layer (type:depth-idx)                             Input Shape               Output Shape              Param #                   Kernel Shape
======================================================================================================================================================
VegAnnModel                                        [1, 3, 512, 512]          [1, 1, 512, 512]          --                        --
├─Unet: 1-1                                        [1, 3, 512, 512]          [1, 1, 512, 512]          --                        --
│    └─ResNetEncoder: 2-1                          [1, 3, 512, 512]          [1, 3, 512, 512]          --                        --
│    │    └─Conv2d: 3-1                            [1, 3, 512, 512]          [1, 64, 256, 256]         9,408                     [7, 7]
│    │    └─BatchNorm2d: 3-2                       [1, 64, 256, 256]         [1, 64, 256, 256]         128                       --
│    │    └─ReLU: 3-3                              [1, 64, 256, 256]         [1, 64, 256, 256]         --                        --
│    │    └─MaxPool2d: 3-4                         [1, 64, 256, 256]         [1, 64, 128, 128]         --                        3
│    │    └─Sequential: 3-5                        [1, 64, 128, 128]         [1, 64, 128, 128]         --                        --
│    │    │    └─BasicBlock: 4-1                   [1, 64, 128, 128]         [1, 64, 128, 128]         73,984                    --
│    │    │    └─BasicBlock: 4-2                   [1, 64, 128, 128]         [1, 64, 128, 128]         73,984                    --
│    │    │    └─BasicBlock: 4-3                   [1, 64, 128, 128]         [1, 64, 128, 128]         73,984                    --
│    │    └─Sequential: 3-6                        [1, 64, 128, 128]         [1, 128, 64, 64]          --                        --
│    │    │    └─BasicBlock: 4-4                   [1, 64, 128, 128]         [1, 128, 64, 64]          230,144                   --
│    │    │    └─BasicBlock: 4-5                   [1, 128, 64, 64]          [1, 128, 64, 64]          295,424                   --
│    │    │    └─BasicBlock: 4-6                   [1, 128, 64, 64]          [1, 128, 64, 64]          295,424                   --
│    │    │    └─BasicBlock: 4-7                   [1, 128, 64, 64]          [1, 128, 64, 64]          295,424                   --
│    │    └─Sequential: 3-7                        [1, 128, 64, 64]          [1, 256, 32, 32]          --                        --
│    │    │    └─BasicBlock: 4-8                   [1, 128, 64, 64]          [1, 256, 32, 32]          919,040                   --
│    │    │    └─BasicBlock: 4-9                   [1, 256, 32, 32]          [1, 256, 32, 32]          1,180,672                 --
│    │    │    └─BasicBlock: 4-10                  [1, 256, 32, 32]          [1, 256, 32, 32]          1,180,672                 --
│    │    │    └─BasicBlock: 4-11                  [1, 256, 32, 32]          [1, 256, 32, 32]          1,180,672                 --
│    │    │    └─BasicBlock: 4-12                  [1, 256, 32, 32]          [1, 256, 32, 32]          1,180,672                 --
│    │    │    └─BasicBlock: 4-13                  [1, 256, 32, 32]          [1, 256, 32, 32]          1,180,672                 --
│    │    └─Sequential: 3-8                        [1, 256, 32, 32]          [1, 512, 16, 16]          --                        --
│    │    │    └─BasicBlock: 4-14                  [1, 256, 32, 32]          [1, 512, 16, 16]          3,673,088                 --
│    │    │    └─BasicBlock: 4-15                  [1, 512, 16, 16]          [1, 512, 16, 16]          4,720,640                 --
│    │    │    └─BasicBlock: 4-16                  [1, 512, 16, 16]          [1, 512, 16, 16]          4,720,640                 --
│    └─UnetDecoder: 2-2                            [1, 3, 512, 512]          [1, 16, 512, 512]         --                        --
│    │    └─Identity: 3-9                          [1, 512, 16, 16]          [1, 512, 16, 16]          --                        --
│    │    └─ModuleList: 3-10                       --                        --                        --                        --
│    │    │    └─DecoderBlock: 4-17                [1, 512, 16, 16]          [1, 256, 32, 32]          2,360,320                 --
│    │    │    └─DecoderBlock: 4-18                [1, 256, 32, 32]          [1, 128, 64, 64]          590,336                   --
│    │    │    └─DecoderBlock: 4-19                [1, 128, 64, 64]          [1, 64, 128, 128]         147,712                   --
│    │    │    └─DecoderBlock: 4-20                [1, 64, 128, 128]         [1, 32, 256, 256]         46,208                    --
│    │    │    └─DecoderBlock: 4-21                [1, 32, 256, 256]         [1, 16, 512, 512]         6,976                     --
│    └─SegmentationHead: 2-3                       [1, 16, 512, 512]         [1, 1, 512, 512]          --                        --
│    │    └─Conv2d: 3-11                           [1, 16, 512, 512]         [1, 1, 512, 512]          145                       [3, 3]
│    │    └─Identity: 3-12                         [1, 1, 512, 512]          [1, 1, 512, 512]          --                        --
│    │    └─Activation: 3-13                       [1, 1, 512, 512]          [1, 1, 512, 512]          --                        --
│    │    │    └─Identity: 4-22                    [1, 1, 512, 512]          [1, 1, 512, 512]          --                        --
======================================================================================================================================================
Total params: 24,436,369
Trainable params: 24,436,369
Non-trainable params: 0
Total mult-adds (G): 31.26
======================================================================================================================================================
Input size (MB): 3.15
Forward/backward pass size (MB): 574.62
Params size (MB): 97.75
Estimated Total Size (MB): 675.51
======================================================================================================================================================

32.3 加载数据集

我们需要定义一个自定义的数据集类来处理从 Hugging Face 加载的数据。

from src.segveg import VegAnnDataset
from albumentations import Compose, Resize, Normalize, HorizontalFlip, RandomRotate90, ColorJitter, ToFloat
from torch.utils.data import DataLoader
from datasets import load_dataset

# 简化数据增强流程
transform = Compose([
    Resize(512, 512),
    HorizontalFlip(p=0.5),
    RandomRotate90(p=0.5),
    ColorJitter(brightness=0.2, contrast=0.2),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the VegAnn dataset
ds = load_dataset("simonMadec/VegAnn", split="train")

# sample small size of data to test
# ds = ds.shuffle(seed=42).select(range(50))

# 直接使用 datasets 内置的 train_test_split
split_ds = ds.train_test_split(test_size=0.2, seed=42)

# Create custom datasets
train_dataset = VegAnnDataset(split_ds['train'], transform=transform)
val_dataset = VegAnnDataset(split_ds['test'], transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0)

32.4 定义损失器和优化器

通过 Dice loss 函数和 Adam 优化器进行训练。

# 定义损失器和优化器
from torch import nn, optim

# Using Dice loss and Adam optimizer as specified
criterion = smp.losses.DiceLoss(mode='binary')
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

32.5 定义训练循环

下面,我们定义一个训练循环,用于训练 U-net 模型。这个训练循环包括以下几个部分:

  • train_model 函数:定义了训练循环,包括模型训练、验证、保存最佳模型等。
  • train_model 函数中的 wandb.init:初始化 W&B 项目,用于记录训练过程和结果。
  • train_model 函数中的 wandb.log:记录训练指标到 W&B。
  • train_model 函数中的 wandb.save:保存最佳模型到 W&B。
  • train_model 函数中的 torch.amp.GradScaler:启用混合精度训练和优化配置。
  • train_model 函数中的 torch.backends.cudnn.benchmark:启用 CuDNN 自动调优。
  • train_model 函数中的 torch.set_float32_matmul_precision:优化矩阵运算。
# 定义训练循环
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=3, device="mps"):
    # Move model to device
    device = torch.device(device)
    model = model.to(device)
    
    # 启用混合精度训练和优化配置
    scaler = torch.amp.GradScaler(enabled=(device.type == "cuda"))
    torch.backends.cudnn.benchmark = (device.type == 'cuda')
    torch.set_float32_matmul_precision('high')  # 优化矩阵运算
    import wandb  # 新增导入
    wandb.init(  # 初始化W&B
        project="veg-segmentation",
        config={
            "architecture": "U-Net",
            "encoder": "resnet34",
            "learning_rate": 0.001,
            "batch_size": 32,
            "epochs": num_epochs
        }
    )
    
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for batch in train_loader:
            # 显式转移数据到设备并添加内存监控
            images = batch["image"].to(device, non_blocking=True)
            masks = batch["mask"].to(device, non_blocking=True)
            
            optimizer.zero_grad(set_to_none=True)  # 更高效的梯度清零
            
            dtype = torch.float16 if device.type == "mps" else torch.float32
            with torch.amp.autocast(device_type=device.type, dtype=dtype, enabled=(device != "cpu")):
                outputs = model(images)
                loss = criterion(outputs, masks)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            # 记录内存使用情况
            if device.type == 'cuda':
                wandb.log({
                    "gpu_mem_alloc": torch.cuda.memory_allocated() / 1e9,
                    "gpu_mem_reserved": torch.cuda.memory_reserved() / 1e9
                })

            running_loss += loss.item() * images.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}')
        
        # 记录训练指标到W&B
        wandb.log({
            "train_loss": epoch_loss,
            "learning_rate": scheduler.get_last_lr()[0]
        })

        # Validation
        model.eval()
        running_val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                images = batch["image"].to(device)
                masks = batch["mask"].to(device)

                # 直接使用模型输出,不假设它有'out'键
                outputs = model(images)
                loss = criterion(outputs, masks)
                running_val_loss += loss.item() * images.size(0)

        val_loss = running_val_loss / len(val_loader.dataset)
        print(f'Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss:.4f}')
        
        # 记录验证指标
        wandb.log({"val_loss": val_loss})

        # Save the best model
        # 保存最佳模型到W&B
        if val_loss < best_val_loss:
            wandb.save('best_model.pth')
            best_val_loss = val_loss
            torch.save({"state_dict": model.state_dict()}, 'best_model.pth')

        scheduler.step()

    print('Training complete.')

32.6 开始训练

首先,检查设备是否可用,然后运行预先定义好的训练循环。

# check device availability
if torch.mps.is_available:
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

# print device
print(f'Found device: {device}')
Found device: mps

训练模型。

train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=5, device=device)

训练过程耗时较长,可以在 W&B 项目中查看训练过程和结果。

32.7 保存训练结果

# Save the model - 使用与前面一致的格式保存
torch.save({"state_dict": model.state_dict()}, 'data/segveg/best_model.pth')