【深度学习-图像识别】使用fastai对Caltech101数据集进行图像多分类(50行以内的代码就可达到很高准确率)


文章目录

  • 前言
    • fastai介绍
      • 数据集介绍
  • 一、环境准备
  • 二、数据集处理
    • 1.数据目录结构
    • 2.导入依赖项
    • 2.读入数据
    • 3.模型构建
      • 3.1 寻找合适的学习率
      • 3.2 模型调优
    • 4.模型保存与应用
  • 总结
      • 人工智能-图像识别 系列文章目录


前言

fastai介绍

fastai 是一个深度学习库,它为从业人员提供了高级组件,可以快速、轻松地在标准深度学习领域提供最先进的结果,并为研究人员提供了低级组件,可以混合和匹配以构建新的方法。以解耦抽象的方式表达了许多深度学习和数据处理技术的通用底层模式。
fastai 有两个主要的设计目标:易于使用、快速高效,同时具有很强的可破解性和可配置性。它建立在提供可组合构件的低级应用程序接口的层次结构之上。这样,如果用户想重写部分高级应用程序接口或添加特定行为以满足自己的需求,就不必学习如何使用最底层的应用程序接口。
在这里插入图片描述

数据集介绍

下载链接
Caltech101国内下载地址
Caltech101

Caltech101数据集内部有 101 个类别的物体图片。每个类别约有 40 至 800 张图片。大多数类别约有 50 张图片。每张图片的大小大约为 300 x 200 像素。并且作者还标注了这些图片中每个物体的轮廓,这些都包含在 "Annotations.tar "中。还有一个 MATLAB 脚本 "show_annotations.m "可以查看注释。

Collected in September 2003 by Fei-Fei Li, Marco Andreetto, and
Marc’Aurelio Ranzato。

一、环境准备

这里展示使用GPU进行训练的环境搭建,只用CPU也可以进行训练,只是训练时间比较慢。
首先安装Anaconda,通过conda安装我们需要的包

 conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
 conda install -c nvidia fastai anaconda

详情可见第一篇文章。

二、数据集处理

1.数据目录结构

├───data_iamge
│   ├───101_ObjectCategories
│   │   ├───accordion
│   │   ├───airplanes
│   │   ├───anchor
│   │   ├───ant
│   │   ├───BACKGROUND_Google
│   │   ├───barrel
│   │   ├───bass
│   │   ├───beaver
│   │   ├───binocular
│   │   ├───bonsai
│   │   ├───brain
│   │   ├───brontosaurus
...

2.导入依赖项

from fastai import *
from fastai.vision.all import *
from fastai.metrics import error_rate

import os
#from keras.utils import plot_model
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

查看环境以及版本信息,cuda.is_available()判断是否可以用GPU。

print(torch.cuda.is_available())
print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())

True
2.0.1
11.8
8700

'''SEED Everything'''
def seed_everything(SEED=42):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = True # keep True if all the input have same size.
SEED=42
seed_everything(SEED=SEED)
'''SEED Everything'''

2.读入数据

代码如下(示例):

path='./data_image/101_ObjectCategories/'
image_rsize=224
item_tfms = [Resize((image_rsize,image_rsize))]
data = ImageDataLoaders.from_folder(path, train = '.', valid_pct=0.2,
                                   size=image_rsize,
                                  item_tfms=item_tfms)
data.show_batch(figsize=(7,6))

在这里插入图片描述

3.模型构建

这里使用预训练模型resnet101,这是一个非常优秀的残差网络模型。
这些残差网络更容易优化,并且可以从显着增加的深度中获得准确性。
这些残差网络的集合在 ImageNet 测试集上实现了 3.57% 的误差。该结果在ILSVRC 1分类任务中获得第一名。

learn = cnn_learner(data, models.resnet101, model_dir='./model', path = Path("."))

3.1 寻找合适的学习率

learn.lr_find()

在这里插入图片描述

接下来使用fit_one_cycle方法用更小的学习率进一步训练。fit_one_cycle使用的是一种周期性学习率,从较小的学习率开始学习,缓慢提高至较高的学习率,然后再慢慢下降,周而复始,每个周期的长度略微缩短,在训练的最后部分,允许学习率比之前的最小值降得更低。这不仅可以加速训练,还有助于防止模型落入损失平面的陡峭区域,使模型更倾向于寻找更平坦的极小值,从而缓解过拟合现象。

lr1 = 1e-3
lr2 = 1e-1
epoch	train_loss	valid_loss	time
0	1.417713	1.648756	00:45
1	3.097069	9.964518	00:43
2	5.385355	5.347832	00:44
3	4.194504	12.162844	00:44
4	2.985504	3.486863	00:43
5	2.152388	22.297184	00:43
6	1.295905	3.554162	00:43
7	0.630879	9.193820	00:43
8	0.361619	49.334236	00:43
9	0.255115	9.832499	00:43

3.2 模型调优

unfreeze
在fastai课程中使用的是预训练模型,模型卷积层的权重已经提前在ImageNet
上训练好了,在使用的时候一般只需要在预训练模型最后一层卷积层后添加自定义的全连接层即可。卷积层默认是freeze的,即在训练阶段进行反向传播时不会更新卷积层的权重,只会更新全连接层的权重。在训练几个epoch之后,全连接层的权重已经训练的差不多了,但accuracy还没有达到你的要求,这时你可以调用unfreeze然后再进行训练,这样在进行反向传播时便会更新卷积层的权重(一般不会对卷积层权重进行较大的更新,只会进行一点点的微调,越靠前的卷积层调整的幅度越小,所以有了differential
learning rate 这一想法)

precompute
当precompute=True时,会提前计算出每一个训练样本(不包括增强样本)在预训练模型最后一层卷积层的activation,
并将其缓存下来,之后在训练阶段进行前向传播的时候,直接将precompute 的activation 作为后面全连接层(FC
Layer)的输入,这样便省去前面卷积层进行前向传播的计算量,减少训练所需时间(这种优势在epoch比较大的时候能够显著0提高训练速度)。当precompute=False时,则不会提前计算训练样本的activation,每一个epoch都需要重新将训练样本+增强样本(前提是进行了增强操作)进行卷积层的前向传播,然后进行反向传播更新对应的权重。

learn.unfreeze()
learn.show_results()

在这里插入图片描述
从展示的部分训练结果可以看出,只有一张图被预测错误了,其他的都是正确的。

4.模型保存与应用

最后我们可以将模型保存下来,并且对验证集的图片的类别进行预测。

learn.export(Path("./model/export.pkl"))
from PIL import Image
img = Image.open(path+'ant/image_0001.jpg')
image_rsize=224
# Resize the image to 224x224
img_resized = img.resize((image_rsize,image_rsize))
pred, pred_idx, probs = learn.predict(img_resized)
im_t = cast(array(img_resized), TensorImage)
# Print the predicted label and probability
print(f"Predicted label: {pred}, probability: {probs[pred_idx]:.4f}")
img

在这里插入图片描述

总结

epoch	train_loss	valid_loss	time
0	1.030772	979.477417	00:52
1	1.074642	86.289436	00:52
2	0.553576	0.457210	00:52
3	0.302997	0.546438	00:52
4	0.176070	0.596845	00:52

我们借助fastai训练了resnet101模型,对 101 个类别的图像数据集进行了分类。
使用基于pytorch的fastai库,使用resnet模型和有101个类别的Caltech101图像数据集,训练了一个高准确率的多分类的深度学习模型,能够对101个类别的图像大数据集进行准确的图像类别识别。
使用简洁高效的代码,借助GPU提升训练速度(也可以使用CPU训练,本项目会自动识别硬件),首先数据集进行预处理,然后对模型进行训练,并将模型保存为pkl格式,最后对测试集的图像的类别进行预测。

可见,使用fastai进行图像多分类是非常简便的,所使用的代码行数非常少却能达到很高的准确率,而且借助GPU训练速度非常快。

这里将全部的代码和图片数据集打包起来了,方便大家复现。
开箱即用,欢迎下载:
使用fastai对Caltech101数据集进行图像多分类

单独下载数据集
Caltech101数据集 2023完整版 增加了更多图片


人工智能-图像识别 系列文章目录

  1. 环境搭建: pytorch以及fastai安装,配置GPU训练环境 待更新。。。
  2. 使用fastai对Caltech101数据集进行图像多分类(50行以内的代码就可达到很高准确率)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/85439.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

项目解决问题

红外 没接收到红外信号时, 会有杂波干扰 STC单片机 STC的串口要用一个定时器作为波特率发生器 开定时器2需要 开定时器0 1 要ET0 1 ET11打开 串口有时候和定时器有冲突 串口发送函数放定时器中断函数中,时间太少可能会导致一直卡在定时器中AUXR | 0x…

计算机视觉:比SAM快50倍的分割一切视觉模型FastSAM

目录 引言 1 FastSAM介绍 1.1 FastSAM诞生 1.2 模型算法 1.3 实验结果 2 FastSAM运行环境构建 2.1 conda环境构建 2.2 运行环境安装 2.3 模型下载 3 FastSAM运行 3.1 命令行运行 3.1.1 Everything mode 3.1.2 Text prompt 3.1.3 Box prompt (xywh) 3.1.4 Points p…

科技云报道:云计算下半场,公有云市场生变,私有云风景独好

科技云报道原创。 大数据、云计算、人工智能,组成了恢弘的万亿级科技市场。这三个领域,无论远观近观,都如此性感和魅力,让一代又一代创业者为之杀伐攻略。 然而高手过招往往一瞬之间便已胜负知晓,云计算市场的巨幕甫…

Threejs学习05——球缓冲几何体背景贴图和环境贴图

实现随机多个三角形随机位置随机颜色展示效果 这是一个非常简单基础的threejs的学习应用!本节主要学习的是球面缓冲几何体的贴图部分,这里有环境贴图以及背景贴图,这样可以有一种身临其境的效果!这里环境贴图用的是一个.hdr的文件…

opencv 进阶16-基于FAST特征和BRIEF描述符的ORB(图像匹配)

在计算机视觉领域,从图像中提取和匹配特征的能力对于对象识别、图像拼接和相机定位等任务至关重要。实现这一目标的一种流行方法是 ORB(Oriented FAST and Rotated Brief)特征检测器和描述符。ORB 由 Ethan Rublee 等人开发,结合了…

IDEA中导入多module的Maven项目无法识别module的解决办法

首先举个栗子 这是正常的多module工程(spring cloud项目) 正常工程.png 这是导入出现问题的多module工程 导入出现问题的工程.png 原因: 出现该问题,是由于打开工程的时候IDEA只编译了最外层的pom.xml文件,而内部的…

【分享】华为设备登录安全配置案例

微思网络www.xmws.cn,2002年成立,专业IT认证培训21年,面向全国招生! 微 信 号 咨 询: xmws-IT 华为HCIA试听课程:超级实用,华为VRP系统文件详解【视频教学】华为VRP系统文件详解 华为HCIA试听课…

机器学习在大数据分析中的应用

文章目录 机器学习在大数据分析中的原理机器学习在大数据分析中的应用示例预测销售趋势客户细分和个性化营销 机器学习在大数据分析中的前景和挑战前景挑战 总结 🎉欢迎来到AIGC人工智能专栏~探索机器学习在大数据分析中的应用 ☆* o(≧▽≦)o *☆嗨~我是IT陈寒&…

Android SDK 上手指南||第四章 应用程序结构

第四章 应用程序结构 本教程将主要以探索与了解为主要目的,但后续的系列文章则将进一步带大家深入学习如何创建用户界面、响应用户交互操作以及利用Java编排应用逻辑。我们将专注于大家刚刚开始接触Android开发时最常遇到的项目内容,但也会同时涉及一部…

SpringBoot内嵌Tomcat连接池分析

文章目录 1 Tomcat连接池1.1 简介1.2 架构图1.2.1 JDK线程池架构图1.2.2 Tomcat线程架构 1.3 核心参数1.3.1 AcceptCount1.3.2 MaxConnections1.3.3 MinSpareThread/MaxThread1.3.4 MaxKeepAliveRequests1.3.5 ConnectionTimeout1.3.6 KeepAliveTimeout 1.4 核心内部线程1.4.1 …

shell脚本免交互

一.Here Document免交互 1.免交互概述 使用I/O重定向的方式将命令列表提供给交互式程序 是一种标准输入&#xff0c;只能接收正确的指令或命令 2.格式&#xff1a; 命令 <<标记 ....... 内容 #标记之间是传入内容 ....... 标记 注意事项 标记可以使用任意的合法…

“深度学习”学习日记:Tensorflow实现VGG每一个卷积层的可视化

2023.8.19 深度学习的卷积对于初学者是非常抽象&#xff0c;当时在入门学习的时候直接劝退一大班人&#xff0c;还好我坚持了下来。可视化时用到的图片&#xff08;我们学校的一角&#xff01;&#xff01;&#xff01;&#xff09;以下展示了一个卷积和一次Relu的变化 作者使…

2023国赛数学建模思路 - 案例:最短时间生产计划安排

文章目录 0 赛题思路1 模型描述2 实例2.1 问题描述2.2 数学模型2.2.1 模型流程2.2.2 符号约定2.2.3 求解模型 2.3 相关代码2.4 模型求解结果 建模资料 0 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 最短时…

网络安全—黑客—自学笔记

想自学网络安全&#xff08;黑客技术&#xff09;首先你得了解什么是网络安全&#xff01;什么是黑客&#xff01; 网络安全可以基于攻击和防御视角来分类&#xff0c;我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术&#xff0c;而“蓝队”、“安全运营”、“安全…

Mac Flutter web环境搭建

获取 Flutter SDK 下载以下安装包来获取最新的 stable Flutter SDK将文件解压到目标路径, 比如: cd ~/development $ unzip ~/Downloads/flutter_macos_3.13.0-stable.zip 配置 flutter 的 PATH 环境变量&#xff1a; export PATH"$PATH:pwd/flutter/bin" // 这个命…

无类别域间路由(Classless Inter-Domain Routing, CIDR):理解IP网络和子网划分(传统的IP地址类ABCDE:分类网络)

文章目录 无类别域间路由&#xff08;CIDR&#xff09;&#xff1a;理解IP网络和子网划分引言传统的IP地址类关于“IP地址的浪费” IP地址与CIDRIP地址概述网络号与主机号CIDR记法&#xff08;网络 网络地址/子网掩码&#xff09;网络和广播地址 CIDR的优势减少路由表项缓解IP…

Android 面试之Glide做了哪些优化?

前言 Glide可以说是最常用的图片加载框架了&#xff0c;Glide链式调用使用方便&#xff0c;性能上也可以满足大多数场景的使用&#xff0c;Glide源码与原理也是面试中的常客。 但是Glide的源码内容比较多&#xff0c;想要学习它的源码往往千头万绪&#xff0c;一时抓不住重点.…

0101读写分离测试-jdbc-shardingsphere-中间件

文章目录 1 前言2、创建SpringBoot程序2.1、创建项目2.2、添加依赖2.3、生成实体类、service与Mapper1.5、配置读写分离 2、测试2.1、读写分离测试2.2、事务测试2.3、负载均衡测试 结语 1 前言 shardingshpere-jdbc定位为轻量级 Java 框架&#xff0c;在 Java 的 JDBC 层提供的…

文心一言 VS 讯飞星火 VS chatgpt (79)-- 算法导论7.4 4题

四、请用go语言&#xff0c;证明:RANDOMIZED-QUICKSORT 期望运行时间是 Ω(nlgn)。 文心一言&#xff1a; 首先&#xff0c;为了证明RANDOMIZED-QUICKSORT的期望运行时间是Ω(nlg n)&#xff0c;我们需要证明在最坏的情况下&#xff0c;该算法的运行时间是O(nlg n)。然后&…

elelementui组件

一、按钮 1、按钮样式 使用type、plain、round和circle属性来定义 Button 的样式。 2、主要代码 <el-row><el-button>默认按钮</el-button><el-button type"primary">主要按钮</el-button><el-button type"success">…