【Pytorch和Keras】使用transformer库进行图像分类

目录

    • 一、环境准备
    • 二、基于Pytorch的预训练模型
      • 1、准备数据集
      • 2、加载预训练模型
      • 3、 使用pytorch进行模型构建
    • 三、基于keras的预训练模型
    • 四、模型测试
    • 五、参考

 现在大多数的模型都会上传到huggface平台进行统一的管理,transformer库能关联到huggface中对应的模型,并且提供简洁的transformer模型调用,这大大提高了开发人员的开发效率。本博客主要利用transformer库实现一个简单的模型微调,以进行图像分类的任务。


一、环境准备

 使用终端命令行安装对应的第三方包,具体安装命令输入如下:

pip install transformers datasets evaluate

二、基于Pytorch的预训练模型

  由于下面这些内容需要在huggface上申请账号权限,才能进行模型和数据集加载,如果之前有从huggface上拉取模型和数据集的经验,可以略过,如果没有配置过,可以参考笔者之前的文章https://blog.csdn.net/qq_40734883/article/details/143922095,然后直接申请Write权限就可以。

在这里插入图片描述
  后续所有涉及到的数据集food101和transformer模型都需要参考上述文章进行直接下载,才能运行整个程序,或者在google的colab直接运行。

  如果在google的colab上运行,请提前设置好电脑的GPU资源,同时加入huggface登录代码,具体如下:

from huggingface_hub import notebook_login
notebook_login()

  运行之后会提示进行token输入,按之前获取到的token输入即可。

1、准备数据集

  这里以food101数据集作为微调数据集,在imagenet-21k上训练完成的transformer模型(vit-base-patch16-224)进行优化

from datasets import load_dataset

food = load_dataset("food101", split="train[:5000]")


# 划分数据集,训练集:测试集=8:2,food有两个键:一个train,一个test
food = food.train_test_split(test_size=0.2)  

# 标签转换
labels = food["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

id2label为通过id访问标签的字典,后续会使用到。

2、加载预训练模型

from transformers import AutoImageProcessor

checkpoint = "google/vit-base-patch16-224-in21k"   # ImageNet-21k上的预训练模型
image_processor = AutoImageProcessor.from_pretrained(checkpoint)  # 从huggface拉取并加载模型

3、 使用pytorch进行模型构建

from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

# 数据预处理操作定义
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (image_processor.size["shortest_edge"]
	    if "shortest_edge" in image_processor.size
	    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), 
					   ToTensor(), 
					   normalize])

# 对原始数据进行RGB及字典化
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

food = food.with_transform(transforms)
# 验证
import evaluate
# 指定验证过程中的评价指标-准确率
accuracy = evaluate.load("accuracy")

import numpy as np
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

  训练设置和运行,具体输入代码如下:

# 整合训练中的数据,以便在模型训练或评估过程中使用
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()

from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

# 初始化模型
model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

# 设置模型优化参数
training_args = TrainingArguments(
    output_dir="my_awesome_food_model",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
)

# 初始化训练实例
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=food["train"],
    eval_dataset=food["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

trainer.train()  # 开始训练

trainer.push_to_hub()  # 推送到huggfacehub

  经过上述设置训练完成之后,会将模型微调结果推送到huggface平台,如果不想推送,可以不运行相关的命令行,并且training_args中的push_to_hub=False

  训练结果如下图所示:

在这里插入图片描述
  默认需要选择是否关联wandb,如果不想选择,直接根据设置提示跳过即可。

  如果选择了推送到huggfacehub(trainer.push_to_hub() )的话,在个人的huggface上会有一个名为my_awesome_food_model的模型,里面包含了模型训练的各个参数设置和测试结果。

在这里插入图片描述


三、基于keras的预训练模型

  使用transflow的keras API 进行模型的搭建,具体代码如下:

from transformers import create_optimizer

# 超参数设置
batch_size = 16
num_epochs = 5
num_train_steps = len(food["train"]) * num_epochs
learning_rate = 3e-5
weight_decay_rate = 0.01

# 定义优化方式和策略
optimizer, lr_schedule = create_optimizer(init_lr=learning_rate, 
										  num_train_steps=num_train_steps, 
										  weight_decay_rate=weight_decay_rate, 
										  num_warmup_steps=0)

# 定义分类器
from transformers import TFAutoModelForImageClassification
model = TFAutoModelForImageClassification.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)


# converting our train dataset to tf.data.Dataset
tf_train_dataset = food["train"].to_tf_dataset(columns="pixel_values", label_cols="label", shuffle=True, batch_size=batch_size, collate_fn=data_collator)

# converting our test dataset to tf.data.Dataset
tf_eval_dataset = food["test"].to_tf_dataset(columns="pixel_values", label_cols="label", shuffle=False, batch_size=batch_size, collate_fn=data_collator)

# 定义损失函数
from tensorflow.keras.losses import SparseCategoricalCrossentropy
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model.compile(optimizer=optimizer, loss=loss)


from transformers.keras_callbacks import KerasMetricCallback, PushToHubCallback
# 定义验证指标
metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset)
# 推送到huggface回调函数
push_to_hub_callback = PushToHubCallback(output_dir="food_classifier", 
                                         tokenizer=image_processor, 
                                         save_strategy="no")
callbacks = [metric_callback, push_to_hub_callback]

# 开始训练
model.fit(tf_train_dataset, validation_data=tf_eval_dataset, epochs=num_epochs, callbacks=callbacks)

四、模型测试

 这里使用微调好的模型在food101上找一张验证图像进行简单的验证测试,具体代码如下:

# 验证food中验证集的某一张图像
ds = load_dataset("food101", split="validation[-5:-1]")
image = ds["image"][-1]

# visualize image
import matplotlib.pyplot as plt
plt.imshow(image)
plt.axis('off')  
plt.show()

  测试图像如下所示:
在这里插入图片描述


from transformers import pipeline
# initialize classifier instance
classifier = pipeline("image-classification", model="my_awesome_food_model")
classifier(image)

from transformers import AutoImageProcessor
import torch
# load pre-trained image processor
image_processor = AutoImageProcessor.from_pretrained("my_awesome_food_model")
inputs = image_processor(image, return_tensors="pt")


from transformers import AutoModelForImageClassification
# laod pre-trained model
model = AutoModelForImageClassification.from_pretrained("my_awesome_food_model")
with torch.no_grad():
    logits = model(**inputs).logits

# 输出测试结果
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])

  输出结果如下所示:

Device set to use cuda:0
[{'label': 'ramen', 'score': 0.9517934918403625},
 {'label': 'bruschetta', 'score': 0.7566707730293274},
 {'label': 'hamburger', 'score': 0.7004948854446411},
 {'label': 'chicken_wings', 'score': 0.6275856494903564},
 {'label': 'prime_rib', 'score': 0.5991673469543457}]

  预测结果为:ramen

  释义:“ramen”一词源于日语“ラーメン”,是“拉面”的意思。它进一步追溯至汉语“拉面”,是一种起源于中国、流行于日本及其他东亚地区的面条食品。在日本,拉面通常由小麦面粉制成的面条,搭配肉汤和各种配料,如叉烧、鸡蛋、蔬菜等。

五、参考

[1] https://huggingface.co/docs/transformers/main/tasks/image_classification

[2] https://github.com/huggingface/transformers/blob/main/docs/source/en/installation.md

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/962892.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

c语言进阶(简单的函数 数组 指针 预处理 文件 结构体)

c语言补充 格式 void函数头 {} 中的是函数体 sum函数名 &#xff08;&#xff09; 参数表 #include <stdio.h>void sum(int begin, int end) {int i;int sum 0;for (i begin ; i < end ; i) {sum i;}printf("%d到%d的和是%d\n", begin, end, sum); …

比较器使用

1 比较器是什么 2 使用比较器的注意事项 输出是否要接上拉电阻 2.1 推挽结构&#xff08;可以不用加上拉电阻&#xff09; 2.2 开漏结构&#xff08;OD,OC&#xff09;需要加上拉电阻 2.3 LM393 2.4 推挽和开漏 3 比较器的注意事项 输出是否要接上拉电阻 3.1 比较器内部有吸…

爬虫基础(三)Session和Cookie讲解

目录 一、前备知识点 &#xff08;1&#xff09;静态网页 &#xff08;2&#xff09;动态网页 &#xff08;3&#xff09;无状态HTTP 二、Session和Cookie 三、Session 四、Cookie &#xff08;1&#xff09;维持过程 &#xff08;2&#xff09;结构 正式开始说 Sessi…

项目升级Sass版本或升级Element Plus版本遇到的问题

项目升级Sass版本或升级Element Plus版本遇到的问题 如果项目有需求需要用到高版本的Element Plus组件&#xff0c;则需要升级相对应的sass版本&#xff0c;Element 文档中有提示&#xff0c;2.8.5及以后得版本&#xff0c;sass最低支持的版本为1.79.0&#xff0c;所升级sass、…

数据结构-Stack和栈

1.栈 1.1什么是栈 栈是一种特殊的线性表&#xff0c;只允许在固定的一段进行插入和删除操作&#xff0c;进行插入和删除操作的一段称为栈顶&#xff0c;另一端称为栈底。 栈中的数据元素遵顼后进先出LIFO&#xff08;Last In First Out&#xff09;的原则&#xff0c;就像一…

stm32硬件实现与w25qxx通信

使用的型号为stm32f103c8t6与w25q64。 STM32CubeMX配置与引脚衔接 根据stm32f103c8t6引脚手册&#xff0c;采用B12-B15四个引脚与W25Q64连接&#xff0c;实现SPI通信。 W25Q64SCK&#xff08;CLK&#xff09;PB13MOSI&#xff08;DI&#xff09;PB15MISO(DO)PB14CS&#xff08…

04树 + 堆 + 优先队列 + 图(D1_树(D8_B*树(B*)))

目录 一、基本介绍 二、相同思想和策略 三、不同的方式的磁盘空间利用 四、知识小结 一、基本介绍 B*树是Btree的变体&#xff0c;在B树的基础上(所有的叶子结点中包含了全部关键字的信息&#xff0c;及指向含有 这些关键字记录的指针)&#xff0c; B*树中非根和非叶子结…

Hot100之哈希

1两数之和 题目 思路解析 解法1--两次循环 解法2--哈希表一次循环 代码 解法1--两次循环 class Solution {public int[] twoSum(int[] nums, int target) {int nums1[] new int[2];int length nums.length;for (int i 0; i < length; i) {for (int j i 1; j < …

Autosar-以太网是怎么运行的?(原理部分)

写在前面&#xff1a; 入行一段时间了&#xff0c;基于个人理解整理一些东西&#xff0c;如有错误&#xff0c;欢迎各位大佬评论区指正&#xff01;&#xff01;&#xff01; 1.TCP/IP协议详解 TCP/IP协议包含了一系列的协议&#xff0c;也叫TCP/IP协议族&#xff08;TCP/IP P…

2025年大数据毕业设计选题推荐:数据分析与可视化 数据挖掘

目录 前言 毕设选题 开题指导建议 更多精选选题 选题帮助 最后 前言 大家好,这里是海浪学长毕设专题! 大四是整个大学期间最忙碌的时光&#xff0c;一边要忙着准备考研、考公、考教资或者实习为毕业后面临的升学就业做准备,一边要为毕业设计耗费大量精力。学长给大家整理…

如何在vs2022中处理python下无法输出中文问题

1.如何在vs2022中处理python下无法输出中文问题 进入vs界面时---工具菜单---自定义----在自定义窗口下选中”命令”页面----在菜单栏内----选择文件----再点击添加命令----左侧栏下滑找到文件-----在右侧往下拉找到并点击高级保存选项----再点击确定。 此时VS工程页面上会出现…

SSRF 漏洞利用 Redis 实战全解析:原理、攻击与防范

目录 前言 SSRF 漏洞深度剖析 Redis&#xff1a;强大的内存数据库 Redis 产生漏洞的原因 SSRF 漏洞利用 Redis 实战步骤 准备环境 下载安装 Redis 配置漏洞环境 启动 Redis 攻击机远程连接 Redis 利用 Redis 写 Webshell 防范措施 前言 在网络安全领域&#xff0…

UniApp开发的微信小程序主包过大问题及解决方案 编译小程序时Node-modules被打入主包

欢迎关注 『开发必备』 专栏,专注于解决你在开发过程中遇到的各种问题,帮你快速找到解决方案,节省大量调试时间。内容持续更新中,保证每篇都值得收藏! UniApp开发的微信小程序主包过大问题及解决方案 在使用UniApp开发微信小程序时,很多开发者都会遇到一个问题:打包后,…

Diffusion--人工智能领域的革命性技术

在人工智能领域&#xff0c;“diffusion”一词通常指的是“扩散模型”&#xff08;Diffusion Models&#xff09;&#xff0c;其全称为“Denoising Diffusion Probabilistic Models”&#xff08;DDPMs&#xff09;。扩散模型是一类生成式模型&#xff0c;它通过逐步去噪的方式…

升级到Mac15.1后pod install报错

升级Mac后&#xff0c;Flutter项目里的ios项目运行 pod install报错&#xff0c; 遇到这种问题&#xff0c;不要着急去百度&#xff0c;大概看一下报错信息&#xff0c;每个人遇到的问题都不一样。 别人的解决方法并不一定适合你&#xff1b; 下面是报错信息&#xff1a; #…

基于 oneM2M 标准的空气质量监测系统的互操作性

论文标题 英文标题&#xff1a; Interoperability of Air Quality Monitoring Systems through the oneM2M Standard 中文标题&#xff1a; 基于 oneM2M 标准的空气质量监测系统的互操作性 作者信息 Jonnar Danielle Diosana, Gabriel Angelo Limlingan, Danielle Bryan Sor…

利用Muduo库实现简单且健壮的Echo服务器

一、muduo网络库主要提供了两个类&#xff1a; TcpServer&#xff1a;用于编写服务器程序 TcpClient&#xff1a;用于编写客户端程序 二、三个重要的链接库&#xff1a; libmuduo_net、libmuduo_base、libpthread 三、muduo库底层就是epoll线程池&#xff0c;其好处是…

四.3 Redis 五大数据类型/结构的详细说明/详细使用( hash 哈希表数据类型详解和使用)

四.3 Redis 五大数据类型/结构的详细说明/详细使用&#xff08; hash 哈希表数据类型详解和使用&#xff09; 文章目录 四.3 Redis 五大数据类型/结构的详细说明/详细使用&#xff08; hash 哈希表数据类型详解和使用&#xff09;2.hash 哈希表常用指令(详细讲解说明)2.1 hset …

苍穹外卖第一天

角色分工 技术选型 pojo子模块 nginx反向代理 MD5密码加密

动态规划DP 背包问题 完全背包问题(题目分析+C++完整代码)

概览检索 动态规划DP 概览&#xff08;点击链接跳转&#xff09; 动态规划DP 背包问题 概览&#xff08;点击链接跳转&#xff09; 完全背包问题 原题链接 AcWiing 3. 完全背包问题 题目描述 有 N种物品和一个容量是 V的背包&#xff0c;每种物品都有无限件可用。 第 i种物…