第八章 模型篇:transfer learning for computer vision

参考教程:
transfer-learning
transfer-learning tutorial

文章目录

  • transfer learning
    • 对卷积网络进行finetune
    • 把卷积网络作为特征提取器
    • 何时、如何进行fine tune
  • 代码示例
    • 加载数据集
    • 构建模型
      • fine-tune 模型
      • 模型作为feature extractor
    • 定义train_loop和test_loop
    • 定义超参数,开始训练
    • 结果可视化

transfer learning

很少会有人从头开始训练一个卷积神经网络,因为并不是所有人都有机会接触到大量的数据。常用的选择是在一个非常大的模型上预训练一个模型,然后用这个模型为基础,或者固定它的参数用作特征提取,来完成特定的任务。

对卷积网络进行finetune

进行transfer-learning的一个方法是在基于大数据训练的模型上进行fine-tune。可以选择对模型的每一个层都进行fine-tune,也可以选择freeze特定的层(一般是比较浅的层)而只对模型的较深的层进行fine-tune。理论支持是,模型的浅层通常是一些通用的特征,比如edge或者colo blob,这些特征可以应用于多种类型的任务,而高层的特征则会更倾向于用于训练的原始数据集中的数据特点,因为不太能泛化到新数据上去。

把卷积网络作为特征提取器

将ConvNet作为一个特征提取器,通常是去掉它最后一个用于分类的全连接层,把剩余的层用来提取新数据的特征。你可以在该特征提取器后加上你自己的head,比如分类head或者回归head,用于完成你自己的任务。

何时、如何进行fine tune

使用哪种方法有多种因素决定,最主要的因素是你的新数据集的大小和它与原始数据集的相似度。

  • 当你的新数据集很小,并和原始数据集比较相似时。
    因为你的数据集很小,所以从过拟合的角度出发,不推荐在卷积网络上进行fine-tune。又因为你的数据和原始数据比较相似,所以卷积网络提取的高层特征和你的数据也是相关的。因此你可以直接卷积网络当作特征提取器,在此基础上训练一个线性分类器。
  • 当你的新数据集很大,并和原始数据集比较相似时。
    新数据集很大时,我们可以对整个网络进行fine-tune,因为我们不太会有过拟合的风险。
  • 当你的新数据集很小,并和原始数据集不太相似时。
    因为你的数据集很小,我们还是推荐只训练一个线性的分类器。但是新数据和原始数据又不相似,所以不建议在网络顶端接上新的分类器,因为网络顶端包含很多的dataset-specific的特征,所以更推荐的是从浅层网络的一个位置出发接上一个分类器。
  • 当你的新数据集很大,并和原始数据集不太相似时。
    因为你的数据集很大,我们仍然选择对整个网络进行fine-tune。因为通常情况下以一个pretrained-model对模型进行初始化的效果比随机初始化要好。

代码示例

我们使用与第四章 模型篇:模型训练与示例一样的流程进行模型训练。

加载数据集

首先是加载数据集,方便起见我们直接使用torchvision中的cifar10数据进行训练。

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

training_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=transform
)


test_data = datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

train_dataloader = DataLoader(training_data, batch_size = 64)
test_dataloader = DataLoader(test_data, batch_size = 64)

使用官方提供的代码对我们的dataset进行可视化,注意训练时使用的batchsize为64,这里可视化时为了方便暂时使用了batchsize=4。
在这里插入图片描述

构建模型

在第四章中我们用了自定义的model。在这里我们使用预训练好的模型,并对模型结构进行修改。

transfer-learning对模型的处理有两种,一种是fine-tune整个模型,一种是将模型作为feature-extractor。第二种和第一种的区别是,模型中的部分层被freeze,不在训练过程中更新。

fine-tune 模型

model_ft = models.resnet18(weights = 'IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 10) # 因为cifar10是十分类,所以输出这里为10

模型作为feature extractor

model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')
for param in model_conv.parameters():
    param.requires_grad = False  # requires_grad 设为False,不随训练更新

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 10)

定义train_loop和test_loop

这两个部分直接参考第四章的代码就可以,复制过来直接使用。

# 训练过程的每个epoch的操作,代码来自pytorch_tutorial
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        optimizer.zero_grad() # 重置梯度计算
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward() # 反向传播计算梯度
        optimizer.step() # 调整模型参数
        

        if batch % 10 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

定义超参数,开始训练

全都准备好以后,我们定义一下要使用的优化器和loss,和一些别的超参数,就可以开始训练了。

learning_rate = 1e-3
momentum=0.9
epochs = 20

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate,momentum=momentum)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model_ft, loss_fn, optimizer)
    test_loop(test_dataloader, model_ft, loss_fn)
print("Done!")

因为是在个人pc跑的,所以就随便放一个效果。。。。。
在这里插入图片描述

结果可视化

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/31840.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Docker常见使用

Docker常见使用 1、Docker安装 ## 下载阿里源repo文件 $ curl -o /etc/yum.repos.d/Centos-7.repo http://mirrors.aliyun.com/repo/Centos-7.repo $ curl -o /etc/yum.repos.d/docker-ce.repo http://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo$ yum clean …

工业相机的镜头如何选择?

相机的镜头如何计算,如果看公式的话,需要知道相机sensor的尺寸,相元的尺寸,计算起来数据也比较复杂,下面教大家一个简单的方法,就是如何借助镜头计算工具来使用。 巴斯勒相机的镜头选型地址 工业镜头选型…

操作系统之死锁详解

本文已收录于专栏 《自考》 目录 背景介绍死锁的前提死锁的概念死锁的分类死锁的产生原因条件 死锁的解决预防避免检测与恢复 死锁的实现总结提升 背景介绍 最近一直在做操作系统的测试题,在做题的过程中发现有很多地方涉及到了关于死锁的知识点。今天就回归课本来自…

哈工大计算机网络课程网络层协议详解之:网络地址转换NAT

哈工大计算机网络课程网络层协议详解之:网络地址转换NAT 文章目录 哈工大计算机网络课程网络层协议详解之:网络地址转换NAT网络地址转换(NAT)NAT实现原理NAT穿透问题NAT穿透问题的解决方案 上一节中,我们在DHCP协议中介…

【人脸检测——基于机器学习4】HOG特征

前言 HOG特征的全称是Histograms of Oriented Gradients,基于HOG特征的人脸识别算法主要包括HOG特征提取和目标检测,该算法的流程图如下图所示。本文主要讲HOG特征提取。 HOG特征的组成 Cell:将一幅图片划分为若干个cell(如上图绿色框所示),每个cell为8*8像素 Block:选…

【力扣刷题 | 第十四天】

目录 前言: 7. 整数反转 - 力扣(LeetCode) 面试题 16.05. 阶乘尾数 - 力扣(LeetCode) 总结; 前言: 今天仍然是无固定类型刷题, 7. 整数反转 - 力扣(LeetCode) 给你…

Android大图加载优化方案,避免程序OOM

我们在编写Android程序的时候经常要用到许多图片,不同图片总是会有不同的形状、不同的大小,但在大多数情况下,这些图片都会大于我们程序所需要的大小。比如微博长图,海报等等。所以我们就要对图片进行局部显示。 大图加载基本需求…

Redis入门(5)-set

Redis中set的元素具有无序性与不可重复性 1.sadd key member[member] 添加元素,若元素存在返回0若不存在则添加 sadd DB mysql oracle sadd DB mysql sadd DB db22.smembers key 查看set中所有元素 smembers DB3.sismember key member 判断元素在set中是否存…

GIS 功能模块实现

文章目录 1. GIS 模块流程图2. 网页端地图缓存的实现3. GIS 图形操作功能实现1 )地图漫游2 )对象删除3 )选择复制属性查看 GIS 基本功能模块主要是在表现层开发的,是在OpenLayers 开发框架提供的接口上,通过Geo Server…

智芯MCU软件开发环境搭建

智芯MCU软件开发环境搭建 目录 智芯MCU软件开发环境搭建前言1 软件安装2 编译环境3 烧录环境4 新建工程结束语 前言 智芯科技的MCU主要应用于汽车行业,属于车规级的MCU,目前上市的MCU型号较少,相关资料也不多,所以这里出一期开发…

uniapp实现tab切换可以滚动的效果

实现效果 当 tab 切换的内容很多时,需要用到滚动,希望在点击 tab 的时候可以自动滑动到对应的tab下 知识点 scrollIntoView:该scrollIntoView()方法将调用它的元素滚动到浏览器窗口的可见区域。 语法 element.scrollIntoView&#xff08…

【kubernetes】部署controller-manager与kube-scheduler

前言:二进制部署kubernetes集群在企业应用中扮演着非常重要的角色。无论是集群升级,还是证书设置有效期都非常方便,也是从事云原生相关工作从入门到精通不得不迈过的坎。通过本系列文章,你将从虚拟机准备开始,到使用二进制方式从零到一搭建起安全稳定的高可用kubernetes集…

记录正式环境测试环境【RedHat7编译升级redis7.0.9】--有关报错及解决

记录正式环境&测试环境【RedHat7 编译升级redis7.0.9】--有关报错及解决 🔻 一、报错详情1.1 ⛳ 写在前面1.2 ⛳ 报错11.3 ⛳ 报错21.4 ⛳ 安装redis1.5 ⛳ 版本检查 🔻 二、⛳ 总结 🔻 一、报错详情 1.1 ⛳ 写在前面 🍁 升级…

王道计算机网络学习笔记(3)——数据链路层

前言 文章中的内容来自B站王道考研计算机网络课程,想要完整学习的可以到B站官方看完整版。 三:数据链路层 3.1:数据链路层功能概述 结点:主机、路由器 链路:网络中两个结点之间的物理通道,链路的传输介…

【DeepLearning】Ubuntu中深度学习环境配置完整流程

Ubuntu中深度学习环境配置完整流程 1 显卡驱动2 cuda3 cuDNN4 torch5 torchvision 1 显卡驱动 支持 cuda 的所有显卡型号: Link 查询显卡型号 lspci -nn | grep VGA即 Vendor ID:Device ID 为 10de:21c4,在浏览器或者 Link 中搜索。 填写显卡信息: Link 选择要下载…

数据结构——快速排序的介绍

快速排序 快速排序是霍尔(Hoare)于1962年提出的一种二叉树结构的交换排序方法。快速排序是一种常用的排序算法,其基本思想是通过选择一个元素作为"基准值",将待排序序列分割成两个子序列,其中一个子序列的元素都小于等于基准值&am…

SpringBoot集成WebSocket实现消息实时推送(提供Gitee源码)

前言:在最近的工作当中,客户反应需要实时接收消息提醒,这个功能虽然不大,但不过也用到了一些新的技术,于是我这边写一个关于我如何实现这个功能、编写、测试到部署服务器,归纳到这篇博客中进行总结。 目录 …

【计算机网络自顶向下】计算机网络期末自测题(一)

前言 “(学不懂一点) (阴暗的爬行)(尖叫)(扭曲)(阴暗的爬行)(尖叫)(扭曲)(阴暗的爬行)(尖叫&#…

LeetCode·1262. 可被三整除的最大和·贪心

作者:小迅 链接:https://leetcode.cn/problems/greatest-sum-divisible-by-three/solutions/2314049/tan-xin-zhu-shi-chao-ji-xiang-xi-by-xun-r0n76/ 来源:力扣(LeetCode) 著作权归作者所有。商业转载请联系作者获得…

vscode 调试

目录 准备 GDB 调试方法 问题 准备 然后点击 文件-打开文件夹,找到创建的代码路径,确定后,在左侧的资源管理器可以看到代码文件。 第一次运行需要安装 c 的扩展,在扩展页面中,安装 C/C 编译注意一定要加上 -g 指令…