动手学深度学习(Pytorch版)代码实践 -计算机视觉-37微调

37微调

在这里插入图片描述

import os
import torch
import torchvision
from torch import nn
import liliPytorch as lp
import matplotlib.pyplot as plt
from d2l import torch as d2l

# 获取数据集
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
                         'fba480ffa8aa7e0febbb511d181409f899b9baa5')

data_dir = d2l.download_extract('hotdog')
#Downloading ../data\hotdog.zip from http://d2l-data.s3-accelerate.amazonaws.com/hotdog.zip...

# 分别读取训练和测试数据集中的所有图像文件
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))
# ImageFolder 会递归地读取指定目录下的所有图像文件。
# print(train_imgs.classes)#一个类名列表 # ['hotdog', 'not-hotdog']
# print(train_imgs.class_to_idx) # 一个字典,类名映射到类索引 # {'hotdog': 0, 'not-hotdog': 1}
# print(train_imgs.imgs) # 一个包含所有图像路径和对应类索引的列表
# 例如:[('../data\\hotdog\\train\\hotdog\\0.png', 0), ('../data\\hotdog\\train\\hotdog\\1.png', 0)
#       , ('../data\\hotdog\\train\\not-hotdog\\999.png', 1)]
# 显示了前8个正类样本图片和最后8张负类样本图片

# hotdogs = [train_imgs[i][0] for i in range(8)] #train_imgs[i] 返回一个元组 (image, label),
# # 其中 image 是图像张量,label 是对应的标签。[0] 只提取图像张量。

# not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)] # 索引从 -1 到 -8

# d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)
# plt.show() # 显示图片

# 使用RGB通道的均值和标准差,以标准化每个通道
normalize = torchvision.transforms.Normalize(
    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

train_augs = torchvision.transforms.Compose([
    #从图像中裁切随机大小和随机长宽比的区域,然后将该区域缩放为224 * 224
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    normalize])

test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize([256, 256]),
    torchvision.transforms.CenterCrop(224), # 裁剪中央224 * 224
    torchvision.transforms.ToTensor(),
    normalize])

# 定义和初始化模型
# 使用在ImageNet数据集上预训练的ResNet-18作为源模型
pretrained_net = torchvision.models.resnet18(pretrained=True)

# 源模型实例包含许多特征层和一个输出层fc
print(pretrained_net.fc)
# Linear(in_features=512, out_features=1000, bias=True)

finetune_net = pretrained_net
# 改变输出层fc
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
# 参数初始化
nn.init.xavier_uniform_(finetune_net.fc.weight)


def train_batch_ch13(net, X, y, loss, trainer, devices):
    """使用多GPU训练一个小批量数据。
    参数:
    net: 神经网络模型。
    X: 输入数据,张量或张量列表。
    y: 标签数据。
    loss: 损失函数。
    trainer: 优化器。
    devices: GPU设备列表。
    返回:
    train_loss_sum: 当前批次的训练损失和。
    train_acc_sum: 当前批次的训练准确度和。
    """
    # 如果输入数据X是列表类型
    if isinstance(X, list):
        # 将列表中的每个张量移动到第一个GPU设备
        X = [x.to(devices[0]) for x in X]
    else:
        X = X.to(devices[0])# 如果X不是列表,直接将X移动到第一个GPU设备
    y = y.to(devices[0])# 将标签数据y移动到第一个GPU设备
    net.train() # 设置网络为训练模式
    trainer.zero_grad()# 梯度清零
    pred = net(X) # 前向传播,计算预测值
    l = loss(pred, y) # 计算损失
    l.sum().backward()# 反向传播,计算梯度
    trainer.step() # 更新模型参数
    train_loss_sum = l.sum()# 计算当前批次的总损失
    train_acc_sum = d2l.accuracy(pred, y)# 计算当前批次的总准确度
    return train_loss_sum, train_acc_sum# 返回训练损失和与准确度和


def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
               devices=d2l.try_all_gpus()):
    """训练模型在多GPU
    参数:
    net: 神经网络模型。
    train_iter: 训练数据集的迭代器。
    test_iter: 测试数据集的迭代器。
    loss: 损失函数。
    trainer: 优化器。
    num_epochs: 训练的轮数。
    devices: GPU设备列表,默认使用所有可用的GPU。
    """
    # 初始化计时器和训练批次数
    timer, num_batches = d2l.Timer(), len(train_iter)
    # 初始化动画器,用于实时绘制训练和测试指标
    animator = lp.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                           legend=['train loss', 'train acc', 'test acc'])
    # 将模型封装成 DataParallel 模式以支持多GPU训练,并将其移动到第一个GPU设备
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    # 训练循环,遍历每个epoch
    for epoch in range(num_epochs):
        # 初始化指标累加器,metric[0]表示总损失,metric[1]表示总准确度,
        # metric[2]表示样本数量,metric[3]表示标签数量
        metric = lp.Accumulator(4)
        # 遍历训练数据集
        for i, (features, labels) in enumerate(train_iter):
            timer.start()  # 开始计时
            # 训练一个小批量数据,并获取损失和准确度
            l, acc = train_batch_ch13(net, features, labels, loss, trainer, devices)
            metric.add(l, acc, labels.shape[0], labels.numel())   # 更新指标累加器
            timer.stop()  # 停止计时
            # 每训练完五分之一的批次或者是最后一个批次时,更新动画器
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[2], metric[1] / metric[3], None))
        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter) # 在测试数据集上评估模型准确度
        animator.add(epoch + 1, (None, None, test_acc))# 更新动画器
    # 打印最终的训练损失、训练准确度和测试准确度
    print(f'loss {metric[0] / metric[2]:.3f}, train acc '
          f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    # 打印每秒处理的样本数和使用的GPU设备信息
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
          f'{str(devices)}')


def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=5,param_group=True):
    """
    参数:
    net: 神经网络模型。
    learning_rate: 学习率。
    batch_size: 每个小批量的大小,默认为128。
    num_epochs: 训练的轮数,默认为5。
    param_group: 是否对不同层使用不同的学习率,默认为True。
    """
    train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'train'), transform=train_augs),
        batch_size=batch_size, shuffle=True)  # 创建训练数据集的迭代器
    
    test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
        os.path.join(data_dir, 'test'), transform=test_augs),
        batch_size=batch_size)  # 创建测试数据集的迭代器

    devices = d2l.try_all_gpus()  # 获取所有可用的GPU设备
    loss = nn.CrossEntropyLoss(reduction="none")   # 定义损失函数
    # 如果使用参数组
    if param_group:
        # 获取除最后全连接层外的所有参数
        # 列表params_1x,包含除最后一层全连接层外的所有参数。
        params_1x = [param for name, param in net.named_parameters()
                     if name not in ["fc.weight", "fc.bias"]]
        # 定义优化器,分别为不同的参数组设置不同的学习率
        trainer = torch.optim.SGD([{'params': params_1x},
                                   {'params': net.fc.parameters(),
                                    'lr': learning_rate * 10}],
                                  lr=learning_rate, weight_decay=0.001)
    else:
        # 如果不使用参数组,为所有参数设置相同的学习率
        trainer = torch.optim.SGD(net.parameters(), lr=learning_rate,
                                  weight_decay=0.001)
    # 调用训练函数,开始训练
    train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
    
train_fine_tuning(finetune_net, 5e-5)
# loss 0.211, train acc 0.927, test acc 0.938
# 456.7 examples/sec on [device(type='cuda', index=0)]


"""
为了进行比较,我们定义了一个相同的模型,但是将其所有模型参数初始化为随机值。
由于整个模型需要从头开始训练,因此我们需要使用更大的学习率。
"""
scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features, 2)
train_fine_tuning(scratch_net, 5e-4, param_group=False)
# loss 0.338, train acc 0.842, test acc 0.859
# 457.7 examples/sec on [device(type='cuda', index=0)]

plt.show() #显示图片 

预训练resnet18模型运行效果:

在这里插入图片描述

初始化resnet18模型运行效果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/737298.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

手撕RPC——前言

手撕RPC——前言 一、RPC是什么?二、为什么会出现RPC三、RPC的原理3.1 RPC是如何做到透明化远程服务调用?3.2 如何实现传输消息的编解码? 一、RPC是什么? RPC(Remote Procedure Call,远程过程调用&#xff…

RealityCheck™电机监测和预测性维护模型

RealityCheck™电机 一个附加的软件工具箱,可实现条件监测和预测性维护功能,而无需依赖额外的传感器。相反,它使用来自电机控制过程的电子信息作为振动和其他传感器的代理。凭借其先进的信号处理和机器学习(ML)模型,RealityCheck …

示例:推荐一个应用Adorner做的表单对话框

一、目的:开发过程中经常会修改和查看一个Model的数据,一般情况下会自定义一个控件或Window去显示Model数据,但这种数据如果比较多会增加很多开发工作,本文介绍一种通用的方式,应用表达Form控件去简化处理,…

ARM裸机:基础了解

ARM的几种版本号 ARM内核版本号 ARMv7 ARM SoC版本号 Cortex-A8 芯片型号 S5PV210 ARM型号的发展历程 m microcontroller微控制器 就是单片机 a application应用级处理器 就是手机、平板、电脑的CPU r realtime实时处理器 响应速度快,主要用在工业、航天等领域 soc 、cpu、…

Elasticsearch:智能 RAG,获取周围分块(二)

在之前的文章 “Elasticsearch:智能 RAG,获取周围分块(一) ” 里,它介绍了如何实现智能 RAG,获取周围分块。在那个文章里有一个 notebook。为了方便在本地部署的开发者能够顺利的运行那里的 notebook。在本…

如何在 Mac 上清空硬盘后恢复丢失的数据?

如果您不小心从 Mac 硬盘上删除了重要文件,您可能会感到非常沮丧。但您仍然可以找回丢失的信息。将 Mac 想象成一个大盒子,里面装着所有东西。丢弃某样东西就像撕掉盒子上的标签:房间现在可以放新东西了,但旧东西仍然在那里&#…

文华财经T8自动化交易程序策略模型指标公式源码

文华财经T8自动化交易程序策略模型指标公式源码: //定义变量 //资金管理与仓位控制 8CS:INITMONEY;//初始资金 8QY:MONEYTOT;//实际权益 8QY1:MIN(MA(8QY,5*R),MA(8QY,2*R)); FXBL:N1; DBKS:8QY1*N1;//计算单笔允许亏损额度 BZDKS:MAX(AA-BB,N*1T)*UNIT; SZDKS:MAX…

已解决ApplicationException异常的正确解决方法,亲测有效!!!

已解决ApplicationException异常的正确解决方法,亲测有效!!! 目录 问题分析 出现问题的场景 报错原因 解决思路 解决方法 分析错误日志 检查业务逻辑 验证输入数据 确认服务器端资源的可用性 增加对特殊业务情况的处理…

能正常执行但是 cion 标红/没有字段提示

ctrl q 退出 clion 找到工程根目录,删除隐藏文件 .idea 再重新打开 clion 标红消失,同时再次输入函数/类属性,出现字段提示 clion 的智能提示方案存储在 .idea 文件中,如果工程能够正常编译执行,那么说明是智能提示…

InfoMasker :新型反窃听系统,保护语音隐私

随着智能手机、智能音箱等设备的普及,人们越来越担心自己的谈话内容被窃听。由于这些设备通常是黑盒的,攻击者可能利用、篡改或配置这些设备进行窃听。借助自动语音识别 (ASR) 系统,攻击者可以从窃听的录音中提取受害者的个人信息&#xff0c…

如何搭建饥荒服务器

《饥荒》是由Klei Entertainment开发的一款动作冒险类求生游戏,于2013年4月23日在PC上发行,2015年7月9日在iOS发布口袋版。游戏讲述的是关于一名科学家被恶魔传送到了一个神秘的世界,玩家将在这个异世界生存并逃出这个异世界的故事。《饥荒》…

华为数通——ACL

ACL基本介绍 ACL:访问控制列表,通过端口对数据流进行过滤,ACL判别依据是五元组:源IP地址,源端口,目的IP地址,目的端口、协议。(ACL工作于OSI模型第三层,是路由器和三层交换机接口的…

2.超声波测距模块

1.简介 2.超声波的时序图 3.基于51单片机实现的代码 #include "reg52.h" #include "intrins.h" sbit led1P3^7;//小于10,led1亮,led2灭 sbit led2P3^6;//否则,led1灭,led2亮 sbit trigP1^5; sbit echo…

基于51单片机抽奖系统

基于51单片机抽奖系统 (仿真+程序) 功能介绍 具体功能: 1.利用5片74HC495对单片机的IO进行串并转换,进而控制5个1位数码管; 2.采用一个独立按键用于抽奖系统的启停控制; 3.8位拨码开关是用…

地推利器Xinstall:全方位二维码统计,打造高效地推策略,轻松掌握市场脉搏!

在移动互联网时代,地推作为一种传统的推广方式,依然占据着重要的地位。然而,随着市场竞争的加剧,地推也面临着诸多挑战,如如何有效监测下载来源、解决填码和人工登记的繁琐、避免重复打包和iOS限制、以及如何准确考核推…

Linux基础二

目录 一,tail查看文件尾部指令 二,date显示日期指令 三,cal查看日历指令 四,find搜索指令 五,grep 查找指令 六,> 和>> 重定向输出指令 七, | 管道指令 八,&&逻辑控…

让你的Python代码更简洁:一篇文章带你了解Python列表推导式

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 列表推导式 📒📝 语法📝 条件筛选📝 多重循环📝 列表推导式的优点📝 使用场景📝 示例代码🎯 示例1🎯 示例2⚓️ 相关链接 ⚓️📖 介绍 📖 在Python编程中,列表推导式是一种强大且高效的语法,它允许你用…

江协科技51单片机学习- p14 调试LCD1602显示屏

前言: 本文是根据哔哩哔哩网站上“江协科技51单片机”视频的学习笔记,在这里会记录下江协科技51单片机开发板的配套视频教程所作的实验和学习笔记内容。本文大量引用了江协科技51单片机教学视频和链接中的内容。 引用: 51单片机入门教程-2…

YouTube API接口:一键获取Playlist视频合集信息

核心功能介绍 在视频内容日益繁荣的今天,YouTube作为全球领先的视频分享平台,为内容创作者、品牌商家以及数据分析师提供了丰富的视频资源。其中,Playlist视频合集作为YouTube上的一种特色内容形式,深受用户喜爱。为了更好地满足…

cpolar:通过脚本自动更新主机名称和端口号进行内网穿透【免费版】

cpolar 的免费版经常会重新分配 HostName 和 Port,总是手动修改太过麻烦,分享一下自动更新配置文件并进行内网穿透的方法。 文章目录 配置 ssh config编写脚本获取 csrf_token打开登陆界面SafariChrome 设置别名 假设你已经配置好了服务器端的 cpolar。 …