第J8周:Inception v1算法实战与解析

>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**

📌 本周任务:
1了解并学习图2中的卷积层运算量的计算过程(🏐储备知识->卷积层运算量的计算,有我的推导过程,建议先自己手动推导,然后再看)
2了解并学习卷积层的并行结构与1x1卷积核部分内容(重点)
3尝试根据模型框架图写入相应的pytorch代码,并使用Inception v1完成猴痘病识别

🏡 我的环境:

  • 语言环境:Python3.8
  • 编译器:Jupyter Notebook
  • 深度学习环境:Pytorch
    • torch==2.3.1+cu118
    • torchvision==0.18.1+cu118

一、Inception v1

Inception v1 论文:Going deeper with convolutions

1. 理论知识

GoogLeNet首次出现在2014年ILSVRC 比赛中获得冠军。这次的版本通常称其为Inception V1。Inception V1有22层深,参数量为5M。同一时期的VGGNet性能和Inception V1差不多,但是参数量也是远大于Inception V1。

Inception Module是Inception V1的核心组成单元,提出了卷积层的并行结构,实现了在同一层就可以提取不同的特征,如下图。

按照这样的结构来增加网络的深度,虽然可以提升性能,但是还面临计算量大(参数多)的问题。为改善这种现象,Inception Module借鉴Network-in-Network的思想,使用1x1的卷积核实现降维操作(也间接增加了网络的深度),以此来减小网络的参数量与计算量,如上图b所示。

备注举例:假如前一层的输出为100x100x128,经过具有256个5x5卷积核的卷积层之后(stride=1,pad=2),输出数据为100x100x256。其中,卷积层的参数为5x5x128x256+256。假如上一层输出先经过具有32个1x1卷积核的卷积层(1x1卷积降低了通道数,且特征图尺寸不变),再经过具有256个5x5卷积核的卷积层,最终的输出数据仍为为100x100x256,但卷积参数量已经减少为128x1x1x32+32 + 32x5x5x256+256,参数数量减少为原来的约4分之一。其计算量由原先的8.192\times 10^{9} ,降低至 2.048\times 10^{9} ,更详细的计算过程可参考我训练营内发布的“卷积层计算量的计算”一文。

1x1卷积核的作用: 1x1卷积核的最大作用是降低输入特征图的通道数,减小网络的参数量与计算量。

最后Inception Module基本由11卷积,33卷积,55卷积,33最大池化四个基本单元组成,对四个基本单元运算结果进行通道上组合,不同大小的卷积核赋予不同大小的感受野,从而提取到图像不同尺度的信息,进行融合,得到图像更好的表征, 就是Inception Module的核心思想。

2. 算法结构

实现的Inception v1网络结构图如下:

注:另外增加了两个辅助分支,作用有两点,一是为了避免梯度消失,用于向前传导梯度。反向传播时如果有一层求导为0,链式求导结果则为0。二是将中间某一层输出用作分类,起到模型融合作用,实际测试时,这两个辅助softmax分支会被去掉,在后续模型的发展中,该方法被采用较少,可以直接绕过,重点学习卷积层的并行结构与1x1卷积核部分的内容即可

详细网络结构图如下:

二、 前期准备

1. 设置GPU

如果设备上支持GPU就使用GPU,否则使用CPU

import warnings
warnings.filterwarnings("ignore")

import torch
device=torch.device("cuda" if torch.cuda.is_available() else "CPU")
device

运行结果:

device(type='cuda')

2. 导入数据

同时查看数据集中图片的数量

import pathlib

data_dir=r'D:\THE MNIST DATABASE\P4-data'
data_dir=pathlib.Path(data_dir)

image_count=len(list(data_dir.glob('*/*')))
image_count

运行结果:

2142

3. 查看数据集分类

data_paths=list(data_dir.glob('*'))
classNames=[str(path).split("\\")[3] for path in data_paths]
classNames

运行结果:

['Monkeypox', 'Others']

4. 随机查看图片

随机抽取数据集中的20张图片进行查看

import PIL,random
import matplotlib.pyplot as plt
from PIL import Image

data_paths2=list(data_dir.glob('*/*'))
plt.figure(figsize=(20,4))
for i in range(20):
    plt.subplot(2,10,i+1)
    plt.axis('off')
    image=random.choice(data_paths2) #随机选择一个图片
    plt.title(image.parts[-2])  #通过glob对象取出他的文件夹名称,即分类名
    plt.imshow(Image.open(str(image)))  #显示图片

运行结果:

5. 图片预处理  

import torchvision.transforms as transforms
from torchvision import transforms,datasets

train_transforms=transforms.Compose([
    transforms.Resize([224,224]), #将图片统一尺寸
    transforms.RandomHorizontalFlip(), #将图片随机水平翻转
    transforms.ToTensor(),  #将图片转换为tensor
    transforms.Normalize(  #标准化处理-->转换为正态分布,使模型更容易收敛
        mean=[0.485,0.456,0.406],
        std=[0.229,0.224,0.225]
    )
])

total_data=datasets.ImageFolder(
    r'D:\THE MNIST DATABASE\P4-data',
    transform=train_transforms
)
total_data

运行结果:

Dataset ImageFolder
    Number of datapoints: 2142
    Root location: D:\THE MNIST DATABASE\P4-data
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
               RandomHorizontalFlip(p=0.5)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

将数据集分类情况进行映射输出:

total_data.class_to_idx

运行结果:

{'Monkeypox': 0, 'Others': 1}

6. 划分数据集

train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_size

train_dataset,test_dataset=torch.utils.data.random_split(
    total_data,
    [train_size,test_size]
)
train_dataset,test_dataset

运行结果:

(<torch.utils.data.dataset.Subset at 0x241dec0e950>,
 <torch.utils.data.dataset.Subset at 0x241deef32d0>)

查看训练集和测试集的数据数量:

train_size,test_size

运行结果:

(1713, 429)

7. 加载数据集

batch_size=16
train_dl=torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1
)
test_dl=torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1
)

查看测试集的情况:

for x,y in train_dl:
    print("Shape of x [N,C,H,W]:",x.shape)
    print("Shape of y:",y.shape,y.dtype)
    break

运行结果:

Shape of x [N,C,H,W]: torch.Size([16, 3, 224, 224])
Shape of y: torch.Size([16]) torch.int64

二、pytorch环境下复现 Inception V1 模型

参考上方表格中的网络架构和参数搭建 Inception V1 网络模型

1. 构建模型

这里去掉了两个辅助分支,直接复现主支。

定义一个名为 Inception 的类,继承自 nn.Module。inception_block 类包含了 Inception v1 模型的所有层和参数。

import torch.nn as nn
import torch.nn.functional as F

class inception_block(nn.Module):
    def __init__(self,in_channels,ch1x1,ch3x3red,ch3x3,ch5x5red,ch5x5,pool_proj):
        super(inception_block,self).__init__()
        
        # 1x1 conv branch
        self.branch1=nn.Sequential(
            nn.Conv2d(in_channels,ch1x1,kernel_size=1),
            nn.BatchNorm2d(ch1x1),
            nn.ReLU(inplace=True)
        )
        
        # 1x1 conv -> 3x3 conv branch
        self.branch2=nn.Sequential(
            nn.Conv2d(in_channels,ch3x3red,kernel_size=1),
            nn.BatchNorm2d(ch3x3red),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch3x3red,ch3x3,kernel_size=3,padding=1),
            nn.BatchNorm2d(ch3x3),
            nn.ReLU(inplace=True)
        )
        
        # 1x1 conv -> 5x5 conv branch
        self.branch3=nn.Sequential(
            nn.Conv2d(in_channels,ch5x5red,kernel_size=1),
            nn.BatchNorm2d(ch5x5red),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch5x5red,ch5x5,kernel_size=5,padding=2),
            nn.BatchNorm2d(ch5x5),
            nn.ReLU(inplace=True)
        )
        
        # 3x3 max pooling -> 1x1 conv branch
        self.branch4=nn.Sequential(
            nn.MaxPool2d(kernel_size=3,stride=1,padding=1),
            nn.Conv2d(in_channels,pool_proj,kernel_size=1),
            nn.BatchNorm2d(pool_proj),
            nn.ReLU(inplace=True)
        )
        
    def forward(self,x):
        #compute forward pass through all branches and concatenate the output feature maps
        branch1_output=self.branch1(x)
        branch2_output=self.branch2(x)
        branch3_output=self.branch3(x)
        branch4_output=self.branch4(x)
        
        outputs=[branch1_output,branch2_output,branch3_output,branch4_output]
        return torch.cat(outputs,1)

__init__方法中,定义了四个分支,分别是:

(1)branch1,一个 1x1 卷积层;
(2)branch2,一个 1x1 卷积层接一个 3x3 卷积层;
(3)branch3,一个 1x1 卷积层接一个 5x5 卷积层;
(4)branch4,一个 3x3 最大池化层接一个 1x1 卷积层。
每个分支都包含了一些卷积层、批归一化层和激活函数。这些层都是 PyTorch 中的标准层,我们可以使用 nn.Conv2d、nn.BatchNorm2dnn.ReLU 分别定义卷积层、批归一化层和 ReLU 激活函数。

forward 方法中,我们计算从输入到所有分支的前向传递,并将所有分支的输出特征图拼接在一起。最后,我们返回拼接后的特征图。

接下来,我们定义 Inception v1 模型,使用 nn.ModuleListnn.Sequential 组合多个 Inception 模块和其他层。

class InceptionV1(nn.Module):
    def __init__(self,num_classes=1000):
        super(InceptionV1,self).__init__()
        
        self.conv1=nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3)
        self.maxpool1=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        self.conv2=nn.Conv2d(64,64,kernel_size=1,stride=1,padding=0)
        self.conv3=nn.Conv2d(64,192,kernel_size=3,stride=1,padding=1)
        self.maxpool2=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        
        self.inception3a=inception_block(192,64,96,128,16,32,32)
        self.inception3b=inception_block(256,128,128,192,32,96,64)
        self.maxpool3=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        
        self.inception4a=inception_block(480,192,96,208,16,48,64)
        self.inception4b=inception_block(512,160,112,224,24,64,64)
        self.inception4c=inception_block(512,128,128,256,24,64,64)
        self.inception4d=inception_block(512,112,114,288,32,64,64)
        self.inception4e=inception_block(528,256,160,320,32,128,128)
        self.maxpool4=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        
        self.inception5a=inception_block(832,256,160,320,32,128,128)
        
        self.inception5b=nn.Sequential(
            inception_block(832,384,192,384,48,128,128),
            nn.AvgPool2d(kernel_size=7,stride=1,padding=0),
            nn.Dropout(0.4)
        )
        
        #全连接网络层,用于分类
        self.classifier=nn.Sequential(
            nn.Linear(in_features=1024,out_features=1024),
            nn.ReLU(),
            nn.Linear(in_features=1024,out_features=num_classes),
            nn.Softmax(dim=1)
        )
        
    def forward(self,x):
        x=self.conv1(x)
        x=F.relu(x)
        x=self.maxpool1(x)
        x=self.conv2(x)
        x=F.relu(x)
        x=self.conv3(x)
        x=F.relu(x)
        x=self.maxpool2(x)
        
        x=self.inception3a(x)
        x=self.inception3b(x)
        x=self.maxpool3(x)
        
        x=self.inception4a(x)
        x=self.inception4b(x)
        x=self.inception4c(x)
        x=self.inception4d(x)
        x=self.inception4e(x)
        x=self.maxpool4(x)
        
        x=self.inception5a(x)
        x=self.inception5b(x)
        
        x=torch.flatten(x,start_dim=1)
        x=self.classifier(x)
        return x

2. 输出模型结果

#统计模型参数量以及其他指标
import torchsummary

#调用并将模型转移到GPU中
model=InceptionV1(num_classes=2).to(device)

#显示网络结构
torchsummary.summary(model,(3,224,224))
print(model)

运行结果:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,472
         MaxPool2d-2           [-1, 64, 56, 56]               0
            Conv2d-3           [-1, 64, 56, 56]           4,160
            Conv2d-4          [-1, 192, 56, 56]         110,784
         MaxPool2d-5          [-1, 192, 28, 28]               0
            Conv2d-6           [-1, 64, 28, 28]          12,352
       BatchNorm2d-7           [-1, 64, 28, 28]             128
              ReLU-8           [-1, 64, 28, 28]               0
            Conv2d-9           [-1, 96, 28, 28]          18,528
      BatchNorm2d-10           [-1, 96, 28, 28]             192
             ReLU-11           [-1, 96, 28, 28]               0
           Conv2d-12          [-1, 128, 28, 28]         110,720
      BatchNorm2d-13          [-1, 128, 28, 28]             256
             ReLU-14          [-1, 128, 28, 28]               0
           Conv2d-15           [-1, 16, 28, 28]           3,088
      BatchNorm2d-16           [-1, 16, 28, 28]              32
             ReLU-17           [-1, 16, 28, 28]               0
           Conv2d-18           [-1, 32, 28, 28]          12,832
      BatchNorm2d-19           [-1, 32, 28, 28]              64
             ReLU-20           [-1, 32, 28, 28]               0
        MaxPool2d-21          [-1, 192, 28, 28]               0
           Conv2d-22           [-1, 32, 28, 28]           6,176
      BatchNorm2d-23           [-1, 32, 28, 28]              64
             ReLU-24           [-1, 32, 28, 28]               0
  inception_block-25          [-1, 256, 28, 28]               0
           Conv2d-26          [-1, 128, 28, 28]          32,896
      BatchNorm2d-27          [-1, 128, 28, 28]             256
             ReLU-28          [-1, 128, 28, 28]               0
           Conv2d-29          [-1, 128, 28, 28]          32,896
      BatchNorm2d-30          [-1, 128, 28, 28]             256
             ReLU-31          [-1, 128, 28, 28]               0
           Conv2d-32          [-1, 192, 28, 28]         221,376
      BatchNorm2d-33          [-1, 192, 28, 28]             384
             ReLU-34          [-1, 192, 28, 28]               0
           Conv2d-35           [-1, 32, 28, 28]           8,224
      BatchNorm2d-36           [-1, 32, 28, 28]              64
             ReLU-37           [-1, 32, 28, 28]               0
           Conv2d-38           [-1, 96, 28, 28]          76,896
      BatchNorm2d-39           [-1, 96, 28, 28]             192
             ReLU-40           [-1, 96, 28, 28]               0
        MaxPool2d-41          [-1, 256, 28, 28]               0
           Conv2d-42           [-1, 64, 28, 28]          16,448
      BatchNorm2d-43           [-1, 64, 28, 28]             128
             ReLU-44           [-1, 64, 28, 28]               0
  inception_block-45          [-1, 480, 28, 28]               0
        MaxPool2d-46          [-1, 480, 14, 14]               0
           Conv2d-47          [-1, 192, 14, 14]          92,352
      BatchNorm2d-48          [-1, 192, 14, 14]             384
             ReLU-49          [-1, 192, 14, 14]               0
           Conv2d-50           [-1, 96, 14, 14]          46,176
      BatchNorm2d-51           [-1, 96, 14, 14]             192
             ReLU-52           [-1, 96, 14, 14]               0
           Conv2d-53          [-1, 208, 14, 14]         179,920
      BatchNorm2d-54          [-1, 208, 14, 14]             416
             ReLU-55          [-1, 208, 14, 14]               0
           Conv2d-56           [-1, 16, 14, 14]           7,696
      BatchNorm2d-57           [-1, 16, 14, 14]              32
             ReLU-58           [-1, 16, 14, 14]               0
           Conv2d-59           [-1, 48, 14, 14]          19,248
      BatchNorm2d-60           [-1, 48, 14, 14]              96
             ReLU-61           [-1, 48, 14, 14]               0
        MaxPool2d-62          [-1, 480, 14, 14]               0
           Conv2d-63           [-1, 64, 14, 14]          30,784
      BatchNorm2d-64           [-1, 64, 14, 14]             128
             ReLU-65           [-1, 64, 14, 14]               0
  inception_block-66          [-1, 512, 14, 14]               0
           Conv2d-67          [-1, 160, 14, 14]          82,080
      BatchNorm2d-68          [-1, 160, 14, 14]             320
             ReLU-69          [-1, 160, 14, 14]               0
           Conv2d-70          [-1, 112, 14, 14]          57,456
      BatchNorm2d-71          [-1, 112, 14, 14]             224
             ReLU-72          [-1, 112, 14, 14]               0
           Conv2d-73          [-1, 224, 14, 14]         226,016
      BatchNorm2d-74          [-1, 224, 14, 14]             448
             ReLU-75          [-1, 224, 14, 14]               0
           Conv2d-76           [-1, 24, 14, 14]          12,312
      BatchNorm2d-77           [-1, 24, 14, 14]              48
             ReLU-78           [-1, 24, 14, 14]               0
           Conv2d-79           [-1, 64, 14, 14]          38,464
      BatchNorm2d-80           [-1, 64, 14, 14]             128
             ReLU-81           [-1, 64, 14, 14]               0
        MaxPool2d-82          [-1, 512, 14, 14]               0
           Conv2d-83           [-1, 64, 14, 14]          32,832
      BatchNorm2d-84           [-1, 64, 14, 14]             128
             ReLU-85           [-1, 64, 14, 14]               0
  inception_block-86          [-1, 512, 14, 14]               0
           Conv2d-87          [-1, 128, 14, 14]          65,664
      BatchNorm2d-88          [-1, 128, 14, 14]             256
             ReLU-89          [-1, 128, 14, 14]               0
           Conv2d-90          [-1, 128, 14, 14]          65,664
      BatchNorm2d-91          [-1, 128, 14, 14]             256
             ReLU-92          [-1, 128, 14, 14]               0
           Conv2d-93          [-1, 256, 14, 14]         295,168
      BatchNorm2d-94          [-1, 256, 14, 14]             512
             ReLU-95          [-1, 256, 14, 14]               0
           Conv2d-96           [-1, 24, 14, 14]          12,312
      BatchNorm2d-97           [-1, 24, 14, 14]              48
             ReLU-98           [-1, 24, 14, 14]               0
           Conv2d-99           [-1, 64, 14, 14]          38,464
     BatchNorm2d-100           [-1, 64, 14, 14]             128
            ReLU-101           [-1, 64, 14, 14]               0
       MaxPool2d-102          [-1, 512, 14, 14]               0
          Conv2d-103           [-1, 64, 14, 14]          32,832
     BatchNorm2d-104           [-1, 64, 14, 14]             128
            ReLU-105           [-1, 64, 14, 14]               0
 inception_block-106          [-1, 512, 14, 14]               0
          Conv2d-107          [-1, 112, 14, 14]          57,456
     BatchNorm2d-108          [-1, 112, 14, 14]             224
            ReLU-109          [-1, 112, 14, 14]               0
          Conv2d-110          [-1, 114, 14, 14]          58,482
     BatchNorm2d-111          [-1, 114, 14, 14]             228
            ReLU-112          [-1, 114, 14, 14]               0
          Conv2d-113          [-1, 288, 14, 14]         295,776
     BatchNorm2d-114          [-1, 288, 14, 14]             576
            ReLU-115          [-1, 288, 14, 14]               0
          Conv2d-116           [-1, 32, 14, 14]          16,416
     BatchNorm2d-117           [-1, 32, 14, 14]              64
            ReLU-118           [-1, 32, 14, 14]               0
          Conv2d-119           [-1, 64, 14, 14]          51,264
     BatchNorm2d-120           [-1, 64, 14, 14]             128
            ReLU-121           [-1, 64, 14, 14]               0
       MaxPool2d-122          [-1, 512, 14, 14]               0
          Conv2d-123           [-1, 64, 14, 14]          32,832
     BatchNorm2d-124           [-1, 64, 14, 14]             128
            ReLU-125           [-1, 64, 14, 14]               0
 inception_block-126          [-1, 528, 14, 14]               0
          Conv2d-127          [-1, 256, 14, 14]         135,424
     BatchNorm2d-128          [-1, 256, 14, 14]             512
            ReLU-129          [-1, 256, 14, 14]               0
          Conv2d-130          [-1, 160, 14, 14]          84,640
     BatchNorm2d-131          [-1, 160, 14, 14]             320
            ReLU-132          [-1, 160, 14, 14]               0
          Conv2d-133          [-1, 320, 14, 14]         461,120
     BatchNorm2d-134          [-1, 320, 14, 14]             640
            ReLU-135          [-1, 320, 14, 14]               0
          Conv2d-136           [-1, 32, 14, 14]          16,928
     BatchNorm2d-137           [-1, 32, 14, 14]              64
            ReLU-138           [-1, 32, 14, 14]               0
          Conv2d-139          [-1, 128, 14, 14]         102,528
     BatchNorm2d-140          [-1, 128, 14, 14]             256
            ReLU-141          [-1, 128, 14, 14]               0
       MaxPool2d-142          [-1, 528, 14, 14]               0
          Conv2d-143          [-1, 128, 14, 14]          67,712
     BatchNorm2d-144          [-1, 128, 14, 14]             256
            ReLU-145          [-1, 128, 14, 14]               0
 inception_block-146          [-1, 832, 14, 14]               0
       MaxPool2d-147            [-1, 832, 7, 7]               0
          Conv2d-148            [-1, 256, 7, 7]         213,248
     BatchNorm2d-149            [-1, 256, 7, 7]             512
            ReLU-150            [-1, 256, 7, 7]               0
          Conv2d-151            [-1, 160, 7, 7]         133,280
     BatchNorm2d-152            [-1, 160, 7, 7]             320
            ReLU-153            [-1, 160, 7, 7]               0
          Conv2d-154            [-1, 320, 7, 7]         461,120
     BatchNorm2d-155            [-1, 320, 7, 7]             640
            ReLU-156            [-1, 320, 7, 7]               0
          Conv2d-157             [-1, 32, 7, 7]          26,656
     BatchNorm2d-158             [-1, 32, 7, 7]              64
            ReLU-159             [-1, 32, 7, 7]               0
          Conv2d-160            [-1, 128, 7, 7]         102,528
     BatchNorm2d-161            [-1, 128, 7, 7]             256
            ReLU-162            [-1, 128, 7, 7]               0
       MaxPool2d-163            [-1, 832, 7, 7]               0
          Conv2d-164            [-1, 128, 7, 7]         106,624
     BatchNorm2d-165            [-1, 128, 7, 7]             256
            ReLU-166            [-1, 128, 7, 7]               0
 inception_block-167            [-1, 832, 7, 7]               0
          Conv2d-168            [-1, 384, 7, 7]         319,872
     BatchNorm2d-169            [-1, 384, 7, 7]             768
            ReLU-170            [-1, 384, 7, 7]               0
          Conv2d-171            [-1, 192, 7, 7]         159,936
     BatchNorm2d-172            [-1, 192, 7, 7]             384
            ReLU-173            [-1, 192, 7, 7]               0
          Conv2d-174            [-1, 384, 7, 7]         663,936
     BatchNorm2d-175            [-1, 384, 7, 7]             768
            ReLU-176            [-1, 384, 7, 7]               0
          Conv2d-177             [-1, 48, 7, 7]          39,984
     BatchNorm2d-178             [-1, 48, 7, 7]              96
            ReLU-179             [-1, 48, 7, 7]               0
          Conv2d-180            [-1, 128, 7, 7]         153,728
     BatchNorm2d-181            [-1, 128, 7, 7]             256
            ReLU-182            [-1, 128, 7, 7]               0
       MaxPool2d-183            [-1, 832, 7, 7]               0
          Conv2d-184            [-1, 128, 7, 7]         106,624
     BatchNorm2d-185            [-1, 128, 7, 7]             256
            ReLU-186            [-1, 128, 7, 7]               0
 inception_block-187           [-1, 1024, 7, 7]               0
       AvgPool2d-188           [-1, 1024, 1, 1]               0
         Dropout-189           [-1, 1024, 1, 1]               0
          Linear-190                 [-1, 1024]       1,049,600
            ReLU-191                 [-1, 1024]               0
          Linear-192                    [-1, 2]           2,050
         Softmax-193                    [-1, 2]               0
================================================================
Total params: 6,945,912
Trainable params: 6,945,912
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 69.48
Params size (MB): 26.50
Estimated Total Size (MB): 96.55
----------------------------------------------------------------
InceptionV1(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (conv2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
  (conv3): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (inception3a): inception_block(
    (branch1): Sequential(
      (0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (branch2): Sequential(
      (0): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch3): Sequential(
      (0): Conv2d(192, 16, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (1): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
  )
  (inception3b): inception_block(
    (branch1): Sequential(
      (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (branch2): Sequential(
      (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch3): Sequential(
      (0): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
  )
  (maxpool3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (inception4a): inception_block(
    (branch1): Sequential(
      (0): Conv2d(480, 192, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (branch2): Sequential(
      (0): Conv2d(480, 96, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(96, 208, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(208, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch3): Sequential(
      (0): Conv2d(480, 16, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(16, 48, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (1): Conv2d(480, 64, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
  )
  (inception4b): inception_block(
    (branch1): Sequential(
      (0): Conv2d(512, 160, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (branch2): Sequential(
      (0): Conv2d(512, 112, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(112, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch3): Sequential(
      (0): Conv2d(512, 24, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(24, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (1): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
  )
  (inception4c): inception_block(
    (branch1): Sequential(
      (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (branch2): Sequential(
      (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch3): Sequential(
      (0): Conv2d(512, 24, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(24, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (1): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
  )
  (inception4d): inception_block(
    (branch1): Sequential(
      (0): Conv2d(512, 112, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (branch2): Sequential(
      (0): Conv2d(512, 114, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(114, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(114, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch3): Sequential(
      (0): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (1): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
  )
  (inception4e): inception_block(
    (branch1): Sequential(
      (0): Conv2d(528, 256, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (branch2): Sequential(
      (0): Conv2d(528, 160, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch3): Sequential(
      (0): Conv2d(528, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (1): Conv2d(528, 128, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
  )
  (maxpool4): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (inception5a): inception_block(
    (branch1): Sequential(
      (0): Conv2d(832, 256, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (branch2): Sequential(
      (0): Conv2d(832, 160, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch3): Sequential(
      (0): Conv2d(832, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
      (1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1))
      (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU(inplace=True)
    )
  )
  (inception5b): Sequential(
    (0): inception_block(
      (branch1): Sequential(
        (0): Conv2d(832, 384, kernel_size=(1, 1), stride=(1, 1))
        (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (branch2): Sequential(
        (0): Conv2d(832, 192, kernel_size=(1, 1), stride=(1, 1))
        (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
      (branch3): Sequential(
        (0): Conv2d(832, 48, kernel_size=(1, 1), stride=(1, 1))
        (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(48, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
      (branch4): Sequential(
        (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
        (1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1))
        (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
      )
    )
    (1): AvgPool2d(kernel_size=7, stride=1, padding=0)
    (2): Dropout(p=0.4, inplace=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=2, bias=True)
    (3): Softmax(dim=1)
  )
)

三、 训练模型

1. 编写训练函数

def train(dataloader,model,loss_fn,optimizer):
    size=len(dataloader.dataset)  #训练集的大小
    num_batches=len(dataloader)  #批次数目
    
    train_loss,train_acc=0,0  #初始化训练损失和正确率
    
    for x,y in dataloader:  #获取图片及其标签
        x,y=x.to(device),y.to(device)
        
        #计算预测误差
        pred=model(x)  #网络输出
        loss=loss_fn(pred,y)  #计算网络输出和真实值之间的差距,二者差值即为损失
        
        #反向传播
        optimizer.zero_grad()  #grad属性归零
        loss.backward()  #反向传播
        optimizer.step()  #每一步自动更新
        
        #记录acc与loss
        train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()
        train_loss+=loss.item()
        
    train_acc/=size
    train_loss/=num_batches
    
    return train_acc,train_loss

2. 编写测试函数

测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器

#测试函数
def test(dataloader,model,loss_fn):
    size=len(dataloader.dataset) #测试集的大小
    num_batches=len(dataloader)  #批次数目
    test_loss,test_acc=0,0
    
    #当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs,target in dataloader:
            imgs,target=imgs.to(device),target.to(device)
            
            #计算loss
            target_pred=model(imgs)
            loss=loss_fn(target_pred,target)
            
            test_loss+=loss.item()
            test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()
            
    test_acc/=size
    test_loss/=num_batches
    
    return test_acc,test_loss

3. 正式训练

import copy
optimizer=torch.optim.Adam(model.parameters(),lr=1e-4)  #创建优化器,并设置学习率
loss_fn=nn.CrossEntropyLoss()  #创建损失函数 
 
epochs=100
 
train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]
 
best_acc=0  #设置一个最佳准确率,作为最佳模型的判别指标
 
for epoch in range(epochs):
    
    model.train()
    epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)
    
    model.eval()
    epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)
    
    #保存最佳模型到J8_model
    if epoch_test_acc>best_acc:
        best_acc=epoch_test_acc
        J8_model=copy.deepcopy(model)
        
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    #获取当前学习率
    lr=optimizer.state_dict()['param_groups'][0]['lr']
    
    template=('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')
    print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,
                          epoch_test_acc*100,epoch_test_loss,lr))
    
#保存最佳模型到文件中
PATH=r'D:\THE MNIST DATABASE\J-series\J8_model.pth'
torch.save(model.state_dict(),PATH)
print('Done')

运行结果:

Epoch: 1,Train_acc:66.4%,Train_loss:0.633,Test_acc:61.5%,Test_loss:0.656,Lr:1.00E-04
Epoch: 2,Train_acc:67.3%,Train_loss:0.616,Test_acc:68.1%,Test_loss:0.624,Lr:1.00E-04
Epoch: 3,Train_acc:69.9%,Train_loss:0.591,Test_acc:68.5%,Test_loss:0.613,Lr:1.00E-04
Epoch: 4,Train_acc:73.3%,Train_loss:0.574,Test_acc:68.1%,Test_loss:0.605,Lr:1.00E-04
Epoch: 5,Train_acc:75.3%,Train_loss:0.556,Test_acc:65.0%,Test_loss:0.651,Lr:1.00E-04
Epoch: 6,Train_acc:74.9%,Train_loss:0.551,Test_acc:73.0%,Test_loss:0.563,Lr:1.00E-04
Epoch: 7,Train_acc:77.4%,Train_loss:0.529,Test_acc:80.4%,Test_loss:0.500,Lr:1.00E-04
Epoch: 8,Train_acc:80.2%,Train_loss:0.503,Test_acc:78.1%,Test_loss:0.518,Lr:1.00E-04
Epoch: 9,Train_acc:79.6%,Train_loss:0.509,Test_acc:77.4%,Test_loss:0.533,Lr:1.00E-04
Epoch:10,Train_acc:80.0%,Train_loss:0.504,Test_acc:77.6%,Test_loss:0.528,Lr:1.00E-04
Epoch:11,Train_acc:85.4%,Train_loss:0.456,Test_acc:86.0%,Test_loss:0.448,Lr:1.00E-04
Epoch:12,Train_acc:84.0%,Train_loss:0.466,Test_acc:84.4%,Test_loss:0.469,Lr:1.00E-04
Epoch:13,Train_acc:84.1%,Train_loss:0.467,Test_acc:80.0%,Test_loss:0.511,Lr:1.00E-04
Epoch:14,Train_acc:85.8%,Train_loss:0.449,Test_acc:77.4%,Test_loss:0.534,Lr:1.00E-04
Epoch:15,Train_acc:85.5%,Train_loss:0.460,Test_acc:83.2%,Test_loss:0.471,Lr:1.00E-04
Epoch:16,Train_acc:84.8%,Train_loss:0.464,Test_acc:81.8%,Test_loss:0.490,Lr:1.00E-04
Epoch:17,Train_acc:84.2%,Train_loss:0.467,Test_acc:76.5%,Test_loss:0.540,Lr:1.00E-04
Epoch:18,Train_acc:83.7%,Train_loss:0.471,Test_acc:84.6%,Test_loss:0.467,Lr:1.00E-04
Epoch:19,Train_acc:87.3%,Train_loss:0.436,Test_acc:88.3%,Test_loss:0.427,Lr:1.00E-04
Epoch:20,Train_acc:86.2%,Train_loss:0.454,Test_acc:82.5%,Test_loss:0.487,Lr:1.00E-04
Epoch:21,Train_acc:86.2%,Train_loss:0.444,Test_acc:87.9%,Test_loss:0.430,Lr:1.00E-04
Epoch:22,Train_acc:88.3%,Train_loss:0.437,Test_acc:87.9%,Test_loss:0.436,Lr:1.00E-04
Epoch:23,Train_acc:89.2%,Train_loss:0.423,Test_acc:86.7%,Test_loss:0.442,Lr:1.00E-04
Epoch:24,Train_acc:89.6%,Train_loss:0.424,Test_acc:89.0%,Test_loss:0.421,Lr:1.00E-04
Epoch:25,Train_acc:90.6%,Train_loss:0.405,Test_acc:91.6%,Test_loss:0.402,Lr:1.00E-04
Epoch:26,Train_acc:90.7%,Train_loss:0.413,Test_acc:89.7%,Test_loss:0.412,Lr:1.00E-04
Epoch:27,Train_acc:90.5%,Train_loss:0.406,Test_acc:87.6%,Test_loss:0.431,Lr:1.00E-04
Epoch:28,Train_acc:87.6%,Train_loss:0.427,Test_acc:86.0%,Test_loss:0.451,Lr:1.00E-04
Epoch:29,Train_acc:89.1%,Train_loss:0.417,Test_acc:89.0%,Test_loss:0.421,Lr:1.00E-04
Epoch:30,Train_acc:91.5%,Train_loss:0.393,Test_acc:90.7%,Test_loss:0.406,Lr:1.00E-04
Epoch:31,Train_acc:92.1%,Train_loss:0.395,Test_acc:88.1%,Test_loss:0.427,Lr:1.00E-04
Epoch:32,Train_acc:93.0%,Train_loss:0.385,Test_acc:88.8%,Test_loss:0.418,Lr:1.00E-04
Epoch:33,Train_acc:91.4%,Train_loss:0.397,Test_acc:91.1%,Test_loss:0.402,Lr:1.00E-04
Epoch:34,Train_acc:92.5%,Train_loss:0.385,Test_acc:88.8%,Test_loss:0.425,Lr:1.00E-04
Epoch:35,Train_acc:91.8%,Train_loss:0.400,Test_acc:92.1%,Test_loss:0.391,Lr:1.00E-04
Epoch:36,Train_acc:91.9%,Train_loss:0.396,Test_acc:92.1%,Test_loss:0.390,Lr:1.00E-04
Epoch:37,Train_acc:90.5%,Train_loss:0.409,Test_acc:90.0%,Test_loss:0.413,Lr:1.00E-04
Epoch:38,Train_acc:93.1%,Train_loss:0.381,Test_acc:86.0%,Test_loss:0.444,Lr:1.00E-04
Epoch:39,Train_acc:93.1%,Train_loss:0.381,Test_acc:93.7%,Test_loss:0.379,Lr:1.00E-04
Epoch:40,Train_acc:93.5%,Train_loss:0.387,Test_acc:93.0%,Test_loss:0.381,Lr:1.00E-04
Epoch:41,Train_acc:94.1%,Train_loss:0.379,Test_acc:91.8%,Test_loss:0.394,Lr:1.00E-04
Epoch:42,Train_acc:93.6%,Train_loss:0.377,Test_acc:93.2%,Test_loss:0.381,Lr:1.00E-04
Epoch:43,Train_acc:93.9%,Train_loss:0.380,Test_acc:92.5%,Test_loss:0.384,Lr:1.00E-04
Epoch:44,Train_acc:93.9%,Train_loss:0.381,Test_acc:92.5%,Test_loss:0.384,Lr:1.00E-04
Epoch:45,Train_acc:89.6%,Train_loss:0.413,Test_acc:92.3%,Test_loss:0.388,Lr:1.00E-04
Epoch:46,Train_acc:91.7%,Train_loss:0.395,Test_acc:90.9%,Test_loss:0.401,Lr:1.00E-04
Epoch:47,Train_acc:93.4%,Train_loss:0.387,Test_acc:90.4%,Test_loss:0.407,Lr:1.00E-04
Epoch:48,Train_acc:93.6%,Train_loss:0.375,Test_acc:92.1%,Test_loss:0.388,Lr:1.00E-04
Epoch:49,Train_acc:93.8%,Train_loss:0.375,Test_acc:94.9%,Test_loss:0.367,Lr:1.00E-04
Epoch:50,Train_acc:94.2%,Train_loss:0.379,Test_acc:92.1%,Test_loss:0.390,Lr:1.00E-04
Epoch:51,Train_acc:93.6%,Train_loss:0.385,Test_acc:92.5%,Test_loss:0.383,Lr:1.00E-04
Epoch:52,Train_acc:93.9%,Train_loss:0.380,Test_acc:90.0%,Test_loss:0.414,Lr:1.00E-04
Epoch:53,Train_acc:93.2%,Train_loss:0.378,Test_acc:90.2%,Test_loss:0.412,Lr:1.00E-04
Epoch:54,Train_acc:92.6%,Train_loss:0.393,Test_acc:85.5%,Test_loss:0.454,Lr:1.00E-04
Epoch:55,Train_acc:91.9%,Train_loss:0.397,Test_acc:91.6%,Test_loss:0.398,Lr:1.00E-04
Epoch:56,Train_acc:94.3%,Train_loss:0.368,Test_acc:92.5%,Test_loss:0.384,Lr:1.00E-04
Epoch:57,Train_acc:95.7%,Train_loss:0.354,Test_acc:93.5%,Test_loss:0.379,Lr:1.00E-04
Epoch:58,Train_acc:94.2%,Train_loss:0.380,Test_acc:93.5%,Test_loss:0.377,Lr:1.00E-04
Epoch:59,Train_acc:95.3%,Train_loss:0.361,Test_acc:93.0%,Test_loss:0.381,Lr:1.00E-04
Epoch:60,Train_acc:92.5%,Train_loss:0.385,Test_acc:90.0%,Test_loss:0.412,Lr:1.00E-04
Epoch:61,Train_acc:94.7%,Train_loss:0.372,Test_acc:95.1%,Test_loss:0.362,Lr:1.00E-04
Epoch:62,Train_acc:95.1%,Train_loss:0.369,Test_acc:90.2%,Test_loss:0.408,Lr:1.00E-04
Epoch:63,Train_acc:94.2%,Train_loss:0.380,Test_acc:92.8%,Test_loss:0.385,Lr:1.00E-04
Epoch:64,Train_acc:94.7%,Train_loss:0.374,Test_acc:91.8%,Test_loss:0.395,Lr:1.00E-04
Epoch:65,Train_acc:96.3%,Train_loss:0.359,Test_acc:94.2%,Test_loss:0.372,Lr:1.00E-04
Epoch:66,Train_acc:95.1%,Train_loss:0.361,Test_acc:92.1%,Test_loss:0.392,Lr:1.00E-04
Epoch:67,Train_acc:95.8%,Train_loss:0.354,Test_acc:92.3%,Test_loss:0.390,Lr:1.00E-04
Epoch:68,Train_acc:95.2%,Train_loss:0.358,Test_acc:93.7%,Test_loss:0.373,Lr:1.00E-04
Epoch:69,Train_acc:95.3%,Train_loss:0.360,Test_acc:93.7%,Test_loss:0.377,Lr:1.00E-04
Epoch:70,Train_acc:95.2%,Train_loss:0.368,Test_acc:93.2%,Test_loss:0.377,Lr:1.00E-04
Epoch:71,Train_acc:94.9%,Train_loss:0.373,Test_acc:91.8%,Test_loss:0.396,Lr:1.00E-04
Epoch:72,Train_acc:94.0%,Train_loss:0.371,Test_acc:94.2%,Test_loss:0.368,Lr:1.00E-04
Epoch:73,Train_acc:95.5%,Train_loss:0.356,Test_acc:93.2%,Test_loss:0.377,Lr:1.00E-04
Epoch:74,Train_acc:96.0%,Train_loss:0.362,Test_acc:93.9%,Test_loss:0.372,Lr:1.00E-04
Epoch:75,Train_acc:95.3%,Train_loss:0.368,Test_acc:93.2%,Test_loss:0.379,Lr:1.00E-04
Epoch:76,Train_acc:94.8%,Train_loss:0.368,Test_acc:91.4%,Test_loss:0.397,Lr:1.00E-04
Epoch:77,Train_acc:91.1%,Train_loss:0.398,Test_acc:90.9%,Test_loss:0.404,Lr:1.00E-04
Epoch:78,Train_acc:92.6%,Train_loss:0.388,Test_acc:90.0%,Test_loss:0.407,Lr:1.00E-04
Epoch:79,Train_acc:94.9%,Train_loss:0.362,Test_acc:94.9%,Test_loss:0.362,Lr:1.00E-04
Epoch:80,Train_acc:93.2%,Train_loss:0.380,Test_acc:91.4%,Test_loss:0.398,Lr:1.00E-04
Epoch:81,Train_acc:94.3%,Train_loss:0.368,Test_acc:93.7%,Test_loss:0.373,Lr:1.00E-04
Epoch:82,Train_acc:94.8%,Train_loss:0.363,Test_acc:93.2%,Test_loss:0.378,Lr:1.00E-04
Epoch:83,Train_acc:95.6%,Train_loss:0.364,Test_acc:88.1%,Test_loss:0.425,Lr:1.00E-04
Epoch:84,Train_acc:92.6%,Train_loss:0.386,Test_acc:92.3%,Test_loss:0.389,Lr:1.00E-04
Epoch:85,Train_acc:94.3%,Train_loss:0.377,Test_acc:91.4%,Test_loss:0.397,Lr:1.00E-04
Epoch:86,Train_acc:96.1%,Train_loss:0.359,Test_acc:96.3%,Test_loss:0.350,Lr:1.00E-04
Epoch:87,Train_acc:95.7%,Train_loss:0.363,Test_acc:94.2%,Test_loss:0.369,Lr:1.00E-04
Epoch:88,Train_acc:96.0%,Train_loss:0.360,Test_acc:93.2%,Test_loss:0.381,Lr:1.00E-04
Epoch:89,Train_acc:95.9%,Train_loss:0.353,Test_acc:93.7%,Test_loss:0.374,Lr:1.00E-04
Epoch:90,Train_acc:95.4%,Train_loss:0.361,Test_acc:93.2%,Test_loss:0.375,Lr:1.00E-04
Epoch:91,Train_acc:96.5%,Train_loss:0.347,Test_acc:95.1%,Test_loss:0.358,Lr:1.00E-04
Epoch:92,Train_acc:97.1%,Train_loss:0.342,Test_acc:94.9%,Test_loss:0.362,Lr:1.00E-04
Epoch:93,Train_acc:95.7%,Train_loss:0.363,Test_acc:93.7%,Test_loss:0.371,Lr:1.00E-04
Epoch:94,Train_acc:94.9%,Train_loss:0.361,Test_acc:95.1%,Test_loss:0.360,Lr:1.00E-04
Epoch:95,Train_acc:96.1%,Train_loss:0.360,Test_acc:95.1%,Test_loss:0.362,Lr:1.00E-04
Epoch:96,Train_acc:96.9%,Train_loss:0.352,Test_acc:93.7%,Test_loss:0.376,Lr:1.00E-04
Epoch:97,Train_acc:95.0%,Train_loss:0.367,Test_acc:93.7%,Test_loss:0.376,Lr:1.00E-04
Epoch:98,Train_acc:95.0%,Train_loss:0.371,Test_acc:95.1%,Test_loss:0.360,Lr:1.00E-04
Epoch:99,Train_acc:96.9%,Train_loss:0.352,Test_acc:95.6%,Test_loss:0.359,Lr:1.00E-04
Epoch:100,Train_acc:94.8%,Train_loss:0.374,Test_acc:93.7%,Test_loss:0.372,Lr:1.00E-04
Done

四、 结果可视化

1. Loss与Accuracy图

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")   #忽略警告信息
plt.rcParams['font.sans-serif']=['SimHei']   #正常显示中文标签
plt.rcParams['axes.unicode_minus']=False   #正常显示负号
plt.rcParams['figure.dpi']=300   #分辨率
 
epochs_range=range(epochs)
plt.figure(figsize=(12,3))
 
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='Training Accuracy')
plt.plot(epochs_range,test_acc,label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
 
plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='Training Loss')
plt.plot(epochs_range,test_loss,label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

运行结果:

2. 指定图片进行预测 

from PIL import Image
 
classes=list(total_data.class_to_idx)
 
def predict_one_image(image_path,model,transform,classes):
    
    test_img=Image.open(image_path).convert('RGB')
    plt.imshow(test_img)   #展示预测的图片
    
    test_img=transform(test_img)
    img=test_img.to(device).unsqueeze(0)
    
    model.eval()
    output=model(img)
    
    _,pred=torch.max(output,1)
    pred_class=classes[pred]
    print(f'预测结果是:{pred_class}')

预测图片:

#预测训练集中的某张照片
predict_one_image(image_path=r'D:\THE MNIST DATABASE\P4-data\Monkeypox\M01_02_00.jpg',
                  model=model,transform=train_transforms,classes=classes)

运行结果:

预测结果是:Monkeypox

3. 模型评估

J8_model.eval()
epoch_test_acc,epoch_test_loss=test(test_dl,J8_model,loss_fn)
epoch_test_acc,epoch_test_loss

运行结果:

(0.9627039627039627, 0.34947266622825907)

五、心得体会

在本次项目训练中,体会了再pytorch环境下搭建Iception V1模型的过程,加深了对该模型的理解。同时在模型运行过程中尝试调整了学习率,从实际结果来看,当学习率为1e-7和1e-5时的结果不如1e-4的结果更好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/903332.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

内网穿透之网络层ICMP隧道

免责申明 本文仅是用于学习检测自己搭建的靶场环境有关ICMP隧道原理和攻击实验,请勿用在非法途径上,若将其用于非法目的,所造成的一切后果由您自行承担,产生的一切风险和后果与笔者无关;本文开始前请认真详细学习《‌中华人民共和国网络安全法》‌及其所在国家地区相关法规…

提升网站流量和自然排名的SEO基本知识与策略分析

内容概要 在当今数字化时代&#xff0c;SEO&#xff08;搜索引擎优化&#xff09;成为加强网站可见度和提升流量的重要工具。SEO的基础知识包括理解搜索引擎的工作原理&#xff0c;以及如何通过优化网站内容和结构来提高自然排名。白帽SEO和黑帽SEO代表了两种截然不同的策略&a…

计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-27

计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-27 目录 文章目录 计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-27目录1. Large Language Model-based Augmentation for Imbalanced Node Classification on Text-Attributed Graphs摘要研究背…

耳背式助听器与定制式助听器,究竟该如何选?

在面对听力损失问题时&#xff0c;选择一款合适的助听器至关重要。目前&#xff0c;耳背式助听器和定制式助听器是比较常见的两种类型&#xff0c;很多人在二者之间犹豫不决。那么&#xff0c;到底应该怎么选呢&#xff1f; 一、耳背式助听器的特点 耳背式助听器形状类似香蕉&a…

论文阅读 - Pre-trained Online Contrastive Learning for Insurance Fraud Detection

Pre-trained Online Contrastive Learning for Insurance Fraud Detection| Proceedings of the AAAI Conference on Artificial Intelligence 目录 摘要 Introduction Methodology Problem Formulation Pre-trained Model for Enhanced Robustness Detecting Network a…

【STM32】程序建立模板

文章目录 STM32的开发方式建立基于库函数的工程建立工程的具体步骤具体程序举例工程架构 本篇介绍如何建立一个STM32工程 STM32工程结构比较复杂&#xff0c;需要用到的文件很多&#xff0c;并且程序代码也都是建立在工程结构的基础之上&#xff0c;所以学习如何新建一个STM32工…

Oracle视频基础1.1.4练习

1.1.4 dbb,ddabcPMON,SMON,LGWR,CKPT,DBWna5,b4,c2,d3,e1ad,a,c,b,eOracle instance,Oracle databaseSGA,background processcontrol file,data file,online redo file 以下是一篇关于 Oracle 基础习题 1.1.4 的博客&#xff1a; Oracle 基础习题解析&#xff1a;1.1.4 本篇文…

UE5 喷射背包

首选创建一个输入操作 然后在输入映射中添加&#xff0c;shift是向上飞&#xff0c;ctrl是向下飞 进入人物蓝图中编写逻辑&#xff0c;变量HaveJatpack默认true&#xff0c;Thrust为0 最后

linux进程的状态

​​​​​​​linux进程的概念 上篇我们学习了进程的概念&#xff0c;这篇我们将学习进程的状态 目录 前言 一、子进程和父进程 1、pid和ppid 2、通过系统调用创建进程-fork初识 二、进程的状态 1.Linux内核源代码 2.进程状态查看 3、Z(zombie)-僵尸进程 ​编辑 僵尸…

Linux下docker中elasticsearch与kibana的安装

他的脸红不是因为亚热带季风气候&#xff0c;而是因为那天太阳不忠&#xff0c;出卖一九九四年夏末心动。–《太平山顶》 在本篇博客中&#xff0c;我将详细介绍如何在 Linux 系统中安装并配置 Elasticsearch 和 Kibana&#xff0c;这两者是 ELK 堆栈的重要组成部分&#xff0c…

密钥管理方法DUKPT的OpenSSL代码实现Demo

目录 1 DUKPT简介 2 基本概念 2.1 BDK 2.2 KSN 2.3 IPEK 2.4 FK 2.5 TK 3 工作流程 3.1 密钥注入过程 3.2 交易过程 3.3 BDK派生IPEK过程 3.4 IPEK计算FK过程 4 演示Demo 4.1 开发环境 4.2 功能介绍 4.3 下载地址 5 在线工具 6 标准下载 1 DUKPT简介 DUKPT&a…

DEVOPS: 集群伸缩原理

概述 阿里云 K8S 集群的一个重要特性&#xff0c;是集群的节点可以动态的增加或减少有了这个特性&#xff0c;集群才能在计算资源不足的情况下扩容新的节点&#xff0c;同时也可以在资源利用 率降低的时候&#xff0c;释放节点以节省费用理解实现原理&#xff0c;在遇到问题的…

Linux系统解压分卷压缩文件的解决方案

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

【CUDA代码实践03】m维网格n维线程块对二维矩阵的索引

文章目录 一、数据存储方式二、二维网格二维线程块三、二维网格一维线程块四、一维网格一维线程块 为了方便下次找到文章&#xff0c;也方便联系我给大家提供帮助&#xff0c;欢迎大家点赞&#x1f44d;、收藏&#x1f4c2;和关注&#x1f514;&#xff01;一起讨论技术问题&am…

低功耗4G模组:FTP应用示例

一、FTP 概述 FTP&#xff08;File Transfer Protocol&#xff0c;文件传输协议&#xff09; 是 TCP/IP 协议组中的协议之一。 FTP协议包括两个组成部分&#xff0c;其一为FTP服务器&#xff0c;其二为FTP客户端。 其中FTP服务器用来存储文件&#xff0c;用户可以使用FTP客户…

鸿蒙UI开发——基于组件安全区方案实现沉浸式界面

1、概 述 本文是接着上篇文章 鸿蒙UI开发——基于全屏方案实现沉浸式界面 的继续讨论。除了全屏方案实现沉浸式界面外&#xff0c;我们还可以使用组件安全区的方案。 当我们没有使用setWindowLayoutFullScreen()接口设置窗口为全屏布局时&#xff0c;默认使用的策略就是组件安…

智慧税务管理:金融企业报税效率与合规性提升

前言 在数字化浪潮席卷全球的今天&#xff0c;金融行业正面临前所未有的挑战与机遇。如何在复杂的税务环境中保持合规并提高效率&#xff0c;已成为每个金融企业的重中之重。今天小编就为大家介绍一下如何通过借助智能税务平台&#xff0c;实现税务管理的智能化革新&#xff0…

Docker 常用命令全解析:提升对雷池社区版的使用经验

Docker 常用命令解析 Docker 是一个开源的容器化平台&#xff0c;允许开发者将应用及其依赖打包到一个可移植的容器中。以下是一些常用的 Docker 命令及其解析&#xff0c;帮助您更好地使用 Docker。 1. Docker 基础命令 查看 Docker 版本 docker --version查看 Docker 运行…

华为OD机试 - 无向图染色(Java 2024 E卷 100分)

华为OD机试 2024E卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试&#xff08;JAVA&#xff09;真题&#xff08;E卷D卷A卷B卷C卷&#xff09;》。 刷的越多&#xff0c;抽中的概率越大&#xff0c;私信哪吒&#xff0c;备注华为OD&#xff0c;加…

Python+pandas读取Excel将表头为键:对应行为值存为字典—再转json

目录 专栏导读1、库的介绍2、库的安装3、核心代码4、方法1:5、方法2总结专栏导读 🌸 欢迎来到Python办公自动化专栏—Python处理办公问题,解放您的双手 🏳️‍🌈 博客主页:请点击——> 一晌小贪欢的博客主页求关注 👍 该系列文章专栏:请点击——>Python办公自…