迁移学习实现图片分类任务

导入工具包

import time
import os

import numpy as np
from tqdm import tqdm

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
%matplotlib inline

# 忽略烦人的红色提示
import warnings
warnings.filterwarnings("ignore")

获取计算硬件

# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

图片预处理

from torchvision import transforms

# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                     ])

# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

这里对train训练集和text集的处理不同,几个transforms的操作通过compose进行整合。

载入图片分类数据集

# 数据集文件夹路径
dataset_dir = 'fruit30_split'

train_path = os.path.join(dataset_dir, 'train')
test_path = os.path.join(dataset_dir, 'val')
print('训练集路径', train_path)
print('测试集路径', test_path)

from torchvision import datasets

# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)

# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)


print('训练集图像数量', len(train_dataset))
print('类别个数', len(train_dataset.classes))
print('各类别名称', train_dataset.classes)

print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)

datasets下的ImageFolder,可以直接构建数据集。

类别与索引号一一对应

class_names = train_dataset.classes
n_class = len(class_names)


# 映射关系:类别 到 索引号
train_dataset.class_to_idx

定义数据加载器Dataloader,dataloader用于给模型喂数据。

from torch.utils.data import DataLoader

BATCH_SIZE = 32

# 训练集的数据加载器
train_loader = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=4
                         )

# 测试集的数据加载器
test_loader = DataLoader(test_dataset,
                         batch_size=BATCH_SIZE,
                         shuffle=False,
                         num_workers=4
                        )

查看一个batch的图像与标注

# DataLoader 是 python生成器,每次调用返回一个 batch 的数据
images, labels = next(iter(train_loader))

images. Shape
#torch.Size([32, 3, 224, 224])
labels
#tensor([11, 19,  3, 25, 29, 13, 21, 18, 11,  1, 13, 15, 13,  0, 15, 25,  0,  7,11, 10,  9,  6, 26,  2, 11, 10, 29, 29, 15,  8, 19,  8])

迁移学习范式

导入训练所用的工具包

from torchvision import models
import torch.optim as optim
model = models.resnet18(pretrained=True) # 载入预训练模型
# 修改全连接层,使得全连接层的输出与当前数据集类别数对应
# 新建的层默认 requires_grad=True
model.fc = nn.Linear(model.fc.in_features, n_class)
model.fc
Linear(in_features=512, out_features=30, bias=True)
# 只微调训练最后一层全连接层的参数,其它层冻结
optimizer = optim.Adam(model.fc.parameters())

采用第一种迁移学习的方式,优化器采用的是Adam的优化器。

训练配置

model = model.to(device)

# 交叉熵损失函数
criterion = nn.CrossEntropyLoss() 

# 训练轮次 Epoch
EPOCHS = 20

模拟一个batch的训练

这里着重注意反向传播三部曲

# 反向传播“三部曲”
optimizer.zero_grad() # 清除梯度
loss.backward() # 反向传播
optimizer.step() # 优化更新

 运行完整训练

# 遍历每个 EPOCH
for epoch in tqdm(range(EPOCHS)):

    model. Train() #每次开始前将模型设置为训练模式

    for images, labels in train_loader:  # 获取训练集的一个 batch,包含数据和标注
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)           # 前向预测,获得当前 batch 的预测结果
        loss = criterion(outputs, labels) # 比较预测结果和标注,计算当前 batch 的交叉熵损失函数
        
        optimizer.zero_grad()
        loss.backward()                   # 损失函数对神经网络权重反向传播求梯度
        optimizer.step()                  # 优化更新神经网络权重

在测试集上进行初步测试

model.eval() #模型设置为测试模式
with torch.no_grad(): #不再回传梯度
    correct = 0
    total = 0
    for images, labels in tqdm(test_loader): # 获取测试集的一个 batch,包含数据和标注
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)              # 前向预测,获得当前 batch 的预测置信度
        _, preds = torch.max(outputs, 1)     # 获得最大置信度对应的类别,作为预测结果
        total += labels.size(0)
        correct += (preds == labels).sum()   # 预测正确样本个数,如果预测类别等于标注类别

    print('测试集上的准确率为 {:.3f} %'.format(100 * correct / total))

保存模型

torch.save(model, 'checkpoint/fruit30_pytorch_C1.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/362192.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

11.Ubuntu

目录 1. 什么是Ubuntu 1.1. 概述 1.2. Ubuntu版本简介 1.2.1. 桌面版 1.2.2. 服务器版 2. 部署系统 2.1. 新建虚拟机 2.2. 安装系统 2.3. 部署后的设置 2.3.1. 设置root密码 2.3.2. 关闭防火墙 2.3.3. 启用允许root进行ssh 2.3.4. 安装所需软件 2.3.5. 制作快照 …

企业图纸数据防泄密方案(图纸防泄密的几个实用方法)

在当今的商业环境中,企业图纸数据的重要性日益凸显。无论是产品设计、研发、生产还是项目管理,图纸都承载着企业的核心知识产权和商业机密。 然而,随着信息技术的迅猛发展,企业图纸数据泄密事件屡屡发生,给企业带来了…

【20240131】USB相机(查看设备列表、打开设备)

USB相机采集 1、v4l2查看设备列表2、查看具体设备信息3、在桌面打开USB相机 1、v4l2查看设备列表 打开终端,输入:v4l2-ctl --list-devices usb设备在Webcam: Webcam栏,分别是video9和video10,下一步:确定哪一个是接入…

存内计算——发展史与近期成果

存内计算的概念早在上个世纪就已经被提出,但当时的人们寄希望于通过优化处理器设计以及工艺制程的升级,来获得性能和能效比的提升,存内计算的研究仅停留在理论阶段。随着大数据时代的到来,存内计算由于其结构特点以及摩尔定律的“…

机器学习 低代码 ML:PyCaret 的使用

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。 🍎个人主页:小嗷犬的个人主页 🍊个人网站:小嗷犬的技术小站 🥭个人信条:为天地立心&…

大数加法-----蓝桥杯

题目描述 以字符串的形式读入两个数字&#xff0c;编写一个函数计算它们的和&#xff0c;以字符串形式返回。 数据范围: s.length,tlength < 100000&#xff0c;字符串仅由0~9构成要求: 时间复杂度 O(n)。 结构示例1&#xff1a; 结果示例2&#xff1a; 代码展示&#xff…

在中国,大多数做机器视觉工程师的人的出路是什么?

工程师一直号称是靠技术吃饭&#xff0c;越老越吃香。显然我们机器视觉工程师归属于工程师。 可是这种是一种低级趣味的思维力度。 试问现在程序员难道不是越老越吃香吗&#xff1f; 律师难道不是吗&#xff1f; 银行家难道不是吗&#xff1f; 现在整体经济下滑情况&#xff0…

C++ 菱形继承和虚拟菱形继承

菱形继承和虚拟菱形继承 菱形继承1. 概念2. 产生的问题 虚拟菱形继承1.1 使用1.2 原理 菱形继承 1. 概念 菱形继承是多继承的一个特殊情况&#xff0c;多继承是指一个子类类继承了两个或以上的直接父类&#xff0c;而菱形继承问题的产生是因为该子类的父类&#xff0c;继承了…

超好看的前端特效HTML特效、CSS特效、JS特效(第一期)

超好看的前端特效 1. 粒子组成文字动画特效 文件组成&#xff1a; base.css import url(https://fonts.googleapis.com/css?familyAbrilFatface|Raleway:300,400,900);.coidea-header {position: fixed;display: block;width: 96%;width: calc( 100% - 32px );height: 40px;m…

面试题:Feign第一次调用为什么会很慢?

文章目录 前言Ribbon是如何进行负载的RibbonClientConfigurationZoneAwareLoadBalancerRibbon负载均衡策略Ribbon-eager-load&#xff08;饥饿加载&#xff09;模式开启Ribbon饥饿加载总结 前言 首先要了解 Feign 是如何进行远程调用的&#xff0c;这里面包括&#xff0c;注册…

Fiddler修改https请求与响应 bug修复变灰了选不了等 Fiddle对夜神模拟器抓包设置

不要修改别人的东西&#xff0c;不要修改别人的东西&#xff0c;不要修改别人的东西 只用于自己的网站&#xff0c;自己安全调试。 fiddler修改https请求 1、打到要改的请求 2、替换请求内容 3、开启捕获。操作产生请求。 4、fiddler里查看请求或响应数据 &#xff0c;确认成…

ubuntu20配置mysql8

首先更新软件包索引运行 sudo apt update命令。然后运行 sudo apt install mysql-server安装MySQL服务器。 安装完成后&#xff0c;MySQL服务将作为systemd服务自动启动。你可以运行 sudo systemctl status mysql命令验证MySQL服务器是否正在运行。 连接MySQL 当MySQL安装…

操作系统基础:进程同步【下】

&#x1f308;个人主页&#xff1a;godspeed_lucip &#x1f525; 系列专栏&#xff1a;OS从基础到进阶 1 进程同步⛵1.1 吸烟者问题✈️1.1.1 问题描述✈️1.1.2 问题分析1.1.2.1 关系分析&#xff08;确定同步、互斥关系&#xff09;1.1.2.2 整理思路&#xff08;确定PV操作的…

Golang语言异常机制解析:错误策略与优雅处理

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站https://www.captainbed.cn/kitie。 前言 作为开发者来说&#xff0c;我们没办法保证程序在运行过程中永远不会出现异常&#xff0c;对于异常…

K8S网络

一、介绍 k8s不提供网络通信&#xff0c;提供了CNI接口(Container Network Interface&#xff0c;容器网络接口)&#xff0c;由CNI插件实现完成。 1.1 Pod通信 1.1.1 同一节点Pod通信 Pod通过虚拟Ethernet接口对&#xff08;Veth Pair&#xff09;与外部通信&#xff0c;Veth…

Unity_颜色空间GammaLinear

Unity_颜色空间Gamma&Linear Unity颜色空间的选择对于效果的影响具体有多大&#xff1f; 在ProjectSetting -> Player -> OtherSetting -> Rendering设置下的颜色空间选项卡选择颜色空间进行设置&#xff1a; 太深奥的解释一时半会看不懂&#xff0c;找见一个粗…

jsp 样衣申请与归还管理系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 JSP 样衣申请与归还管理系统是一套完善的java web信息管理系统&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境 为TOMCAT7.0,Myeclipse8.5开发&#xff0c;数据库为My…

地理空间分析10——空间数据分析中的地理编码与Python

目录 写在开头1. 地理编码基础1.1 地理编码的基本原理1.1.1 坐标系统1.1.2 地名解析1.1.3 编码算法1.2 Python中使用地理编码的基础知识1.2.1 百度地图API1.2.2 高德地图API1.2.3 腾讯地图API1.3 Python中实现代码2. 逆地理编码2.1 利用Python进行逆地理编码2.1.1 获取高德地图…

地毯填补问题

地毯填补问题 题目描述 相传在一个古老的阿拉伯国家里&#xff0c;有一座宫殿。宫殿里有个四四方方的格子迷宫&#xff0c;国王选择驸马的方法非常特殊&#xff0c;也非常简单&#xff1a;公主就站在其中一个方格子上&#xff0c;只要谁能用地毯将除公主站立的地方外的所有地…

使用最大边界相关算法处理文章自动摘要

一、需求背景 对于博客或者文章来说&#xff0c;摘要是普遍性的需求。但是我们不可能让作者自己手动填写摘要或者直接暴力截取文章的部分段落作为摘要&#xff0c;这样既不符合逻辑又不具有代表性&#xff0c;那么&#xff0c;是否有相关的算法或者数学理论能够完成这个需求呢&…