31  数据预处理

本章介绍如何使用 Huggingface Datasets 加载 VegAnn 数据集,并对数据进行预处理。我们将展示如何绘制示例图像和掩膜,以及如何配置神经网络模型。

31.1 数据集简介

VegAnn 是一个包含 3,775 张多作物 RGB 图像的集合,旨在增强作物植被分割研究。

这些数据是由包括 Arvalis、INRAe、东京大学、昆士兰大学、NEON 和 EOLAB 等机构的合作提供的。

这些图像涵盖了不同的物候阶段,并在各种照明条件下使用不同的系统和平台捕获。通过聚合来自不同项目和机构的子数据集,VegAnn 代表了广泛的测量条件、作物种类和发育阶段。

  • VegAnn 数据集包含 3775 张图片
  • 图片尺寸为 512*512 像素
  • 对应的二值掩膜中,0 表示土壤和作物残留物(背景),255 表示植被(前景)
  • 该数据集包含 26 种以上作物物种,各物种的图片数量分布不均匀
  • VegAnn 数据集由使用不同采集系统和配置拍摄的户外图像编译而成

VegAnn 项目的数据集可以在 Huggingface 上访问。请参阅以下链接:https://huggingface.co/datasets/simonMadec/VegAnn

# 使用 Huggingface Datasets 加载数据集
from datasets import load_dataset

# VegAnn 只有 train 集
ds = load_dataset("simonMadec/VegAnn", split="train")

print(ds)
Dataset({
    features: ['image', 'mask', 'System', 'Orientation', 'latitude', 'longitude', 'date', 'LocAcc', 'Species', 'Owner', 'Dataset-Name', 'TVT-split1', 'TVT-split2', 'TVT-split3', 'TVT-split4', 'TVT-split5'],
    num_rows: 3775
})

31.2 数据字段

数据集中的每一条记录包含以下字段:

  • id: 每个图像补丁的唯一标识符【注:可将索引设为 id】。
  • System: 用于获取照片的成像系统(例如,手持相机、DHP、无人机)。
  • Orientation: 图像捕捉时相机的方向(例如,正下方,45 度)。
  • latitudelongitude: 拍摄图像的地理坐标。
  • date: 图像获取日期。
  • LocAcc: 位置准确性标志(1 表示高准确性,0 表示低或不确定的准确性)。
  • Species: 图像中展示的作物种类(例如,小麦、玉米、大豆)。
  • Owner: 提供图像的机构或实体(例如,Arvalis,INRAe)。
  • Dataset-Name: 图像来源的子数据集或项目(例如,Phenomobile,Easypcc)。
  • TVT-split1TVT-split5: 表示训练/验证/测试划分配置的字段,便于各种实验设置。
# 统计 ds 中各个 Dataset-Name 的数量
from collections import Counter

dataset_names = ds['Dataset-Name']
counts = Counter(dataset_names)

# Print sorted counts
for name, count in sorted(counts.items()):
    print(f"{name}: {count}")
CAN-Eye: 948
Easypcc: 583
INVITA: 500
Literal: 938
P2S2: 499
Phenomobile: 225
web: 82
Note

给图片添加 id

# 给每一行添加 'id',其值为行号
ds = ds.map(lambda example, idx: { "id": idx, **example}, with_indices=True, num_proc=10)

在 Huggingface Datasets 中,map 方法可以用于对数据集进行转换。在上面的代码中,我们使用 map 方法为数据集的每一行添加一个 id 字段,其值为行号。num_proc 参数表示使用的进程数,可以加快数据处理速度。

**example 语法用于将 example 字典中的所有键值对解包为关键字参数。这样可以将 id 字段添加到每个示例中,而不会影响其他字段。

31.3 绘制示例图像

首先,我们绘制 VegAnn 数据集的一些信息Figure 31.1

# 绘制 VegAnn 数据集的信息
import matplotlib.pyplot as plt

# Get the first image and its metadata
sample = ds[0]
image_data = sample['image']

# Create figure with two subplots
fig, ax = plt.subplots(figsize=(6, 6))

# Display the image
ax.imshow(image_data)

# Add metadata as text annotations in red
metadata_text = f"ID: {0}\n"
metadata_text += f"System: {sample['System']}\n"
metadata_text += f"Orientation: {sample['Orientation']}\n"
metadata_text += f"Location: ({sample['latitude']}, {sample['longitude']})\n"
metadata_text += f"Date: {sample['date']}\n"
metadata_text += f"LocAcc: {sample['LocAcc']}\n"
metadata_text += f"Species: {sample['Species']}\n"
metadata_text += f"Owner: {sample['Owner']}\n"
metadata_text += f"Dataset: {sample['Dataset-Name']}"

# Add text to the plot in red color at bottom left
ax.text(40, 460, metadata_text, color='red', fontsize=10, 
    bbox=dict(facecolor='white', alpha=0.7))

# Set title and remove axis
ax.set_title("Sample Image with Metadata")
ax.axis('off')

plt.tight_layout()
plt.show()
Figure 31.1: VegAnn 数据集的示例图像和元数据

此外,我们还可以绘制 VegAnn 数据集的示例图像和对应掩膜,以便更好地了解数据集的内容Figure 31.2

import matplotlib.pyplot as plt

# 绘制示例图像 
fig, axs = plt.subplots(3, 4, figsize=(7, 6))

for i in range(3):
    for j in range(2):
        idx = i * 2 + j
        # 获取图像和掩膜
        image = ds[idx]['image']
        mask = ds[idx]['mask']
        
        # 绘制原始图像
        axs[i,j*2].imshow(image)
        axs[i,j*2].set_title(f"Image {idx}")
        axs[i,j*2].axis('off')
        
        # 绘制掩膜
        axs[i,j*2+1].imshow(mask, cmap='binary')
        axs[i,j*2+1].set_title(f"Mask {idx}")
        axs[i,j*2+1].axis('off')

fig.suptitle("Sample Images and Masks")
plt.tight_layout()
plt.show()
Figure 31.2: VegAnn 数据集的示例图像和对应掩膜

31.4 扩展数据集类

我们需要扩展 Dataset,定义一个自定义的数据集类来处理从 Hugging Face 加载的数据。

# 定义 VegAnnDataset 类
from torch.utils.data import Dataset
import numpy as np
import torchvision.transforms as transforms

class VegAnnDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = np.array(item['image'])
        mask = np.array(item['mask'])

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        image = transforms.ToTensor()(image)
        mask = torch.tensor(mask, dtype=torch.long).unsqueeze(0)

        # 返回字典格式,与VegAnnModel的shared_step方法期望的输入格式匹配
        return {"id": idx, "image": image, "mask": mask}

31.5 定义数据增强

我们使用 albumentations 库来定义数据增强的方法。这里我们使用 ResizeNormalize 方法。

# 定义数据增强方法
from albumentations import Compose, Resize, Normalize, HorizontalFlip, RandomRotate90, ColorJitter, ToFloat

# 简化数据增强流程
transform = Compose([
    Resize(512, 512),
    HorizontalFlip(p=0.5),
    RandomRotate90(p=0.5),
    ColorJitter(brightness=0.2, contrast=0.2),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

Compose:这是一个容器,它将多个变换组合在一起,并按顺序应用它们。在这个例子中,Compose 包含了多个变换:

  • Resize(512, 512):将输入图像的尺寸调整为 512x512 像素。这一步骤确保所有输入图像具有相同的尺寸,以便于后续处理。
  • HorizontalFlip(p=0.5):以 50% 的概率对图像进行水平翻转。这有助于增加数据的多样性,从而提高模型的泛化能力。
  • RandomRotate90(p=0.5):以 50% 的概率对图像进行随机旋转 90 度。这有助于模型学习不同角度的特征。
  • ColorJitter(brightness=0.2, contrast=0.2):对图像进行颜色抖动,包括亮度和对比度。这有助于模型学习不同光照条件下的特征。
  • Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):对图像进行归一化处理,使用给定的均值和标准差。这些值通常是基于 ImageNet 数据集计算得出的,适用于大多数自然图像。

图像预处理是计算机视觉任务中的常见步骤,有助于提高模型性能。通过调整图像尺寸和归一化处理,可以使得模型更好地学习图像特征,从而提高模型的训练效果和泛化能力。

Tip

在 Conda 中安装 albumentations 可以使用 conda-forge 源:

conda install -c conda-forge albumentations

对同一张图做数据增强,查看效果。

# Load the first image and mask from the dataset
sample = ds[0]
image = np.array(sample['image'])
mask = np.array(sample['mask'])

# Create a figure with a 4x4 grid (8 pairs of image+mask)
fig, axs = plt.subplots(4, 4, figsize=(8, 8))
fig.suptitle('Data Augmentation Examples', fontsize=12)

# Generate 16 augmented pairs and display them
for i in range(8):
    row = i // 2
    col = (i % 2) * 2  # Each sample uses 2 columns (image + mask)
    
    # Apply augmentation
    augmented = transform(image=image, mask=mask)
    aug_image = augmented['image']
    aug_mask = augmented['mask']

    # un-normalize image
    std = [0.229, 0.224, 0.225]
    mean = [0.485, 0.456, 0.406]
    aug_image = (aug_image * std) + mean
    aug_image = np.clip(aug_image, 0, 1)

    # Plot image
    axs[row, col].imshow(aug_image)
    axs[row, col].axis('off')
    
    # Plot mask
    axs[row, col+1].imshow(aug_mask, cmap='spring')
    axs[row, col+1].axis('off')

# Adjust layout and display
plt.tight_layout()
plt.show()

Note

反归一化

因为 matplotlib.pyplot.imshow() 需要 RGB 数据在 [0,1](float)或 [0,255](int)范围内,而数据增强处理后数值超出了这个范围(如:最小值 -2.1179,最大值 2.64)。所以需要对图像进行反归一化处理,将数据还原到 [0,1] 范围内。

反归一化有两步操作:

  • 乘以标准差 stdaug_image = aug_image * 0.229
  • 加上均值 meanaug_image = aug_image + 0.485

为了确保数据在 [0,1] 范围内,使用 np.clip() 方法进行截断处理:

  • 数据截断处理aug_image = aug_image.clip(0, 1)

31.6 配置数据加载器

这里,我们将数据集分为训练集和验证集,并创建相应的数据加载器。

from torch.utils.data import DataLoader

# 直接使用 datasets 内置的 train_test_split
split_ds = ds.train_test_split(test_size=0.2, seed=42)

# Create custom datasets
train_dataset = VegAnnDataset(split_ds['train'], transform=transform)
val_dataset = VegAnnDataset(split_ds['test'], transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0)
Note

DataLoader 类是 PyTorch 中用于加载数据的实用工具。它可以自动对数据进行批处理、打乱和多线程加载,从而提高训练效率。

要从中读取数据,不能使用 train_loader[0] 这样的方式,而是使用迭代器 iter(train_loader)next(iter(train_loader))

在 Jupyter Notebook 中,将 num_workers 设置为 0 可以避免一些问题(因为 multiprocessing 可能会导致 worker 进程异常退出,陷入无尽的等待)。在生产环境中(*.py),可以根据需要调整 num_workers 的值。

使用 DataLoader 加载一个 batch,显示其中的一些图片。

import torch
import numpy as np
import matplotlib.pyplot as plt

# 取一个 batch
batch = next(iter(train_loader))

# 获取图像和掩码
images = batch['image'][:16]  # 取前 16 张
masks = batch['mask'][:16]

# 反归一化参数
std = np.array([0.229, 0.224, 0.225])
mean = np.array([0.485, 0.456, 0.406])

# 创建 4x4 子图
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
fig.suptitle('Sample Images and Masks', fontsize=12)

# 使用 zip 函数遍历图像和掩码
for ax, img, mask in zip(axes.flat, images, masks):
    # 将张量转换为 NumPy 数组 (H, W, C)
    img = img.permute(1, 2, 0).numpy()
    
    # 反归一化
    img = img * std + mean
    img = np.clip(img, 0, 1)

    # 处理 mask
    mask = mask.squeeze().numpy()
    mask = np.ma.masked_where(mask == 0, mask)

    ax.imshow(img)
    ax.imshow(mask, cmap='spring', alpha=0.4)
    ax.axis('off')

plt.tight_layout()
plt.show()

31.7 小结

本章介绍了如何使用 Huggingface Datasets 加载 VegAnn 数据集,并对数据进行预处理。我们展示了如何绘制示例图像和掩膜,以及如何配置神经网络模型。接下来,我们将使用VegAnn 数据集对 SegVeg 模型进行训练。