Pytorch | 从零构建GoogleNet对CIFAR10进行分类

Pytorch | 从零构建Vgg对CIFAR10进行分类

  • CIFAR10数据集
  • GoogleNet
      • 网络结构特点
      • 网络整体架构
      • 特征图尺寸变化
      • 应用与影响
  • GoogleNet结构代码详解
    • 结构代码
    • 代码详解
      • Inception 类
        • 初始化方法
        • 前向传播 forward
      • GoogleNet 类
        • 初始化方法
        • 前向传播 forward
  • 训练和测试
    • 训练代码train.py
    • 测试代码test.py
    • 训练过程和测试结果
  • 代码汇总
    • googlenet.py
    • train.py
    • test.py

前面文章我们构建了AlexNet、Vgg对CIFAR10进行分类:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
这篇文章我们来构建GoogleNet.

CIFAR10数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:
在这里插入图片描述

GoogleNet

GoogleNet是由Google团队在2014年提出的一种深度卷积神经网络架构,以下是对它的详细介绍:

网络结构特点

  • Inception模块:这是GoogleNet的核心创新点。Inception模块通过并行使用不同大小的卷积核(如1×1、3×3、5×5)和池化操作,然后将它们的结果在通道维度上进行拼接,从而可以同时提取不同尺度的特征。例如,1×1卷积核可以用于在不改变特征图尺寸的情况下进行降维或升维,减少计算量;3×3和5×5卷积核则可以捕捉不同感受野的特征。
  • 深度和宽度:GoogleNet网络很深,共有22层,但它的参数量却比同层次的一些网络少很多,这得益于Inception模块的高效设计。同时,网络的宽度也较大,能够学习到丰富的特征表示。
  • 辅助分类器:为了缓解梯度消失问题,GoogleNet在网络中间层添加了两个辅助分类器。这些辅助分类器在训练过程中与主分类器一起进行反向传播,帮助梯度更好地传播到浅层网络,加快训练速度并提高模型的泛化能力。在测试时,辅助分类器的结果会被加权融合到主分类器的结果中。

网络整体架构

  • 输入层:接收大小为 H × W × 3 H×W×3 H×W×3的图像数据,其中 H H H W W W表示图像的高度和宽度,3表示图像的RGB通道数。

  • 卷积层和池化层:网络的前面几层主要由卷积层和池化层组成,用于提取图像的基本特征。这些层逐渐降低图像的分辨率,同时增加特征图的通道数。

  • Inception模块组:网络的主体部分由多个Inception模块组构成,每个模块组包含多个Inception模块。随着网络的深入,Inception模块的输出通道数逐渐增加,以学习更高级的特征。
    在这里插入图片描述

  • 池化层和全连接层:在Inception模块组之后,网络通过一个平均池化层将特征图的尺寸缩小到1×1,然后将其展平并连接到一个全连接层,最后通过一个Softmax层输出分类结果。

特征图尺寸变化

假设输入图像的尺寸为 32 × 32 × 3 32×32×3 32×32×3,以下是在网络前向传播过程中特征图尺寸的大致变化:

  1. 第一层卷积:使用 7 × 7 7×7 7×7的卷积核,步长为2,填充为3,经过卷积后特征图尺寸变为 16 × 16 × 64 16×16×64 16×16×64
  2. 最大池化层:使用 3 × 3 3×3 3×3的池化核,步长为2,经过池化后特征图尺寸变为 8 × 8 × 64 8×8×64 8×8×64
  3. Inception模块组:在Inception模块组中,特征图的尺寸会根据不同的卷积和池化操作而变化。例如,在一些Inception模块中,使用 1 × 1 1×1 1×1 3 × 3 3×3 3×3 5 × 5 5×5 5×5的卷积核以及 3 × 3 3×3 3×3的池化操作,特征图的尺寸可能会在 8 × 8 8×8 8×8 4 × 4 4×4 4×4等之间变化,而通道数会逐渐增加。
  4. 平均池化层:在网络的最后,使用一个平均池化层将特征图的尺寸变为 1 × 1 × 1024 1×1×1024 1×1×1024

应用与影响

  • 图像分类:GoogleNet在图像分类任务上取得了非常好的效果,在ILSVRC 2014图像分类竞赛中获得了冠军。它能够准确地识别各种自然图像中的物体类别,如猫、狗、汽车、飞机等。
  • 目标检测:GoogleNet也可以应用于目标检测任务,通过在网络中添加一些额外的检测层和算法,可以实现对图像中物体的定位和检测。
  • 后续研究基础:GoogleNet的成功推动了深度学习领域的发展,其Inception模块的设计思想为后来的许多网络架构提供了灵感,如Inception系列的后续版本以及其他一些基于多分支结构的网络。

GoogleNet结构代码详解

结构代码

import torch
import torch.nn as nn


class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3reduc, ch3x3, ch5x5reduc, ch5x5, pool_proj):
        super().__init__()
        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, ch1x1, kernel_size=1),
            nn.BatchNorm2d(ch1x1),
            nn.ReLU(inplace=True)
        )

        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, ch3x3reduc, kernel_size=1),
            nn.BatchNorm2d(ch3x3reduc),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch3x3reduc, ch3x3, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, ch5x5reduc, kernel_size=1),
            nn.BatchNorm2d(ch5x5reduc),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch5x5reduc, ch5x5, kernel_size=3, padding=1),
            nn.BatchNorm2d(ch5x5),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch5x5, ch5x5, kernel_size=3, padding=1),
            nn.BatchNorm2d(ch5x5),
            nn.ReLU(inplace=True)
        )

        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, pool_proj, kernel_size=1),
            nn.BatchNorm2d(pool_proj),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        branch_pool = self.branch_pool(x)

        return torch.cat([branch1x1, branch3x3, branch5x5, branch_pool], 1)


class GoogleNet(nn.Module):
    def __init__(self, num_classes):
        super(GoogleNet, self).__init__()
        self.prelayers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True)
        )
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.prelayers(x)
        x = self.maxpool2(x)

        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)

        x = self.inception4a(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        x = self.inception4e(x)
        x = self.maxpool4(x)

        x = self.inception5a(x)
        x = self.inception5b(x)

        x = self.avgpool(x)
        x = self.dropout(x)
        x = x.view(x.size()[0], -1)
        x = self.fc(x)

        return x

代码详解

以下是对上述代码的详细解释,这段Python代码使用PyTorch库构建了经典的GoogleNet(Inception v1)网络结构,用于图像分类任务,以下从不同部分展开介绍:

Inception 类

这个类定义了GoogleNet中的Inception模块,它的作用是通过不同尺寸的卷积核等操作来并行提取特征,然后将这些特征在通道维度上进行拼接。

初始化方法
  • 参数说明
    • in_channels:输入特征图的通道数,即输入数据的深度维度。
    • ch1x1ch3x3reducch3x3ch5x5reducch5x5pool_proj:分别对应不同分支中卷积操作涉及的通道数等参数,用于配置每个分支的结构。
  • 网络结构构建
    • self.branch1x1:构建了一个包含1×1卷积、批归一化(BatchNorm)和ReLU激活函数的顺序结构。1×1卷积用于在不改变特征图尺寸的情况下调整通道数,批归一化有助于加速训练和提高模型稳定性,ReLU激活函数引入非线性变换。
    • self.branch3x3:先是一个1×1卷积进行通道数的降维(减少计算量),接着经过批归一化和ReLU激活,然后是一个3×3卷积(通过padding=1保证特征图尺寸不变),最后再接ReLU激活。
    • self.branch5x5:结构相对更复杂些,先是1×1卷积和批归一化、ReLU激活,然后连续两个3×3卷积(都通过合适的padding保证尺寸不变),中间穿插批归一化和ReLU激活,用于提取更复杂的特征。
    • self.branch_pool:先进行最大池化(MaxPool2d,通过特定参数设置保证尺寸基本不变),然后接1×1卷积来调整通道数,再进行批归一化和ReLU激活。
前向传播 forward
  • 接收输入张量x,分别将其传入上述四个分支结构中,得到四个分支的输出branch1x1branch3x3branch5x5branch_pool
  • 最后通过torch.cat函数沿着通道维度(维度1,即参数中的1)将这四个分支的输出特征图拼接在一起,作为整个Inception模块的输出。

GoogleNet 类

这是整个网络的主体类,将多个Inception模块以及其他必要的层组合起来构建完整的GoogleNet架构。

初始化方法
  • 参数说明
    • num_classes:表示分类任务的类别数量,用于最终全连接层输出对应数量的类别预测结果。
  • 网络结构构建
    • self.prelayers:由一系列的卷积、批归一化和ReLU激活函数组成的顺序结构,用于对输入图像进行初步的特征提取,逐步将输入的3通道(对应RGB图像)特征图转换为192通道的特征图。
    • self.maxpool2:一个最大池化层,用于下采样,减小特征图尺寸,同时增大感受野,步长为2,按一定的padding设置来控制输出尺寸。
    • 接下来依次定义了多个Inception模块,如self.inception3aself.inception3b等,它们的输入通道数和各分支的配置参数不同,随着网络的深入逐渐提取更高级、更复杂的特征,并且中间穿插了几个最大池化层(self.maxpool3self.maxpool4等)进行下采样操作。
    • self.avgpool:自适应平均池化层,将不同尺寸的特征图转换为固定大小(这里是1×1)的特征图,方便后续的全连接层处理。
    • self.dropout:引入Dropout层,概率设置为0.4,在训练过程中随机丢弃部分神经元连接,防止过拟合。
    • self.fc:全连接层,将经过前面处理后的特征映射到指定的num_classes个类别上,用于最终的分类预测。
前向传播 forward
  • 首先将输入x传入self.prelayers进行初步特征提取,然后经过self.maxpool2下采样。
  • 接着依次将特征图传入各个Inception模块,并穿插经过最大池化层进行下采样,不断提取和整合特征。
  • 经过最后的Inception模块后,特征图通过self.avgpool进行平均池化,再经过self.dropout进行随机失活处理,然后通过x.view函数将特征图展平成一维向量(方便全连接层处理),最后传入self.fc全连接层得到最终的分类预测结果并返回。

训练和测试

训练代码train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import *
import matplotlib.pyplot as plt

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# 定义数据预处理操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])

# 加载CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化模型
model_name = 'GoogleNet'
if model_name == 'AlexNet':
    model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':
    model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':
    model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':
    model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':
    model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':
    model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':
    model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':
    model = GoogleNet(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练轮次
epochs = 15

def train(model, trainloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(trainloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

if __name__ == "__main__":
    loss_history, acc_history = [], []
    for epoch in range(epochs):
        train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)
        print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        loss_history.append(train_loss)
        acc_history.append(train_acc)
        # 保存模型权重,每5轮次保存到weights文件夹下
        if (epoch + 1) % 5 == 0:
            torch.save(model.state_dict(), f'weights/{model_name}_epoch_{epoch + 1}.pth')
    
    # 绘制损失曲线
    plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')
    plt.legend()
    plt.savefig(f'results\\{model_name}_train_loss_curve.png')
    plt.close()

    # 绘制准确率曲线
    plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training Accuracy Curve')
    plt.legend()
    plt.savefig(f'results\\{model_name}_train_acc_curve.png')
    plt.close()

测试代码test.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定义数据预处理操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])

# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化模型
model_name = 'GoogleNet'
if model_name == 'AlexNet':
    model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':
    model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':
    model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':
    model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':
    model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':
    model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':
    model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':
    model = GoogleNet(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()

# 加载模型权重
weights_path = f"weights/{model_name}_epoch_15.pth"  
model.load_state_dict(torch.load(weights_path, map_location=device))

def test(model, testloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(testloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

if __name__ == "__main__":
    test_loss, test_acc = test(model, testloader, criterion, device)
    print(f"================{model_name} Test================")
    print(f"Load Model Weights From: {weights_path}")
    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

训练过程和测试结果

训练过程损失函数变化曲线:
在这里插入图片描述

训练过程准确率变化曲线:
在这里插入图片描述

测试结果:
在这里插入图片描述

代码汇总

项目github地址
项目结构:

|--data
|--models
	|--__init__.py
	|--googlenet.py
	|--...
|--results
|--weights
|--train.py
|--test.py

googlenet.py

import torch
import torch.nn as nn


class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3reduc, ch3x3, ch5x5reduc, ch5x5, pool_proj):
        super().__init__()
        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, ch1x1, kernel_size=1),
            nn.BatchNorm2d(ch1x1),
            nn.ReLU(inplace=True)
        )

        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, ch3x3reduc, kernel_size=1),
            nn.BatchNorm2d(ch3x3reduc),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch3x3reduc, ch3x3, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, ch5x5reduc, kernel_size=1),
            nn.BatchNorm2d(ch5x5reduc),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch5x5reduc, ch5x5, kernel_size=3, padding=1),
            nn.BatchNorm2d(ch5x5),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch5x5, ch5x5, kernel_size=3, padding=1),
            nn.BatchNorm2d(ch5x5),
            nn.ReLU(inplace=True)
        )

        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, pool_proj, kernel_size=1),
            nn.BatchNorm2d(pool_proj),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        branch_pool = self.branch_pool(x)

        return torch.cat([branch1x1, branch3x3, branch5x5, branch_pool], 1)


class GoogleNet(nn.Module):
    def __init__(self, num_classes):
        super(GoogleNet, self).__init__()
        self.prelayers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True)
        )
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.prelayers(x)
        x = self.maxpool2(x)

        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)

        x = self.inception4a(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        x = self.inception4e(x)
        x = self.maxpool4(x)

        x = self.inception5a(x)
        x = self.inception5b(x)

        x = self.avgpool(x)
        x = self.dropout(x)
        x = x.view(x.size()[0], -1)
        x = self.fc(x)

        return x

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import *
import matplotlib.pyplot as plt

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# 定义数据预处理操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])

# 加载CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化模型
model_name = 'GoogleNet'
if model_name == 'AlexNet':
    model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':
    model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':
    model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':
    model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':
    model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':
    model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':
    model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':
    model = GoogleNet(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练轮次
epochs = 15

def train(model, trainloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(trainloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

if __name__ == "__main__":
    loss_history, acc_history = [], []
    for epoch in range(epochs):
        train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)
        print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        loss_history.append(train_loss)
        acc_history.append(train_acc)
        # 保存模型权重,每5轮次保存到weights文件夹下
        if (epoch + 1) % 5 == 0:
            torch.save(model.state_dict(), f'weights/{model_name}_epoch_{epoch + 1}.pth')
    
    # 绘制损失曲线
    plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')
    plt.legend()
    plt.savefig(f'results\\{model_name}_train_loss_curve.png')
    plt.close()

    # 绘制准确率曲线
    plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training Accuracy Curve')
    plt.legend()
    plt.savefig(f'results\\{model_name}_train_acc_curve.png')
    plt.close()

test.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定义数据预处理操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])

# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化模型
model_name = 'GoogleNet'
if model_name == 'AlexNet':
    model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':
    model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':
    model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':
    model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':
    model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':
    model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':
    model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':
    model = GoogleNet(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()

# 加载模型权重
weights_path = f"weights/{model_name}_epoch_15.pth"  
model.load_state_dict(torch.load(weights_path, map_location=device))

def test(model, testloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(testloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

if __name__ == "__main__":
    test_loss, test_acc = test(model, testloader, criterion, device)
    print(f"================{model_name} Test================")
    print(f"Load Model Weights From: {weights_path}")
    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/938367.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

简单了解一下 Go 语言的构建约束?

​构建约束是一种在 Go 语言中控制源文件编译条件的方法,它可以让您指定某些文件只在特定的操作系统、架构、编译器或 Go 版本下编译,而在其他环境中自动忽略。这样可以方便您针对不同的平台或场景编写不同的代码,实现条件编译的功能。 构建…

12.17双向链表,循环链表

循环单向链表 1.头文件test.h #ifndef __TEST_H_ #define __TEST_H_#include<stdio.h> #include<stdlib.h>typedef struct node {union{int len;int data;};struct node *next; }looplink,*looplinkPtr;//创建 looplinkPtr create();//判空 int empty(); //申请…

图的最小生成树(C++实现图【3】)

目录 1.最小生成树 1.1 Kruskal算法 代码部分 1.2 Prim算法 代码部分 1.最小生成树 连通图中的每一棵生成树&#xff0c;都是原图的一个极大无环子图&#xff0c;即&#xff1a;从其中删去任何一条边&#xff0c;生成树就不在连通&#xff1b;反之&#xff0c;在其中引入任何一…

解决电脑网速慢问题:硬件检查与软件设置指南

电脑网速慢是许多用户在使用过程中常见的问题&#xff0c;它不仅会降低工作效率&#xff0c;还可能影响娱乐体验。导致电脑网速慢的原因多种多样&#xff0c;包括硬件问题、软件设置和网络环境等。本文将从不同角度分析这些原因&#xff0c;并提供提高电脑网速的方法。 一、检查…

Python-基于Pygame的小游戏(贪吃蛇)(一)

前言:贪吃蛇是一款经典的电子游戏&#xff0c;最早可以追溯到1976年的街机游戏Blockade。随着诺基亚手机的普及&#xff0c;贪吃蛇游戏在1990年代变得广为人知。它是一款休闲益智类游戏&#xff0c;适合所有年龄段的玩家&#xff0c;其最初为单机模式&#xff0c;后来随着技术发…

MySQL表的增删改查(2)

1.数据库约束 1)约束类型 not null指定某列不能存储null值unique保证某列的每一行必须有唯一值default规定没有给列赋值时的默认值primary keynot null和unique的结合,一张表里只能有一个,作为身份标识的数据foreign key保证一个表的数据匹配另一个表中的值的参照完整性check…

职场人如何提升职业技能?

职场人如何提升职业技能&#xff1f; 在职场中&#xff0c;每个人都像是一名航行在广阔大海上的水手&#xff0c;面对着不断变化的风浪和挑战。要想在这片职场海洋中稳步前行&#xff0c;甚至脱颖而出&#xff0c;提升职业技能是必不可少的。那么&#xff0c;职场人究竟该如何…

IVE Model 2.0.2运行报错:Error launching application × could not locate Java runtime

在windows电脑上运行IVE Model 2.0.2程序的时候弹窗报错: could not locate Java runtime 一、原因分析 第一次安装的时候,很确定自己的JDK环境安装是没有问题,但是运行仍然会报错,由于软件没有说明使用什么版本的JDK只能挨个尝试,换了几个版本仍然不行,忽然想到,这个软…

模型训练篇 | 关于常见的10种数据标注工具介绍

前言:Hello大家好,我是小哥谈。数据标注工具是一种用于标记和分类数字图像、音频、视频或文本等数据集的工具。数据标注工具可以自动或手动标记数据集中的对象、人脸、物体、文字等,以便机器学习模型能够理解和识别这些数据。数据标注工具通常由开发者或数据标注团队开发和使…

Linux应用开发————mysql数据库

数据库概述 什么是数据库(database)? 数据库是一种数据管理的管理软件&#xff0c;它的作用是为了有效管理数据&#xff0c;形成一个尽可能无几余的数据集合&#xff0c;并能提供接口&#xff0c;方便用户使用。 数据库能用来干什么? 顾名思义&#xff0c;仓库就是用来保存东…

c++理解(三)

本文主要探讨c相关知识。 模板是对类型参数化 函数模板特化不是模板函数重载 allocator(空间配置器):内存开辟释放,对象构造析构 优先调用对象成员方法实现的运算符重载函数,其次全局作用域找 迭代器遍历访问元素,调用erase&#xff0c;insert方法后&#xff0c;当前位置到容器…

动态规划——最长公共子序列

文章目录 概要整体流程问题描述递推公式由来两个序列的最后一位相等两个序列的最后一位不等左图右图 表格填写dp 表格定义递推公式填表过程填表过程解析最终结果 小结 概要 动态规划相关知识 求解最长的公共子序列 整体流程 问题定义与区分&#xff1a;理解最长公共子串与最…

Node的学习以及学习通过Node书写接口并简单操作数据库

Node的学习 Node的基础上述是关于Node的一些基础&#xff0c;总结的还行&#xff1b; 利用Node书写接口并操作数据库 1. 初始化项目 创建新的项目文件夹&#xff0c;并初始化 package.json mkdir my-backend cd my-backend npm init -y2. 安装必要的依赖 安装Express.js&…

arXiv-2024 | NavAgent:基于多尺度城市街道视图融合的无人机视觉语言导航

作者&#xff1a;Youzhi Liu, Fanglong Yao*, Yuanchang Yue, Guangluan Xu, Xian Sun, Kun Fu 单位&#xff1a;中国科学院大学电子电气与通信工程学院&#xff0c;中国科学院空天信息创新研究院网络信息系统技术重点实验室 原文链接&#xff1a;NavAgent: Multi-scale Urba…

易语言鼠标轨迹算法(游戏防检测算法)

一.简介 鼠标轨迹算法是一种模拟人类鼠标操作的程序&#xff0c;它能够模拟出自然而真实的鼠标移动路径。 鼠标轨迹算法的底层实现采用C/C语言&#xff0c;原因在于C/C提供了高性能的执行能力和直接访问操作系统底层资源的能力。 鼠标轨迹算法具有以下优势&#xff1a; 模拟…

Three.js材质纹理扩散过渡

Three.js材质纹理扩散过渡 import * as THREE from "three"; import { ThreeHelper } from "/src/ThreeHelper"; import { LoadGLTF, MethodBaseSceneSet } from "/src/ThreeHelper/decorators"; import { MainScreen } from "/src/compone…

apache-tomcat-6.0.44.exe Win10

apache-tomcat-6.0.44.exe Win10

赫布定律 | 机器学习 / 反向传播 / 经验 / 习惯

注&#xff1a;本文为 “赫布定律” 相关文章合辑。 未整理。 赫布定律 Hebb‘s law 馥墨轩 2021 年 03 月 13 日 00:03 1 赫布集合的基本定义 唐纳德・赫布&#xff08;Donald Hebb&#xff09;在 1949 年出版了《行为的组织》&#xff08;The Organization of Behavior&a…

uni-app实现小程序、H5图片轮播预览、双指缩放、双击放大、单击还原、滑动切换功能

前言 这次的标题有点长&#xff0c;主要是想要表述的功能点有点多&#xff1b; 简单做一下需求描述 产品要求在商品详情页的头部轮播图部分&#xff0c;可以单击预览大图&#xff0c;同时在预览界面可以双指放大缩小图片并且可以移动查看图片&#xff0c;双击放大&#xff0…

杭州乘云联合信通院发布《云计算智能化可观测性能力成熟度模型》

原文地址&#xff1a;杭州乘云联合中国信通院等单位正式发布《云计算智能化可观测性能力成熟度模型》标准 2024年12月3日&#xff0c;由全球数字经济大会组委会主办、中国信通院承办的 2024全球数字经济大会 云AI计算创新发展大会&#xff08;2024 Cloud AI Compute Ignite&…