>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**
📌 本周任务:
1了解并学习图2中的卷积层运算量的计算过程(🏐储备知识->卷积层运算量的计算,有我的推导过程,建议先自己手动推导,然后再看)
2了解并学习卷积层的并行结构与1x1卷积核部分内容(重点)
3尝试根据模型框架图写入相应的pytorch代码,并使用Inception v1完成猴痘病识别
🏡 我的环境:
- 语言环境:Python3.8
- 编译器:Jupyter Notebook
- 深度学习环境:Pytorch
-
- torch==2.3.1+cu118
-
- torchvision==0.18.1+cu118
一、Inception v1
Inception v1 论文:Going deeper with convolutions
1. 理论知识
GoogLeNet首次出现在2014年ILSVRC 比赛中获得冠军。这次的版本通常称其为Inception V1。Inception V1有22层深,参数量为5M。同一时期的VGGNet性能和Inception V1差不多,但是参数量也是远大于Inception V1。
Inception Module是Inception V1的核心组成单元,提出了卷积层的并行结构,实现了在同一层就可以提取不同的特征,如下图。
按照这样的结构来增加网络的深度,虽然可以提升性能,但是还面临计算量大(参数多)的问题。为改善这种现象,Inception Module借鉴Network-in-Network的思想,使用1x1的卷积核实现降维操作(也间接增加了网络的深度),以此来减小网络的参数量与计算量,如上图b所示。
备注举例:假如前一层的输出为100x100x128,经过具有256个5x5卷积核的卷积层之后(stride=1,pad=2),输出数据为100x100x256。其中,卷积层的参数为5x5x128x256+256。假如上一层输出先经过具有32个1x1卷积核的卷积层(1x1卷积降低了通道数,且特征图尺寸不变),再经过具有256个5x5卷积核的卷积层,最终的输出数据仍为为100x100x256,但卷积参数量已经减少为128x1x1x32+32 + 32x5x5x256+256,参数数量减少为原来的约4分之一。其计算量由原先的 ,降低至 ,更详细的计算过程可参考我训练营内发布的“卷积层计算量的计算”一文。
1x1卷积核的作用: 1x1卷积核的最大作用是降低输入特征图的通道数,减小网络的参数量与计算量。
最后Inception Module基本由11卷积,33卷积,55卷积,33最大池化四个基本单元组成,对四个基本单元运算结果进行通道上组合,不同大小的卷积核赋予不同大小的感受野,从而提取到图像不同尺度的信息,进行融合,得到图像更好的表征, 就是Inception Module的核心思想。
2. 算法结构
实现的Inception v1网络结构图如下:
注:另外增加了两个辅助分支,作用有两点,一是为了避免梯度消失,用于向前传导梯度。反向传播时如果有一层求导为0,链式求导结果则为0。二是将中间某一层输出用作分类,起到模型融合作用,实际测试时,这两个辅助softmax分支会被去掉,在后续模型的发展中,该方法被采用较少,可以直接绕过,重点学习卷积层的并行结构与1x1卷积核部分的内容即可。
详细网络结构图如下:
二、 前期准备
1. 设置GPU
如果设备上支持GPU就使用GPU,否则使用CPU
import warnings
warnings.filterwarnings("ignore")
import torch
device=torch.device("cuda" if torch.cuda.is_available() else "CPU")
device
运行结果:
device(type='cuda')
2. 导入数据
同时查看数据集中图片的数量
import pathlib
data_dir=r'D:\THE MNIST DATABASE\P4-data'
data_dir=pathlib.Path(data_dir)
image_count=len(list(data_dir.glob('*/*')))
image_count
运行结果:
2142
3. 查看数据集分类
data_paths=list(data_dir.glob('*'))
classNames=[str(path).split("\\")[3] for path in data_paths]
classNames
运行结果:
['Monkeypox', 'Others']
4. 随机查看图片
随机抽取数据集中的20张图片进行查看
import PIL,random
import matplotlib.pyplot as plt
from PIL import Image
data_paths2=list(data_dir.glob('*/*'))
plt.figure(figsize=(20,4))
for i in range(20):
plt.subplot(2,10,i+1)
plt.axis('off')
image=random.choice(data_paths2) #随机选择一个图片
plt.title(image.parts[-2]) #通过glob对象取出他的文件夹名称,即分类名
plt.imshow(Image.open(str(image))) #显示图片
运行结果:
5. 图片预处理
import torchvision.transforms as transforms
from torchvision import transforms,datasets
train_transforms=transforms.Compose([
transforms.Resize([224,224]), #将图片统一尺寸
transforms.RandomHorizontalFlip(), #将图片随机水平翻转
transforms.ToTensor(), #将图片转换为tensor
transforms.Normalize( #标准化处理-->转换为正态分布,使模型更容易收敛
mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225]
)
])
total_data=datasets.ImageFolder(
r'D:\THE MNIST DATABASE\P4-data',
transform=train_transforms
)
total_data
运行结果:
Dataset ImageFolder
Number of datapoints: 2142
Root location: D:\THE MNIST DATABASE\P4-data
StandardTransform
Transform: Compose(
Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
RandomHorizontalFlip(p=0.5)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
将数据集分类情况进行映射输出:
total_data.class_to_idx
运行结果:
{'Monkeypox': 0, 'Others': 1}
6. 划分数据集
train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_size
train_dataset,test_dataset=torch.utils.data.random_split(
total_data,
[train_size,test_size]
)
train_dataset,test_dataset
运行结果:
(<torch.utils.data.dataset.Subset at 0x241dec0e950>,
<torch.utils.data.dataset.Subset at 0x241deef32d0>)
查看训练集和测试集的数据数量:
train_size,test_size
运行结果:
(1713, 429)
7. 加载数据集
batch_size=16
train_dl=torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=1
)
test_dl=torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=1
)
查看测试集的情况:
for x,y in train_dl:
print("Shape of x [N,C,H,W]:",x.shape)
print("Shape of y:",y.shape,y.dtype)
break
运行结果:
Shape of x [N,C,H,W]: torch.Size([16, 3, 224, 224])
Shape of y: torch.Size([16]) torch.int64
二、pytorch环境下复现 Inception V1 模型
参考上方表格中的网络架构和参数搭建 Inception V1 网络模型
1. 构建模型
这里去掉了两个辅助分支,直接复现主支。
定义一个名为 Inception 的类,继承自 nn.Module。inception_block 类包含了 Inception v1 模型的所有层和参数。
import torch.nn as nn
import torch.nn.functional as F
class inception_block(nn.Module):
def __init__(self,in_channels,ch1x1,ch3x3red,ch3x3,ch5x5red,ch5x5,pool_proj):
super(inception_block,self).__init__()
# 1x1 conv branch
self.branch1=nn.Sequential(
nn.Conv2d(in_channels,ch1x1,kernel_size=1),
nn.BatchNorm2d(ch1x1),
nn.ReLU(inplace=True)
)
# 1x1 conv -> 3x3 conv branch
self.branch2=nn.Sequential(
nn.Conv2d(in_channels,ch3x3red,kernel_size=1),
nn.BatchNorm2d(ch3x3red),
nn.ReLU(inplace=True),
nn.Conv2d(ch3x3red,ch3x3,kernel_size=3,padding=1),
nn.BatchNorm2d(ch3x3),
nn.ReLU(inplace=True)
)
# 1x1 conv -> 5x5 conv branch
self.branch3=nn.Sequential(
nn.Conv2d(in_channels,ch5x5red,kernel_size=1),
nn.BatchNorm2d(ch5x5red),
nn.ReLU(inplace=True),
nn.Conv2d(ch5x5red,ch5x5,kernel_size=5,padding=2),
nn.BatchNorm2d(ch5x5),
nn.ReLU(inplace=True)
)
# 3x3 max pooling -> 1x1 conv branch
self.branch4=nn.Sequential(
nn.MaxPool2d(kernel_size=3,stride=1,padding=1),
nn.Conv2d(in_channels,pool_proj,kernel_size=1),
nn.BatchNorm2d(pool_proj),
nn.ReLU(inplace=True)
)
def forward(self,x):
#compute forward pass through all branches and concatenate the output feature maps
branch1_output=self.branch1(x)
branch2_output=self.branch2(x)
branch3_output=self.branch3(x)
branch4_output=self.branch4(x)
outputs=[branch1_output,branch2_output,branch3_output,branch4_output]
return torch.cat(outputs,1)
在__init__方法中,定义了四个分支,分别是:
(1)branch1,一个 1x1 卷积层;
(2)branch2,一个 1x1 卷积层接一个 3x3 卷积层;
(3)branch3,一个 1x1 卷积层接一个 5x5 卷积层;
(4)branch4,一个 3x3 最大池化层接一个 1x1 卷积层。
每个分支都包含了一些卷积层、批归一化层和激活函数。这些层都是 PyTorch 中的标准层,我们可以使用 nn.Conv2d、nn.BatchNorm2d 和 nn.ReLU 分别定义卷积层、批归一化层和 ReLU 激活函数。
在 forward 方法中,我们计算从输入到所有分支的前向传递,并将所有分支的输出特征图拼接在一起。最后,我们返回拼接后的特征图。
接下来,我们定义 Inception v1 模型,使用 nn.ModuleList 和 nn.Sequential 组合多个 Inception 模块和其他层。
class InceptionV1(nn.Module):
def __init__(self,num_classes=1000):
super(InceptionV1,self).__init__()
self.conv1=nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3)
self.maxpool1=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
self.conv2=nn.Conv2d(64,64,kernel_size=1,stride=1,padding=0)
self.conv3=nn.Conv2d(64,192,kernel_size=3,stride=1,padding=1)
self.maxpool2=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
self.inception3a=inception_block(192,64,96,128,16,32,32)
self.inception3b=inception_block(256,128,128,192,32,96,64)
self.maxpool3=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
self.inception4a=inception_block(480,192,96,208,16,48,64)
self.inception4b=inception_block(512,160,112,224,24,64,64)
self.inception4c=inception_block(512,128,128,256,24,64,64)
self.inception4d=inception_block(512,112,114,288,32,64,64)
self.inception4e=inception_block(528,256,160,320,32,128,128)
self.maxpool4=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
self.inception5a=inception_block(832,256,160,320,32,128,128)
self.inception5b=nn.Sequential(
inception_block(832,384,192,384,48,128,128),
nn.AvgPool2d(kernel_size=7,stride=1,padding=0),
nn.Dropout(0.4)
)
#全连接网络层,用于分类
self.classifier=nn.Sequential(
nn.Linear(in_features=1024,out_features=1024),
nn.ReLU(),
nn.Linear(in_features=1024,out_features=num_classes),
nn.Softmax(dim=1)
)
def forward(self,x):
x=self.conv1(x)
x=F.relu(x)
x=self.maxpool1(x)
x=self.conv2(x)
x=F.relu(x)
x=self.conv3(x)
x=F.relu(x)
x=self.maxpool2(x)
x=self.inception3a(x)
x=self.inception3b(x)
x=self.maxpool3(x)
x=self.inception4a(x)
x=self.inception4b(x)
x=self.inception4c(x)
x=self.inception4d(x)
x=self.inception4e(x)
x=self.maxpool4(x)
x=self.inception5a(x)
x=self.inception5b(x)
x=torch.flatten(x,start_dim=1)
x=self.classifier(x)
return x
2. 输出模型结果
#统计模型参数量以及其他指标
import torchsummary
#调用并将模型转移到GPU中
model=InceptionV1(num_classes=2).to(device)
#显示网络结构
torchsummary.summary(model,(3,224,224))
print(model)
运行结果:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,472
MaxPool2d-2 [-1, 64, 56, 56] 0
Conv2d-3 [-1, 64, 56, 56] 4,160
Conv2d-4 [-1, 192, 56, 56] 110,784
MaxPool2d-5 [-1, 192, 28, 28] 0
Conv2d-6 [-1, 64, 28, 28] 12,352
BatchNorm2d-7 [-1, 64, 28, 28] 128
ReLU-8 [-1, 64, 28, 28] 0
Conv2d-9 [-1, 96, 28, 28] 18,528
BatchNorm2d-10 [-1, 96, 28, 28] 192
ReLU-11 [-1, 96, 28, 28] 0
Conv2d-12 [-1, 128, 28, 28] 110,720
BatchNorm2d-13 [-1, 128, 28, 28] 256
ReLU-14 [-1, 128, 28, 28] 0
Conv2d-15 [-1, 16, 28, 28] 3,088
BatchNorm2d-16 [-1, 16, 28, 28] 32
ReLU-17 [-1, 16, 28, 28] 0
Conv2d-18 [-1, 32, 28, 28] 12,832
BatchNorm2d-19 [-1, 32, 28, 28] 64
ReLU-20 [-1, 32, 28, 28] 0
MaxPool2d-21 [-1, 192, 28, 28] 0
Conv2d-22 [-1, 32, 28, 28] 6,176
BatchNorm2d-23 [-1, 32, 28, 28] 64
ReLU-24 [-1, 32, 28, 28] 0
inception_block-25 [-1, 256, 28, 28] 0
Conv2d-26 [-1, 128, 28, 28] 32,896
BatchNorm2d-27 [-1, 128, 28, 28] 256
ReLU-28 [-1, 128, 28, 28] 0
Conv2d-29 [-1, 128, 28, 28] 32,896
BatchNorm2d-30 [-1, 128, 28, 28] 256
ReLU-31 [-1, 128, 28, 28] 0
Conv2d-32 [-1, 192, 28, 28] 221,376
BatchNorm2d-33 [-1, 192, 28, 28] 384
ReLU-34 [-1, 192, 28, 28] 0
Conv2d-35 [-1, 32, 28, 28] 8,224
BatchNorm2d-36 [-1, 32, 28, 28] 64
ReLU-37 [-1, 32, 28, 28] 0
Conv2d-38 [-1, 96, 28, 28] 76,896
BatchNorm2d-39 [-1, 96, 28, 28] 192
ReLU-40 [-1, 96, 28, 28] 0
MaxPool2d-41 [-1, 256, 28, 28] 0
Conv2d-42 [-1, 64, 28, 28] 16,448
BatchNorm2d-43 [-1, 64, 28, 28] 128
ReLU-44 [-1, 64, 28, 28] 0
inception_block-45 [-1, 480, 28, 28] 0
MaxPool2d-46 [-1, 480, 14, 14] 0
Conv2d-47 [-1, 192, 14, 14] 92,352
BatchNorm2d-48 [-1, 192, 14, 14] 384
ReLU-49 [-1, 192, 14, 14] 0
Conv2d-50 [-1, 96, 14, 14] 46,176
BatchNorm2d-51 [-1, 96, 14, 14] 192
ReLU-52 [-1, 96, 14, 14] 0
Conv2d-53 [-1, 208, 14, 14] 179,920
BatchNorm2d-54 [-1, 208, 14, 14] 416
ReLU-55 [-1, 208, 14, 14] 0
Conv2d-56 [-1, 16, 14, 14] 7,696
BatchNorm2d-57 [-1, 16, 14, 14] 32
ReLU-58 [-1, 16, 14, 14] 0
Conv2d-59 [-1, 48, 14, 14] 19,248
BatchNorm2d-60 [-1, 48, 14, 14] 96
ReLU-61 [-1, 48, 14, 14] 0
MaxPool2d-62 [-1, 480, 14, 14] 0
Conv2d-63 [-1, 64, 14, 14] 30,784
BatchNorm2d-64 [-1, 64, 14, 14] 128
ReLU-65 [-1, 64, 14, 14] 0
inception_block-66 [-1, 512, 14, 14] 0
Conv2d-67 [-1, 160, 14, 14] 82,080
BatchNorm2d-68 [-1, 160, 14, 14] 320
ReLU-69 [-1, 160, 14, 14] 0
Conv2d-70 [-1, 112, 14, 14] 57,456
BatchNorm2d-71 [-1, 112, 14, 14] 224
ReLU-72 [-1, 112, 14, 14] 0
Conv2d-73 [-1, 224, 14, 14] 226,016
BatchNorm2d-74 [-1, 224, 14, 14] 448
ReLU-75 [-1, 224, 14, 14] 0
Conv2d-76 [-1, 24, 14, 14] 12,312
BatchNorm2d-77 [-1, 24, 14, 14] 48
ReLU-78 [-1, 24, 14, 14] 0
Conv2d-79 [-1, 64, 14, 14] 38,464
BatchNorm2d-80 [-1, 64, 14, 14] 128
ReLU-81 [-1, 64, 14, 14] 0
MaxPool2d-82 [-1, 512, 14, 14] 0
Conv2d-83 [-1, 64, 14, 14] 32,832
BatchNorm2d-84 [-1, 64, 14, 14] 128
ReLU-85 [-1, 64, 14, 14] 0
inception_block-86 [-1, 512, 14, 14] 0
Conv2d-87 [-1, 128, 14, 14] 65,664
BatchNorm2d-88 [-1, 128, 14, 14] 256
ReLU-89 [-1, 128, 14, 14] 0
Conv2d-90 [-1, 128, 14, 14] 65,664
BatchNorm2d-91 [-1, 128, 14, 14] 256
ReLU-92 [-1, 128, 14, 14] 0
Conv2d-93 [-1, 256, 14, 14] 295,168
BatchNorm2d-94 [-1, 256, 14, 14] 512
ReLU-95 [-1, 256, 14, 14] 0
Conv2d-96 [-1, 24, 14, 14] 12,312
BatchNorm2d-97 [-1, 24, 14, 14] 48
ReLU-98 [-1, 24, 14, 14] 0
Conv2d-99 [-1, 64, 14, 14] 38,464
BatchNorm2d-100 [-1, 64, 14, 14] 128
ReLU-101 [-1, 64, 14, 14] 0
MaxPool2d-102 [-1, 512, 14, 14] 0
Conv2d-103 [-1, 64, 14, 14] 32,832
BatchNorm2d-104 [-1, 64, 14, 14] 128
ReLU-105 [-1, 64, 14, 14] 0
inception_block-106 [-1, 512, 14, 14] 0
Conv2d-107 [-1, 112, 14, 14] 57,456
BatchNorm2d-108 [-1, 112, 14, 14] 224
ReLU-109 [-1, 112, 14, 14] 0
Conv2d-110 [-1, 114, 14, 14] 58,482
BatchNorm2d-111 [-1, 114, 14, 14] 228
ReLU-112 [-1, 114, 14, 14] 0
Conv2d-113 [-1, 288, 14, 14] 295,776
BatchNorm2d-114 [-1, 288, 14, 14] 576
ReLU-115 [-1, 288, 14, 14] 0
Conv2d-116 [-1, 32, 14, 14] 16,416
BatchNorm2d-117 [-1, 32, 14, 14] 64
ReLU-118 [-1, 32, 14, 14] 0
Conv2d-119 [-1, 64, 14, 14] 51,264
BatchNorm2d-120 [-1, 64, 14, 14] 128
ReLU-121 [-1, 64, 14, 14] 0
MaxPool2d-122 [-1, 512, 14, 14] 0
Conv2d-123 [-1, 64, 14, 14] 32,832
BatchNorm2d-124 [-1, 64, 14, 14] 128
ReLU-125 [-1, 64, 14, 14] 0
inception_block-126 [-1, 528, 14, 14] 0
Conv2d-127 [-1, 256, 14, 14] 135,424
BatchNorm2d-128 [-1, 256, 14, 14] 512
ReLU-129 [-1, 256, 14, 14] 0
Conv2d-130 [-1, 160, 14, 14] 84,640
BatchNorm2d-131 [-1, 160, 14, 14] 320
ReLU-132 [-1, 160, 14, 14] 0
Conv2d-133 [-1, 320, 14, 14] 461,120
BatchNorm2d-134 [-1, 320, 14, 14] 640
ReLU-135 [-1, 320, 14, 14] 0
Conv2d-136 [-1, 32, 14, 14] 16,928
BatchNorm2d-137 [-1, 32, 14, 14] 64
ReLU-138 [-1, 32, 14, 14] 0
Conv2d-139 [-1, 128, 14, 14] 102,528
BatchNorm2d-140 [-1, 128, 14, 14] 256
ReLU-141 [-1, 128, 14, 14] 0
MaxPool2d-142 [-1, 528, 14, 14] 0
Conv2d-143 [-1, 128, 14, 14] 67,712
BatchNorm2d-144 [-1, 128, 14, 14] 256
ReLU-145 [-1, 128, 14, 14] 0
inception_block-146 [-1, 832, 14, 14] 0
MaxPool2d-147 [-1, 832, 7, 7] 0
Conv2d-148 [-1, 256, 7, 7] 213,248
BatchNorm2d-149 [-1, 256, 7, 7] 512
ReLU-150 [-1, 256, 7, 7] 0
Conv2d-151 [-1, 160, 7, 7] 133,280
BatchNorm2d-152 [-1, 160, 7, 7] 320
ReLU-153 [-1, 160, 7, 7] 0
Conv2d-154 [-1, 320, 7, 7] 461,120
BatchNorm2d-155 [-1, 320, 7, 7] 640
ReLU-156 [-1, 320, 7, 7] 0
Conv2d-157 [-1, 32, 7, 7] 26,656
BatchNorm2d-158 [-1, 32, 7, 7] 64
ReLU-159 [-1, 32, 7, 7] 0
Conv2d-160 [-1, 128, 7, 7] 102,528
BatchNorm2d-161 [-1, 128, 7, 7] 256
ReLU-162 [-1, 128, 7, 7] 0
MaxPool2d-163 [-1, 832, 7, 7] 0
Conv2d-164 [-1, 128, 7, 7] 106,624
BatchNorm2d-165 [-1, 128, 7, 7] 256
ReLU-166 [-1, 128, 7, 7] 0
inception_block-167 [-1, 832, 7, 7] 0
Conv2d-168 [-1, 384, 7, 7] 319,872
BatchNorm2d-169 [-1, 384, 7, 7] 768
ReLU-170 [-1, 384, 7, 7] 0
Conv2d-171 [-1, 192, 7, 7] 159,936
BatchNorm2d-172 [-1, 192, 7, 7] 384
ReLU-173 [-1, 192, 7, 7] 0
Conv2d-174 [-1, 384, 7, 7] 663,936
BatchNorm2d-175 [-1, 384, 7, 7] 768
ReLU-176 [-1, 384, 7, 7] 0
Conv2d-177 [-1, 48, 7, 7] 39,984
BatchNorm2d-178 [-1, 48, 7, 7] 96
ReLU-179 [-1, 48, 7, 7] 0
Conv2d-180 [-1, 128, 7, 7] 153,728
BatchNorm2d-181 [-1, 128, 7, 7] 256
ReLU-182 [-1, 128, 7, 7] 0
MaxPool2d-183 [-1, 832, 7, 7] 0
Conv2d-184 [-1, 128, 7, 7] 106,624
BatchNorm2d-185 [-1, 128, 7, 7] 256
ReLU-186 [-1, 128, 7, 7] 0
inception_block-187 [-1, 1024, 7, 7] 0
AvgPool2d-188 [-1, 1024, 1, 1] 0
Dropout-189 [-1, 1024, 1, 1] 0
Linear-190 [-1, 1024] 1,049,600
ReLU-191 [-1, 1024] 0
Linear-192 [-1, 2] 2,050
Softmax-193 [-1, 2] 0
================================================================
Total params: 6,945,912
Trainable params: 6,945,912
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 69.48
Params size (MB): 26.50
Estimated Total Size (MB): 96.55
----------------------------------------------------------------
InceptionV1(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
(maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(conv2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(conv3): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(inception3a): inception_block(
(branch1): Sequential(
(0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(branch2): Sequential(
(0): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch3): Sequential(
(0): Conv2d(192, 16, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch4): Sequential(
(0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(1): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
)
(inception3b): inception_block(
(branch1): Sequential(
(0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(branch2): Sequential(
(0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch3): Sequential(
(0): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(32, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch4): Sequential(
(0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
)
(maxpool3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(inception4a): inception_block(
(branch1): Sequential(
(0): Conv2d(480, 192, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(branch2): Sequential(
(0): Conv2d(480, 96, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(96, 208, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(208, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch3): Sequential(
(0): Conv2d(480, 16, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(16, 48, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch4): Sequential(
(0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(1): Conv2d(480, 64, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
)
(inception4b): inception_block(
(branch1): Sequential(
(0): Conv2d(512, 160, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(branch2): Sequential(
(0): Conv2d(512, 112, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(112, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch3): Sequential(
(0): Conv2d(512, 24, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(24, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch4): Sequential(
(0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(1): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
)
(inception4c): inception_block(
(branch1): Sequential(
(0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(branch2): Sequential(
(0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch3): Sequential(
(0): Conv2d(512, 24, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(24, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch4): Sequential(
(0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(1): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
)
(inception4d): inception_block(
(branch1): Sequential(
(0): Conv2d(512, 112, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(branch2): Sequential(
(0): Conv2d(512, 114, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(114, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(114, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch3): Sequential(
(0): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch4): Sequential(
(0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(1): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
)
(inception4e): inception_block(
(branch1): Sequential(
(0): Conv2d(528, 256, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(branch2): Sequential(
(0): Conv2d(528, 160, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch3): Sequential(
(0): Conv2d(528, 32, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(32, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch4): Sequential(
(0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(1): Conv2d(528, 128, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
)
(maxpool4): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(inception5a): inception_block(
(branch1): Sequential(
(0): Conv2d(832, 256, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(branch2): Sequential(
(0): Conv2d(832, 160, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch3): Sequential(
(0): Conv2d(832, 32, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(32, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch4): Sequential(
(0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
)
(inception5b): Sequential(
(0): inception_block(
(branch1): Sequential(
(0): Conv2d(832, 384, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(branch2): Sequential(
(0): Conv2d(832, 192, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch3): Sequential(
(0): Conv2d(832, 48, kernel_size=(1, 1), stride=(1, 1))
(1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(48, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
)
(branch4): Sequential(
(0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
)
(1): AvgPool2d(kernel_size=7, stride=1, padding=0)
(2): Dropout(p=0.4, inplace=False)
)
(classifier): Sequential(
(0): Linear(in_features=1024, out_features=1024, bias=True)
(1): ReLU()
(2): Linear(in_features=1024, out_features=2, bias=True)
(3): Softmax(dim=1)
)
)
三、 训练模型
1. 编写训练函数
def train(dataloader,model,loss_fn,optimizer):
size=len(dataloader.dataset) #训练集的大小
num_batches=len(dataloader) #批次数目
train_loss,train_acc=0,0 #初始化训练损失和正确率
for x,y in dataloader: #获取图片及其标签
x,y=x.to(device),y.to(device)
#计算预测误差
pred=model(x) #网络输出
loss=loss_fn(pred,y) #计算网络输出和真实值之间的差距,二者差值即为损失
#反向传播
optimizer.zero_grad() #grad属性归零
loss.backward() #反向传播
optimizer.step() #每一步自动更新
#记录acc与loss
train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()
train_loss+=loss.item()
train_acc/=size
train_loss/=num_batches
return train_acc,train_loss
2. 编写测试函数
测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器
#测试函数
def test(dataloader,model,loss_fn):
size=len(dataloader.dataset) #测试集的大小
num_batches=len(dataloader) #批次数目
test_loss,test_acc=0,0
#当不进行训练时,停止梯度更新,节省计算内存消耗
with torch.no_grad():
for imgs,target in dataloader:
imgs,target=imgs.to(device),target.to(device)
#计算loss
target_pred=model(imgs)
loss=loss_fn(target_pred,target)
test_loss+=loss.item()
test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()
test_acc/=size
test_loss/=num_batches
return test_acc,test_loss
3. 正式训练
import copy
optimizer=torch.optim.Adam(model.parameters(),lr=1e-4) #创建优化器,并设置学习率
loss_fn=nn.CrossEntropyLoss() #创建损失函数
epochs=100
train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]
best_acc=0 #设置一个最佳准确率,作为最佳模型的判别指标
for epoch in range(epochs):
model.train()
epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)
model.eval()
epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)
#保存最佳模型到J8_model
if epoch_test_acc>best_acc:
best_acc=epoch_test_acc
J8_model=copy.deepcopy(model)
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
#获取当前学习率
lr=optimizer.state_dict()['param_groups'][0]['lr']
template=('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')
print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,
epoch_test_acc*100,epoch_test_loss,lr))
#保存最佳模型到文件中
PATH=r'D:\THE MNIST DATABASE\J-series\J8_model.pth'
torch.save(model.state_dict(),PATH)
print('Done')
运行结果:
Epoch: 1,Train_acc:66.4%,Train_loss:0.633,Test_acc:61.5%,Test_loss:0.656,Lr:1.00E-04
Epoch: 2,Train_acc:67.3%,Train_loss:0.616,Test_acc:68.1%,Test_loss:0.624,Lr:1.00E-04
Epoch: 3,Train_acc:69.9%,Train_loss:0.591,Test_acc:68.5%,Test_loss:0.613,Lr:1.00E-04
Epoch: 4,Train_acc:73.3%,Train_loss:0.574,Test_acc:68.1%,Test_loss:0.605,Lr:1.00E-04
Epoch: 5,Train_acc:75.3%,Train_loss:0.556,Test_acc:65.0%,Test_loss:0.651,Lr:1.00E-04
Epoch: 6,Train_acc:74.9%,Train_loss:0.551,Test_acc:73.0%,Test_loss:0.563,Lr:1.00E-04
Epoch: 7,Train_acc:77.4%,Train_loss:0.529,Test_acc:80.4%,Test_loss:0.500,Lr:1.00E-04
Epoch: 8,Train_acc:80.2%,Train_loss:0.503,Test_acc:78.1%,Test_loss:0.518,Lr:1.00E-04
Epoch: 9,Train_acc:79.6%,Train_loss:0.509,Test_acc:77.4%,Test_loss:0.533,Lr:1.00E-04
Epoch:10,Train_acc:80.0%,Train_loss:0.504,Test_acc:77.6%,Test_loss:0.528,Lr:1.00E-04
Epoch:11,Train_acc:85.4%,Train_loss:0.456,Test_acc:86.0%,Test_loss:0.448,Lr:1.00E-04
Epoch:12,Train_acc:84.0%,Train_loss:0.466,Test_acc:84.4%,Test_loss:0.469,Lr:1.00E-04
Epoch:13,Train_acc:84.1%,Train_loss:0.467,Test_acc:80.0%,Test_loss:0.511,Lr:1.00E-04
Epoch:14,Train_acc:85.8%,Train_loss:0.449,Test_acc:77.4%,Test_loss:0.534,Lr:1.00E-04
Epoch:15,Train_acc:85.5%,Train_loss:0.460,Test_acc:83.2%,Test_loss:0.471,Lr:1.00E-04
Epoch:16,Train_acc:84.8%,Train_loss:0.464,Test_acc:81.8%,Test_loss:0.490,Lr:1.00E-04
Epoch:17,Train_acc:84.2%,Train_loss:0.467,Test_acc:76.5%,Test_loss:0.540,Lr:1.00E-04
Epoch:18,Train_acc:83.7%,Train_loss:0.471,Test_acc:84.6%,Test_loss:0.467,Lr:1.00E-04
Epoch:19,Train_acc:87.3%,Train_loss:0.436,Test_acc:88.3%,Test_loss:0.427,Lr:1.00E-04
Epoch:20,Train_acc:86.2%,Train_loss:0.454,Test_acc:82.5%,Test_loss:0.487,Lr:1.00E-04
Epoch:21,Train_acc:86.2%,Train_loss:0.444,Test_acc:87.9%,Test_loss:0.430,Lr:1.00E-04
Epoch:22,Train_acc:88.3%,Train_loss:0.437,Test_acc:87.9%,Test_loss:0.436,Lr:1.00E-04
Epoch:23,Train_acc:89.2%,Train_loss:0.423,Test_acc:86.7%,Test_loss:0.442,Lr:1.00E-04
Epoch:24,Train_acc:89.6%,Train_loss:0.424,Test_acc:89.0%,Test_loss:0.421,Lr:1.00E-04
Epoch:25,Train_acc:90.6%,Train_loss:0.405,Test_acc:91.6%,Test_loss:0.402,Lr:1.00E-04
Epoch:26,Train_acc:90.7%,Train_loss:0.413,Test_acc:89.7%,Test_loss:0.412,Lr:1.00E-04
Epoch:27,Train_acc:90.5%,Train_loss:0.406,Test_acc:87.6%,Test_loss:0.431,Lr:1.00E-04
Epoch:28,Train_acc:87.6%,Train_loss:0.427,Test_acc:86.0%,Test_loss:0.451,Lr:1.00E-04
Epoch:29,Train_acc:89.1%,Train_loss:0.417,Test_acc:89.0%,Test_loss:0.421,Lr:1.00E-04
Epoch:30,Train_acc:91.5%,Train_loss:0.393,Test_acc:90.7%,Test_loss:0.406,Lr:1.00E-04
Epoch:31,Train_acc:92.1%,Train_loss:0.395,Test_acc:88.1%,Test_loss:0.427,Lr:1.00E-04
Epoch:32,Train_acc:93.0%,Train_loss:0.385,Test_acc:88.8%,Test_loss:0.418,Lr:1.00E-04
Epoch:33,Train_acc:91.4%,Train_loss:0.397,Test_acc:91.1%,Test_loss:0.402,Lr:1.00E-04
Epoch:34,Train_acc:92.5%,Train_loss:0.385,Test_acc:88.8%,Test_loss:0.425,Lr:1.00E-04
Epoch:35,Train_acc:91.8%,Train_loss:0.400,Test_acc:92.1%,Test_loss:0.391,Lr:1.00E-04
Epoch:36,Train_acc:91.9%,Train_loss:0.396,Test_acc:92.1%,Test_loss:0.390,Lr:1.00E-04
Epoch:37,Train_acc:90.5%,Train_loss:0.409,Test_acc:90.0%,Test_loss:0.413,Lr:1.00E-04
Epoch:38,Train_acc:93.1%,Train_loss:0.381,Test_acc:86.0%,Test_loss:0.444,Lr:1.00E-04
Epoch:39,Train_acc:93.1%,Train_loss:0.381,Test_acc:93.7%,Test_loss:0.379,Lr:1.00E-04
Epoch:40,Train_acc:93.5%,Train_loss:0.387,Test_acc:93.0%,Test_loss:0.381,Lr:1.00E-04
Epoch:41,Train_acc:94.1%,Train_loss:0.379,Test_acc:91.8%,Test_loss:0.394,Lr:1.00E-04
Epoch:42,Train_acc:93.6%,Train_loss:0.377,Test_acc:93.2%,Test_loss:0.381,Lr:1.00E-04
Epoch:43,Train_acc:93.9%,Train_loss:0.380,Test_acc:92.5%,Test_loss:0.384,Lr:1.00E-04
Epoch:44,Train_acc:93.9%,Train_loss:0.381,Test_acc:92.5%,Test_loss:0.384,Lr:1.00E-04
Epoch:45,Train_acc:89.6%,Train_loss:0.413,Test_acc:92.3%,Test_loss:0.388,Lr:1.00E-04
Epoch:46,Train_acc:91.7%,Train_loss:0.395,Test_acc:90.9%,Test_loss:0.401,Lr:1.00E-04
Epoch:47,Train_acc:93.4%,Train_loss:0.387,Test_acc:90.4%,Test_loss:0.407,Lr:1.00E-04
Epoch:48,Train_acc:93.6%,Train_loss:0.375,Test_acc:92.1%,Test_loss:0.388,Lr:1.00E-04
Epoch:49,Train_acc:93.8%,Train_loss:0.375,Test_acc:94.9%,Test_loss:0.367,Lr:1.00E-04
Epoch:50,Train_acc:94.2%,Train_loss:0.379,Test_acc:92.1%,Test_loss:0.390,Lr:1.00E-04
Epoch:51,Train_acc:93.6%,Train_loss:0.385,Test_acc:92.5%,Test_loss:0.383,Lr:1.00E-04
Epoch:52,Train_acc:93.9%,Train_loss:0.380,Test_acc:90.0%,Test_loss:0.414,Lr:1.00E-04
Epoch:53,Train_acc:93.2%,Train_loss:0.378,Test_acc:90.2%,Test_loss:0.412,Lr:1.00E-04
Epoch:54,Train_acc:92.6%,Train_loss:0.393,Test_acc:85.5%,Test_loss:0.454,Lr:1.00E-04
Epoch:55,Train_acc:91.9%,Train_loss:0.397,Test_acc:91.6%,Test_loss:0.398,Lr:1.00E-04
Epoch:56,Train_acc:94.3%,Train_loss:0.368,Test_acc:92.5%,Test_loss:0.384,Lr:1.00E-04
Epoch:57,Train_acc:95.7%,Train_loss:0.354,Test_acc:93.5%,Test_loss:0.379,Lr:1.00E-04
Epoch:58,Train_acc:94.2%,Train_loss:0.380,Test_acc:93.5%,Test_loss:0.377,Lr:1.00E-04
Epoch:59,Train_acc:95.3%,Train_loss:0.361,Test_acc:93.0%,Test_loss:0.381,Lr:1.00E-04
Epoch:60,Train_acc:92.5%,Train_loss:0.385,Test_acc:90.0%,Test_loss:0.412,Lr:1.00E-04
Epoch:61,Train_acc:94.7%,Train_loss:0.372,Test_acc:95.1%,Test_loss:0.362,Lr:1.00E-04
Epoch:62,Train_acc:95.1%,Train_loss:0.369,Test_acc:90.2%,Test_loss:0.408,Lr:1.00E-04
Epoch:63,Train_acc:94.2%,Train_loss:0.380,Test_acc:92.8%,Test_loss:0.385,Lr:1.00E-04
Epoch:64,Train_acc:94.7%,Train_loss:0.374,Test_acc:91.8%,Test_loss:0.395,Lr:1.00E-04
Epoch:65,Train_acc:96.3%,Train_loss:0.359,Test_acc:94.2%,Test_loss:0.372,Lr:1.00E-04
Epoch:66,Train_acc:95.1%,Train_loss:0.361,Test_acc:92.1%,Test_loss:0.392,Lr:1.00E-04
Epoch:67,Train_acc:95.8%,Train_loss:0.354,Test_acc:92.3%,Test_loss:0.390,Lr:1.00E-04
Epoch:68,Train_acc:95.2%,Train_loss:0.358,Test_acc:93.7%,Test_loss:0.373,Lr:1.00E-04
Epoch:69,Train_acc:95.3%,Train_loss:0.360,Test_acc:93.7%,Test_loss:0.377,Lr:1.00E-04
Epoch:70,Train_acc:95.2%,Train_loss:0.368,Test_acc:93.2%,Test_loss:0.377,Lr:1.00E-04
Epoch:71,Train_acc:94.9%,Train_loss:0.373,Test_acc:91.8%,Test_loss:0.396,Lr:1.00E-04
Epoch:72,Train_acc:94.0%,Train_loss:0.371,Test_acc:94.2%,Test_loss:0.368,Lr:1.00E-04
Epoch:73,Train_acc:95.5%,Train_loss:0.356,Test_acc:93.2%,Test_loss:0.377,Lr:1.00E-04
Epoch:74,Train_acc:96.0%,Train_loss:0.362,Test_acc:93.9%,Test_loss:0.372,Lr:1.00E-04
Epoch:75,Train_acc:95.3%,Train_loss:0.368,Test_acc:93.2%,Test_loss:0.379,Lr:1.00E-04
Epoch:76,Train_acc:94.8%,Train_loss:0.368,Test_acc:91.4%,Test_loss:0.397,Lr:1.00E-04
Epoch:77,Train_acc:91.1%,Train_loss:0.398,Test_acc:90.9%,Test_loss:0.404,Lr:1.00E-04
Epoch:78,Train_acc:92.6%,Train_loss:0.388,Test_acc:90.0%,Test_loss:0.407,Lr:1.00E-04
Epoch:79,Train_acc:94.9%,Train_loss:0.362,Test_acc:94.9%,Test_loss:0.362,Lr:1.00E-04
Epoch:80,Train_acc:93.2%,Train_loss:0.380,Test_acc:91.4%,Test_loss:0.398,Lr:1.00E-04
Epoch:81,Train_acc:94.3%,Train_loss:0.368,Test_acc:93.7%,Test_loss:0.373,Lr:1.00E-04
Epoch:82,Train_acc:94.8%,Train_loss:0.363,Test_acc:93.2%,Test_loss:0.378,Lr:1.00E-04
Epoch:83,Train_acc:95.6%,Train_loss:0.364,Test_acc:88.1%,Test_loss:0.425,Lr:1.00E-04
Epoch:84,Train_acc:92.6%,Train_loss:0.386,Test_acc:92.3%,Test_loss:0.389,Lr:1.00E-04
Epoch:85,Train_acc:94.3%,Train_loss:0.377,Test_acc:91.4%,Test_loss:0.397,Lr:1.00E-04
Epoch:86,Train_acc:96.1%,Train_loss:0.359,Test_acc:96.3%,Test_loss:0.350,Lr:1.00E-04
Epoch:87,Train_acc:95.7%,Train_loss:0.363,Test_acc:94.2%,Test_loss:0.369,Lr:1.00E-04
Epoch:88,Train_acc:96.0%,Train_loss:0.360,Test_acc:93.2%,Test_loss:0.381,Lr:1.00E-04
Epoch:89,Train_acc:95.9%,Train_loss:0.353,Test_acc:93.7%,Test_loss:0.374,Lr:1.00E-04
Epoch:90,Train_acc:95.4%,Train_loss:0.361,Test_acc:93.2%,Test_loss:0.375,Lr:1.00E-04
Epoch:91,Train_acc:96.5%,Train_loss:0.347,Test_acc:95.1%,Test_loss:0.358,Lr:1.00E-04
Epoch:92,Train_acc:97.1%,Train_loss:0.342,Test_acc:94.9%,Test_loss:0.362,Lr:1.00E-04
Epoch:93,Train_acc:95.7%,Train_loss:0.363,Test_acc:93.7%,Test_loss:0.371,Lr:1.00E-04
Epoch:94,Train_acc:94.9%,Train_loss:0.361,Test_acc:95.1%,Test_loss:0.360,Lr:1.00E-04
Epoch:95,Train_acc:96.1%,Train_loss:0.360,Test_acc:95.1%,Test_loss:0.362,Lr:1.00E-04
Epoch:96,Train_acc:96.9%,Train_loss:0.352,Test_acc:93.7%,Test_loss:0.376,Lr:1.00E-04
Epoch:97,Train_acc:95.0%,Train_loss:0.367,Test_acc:93.7%,Test_loss:0.376,Lr:1.00E-04
Epoch:98,Train_acc:95.0%,Train_loss:0.371,Test_acc:95.1%,Test_loss:0.360,Lr:1.00E-04
Epoch:99,Train_acc:96.9%,Train_loss:0.352,Test_acc:95.6%,Test_loss:0.359,Lr:1.00E-04
Epoch:100,Train_acc:94.8%,Train_loss:0.374,Test_acc:93.7%,Test_loss:0.372,Lr:1.00E-04
Done
四、 结果可视化
1. Loss与Accuracy图
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif']=['SimHei'] #正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #正常显示负号
plt.rcParams['figure.dpi']=300 #分辨率
epochs_range=range(epochs)
plt.figure(figsize=(12,3))
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='Training Accuracy')
plt.plot(epochs_range,test_acc,label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='Training Loss')
plt.plot(epochs_range,test_loss,label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
运行结果:
2. 指定图片进行预测
from PIL import Image
classes=list(total_data.class_to_idx)
def predict_one_image(image_path,model,transform,classes):
test_img=Image.open(image_path).convert('RGB')
plt.imshow(test_img) #展示预测的图片
test_img=transform(test_img)
img=test_img.to(device).unsqueeze(0)
model.eval()
output=model(img)
_,pred=torch.max(output,1)
pred_class=classes[pred]
print(f'预测结果是:{pred_class}')
预测图片:
#预测训练集中的某张照片
predict_one_image(image_path=r'D:\THE MNIST DATABASE\P4-data\Monkeypox\M01_02_00.jpg',
model=model,transform=train_transforms,classes=classes)
运行结果:
预测结果是:Monkeypox
3. 模型评估
J8_model.eval()
epoch_test_acc,epoch_test_loss=test(test_dl,J8_model,loss_fn)
epoch_test_acc,epoch_test_loss
运行结果:
(0.9627039627039627, 0.34947266622825907)
五、心得体会
在本次项目训练中,体会了再pytorch环境下搭建Iception V1模型的过程,加深了对该模型的理解。同时在模型运行过程中尝试调整了学习率,从实际结果来看,当学习率为1e-7和1e-5时的结果不如1e-4的结果更好。