实验13 使用预训练resnet18实现CIFAR-10分类

1.数据预处理

首先利用函数transforms.Compose定义了一个预处理函数transform,里面定义了两种操作,一个是将图像转换为Tensor,一个是对图像进行标准化。然后利用函数torchvision.datasets.CIFAR10下载数据集,这个函数有四个常见的初始化参数:root为数据存储的路径,如果数据已经下载,会直接从这个路径加载数据。train如果为True,表示加载训练集,train如果为False,加载测试集。download如果设置为True,表示如果本地不存在数据集,会自动从互联网上下载。transform指定一个转换函数,对数据进行预处理和数据增强等操作。所以下载训练集train_full时,train赋值为True,下载测试集时,train赋值为False。之后对下载的训练集train_full进行划分,先规定指定的大小,然后利用random_split进行划分,最后就是创建Dataloader,batch_size设为64,得到train_loader,val_loader,test_loader。

代码:

# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 数据预处理和增强
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为Tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 图像标准化
])

# 下载 CIFAR-10 数据集
train_full = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# 划分训练集(40,000)和验证集(10,000)
train_size = int(0.8 * len(train_full))  # 80% 用于训练
val_size = len(train_full) - train_size  # 剩余 20% 用于验证
train_data, val_data = random_split(train_full, [train_size, val_size])

# 创建 DataLoader
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)
test_loader = DataLoader(test, batch_size=64, shuffle=False)

2.模型构建

模型构建就比较简单,直接使用使用pytorch定义的库函数,只有一行代码:

model = models.resnet18(pretrained=False),pretrained=False表示不使用在Imagenet上预训练的权重,pretrained=True表示使用在Imagenet上预训练的权重。因为这个模型是训练Imagenet构建的模型,要想让这个模型适应新任务,需要获取最后一层的输入特征数,然后利用一个全连接层将输出改为10。

代码:

# 初始化 ResNet-18 模型
model = models.resnet18(pretrained=True)
# 修改最后一层(全连接层),适应新的任务
num_ftrs = model.fc.in_features  # 获取最后一层的输入特征数
model.fc = torch.nn.Linear(num_ftrs, 10)  # 将输出改为 10 个类别(例如 CIFAR-10)

3.模型训练

创建Runner类,管理训练、评估、测试和预测过程。还是之前的一套东西,首先是一个init函数,用于初始化数据集、损失函数、优化器等。train函数用于计算在训练集上的loss,并反向传播更新参数。evaluate函数用于计算在验证集上的损失,不用反向传播更新模型的参数,同时根据evaluate函数得到的损失判断是否保存最优模型,利用state_dict函数保存最优模型。test函数首先加载最优模型,然后在测试集计算最优模型的准确率。predict函数预测某个图像属于某个类别的概率,虽然resnet最后一层没有softmax,但是也可以根据最后一层得到的10个logits(未经过归一化的原始输出)取最大来判断图像属于某一类(因为这10个值也是有大小关系的,softmax函数不会修改这10个值的大小关系)。

定义学习率=0.01、批次大小=30、损失函数为交叉熵损失nn.CrossEntropyLoss()、优化器为Adam。

实例化Runner,调用train函数,开始训练。

代码:

class Runner:
    def __init__(self, model, train_loader, val_loader, test_loader, criterion, optimizer, device):
        self.model = model.to(device)  # 将模型移到GPU
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.best_model = None
        self.best_val_loss = float('inf')
        self.train_losses = []  # 存储训练损失
        self.val_losses = []  # 存储验证损失

    def train(self, epochs=10):
        for epoch in range(epochs):
            self.model.train()
            running_loss = 0.0

            for inputs, labels in self.train_loader:
                # 将数据移到GPU
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()

                running_loss += loss.item()

            # 计算平均训练损失
            train_loss = running_loss / len(self.train_loader)
            self.train_losses.append(train_loss)

            # 计算验证集上的损失
            val_loss = self.evaluate()
            self.val_losses.append(val_loss)

            print(f'Epoch [{epoch + 1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

            # 如果验证集上的损失最小,保存模型
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.best_model = self.model.state_dict()

    def evaluate(self):
        self.model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, labels in self.val_loader:
                # 将数据移到GPU
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                val_loss += loss.item()
        return val_loss / len(self.val_loader)

    def test(self):
        self.model.load_state_dict(self.best_model)
        self.model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in self.test_loader:
                # 将数据移到GPU
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                outputs = self.model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        test_accuracy = correct / total
        print(f'Test Accuracy: {test_accuracy:.4f}')

    def predict(self, image):
        self.model.eval()
        image = image.to(self.device)  # 将图像移到GPU
        with torch.no_grad():
            output = self.model(image)
            _, predicted = torch.max(output, 1)
            return predicted.item()

    def visualize_and_predict(self, index):
        """
        针对训练集中的某一张图片进行预测,并可视化图片。
        :param index: 训练集中的图片索引
        """
        # 获取训练集中的第 index 张图片
        image, label = self.train_loader.dataset[index]

        # 将图像移到GPU(如果需要)
        image = image.unsqueeze(0).to(self.device)  # 增加一个维度作为batch size

        # 可视化图像
        plt.imshow(image.cpu().squeeze().numpy(), cmap='gray')  # 假设是灰度图,若是彩色图像要调整
        plt.title(f"True Label: {label}")
        plt.show()

        # 预测该图片的类别
        predicted_label = self.predict(image)
        print(f"Predicted Label: {predicted_label}")
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 实例化Runner类
runner = Runner(model, train_loader, val_loader, test_loader, criterion, optimizer, device)

# 训练模型
runner.train(epochs=30)
# 绘制损失曲线
plt.figure(figsize=(10, 6))
plt.plot(runner.train_losses, label='Train Loss')
plt.plot(runner.val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.legend()
plt.grid()
plt.show()

4.模型评价

调用test函数,计算在测试集上的准确率。

代码:

# 在最优模型上评估测试集准确率
runner.test()

5.模型预测

在训练集任意选取一个图像,获取图像的image和标签label,因为图像已经经过了transform的变换,所以这个图像不需要transform,只需要添加一个维度1作为batch_size,可视化图像和真实标签,然后调用predict函数进行预测,输出真实类别。

代码:

# CIFAR-10 是 RGB 图像,确保正确显示
# 将 Tensor 转换为 numpy 数组并调整维度顺序为 HWC (Height, Width, Channels)
image_np = image.numpy().transpose((1, 2, 0))  # 从 CHW 转为 HWC

# 可视化图像
plt.imshow(image_np)
plt.title(f"True Label: {label}")
plt.show()

# 直接将图像传递给预测函数,不再需要 transform
# 但是要确保图像传入时是正确的 batch size 形状,即增加一个 batch 维度
image_transformed = image.unsqueeze(0).to(device)  # 增加一个维度作为 batch size

# 预测该图片的类别
predicted_label = runner.predict(image_transformed)
print(f"Predicted Label: {predicted_label}")

6.实验结果与分析

不使用预训练权重的损失变化、准确率和预测结果

使用预训练权重的损失变化、准确率和预测结果

通过观察损失变化,我们发现两个模型在训练集上的loss一直在减小,说明模型的参数一直在更新。但是在验证集上的损失一开始是下降的,但是后来不断增大,我觉得是因为模型过拟合了。但是可以发现在没有预训练权重上的最优验证损失是比有预训练权重的模型上的最优验证损失大的。通过保存最优模型,在最优模型上计算准确率,发现在没有预训练权重的模型得到的准确率是0.7332,在使用预训练权重的模型得到的准确率是0.7431。

结论:通过对比在验证集上的最优验证损失和在测试集上的准确率,得到结论使用了预训练的模型效果要更好。

7.总结与心得体会

总结:

1.预训练模型:

预训练模型是指在一个大规模数据集上(如 ImageNet、COCO 等)经过训练的模型。这个模型已经学习到了一些通用的特征,比如图像中的边缘、纹理、颜色、形状等,或者文本中的语法、词汇关系等。这些特征是从数据中自动学习的,并且在很多不同的任务中都有用。

例子:

在图像分类任务中,ResNet、VGG、Inception 等深度神经网络在 ImageNet 上经过训练后,它们可以识别成千上万种不同的物体。由于这些物体特征具有广泛的普适性,我们可以将这些模型用于其他图像分类任务(例如 Cifar-10、Cifar-100),而无需从头开始训练。

在自然语言处理(NLP)中,像 BERT、GPT 等预训练语言模型已经在大量的文本数据上训练过,学习了丰富的语言知识。因此,我们可以将这些模型应用于文本分类、情感分析、问答等任务。

预训练模型的优势:

节省计算资源:训练深度神经网络需要大量的计算资源和时间,尤其是在大规模数据集上。通过使用预训练模型,用户可以避免从零开始训练,直接利用现成的知识。

提高效果:预训练模型已经学习到了一些通用的特征,可以加速学习过程,并且通常能够取得比从头开始训练更好的效果。
2. 迁移学习(Transfer Learning)

迁移学习是一种利用在一个任务上学到的知识,来帮助在另一个相关任务上进行学习的技术。换句话说,它将一个任务中的学习成果迁移到另一个任务中,特别是在目标任务的数据较少时。

迁移学习的核心思想是:如果一个模型在某个任务上已经学到了一些有用的特征,那么这些特征可以迁移到另一个任务上,帮助模型更好地学习。

迁移学习的典型流程:

模型加载:加载一个在大数据集上预训练的模型(如 ResNet、VGG、BERT 等)。

模型微调:对模型的部分层进行微调,或者只训练新添加的层(如分类层)。

应用于新任务:将经过微调的模型应用于新的、可能较小的数据集。

迁移学习的类型

迁移学习有多种不同的方式,常见的有以下几种:

微调(Fine-Tuning):使用预训练模型的权重,并对某些层或整个模型进行微调,以适应新的任务和数据。

通常会冻结前几层(因为它们学习的是通用特征),只训练后几层(专门针对当前任务)。

特征提取(Feature Extraction):使用预训练模型的特征提取能力,将前几层的权重固定,不更新,仅训练新加的全连接层或输出层。

零-shot 学习:在一些任务中,预训练模型被直接应用到目标任务,而不进行微调,特别是当目标任务的标注数据非常少时。

迁移学习的应用:

计算机视觉:在一个大规模的数据集(如 ImageNet)上训练的模型可以用于许多不同的图像分类任务,例如识别猫、狗、车、飞机等物体,或者在医疗影像、无人驾驶等领域中应用。

自然语言处理(NLP):例如,BERT 和 GPT 等模型可以在情感分析、命名实体识别、机器翻译等任务上进行迁移学习。

3. 预训练模型和迁移学习的关系

预训练模型和迁移学习是紧密相关的。迁移学习通常依赖于预训练模型,使用在一个任务中学到的知识来帮助另一个任务。在迁移学习中,预训练模型提供了一个良好的起点,减少了从头开始训练的难度和所需的数据量。

预训练模型与迁移学习的关系:

预训练模型是迁移学习的基础,因为迁移学习的一个关键步骤是使用已经在其他任务上训练好的模型。

迁移学习则是使用这些预训练模型的技术,它通过微调或特征提取等方式,将预训练模型的知识应用到新任务中。

使用torchvision.datasets的常见参数:

root:数据存储的路径。如果数据已经下载,它会直接从该路径加载数据。

train:如果设置为 True,加载训练集;如果设置为 False,加载测试集。

download:如果设置为 True,如果本地不存在数据集,它会自动从互联网上下载。

transform:指定一个转换函数,对数据进行预处理和数据增强等操作。

transforms.Compose 是 torchvision.transforms 模块中的一个函数,用于将多个图像预处理操作组合成一个复合操作。在神经网络训练中,常常需要对输入图像进行多种预处理,例如将图像转换为张量(Tensor)、标准化、数据增强等。transforms.Compose 允许你将这些操作按顺序组合在一起,并一次性应用于输入图像。

心得体会:

这个实验直接调用预训练的resnet18进行CIFAR-10数据集的分类,因为这个模型是在Imagenet数据集上训练得到的,所以适用于新的任务需要微调模型。通过对比没有预训练权重的模型和有预训练权重的模型的训练效果,发现还是有预训练权重得到的结果比较好,因为预训练模型已经学习到了一些通用的特征,可以加速学习过程,通常能够取得比从头开始训练更好的效果。在实际应用中在理解模型内部实现的基础上,直接调用高层API是一个不错的选择,可以减少代码量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/928033.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Wwise SoundBanks内存优化

1.更换音频格式为Vorbis 2.停用多余的音频&#xff0c;如Random Container的随机脚步声数量降为2个 3.背景音乐勾选“Stream”。这样就让音频从硬盘流送到Wwise&#xff0c;而不是保存在内存当中&#xff0c;也就节省了内存 4.设置最大发声数Max Voice Instances 5.设置音频…

【测试工具JMeter篇】JMeter性能测试入门级教程(六):JMeter中实现参数化的几种方式

一、参数化的定义 什么是参数化&#xff1f;从字面上去理解的话&#xff0c;就是事先准备好数据&#xff08;广义上来说&#xff0c;可以是具体的数据值&#xff0c;也可以是数据生成规则&#xff09;&#xff0c;而非在脚本中固化&#xff0c;脚本执行时从准备好的数据中取值。…

2024年11月份实时获取地图边界数据方法,省市区县街道多级联动【附实时geoJson数据下载】

首先&#xff0c;来看下效果图 在线体验地址&#xff1a;https://geojson.hxkj.vip&#xff0c;并提供实时geoJson数据文件下载 可下载的数据包含省级geojson行政边界数据、市级geojson行政边界数据、区/县级geojson行政边界数据、省市区县街道行政编码四级联动数据&#xff0…

【力扣】—— 二叉树的前序遍历、字典序最小回文串

Hi~&#xff01;这里是奋斗的明志&#xff0c;很荣幸您能阅读我的文章&#xff0c;诚请评论指点&#xff0c;欢迎欢迎 ~~ &#x1f331;&#x1f331;个人主页&#xff1a;奋斗的明志 &#x1f331;&#x1f331;所属专栏&#xff1a;数据结构 &#x1f4da;本系列文章为个人学…

电脑显示没信号显示屏不亮怎么办?电脑没信号解决方法

电脑没信号显示屏不亮这种故障的原因可能有多种&#xff0c;例如显示器的供电、连接、设置等问题&#xff0c;或者电脑的显卡、内存、硬盘、主板等硬件问题。所以我们想要解决这个问题&#xff0c;也是需要多方面排除找到具体原因然后进行修复。下面将为大家介绍一些常见的电脑…

【docker】Windows11创建Ubuntu-desktop并使用VNC完成远程访问

【docker】Windows11创建Ubuntu-desktop并使用VNC完成远程访问 文章目录 【docker】Windows11创建Ubuntu-desktop并使用VNC完成远程访问前言创建Ubuntu容器下载镜像运行容器连接容器 搭建容器XFCE桌面环境安装ubuntu桌面 总结 前言 docker ubuntu容器在深度学习领域的使用过程…

歇一歇,写写段子

无聊的日子都在写段子1.0 中学的时候喜欢看意林之类的杂志&#xff0c; 里面的作者用乱七八糟的理由跑去旅游&#xff0c;然后说“阻碍你脚步的永远只有逃离的勇气和对生活的热爱”&#xff0c; 我觉得太对了&#xff0c;可惜 12306 付款方式里没有勇气和热爱&#xff0c;不…

1203论文速读

1、Hierarchical Stochastic Block Model for Community Detection in Multiplex Networks∗ &#xff08;多层网络社区检测的层次随机块模型 &#xff09; 全文总结&#xff1a;本文提出了一种新颖的贝叶斯模型&#xff0c;称为分层随机块模型&#xff08;HSBM&#xff09;&a…

双向长短期记忆(Bi-LSTM)神经网络介绍

长短期记忆(Long Short-Term Memory, LSTM)神经网络&#xff1a; 1.是Hochreiter和Schmidhuber设计的循环神经网络(Recurrent Neural Network, RNN)的改进版本。LSTM模型借鉴了人类大脑的选择性输入和选择性遗忘机制&#xff0c;获取序列中的关键信息&#xff0c;遗忘和当前预测…

.NET 9 中 LINQ 新增功能实现过程

本文介绍了.NET 9中LINQ新增功能&#xff0c;包括CountBy、AggregateBy和Index方法,并提供了相关代码示例和输出结果&#xff0c;感兴趣的朋友跟随我一起看看吧 LINQ 介绍 语言集成查询 (LINQ) 是一系列直接将查询功能集成到 C# 语言的技术统称。 数据查询历来都表示为简单的…

解决PowerPoint的流程图图标中输入文字位置偏下的问题

解决PowerPoint的流程图图标中输入文字位置偏下的问题 背景 在PowerPoint中&#xff0c;插入流程图形状&#xff0c;并在其内部输入中文字符&#xff0c;是很常规的操作。然而&#xff0c;有时输入文本发现文本整体偏下&#xff0c;靠近流程图下侧。 症状 文字位置偏下的效…

C++基础:list的基本使用

文章目录 1.基本构造和插入删除基本构造和尾插数据迭代器的分类内置排序sort任意位置插入删除 2.链表的合并,去重和剪切链表的合并链表去重链表的剪切 list的本质就是带头双向循环列表 1.基本构造和插入删除 基本构造和尾插数据 与之前vector的方法相同直接调用即可 迭代器的分…

SpringBoot中实现EasyExcel实现动态表头导入(完整版)

前言 最近在写项目的时候有一个需求&#xff0c;就是实现动态表头的导入&#xff0c;那时候我自己也不知道动态表头导入是什么&#xff0c;查询了大量的网站和资料&#xff0c;终于了解了动态表头导入是什么。 一、准备工作 确保项目中引入了处理 Excel 文件的相关库&#xff…

亚马逊云(AWS)使用root用户登录

最近在AWS新开了服务器&#xff08;EC2&#xff09;&#xff0c;用于学习&#xff0c;遇到一个问题就是默认是用ec2-user用户登录&#xff0c;也需要密钥对。 既然是学习用的服务器&#xff0c;还是想直接用root登录&#xff0c;下面开始修改&#xff1a; 操作系统是&#xff1…

基于Java Springboot武汉市公交路线查询APP且微信小程序

一、作品包含 源码数据库设计文档万字PPT全套环境和工具资源部署教程 二、项目技术 前端技术&#xff1a;Html、Css、Js、Vue、Element-ui 数据库&#xff1a;MySQL 后端技术&#xff1a;Java、Spring Boot、MyBatis 三、运行环境 开发工具&#xff1a;IDEA/eclipse 微信…

【C++】数组

1.概述 所谓数组&#xff0c;就是一个集合&#xff0c;该集合里面存放了相同类型的数据元素。 数组特点&#xff1a; &#xff08;1&#xff09;数组中的每个数据元素都是相同的数据类型。 &#xff08;2&#xff09;数组是有连续的内存空间组成的。 2、一维数组 2.1维数组定…

WPF中的VisualState(视觉状态)

以前在设置控件样式或自定义控件时&#xff0c;都是使用触发器来进行样式更改。触发器可以在属性值发生更改时启动操作。 像这样&#xff1a; <Style TargetType"ListBoxItem"><Setter Property"Opacity" Value"0.5" /><Setter …

day04【入门】MySQL学习(1)

目前的学习进度&#xff0c;如上图所示。从晚上开始学习MySQL数据库啦。 目录 1、数据库简介 2、数据集连接及准备工作 3、sql 语言中的注释 4、MySQL中常用数据类型 5、数据库中元素 6、创建表 7、insert插入记录 8、select查询 9、update修改数据 10、delete删除、t…

redis核心命令全局命令 + redis 常见的数据结构 + redis单线程模型

文章目录 一. 核心命令1. set2. get 二. 全局命令1. keys2. exists3. del4. expire5. ttl6. type 三. redis 常见的数据结构及内部编码四. redis单线程模型 一. 核心命令 1. set set key value key 和 value 都是string类型的 对于key value, 不需要加上引号, 就是表示字符串…

Dromara WarmFlow工作流动态指定办理人

Dromara WarmFlow工作流动态指定办理人 背景&#xff1a; 审批任务的办理人&#xff0c;通常是在流程设计器中预先设定好办理人&#xff0c;那如果想要在办理过程中指定办理人呢&#xff1f; 那不得不提一下本次的主角&#xff0c;来自Dromara组织的WarmFlow工作流&#xff0…