【LLM多模态】LLava模型结构和训练过程 | CLIP模型

note

  • CLIP使用了对比学习的方法,即通过正样本(匹配的图像-文本对)和负样本(不匹配的图像-文本对)来训练模型。在训练过程中,模型会尝试最大化正样本对的相似度(比如通过计算余弦相似度),同时最小化负样本对的相似度。
    • CLIP模型在zero-shot学习中表现强大,可以直接用于zero-shot推理,比如将猫的图片emb后,将猪狗猫等类的文本描述也分别emb,计算图片和类别emb之间的相似度,从而进行分类。
    • CLIP-ViT-L/14模型的14表示每个patch的分辨率为14X14,比如在224x224像素的图像上,总共有(224 / 14) x (224 / 14) = 16 x 16 = 256个patch。
  • LLaVA的模型结构非常简单,就是CLIP+LLM(Vicuna,LLaMA结构),利用Vison Encoder将图片转换为[N=1, grid_H x grid_W, hidden_dim]的feature map,然后接一个插值层Projection W,将图像特征和文本特征进行维度对齐。经过Projection后,得到[N=1, grid_H x grid_W=image_seqlen, emb_dim]。然后将 image token embedding和text token embedding合并到一起,作为语言模型的输入,生成描述的文本。
  • 与InstructBLIP或Qwen-VL在数亿甚至数十几亿的图像文本配对数据上训练的、专门设计的视觉重新采样器相比,LLaVA用的是最简单的LMM架构设计,只需要在600K个图像-文本对上,训练一个简单的完全连接映射层即可。

文章目录

  • note
  • CLIP模型
    • CLIP 损失函数
    • CLIP实践
  • LLava模型
    • 论文贡献
    • 模型结构
    • LLaVA两阶段训练
  • LLava 1.5模型
  • LLava 1.6模型
  • Reference

CLIP模型

链接:https://arxiv.org/pdf/2103.00020.pdf

CLIP模型是一个双塔结构,包括一个文本编码器Text Encoder和一个图像编码器Image Encoder。训练数据集的形式为(image, text),对于每个正确匹配的image和text,text是对image的一句正确描述。CLIP模型需要对(image, text)的数据对进行预测,即(image, text)匹配的为1,不匹配的为0。

  • Text Encoder: 对于每个句子, 将其编码成一个隐向量, T i T_i Ti 维度 ( 1 , 512 ) ; N (1,512) ; \mathrm{N} (1,512);N 个句子, 因此有 T 1 T_1 T1 T N T_N TN ,即[N, 512]
  • Image Encoder: 对于每张img, 将其编码成一个隐向量, l i l_i li 维度 ( 1 , 512 ) ; N (1,512) ; \mathrm{N} (1,512);N 张图, 因此有 l 1 l_1 l1 l N l_N lN, 即 [ N , 512 ] [\mathrm{N}, 512] [N,512]

在这里插入图片描述
由于Text Encoder和Image Encoder最后都是输出[N,512]的Tensor,因此可以很方便地计算images和texts两两之间的相似度。CLIP可以选在ResNet或ViT作为Backbone。实验表明,ViT的效果要好于ResNet。

CLIP 损失函数

CLIP采用对称损失函数,简单来说,就是对相似度矩阵,分别从行方向和列方向计算loss,最后取两者的平均。
在这里插入图片描述

  1. 图像到文本(Image-to-Text)

    • 对于每一个图像,模型尝试找出与之对应的文本描述。模型计算该图像特征向量与所有文本特征向量的相似度(通常使用点积),并通过softmax函数将这些相似度转换为概率分布。模型的目标是使得与当前图像真实对应的文本的概率最大化。这通常通过最小化交叉熵损失来实现,其中正样本是图像对应的真实文本。
  2. 文本到图像(Text-to-Image)

    • 对于每一个文本描述,模型尝试找出与之对应的图像。这个过程与图像到文本非常相似,但方向相反。模型计算该文本特征向量与所有图像特征向量的相似度,并通过softmax函数转换为概率分布。模型的目标是使得与当前文本真实对应的图像的概率最大化。这也是通过最小化交叉熵损失来实现,其中正样本是文本对应的真实图像。

数学表示和损失函数实现:
给定批量中有 N N N 个图像和文本对, 损失函数由两部分交叉嫡组成:
L = 1 2 N ( ∑ i = 1 N − log ⁡ e s i , i / τ ∑ j = 1 N e s i , j / τ + ∑ i = 1 N − log ⁡ e s i , i / τ ∑ j = 1 N e s j , i / τ ) L=\frac{1}{2 N}\left(\sum_{i=1}^N-\log \frac{e^{s_{i, i} / \tau}}{\sum_{j=1}^N e^{s_{i, j} / \tau}}+\sum_{i=1}^N-\log \frac{e^{s_{i, i} / \tau}}{\sum_{j=1}^N e^{s_{j, i} / \tau}}\right) L=2N1(i=1Nlogj=1Nesi,j/τesi,i/τ+i=1Nlogj=1Nesj,i/τesi,i/τ)

其中, s i , j s_{i, j} si,j 是图像 i i i 和文本 j j j 的特征向量的点积, τ \tau τ 是一个温度参数。

伪代码:

# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter
# extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]
# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)
# symmetric loss function
labels = np.arange(n)
# 图像到文本的损失函数,第0维度即图片的行维度
loss_i = cross_entropy_loss(logits, labels, axis=0)
# 文本到图像的损失函数
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t)/2

CLIP实践

$ conda install --yes -c pytorch pytorch torchvision cudatoolkit
$ pip install ftfy regex tqdm
$ pip install git+https://github.com/openai/CLIP.git

import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    print("image_features shape:", image_features.shape) # [1, 512]
    text_features = model.encode_text(text)
    print("text_features shape:", text_features.shape) # [3, 512]
    
    logits_per_image, logits_per_text = model(image, text)
    print("logits_per_image shape:", logits_per_image.shape) # [1, 3]
    print("logits_per_text shape:", logits_per_text.shape) # [3, 1]
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]
print("      Label: {}".format(["a diagram", "a dog", "a cat"]))

LLava模型

论文:https://arxiv.org/pdf/2304.08485.pdf
项目:https://llava-vl.github.io/
在这里插入图片描述

论文贡献

  • 多模态指令数据。当下关键的挑战之一是缺乏视觉与语言组成的指令数据。本文提出了一个数据重组方式,使用 ChatGPT/GPT-4 将图像 - 文本对转换为适当的指令格式;
  • 大型多模态模型。研究者通过连接 CLIP 的开源视觉编码器和语言解码器 LLaMA,开发了一个大型多模态模型(LMM)—— LLaVA,并在生成的视觉 - 语言指令数据上进行端到端微调。实证研究验证了将生成的数据用于 LMM 进行 instruction-tuning 的有效性,并为构建遵循视觉 agent 的通用指令提供了较为实用的技巧。使用 GPT-4,本文在 Science QA 这个多模态推理数据集上实现了最先进的性能。
  • 开源。研究者向公众发布了以下资产:生成的多模式指令数据、用于数据生成和模型训练的代码库、模型检查点和可视化聊天演示。

模型结构

LLaVA模型的架构,是将一个预训练的视觉编码器(CLIP ViT-L/14)与一个大规模语言模型(Vicuna)连接在一起。

这两个模型通过一个简单的映射矩阵连接,这个矩阵负责将视觉和语言特征对齐或转换,以便在一个统一的空间内对它们进行操作。在多模态指令跟随数据集上,LLaVA表现出色,跟GPT-4相比,分数达到了85.1%。在Science QA上,LLaVA的准确率刷新了纪录,达到92.53%。

与InstructBLIP或Qwen-VL在数亿甚至数十几亿的图像文本配对数据上训练的、专门设计的视觉重新采样器相比,LLaVA用的是最简单的LMM架构设计,只需要在600K个图像-文本对上,训练一个简单的完全连接映射层即可。
在这里插入图片描述
对于输入图像 X_v,本文使用预训练的 CLIP 视觉编码器 ViT-L/14 进行处理,得到视觉特征 Z_v=g (X_v)。实验中使用的是最后一个 Transformer 层之前和之后的网格特征。本文使用一个简单的线性层来将图像特征连接到单词嵌入空间中。具体而言,应用可训练投影矩阵 W 将 Z_v 转换为语言嵌入标记 H_v,H_v 具有与语言模型中的单词嵌入空间相同的维度:
H v = W ⋅ Z v , with  Z v = g ( X v ) \mathbf{H}_{\mathrm{v}}=\mathbf{W} \cdot \mathbf{Z}_{\mathrm{v}} \text {, with } \mathbf{Z}_{\mathrm{v}}=g\left(\mathbf{X}_{\mathrm{v}}\right) Hv=WZv, with Zv=g(Xv)

在这里插入图片描述

在LLaVA中,Vision Encoder使用的是CLIP-ViT-L/14,并且,需要注意的是,LLaVA使用最后一层Transformer之前或之后的grid features作为图像表示,而不是CLIP最后的输出层。

总结:LLaVA的模型结构非常简单,就是CLIP+LLM(Vicuna,LLaMA结构),利用Vison Encoder将图片转换为[N=1, grid_H x grid_W, hidden_dim]的feature map,然后接一个插值层Projection W,将图像特征和文本特征进行维度对齐。经过Projection后,得到[N=1, grid_H x grid_W=image_seqlen, emb_dim]。然后将 image token embedding和text token embedding合并到一起,作为语言模型的输入,生成描述的文本。

LLaVA两阶段训练

阶段一:特征对齐预训练。由于从CLIP提取的特征与word embedding不在同一个语义表达空间,因此,需要通过预训练,将image token embedding对齐到text word embedding的语义表达空间。这个阶段冻结Vision Encoder和LLM模型的权重参数,只训练插值层Projection W的权重。

在这里插入图片描述

阶段二:端到端训练。这个阶段,依然冻结Vision Encoder的权重,训练过程中同时更新插值层Projection W和LLM语言模型的权重,训练考虑Multimodal Chatbot和Science QA两种典型的任务。
在这里插入图片描述

LLava 1.5模型

论文:https://arxiv.org/pdf/2310.03744.pdf
LLaVA 1.5和LLaVA在模型架构上基本一致,对LLM模型和插值层做了修改,但是模型效果逐渐开始炸裂:
在这里插入图片描述

  • LLM模型:LLM语言模型升级为Vicuna v1.5 13B,语言模型参数量更大,效果更好
  • Connector:也就是插值层,由原来的单个线性层替换为MLP层(多层线性层叠加)
  • Vision Encoder: 输入图像分辨率由224增大为336,使用CLIP ViT-L/336px,对图像细节理解能力更强
  • 更高质量的数据

LLava 1.6模型

链接:https://llava-vl.github.io/blog/2024-01-30-llava-next/
2024年1月30日,LLaVA发布了1.6版本,模型效果又得到了进一步提升,增强了推理,OCR和世界知识的能力。模型参数量来到了34B,比1.5版本的13B有了巨大的提升,同时模型效果在各项指标都直接有10个点以上的提升。主要改变是:

  • Vision Encoder分辨率:支持更大的分辨率,包括672x672, 336x1344, 1344x336 几种分辨率的输入,并且支持通过图片裁切,编码,合并来实现。
    在这里插入图片描述
  • LLM模型参数量大升级:由LLaVA 1.5的13B参数,增加到最多34B参数。
  • OCR能力提升:更好的推理和OCR能力:通过修改指令数据集实现
  • 更好的视觉对话:在一些场景下,拥有更好的世界知识

Reference

[1] 浙大校友开源多模态大模型LLaVA-1.5,130亿参数8个A100一天训完
[2] [LLaVA系列]CLIP/LLaVA/LLaVA1.5/VILA笔记
[3] 从零实现CLIP模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/635483.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

单细胞分析(Signac): PBMC scATAC-seq 聚类

引言 在本教学指南中,我们将探讨由10x Genomics公司提供的人类外周血单核细胞(PBMCs)的单细胞ATAC-seq数据集。 加载包 首先加载 Signac、Seurat 和我们将用于分析人类数据的其他一些包。 if (!requireNamespace("EnsDb.Hsapiens.v75&qu…

HTTP3

HTTP 状态码:描述了这次HTTP请求是否成功,以及失败的原因。 他们用相应的状态码来描述异常的发现。 常见的状态码 1.200 OK 访问成功。 2.404 NOT Found 客户端请求的资源在服务器这边不存在 URL:ip端口路径查询字符串 3.403 Forbid…

SQL刷题笔记day1

1题目 我的代码: select * from employees order by hire_date desc limit 2,1 标准代码: select * from employees where hire_date (select distinct hire_date from employees order by hire_date desc limit 2,1) 复盘:因为按照入…

vue3插槽solt 使用

背景增加组件的复用性,个人体验组件化还是react 方便。 Vue插槽solt如何传递具名插槽的数据给子组件? 一、solt 原理 知其然知其所以然 Vue的插槽(slots)是一种分发内容的机制,允许你在组件模板中定义可插入的内容…

ARP基本原理

相关概念 ARP报文 ARP报文分为ARP请求报文和ARP应答报文,报文格式如图1所示。 图1 ARP报文格式 Ethernet Address of destination(0–31)和Ethernet Address of destination(32–47)分别表示Ethernet Address of dest…

Mendix 版本 10.10 发布 – 跨平台的功能

​本月,我们将发布遍布整个平台的许多功能,以改善所有用户的生活。Studio Pro 包含多项生活质量改进,例如性能和 Epics/Jira 集成!除此之外,还有一些不错的小部件、MxConnect和AI更新。以及App Insights, Mendix Cloud…

2024年5月19日优雅草蜻蜓K知识付费系统旗舰版v1.0.9进度更新

v1.1.0更新 v1.1.0更新 2024年5月19日优雅草蜻蜓K知识付费系统旗舰版v1.0.9进度更新,首页体育栏目完善新增用户发布页面 开发进度 首页体育栏目完善 新增用户发布页面 新增用户登录完善 新增学习课程页面完善-过往课程数据完成 去掉其他三方登录,新增…

文件的读写

文件操作&#xff1a; 1.打开文件 2.读/写-----操作文件 test.c------写&#xff08;输出&#xff09;------->文件 test.c<------读&#xff08;输入&#xff09;--------文件 文件名&#xff1a;文件路径文件名主干文件后缀 文件指针&#xff1a;FILE* pf;//文件指…

2024年5月24日 十二生肖 今日运势

小运播报&#xff1a;2024年5月24日&#xff0c;星期五&#xff0c;农历四月十七 &#xff08;甲辰年己巳月戊子日&#xff09;&#xff0c;法定工作日。 红榜生肖&#xff1a;龙、牛、猴 需要注意&#xff1a;兔、羊、马 喜神方位&#xff1a;东南方 财神方位&#xff1a;…

在windows中使用wsl下的unbuntu环境

1 unbuntu下载编译环境 编译环境安装命令&#xff1a; sudo apt install gdb sudo apt install gcc sudo apt install g 2 使用vscode正常打开项目&#xff0c;在window中打开的项目&#xff08;官方推荐将项目放在linux中的home目录&#xff09; 但在windows中也可以使用&a…

数据结构---树,二叉树的简单概念介绍、堆和堆排序

树 树的概念和结构 结构 在我们将堆之前&#xff0c;我们先来了解一下我们的树。 我们的堆是属于树里面的一种&#xff0c; 树是一种非线性结构&#xff0c;是一种一对多的一种结构&#xff0c;也就是我们的一个节点可能有多个后继节点&#xff0c;当然也可以只有一个或者没…

Mac安装配置maven

文章目录 前言一、Maven介绍二、下载三、安装四、配置环境变量五、验证六、maven配置配置本地仓库配置远程仓库 七、maven的常用命令总结 前言 Maven是一个项目管理和构建自动化工具&#xff0c;主要服务于基于Java的软件项目。是我们WEB开发过程中必不可少的工具。 一、Maven…

支持北斗卫星授时的工厂NTP电子钟,打造智能化工厂时间环境

在当今数字化、智能化的工业时代&#xff0c;时间的精准性和同步性对于工厂的高效运转至关重要。支持北斗卫星授时的工厂 NTP 电子钟的出现&#xff0c;为打造智能化工厂时间环境提供了有力支撑。 一、应用原因 工厂是一个复杂且高度协作的生产环境&#xff0c;精确一致的时间是…

linux 错误记录(三)

这里的内核源码路径&#xff1a; cd /usr/src/linux-headers-5.4.0-150-generic/ 内核版本&#xff1a; $ uname -r 5.4.0-150-generic 错误现象 ./include/uapi/asm-generic/int-ll64.h:12:10: fatal error: asm/bitsperlong.h: No such file or directory 搜索后是有的 …

深度学习之基于Tensorflow图像分类模型生成与读取

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景与意义 图像分类是计算机视觉领域的重要任务之一&#xff0c;也是深度学习技术的一个核心应用。通过构…

数据库操作(函数)

函数是一段可以直接被另外一段程序调用的程序或代码 一。字符串函数 1.concat(s1,s1....sn)&#xff1a;字符串拼接&#xff0c;将s1&#xff0c;s2&#xff0c;sn拼接为一个字符串 例如&#xff1a; select concat("hello","world"); 2.lower(str&…

Google的MLP-MIXer的复现(pytorch实现)

Google的MLP-MIXer的复现&#xff08;pytorch实现&#xff09; 该模型原论文实现用的jax框架实现&#xff0c;先贴出原论文的代码实现&#xff1a; # Copyright 2024 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may …

【杂七杂八】Huawei Gt runner手表系统降级

文章目录 Step1&#xff1a;下载安装修改版华为运动与健康Step2&#xff1a;在APP里进行配置Step3&#xff1a;更新固件(时间会很长) 目前在使用用鸿蒙4 111版本的手表系统&#xff0c;但是感觉睡眠检测和运动心率检测一言难尽&#xff0c;于是想到是否能回退到以前的版本&…

从需求角度介绍PasteSpider(K8S平替部署工具适合于任何开发语言)

你是否被K8S的强大而吸引&#xff0c;我相信一部分人是被那复杂的配置和各种专业知识而劝退&#xff0c;应该还有一部分人是因为K8S太吃资源而放手&#xff01; 这里介绍一款平替工具PasteSpider&#xff0c;PasteSpider是一款使用c#编写的linux容器部署工具(使用PasteSpider和…

算法学习:快速排序

&#x1f525; 个人主页&#xff1a;空白诗 文章目录 &#x1f680; 引言&#x1f4cc; 快速排序算法核心思想1. 选择基准值&#xff08;Pivot&#xff09;2. 分区操作&#xff08;Partitioning&#xff09;3. 递归排序子序列 &#x1f4cc; JavaScript 实现1. 快速排序主函数2…