24  使用模型进行病原菌检测

在这篇论文开放的源代码中 (“Csho33/Bacteria-ID: Source Code and Demos for "Rapid Identification of Pathogenic Bacteria Using Raman Spectroscopy and Deep Learning",” n.d.),提供了 3 个预训练模型的权重文件。分别是:

  1. pretrained_model.ckpt
  2. clinical_pretrained_model.ckpt
  3. finetuned_model.ckpt

我们拟使用第 3 个调优后的模型权重,对测试数据进行预测,并且与原文中的结果进行对比。

24.1 载入预训练模型

首先,使用相同的参数重建模型,并载入权重。

from resnet import ResNet
import os
import torch

# CNN parameters
layers = 6
hidden_size = 100
block_size = 2
hidden_sizes = [hidden_size] * layers
num_blocks = [block_size] * layers
input_dim = 1000
in_channels = 64
n_classes = 30 # instead of 30, we use the 8 empiric groupings


# Load trained weights for demo
cnn = ResNet(hidden_sizes, num_blocks, input_dim=input_dim,
                in_channels=in_channels, n_classes=n_classes)

# 选择设备
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

# 载入模型权重
cnn.load_state_dict(torch.load('./finetuned_model.ckpt', 
        map_location=lambda storage, loc: storage))

# 将模型移动到指定设备
cnn.to(device)
ResNet(
  (conv1): Conv1d(1, 64, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (encoder): Sequential(
    (0): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(64, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(64, 100, kernel_size=(1,), stride=(1,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
    (1): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
    (2): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
    (3): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
    (4): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
    (5): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
  )
  (linear): Linear(in_features=3200, out_features=30, bias=True)
)

24.2 模型结构解析

torchviz 库是用来可视化 PyTorch 模型的图的工具。通常,make_dot 函数会生成模型中所有操作和张量的图,对于大型模型,图会变得非常复杂。

from torchviz import make_dot

y = cnn(torch.randn(4, 1, 1000).to(device))  # 随机生成一个输入来通过模型
make_dot(y, params=dict(cnn.named_parameters()))

网络的完整结构展示出来非常大,观感不好。我们不妨看一下论文中的介绍。

Figure 1

CNN architecture

The CNN architecture is adapted from the Resnet architecture37 that has been widely successful across a range of computer vision tasks. It consists of an initial convolution layer followed by 6 residual layers and a final fully connected classification layer — a block diagram can be seen in Fig. 1. The residual layers contain shortcut connections between the input and output of each residual block, allowing for better gradient propagation and stable training (refer to reference 37 for details). Each residual layer contains 4 convolutional layers, so the total depth of the network is 26 layers. The initial convolution layer has 64 convolutional filters, while each of the hidden layers has 100 filters. These architecture hyperparameters were selected via grid search using one training and validation split on the isolate classification task. We also experimented with simple MLP (multi-layer perceptron) and CNN architectures but found that the Resnet-based architecture performed best.

这里说明,所用的 CNN 架构是基于已广泛成功应用于多种计算机视觉任务的 Resnet 架构(参考文献37)进行改进的。它包括一个初始的卷积层,后跟 6 个残差层,以及一个最终的全连接分类层——这一结构在图 1 中有所展示。残差层包含了输入和每个残差块输出之间的快捷连接,这样的设计允许更好的梯度传播和稳定的训练(详细信息请参阅参考文献 37)。每个残差层包含 4 个卷积层,因此整个网络的总深度为 26 层。初始卷积层设有 64 个卷积滤波器,而各隐藏层则各有 100 个滤波器。这些架构超参数是通过网格搜索法选定的,使用的是隔离分类任务上的一个训练和验证分割。我们也尝试过简单的多层感知机(MLP)和 CNN 架构,但发现基于 Resnet 的架构表现最佳。

在 PyTorch 中,.named_modules() 会递归地返回模型中所有模块的迭代器,包括模型本身和它所有的子模块,这可能会包括许多你不感兴趣的内部层。如果你只想要打印出主要层级,可以检查模块的类型或其名称中是否包含特定的分隔符,这通常表明了一个层级的子层。在这里,我们检查模块的名称是否包含点号(点号通常用于分隔子模块的名称)。如果没有点号,我们可以认为这是一个顶级模块。

for name, module in cnn.named_modules():
    # 如果名字是空,那么我们是最顶级;如果没有点,那么是顶级;有点的是子模块。
    if name == '':
        print(module)
ResNet(
  (conv1): Conv1d(1, 64, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (encoder): Sequential(
    (0): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(64, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(64, 100, kernel_size=(1,), stride=(1,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
    (1): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
    (2): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
    (3): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
    (4): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
    (5): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
  )
  (linear): Linear(in_features=3200, out_features=30, bias=True)
)

这段代码定义了一个一维卷积神经网络架构(ResNet),主要用于处理一维数据。这个网络结构中包含了多个残差块,每个残差块由两个卷积层和一个恒等映射(shortcut)组成。

24.2.1 网络架构概述

  1. 初始卷积层和批归一化层

    • conv1: 一个输入通道(通常为单通道的信号数据)到64个输出通道的卷积层,卷积核大小为5,步幅为1,填充为2。
    • bn1: 对64个通道的输出进行批归一化。
  2. 编码器(encoder)

    • encoder 是一个由6个 Sequential 模块组成的层级结构。每个 Sequential 模块包含两个残差块(ResidualBlock)。
  3. 残差块(ResidualBlock)

    • 每个残差块包含两个卷积层和一个恒等映射(shortcut)。具体结构如下:
      • conv1conv2: 卷积核大小为5,填充为2,无偏置的卷积层。
      • bn1bn2: 对每个卷积层输出进行批归一化。
      • shortcut: 在输入和输出通道数不同或步幅不同的情况下,使用卷积层和批归一化层调整尺寸。
  4. 线性层(linear)

    • linear: 将编码器的输出特征映射到30个输出特征,通常用于分类任务。

24.2.2 详细结构

  • 第一层卷积和批归一化

    (conv1): Conv1d(1, 64, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
    (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  • 编码器部分(encoder)

    • 每个Sequential包含两个残差块。残差块中的卷积层和批归一化层配置如下:

      (0): Sequential(
        (0): ResidualBlock(
          (conv1): Conv1d(64, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
          (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
          (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (shortcut): Sequential(
            (0): Conv1d(64, 100, kernel_size=(1,), stride=(1,), bias=False)
            (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): ResidualBlock(
          (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
          (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
          (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (shortcut): Sequential()
        )
      )
  • 最后的线性层

    (linear): Linear(in_features=3200, out_features=30, bias=True)

24.2.2.1 关键点

  1. 残差连接:通过恒等映射(shortcut)解决梯度消失问题,允许训练更深的网络。
  2. 卷积层:使用多个卷积层提取特征,尤其是卷积核大小为5的卷积层。
  3. 批归一化:在每个卷积层之后使用批归一化层,提高训练的稳定性和速度。
  4. 线性层:最后的线性层将特征映射到30个输出,用于分类或其他任务。

这个 ResNet 变体是一个较为复杂的一维卷积神经网络,适用于处理序列数据或时间序列数据,并具有强大的特征提取和分类能力。

24.3 使用模型进行预测

现在我们使用训练好的模型进行预测,并报告每个菌株的准确率。这个数字应该接近图2中报告的82.2%,但由于在微调过程中对微调数据集进行了随机采样,所以不会完全相同。

import numpy as np

# 载入数据
X = np.load('./data/raman/X_test.npy')
y = np.load('./data/raman/y_test.npy')

# 打印数据形状
print(X.shape, y.shape)
(3000, 1000) (3000,)

在这个例子中,我们没有使用 DataLoader,而是直接将整个数据集 X 转换为张量,并将其传递给模型进行预测。

cnn.eval()

X_tensor = torch.tensor(X, dtype=torch.float32)
X_tensor = X_tensor.unsqueeze(1)
X_tensor = X_tensor.to(device)

with torch.no_grad():
    preds = cnn(X_tensor)

打印预测准确性。

# 计算准确性
y_hat = preds.argmax(dim=1).cpu().numpy()
acc = (y_hat == y).mean()
print('Accuracy: {:0.1f}%'.format(100*acc))
Accuracy: 75.9%

24.4 绘制混淆矩阵

24.4.1 读取菌株名称

config.py 文件中定义了菌株名称,现在把这些定义读取出来,重新绘制混淆矩阵。

import config

# 读取菌株名称顺序
order = config.ORDER

# 读取菌株名称
strains = config.STRAINS

# 打印菌株名称顺序
print(order)

# 打印菌株名称
print(strains)
[16, 17, 14, 18, 15, 20, 21, 24, 23, 26, 27, 28, 29, 25, 6, 7, 5, 3, 4, 9, 10, 2, 8, 11, 22, 19, 12, 13, 0, 1]
{0: 'C. albicans', 1: 'C. glabrata', 2: 'K. aerogenes', 3: 'E. coli 1', 4: 'E. coli 2', 5: 'E. faecium', 6: 'E. faecalis 1', 7: 'E. faecalis 2', 8: 'E. cloacae', 9: 'K. pneumoniae 1', 10: 'K. pneumoniae 2', 11: 'P. mirabilis', 12: 'P. aeruginosa 1', 13: 'P. aeruginosa 2', 14: 'MSSA 1', 15: 'MSSA 3', 16: 'MRSA 1 (isogenic)', 17: 'MRSA 2', 18: 'MSSA 2', 19: 'S. enterica', 20: 'S. epidermidis', 21: 'S. lugdunensis', 22: 'S. marcescens', 23: 'S. pneumoniae 2', 24: 'S. pneumoniae 1', 25: 'S. sanguinis', 26: 'Group A Strep.', 27: 'Group B Strep.', 28: 'Group C Strep.', 29: 'Group G Strep.'}

yy_hat 中的数字编号使用 order 调整顺序后,再转变为 STRAINS 中的菌株名称,绘制混淆矩阵。

from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# 计算混淆矩阵
conf_matrix = confusion_matrix(y, y_hat, labels=order)

# 获取标签名称
label_names = [strains[i] for i in order]

# 绘制带有菌株名称的混淆矩阵
plt.figure(figsize=(10, 8))

# 创建热图
ax = sns.heatmap(conf_matrix, 
            annot=True, 
            fmt='d', 
            cmap='YlGnBu',
            xticklabels=label_names,
            yticklabels=label_names)

# 将x轴标签移到顶部
ax.xaxis.set_ticks_position('top')
ax.xaxis.set_label_position('top')

plt.xticks(rotation=45, ha='left')
plt.yticks(rotation=0)
plt.xlabel('Predicted')
plt.ylabel('True')

# 调整布局以防止标签被切掉
plt.tight_layout()
plt.show()

24.5 模型预测细节

24.5.1 切换工作模式

在 PyTorch 中,你可以通过检查模型的 .training 属性来查看模型当前是在训练模式还是在评估模式。这个属性是一个布尔值,当模型处于训练模式时为 True,而在评估模式(也就是说,进行推理时)为 False

调用 model.eval() 可以将模型切换到评估模式,关闭了像 DropoutBatchNorm 这样的层的特定训练时行为。相应地,model.train() 将模型切回训练模式。

在实际应用中,确保在进行模型评估、验证或测试时调用 model.eval() 来获得正确的预测结果是非常重要的。

# 切换模型模式
cnn.eval()
ResNet(
  (conv1): Conv1d(1, 64, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (encoder): Sequential(
    (0): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(64, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(64, 100, kernel_size=(1,), stride=(1,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
    (1): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
    (2): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
    (3): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
    (4): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
    (5): Sequential(
      (0): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(2,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv1d(100, 100, kernel_size=(1,), stride=(2,), bias=False)
          (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ResidualBlock(
        (conv1): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
        (bn2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
    )
  )
  (linear): Linear(in_features=3200, out_features=30, bias=True)
)

24.5.2 模型的输出格式

在 PyTorch 中,模型的各个子模块可以通过 named_modules() 方法来遍历,该方法返回一个迭代器,包括所有子模块的名称和模块对象。如果你想查看最后 5 个named_modules,你可以将迭代器转换成列表,然后选取最后 5 个条目。

下面这段代码会打印出最后 5 个模块的名称和它们的结构。如果模型中子模块的总数少于 5 个,这段代码仍然会工作,但是它会返回模型中所有的子模块。

import torch.nn as nn

def print_last_five_modules(model):
    # 假设有一个模型实例叫做 model,可以是任何继承自nn.Module的类的实例
    # model = YourModel()

    # 获取所有named modules的列表
    named_modules_list = list(model.named_modules())

    # 获取最后5个named modules
    last_five_named_modules = named_modules_list[-5:]

    # 打印这些modules的名字和结构
    for name, module in last_five_named_modules:
        print(name, '->', module)

print_last_five_modules(cnn)
encoder.5.1.bn1 -> BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
encoder.5.1.conv2 -> Conv1d(100, 100, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
encoder.5.1.bn2 -> BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
encoder.5.1.shortcut -> Sequential()
linear -> Linear(in_features=3200, out_features=30, bias=True)

现在网络的最后一个模块的名字是 linearLinear(in_features=3200, out_features=30, bias=True) 表示这是一个线性层(也称作全连接层或者密集层)的声明,在神经网络中用于变换输入特征的线性映射。下面是参数的具体含义:

  • in_features=3200: 这指的是输入特征的数量,也就是说这个层期望每个输入数据的维度是3200。在神经网络中,如果这是第一个层,那么每个输入样本应该是一个含有3200个元素的一维张量。如果这个层不是第一个层,那么前一个层的输出特征数量应该是3200。

  • out_features=30: 这指的是输出特征的数量,这一层将会输出30个特征值。无论输入的特征有多少个,经过这个层的线性变换后,最后输出的每个样本都是一个含有30个元素的一维张量。

  • bias=True: 这一选项表示这一层包含偏置(bias),每个输出特征将会有其相对应的偏置值。偏置是一个可学习的参数,它的默认初始值通常是很小的随机数。在进行线性变换后加上偏置,可以增加模型的灵活性。如果将bias设置为False,那么这一层就不会有偏置值。

# 生成一个输入数据
test = torch.randn(1, 1, 1000).to(device)

# 打印输出
print(cnn(test).shape)
torch.Size([1, 30])

24.5.3 获取预测结果

原始的预测结果为 preds,是一个形状为 (3000, 30) 的张量,表示每个样本属于每个类别的概率。

# 打印预测结果的形状
print(preds.shape)

# 打印预测结果
print(preds[1, :])
torch.Size([3000, 30])
tensor([ -2.3633, -16.7902, -31.0268, -36.2345, -46.9973, -19.5463, -16.5988,
        -36.1822, -33.4462, -34.4168, -19.3850, -45.2024, -36.0918, -27.8640,
        -21.4089, -48.9349, -37.9030, -30.3979, -41.8903, -18.3937, -24.8710,
        -32.6808, -31.7885, -18.2239, -14.3602, -33.7739, -23.3159, -28.2033,
        -40.5117, -41.5156], device='mps:0')

使用 torch.argmax 获取预测结果。argmax 返回的是最大值的索引,而不是最大值本身。

# 获取预测结果
y_hat = preds.argmax(dim=1).cpu().numpy()

# 打印预测结果
print(y_hat)
[ 0  0  0 ... 29 29 29]