【深度学习 pytorch】迁移学习 (迁移ResNet18)

李宏毅深度学习笔记
《深度学习原理Pytorch实战》
https://blog.csdn.net/peter6768/article/details/135712687

迁移学习

实际应用中很多任务的数据的标注成本很高,无法获得充足的训练数据,这种情况可以使用迁移学习(transfer learning)。假设A、B是两个相关的任务,A任务有很多训练数据,就可以把从A任务中学习到的某些可以泛化知识迁移到B任务。

有了迁移学习技术,我们便可以将神经网络像软件模块一样进行拼装和重复利用。例如,我们可以将在大数据集上训练好的大型网络迁移到小数据集上,从而只需经过少量的训练就能达到良好的效果。我们也可以将两个神经网络同时迁移过来,组合成一个新的网络,这两个神经网络就像软件模块一样被组合了起来。

监督学习要求训练集和测试集上的数据具有相同的分布特性。两个数据集具有相同的分布就意味着,两者中猫和狗的比例(甚至胖猫、瘦猫和胖狗、瘦狗的分布比例)大体相同,这样才能保证模型在训练集中学习到稳定的特征,从而应用到测试集中。
而迁移学习则不同,它允许训练集和测试集的数据有不同的分布、目标,甚至领域。

迁移学习方式:

  • 预训练模式:将迁移过来的权重视作新网络的初始权重,但是它在训练的过程中会被梯度下降算法改变数值。
  • 固定值模式:迁移过来的部分网络在结构和权重上都保持固定的数值,训练过程仅针对迁移模块后面的全连接网络。

贫困预测场景

一个地区的遥感图像大体能够反映该地区的贫困状况,因为贫困地区的街道布置往往更加混乱。但是,要想训练一个深度卷积神经网络来预测贫困地区,除了需要大量输入图像,还需要对每一张图像进行贫困程度的标注。
由于非洲贫困地区可获得的贫困数据非常少,仅有大概600多个数据点,这对于训练一个大型的卷积神经网络来说远远不够(对于一个8层的卷积神经网络来说,训练数据的量级至少要达到数百万)。所以,我们需要对遥感图像进行手工标注。然而,这种标注工作量巨大,简直就是一个不可能完成的任务。

应用迁移学习技术,解决标注数据缺失问题的方法。首先,训练一个卷积神经网络,用遥感图像来预测夜光亮度;然后,将训练好的网络迁移到运用遥感图像预测贫困地区的任务中,这样即使训练数据仅有几百个,我们照样可以进行更准确的预测。
在这里插入图片描述

1、首先,他们使用预训练的方法,将用于图像分类的大型卷积神经网络VGGF(包含8个卷积层)迁移过来,作为初始分类网络。物体分类网络VGGF是经过干万张图像训练而成的,它已经学会了如何对图像进行特征提取,例如提取物体的边缘等。
2、其次,在预训练好的VGGF网络上,应用卫星遥感影像数据和夜光影像数据对其进行训练。该模型的输入为某地区的卫星遥感图像,输出为该地区夜间明暗程度的预测。由于夜光数据很容易获得,因此将它作为标签,可以轻松获得数十万个成对的训练数据。另外,当卷积神经网络尝试预测夜光时,它需要学会有效地从卫星遥感图像中提取特定的特征,例如街道、房屋屋顶、混凝土建筑等。这样学到的网络就是一个能从卫星遥感图像中有效提炼特征的特征提取器。
3、然后,我们将用于预测夜光的神经网络的卷积层迁移过来,拼接一个新的全连接网络,用于预测一个地区的贫困程度。在这一部分,我们将采取固定值迁移方法,仅训练全连接网部分。这样便可以应用仅有的数百个贫困数据来训练这个预测器。在这一步,我们相当于在原始图像中提取有关特征,据此预测贫困程度。

分类问题(蚂蚁还是蜜蜂)

一是这些图像极其复杂,人类肉眼都不太容易一下子区分出画面中是蚂蚁还是蜜蜂,简单的卷积神经网络无法应付这个分类任务
二是整个训练集仅有244个样本,这么小的数据量无法训川练大的卷积神经网络。

我们把ResNet18中的卷积模块作为特征提取层迁移过来,用于提取局部特征。同时,构建一个包含512个隐含节点的全连接层,后接两个节点的输出层,用于最后的分类输出,最终构建一个包含20层的深度网络。

数据加载

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as pyplot
import time
import copy
import os
 
data_path = 'data'
image_size = 224


class TranNet():
    def __init__(self):
        super(TranNet, self).__init__()
        
#加载的过程将会对图像进行如下增强操作:
#1.随机从原始图像中切下来一块224×224大小的区域
#2.随机水平翻转图像
#3.将图像的色彩数值标准化 
        self.train_dataset = datasets.ImageFolder(os.path.join(data_path, 'train'), transforms.Compose([
            transforms.RandomSizedCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]))
#加载校验数据集,对每个加载的数据进行如下处理:
#1.放大到256×256
#2.从中心区域切割下224×224大小的区域
#3.将图像的色彩数值标准化
        self.verify_dataset = datasets.ImageFolder(os.path.join(data_path, 'verify'), transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]))`在这里插入代码片`
#创建相应的数据加载器
        self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=4, shuffle=True, num_workers=4)
        self.verify_loader = torch.utils.data.DataLoader(self.verify_dataset, batch_size=4, shuffle=True, num_workers=4)
        self.num_classes = len(self.train_dataset.classes)
net = models.resnet18(pretrained=True)

可以使用预训练的方式将这个网络迁移过来:

        net = models.resnet18(pretrained=True)
        num_features = net.fc.in_features
        net.fc = nn.Linear(num_features, 2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)

其中,num_features存储了ResNet18最后的全连接层的输入神经元个数。事实上,以上代码所做的就是将原来的ResNet18最后两层全连接层替换成一个输出单元为2的全连接层,这就是net.fc。之后,我们按照普通的方法定义损失函数和优化器。因此,这个模型首先会利用ResNet预训练好的权重,提取输入图像中的重要特征,然后利用net.fc这个线性层,根据输入特征进行分类。

当使用固定值的方式进行迁移的时候,可以使用下列代码:

net = models.resnet18(pretrained=True)
for param in net.parameters():
	param.requires_grad=False
num_features = net.fc.in_features
net.fc = nn.Linear(num_features, 2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)

gpu加速

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
itype = torch.cuda.LongTensor if use_cuda else torch.LongTensor
net = models.resnet18(pretrained=True)
#如果存在GPU,就将网络加载到GPU上
net = net.cuda() if use_cuda else net

将训练数据加载到gpu上

#将数据复制出来,然后加载到GPU上
data,target=data.clone().detach().requires_grad(True),target.clone().detach()
if use_cuda:
  data,target=data.cuda(),target.cuda()

最后,我们使用.cpu()将GPU上的计算结果再次转回内存中:

#待计算完成后,需将数据放回CPU
loss=loss.cpu() if use_cuda else loss

训练

    def model_prepare(self):
        net = models.resnet18(pretrained=True)
        
        # jusge whether GPU
        use_cuda = torch.cuda.is_available()
        if use_cuda:
            dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
            itype = torch.cuda.LongTensor if use_cuda else torch.LongTensor
            net = net.cuda() if use_cuda else net
        
        # float net values
        num_features = net.fc.in_features
        net.fc = nn.Linear(num_features, 2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)
        
        # fixed net values
        '''
        for param in net.parameters():
            param.requires_grad = False
        num_features = net.fc.in_features
        net.fc = nn.Linear(num_features, 2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.fc.parameters(), lr = 0.001, momentum=0.9)
        '''
        record = []
        num_epochs = 20
        net.train(True) # open dropout  给网络模型做标记,说明模型在训练集上训练
        for epoch in range(num_epochs):
            train_rights = []
            train_losses = []
            for batch_index, (data, target) in enumerate(self.train_loader):
                data, target = data.clone().detach().requires_grad_(True), target.clone().detach()
                output = net(data) #完成一次预测
                loss = criterion(output, target)  #计算误差
                optimizer.zero_grad() #清空梯度
                loss.backward()  # 反向传播
                optimizer.step() #随机梯度下降
                right = rightness(output, target)  #计算准确率所需数值,返回数值为(正确样例数,总样本数)
                train_rights.append(right)
                train_losses.append(loss.data.numpy())
                if batch_index % 400 == 0:
                    verify_rights = []
                    for index, (data_v, target_v) in enumerate(self.verify_loader):
                        data_v, target_v = data_v.clone().detach(), target_v.clone().detach()
                        output_v = net(data_v)
                        right = rightness(output_v, target_v)
                        verify_rights.append(right)
                    verify_accu = sum([row[0] for row in verify_rights]) / sum([row[1] for row in verify_rights])
                    record.append((verify_accu))
                    print(f'verify data accu:{verify_accu}')
        # plot
        pyplot.figure(figsize=(8, 6))
        pyplot.plot(record)
        pyplot.xlabel('step')
        pyplot.ylabel('verify loss')
        pyplot.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/797086.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Dify中高质量索引模式时,通过线程池处理chunk过程

本文主要介绍了Dify中高质量索引模式时,如何通过线程池执行器来处理chunk的过程。源码位置:dify\api\core\indexing_runner.py\IndexingRunner._load。核心思想:假设一个数据集中有一个文档,该文档可以拆分为12个段(segment)。如果chunk_size=10,那么分为2批提交给线程池…

C语言——流程控制:if...else、switch...case

控制类语句: 逻辑运算符: 选择语句: if...else: if()括号内的内容终究会被转换成0,1,满足的话即为1,不满足的话为0。因此要注意,()括号内因为条件…

智慧城市3d数据可视化系统提升信息汇报的时效和精准度

在信息大爆炸的时代,数据的力量无可估量。而如何将这些数据以直观、高效的方式呈现出来,成为了一个亟待解决的问题。为此,我们推出了全新的3D可视化数据大屏系统,让数据“跃然屏上”,助力您洞察先机,决胜千…

【图解大数据技术】流式计算:Spark Streaming、Flink

【图解大数据技术】流式计算:Spark Streaming、Flink 批处理 VS 流式计算Spark StreamingFlinkFlink简介Flink入门案例Streaming Dataflow Flink架构Flink任务调度与执行task slot 和 task EventTime、Windows、WatermarksEventTimeWindowsWatermarks 批处理 VS 流式…

【我的机器学习之旅】打造个性化手写汉字识别系统

在当今这个数字化飞速发展的时代,人工智能(AI)已经渗透到我们生活的每一个角落,从智能家居到自动驾驶,无一不彰显着其强大的潜力和无限可能。作为一名计算机科学与技术专业的学生,我的毕业设计选择了一个既…

【java】力扣 合并k个升序链表

文章目录 题目链接题目描述思路代码 题目链接 23.合并k个升序链表 题目描述 给你一个链表数组,每个链表都已经按升序排列。 请你将所有链表合并到一个升序链表中,返回合并后的链表 思路 我在这个题里面用到了PriorityQueue(优先队列) 的知识 Prio…

网络安全----防御----防火墙nat以及智能选路

前面要求在前一篇博客 网络安全----防御----防火墙安全策略组网-CSDN博客 7,办公区设备可以通过电信链路和移动链路上网(多对多的NAT,并且需要保留一个公网IP不能用来转换) 8,分公司设备可以通过总公司的移动链路和电信链路访问到Dmz区的ht…

本人学习保存-macOS打开Navicat提示「“Navicat Premium”已损坏,无法打开。 你应该将它移到废纸篓。」的解决方法

新安装了macOS Ventura,打开Navicat Premium,发现会提示: “Navicat Premium”已损坏,无法打开。 你应该将它移到废纸篓。 遇到这种情况,千万别直接移到废纸篓,是有办法解决的。在这里记录一下解决方案。 …

据传 OpenAI秘密研发“Strawberry”项目

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

2、ASPX、.NAT(环境/框架)安全

ASPX、.NAT&#xff08;环境/框架&#xff09;安全 源自小迪安全b站公开课 1、搭建组合&#xff1a; WindowsIISaspxsqlserver .NAT基于windows C开发的框架/环境 对抗Java xx.dll <> xx.jar 关键源码封装在dll文件内。 2、.NAT配置调试-信息泄露 功能点&#xf…

阿里ChatSDK使用,开箱即用聊天框

介绍&#xff1a; 效果&#xff1a;智能助理 ChatSDK&#xff0c;是在ChatUI的基础上&#xff0c;结合阿里云智能客服的最佳实践&#xff0c;沉淀和总结出来的一个开箱即用的&#xff0c;可快速搭建智能对话机器人的框架。它简单易上手&#xff0c;通过简单的配置就能搭建出对…

C++ 入门基础:开启编程之旅

引言 C 是一种高效、灵活且功能强大的编程语言&#xff0c;广泛应用于系统软件、游戏开发、嵌入式系统、科学计算等多个领域。作为 C 语言的扩展&#xff0c;C 不仅继承了 C 语言的过程化编程特性&#xff0c;还增加了面向对象编程&#xff08;OOP&#xff09;的支持&#xff…

frp内网穿透ssh,tcp经过服务器慢速和p2p模式实现高速吃满上传带宽

ssh_server aliyun_server ssh_client 办公室 云服务器 家 在家里经过云服务器中转&#xff0c;很慢&#xff0c;但是很稳定 使用p2p穿透&#xff0c;速度可以直接拉满 ssh_server cc.ini # 连接服务器配置 [common] server_addr 1…

实用机器学习(快速入门)

前言 因为需要机器学习的助力&#xff0c;所以&#xff08;浅浅&#xff09;进修了一下。现在什么东西和AI结合一下感觉就好发文章了&#xff1b;我看了好多学习视频&#xff0c;发现机器学习实际上是数学&#xff0c;并不是常规的去学习代码什么的&#xff08;虽然代码也很简…

【学习笔记】无人机(UAV)在3GPP系统中的增强支持(五)-同时支持无人机和eMBB用户数据传输的用例

引言 本文是3GPP TR 22.829 V17.1.0技术报告&#xff0c;专注于无人机&#xff08;UAV&#xff09;在3GPP系统中的增强支持。文章提出了多个无人机应用场景&#xff0c;分析了相应的能力要求&#xff0c;并建议了新的服务级别要求和关键性能指标&#xff08;KPIs&#xff09;。…

昇思25天学习打卡营第13天 | ResNet50迁移学习

昇思25天学习打卡营第13天 | ResNet50迁移学习 文章目录 昇思25天学习打卡营第13天 | ResNet50迁移学习数据集加载数据集数据集可视化 模型训练固定特征 总结打卡 在实际应用场景中&#xff0c;由于训练数据集不足&#xff0c;很少从头开始训练整个网络。普遍做法是在一个非常大…

Qt下使用OpenCV的鼠标回调函数进行圆形/矩形/多边形的绘制

文章目录 前言一、设置imshow显示窗口二、绘制圆形三、绘制矩形四、绘制多边形五、示例完整代码总结 前言 本文主要讲述了在Qt下使用OpenCV的鼠标回调在OpenCV的namedWindow和imshow函数显示出来的界面上进行一些图形的绘制&#xff0c;并最终将绘制好的图形显示在QLabel上。示…

算法笔记——LCR

一.LCR 152. 验证二叉搜索树的后序遍历序列 题目描述&#xff1a; 给你一个二叉搜索树的后续遍历序列&#xff0c;让你判断该序列是否合法。 解题思路&#xff1a; 根据二叉搜索树的特性&#xff0c;二叉树搜索的每一个结点&#xff0c;大于左子树&#xff0c;小于右子树。…

到底哪些牌子的鼠标好?选择鼠标需要注意哪些问题?

鼠标的选择从外观材质、手感、配置到价格定位都不尽相同&#xff0c;消费者的选择也越来越多。一般在选择鼠标时&#xff0c;我们也会发现鼠标能够选择的品牌虽然众多&#xff0c;但是不同品牌下的鼠标在品质和款式上都是大不相同的&#xff0c;那么到底哪些牌子的鼠标好呢?我…

【Python】数据分析-Matplotlib绘图

数据分析 Jupyter Notebook Jupyter Notebook: 一款用于编程、文档、笔记和展示的软件。 启动命令&#xff1a; jupyter notebookMatplotlib 设置中文格式&#xff1a;plt.rcParams[font.sans-serif] [KaiTi] # 查看本地所有字体 import matplotlib.font_manager a sorted…