深度学习 Day24——J3-1DenseNet算法实战与解析

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

文章目录

  • 前言
  • 1 我的环境
  • 2 pytorch实现DenseNet算法
    • 2.1 前期准备
      • 2.1.1 引入库
      • 2.1.2 设置GPU(如果设备上支持GPU就使用GPU,否则使用CPU)
      • 2.1.3 导入数据
      • 2.1.4 可视化数据
      • 2.1.4 图像数据变换
      • 2.1.4 划分数据集
      • 2.1.4 加载数据
      • 2.1.4 查看数据
    • 2.2 搭建densenet121模型
    • 2.3 训练模型
      • 2.3.1 设置超参数
      • 2.3.2 编写训练函数
      • 2.3.3 编写测试函数
      • 2.3.4 正式训练
    • 2.4 结果可视化
    • 2.4 指定图片进行预测
    • 2.6 模型评估
  • 3 知识点详解
    • 3.1 nn.Sequential和nn.Module区别与选择
      • 3.1.1 nn.Sequential
      • 3.1.2 nn.Module
      • 3.1.3 对比
      • 3.1.4 总结
    • 3.2 python中OrderedDict的使用
  • 总结


前言

关键字: pytorch实现DenseNet算法,nn.Sequential和nn.Module区别与选择,python中OrderedDict的使用

1 我的环境

  • 电脑系统:Windows 11
  • 语言环境:python 3.8.6
  • 编译器:pycharm2020.2.3
  • 深度学习环境:
    torch == 1.9.1+cu111
    torchvision == 0.10.1+cu111
    TensorFlow 2.10.1
  • 显卡:NVIDIA GeForce RTX 4070

2 pytorch实现DenseNet算法

2.1 前期准备

2.1.1 引入库


import torch
import torch.nn as nn
import time
import copy
from torchvision import transforms, datasets
from pathlib import Path
from PIL import Image
import torchsummary as summary
import torch.nn.functional as F
from collections import OrderedDict
import re
import torch.utils.model_zoo as model_zoo
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率
import warnings

warnings.filterwarnings('ignore')  # 忽略一些warning内容,无需打印

2.1.2 设置GPU(如果设备上支持GPU就使用GPU,否则使用CPU)

"""前期准备-设置GPU"""
# 如果设备上支持GPU就使用GPU,否则使用CPU
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 print("Using {} device".format(device))

输出

Using cuda device

2.1.3 导入数据

'''前期工作-导入数据'''
data_dir = r"D:\DeepLearning\data\BreastCancer"
data_dir = Path(data_dir)

data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[-1] for path in data_paths]
print(classeNames)

输出

['.DS_Store', '0', '1']

2.1.4 可视化数据

'''前期工作-可视化数据'''
subfolder = Path(data_dir) / "1"
image_files = list(p.resolve() for p in subfolder.glob('*') if p.suffix in [".jpg", ".png", ".jpeg"])
plt.figure(figsize=(10, 6))
for i in range(len(image_files[:12])):
    image_file = image_files[i]
    ax = plt.subplot(3, 4, i + 1)
    img = Image.open(str(image_file))
    plt.imshow(img)
    plt.axis("off")
# 显示图片
plt.tight_layout()
plt.show()

在这里插入图片描述

2.1.4 图像数据变换

'''前期工作-图像数据变换'''
total_datadir = data_dir

# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(  # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])
total_data = datasets.ImageFolder(total_datadir, transform=train_transforms)
print(total_data)
print(total_data.class_to_idx)

输出

Dataset ImageFolder
    Number of datapoints: 13403
    Root location: D:\DeepLearning\data\BreastCancer
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )
{'0': 0, '1': 1}

2.1.4 划分数据集

'''前期工作-划分数据集'''
train_size = int(0.8 * len(total_data))  # train_size表示训练集大小,通过将总体数据长度的80%转换为整数得到;
test_size = len(total_data) - train_size  # test_size表示测试集大小,是总体数据长度减去训练集大小。
# 使用torch.utils.data.random_split()方法进行数据集划分。该方法将总体数据total_data按照指定的大小比例([train_size, test_size])随机划分为训练集和测试集,
# 并将划分结果分别赋值给train_dataset和test_dataset两个变量。
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
print("train_dataset={}\ntest_dataset={}".format(train_dataset, test_dataset))
print("train_size={}\ntest_size={}".format(train_size, test_size))

输出

train_dataset=<torch.utils.data.dataset.Subset object at 0x000001AB3AD06BE0>
test_dataset=<torch.utils.data.dataset.Subset object at 0x000001AB3AD06B20>
train_size=10722
test_size=2681

2.1.4 加载数据

'''前期工作-加载数据'''
batch_size = 32

train_dl = torch.utils.data.DataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=4)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=4)

2.1.4 查看数据

'''前期工作-查看数据'''
for X, y in test_dl:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

输出

Shape of X [N, C, H, W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

2.2 搭建densenet121模型

"""构建DenseNet网络"""
# 这里我们采用了Pytorch的框架来实现DenseNet,
# 首先实现DenseBlock中的内部结构,这里是BN+ReLU+1×1Conv+BN+ReLU+3×3Conv结构,最后也加入dropout层用于训练过程。
class _DenseLayer(nn.Sequential):
    """Basic unit of DenseBlock (using bottleneck layer) """

    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * growth_rate,
                                           kernel_size=1, stride=1, bias=False)),
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1, bias=False)),
        self.drop_rate = drop_rate

    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return torch.cat([x, new_features], 1)


# 实现DenseBlock模块,内部是密集连接方式(输入特征数线性增长):
class _DenseBlock(nn.Sequential):
    """DenseBlock """

    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(
                num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
            self.add_module('denselayer%d' % (i + 1), layer)


# 实现Transition层,它主要是一个卷积层和一个池化层:
class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


# 最后我们实现DenseNet网络:
class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 3 or 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
            (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
    """

    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
                 num_init_features=24, bn_size=4, compression=0.5, drop_rate=0,
                 num_classes=1000):
        super(DenseNet, self).__init__()

        # First Conv2d
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        ]))


        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate)
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features += num_layers * growth_rate
            if i != len(block_config) - 1:
                transition = _Transition(num_input_features=num_features,
                                         num_output_features=int(num_features * compression))
                self.features.add_module('transition%d' % (i + 1), transition)
                num_features = int(num_features * compression)

        # Final bn+relu
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))
        self.features.add_module('relu5', nn.ReLU(inplace=True))

        # classification layer
        self.classifier = nn.Linear(num_features, num_classes)

        # params initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.avg_pool2d(features, 7, stride=1).view(features.size(0), -1)
        out = self.classifier(out)
        return out



model_urls = {
    'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth'}


def densenet121(pretrained=False, **kwargs):
    """DenseNet121"""
    model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),	**kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet121'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model

"""搭建densenet121模型"""
# model = densenet121().to(device)  
model = densenet121(True).to(device)  # 使用预训练模型
print(model)
print(summary.summary(model, (3, 224, 224)))  # 查看模型的参数量以及相关指标
    

输出

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
       BatchNorm2d-5           [-1, 64, 56, 56]             128
              ReLU-6           [-1, 64, 56, 56]               0
            Conv2d-7          [-1, 128, 56, 56]           8,192
       BatchNorm2d-8          [-1, 128, 56, 56]             256
              ReLU-9          [-1, 128, 56, 56]               0
           Conv2d-10           [-1, 32, 56, 56]          36,864
      BatchNorm2d-11           [-1, 96, 56, 56]             192
             ReLU-12           [-1, 96, 56, 56]               0
           Conv2d-13          [-1, 128, 56, 56]          12,288
      BatchNorm2d-14          [-1, 128, 56, 56]             256
             ReLU-15          [-1, 128, 56, 56]               0
           Conv2d-16           [-1, 32, 56, 56]          36,864
      BatchNorm2d-17          [-1, 128, 56, 56]             256
             ReLU-18          [-1, 128, 56, 56]               0
           Conv2d-19          [-1, 128, 56, 56]          16,384
      BatchNorm2d-20          [-1, 128, 56, 56]             256
             ReLU-21          [-1, 128, 56, 56]               0
           Conv2d-22           [-1, 32, 56, 56]          36,864
      BatchNorm2d-23          [-1, 160, 56, 56]             320
             ReLU-24          [-1, 160, 56, 56]               0
           Conv2d-25          [-1, 128, 56, 56]          20,480
      BatchNorm2d-26          [-1, 128, 56, 56]             256
             ReLU-27          [-1, 128, 56, 56]               0
           Conv2d-28           [-1, 32, 56, 56]          36,864
      BatchNorm2d-29          [-1, 192, 56, 56]             384
             ReLU-30          [-1, 192, 56, 56]               0
           Conv2d-31          [-1, 128, 56, 56]          24,576
      BatchNorm2d-32          [-1, 128, 56, 56]             256
             ReLU-33          [-1, 128, 56, 56]               0
           Conv2d-34           [-1, 32, 56, 56]          36,864
      BatchNorm2d-35          [-1, 224, 56, 56]             448
             ReLU-36          [-1, 224, 56, 56]               0
           Conv2d-37          [-1, 128, 56, 56]          28,672
      BatchNorm2d-38          [-1, 128, 56, 56]             256
             ReLU-39          [-1, 128, 56, 56]               0
           Conv2d-40           [-1, 32, 56, 56]          36,864
      BatchNorm2d-41          [-1, 256, 56, 56]             512
             ReLU-42          [-1, 256, 56, 56]               0
           Conv2d-43          [-1, 128, 56, 56]          32,768
        AvgPool2d-44          [-1, 128, 28, 28]               0
      BatchNorm2d-45          [-1, 128, 28, 28]             256
             ReLU-46          [-1, 128, 28, 28]               0
           Conv2d-47          [-1, 128, 28, 28]          16,384
      BatchNorm2d-48          [-1, 128, 28, 28]             256
             ReLU-49          [-1, 128, 28, 28]               0
           Conv2d-50           [-1, 32, 28, 28]          36,864
      BatchNorm2d-51          [-1, 160, 28, 28]             320
             ReLU-52          [-1, 160, 28, 28]               0
           Conv2d-53          [-1, 128, 28, 28]          20,480
      BatchNorm2d-54          [-1, 128, 28, 28]             256
             ReLU-55          [-1, 128, 28, 28]               0
           Conv2d-56           [-1, 32, 28, 28]          36,864
      BatchNorm2d-57          [-1, 192, 28, 28]             384
             ReLU-58          [-1, 192, 28, 28]               0
           Conv2d-59          [-1, 128, 28, 28]          24,576
      BatchNorm2d-60          [-1, 128, 28, 28]             256
             ReLU-61          [-1, 128, 28, 28]               0
           Conv2d-62           [-1, 32, 28, 28]          36,864
      BatchNorm2d-63          [-1, 224, 28, 28]             448
             ReLU-64          [-1, 224, 28, 28]               0
           Conv2d-65          [-1, 128, 28, 28]          28,672
      BatchNorm2d-66          [-1, 128, 28, 28]             256
             ReLU-67          [-1, 128, 28, 28]               0
           Conv2d-68           [-1, 32, 28, 28]          36,864
      BatchNorm2d-69          [-1, 256, 28, 28]             512
             ReLU-70          [-1, 256, 28, 28]               0
           Conv2d-71          [-1, 128, 28, 28]          32,768
      BatchNorm2d-72          [-1, 128, 28, 28]             256
             ReLU-73          [-1, 128, 28, 28]               0
           Conv2d-74           [-1, 32, 28, 28]          36,864
      BatchNorm2d-75          [-1, 288, 28, 28]             576
             ReLU-76          [-1, 288, 28, 28]               0
           Conv2d-77          [-1, 128, 28, 28]          36,864
      BatchNorm2d-78          [-1, 128, 28, 28]             256
             ReLU-79          [-1, 128, 28, 28]               0
           Conv2d-80           [-1, 32, 28, 28]          36,864
      BatchNorm2d-81          [-1, 320, 28, 28]             640
             ReLU-82          [-1, 320, 28, 28]               0
           Conv2d-83          [-1, 128, 28, 28]          40,960
      BatchNorm2d-84          [-1, 128, 28, 28]             256
             ReLU-85          [-1, 128, 28, 28]               0
           Conv2d-86           [-1, 32, 28, 28]          36,864
      BatchNorm2d-87          [-1, 352, 28, 28]             704
             ReLU-88          [-1, 352, 28, 28]               0
           Conv2d-89          [-1, 128, 28, 28]          45,056
      BatchNorm2d-90          [-1, 128, 28, 28]             256
             ReLU-91          [-1, 128, 28, 28]               0
           Conv2d-92           [-1, 32, 28, 28]          36,864
      BatchNorm2d-93          [-1, 384, 28, 28]             768
             ReLU-94          [-1, 384, 28, 28]               0
           Conv2d-95          [-1, 128, 28, 28]          49,152
      BatchNorm2d-96          [-1, 128, 28, 28]             256
             ReLU-97          [-1, 128, 28, 28]               0
           Conv2d-98           [-1, 32, 28, 28]          36,864
      BatchNorm2d-99          [-1, 416, 28, 28]             832
            ReLU-100          [-1, 416, 28, 28]               0
          Conv2d-101          [-1, 128, 28, 28]          53,248
     BatchNorm2d-102          [-1, 128, 28, 28]             256
            ReLU-103          [-1, 128, 28, 28]               0
          Conv2d-104           [-1, 32, 28, 28]          36,864
     BatchNorm2d-105          [-1, 448, 28, 28]             896
            ReLU-106          [-1, 448, 28, 28]               0
          Conv2d-107          [-1, 128, 28, 28]          57,344
     BatchNorm2d-108          [-1, 128, 28, 28]             256
            ReLU-109          [-1, 128, 28, 28]               0
          Conv2d-110           [-1, 32, 28, 28]          36,864
     BatchNorm2d-111          [-1, 480, 28, 28]             960
            ReLU-112          [-1, 480, 28, 28]               0
          Conv2d-113          [-1, 128, 28, 28]          61,440
     BatchNorm2d-114          [-1, 128, 28, 28]             256
            ReLU-115          [-1, 128, 28, 28]               0
          Conv2d-116           [-1, 32, 28, 28]          36,864
     BatchNorm2d-117          [-1, 512, 28, 28]           1,024
            ReLU-118          [-1, 512, 28, 28]               0
          Conv2d-119          [-1, 256, 28, 28]         131,072
       AvgPool2d-120          [-1, 256, 14, 14]               0
     BatchNorm2d-121          [-1, 256, 14, 14]             512
            ReLU-122          [-1, 256, 14, 14]               0
          Conv2d-123          [-1, 128, 14, 14]          32,768
     BatchNorm2d-124          [-1, 128, 14, 14]             256
            ReLU-125          [-1, 128, 14, 14]               0
          Conv2d-126           [-1, 32, 14, 14]          36,864
     BatchNorm2d-127          [-1, 288, 14, 14]             576
            ReLU-128          [-1, 288, 14, 14]               0
          Conv2d-129          [-1, 128, 14, 14]          36,864
     BatchNorm2d-130          [-1, 128, 14, 14]             256
            ReLU-131          [-1, 128, 14, 14]               0
          Conv2d-132           [-1, 32, 14, 14]          36,864
     BatchNorm2d-133          [-1, 320, 14, 14]             640
            ReLU-134          [-1, 320, 14, 14]               0
          Conv2d-135          [-1, 128, 14, 14]          40,960
     BatchNorm2d-136          [-1, 128, 14, 14]             256
            ReLU-137          [-1, 128, 14, 14]               0
          Conv2d-138           [-1, 32, 14, 14]          36,864
     BatchNorm2d-139          [-1, 352, 14, 14]             704
            ReLU-140          [-1, 352, 14, 14]               0
          Conv2d-141          [-1, 128, 14, 14]          45,056
     BatchNorm2d-142          [-1, 128, 14, 14]             256
            ReLU-143          [-1, 128, 14, 14]               0
          Conv2d-144           [-1, 32, 14, 14]          36,864
     BatchNorm2d-145          [-1, 384, 14, 14]             768
            ReLU-146          [-1, 384, 14, 14]               0
          Conv2d-147          [-1, 128, 14, 14]          49,152
     BatchNorm2d-148          [-1, 128, 14, 14]             256
            ReLU-149          [-1, 128, 14, 14]               0
          Conv2d-150           [-1, 32, 14, 14]          36,864
     BatchNorm2d-151          [-1, 416, 14, 14]             832
            ReLU-152          [-1, 416, 14, 14]               0
          Conv2d-153          [-1, 128, 14, 14]          53,248
     BatchNorm2d-154          [-1, 128, 14, 14]             256
            ReLU-155          [-1, 128, 14, 14]               0
          Conv2d-156           [-1, 32, 14, 14]          36,864
     BatchNorm2d-157          [-1, 448, 14, 14]             896
            ReLU-158          [-1, 448, 14, 14]               0
          Conv2d-159          [-1, 128, 14, 14]          57,344
     BatchNorm2d-160          [-1, 128, 14, 14]             256
            ReLU-161          [-1, 128, 14, 14]               0
          Conv2d-162           [-1, 32, 14, 14]          36,864
     BatchNorm2d-163          [-1, 480, 14, 14]             960
            ReLU-164          [-1, 480, 14, 14]               0
          Conv2d-165          [-1, 128, 14, 14]          61,440
     BatchNorm2d-166          [-1, 128, 14, 14]             256
            ReLU-167          [-1, 128, 14, 14]               0
          Conv2d-168           [-1, 32, 14, 14]          36,864
     BatchNorm2d-169          [-1, 512, 14, 14]           1,024
            ReLU-170          [-1, 512, 14, 14]               0
          Conv2d-171          [-1, 128, 14, 14]          65,536
     BatchNorm2d-172          [-1, 128, 14, 14]             256
            ReLU-173          [-1, 128, 14, 14]               0
          Conv2d-174           [-1, 32, 14, 14]          36,864
     BatchNorm2d-175          [-1, 544, 14, 14]           1,088
            ReLU-176          [-1, 544, 14, 14]               0
          Conv2d-177          [-1, 128, 14, 14]          69,632
     BatchNorm2d-178          [-1, 128, 14, 14]             256
            ReLU-179          [-1, 128, 14, 14]               0
          Conv2d-180           [-1, 32, 14, 14]          36,864
     BatchNorm2d-181          [-1, 576, 14, 14]           1,152
            ReLU-182          [-1, 576, 14, 14]               0
          Conv2d-183          [-1, 128, 14, 14]          73,728
     BatchNorm2d-184          [-1, 128, 14, 14]             256
            ReLU-185          [-1, 128, 14, 14]               0
          Conv2d-186           [-1, 32, 14, 14]          36,864
     BatchNorm2d-187          [-1, 608, 14, 14]           1,216
            ReLU-188          [-1, 608, 14, 14]               0
          Conv2d-189          [-1, 128, 14, 14]          77,824
     BatchNorm2d-190          [-1, 128, 14, 14]             256
            ReLU-191          [-1, 128, 14, 14]               0
          Conv2d-192           [-1, 32, 14, 14]          36,864
     BatchNorm2d-193          [-1, 640, 14, 14]           1,280
            ReLU-194          [-1, 640, 14, 14]               0
          Conv2d-195          [-1, 128, 14, 14]          81,920
     BatchNorm2d-196          [-1, 128, 14, 14]             256
            ReLU-197          [-1, 128, 14, 14]               0
          Conv2d-198           [-1, 32, 14, 14]          36,864
     BatchNorm2d-199          [-1, 672, 14, 14]           1,344
            ReLU-200          [-1, 672, 14, 14]               0
          Conv2d-201          [-1, 128, 14, 14]          86,016
     BatchNorm2d-202          [-1, 128, 14, 14]             256
            ReLU-203          [-1, 128, 14, 14]               0
          Conv2d-204           [-1, 32, 14, 14]          36,864
     BatchNorm2d-205          [-1, 704, 14, 14]           1,408
            ReLU-206          [-1, 704, 14, 14]               0
          Conv2d-207          [-1, 128, 14, 14]          90,112
     BatchNorm2d-208          [-1, 128, 14, 14]             256
            ReLU-209          [-1, 128, 14, 14]               0
          Conv2d-210           [-1, 32, 14, 14]          36,864
     BatchNorm2d-211          [-1, 736, 14, 14]           1,472
            ReLU-212          [-1, 736, 14, 14]               0
          Conv2d-213          [-1, 128, 14, 14]          94,208
     BatchNorm2d-214          [-1, 128, 14, 14]             256
            ReLU-215          [-1, 128, 14, 14]               0
          Conv2d-216           [-1, 32, 14, 14]          36,864
     BatchNorm2d-217          [-1, 768, 14, 14]           1,536
            ReLU-218          [-1, 768, 14, 14]               0
          Conv2d-219          [-1, 128, 14, 14]          98,304
     BatchNorm2d-220          [-1, 128, 14, 14]             256
            ReLU-221          [-1, 128, 14, 14]               0
          Conv2d-222           [-1, 32, 14, 14]          36,864
     BatchNorm2d-223          [-1, 800, 14, 14]           1,600
            ReLU-224          [-1, 800, 14, 14]               0
          Conv2d-225          [-1, 128, 14, 14]         102,400
     BatchNorm2d-226          [-1, 128, 14, 14]             256
            ReLU-227          [-1, 128, 14, 14]               0
          Conv2d-228           [-1, 32, 14, 14]          36,864
     BatchNorm2d-229          [-1, 832, 14, 14]           1,664
            ReLU-230          [-1, 832, 14, 14]               0
          Conv2d-231          [-1, 128, 14, 14]         106,496
     BatchNorm2d-232          [-1, 128, 14, 14]             256
            ReLU-233          [-1, 128, 14, 14]               0
          Conv2d-234           [-1, 32, 14, 14]          36,864
     BatchNorm2d-235          [-1, 864, 14, 14]           1,728
            ReLU-236          [-1, 864, 14, 14]               0
          Conv2d-237          [-1, 128, 14, 14]         110,592
     BatchNorm2d-238          [-1, 128, 14, 14]             256
            ReLU-239          [-1, 128, 14, 14]               0
          Conv2d-240           [-1, 32, 14, 14]          36,864
     BatchNorm2d-241          [-1, 896, 14, 14]           1,792
            ReLU-242          [-1, 896, 14, 14]               0
          Conv2d-243          [-1, 128, 14, 14]         114,688
     BatchNorm2d-244          [-1, 128, 14, 14]             256
            ReLU-245          [-1, 128, 14, 14]               0
          Conv2d-246           [-1, 32, 14, 14]          36,864
     BatchNorm2d-247          [-1, 928, 14, 14]           1,856
            ReLU-248          [-1, 928, 14, 14]               0
          Conv2d-249          [-1, 128, 14, 14]         118,784
     BatchNorm2d-250          [-1, 128, 14, 14]             256
            ReLU-251          [-1, 128, 14, 14]               0
          Conv2d-252           [-1, 32, 14, 14]          36,864
     BatchNorm2d-253          [-1, 960, 14, 14]           1,920
            ReLU-254          [-1, 960, 14, 14]               0
          Conv2d-255          [-1, 128, 14, 14]         122,880
     BatchNorm2d-256          [-1, 128, 14, 14]             256
            ReLU-257          [-1, 128, 14, 14]               0
          Conv2d-258           [-1, 32, 14, 14]          36,864
     BatchNorm2d-259          [-1, 992, 14, 14]           1,984
            ReLU-260          [-1, 992, 14, 14]               0
          Conv2d-261          [-1, 128, 14, 14]         126,976
     BatchNorm2d-262          [-1, 128, 14, 14]             256
            ReLU-263          [-1, 128, 14, 14]               0
          Conv2d-264           [-1, 32, 14, 14]          36,864
     BatchNorm2d-265         [-1, 1024, 14, 14]           2,048
            ReLU-266         [-1, 1024, 14, 14]               0
          Conv2d-267          [-1, 512, 14, 14]         524,288
       AvgPool2d-268            [-1, 512, 7, 7]               0
     BatchNorm2d-269            [-1, 512, 7, 7]           1,024
            ReLU-270            [-1, 512, 7, 7]               0
          Conv2d-271            [-1, 128, 7, 7]          65,536
     BatchNorm2d-272            [-1, 128, 7, 7]             256
            ReLU-273            [-1, 128, 7, 7]               0
          Conv2d-274             [-1, 32, 7, 7]          36,864
     BatchNorm2d-275            [-1, 544, 7, 7]           1,088
            ReLU-276            [-1, 544, 7, 7]               0
          Conv2d-277            [-1, 128, 7, 7]          69,632
     BatchNorm2d-278            [-1, 128, 7, 7]             256
            ReLU-279            [-1, 128, 7, 7]               0
          Conv2d-280             [-1, 32, 7, 7]          36,864
     BatchNorm2d-281            [-1, 576, 7, 7]           1,152
            ReLU-282            [-1, 576, 7, 7]               0
          Conv2d-283            [-1, 128, 7, 7]          73,728
     BatchNorm2d-284            [-1, 128, 7, 7]             256
            ReLU-285            [-1, 128, 7, 7]               0
          Conv2d-286             [-1, 32, 7, 7]          36,864
     BatchNorm2d-287            [-1, 608, 7, 7]           1,216
            ReLU-288            [-1, 608, 7, 7]               0
          Conv2d-289            [-1, 128, 7, 7]          77,824
     BatchNorm2d-290            [-1, 128, 7, 7]             256
            ReLU-291            [-1, 128, 7, 7]               0
          Conv2d-292             [-1, 32, 7, 7]          36,864
     BatchNorm2d-293            [-1, 640, 7, 7]           1,280
            ReLU-294            [-1, 640, 7, 7]               0
          Conv2d-295            [-1, 128, 7, 7]          81,920
     BatchNorm2d-296            [-1, 128, 7, 7]             256
            ReLU-297            [-1, 128, 7, 7]               0
          Conv2d-298             [-1, 32, 7, 7]          36,864
     BatchNorm2d-299            [-1, 672, 7, 7]           1,344
            ReLU-300            [-1, 672, 7, 7]               0
          Conv2d-301            [-1, 128, 7, 7]          86,016
     BatchNorm2d-302            [-1, 128, 7, 7]             256
            ReLU-303            [-1, 128, 7, 7]               0
          Conv2d-304             [-1, 32, 7, 7]          36,864
     BatchNorm2d-305            [-1, 704, 7, 7]           1,408
            ReLU-306            [-1, 704, 7, 7]               0
          Conv2d-307            [-1, 128, 7, 7]          90,112
     BatchNorm2d-308            [-1, 128, 7, 7]             256
            ReLU-309            [-1, 128, 7, 7]               0
          Conv2d-310             [-1, 32, 7, 7]          36,864
     BatchNorm2d-311            [-1, 736, 7, 7]           1,472
            ReLU-312            [-1, 736, 7, 7]               0
          Conv2d-313            [-1, 128, 7, 7]          94,208
     BatchNorm2d-314            [-1, 128, 7, 7]             256
            ReLU-315            [-1, 128, 7, 7]               0
          Conv2d-316             [-1, 32, 7, 7]          36,864
     BatchNorm2d-317            [-1, 768, 7, 7]           1,536
            ReLU-318            [-1, 768, 7, 7]               0
          Conv2d-319            [-1, 128, 7, 7]          98,304
     BatchNorm2d-320            [-1, 128, 7, 7]             256
            ReLU-321            [-1, 128, 7, 7]               0
          Conv2d-322             [-1, 32, 7, 7]          36,864
     BatchNorm2d-323            [-1, 800, 7, 7]           1,600
            ReLU-324            [-1, 800, 7, 7]               0
          Conv2d-325            [-1, 128, 7, 7]         102,400
     BatchNorm2d-326            [-1, 128, 7, 7]             256
            ReLU-327            [-1, 128, 7, 7]               0
          Conv2d-328             [-1, 32, 7, 7]          36,864
     BatchNorm2d-329            [-1, 832, 7, 7]           1,664
            ReLU-330            [-1, 832, 7, 7]               0
          Conv2d-331            [-1, 128, 7, 7]         106,496
     BatchNorm2d-332            [-1, 128, 7, 7]             256
            ReLU-333            [-1, 128, 7, 7]               0
          Conv2d-334             [-1, 32, 7, 7]          36,864
     BatchNorm2d-335            [-1, 864, 7, 7]           1,728
            ReLU-336            [-1, 864, 7, 7]               0
          Conv2d-337            [-1, 128, 7, 7]         110,592
     BatchNorm2d-338            [-1, 128, 7, 7]             256
            ReLU-339            [-1, 128, 7, 7]               0
          Conv2d-340             [-1, 32, 7, 7]          36,864
     BatchNorm2d-341            [-1, 896, 7, 7]           1,792
            ReLU-342            [-1, 896, 7, 7]               0
          Conv2d-343            [-1, 128, 7, 7]         114,688
     BatchNorm2d-344            [-1, 128, 7, 7]             256
            ReLU-345            [-1, 128, 7, 7]               0
          Conv2d-346             [-1, 32, 7, 7]          36,864
     BatchNorm2d-347            [-1, 928, 7, 7]           1,856
            ReLU-348            [-1, 928, 7, 7]               0
          Conv2d-349            [-1, 128, 7, 7]         118,784
     BatchNorm2d-350            [-1, 128, 7, 7]             256
            ReLU-351            [-1, 128, 7, 7]               0
          Conv2d-352             [-1, 32, 7, 7]          36,864
     BatchNorm2d-353            [-1, 960, 7, 7]           1,920
            ReLU-354            [-1, 960, 7, 7]               0
          Conv2d-355            [-1, 128, 7, 7]         122,880
     BatchNorm2d-356            [-1, 128, 7, 7]             256
            ReLU-357            [-1, 128, 7, 7]               0
          Conv2d-358             [-1, 32, 7, 7]          36,864
     BatchNorm2d-359            [-1, 992, 7, 7]           1,984
            ReLU-360            [-1, 992, 7, 7]               0
          Conv2d-361            [-1, 128, 7, 7]         126,976
     BatchNorm2d-362            [-1, 128, 7, 7]             256
            ReLU-363            [-1, 128, 7, 7]               0
          Conv2d-364             [-1, 32, 7, 7]          36,864
     BatchNorm2d-365           [-1, 1024, 7, 7]           2,048
            ReLU-366           [-1, 1024, 7, 7]               0
          Linear-367                 [-1, 1000]       1,025,000
================================================================
Total params: 7,978,856
Trainable params: 7,978,856
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 294.58
Params size (MB): 30.44
Estimated Total Size (MB): 325.59
----------------------------------------------------------------

2.3 训练模型

2.3.1 设置超参数

"""训练模型--设置超参数"""
loss_fn = nn.CrossEntropyLoss()  # 创建损失函数,计算实际输出和真实相差多少,交叉熵损失函数,事实上,它就是做图片分类任务时常用的损失函数
learn_rate = 1e-4  # 学习率
optimizer1 = torch.optim.SGD(model.parameters(), lr=learn_rate)# 作用是定义优化器,用来训练时候优化模型参数;其中,SGD表示随机梯度下降,用于控制实际输出y与真实y之间的相差有多大
optimizer2 = torch.optim.Adam(model.parameters(), lr=learn_rate)  
lr_opt = optimizer2
model_opt = optimizer2
# 调用官方动态学习率接口时使用2
lambda1 = lambda epoch : 0.92 ** (epoch // 4)
# optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(lr_opt, lr_lambda=lambda1) #选定调整方法

2.3.2 编写训练函数

"""训练模型--编写训练函数"""
# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片
    num_batches = len(dataloader)  # 批次数目,1875(60000/32)

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率

    for X, y in dataloader:  # 加载数据加载器,得到里面的 X(图片数据)和 y(真实标签)
        X, y = X.to(device), y.to(device) # 用于将数据存到显卡

        # 计算预测误差
        pred = model(X)  # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失

        # 反向传播
        optimizer.zero_grad()  # 清空过往梯度
        loss.backward()  # 反向传播,计算当前梯度
        optimizer.step()  # 根据梯度更新网络参数

        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc /= size
    train_loss /= num_batches

    return train_acc, train_loss

2.3.3 编写测试函数

"""训练模型--编写测试函数"""
# 测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)  # 测试集的大小,一共10000张图片
    num_batches = len(dataloader)  # 批次数目,313(10000/32=312.5,向上取整)
    test_loss, test_acc = 0, 0

    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad(): # 测试时模型参数不用更新,所以 no_grad,整个模型参数正向推就ok,不反向更新参数
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)

            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred, target)

            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()#统计预测正确的个数

    test_acc /= size
    test_loss /= num_batches

    return test_acc, test_loss

2.3.4 正式训练

"""训练模型--正式训练"""
epochs = 10
train_loss = []
train_acc = []
test_loss = []
test_acc = []
best_test_acc=0

for epoch in range(epochs):
    milliseconds_t1 = int(time.time() * 1000)

    # 更新学习率(使用自定义学习率时使用)
    # adjust_learning_rate(lr_opt, epoch, learn_rate)

    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, model_opt)
    scheduler.step() # 更新学习率(调用官方动态学习率接口时使用)

    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    # 获取当前的学习率
    lr = lr_opt.state_dict()['param_groups'][0]['lr']

    milliseconds_t2 = int(time.time() * 1000)
    template = ('Epoch:{:2d}, duration:{}ms, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}, Lr:{:.2E}')
    if best_test_acc < epoch_test_acc:
        best_test_acc = epoch_test_acc
        #备份最好的模型
        best_model = copy.deepcopy(model)
        template = (
            'Epoch:{:2d}, duration:{}ms, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}, Lr:{:.2E},Update the best model')
    print(
        template.format(epoch + 1, milliseconds_t2-milliseconds_t1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss, lr))
# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)
print('Done')

输出最高精度为Test_acc:100%

Epoch: 1, duration:74420ms, Train_acc:83.7%, Train_loss:0.902, Test_acc:85.8%,Test_loss:0.345, Lr:1.00E-04,Update the best model
Epoch: 2, duration:72587ms, Train_acc:86.4%, Train_loss:0.329, Test_acc:85.5%,Test_loss:0.343, Lr:1.00E-04
Epoch: 3, duration:72941ms, Train_acc:87.9%, Train_loss:0.292, Test_acc:89.2%,Test_loss:0.262, Lr:1.00E-04,Update the best model
Epoch: 4, duration:74155ms, Train_acc:88.8%, Train_loss:0.279, Test_acc:89.7%,Test_loss:0.248, Lr:1.00E-04,Update the best model
Epoch: 5, duration:75123ms, Train_acc:89.1%, Train_loss:0.265, Test_acc:89.0%,Test_loss:0.277, Lr:1.00E-04
Epoch: 6, duration:74381ms, Train_acc:89.6%, Train_loss:0.255, Test_acc:90.5%,Test_loss:0.249, Lr:1.00E-04,Update the best model
Epoch: 7, duration:73710ms, Train_acc:90.2%, Train_loss:0.243, Test_acc:84.1%,Test_loss:0.369, Lr:1.00E-04
Epoch: 8, duration:73995ms, Train_acc:90.7%, Train_loss:0.230, Test_acc:89.5%,Test_loss:0.250, Lr:1.00E-04
Epoch: 9, duration:73017ms, Train_acc:90.7%, Train_loss:0.223, Test_acc:89.3%,Test_loss:0.263, Lr:1.00E-04
Epoch:10, duration:73960ms, Train_acc:91.2%, Train_loss:0.224, Test_acc:91.6%,Test_loss:0.209, Lr:1.00E-04,Update the best model
Epoch:11, duration:74113ms, Train_acc:91.2%, Train_loss:0.219, Test_acc:90.5%,Test_loss:0.225, Lr:1.00E-04
Epoch:12, duration:73573ms, Train_acc:91.5%, Train_loss:0.213, Test_acc:88.5%,Test_loss:0.273, Lr:1.00E-04
Epoch:13, duration:73206ms, Train_acc:92.2%, Train_loss:0.202, Test_acc:85.1%,Test_loss:0.377, Lr:1.00E-04
Epoch:14, duration:73540ms, Train_acc:92.1%, Train_loss:0.195, Test_acc:91.2%,Test_loss:0.225, Lr:1.00E-04
Epoch:15, duration:73378ms, Train_acc:92.3%, Train_loss:0.192, Test_acc:87.6%,Test_loss:0.796, Lr:1.00E-04
Epoch:16, duration:73195ms, Train_acc:92.5%, Train_loss:0.187, Test_acc:92.5%,Test_loss:0.197, Lr:1.00E-04,Update the best model
Epoch:17, duration:73737ms, Train_acc:93.1%, Train_loss:0.174, Test_acc:92.7%,Test_loss:0.186, Lr:1.00E-04,Update the best model
Epoch:18, duration:73884ms, Train_acc:93.4%, Train_loss:0.171, Test_acc:80.6%,Test_loss:0.463, Lr:1.00E-04
Epoch:19, duration:73239ms, Train_acc:93.2%, Train_loss:0.168, Test_acc:91.2%,Test_loss:0.221, Lr:1.00E-04
Epoch:20, duration:73386ms, Train_acc:93.7%, Train_loss:0.159, Test_acc:92.5%,Test_loss:0.196, Lr:1.00E-04

2.4 结果可视化

"""训练模型--结果可视化"""
epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

2.4 指定图片进行预测

def predict_one_image(image_path, model, transform, classes):
    test_img = Image.open(image_path).convert('RGB')
    plt.imshow(test_img)  # 展示预测的图片
    plt.show()

    test_img = transform(test_img)
    img = test_img.to(device).unsqueeze(0)

    model.eval()
    output = model(img)

    _, pred = torch.max(output, 1)
    pred_class = classes[pred]
    print(f'预测结果是:{pred_class}')
 
# 将参数加载到model当中
model.load_state_dict(torch.load(PATH, map_location=device))

"""指定图片进行预测"""
classes = list(total_data.class_to_idx)
# 预测训练集中的某张照片
predict_one_image(image_path=str(Path(data_dir) / "Cockatoo/001.jpg"),
                  model=model,
                  transform=train_transforms,
                  classes=classes)

在这里插入图片描述

输出

预测结果是:0

2.6 模型评估

"""模型评估"""
best_model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)
# 查看是否与我们记录的最高准确率一致
print(epoch_test_acc, epoch_test_loss)


输出

预测结果是:0
0.9268929503916449 0.185508520431107

3 知识点详解

3.1 nn.Sequential和nn.Module区别与选择

3.1.1 nn.Sequential

torch.nn.Sequential是一个Sequential容器,模块将按照构造函数中传递的顺序添加到模块中。另外,也可以传入一个有序模块。 为了更容易理解,官方给出了一些案例:

# Sequential使用实例

model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

# Sequential with OrderedDict使用实例
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

3.1.2 nn.Module

下面我们再用 Module 定义这个模型,下面是使用 Module 的模板

class 网络名字(nn.Module):
    def __init__(self, 一些定义的参数):
        super(网络名字, self).__init__()
        self.layer1 = nn.Linear(num_input, num_hidden)
        self.layer2 = nn.Sequential(...)
        ...

        定义需要用的网络层

    def forward(self, x): # 定义前向传播
        x1 = self.layer1(x)
        x2 = self.layer2(x)
        x = x1 + x2
        ...
        return x

注意的是,Module 里面也可以使用 Sequential,同时 Module 非常灵活,具体体现在 forward 中,如何复杂的操作都能直观的在 forward 里面执行

3.1.3 对比

为了方便比较,我们先用普通方法搭建一个神经网络。

class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        self.predict = torch.nn.Linear(n_hidden, n_output)

    def forward(self, x):
        x = F.relu(self.hidden(x))
        x = self.predict(x)
        return x
net1 = Net(1, 10, 1)

net2 = torch.nn.Sequential(
    torch.nn.Linear(1, 10),
    torch.nn.ReLU(),
    torch.nn.Linear(10, 1)
)

打印这两个net

print(net1)
"""
Net (
  (hidden): Linear (1 -> 10)
  (predict): Linear (10 -> 1)
)
"""
print(net2)
"""
Sequential (
  (0): Linear (1 -> 10)
  (1): ReLU ()
  (2): Linear (10 -> 1)
)
"""

我们可以发现,打印torch.nn.Sequential会自动加入激励函数,
在 net1 中, 激励函数实际上是在 forward() 功能中被调用的,没有在init中定义,所以在打印网络结构时不会有激励函数的信息.

解析源码,在torch.nn.Sequential中:

def forward(self, input):
    for module in self:
        input = module(input)
    return input

可以看到,torch.nn.Sequential的forward只是简单的顺序传播,操作性有限.

3.1.4 总结

使用torch.nn.Module,我们可以根据自己的需求改变传播过程,如RNN等
如果你需要快速构建或者不需要过多的过程,直接使用torch.nn.Sequential即可。

参考链接:nn.Sequential和nn.Module区别与选择

3.2 python中OrderedDict的使用

很多人认为python中的字典是无序的,因为它是按照hash来存储的,但是python中有个模块collections(英文,收集、集合),里面自带了一个子类

OrderedDict,实现了对字典对象中元素的排序。请看下面的实例:

import collections
print "Regular dictionary"
d={}
d['a']='A'
d['b']='B'
d['c']='C'
for k,v in d.items():
    print k,v

print "\nOrder dictionary"
d1 = collections.OrderedDict()
d1['a'] = 'A'
d1['b'] = 'B'
d1['c'] = 'C'
d1['1'] = '1'
d1['2'] = '2'
for k,v in d1.items():
    print k,v

输出:

Regular dictionary
a A
c C
b B

Order dictionary
a A
b B
c C
1 1
2 2

可以看到,同样是保存了ABC等几个元素,但是使用OrderedDict会根据放入元素的先后顺序进行排序。所以输出的值是排好序的。

OrderedDict对象的字典对象,如果其顺序不同那么Python也会把他们当做是两个不同的对象,请看事例:

print 'Regular dictionary:'
d2={}
d2['a']='A'
d2['b']='B'
d2['c']='C'

d3={}
d3['c']='C'
d3['a']='A'
d3['b']='B'

print d2 == d3

print '\nOrderedDict:'
d4=collections.OrderedDict()
d4['a']='A'
d4['b']='B'
d4['c']='C'

d5=collections.OrderedDict()
d5['c']='C'
d5['a']='A'
d5['b']='B'

print  d1==d2

输出:
Regular dictionary:
True

OrderedDict:
False

再看几个例子:

dd = {'banana': 3, 'apple':4, 'pear': 1, 'orange': 2}
#按key排序
kd = collections.OrderedDict(sorted(dd.items(), key=lambda t: t[0]))
print kd
#按照value排序
vd = collections.OrderedDict(sorted(dd.items(),key=lambda t:t[1]))
print vd

#输出
OrderedDict([('apple', 4), ('banana', 3), ('orange', 2), ('pear', 1)])
OrderedDict([('pear', 1), ('orange', 2), ('banana', 3), ('apple', 4)])

总结

  数据量越大,训练时间越长,在DataLoader中增加num_workers,即增加线程数量,可能会导致内存不足出现,Couldn‘t open shared file mapping或者Out of memery的错误,可尝试减小num_corkers。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/292753.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Spring MVC RequestMappingInfo路由条件匹配

前言 我们已经知道&#xff0c;被RequestMapping标注的方法会被解析为 HandlerMethod&#xff0c;它也是 Spring MVC 中最常用的 Handler 类型。现在的问题是&#xff0c;HTTP 请求是如何路由到对应的 HandlerMethod&#xff1f;你可能脱口而出&#xff1a;根据请求的 Url 匹配…

知识图谱 vs GPT

简介&#xff1a; 当我们谈论知识图谱时&#xff0c;我们指的是一种结构化的知识表示形式&#xff0c;是一种描述真实世界中事物及其关系的语义模型&#xff0c;用于描述实体之间的关系。它通过将知识组织成图形结构&#xff0c;提供了一种更全面、准确和智能的信息处理方式。知…

【论文阅读笔记】Mip-NeRF 360: Unbounded Anti-Aliased Neural Radiance Fields

目录 概述摘要引言参数化效率歧义性 mip-NeRF场景和光线参数化从粗到细的在线蒸馏基于区间的模型的正则化实现细节实验限制总结&#xff1a;附录退火膨胀采样背景颜色 paper&#xff1a;https://arxiv.org/abs/2111.12077 code&#xff1a;https://github.com/google-research/…

分布式系统架构设计之分布式事务的概述和面临的挑战

在当今大规模应用和服务的背景下&#xff0c;分布式系统的广泛应用已经成为了一种必然的主流趋势。然后&#xff0c;伴随着分布式系统的应用范围的增长&#xff0c;分布式事务处理成为了一个至关重要的关键话题。在传统的单体系统中&#xff0c;事务处理通常相对简单&#xff0…

opencv006 绘制直线、矩形、⚪、椭圆

绘制图形是opencv经常使用的操作之一&#xff0c;库中提供了很多有用的接口&#xff0c;今天来学习一下吧&#xff01; &#xff08;里面的函数和参数还是有点繁琐的&#xff09; 最终结果显示 函数介绍 直线 line(img, pt1, pt2, color, thickness, lineType, shift) img: 在…

django websocket

目录 核心代码 consumers.py from channels.generic.websocket import WebsocketConsumer from channels.exceptions import StopConsumer import datetime import time from asgiref.sync import async_to_sync class ChatConsumer(WebsocketConsumer):def websocket_conne…

【STM32】STM32学习笔记-编码器接口测速(20)

00. 目录 文章目录 00. 目录01. 预留02. 编码器测速接线图03. 编码器测速程序示例04. 程序下载05. 附录 01. 预留 02. 编码器测速接线图 03. 编码器测速程序示例 Encoder.h #ifndef __ENCODER_H #define __ENCODER_Hvoid Encoder_Init(void); int16_t Encoder_Get(void);#en…

someip中通过event方式通信,为什么实际使用时使用的是eventGroup?

someip是一种面向服务的可伸缩的协议,用于控制消息的汽车中间件的解决方案。someip提供了三种接口类型:Method,Event和Field,分别对应不同的通信机制和场景。 Event是一种主动发送的接口,用于通知客户端服务端的状态变化或者事件发生。Event可以按照一定的规则或者周期发…

IDEA中自动导包及快捷键

导包设置及快捷键 设置&#xff1a;Setting->Editor->General->Auto import快捷键 设置&#xff1a;Setting->Editor->General->Auto import java区域有两个关键选项 Add unambiguous imports on the fly 快速添加明确的导包 IDEA将在我们书写代码的时候…

JS中模块的导入导出

背景 学习js过程中&#xff0c;发现导入导出有的是使用的export 导出&#xff0c;import导入&#xff0c;有的是使用exports或module.exports导出&#xff0c;使用require导入&#xff0c;不清楚使用场景和规则&#xff0c;比较混乱。 经过了解发现&#xff0c;NodeJS 中&…

莫比乌斯函数

积性函数定义 若gcd(p,q)1&#xff0c;有f(p*q)f(p)*f(q)&#xff0c;则f(x)是积性函数 其中规定f(1)1&#xff0c;对于积性函数有&#xff1a;所有的积性函数都可以用筛法求出 常见的积性函数有欧拉函数和莫比乌斯函数 筛法求莫比乌斯函数 const int N 1e9 5; const int …

QT_01 安装、创建项目

QT - 安装、创建项目 1. 概述 1.1 什么是QT Qt 是一个跨平台的 C图形用户界面应用程序框架。 它为应用程序开发者提供建立艺术级图形界面所需的所有功能。 它是完全面向对象的&#xff0c;很容易扩展&#xff0c;并且允许真正的组件编程。 1.2 发展史 1991 年 Qt 最早由奇…

C# A* 算法 和 Dijkstra 算法 结合使用

前一篇&#xff1a;路径搜索算法 A* 算法 和 Dijkstra 算法-CSDN博客文章浏览阅读330次&#xff0c;点赞9次&#xff0c;收藏5次。Dijkstra算法使用优先队列来管理待处理的节点&#xff0c;通过不断选择最短距离的节点进行扩展&#xff0c;更新相邻节点的距离值。Dijkstra算法使…

Hadoop入门学习笔记——八、数据分析综合案例

视频课程地址&#xff1a;https://www.bilibili.com/video/BV1WY4y197g7 课程资料链接&#xff1a;https://pan.baidu.com/s/15KpnWeKpvExpKmOC8xjmtQ?pwd5ay8 Hadoop入门学习笔记&#xff08;汇总&#xff09; 目录 八、数据分析综合案例8.1. 需求分析8.1.1. 背景介绍8.1.2…

Java业务功能并发问题处理

业务场景&#xff1a; 笔者负责的功能需要调用其他系统的进行审批&#xff0c;而接口的调用过程耗时有点长&#xff08;可能长达10秒&#xff09;&#xff0c;一个订单能被多个人提交审批&#xff0c;当订单已提交后会更改为审批中&#xff0c;不能再次审批&#xff08;下游系…

js逆向第11例:猿人学第4题雪碧图、样式干扰

任务4:采集这5页的全部数字,计算加和并提交结果 打开控制台查看请求地址https://match.yuanrenxue.cn/api/match/4,返回的是一段html网页代码 复制出来格式化后,查看具体内容如下: <td><img src=\"data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABUAAA…

mysql与其他数据库有何区别?

随着信息技术的不断发展&#xff0c;数据库系统在各行各业中得到了广泛的应用。其中&#xff0c;MySQL作为一种流行的关系型数据库管理系统&#xff0c;与其他数据库系统存在一些明显的区别。本文将就MySQL与其他数据库的区别进行深入探讨。 1、更低的成本 MySQL是一个开源的关…

小兔鲜儿 uniapp - 项目打包

目录 微信小程序端​ 核心步骤​ 步骤图示​ 条件编译​ 条件编译语法​ 打包为 H5 端​ 核心步骤​ 路由基础路径​ 打包为 APP 端​ 微信小程序端​ 把当前 uni-app 项目打包成微信小程序端&#xff0c;并发布上线。 核心步骤​ 运行打包命令 pnpm build:mp-weix…

Mybatis系列课程-一对一

目标 学会使用 assocation的select 与column 属性 select:设置分步查询的唯一标识 column:将查询出的某个字段作为分步查询的下一个查询的sql条件 步骤 第一步:修改Student.java 增加 private Grade grade; // 如果之前已经增加过, 跳过这步 第二步:修改StudentMapper.java…

YOLOv8改进 | 2023Neck篇 | 利用Gold-YOLO改进YOLOv8对小目标检测

一、本文介绍 本文给大家带来的改进机制是Gold-YOLO利用其Neck改进v8的Neck,GoLd-YOLO引入了一种新的机制——信息聚集-分发(Gather-and-Distribute, GD)。这个机制通过全局融合不同层次的特征并将融合后的全局信息注入到各个层级中,从而实现更高效的信息交互和融合。这种…