pytorch集智-6手写数字加法机-迁移学习

1 概述

迁移学习概念:将已经训练好的识别某些信息的网络拿去经过训练识别另外不同类别的信息

优越性:提高了训练模型利用率,解决了数据缺失的问题(对于新的预测场景,不需要大量的数据,只需要少量数据即可实现训练,可用于数据点很少的场景)

如何实现:将训练好的一个网络拿来和另一个网络连起来去训练即可实现迁移

训练方式:按是否改变源网络参数可分两类,分别是可改变和不可改变

2 案例 南非贫困预测

2.1 背景

南非存在贫困,1990-2021贫困人口从56%下降到43%,但下降的贫困人口数量和国际人道主义援助资源并不对应,而且大量资金援助一定程度加剧了贫富差距。可以看下具体哪些地区需要援助

2.2 方法

一个方法:夜光光亮遥感数据和人类gdp相关性经实验可达0.8-0.9,但夜光遥感和贫富没太大相关性:夜间光照月亮表示该地区越富有,但越安并不表示该地区越贫穷,也可能无人居住。

另一个方法:光亮遥感数据无法准确预测地区贫穷程度,但卫星遥感数据大体可以做到,判定依据有街道混乱程度等。如果要用深度网络训练,还需要对卫星遥感数据的图片标注贫困程度。非洲能获取到的贫困数据很少,但深度网络需要的数据量很大

最终方法:用迁移学习,将前两种方法合起来,见下图

3 案例2

3.1 背景

任务:区分图像里动物是蚂蚁还是蜜蜂,像素均为224x224

难点:只有244个图像,样本太少不足训练大型卷积网络,准确率只有50%左右

3.2 解决方案

解决方案:resnet与模型迁移,即用已训练好的物体分类的网络加全连接用来区分蚂蚁与蜜蜂

resnet:残差网络,对物体分类有较高精度

3.3 代码实现

3.3.1 准备数据

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as pyplot
import time
import copy
import os

data_path = 'pytorch/jizhi/figure_plus/data'
image_size = 224

class TranNet():
    def __init__(self):
        super(TranNet, self).__init__()
        
        self.train_dataset = datasets.ImageFolder(os.path.join(data_path, 'train'), transforms.Compose([
            transforms.RandomSizedCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]))
        self.verify_dataset = datasets.ImageFolder(os.path.join(data_path, 'verify'), transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]))
        self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=4, shuffle=True, num_workers=4)
        self.verify_loader = torch.utils.data.DataLoader(self.verify_dataset, batch_size=4, shuffle=True, num_workers=4)
        self.num_classes = len(self.train_dataset.classes)
    
    def exec(self):
        ...

def main():
    TranNet().exec()

if __name__ == '__main__':
    main()

3.3.2 模型迁移

    def exec(self):
        self.model_prepare()
        
    def model_prepare(self):
        net = models.resnet18(pretrained=True)
        
        # float net values
        num_features = net.fc.in_features
        net.fc = nn.Linear(num_features, 2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)
        
        # fixed net values
        '''
        for param in net.parameters():
            param.requires_grad = False
        num_features = net.fc.in_features
        net.fc = nn.Linear(num_features, 2)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.fc.parameters(), lr = 0.001, momentum=0.9)
        '''

3.3.3 gpu加速

特点:gpu速度快,但内存低,所以尽量减少在gpu中存储的数据,只用来计算就好

    def model_prepare(self):
        # jusge whether GPU
        use_cuda = torch.cuda.is_available()
        dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
        itype = torch.cuda.LongTensor if use_cuda else torch.LongTensor
        
        net = models.resnet18(pretrained=True)
        net = net.cuda() if use_cuda else net

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/335778.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

STM32407用汇顶的GT911触摸芯片调试实盘

这个配置很关键 代码 #include "stm32f4xx.h" #include "GT9147.h" #include "Touch.h" #include "C_Touch_I2C.h" #include "usart.h" #include "delay.h" #include "LCD.h" #incl…

Java String基础学习

目录 1、String的构造方法 2、String内存模型 3、字符串的比较 4、字符串的练习 1、用户登录系统 2、遍历字符串 3、统计字符次数 4、拼接字符串 5、字符串的反转 6、金额转换 7、手机号屏蔽 * 8、身份证信息查看 9、敏感词替换 5、StringBuilder 1、概念及练习…

新手也能看懂的【前端自动化测试入门】!

前言 最近在网上搜索前端自动化测试相关的文档,但是发现网上的文章都是偏使用,没有把一些基础概念说清楚,导致后续一口气遇到一些karma、Jasmine、jest、Mocha、Chai、BDD等词汇的时候很容易一头雾水,这次一方面整理一下收获的知…

YOLOv8改进 | 进阶实战篇 | 利用YOLOv8进行视频划定区域目标统计计数

一、本文介绍 Hello,各位读者,最近会给大家发一些进阶实战的讲解,如何利用YOLOv8现有的一些功能进行一些实战, 让我们不仅会改进YOLOv8,也能够利用YOLOv8去做一些简单的小工作,后面我也会将这些功能利用PyQt或者是pyside2做一些小的界面给大家使用。 在开始之前给大家推…

解决Spring Boot跨域问题(配置JAVA类)

什么是跨域问题 跨域问题指的是不同端口之间,使用 ajax 无法相互调用的问题。跨域问题本质是浏览器的一种保护机制,它是为了保证用户的安全,防止恶意网站窃取数据。 比如前端用的端口号为8081,后端用的端口号为8080,后…

Linux下安装docker

1、查看系统版本 Docker支持64位版本的CentOS 7和CentOS 8及更高版本,它要求Linux内核版本不低于3.10。查看Linux版本的命令这里推荐两种:lsb_release -a或cat /etc/redhat-release。 显然,当前Linux系统为CentOS7。再查一下内核版本是否不低…

SpringBoot+dynamic-datasource实现多数据源(msyql、sqlserver、postgresql)手动切换

场景 SpringBootMybatisPlusdynamic-datasources实现连接Postgresql和mysql多数据源: SpringBootMybatisPlusdynamic-datasources实现连接Postgresql和mysql多数据源-CSDN博客 上面实现通过注解和配置文件的方式去进行多数据源操作。 如果业务需求,比…

Ubuntu安装最新版Docker和Docker-Compose

ubuntu环境搭建专栏🔗点击跳转 Ubuntu系统环境搭建(十)——Ubuntu安装最新版Docker和Docker Compose 文章目录 Ubuntu系统环境搭建(十)——Ubuntu安装最新版Docker和Docker Compose1.添加Docker库1.1 安装必要的证书并…

PostgreSQL命令大全

文章目录 连接与退出数据库操作表操作外键约束视图操作存储过程与函数权限管理事务管理查询优化与分析数据类型转换分区表操作复制与备份恢复 PostgreSQL是一个功能强大的开源关系型数据库管理系统,以下是一些基本且常用的命令按功能分类: 连接与退出 连…

【前端设计】流光按钮

欢迎来到前端设计专栏,本专栏收藏了一些好看且实用的前端作品,使用简单的html、css语法打造创意有趣的作品,为网站加入更多高级创意的元素。 css body{height: 100vh;display: flex;justify-content: center;align-items: center;background…

HarmonyOS鸿蒙学习基础篇 - 项目目录和文件介绍

├── hvigor //存储购置信息的文件,主要用于发布打包 ├── idea //开发工具相关配置可忽略 ├── AppScope //工程目录 全局公共资源存放路径 │ └── resources │ │ └── base │ │ │ └── element //常亮存放 │ │ │ …

Windows安装WSL2精简版教程

文章目录 一、安装WSL二、更改WSL的存放路径/备份WSL三、安装WSL Terminall四、WSL界面:xlaunch五、WSL1升级WSL2六、WSL2与VMware兼容问题七、更改手动导入的wsl的默认登录用户参考 一、安装WSL 步骤1 - 启用适用于 Linux 的 Windows 子系统: 需要先启…

格局打开!前端未死,只是要求变高了

本文笔者会从以下几个方面分享,希望能够帮助正在迷茫的前端小伙伴提供一点思路! 逛技术博客 不局限框架 全栈工程师兴起 关注前沿 写技术文章 录制前端视频 总结 2024年了,没工作的找到工作了吗?有工作的加薪了吗?加薪…

UI设计中的插画运用优势(上)

1. 插画是设计的原创性和艺术性的基础 无论是印刷品、品牌设计还是UI界面,更加风格化的插画能够将不同的风格和创意加入其中,在激烈的竞争中更容易因此脱颖而出。留下用户才有转化。 2. 插画是视觉触发器,瞬间传达大量信息 我们常说「一图胜千…

Unity3D学习之Unity基础

文章目录 1. 第一部分:MONO中的重要内容2. 延时函数2.1 什么是延时函数2.2 延时函数的使用2.2.1 延时重复函数2.2.2 取消延迟函数2.2.3 判断是否有延迟函数开启2.2.4 延迟函数和依附对象的关系 3 协同程序3.1 Unity中的多线程3.2 协同程序3.3 协程的使用3.3.1 关闭协…

Jetson Orin Nano安装OpenCV带cuda加速版本的全过程

安装过程 使用jetpack安装的jetson,自带了opencv,但是没有cuda加速的,输入opencv_version 使用jtop查看,可以确认自带的opencv是没用cuda的 卸载opencv,先查看有哪些包 pip3 list | grep opencv opencv-python 然后卸…

从规则到神经网络:机器翻译技术的演化之路

文章目录 从规则到神经网络:机器翻译技术的演化之路一、概述1. 机器翻译的历史与发展2. 神经机器翻译的兴起3. 技术对现代社会的影响 二、机器翻译的核心技术1. 规则基础的机器翻译(Rule-Based Machine Translation, RBMT)2. 统计机器翻译&am…

文献分享四:(基础)Pyroelectric drift of integrated-optical LiNbO3 modulators

文献来源——S. M. Kostritskii, Yu. N. Korkishko, V. A. Fedorov & A. V. Yatsenko (2021) Pyroelectric drift of integrated-optical LiNbO3 modulators, Ferroelectrics, 574:1, 170-178, DOI: 10.1080/00150193.2021.1888062 本文分析了我们之所以需要偏置电路的原因…

力扣刷MySQL-第六弹(详细讲解)

🎉欢迎您来到我的MySQL基础复习专栏 ☆* o(≧▽≦)o *☆哈喽~我是小小恶斯法克🍹 ✨博客主页:小小恶斯法克的博客 🎈该系列文章专栏:力扣刷题讲解-MySQL 🍹文章作者技术和水平很有限,如果文中出…

【Linux、C】服务程序的调度

这一节中用到的技术点有&#xff1a;信号、多进程、exec函数 exec函数族 exec的本质 exec是用参数中指定的程序替换了当前进程的正文段、数据段、堆和栈。 如果调用失败会继续运行当前进程 #include <stdio.h> #include <stdlib.h> #include <string.h> #in…