CLIP模型微调简明指南

CLIP 等多模态模型通过将图像等复杂对象与易于理解、生成和解析的文本描述联系起来,开辟了新的 AI 用例。但是,像 CLIP 这样的现成模型可能无法代表特定领域中常见的数据,在这种情况下,可能需要进行微调以使模型适应该领域。

这篇文章展示了如何根据《纽约客》杂志的卡通图像和这些卡通的笑话标题微调 CLIP 模型。它基于 capcon,这是一个与《纽约客》卡通比赛相关的各种任务的数据集。其中一项任务是拍摄一张卡通图像并从可能的标题列表中预测合适的标题。让我们看看如何为这项任务微调 CLIP。

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - AI模型在线查看 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割 

1、数据

数据托管在 gs://datachain-demo/newyorker_caption_contest 上并公开提供,它包含两个部分:

  • images:图像,一个 JPEG 文件文件夹,每个文件代表一张卡通图像。
  • new_yorker_meta.parquet:包含图像元数据的 parquet 文件,包括图像的多种标题选择和正确的标题选择。

为了处理这些数据,我们将使用开源库 datachain,它有助于将此类非结构化数据整理成更结构化的格式(免责声明:我帮助开发了 datachain)。本文中使用的所有代码都可以在 GitHub 上的 Jupyter Notebook 中找到,或者你可以在 Colab 中运行它。

首先,我们从源中读取图像和元数据,然后根据文件名(在元数据中作为一列提供)将它们连接起来:

from datachain import C, DataChain
from datachain.sql.functions import path

img_dc = DataChain.from_storage("gs://datachain-demo/newyorker_caption_contest/images", type="image", anon=True)
meta_dc = DataChain.from_parquet("gs://datachain-demo/newyorker_caption_contest/new_yorker_meta.parquet")
dc = img_dc.mutate(filename=path.name(C("file.path"))).merge(meta_dc, on="filename")

代码首先从目录中的图像创建一个数据集 img_dc,存储每个文件的基本信息,稍后我们将使用这些信息读取图像。然后,它从元数据的 parquet 文件中创建数据集 meta_dc。最后,它根据图像文件名合并这两个数据集。

img_dc 包含一个 file.path 列,其中包含文件的完整路径,而 img_dc.mutate(filename=path.name(C("file.path"))) 仅提取该路径的最后一部分,该部分与 meta_dc 中 filename 列的内容相匹配。合并后的 dc 数据集包含每个图像的文件信息和元数据。

我们可以通过像这样过滤和收集数据来查看数据样本:

sample = dc.filter(C("file.path").endswith("/371.jpeg")).limit(1)
sample_results = list(sample.collect("file", "caption_choices", "label"))

这会将数据限制为以 /371.jpeg 结尾的图像,并仅收集“file”、“caption_choices”、“label”列。结果输出包括一个 ImageFile(见下文)、一个可能的标题列表和一个正确标题字母选择的标签。由于每个图像有多行,并且标题选择不同,因此您最终可能会得到略有不同的结果。

[(ImageFile(source='gs://datachain-demo', path='newyorker_caption_contest/images/371.jpeg', size=25555, version='1719848719616822', etag='CLaWgOCXhocDEAE=', is_latest=True, last_modified=datetime.datetime(2024, 7, 1, 15, 45, 19, 669000, tzinfo=datetime.timezone.utc), location=None, vtype=''),
  ["I feel like we've gotten a little soft, Lex.",
   "Hold on, the Senate Committee on Women's Health is getting out.",
   "I know a specialist, but he's in prison.",
   'Six rounds. Nine lives. You do the math.',
   'Growth has exceeded our projections.'],
  'D')]

我们可以使用 ImageFile 对象的 read() 方法从中获取图像本身,如果你按照笔记本中的说明操作,您可以亲眼看到。在这个示例中,我们有一幅老鼠用枪指着猫的卡通画,正确的标题是选项 D,上面写着“六发子弹。九条命。你算算吧。”

2、应用基础 CLIP 模型

我们可以将 CLIP 应用于这些数据,以预测每个标题的可能性。这类似于 CLIP 的基本架构,它使用对比学习来获取图像并从一批文本标题中辨别出最可能的标题(反之亦然)。

在训练期间,CLIP 将批量图像-文本对作为输入,每个图像都映射到其文本标题。对于每个批次,CLIP 计算每个图像与批次中每个文本的余弦相似度,这样它不仅具有匹配的相似度,还具有每个不匹配的图像-文本对的相似度(见下图)。

然后,它将其视为分类问题,其中匹配被视为正确标签,不匹配被视为不正确的标签。在推理过程中,这可以用作零样本预测器,方法是输入图像和一批标题,CLIP 将为此返回每个标题的概率。

要深入了解 CLIP,请参阅 OpenAI 原始帖子,或者 Chip Huyen 在此处对其工作原理进行了很好的总结。

对于卡通数据集,我们可以输入样本图像和标题选项,以返回每个选项正确匹配的概率。代码如下:

import clip
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
image = example[0].read()
image = preprocess(image).unsqueeze(0).to(device)
text = clip.tokenize(example[1]).to(device)
logits_per_image, logits_per_text = model(image, text)
logits_per_image.softmax(dim=1)[0]

首先,我们将 ViT-B/32 预训练模型和图像预处理器加载到设备上。然后,我们将图像转换为预期的张量输入,并对文本标题进行标记以执行相同操作。接下来,我们在这些转换后的输入上运行模型,以获取图像与每个文本的 logit 相似度分数,最后运行 softmax 函数以获取每个文本标题的相对概率。

输出显示 CLIP 已经可以自信地预测此示例的正确标题,因为标题 D(第四个标题)的概率为 0.9844(如果您自己尝试,您的示例中可能会有不同的标题选择,这可能会导致不同的结果):

tensor([0.0047, 0.0013, 0.0029, 0.9844, 0.0067], grad_fn=<SelectBackward0>)

3、创建训练数据集

现在我们知道如何应用 CLIP 来预测字幕,我们可以构建一个训练数据集来微调模型。让我们获取随机 10 幅图像的相似度分数(您可以将其增加到更大的尺寸,但在这里我们将保持较小尺寸,以便于在笔记本电脑 CPU 上快速跟进)。以下是执行此操作的代码:

from datachain.torch import clip_similarity_scores

train_dc = dc.shuffle().limit(10).save("newyorker_caption_contest_train")
train_dc = train_dc.map(
    func=lambda img_file, txt: clip_similarity_scores(img_file.read(), txt, model, preprocess, clip.tokenize, prob=True)[0],
    params=["file", "caption_choices"],
    output={"scores": list[float]}
)

首先,我们从数据集中随机抽取并保存 10 张图像。然后,我们使用 map() 方法将函数应用于每条记录,并将结果保存为新列。我们使用实用函数 clip_similarity_scores,该函数在一行中执行上一节中的步骤以获取字幕概率。`map()` 函数的输入由 params=["file", "caption_choices"] 定义,输出列由 output={"scores": list[float]} 定义。

对于训练,我们还需要正确字幕的基本事实,因此我们再次使用 map() 计算每条记录的正确字幕的索引,以及该字幕的 CLIP 概率,以便我们了解基线 CLIP 的表现如何:

import string

def label_ind(label):
    return string.ascii_uppercase.index(label)
def label_prob(scores, label_ind):
    return scores[label_ind]
train_dc = (
    train_dc.map(label_ind, params=["label"], output={"label_ind": int})
    .map(label_prob, params=["scores", "label_ind"], output={"label_prob": float})
)
train_dc = train_dc.save()

我们可以运行 train_dc.avg("label_prob") 来获取训练样本正确标题的平均概率。平均值将取决于训练数据集中的随机样本,但您应该看到比上面的样本图像低得多的值,因此其他图像似乎不太容易让基线 CLIP 正确预测。

4、微调

要微调 CLIP,我们需要创建一个 train() 函数来循环训练数据并更新模型:

def train(loader, model, optimizer, epochs=5):
    if device == "cuda":
        model = model.float()
    loss_func = torch.nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        total_loss = 0
        for images, texts, labels in loader:
            optimizer.zero_grad()
            batch_loss = 0
            for image, text, label in zip(images, texts, labels):
                image = image.to(device).unsqueeze(0)
                text = text.to(device)
                label = label.to(device).unsqueeze(0)
                logits_per_image, logits_per_text = model(image, text)
                batch_loss += loss_func(logits_per_image, label)
            batch_loss.backward()
            optimizer.step()
            batch_loss = batch_loss.item()
            total_loss += batch_loss
        print(f"loss for epoch {epoch}: {total_loss}")    

对于每对图像与文本标题的配对,该函数都会计算 logit 相似度得分,使用正确的标签索引应用损失函数,并执行反向传递以更新模型。

这与基本 CLIP 的工作方式非常相似,除了一个区别。基本 CLIP 期望每个批次都包含图像-文本对,其中每幅图像都有一个对应的文本,并且 CLIP 必须从批次中的其他样本中获取不正确的文本以进行对比学习(参见上图)。对于卡通数据集,每幅图像不仅已经具有相应的正确文本标题,而且还具有多个不正确的文本标题。因此,上面的函数不依赖批次中的其他样本进行对比学习,而是仅依赖于为该图像提供的文本标题选择。

要将训练数据输入此函数,我们需要生成一个 PyTorch 数据集和数据加载器,并将加载器与优化器一起传递给 train() 函数:

from torch.utils.data import DataLoader

ds = train_dc.select("file", "caption_choices", "label_ind").to_pytorch(
    transform=preprocess,
    tokenizer=clip.tokenize,
)
loader = DataLoader(ds, batch_size=2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
train(loader, model, optimizer)

上面的代码选择了训练所需的列“file”、“caption_choices”、“label_ind”,然后使用 CLIP 预处理器和标记器调用 to_pytorch(),这将返回一个包含预处理后的图像张量、标记化文本和标签索引的 PyTorch IterableDataset。接下来,代码创建一个 PyTorch DataLoader 和优化器,并将它们传递给 train() 以开始训练。

由于我们使用了一个很小的数据集,我们可以很快看到模型适合样本,并且损失显著减少:

loss for epoch 0: 5.243085099384018
loss for epoch 1: 6.937912189641793e-05
loss for epoch 2: 0.0006402461804100312
loss for epoch 3: 0.0009484810252615716
loss for epoch 4: 0.00019728825191123178

这应该引起人们对过度拟合的警惕,但对于本练习来说,看到 train() 正在做我们期望的事情很有用:从训练数据集中学习正确的标题。我们可以通过使用微调模型计算训练数据中每张图片正确标题的预测概率来确认:

train_dc = train_dc.map(
    func=lambda img_file, txt: clip_similarity_scores(img_file.read(), txt, model, preprocess, clip.tokenize, prob=True)[0],
    params=["file", "caption_choices"],
    output={"scores_fine_tune": list[float]}
)


train_dc = train_dc.map(label_prob, params=["scores_fine_tune", "label_ind"], output={"label_prob_fine_tune": float})

上述代码与微调之前用于计算概率的代码相同。运行 train_dc.avg("label_prob_fine_tune") 输出平均预测概率 >0.99,因此看起来微调按预期工作。

5、结束语

这是一个人工示例,但希望可以让你了解如何微调 CLIP。为了以更稳健的方式解决预测正确标题的任务,你需要获取更大的样本,并根据训练期间未见过的图像和文本的保留样本进行评估。

尝试这样做时,你可能会发现 CLIP 在推广到标题预测问题方面表现不佳,这并不奇怪,因为 CLIP 是为了理解图像内容而不是理解笑话而构建的。CLIP 依赖​​于相对简单的文本编码器,可能值得探索用于该任务的不同文本编码器。

这超出了微调和这篇文章的范围,但现在你已经知道如何训练 CLIP,您可以尝试这个想法,或者提出自己的想法,了解如何将 CLIP 应用于你的多模态用例。


原文链接:CLIP微调简明教程 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/884288.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

8.使用 VSCode 过程中的英语积累 - Help 菜单(每一次重点积累 5 个单词)

前言 学习可以不局限于传统的书籍和课堂&#xff0c;各种生活的元素也都可以做为我们的学习对象&#xff0c;本文将利用 VSCode 页面上的各种英文元素来做英语的积累&#xff0c;如此做有 3 大利 这些软件在我们工作中是时时刻刻接触的&#xff0c;借此做英语积累再合适不过&a…

【Java】虚拟机(JVM)内存模型全解析

目录 一、运行时数据区域划分 版本的差异&#xff1a; 二、程序计数器 程序计数器主要作用 三、Java虚拟机 1. 虚拟机运行原理 2. 活动栈被弹出的方式 3. 虚拟机栈可能产生的错误 4. 虚拟机栈的大小 四、本地方法栈 五、堆 1. 堆区的组成&#xff1a;新生代老生代 …

计算机前沿技术-人工智能算法-大语言模型-最新论文阅读-2024-09-22

计算机前沿技术-人工智能算法-大语言模型-最新论文阅读-2024-09-22 引言: 全球最热销的国产游戏-《黑神话: 悟空》不仅给世界各地玩家们带来愉悦&#xff0c;而且对计算机人工智能研究也带来新的思考。在本期的论文速读中&#xff0c;我们带来一篇关于视觉语言模型&#xff0…

深度解析与解决方案:U盘有盘符但无法打开的困境

引言&#xff1a;U盘困境初现 在日常工作与生活中&#xff0c;U盘作为便携式存储设备&#xff0c;扮演着数据传输与备份的重要角色。然而&#xff0c;不少用户会遇到这样一个棘手问题&#xff1a;U盘在插入电脑后能够正常显示盘符&#xff0c;但尝试打开时却遭遇拒绝访问或提示…

运维,36岁,正在经历中年危机,零基础入门到精通,收藏这一篇就够了

我今年36岁&#xff0c;运维经理&#xff0c;985硕士毕业&#xff0c;目前正在经历中年危机&#xff0c;真的很焦虑&#xff0c;对未来充满担忧。不知道这样的日子还会持续多久&#xff0c;突然很想把这些年的经历记录下来&#xff0c;那就从今天开始吧。 先说一下我的中年危机…

华为LTC流程架构分享

文末附LTC流程管理PPT下载链接~ 前面笔者分享了华为LTC流程相关PPT&#xff0c;应读者需求&#xff0c;今天从架构角度进行再次与读者共同学习下LTC流程架构。 华为LTC流程架构是一个全面且集成的业务流程体系&#xff0c;从线索发现开始&#xff0c;直至收回现金&#xff0c…

浅谈Agent智能体

Agent智能体无疑是24年最为火爆的话题之一&#xff0c;那么什么是Agent智能体&#xff1f;有什么作用&#xff1f;为什么需要Agent智能体&#xff1f; 用下边一张图简单说明一下 每日进步一点点

气膜健身馆:提升运动体验与健康的理想选择—轻空间

近年来&#xff0c;气膜健身馆作为一种新兴的运动场所&#xff0c;正逐渐受到越来越多健身爱好者的青睐。这种独特的建筑形式不仅提供了良好的运动环境&#xff0c;更在健康和运动表现上展现出诸多优势。 优越的空气质量 气膜结构的核心技术通过内外气压差形成稳定的气膜&#…

C++ 9.27

作业&#xff1a; 将之前实现的顺序表、栈、队列都更改成模板类 Stack #include <iostream> using namespace std; template <typename T> class Stack { private: T* arr; // 存储栈元素的数组 int top; // 栈顶索引 int capacity; // 栈的…

【高频SQL基础50题】6-10

目录 1.上级经理已离职的公司员工 2.修复表中的名字 3. 寻找用户推荐人 4.产品销售分析 I 5.平均售价 1.上级经理已离职的公司员工 子查询。 先根据薪水大小查询&#xff0c;再根据manager_id查询该员工是否存在&#xff0c;最后做排序。 # Write your MySQL query st…

Proteus-7.8sp2安装

目录 一、D盘新建空文件夹&#xff0c;名为Proteus。 二、安装软件 三、破解 四、汉化 五、卸载软件 一、D盘新建空文件夹&#xff0c;名为Proteus。 二、安装软件 1.双击P7.8sp2.exe 2.next 三、破解 1.双击 Proteus Pro 7.8 SP2破解 1.0.exe 2. 升级 打开软件&#x…

网站建设中,营销型网站与普通网站有什么区别

营销型网站与普通网站在建站目的、交互设计以及结构优化等方面存在区别。以下是具体分析&#xff1a; 建站目的 营销型网站&#xff1a;以销售和转化为主要目标&#xff0c;通过专业的市场分析和策划来吸引潜在客户&#xff0c;并促使其采取购买行动。普通网站&#xff1a;通常…

8610 顺序查找

### 思路 1. **创建顺序表**&#xff1a;从输入中读取元素个数和元素值&#xff0c;构造顺序表。 2. **顺序查找**&#xff1a;在顺序表中依次查找关键字&#xff0c;找到则返回位置&#xff0c;否则返回0。 ### 伪代码 1. **创建顺序表**&#xff1a; - 动态分配存储空间。…

C. Cards Partition 【Codeforces Round 975 (Div. 2)】

C. Cards Partition 思路&#xff1a; 可以O(n)直接判断&#xff0c;牌组从大到小依次遍历即可。 不要用二分答案&#xff0c;因为答案不一定是单调的 代码: #include <bits/stdc.h> #define endl \n #define int long long #define pb push_back #define pii pair<…

Verilog基础:时序调度中的竞争(四)(描述时序逻辑时使用非阻塞赋值)

相关阅读 Verilog基础https://blog.csdn.net/weixin_45791458/category_12263729.html?spm1001.2014.3001.5482 作为一个硬件描述语言&#xff0c;Verilog HDL常常需要使用语句描述并行执行的电路&#xff0c;但其实在仿真器的底层&#xff0c;这些并行执行的语句是有先后顺序…

重头开始嵌入式第四十四天(硬件 ARM裸机开发)

目录 裸机开发 一、开发背景 二、开发特点 三、开发流程 四、应用领域 使用的软件硬件 软件&#xff1a;keil 硬件&#xff1a;三星S3C2440A JTAG 开发原理 ​编辑 开发步骤 ​编辑 点亮小灯 按键控制亮灭 裸机开发 ARM 裸机开发是指在没有操作系统的情况…

信号处理: Block Pending Handler 与 SIGKILL/SIGSTOP 实验

1. 信号处理机制的 “三张表” kill -l &#xff1a;前 31 个信号为系统标准信号。 block pending handler 三张表保存在每个进程的进程控制块 —— pcb 中&#xff0c;它们分别对应了某一信号的阻塞状态、待处理状态以及处理方式。 block &#xff1a;通过 sigset_t 类型实现&…

【补充】倒易点阵基本性质

&#xff08;1&#xff09;任意倒易矢量 r h k l ∗ h a ∗ k b ∗ l c ∗ \mathbf{r}_{hkl}^* h\mathbf{a^*} k\mathbf{b^*} l\mathbf{c^*} rhkl∗​ha∗kb∗lc∗必然垂直于正空间中的(hkl)晶面。 正空间中的(hkl)晶面的法向是[hkl]&#xff0c;和坐标轴的交点为A、B、…

Steam黑神话悟空禁止更新进入游戏的解决方案

首先打开该网站&#xff1a;https://steamdb.info/app/2358720/ 2358720即为游戏ID 网页下翻&#xff0c;找到更新历史&#xff1a;https://steamdb.info/app/2358720/history/ 然后在Steam的steamapps下&#xff0c;找到后缀为2358720的文件&#xff0c;右击记事本打开 将St…

解决银河麒麟V10向日葵远程连接断开问题

解决银河麒麟V10向日葵远程连接断开问题 方法一&#xff1a;重启系统方法二&#xff1a;执行xhost 命令 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收藏不迷路&#x1f496; 当你在银河麒麟桌面操作系统V10上使用向日葵进行远程连接时&#xff0c;如果遇到频繁断…