微调 Florence-2 - 微软的尖端视觉语言模型

Florence-2 是微软于 2024 年 6 月发布的一个基础视觉语言模型。该模型极具吸引力,因为它尺寸很小 (0.2B 及 0.7B) 且在各种计算机视觉和视觉语言任务上表现出色。

Florence 开箱即用支持多种类型的任务,包括: 看图说话、目标检测、OCR 等等。虽然覆盖面很广,但仍有可能你的任务或领域不在此列,也有可能你希望针对自己的任务更好地控制模型输出。此时,你就需要微调了!

本文,我们展示了一个在 DocVQA 上微调 Florence 的示例。尽管原文宣称 Florence 2 支持视觉问答 (VQA) 任务,但最终发布的模型并未包含 VQA 功能。因此,我们正好拿这个任务练练手,看看我们能做点什么!

预训练细节与模型架构

98ccecd0587f449e3674afd0908280f1.png
Florence-2 架构

无论执行什么样的计算机视觉任务,Florence-2 都会将其建模为序列到序列的任务。Florence-2 以图像和文本作为输入,并输出文本。模型结构比较简单: 用 DaViT 视觉编码器将图像转换为视觉嵌入,并用 BERT 将文本提示转换为文本和位置嵌入; 然后,生成的嵌入由标准编码器 - 解码器 transformer 架构进行处理,最终生成文本和位置词元。Florence-2 的优势并非源自其架构,而是源自海量的预训练数据集。作者指出,市面上领先的计算机视觉数据集通常所含信息有限 - WIT 仅有图文对,SA-1B仅有图像及相关分割掩码。因此,他们决定构建一个新的 FLD-5B 数据集,其中的每个图像都包含最广泛的信息 - 目标框、掩码、描述文本及标签。在创建数据集时,很大程度采用了自动化的过程,作者使用现成的专门任务模型,并用一组启发式规则及质检过程来清理所获得的结果。最终生成的用于预训练 Florence-2 模型的新数据集中包含了 1.26 亿张图像、超过 50 亿个标注。

SA-1Bhttps://ai.meta.com/datasets/segment-anything/

VQA 上的原始性能

我们尝试了各种方法来微调模型以使其适配 VQA (视觉问答) 任务的响应方式。迄今为止,我们发现最有效方法将其建模为图像区域描述任务,尽管其并不完全等同于 VQA 任务。看图说话任务虽然可以输出图像的描述性信息,但其不允许直接输入问题。

我们还测试了几个“不支持”的提示,例如 “<VQA>”、“<vqa>” 以及 “<Visual question answering>”。不幸的是,这些尝试的产生的结果都不可用。

微调后在 DocVQA 上的性能

我们使用 DocVQA 数据集的标准指标Levenshtein 相似度来测量性能。微调前,模型在验证集上的输出与标注的相似度为 0,因为模型输出与标注差异不小。对训练集进行 7 个 epoch 的微调后,验证集上的相似度得分提高到了 57.0。

Levenshtein 相似度https://en.wikipedia.org/wiki/Levenshtein_distance

我们创建了一个🤗 空间以演示微调后的模型。虽然该模型在 DocVQA 上表现良好,但在一般文档理解方面还有改进的空间。但我们仍然认为,它成功地完成了任务,展示了 Florence-2 对下游任务进行微调的潜力。我们建议大家使用The Cauldron数据集对 Florence-2 进行微调,大家可以在我们的 GitHub 页面上找到必要的代码。

  • 🤗 空间https://hf.co/spaces/andito/Florence-2-DocVQA

  • The Cauldronhttps://hf.co/datasets/HuggingFaceM4/the_cauldron

  • 我们的 GitHub 页面https://github.com/andimarafioti/florence2-finetuning

下图给出了微调前后的推理结果对比。你还可以至此处亲自试用模型。

模型试用地址https://hf.co/spaces/andito/Florence-2-DocVQA

266edf3c8245404862cb97421a8aa420.png
微调前后的结果

微调细节

由原文我们可以知道,基础模型在预训练时使用的 batch size 为 2048,大模型在预训练时使用的 batch size 为 3072。另外原文还说: 与冻结图像编码器相比,使用未冻结的图像编码器进行微调能带来性能改进。

我们在低资源的情况下进行了多组实验,以探索模型如何在更受限的条件下进行微调。我们冻结了视觉编码器,并在Colab的分别使用单张 A100 GPU (batch size 6) 、单张 T4 (batch size 1) 顺利完成微调。

Colab 链接https://colab.research.google.com/drive/1hKDrJ5AH_o7I95PtZ9__VlCTNAo1Gjpf?usp=sharing

与此同时,我们还对更多资源的情况进行了实验,以 batch size 64 对整个模型进行了微调。在配备 8 张 H100 GPU 的集群上该训练过程花费了 70 分钟。你可以在这里找到我们训得的模型。

模型地址https://hf.co/HuggingFaceM4/Florence-2-DocVQA

我们都发现 1e-6 的小学习率适合上述所有训练情形。如果学习率变大,模型将很快过拟合。

遛代码

如果你想复现我们的结果,可以在此处找到我们的 Colab 微调笔记本。下面,我们遛一遍在DocVQA上微调Florence-2-base-ft模型。

  • Colab 地址https://colab.research.google.com/drive/1hKDrJ5AH_o7I95PtZ9__VlCTNAo1Gjpf?usp=sharing

  • DocVQAhttps://hf.co/datasets/HuggingFaceM4/DocumentVQA

  • Florence-2-base-fthttps://hf.co/microsoft/Florence-2-base-ft

我们从安装依赖项开始。

!pip install -q datasets flash_attn timm einops

接着,从 Hugging Face Hub 加载 DocVQA 数据集。

import torch
from datasets import load_dataset

data = load_dataset("HuggingFaceM4/DocumentVQA")

我们可以使用 transformers 库中的 AutoModelForCausalLMAutoProcessor 类来加载模型和处理器,并设 trust_remote_code=True ,因为该模型尚未原生集成到 transformers 中,因此需要使用自定义代码。我们还会冻结视觉编码器,以降低微调成本。

from transformers import AutoModelForCausalLM, AutoProcessor
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Florence-2-base-ft",
    trust_remote_code=True,
    revision='refs/pr/6'
).to(device)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base-ft",
    trust_remote_code=True, revision='refs/pr/6')

for param in model.vision_tower.parameters():
  param.is_trainable = False

现在开始微调模型!我们构建一个训练 PyTorch 数据集,并为数据集中的每个问题添加 <DocVQA> 前缀。

import torch from torch.utils.data import Dataset

class DocVQADataset(Dataset):

    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        example = self.data[idx]
        question = "<DocVQA>" + example['question']
        first_answer = example['answers'][0]
        image = example['image'].convert("RGB")
        return question, first_answer, image

接着,构建数据整理器,从数据集样本构建训练 batch,以用于训练。在 40GB 内存的 A100 中,batch size 可设至 6。如果你在 T4 上进行训练,batch size 就只能是 1。

import os
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AdamW, get_scheduler

def collate_fn(batch):
    questions, answers, images = zip(*batch)
    inputs = processor(text=list(questions), images=list(images), return_tensors="pt", padding=True).to(device)
    return inputs, answers

train_dataset = DocVQADataset(data['train'])
val_dataset = DocVQADataset(data['validation'])
batch_size = 6
num_workers = 0

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          collate_fn=collate_fn, num_workers=num_workers, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                          collate_fn=collate_fn, num_workers=num_workers)

开始训练模型:

epochs = 7
optimizer = AdamW(model.parameters(), lr=1e-6)
num_training_steps = epochs * len(train_loader)

lr_scheduler = get_scheduler(name="linear", optimizer=optimizer,
                              num_warmup_steps=0, num_training_steps=num_training_steps,)

for epoch in range(epochs):
    model.train()
    train_loss = 0
    i = -1
    for inputs, answers in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{epochs}"):
        i += 1
        input_ids = inputs["input_ids"]
        pixel_values = inputs["pixel_values"]
        labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)
        outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        train_loss += loss.item()
    avg_train_loss = train_loss / len(train_loader)
    print(f"Average Training Loss: {avg_train_loss}")

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"):
            inputs, answers = batch
            input_ids = inputs["input_ids"]
            pixel_values = inputs["pixel_values"]
            labels = processor.tokenizer(text=answers, return_tensors="pt", padding=True, return_token_type_ids=False).input_ids.to(device)
            outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
            loss = outputs.loss
            val_loss += loss.item()

      print(val_loss / len(val_loader))

你可以分别对模型和处理器调用 save_pretrained() 以保存它们。微调后的模型在此处,你还可以在此处找到其演示。

  • 模型链接https://hf.co/HuggingFaceM4/Florence-2-DocVQA

  • 示例地址https://hf.co/spaces/andito/Florence-2-DocVQA

    2cc9b35547030333924ea6481df35f9a.png  

    演示示例

总结

本文,我们展示了如何有效地针对自定义数据集微调 Florence-2,以在短时间内在全新任务上取得令人眼前一亮的性能。对于那些希望在设备上或在生产环境中经济高效地部署小模型的人来说,该做法特别有价值。我们鼓励开源社区利用这个微调教程,探索 Florence-2 在各种新任务中的巨大潜力!我们迫不及待地想在 🤗 Hub 上看到你的模型!

有用资源

  • 视觉语言模型详解https://hf.co/blog/zh/vlms

  • 微调 Colabhttps://colab.research.google.com/drive/1hKDrJ5AH_o7I95PtZ9__VlCTNAo1Gjpf?usp=sharing

  • 微调 Github 代码库https://github.com/andimarafioti/florence2-finetuning

  • Florence-2 推理 Notebookhttps://hf.co/microsoft/Florence-2-large/blob/main/sample_inference.ipynb

  • Florence-2 DocVQA 演示https://hf.co/spaces/andito/Florence-2-DocVQA

  • Florence-2 演示https://hf.co/spaces/gokaygo

感谢 Pedro Cuenca 对本文的审阅。


英文原文: https://hf.co/blog/finetune-florence2

原文作者: Andres Marafioti,Merve Noyan,Piotr Skalski

译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的应用及大规模模型的训练推理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/802300.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LRC软件、Adobe Lightroom Classic软件多版本下载+LRC教程

简介&#xff1a; Adobe Lightroom Classic&#xff08;简称LR&#xff09;是Adobe Creative Cloud大家庭中的一款专业的图片管理和编辑工具&#xff0c;用于专业摄影师、摄影爱好者以及所有不断优化数码影像的人等。其目标是以丰富的功能提供高效、一致的体验&#xff0c;帮助…

php基础: 三角形

包含&#xff1a;左三角、左上三角、右三角、右上三角、等腰三角、倒等腰三角。注意空格的数量&#xff0c;因为*号后面加了空格 /*** * 左三角形* param $n* return void*/ function triangleLeft($n){echo <pre>;for ($i 1; $i < $n; $i) {for ($j 1; $j < $i…

对服务器进行基本了解(二)

目录 一. 云服务器数据库 1.查看MYSQL版本 2.查看mysql的运行状态 3.运行mysql 4. 进入mysql的用户 5. 更改用户密码 6. 查找mysql端口号 7. 创建一个数据库 8. 查看用户 9. 查看数据库 10. 显示数据库的表 11. 修改用户的host 12. 对用户赋权 13. 开放指定端…

java.lang.IllegalArgumentException: Illegal character in path at index 40解决方案

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

数据库内核研发学习之路(三)创建postgres内置函数

本章之前已经讲明白了我们的postgres如何进行编译安装&#xff0c;这是很重要的一步&#xff0c;接下来就是学会对postgres进行小的改动&#xff0c;然后保证依然能够顺利编译安装运行&#xff01; 本章续讲内容如何创建一个内置函数。 1、内置函数和用户自定义函数的区别 熟…

基于Python+Flask+SQLite的豆瓣电影可视化系统

FlaskMySQLEcharts 基于PythonFlaskSQLite的豆瓣电影可视化系统 Echarts 不支持登录注册&#xff0c;并且信息存储在数据库中 不含爬虫代码&#xff0c;或爬虫代码已失效 简介 基于PythonFlaskMySQL的豆瓣电影可视化系统&#xff0c;采用Echart构建图表&#xff0c;支持自定…

【数据结构】二叉树全攻略,从实现到应用详解

​ &#x1f48e;所属专栏&#xff1a;数据结构与算法学习 &#x1f48e; 欢迎大家互三&#xff1a;2的n次方_ ​ &#x1f341;1. 树形结构的介绍 树是一种非线性的数据结构&#xff0c;它是由n&#xff08;n>0&#xff09;个有限结点组成一个具有层次关系的集合。把它叫做…

动手学深度学习6.3 填充和步幅-笔记练习(PyTorch)

以下内容为结合李沐老师的课程和教材补充的学习笔记&#xff0c;以及对课后练习的一些思考&#xff0c;自留回顾&#xff0c;也供同学之人交流参考。 本节课程地址&#xff1a;填充和步幅_哔哩哔哩_bilibili 代码实现_哔哩哔哩_bilibili 本节教材地址&#xff1a;6.3. 填充和…

旗晟巡检机器人的应用场景有哪些?

巡检机器人作为现代科技的杰出成果&#xff0c;已广泛应用于各个关键场景。从危险的工业现场到至关重要的基础设施&#xff0c;它们的身影无处不在。它们以精准、高效、不知疲倦的特性&#xff0c;担当起保障生产、守护安全的重任&#xff0c;为行业发展注入新的活力。那么&…

【学习笔记】无人机系统(UAS)的连接、识别和跟踪(一)-3GPP TS 23.256 技术规范概述

3GPP TS 23.256 技术规范&#xff0c;主要定义了3GPP系统对无人机&#xff08;UAV&#xff09;的连接性、身份识别、跟踪及A2X&#xff08;Aircraft-to-Everything&#xff09;服务的支持。 3GPP TS 23.256 技术规范&#xff1a; 以下是文档的核心内容总结&#xff1a; UAV系…

深入理解PostgreSql域类型(Domain),灵活定义数据约束,让表结构设计更加严谨

在PostgreSQL中&#xff0c;域&#xff08;Domain&#xff09;是一种用户定义的数据类型&#xff0c;它基于系统内已存在的数据类型&#xff0c;并可以附加约束条件。使用域可以增强数据的完整性和一致性&#xff0c;因为它允许开发者对特定列设定更为具体的规则&#xff0c;比…

Mysql缓存调优的基本知识(附Demo)

目录 前言1. 配置2. 缓存3. 策略 前言 基本的知识推荐阅读&#xff1a; java框架 零基础从入门到精通的学习路线 附开源项目面经等&#xff08;超全&#xff09;Mysql优化高级篇&#xff08;全&#xff09;Mysql底层原理详细剖析常见面试题&#xff08;全&#xff09; MySQL…

集群架构-web服务器(接入负载均衡+数据库+会话保持redis)--15454核心配置详解

紧接着前面的集群架构深化—中小型公司&#xff08;拓展到大型公司业务&#xff09;–下面图简单回顾一下之前做的及故障核心知识总结&#xff08;等后期完全整理后&#xff0c;上传资源希望能帮大家&#xff09; web集群架构-接入负载均衡部署web02服务器等 web集群-搭建web0…

Abaqus基于CT断层扫描的三维重建插件CT2Model 3D

插件介绍 AbyssFish CT2Model 3D V1.0 插件可将采用X射线等方法获取的计算机断层扫描&#xff08;CT&#xff09;图像在Abaqus有限元软件内进行三维重建&#xff0c;进而高效获取可供模拟分析的有限元模型。插件可用于医学影像三维重构、混凝土细观三维重建、岩心数字化等领域…

数据结构-C语言-排序(2)

代码位置&#xff1a;test-c-2024: 对C语言习题代码的练习 (gitee.com) 一、前言&#xff1a; 1.1-排序定义&#xff1a; 排序就是将一组杂乱无章的数据按照一定的规律&#xff08;升序或降序&#xff09;组织起来。(注&#xff1a;我们这里的排序采用的都为升序) 1.2-排序分…

2-36 基于matlab的流行学习算法程序

基于matlab的流行学习算法程序。通过GUI的形式将MDS、PCA、ISOMAP、LLE、Hessian LLE、Laplacian、Dissusion MAP、LTSA八种算法。程序以可视化界面进行展示&#xff0c;可直接调用进行分析。多种案例举例说明八种方法优劣&#xff0c;并且可设置自己数据进行分析。程序已调通&…

STM32智能工业自动化监控系统教程

目录 引言环境准备智能工业自动化监控系统基础代码实现&#xff1a;实现智能工业自动化监控系统 4.1 数据采集模块 4.2 数据处理与控制模块 4.3 通信与网络系统实现 4.4 用户界面与数据可视化应用场景&#xff1a;工业自动化与管理问题解决方案与优化收尾与总结 1. 引言 智能…

百度人脸识别Windows C++离线sdk C#接入

百度人脸识别Windows C离线sdk C#接入 目录 说明 设计背景 • 场景特点&#xff1a; • 客户特点&#xff1a; • 核心需求&#xff1a; SDK 包结构 效果 代码 说明 自己根据SDK封装了动态库&#xff0c;然后C#调用。 功能接口 设计背景 • 场景特点&#xff1a; -…

【漏洞复现】深信服 行为感知系统 日志中心 c.php 远程命令执行

免责声明&#xff1a; 本文内容旨在提供有关特定漏洞或安全漏洞的信息&#xff0c;以帮助用户更好地了解可能存在的风险。公布此类信息的目的在于促进网络安全意识和技术进步&#xff0c;并非出于任何恶意目的。阅读者应该明白&#xff0c;在利用本文提到的漏洞信息或进行相关测…

PTA - 嵌套列表求和

使用递归函数对嵌套列表求和 函数接口定义&#xff1a; def sumtree(L) L是输入的嵌套列表。 裁判测试程序样例&#xff1a; /* 请在这里填写答案 */L eval(input()) print(sumtree(L)) # 调用函数 输入样例&#xff1a; 在这里给出一组输入。例如&#xff1a; [1,[2…