第J6周:ResNeXt-50实战解析(pytorch版)

>- **🍨 本文为[🔗365天深度学习训练营]中的学习记录博客**
>- **🍖 原作者:[K同学啊]**

任务:
●阅读ResNeXt论文,了解作者的构建思路
●对比我们之前介绍的ResNet50V2、DenseNet算法
●使用ResNeXt-50算法完成猴痘病识别

🏡 我的环境:

  • 语言环境:Python3.8
  • 编译器:Jupyter Notebook
  • 深度学习环境:Pytorch
    • torch==2.3.1+cu118
    • torchvision==0.18.1+cu118

       本文完全根据 第J6周:ResNeXt-50实战解析(TensorFlow版)中的内容转换为pytorch版本,所以前述性的内容不在一一重复,仅就pytorch版本中的内容进行叙述。

一、 前期准备

1. 设置GPU

如果设备上支持GPU就使用GPU,否则使用CPU

import warnings
warnings.filterwarnings("ignore")

import torch
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

运行结果:

device(type='cuda')

2. 导入数据

同时查看数据集中图片的数量

import pathlib

data_dir=r'D:\THE MNIST DATABASE\P4-data'
data_dir=pathlib.Path(data_dir)

image_count=len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)

运行结果:

图片总数为: 2142

3. 查看数据集分类

data_paths=list(data_dir.glob('*'))
classNames=[str(path).split("\\")[3] for path in data_paths]
classNames

运行结果:

['Monkeypox', 'Others']

4. 随机查看图片

随机抽取数据集中的20张图片进行查看

import PIL,random
import matplotlib.pyplot as plt
from PIL import Image

plt.rcParams['font.sans-serif']=['SimHei']  #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False  #用来正常显示负号

data_paths2=list(data_dir.glob('*/*'))
plt.figure(figsize=(10,4))
for i in range(10):
    plt.subplot(2,5,i+1)
    plt.axis("off")
    image=random.choice(data_paths2)  #随机选择一个图片
    plt.title(image.parts[-2])  #通过glob对象取出他的文件夹名称,即分类名
    plt.imshow(Image.open(str(image)))  #显示图片

运行结果:

5. 图片预处理 

import torchvision.transforms as transforms
from torchvision import transforms,datasets

train_transforms=transforms.Compose([
    transforms.Resize([224,224]), #将图片统一尺寸
    transforms.RandomHorizontalFlip(), #将图片随机水平翻转
    transforms.RandomRotation(0.2), #将图片按照0.2的弧度值随机翻转
    transforms.ToTensor(), #将图片转换为tensor
    transforms.Normalize(  #标准化处理-->转换为正态分布,使模型更容易收敛
        mean=[0.485,0.456,0.406],
        std=[0.229,0.224,0.225]
    )
])

total_data=datasets.ImageFolder(
    r'D:\THE MNIST DATABASE\P4-data',
    transform=train_transforms
)
total_data

运行结果:

Dataset ImageFolder
    Number of datapoints: 2142
    Root location: D:\THE MNIST DATABASE\P4-data
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=True)
               RandomHorizontalFlip(p=0.5)
               RandomRotation(degrees=[-0.2, 0.2], interpolation=nearest, expand=False, fill=0)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

将数据集分类情况进行映射输出:

total_data.class_to_idx

运行结果:

{'Monkeypox': 0, 'Others': 1}

6. 划分数据集

train_size=int(0.8*len(total_data))
test_size=len(total_data)-train_size

train_dataset,test_dataset=torch.utils.data.random_split(
    total_data,
    [train_size,test_size]
)
train_dataset,test_dataset

运行结果:

(<torch.utils.data.dataset.Subset at 0x207565a54d0>,
 <torch.utils.data.dataset.Subset at 0x2075514cf90>)

查看训练集和测试集的数据数量:

train_size,test_size

运行结果:

(1713, 429)

7. 加载数据集

batch_size=16
train_dl=torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1
)
test_dl=torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1
)

查看测试集的情况:

for x,y in train_dl:
    print("Shape of x [N,C,H,W]:",x.shape)
    print("Shape of y:",y.shape,y.dtype)
    break

运行结果:

Shape of x [N,C,H,W]: torch.Size([16, 3, 224, 224])
Shape of y: torch.Size([16]) torch.int64

二、搭建模型

1. 创建卷积块

import torch.nn as nn
import torch.nn.functional as F
class BN_Conv2d(nn.Module):
    """
    BN_CONV_RELU
    """
    def __init__(self,in_channels,out_channels,kernel_size,stride,
                 padding,dilation=1,groups=1,bias=False):
        super(BN_Conv2d,self).__init__()
        self.seq=nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,
                      padding=padding,dilation=dilation,groups=groups,bias=bias),
            nn.BatchNorm2d(out_channels)
        )
    
    def forward(self,x):
        return F.relu(self.seq(x))

2. 创建block

class ResNeXt_Block(nn.Module):
    """
    ResNeXt block with group convolutions
    """
    def __init__(self,in_channnls,cardinality,group_depth,stride):
        super(ResNeXt_Block,self).__init__()
        self.group_channels=cardinality*group_depth
        self.conv1=BN_Conv2d(in_channnls,self.group_channels,1,stride=1,padding=0)
        self.conv2=BN_Conv2d(self.group_channels,self.group_channels,3,
                             stride=stride,padding=1,groups=cardinality)
        self.conv3=nn.Conv2d(self.group_channels,self.group_channels*2,1,stride=1,padding=0)
        self.bn=nn.BatchNorm2d(self.group_channels*2)
        self.short_cut=nn.Sequential(
            nn.Conv2d(in_channnls,self.group_channels*2,1,stride,0,bias=False),
            nn.BatchNorm2d(self.group_channels*2)
        )
        
    def forward(self,x):
        out=self.conv1(x)
        out=self.conv2(out)
        out=self.bn(self.conv3(out))
        out+=self.short_cut(x)
        return F.relu(out)

3. 搭建ResNeXt 模型

class ResNeXt(nn.Module):
    """
    ResNeXt builder
    """
    
    def __init__(self,layers:object,cardinality,group_depth,num_classes):
        super(ResNeXt,self).__init__()
        self.cardinality=cardinality
        self.channels=64
        self.conv1=BN_Conv2d(3,self.channels,7,stride=2,padding=3)
        d1=group_depth
        self.conv2=self.__make_layers(d1,layers[0],stride=1)
        d2=d1*2
        self.conv3=self.__make_layers(d2,layers[1],stride=2)
        d3=d2*2
        self.conv4=self.__make_layers(d3,layers[2],stride=2)
        d4=d3*2
        self.conv5=self.__make_layers(d4,layers[3],stride=2)
        self.fc=nn.Linear(self.channels,num_classes)  #224*224  input size
        
    def __make_layers(self,d,blocks,stride):
        strides=[stride]+[1]*(blocks-1)
        layers=[]
        for stride in strides:
            layers.append(ResNeXt_Block(self.channels,self.cardinality,d,stride))
            self.channels=self.cardinality*d*2
        return nn.Sequential(*layers)
    
    def forward(self,x):
        out=self.conv1(x)
        out=F.max_pool2d(out,3,2,1)
        out=self.conv2(out)
        out=self.conv3(out)
        out=self.conv4(out)
        out=self.conv5(out)
        out=F.avg_pool2d(out,7)
        out=out.view(out.size(0),-1)
        out=F.softmax(self.fc(out),dim=1)
        return out

4. 查看 ResNeXt-50 模型的参数

model=ResNeXt([3,4,6,3],32,4,4)
model.to(device)

#统计模型参数量以及其他指标
import torchsummary as summary
summary.summary(model,(3,224,224))

运行结果:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
         BN_Conv2d-3         [-1, 64, 112, 112]               0
            Conv2d-4          [-1, 128, 56, 56]           8,192
       BatchNorm2d-5          [-1, 128, 56, 56]             256
         BN_Conv2d-6          [-1, 128, 56, 56]               0
            Conv2d-7          [-1, 128, 56, 56]           4,608
       BatchNorm2d-8          [-1, 128, 56, 56]             256
         BN_Conv2d-9          [-1, 128, 56, 56]               0
           Conv2d-10          [-1, 256, 56, 56]          33,024
      BatchNorm2d-11          [-1, 256, 56, 56]             512
           Conv2d-12          [-1, 256, 56, 56]          16,384
      BatchNorm2d-13          [-1, 256, 56, 56]             512
    ResNeXt_Block-14          [-1, 256, 56, 56]               0
           Conv2d-15          [-1, 128, 56, 56]          32,768
      BatchNorm2d-16          [-1, 128, 56, 56]             256
        BN_Conv2d-17          [-1, 128, 56, 56]               0
           Conv2d-18          [-1, 128, 56, 56]           4,608
      BatchNorm2d-19          [-1, 128, 56, 56]             256
        BN_Conv2d-20          [-1, 128, 56, 56]               0
           Conv2d-21          [-1, 256, 56, 56]          33,024
      BatchNorm2d-22          [-1, 256, 56, 56]             512
           Conv2d-23          [-1, 256, 56, 56]          65,536
      BatchNorm2d-24          [-1, 256, 56, 56]             512
    ResNeXt_Block-25          [-1, 256, 56, 56]               0
           Conv2d-26          [-1, 128, 56, 56]          32,768
      BatchNorm2d-27          [-1, 128, 56, 56]             256
        BN_Conv2d-28          [-1, 128, 56, 56]               0
           Conv2d-29          [-1, 128, 56, 56]           4,608
      BatchNorm2d-30          [-1, 128, 56, 56]             256
        BN_Conv2d-31          [-1, 128, 56, 56]               0
           Conv2d-32          [-1, 256, 56, 56]          33,024
      BatchNorm2d-33          [-1, 256, 56, 56]             512
           Conv2d-34          [-1, 256, 56, 56]          65,536
      BatchNorm2d-35          [-1, 256, 56, 56]             512
    ResNeXt_Block-36          [-1, 256, 56, 56]               0
           Conv2d-37          [-1, 256, 56, 56]          65,536
      BatchNorm2d-38          [-1, 256, 56, 56]             512
        BN_Conv2d-39          [-1, 256, 56, 56]               0
           Conv2d-40          [-1, 256, 28, 28]          18,432
      BatchNorm2d-41          [-1, 256, 28, 28]             512
        BN_Conv2d-42          [-1, 256, 28, 28]               0
           Conv2d-43          [-1, 512, 28, 28]         131,584
      BatchNorm2d-44          [-1, 512, 28, 28]           1,024
           Conv2d-45          [-1, 512, 28, 28]         131,072
      BatchNorm2d-46          [-1, 512, 28, 28]           1,024
    ResNeXt_Block-47          [-1, 512, 28, 28]               0
           Conv2d-48          [-1, 256, 28, 28]         131,072
      BatchNorm2d-49          [-1, 256, 28, 28]             512
        BN_Conv2d-50          [-1, 256, 28, 28]               0
           Conv2d-51          [-1, 256, 28, 28]          18,432
      BatchNorm2d-52          [-1, 256, 28, 28]             512
        BN_Conv2d-53          [-1, 256, 28, 28]               0
           Conv2d-54          [-1, 512, 28, 28]         131,584
      BatchNorm2d-55          [-1, 512, 28, 28]           1,024
           Conv2d-56          [-1, 512, 28, 28]         262,144
      BatchNorm2d-57          [-1, 512, 28, 28]           1,024
    ResNeXt_Block-58          [-1, 512, 28, 28]               0
           Conv2d-59          [-1, 256, 28, 28]         131,072
      BatchNorm2d-60          [-1, 256, 28, 28]             512
        BN_Conv2d-61          [-1, 256, 28, 28]               0
           Conv2d-62          [-1, 256, 28, 28]          18,432
      BatchNorm2d-63          [-1, 256, 28, 28]             512
        BN_Conv2d-64          [-1, 256, 28, 28]               0
           Conv2d-65          [-1, 512, 28, 28]         131,584
      BatchNorm2d-66          [-1, 512, 28, 28]           1,024
           Conv2d-67          [-1, 512, 28, 28]         262,144
      BatchNorm2d-68          [-1, 512, 28, 28]           1,024
    ResNeXt_Block-69          [-1, 512, 28, 28]               0
           Conv2d-70          [-1, 256, 28, 28]         131,072
      BatchNorm2d-71          [-1, 256, 28, 28]             512
        BN_Conv2d-72          [-1, 256, 28, 28]               0
           Conv2d-73          [-1, 256, 28, 28]          18,432
      BatchNorm2d-74          [-1, 256, 28, 28]             512
        BN_Conv2d-75          [-1, 256, 28, 28]               0
           Conv2d-76          [-1, 512, 28, 28]         131,584
      BatchNorm2d-77          [-1, 512, 28, 28]           1,024
           Conv2d-78          [-1, 512, 28, 28]         262,144
      BatchNorm2d-79          [-1, 512, 28, 28]           1,024
    ResNeXt_Block-80          [-1, 512, 28, 28]               0
           Conv2d-81          [-1, 512, 28, 28]         262,144
      BatchNorm2d-82          [-1, 512, 28, 28]           1,024
        BN_Conv2d-83          [-1, 512, 28, 28]               0
           Conv2d-84          [-1, 512, 14, 14]          73,728
      BatchNorm2d-85          [-1, 512, 14, 14]           1,024
        BN_Conv2d-86          [-1, 512, 14, 14]               0
           Conv2d-87         [-1, 1024, 14, 14]         525,312
      BatchNorm2d-88         [-1, 1024, 14, 14]           2,048
           Conv2d-89         [-1, 1024, 14, 14]         524,288
      BatchNorm2d-90         [-1, 1024, 14, 14]           2,048
    ResNeXt_Block-91         [-1, 1024, 14, 14]               0
           Conv2d-92          [-1, 512, 14, 14]         524,288
      BatchNorm2d-93          [-1, 512, 14, 14]           1,024
        BN_Conv2d-94          [-1, 512, 14, 14]               0
           Conv2d-95          [-1, 512, 14, 14]          73,728
      BatchNorm2d-96          [-1, 512, 14, 14]           1,024
        BN_Conv2d-97          [-1, 512, 14, 14]               0
           Conv2d-98         [-1, 1024, 14, 14]         525,312
      BatchNorm2d-99         [-1, 1024, 14, 14]           2,048
          Conv2d-100         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-101         [-1, 1024, 14, 14]           2,048
   ResNeXt_Block-102         [-1, 1024, 14, 14]               0
          Conv2d-103          [-1, 512, 14, 14]         524,288
     BatchNorm2d-104          [-1, 512, 14, 14]           1,024
       BN_Conv2d-105          [-1, 512, 14, 14]               0
          Conv2d-106          [-1, 512, 14, 14]          73,728
     BatchNorm2d-107          [-1, 512, 14, 14]           1,024
       BN_Conv2d-108          [-1, 512, 14, 14]               0
          Conv2d-109         [-1, 1024, 14, 14]         525,312
     BatchNorm2d-110         [-1, 1024, 14, 14]           2,048
          Conv2d-111         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-112         [-1, 1024, 14, 14]           2,048
   ResNeXt_Block-113         [-1, 1024, 14, 14]               0
          Conv2d-114          [-1, 512, 14, 14]         524,288
     BatchNorm2d-115          [-1, 512, 14, 14]           1,024
       BN_Conv2d-116          [-1, 512, 14, 14]               0
          Conv2d-117          [-1, 512, 14, 14]          73,728
     BatchNorm2d-118          [-1, 512, 14, 14]           1,024
       BN_Conv2d-119          [-1, 512, 14, 14]               0
          Conv2d-120         [-1, 1024, 14, 14]         525,312
     BatchNorm2d-121         [-1, 1024, 14, 14]           2,048
          Conv2d-122         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-123         [-1, 1024, 14, 14]           2,048
   ResNeXt_Block-124         [-1, 1024, 14, 14]               0
          Conv2d-125          [-1, 512, 14, 14]         524,288
     BatchNorm2d-126          [-1, 512, 14, 14]           1,024
       BN_Conv2d-127          [-1, 512, 14, 14]               0
          Conv2d-128          [-1, 512, 14, 14]          73,728
     BatchNorm2d-129          [-1, 512, 14, 14]           1,024
       BN_Conv2d-130          [-1, 512, 14, 14]               0
          Conv2d-131         [-1, 1024, 14, 14]         525,312
     BatchNorm2d-132         [-1, 1024, 14, 14]           2,048
          Conv2d-133         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-134         [-1, 1024, 14, 14]           2,048
   ResNeXt_Block-135         [-1, 1024, 14, 14]               0
          Conv2d-136          [-1, 512, 14, 14]         524,288
     BatchNorm2d-137          [-1, 512, 14, 14]           1,024
       BN_Conv2d-138          [-1, 512, 14, 14]               0
          Conv2d-139          [-1, 512, 14, 14]          73,728
     BatchNorm2d-140          [-1, 512, 14, 14]           1,024
       BN_Conv2d-141          [-1, 512, 14, 14]               0
          Conv2d-142         [-1, 1024, 14, 14]         525,312
     BatchNorm2d-143         [-1, 1024, 14, 14]           2,048
          Conv2d-144         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-145         [-1, 1024, 14, 14]           2,048
   ResNeXt_Block-146         [-1, 1024, 14, 14]               0
          Conv2d-147         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-148         [-1, 1024, 14, 14]           2,048
       BN_Conv2d-149         [-1, 1024, 14, 14]               0
          Conv2d-150           [-1, 1024, 7, 7]         294,912
     BatchNorm2d-151           [-1, 1024, 7, 7]           2,048
       BN_Conv2d-152           [-1, 1024, 7, 7]               0
          Conv2d-153           [-1, 2048, 7, 7]       2,099,200
     BatchNorm2d-154           [-1, 2048, 7, 7]           4,096
          Conv2d-155           [-1, 2048, 7, 7]       2,097,152
     BatchNorm2d-156           [-1, 2048, 7, 7]           4,096
   ResNeXt_Block-157           [-1, 2048, 7, 7]               0
          Conv2d-158           [-1, 1024, 7, 7]       2,097,152
     BatchNorm2d-159           [-1, 1024, 7, 7]           2,048
       BN_Conv2d-160           [-1, 1024, 7, 7]               0
          Conv2d-161           [-1, 1024, 7, 7]         294,912
     BatchNorm2d-162           [-1, 1024, 7, 7]           2,048
       BN_Conv2d-163           [-1, 1024, 7, 7]               0
          Conv2d-164           [-1, 2048, 7, 7]       2,099,200
     BatchNorm2d-165           [-1, 2048, 7, 7]           4,096
          Conv2d-166           [-1, 2048, 7, 7]       4,194,304
     BatchNorm2d-167           [-1, 2048, 7, 7]           4,096
   ResNeXt_Block-168           [-1, 2048, 7, 7]               0
          Conv2d-169           [-1, 1024, 7, 7]       2,097,152
     BatchNorm2d-170           [-1, 1024, 7, 7]           2,048
       BN_Conv2d-171           [-1, 1024, 7, 7]               0
          Conv2d-172           [-1, 1024, 7, 7]         294,912
     BatchNorm2d-173           [-1, 1024, 7, 7]           2,048
       BN_Conv2d-174           [-1, 1024, 7, 7]               0
          Conv2d-175           [-1, 2048, 7, 7]       2,099,200
     BatchNorm2d-176           [-1, 2048, 7, 7]           4,096
          Conv2d-177           [-1, 2048, 7, 7]       4,194,304
     BatchNorm2d-178           [-1, 2048, 7, 7]           4,096
   ResNeXt_Block-179           [-1, 2048, 7, 7]               0
          Linear-180                    [-1, 4]           8,196
================================================================
Total params: 37,574,724
Trainable params: 37,574,724
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 379.37
Params size (MB): 143.34
Estimated Total Size (MB): 523.28
----------------------------------------------------------------

三、 训练模型

1. 编写训练函数

#训练循环
def train(dataloader,model,loss_fn,optimizer):
    size=len(dataloader.dataset) #训练集的大小
    num_batches=len(dataloader)  #批次数目,(size/batch_size,向上取整)
    
    train_loss,train_acc=0,0  #初始化训练损失和正确率
    
    for x,y in dataloader:  #获取图片及其标签
        x,y=x.to(device),y.to(device)
        
        #计算预测误差
        pred=model(x)  #网络输出
        loss=loss_fn(pred,y) #计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
        
        #反向传播
        optimizer.zero_grad() #grad属性归零
        loss.backward()  #反向传播
        optimizer.step()  #每一步自动更新
        
        #记录acc与loss
        train_acc+=(pred.argmax(1)==y).type(torch.float).sum().item()
        train_loss+=loss.item()
        
    train_acc/=size
    train_loss/=num_batches
    
    return train_acc,train_loss

2. 编写测试函数

测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器

def test(dataloader,model,loss_fn):
    size=len(dataloader.dataset) #测试集的大小
    num_batches=len(dataloader)  #批次数目
    test_loss,test_acc=0,0
    
    #当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs,target in dataloader:
            imgs,target=imgs.to(device),target.to(device)
            
            #计算loss
            target_pred=model(imgs)
            loss=loss_fn(target_pred,target)
            
            test_loss+=loss.item()
            test_acc+=(target_pred.argmax(1)==target).type(torch.float).sum().item()
            
    test_acc/=size
    test_loss/=num_batches
    
    return test_acc,test_loss

3. 正式训练

import copy
optimizer=torch.optim.Adam(model.parameters(),lr=1e-4)  #创建优化器,并设置学习率
loss_fn=nn.CrossEntropyLoss()  #创建损失函数 

epochs=100

train_loss=[]
train_acc=[]
test_loss=[]
test_acc=[]

best_acc=0  #设置一个最佳准确率,作为最佳模型的判别指标

for epoch in range(epochs):
    
    model.train()
    epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,optimizer)
    
    model.eval()
    epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)
    
    #保存最佳模型到J6_model
    if epoch_test_acc>best_acc:
        best_acc=epoch_test_acc
        J6_model=copy.deepcopy(model)
        
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    #获取当前学习率
    lr=optimizer.state_dict()['param_groups'][0]['lr']
    
    template=('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},Lr:{:.2E}')
    print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,
                          epoch_test_acc*100,epoch_test_loss,lr))
    
#保存最佳模型到文件中
PATH=r'D:\THE MNIST DATABASE\J-series\J6_model.pth'
torch.save(model.state_dict(),PATH)

print('Done')

运行结果:

Epoch: 1,Train_acc:57.0%,Train_loss:1.159,Test_acc:59.0%,Test_loss:1.152,Lr:1.00E-04
Epoch: 2,Train_acc:59.7%,Train_loss:1.133,Test_acc:64.6%,Test_loss:1.089,Lr:1.00E-04
Epoch: 3,Train_acc:64.0%,Train_loss:1.097,Test_acc:62.0%,Test_loss:1.117,Lr:1.00E-04
Epoch: 4,Train_acc:63.9%,Train_loss:1.095,Test_acc:63.9%,Test_loss:1.096,Lr:1.00E-04
Epoch: 5,Train_acc:64.2%,Train_loss:1.100,Test_acc:68.1%,Test_loss:1.067,Lr:1.00E-04
Epoch: 6,Train_acc:64.6%,Train_loss:1.094,Test_acc:61.5%,Test_loss:1.132,Lr:1.00E-04
Epoch: 7,Train_acc:65.5%,Train_loss:1.077,Test_acc:70.4%,Test_loss:1.032,Lr:1.00E-04
Epoch: 8,Train_acc:65.2%,Train_loss:1.088,Test_acc:66.4%,Test_loss:1.072,Lr:1.00E-04
Epoch: 9,Train_acc:67.2%,Train_loss:1.064,Test_acc:74.4%,Test_loss:1.008,Lr:1.00E-04
Epoch:10,Train_acc:66.1%,Train_loss:1.080,Test_acc:68.5%,Test_loss:1.052,Lr:1.00E-04
Epoch:11,Train_acc:65.6%,Train_loss:1.078,Test_acc:69.9%,Test_loss:1.040,Lr:1.00E-04
Epoch:12,Train_acc:66.9%,Train_loss:1.062,Test_acc:76.2%,Test_loss:0.982,Lr:1.00E-04
Epoch:13,Train_acc:65.9%,Train_loss:1.077,Test_acc:74.1%,Test_loss:1.002,Lr:1.00E-04
Epoch:14,Train_acc:65.5%,Train_loss:1.084,Test_acc:59.4%,Test_loss:1.144,Lr:1.00E-04
Epoch:15,Train_acc:62.5%,Train_loss:1.113,Test_acc:56.9%,Test_loss:1.171,Lr:1.00E-04
Epoch:16,Train_acc:66.5%,Train_loss:1.069,Test_acc:67.4%,Test_loss:1.065,Lr:1.00E-04
Epoch:17,Train_acc:68.0%,Train_loss:1.054,Test_acc:73.9%,Test_loss:1.005,Lr:1.00E-04
Epoch:18,Train_acc:67.5%,Train_loss:1.052,Test_acc:73.9%,Test_loss:0.989,Lr:1.00E-04
Epoch:19,Train_acc:68.6%,Train_loss:1.048,Test_acc:67.8%,Test_loss:1.049,Lr:1.00E-04
Epoch:20,Train_acc:70.0%,Train_loss:1.035,Test_acc:70.2%,Test_loss:1.033,Lr:1.00E-04
Epoch:21,Train_acc:70.6%,Train_loss:1.040,Test_acc:62.9%,Test_loss:1.107,Lr:1.00E-04
Epoch:22,Train_acc:71.0%,Train_loss:1.023,Test_acc:71.3%,Test_loss:1.036,Lr:1.00E-04
Epoch:23,Train_acc:72.5%,Train_loss:1.014,Test_acc:76.0%,Test_loss:0.981,Lr:1.00E-04
Epoch:24,Train_acc:70.9%,Train_loss:1.035,Test_acc:75.3%,Test_loss:0.993,Lr:1.00E-04
Epoch:25,Train_acc:72.5%,Train_loss:1.012,Test_acc:76.7%,Test_loss:0.974,Lr:1.00E-04
Epoch:26,Train_acc:70.8%,Train_loss:1.028,Test_acc:72.7%,Test_loss:1.004,Lr:1.00E-04
Epoch:27,Train_acc:72.7%,Train_loss:1.009,Test_acc:73.2%,Test_loss:1.011,Lr:1.00E-04
Epoch:28,Train_acc:73.8%,Train_loss:1.006,Test_acc:75.3%,Test_loss:0.991,Lr:1.00E-04
Epoch:29,Train_acc:74.5%,Train_loss:0.992,Test_acc:74.6%,Test_loss:0.986,Lr:1.00E-04
Epoch:30,Train_acc:73.3%,Train_loss:1.005,Test_acc:73.2%,Test_loss:1.004,Lr:1.00E-04
Epoch:31,Train_acc:75.7%,Train_loss:0.993,Test_acc:77.4%,Test_loss:0.968,Lr:1.00E-04
Epoch:32,Train_acc:74.6%,Train_loss:0.989,Test_acc:72.3%,Test_loss:1.016,Lr:1.00E-04
Epoch:33,Train_acc:76.6%,Train_loss:0.973,Test_acc:70.2%,Test_loss:1.042,Lr:1.00E-04
Epoch:34,Train_acc:75.2%,Train_loss:0.982,Test_acc:74.6%,Test_loss:0.992,Lr:1.00E-04
Epoch:35,Train_acc:71.5%,Train_loss:1.018,Test_acc:77.6%,Test_loss:0.977,Lr:1.00E-04
Epoch:36,Train_acc:74.4%,Train_loss:1.006,Test_acc:76.7%,Test_loss:0.973,Lr:1.00E-04
Epoch:37,Train_acc:72.0%,Train_loss:1.012,Test_acc:76.9%,Test_loss:0.978,Lr:1.00E-04
Epoch:38,Train_acc:71.5%,Train_loss:1.030,Test_acc:72.7%,Test_loss:1.017,Lr:1.00E-04
Epoch:39,Train_acc:75.1%,Train_loss:0.987,Test_acc:76.5%,Test_loss:0.979,Lr:1.00E-04
Epoch:40,Train_acc:75.4%,Train_loss:0.989,Test_acc:75.8%,Test_loss:0.979,Lr:1.00E-04
Epoch:41,Train_acc:78.1%,Train_loss:0.968,Test_acc:77.9%,Test_loss:0.963,Lr:1.00E-04
Epoch:42,Train_acc:77.2%,Train_loss:0.977,Test_acc:74.4%,Test_loss:0.987,Lr:1.00E-04
Epoch:43,Train_acc:77.9%,Train_loss:0.968,Test_acc:73.7%,Test_loss:0.994,Lr:1.00E-04
Epoch:44,Train_acc:79.1%,Train_loss:0.954,Test_acc:78.8%,Test_loss:0.953,Lr:1.00E-04
Epoch:45,Train_acc:79.6%,Train_loss:0.950,Test_acc:79.3%,Test_loss:0.949,Lr:1.00E-04
Epoch:46,Train_acc:80.2%,Train_loss:0.938,Test_acc:79.0%,Test_loss:0.948,Lr:1.00E-04
Epoch:47,Train_acc:80.6%,Train_loss:0.943,Test_acc:78.3%,Test_loss:0.962,Lr:1.00E-04
Epoch:48,Train_acc:75.9%,Train_loss:0.982,Test_acc:73.0%,Test_loss:1.013,Lr:1.00E-04
Epoch:49,Train_acc:77.3%,Train_loss:0.966,Test_acc:76.2%,Test_loss:0.977,Lr:1.00E-04
Epoch:50,Train_acc:79.9%,Train_loss:0.947,Test_acc:74.4%,Test_loss:0.991,Lr:1.00E-04
Epoch:51,Train_acc:80.4%,Train_loss:0.944,Test_acc:75.1%,Test_loss:0.986,Lr:1.00E-04
Epoch:52,Train_acc:79.2%,Train_loss:0.953,Test_acc:77.2%,Test_loss:0.970,Lr:1.00E-04
Epoch:53,Train_acc:80.0%,Train_loss:0.939,Test_acc:78.8%,Test_loss:0.951,Lr:1.00E-04
Epoch:54,Train_acc:79.0%,Train_loss:0.954,Test_acc:80.2%,Test_loss:0.944,Lr:1.00E-04
Epoch:55,Train_acc:82.7%,Train_loss:0.923,Test_acc:79.0%,Test_loss:0.945,Lr:1.00E-04
Epoch:56,Train_acc:81.9%,Train_loss:0.926,Test_acc:80.0%,Test_loss:0.939,Lr:1.00E-04
Epoch:57,Train_acc:82.8%,Train_loss:0.915,Test_acc:76.2%,Test_loss:0.973,Lr:1.00E-04
Epoch:58,Train_acc:81.7%,Train_loss:0.926,Test_acc:82.8%,Test_loss:0.918,Lr:1.00E-04
Epoch:59,Train_acc:83.2%,Train_loss:0.918,Test_acc:81.4%,Test_loss:0.931,Lr:1.00E-04
Epoch:60,Train_acc:82.5%,Train_loss:0.916,Test_acc:81.4%,Test_loss:0.926,Lr:1.00E-04
Epoch:61,Train_acc:79.6%,Train_loss:0.950,Test_acc:78.8%,Test_loss:0.946,Lr:1.00E-04
Epoch:62,Train_acc:83.4%,Train_loss:0.914,Test_acc:80.2%,Test_loss:0.940,Lr:1.00E-04
Epoch:63,Train_acc:86.0%,Train_loss:0.893,Test_acc:80.2%,Test_loss:0.940,Lr:1.00E-04
Epoch:64,Train_acc:84.1%,Train_loss:0.899,Test_acc:80.9%,Test_loss:0.921,Lr:1.00E-04
Epoch:65,Train_acc:84.2%,Train_loss:0.905,Test_acc:82.1%,Test_loss:0.917,Lr:1.00E-04
Epoch:66,Train_acc:85.5%,Train_loss:0.894,Test_acc:80.9%,Test_loss:0.934,Lr:1.00E-04
Epoch:67,Train_acc:83.7%,Train_loss:0.913,Test_acc:80.0%,Test_loss:0.942,Lr:1.00E-04
Epoch:68,Train_acc:83.4%,Train_loss:0.907,Test_acc:81.8%,Test_loss:0.913,Lr:1.00E-04
Epoch:69,Train_acc:85.2%,Train_loss:0.892,Test_acc:81.8%,Test_loss:0.926,Lr:1.00E-04
Epoch:70,Train_acc:86.1%,Train_loss:0.884,Test_acc:82.1%,Test_loss:0.928,Lr:1.00E-04
Epoch:71,Train_acc:82.5%,Train_loss:0.918,Test_acc:81.4%,Test_loss:0.929,Lr:1.00E-04
Epoch:72,Train_acc:85.9%,Train_loss:0.892,Test_acc:81.6%,Test_loss:0.920,Lr:1.00E-04
Epoch:73,Train_acc:85.2%,Train_loss:0.893,Test_acc:79.3%,Test_loss:0.944,Lr:1.00E-04
Epoch:74,Train_acc:87.2%,Train_loss:0.875,Test_acc:85.8%,Test_loss:0.884,Lr:1.00E-04
Epoch:75,Train_acc:86.7%,Train_loss:0.876,Test_acc:84.8%,Test_loss:0.893,Lr:1.00E-04
Epoch:76,Train_acc:86.5%,Train_loss:0.875,Test_acc:83.4%,Test_loss:0.903,Lr:1.00E-04
Epoch:77,Train_acc:87.0%,Train_loss:0.878,Test_acc:85.8%,Test_loss:0.884,Lr:1.00E-04
Epoch:78,Train_acc:88.3%,Train_loss:0.861,Test_acc:86.0%,Test_loss:0.888,Lr:1.00E-04
Epoch:79,Train_acc:87.2%,Train_loss:0.869,Test_acc:86.0%,Test_loss:0.883,Lr:1.00E-04
Epoch:80,Train_acc:87.1%,Train_loss:0.877,Test_acc:85.8%,Test_loss:0.886,Lr:1.00E-04
Epoch:81,Train_acc:88.4%,Train_loss:0.859,Test_acc:82.8%,Test_loss:0.913,Lr:1.00E-04
Epoch:82,Train_acc:88.9%,Train_loss:0.851,Test_acc:85.8%,Test_loss:0.878,Lr:1.00E-04
Epoch:83,Train_acc:88.4%,Train_loss:0.859,Test_acc:84.8%,Test_loss:0.893,Lr:1.00E-04
Epoch:84,Train_acc:89.0%,Train_loss:0.860,Test_acc:84.1%,Test_loss:0.900,Lr:1.00E-04
Epoch:85,Train_acc:89.9%,Train_loss:0.850,Test_acc:84.1%,Test_loss:0.899,Lr:1.00E-04
Epoch:86,Train_acc:89.5%,Train_loss:0.850,Test_acc:83.0%,Test_loss:0.913,Lr:1.00E-04
Epoch:87,Train_acc:88.7%,Train_loss:0.854,Test_acc:86.0%,Test_loss:0.885,Lr:1.00E-04
Epoch:88,Train_acc:91.2%,Train_loss:0.837,Test_acc:80.9%,Test_loss:0.928,Lr:1.00E-04
Epoch:89,Train_acc:91.7%,Train_loss:0.831,Test_acc:86.0%,Test_loss:0.883,Lr:1.00E-04
Epoch:90,Train_acc:87.4%,Train_loss:0.863,Test_acc:84.1%,Test_loss:0.900,Lr:1.00E-04
Epoch:91,Train_acc:90.1%,Train_loss:0.851,Test_acc:86.2%,Test_loss:0.878,Lr:1.00E-04
Epoch:92,Train_acc:88.3%,Train_loss:0.855,Test_acc:86.7%,Test_loss:0.871,Lr:1.00E-04
Epoch:93,Train_acc:90.5%,Train_loss:0.844,Test_acc:85.8%,Test_loss:0.884,Lr:1.00E-04
Epoch:94,Train_acc:92.4%,Train_loss:0.821,Test_acc:85.3%,Test_loss:0.881,Lr:1.00E-04
Epoch:95,Train_acc:91.4%,Train_loss:0.835,Test_acc:86.2%,Test_loss:0.878,Lr:1.00E-04
Epoch:96,Train_acc:92.2%,Train_loss:0.829,Test_acc:82.3%,Test_loss:0.917,Lr:1.00E-04
Epoch:97,Train_acc:90.0%,Train_loss:0.848,Test_acc:83.2%,Test_loss:0.913,Lr:1.00E-04
Epoch:98,Train_acc:90.8%,Train_loss:0.836,Test_acc:87.9%,Test_loss:0.868,Lr:1.00E-04
Epoch:99,Train_acc:89.6%,Train_loss:0.848,Test_acc:83.7%,Test_loss:0.908,Lr:1.00E-04
Epoch:100,Train_acc:91.0%,Train_loss:0.832,Test_acc:86.2%,Test_loss:0.881,Lr:1.00E-04
Done

四、 结果可视化

1. Loss与Accuracy图

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")   #忽略警告信息
plt.rcParams['font.sans-serif']=['SimHei']   #正常显示中文标签
plt.rcParams['axes.unicode_minus']=False   #正常显示负号
plt.rcParams['figure.dpi']=300   #分辨率
 
epochs_range=range(epochs)
plt.figure(figsize=(12,3))
 
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='Training Accuracy')
plt.plot(epochs_range,test_acc,label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
 
plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='Training Loss')
plt.plot(epochs_range,test_loss,label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

运行结果:

2. 指定图片进行预测 

from PIL import Image
 
classes=list(total_data.class_to_idx)
 
def predict_one_image(image_path,model,transform,classes):
    
    test_img=Image.open(image_path).convert('RGB')
    plt.imshow(test_img)   #展示预测的图片
    
    test_img=transform(test_img)
    img=test_img.to(device).unsqueeze(0)
    
    model.eval()
    output=model(img)
    
    _,pred=torch.max(output,1)
    pred_class=classes[pred]
    print(f'预测结果是:{pred_class}')

预测图片:

#预测训练集中的某张照片
predict_one_image(image_path=r'D:\THE MNIST DATABASE\P4-data\Others\NM01_01_00.jpg',
                  model=model,transform=train_transforms,classes=classes)

运行结果:

预测结果是:Others

3. 模型评估

J6_model.eval()
epoch_test_acc,epoch_test_loss=test(test_dl,J6_model,loss_fn)
epoch_test_acc,epoch_test_loss

五、心得体会

在pytorch环境下手动搭建了ResNeXt-50模型,深刻理解了该模型的构造原理,对该模型有了更深层次的感悟。但模型训练结果没有达到最为理想的状态,时间原因不再做调整,在今后的测试中可以尝试调整学习率等查看结果是否有较好变化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/896972.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于Java+SpringBoot+Vue的古典舞在线交流平台的设计与实现

基于JavaSpringBootVue的古典舞在线交流平台的设计与实现 前言 ✌全网粉丝20W,csdn特邀作者、博客专家、CSDN[新星计划]导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末附源码下载链接&a…

科研类型PPT的制作技巧

目录 科研类型PPT的制作技巧 荣誉: 首页:ppt开头结尾 小标题 重点标记:加粗红色下划线 使用三线表 图片,文本排版 一、明确目的与受众分析 二、基础设计原则 三、内容组织与呈现 四、绘图与模型制作 五、其他注意事项 科研类型PPT的制作技巧 荣誉: 首页:ppt开…

spark读取parquet文件

源码 parquet文件读取的入口是FileSourceScanExec&#xff0c;用parquet文件生成对应的RDD 非bucket文件所以走createNonBucketedReadRDD方法。 createNonBucketedReadRDD 过程&#xff1a; 确定文件分割参数 openCostInBytes4M 相关参数spark.sql.files.openCostInBytes4M…

Vue 上传图片前 裁剪图片

一. 使用的技术 vue-cropper 文档&#xff1a;vue-cropper | A simple picture clipping plugin for vue 二. 安装 npm install vue-cropper 或 yarn add vue-cropper 三. 引入 在使用页面中引用 import { VueCropper } from vue-cropper; 四. 使用 配置项&#xff1…

运动爱好者不可错过的双十一特惠,2024年度最火运动装备大推荐

随着健康意识的日益增强&#xff0c;越来越多的人加入到了运动的行列中。无论是追求速度与激情的跑步爱好者&#xff0c;还是享受汗水与肌肉碰撞的健身房常客&#xff0c;亦或是喜欢在自然中寻找乐趣的户外探险家&#xff0c;一款合适的运动装备总是能让人在运动过程中事半功倍…

【MySQL】索引和事务

目录 &#x1f334;索引 &#x1f6a9;概念 &#x1f6a9;索引的作用 &#x1f6a9;索引的使用场景 &#x1f6a9;索引的使用 &#x1f3c0;查看索引 &#x1f3c0;创建索引 &#x1f3c0;删除索引 &#x1f384;索引的底层数据结构 &#x1f6a9;引入B树(B-树) &am…

【含开题报告+文档+PPT+源码】基于SpringBoot和Vue的编程学习系统

开题报告 随着信息技术的迅猛发展和数字化转型的深入推进&#xff0c;编程技能已经成为现代社会中不可或缺的一项基本能力。无论是软件开发、数据分析还是人工智能等领域&#xff0c;编程都扮演着至关重要的角色。因此&#xff0c;培养和提高编程技能对于个人职业发展和社会创…

Python Numpy 实现神经网络自动训练:反向传播与激活函数的应用详解

Python Numpy 实现神经网络自动训练&#xff1a;反向传播与激活函数的应用详解 这篇文章介绍了如何使用 Python 的 Numpy 库来实现神经网络的自动训练&#xff0c;重点展示了反向传播算法和激活函数的应用。反向传播是神经网络训练的核心&#xff0c;能够通过计算梯度来优化模…

文献阅读:通过深度神经网络联合建模多个切片构建3D整体生物体空间图谱

文献介绍 文献题目&#xff1a; 通过深度神经网络联合建模多个切片构建3D整体生物体空间图谱 研究团队&#xff1a; 杨灿&#xff08;香港科技大学&#xff09;、吴若昊&#xff08;香港科技大学&#xff09; 发表时间&#xff1a; 2023-10-19 发表期刊&#xff1a; Nature M…

01 漫画解说-图片框的分割

to 查找最佳的轮廓模式 import cv2 as cv import numpy as np from matplotlib import pyplot as pltimg cv.imread(data/test02.png,0) ret,thresh1 cv.threshold(img,127,255,cv.THRESH_BINARY) ret,thresh2 cv.threshold(img,127,255,cv.THRESH_BINARY_INV) ret,thres…

搭建代购系统时如何保证商品信息的真实性和可靠性

搭建代购系统时&#xff0c;可从以下几个方面保证商品信息的真实性和可靠性&#xff1a; 一、供应商管理&#xff1a; 严格筛选供应商&#xff1a;对供应商进行全面的背景调查&#xff0c;包括其经营资质、信誉记录、行业口碑等。只选择与正规、有良好信誉的供应商合作&#…

LINUX1.2

1.一切都是一个文件 &#xff08;硬盘&#xff09; 2.系统小型 轻量型&#xff0c;300个包 3.避免令人困惑的用户界面 ------------------> 就是没有复杂的图形界面 4.不在乎后缀名&#xff0c;有没有都无所谓&#xff0c;不是通过后缀名来定义文件的类型&#xff08;win…

JSON 注入攻击 API

文章目录 JSON 注入攻击 API"注入所有东西"是"聪明的"发生了什么? 什么是 JSON 注入?为什么解析器是问题所在解析不一致 JSON 解析器互操作性中的安全问题处理重复密钥的方式不一致按键碰撞响应不一致JSON 序列化(反序列化)中的不一致 好的。JSON 解析器…

免费开源AI助手,颠覆你的数字生活体验

Apt Full作为一款开源且完全免费的软件&#xff0c;除了强大的自然语言处理能力&#xff0c;Apt Full还能够对图像和视频进行一系列复杂的AI增强处理&#xff0c;只需简单几步即可实现专业级的效果。 在图像处理方面&#xff0c;Apt Full提供了一套全面的AI工具&#xff0c;包…

springboot 同时上传文件和JSON对象

控制器代码 PostMapping("/upload") public ResponseEntity<String> handleFileUpload(RequestPart("file") MultipartFile file,RequestPart("user") User user) {// 处理文件和用户信息return ResponseEntity.ok("File and user i…

【MATLAB实例】批量提取.csv数据并根据变量名筛选

【MATLAB实例】批量提取.csv数据并根据变量名筛选 准备&#xff1a;数据说明MATLAB批量提取参考 准备&#xff1a;数据说明 .csv数据如下&#xff1a; 打开某表格数据&#xff0c;如下&#xff1a;&#xff08;需要说明的是此数据含表头&#xff09; 需求说明&#xff1a;需…

升级Unity后产生的Objects内存泄露现象

1&#xff09;升级Unity后产生的Objects内存泄露现象 2&#xff09;能否使用OnDemandRendering API来显示帧率 3&#xff09;Unity闪退问题 4&#xff09;配置表堆内存如何优化 这是第405篇UWA技术知识分享的推送&#xff0c;精选了UWA社区的热门话题&#xff0c;涵盖了UWA问答…

中航资本:大幅加仓!社保基金重仓股曝光

跟着上市公司三季报布满宣告&#xff0c;社保基金2024年三季度末重仓股及持股改变状况浮出水面。 Wind数据闪现&#xff0c;到10月21日&#xff0c;已有191家上市公司宣告了2024年三季报&#xff0c;其间有34家上市公司的前十大流通股东中呈现了社保基金的身影&#xff0c;社保…

从零开始学PHP之变量作用域数据类型

一、数据类型 上篇文章提到了数据类型&#xff0c;在PHP中支持以下几种类型 String &#xff08;字符串&#xff09;Integer&#xff08;整型&#xff09;Float &#xff08;浮点型&#xff09;Boolean&#xff08;布尔型&#xff09;Array&#xff08;数组&#xff09;Objec…

天锐绿盾 vs Ping32:企业级加密软件大比拼

在信息安全日益重要的今天&#xff0c;企业级加密软件成为了企业保护敏感数据的得力助手。在众多加密软件中&#xff0c;天锐绿盾与Ping32凭借各自的优势&#xff0c;赢得了市场的广泛认可。那么&#xff0c;这两款软件究竟有何异同&#xff1f;哪款更适合您的企业呢&#xff1…