从零实现CLIP模型

1. 引言

CLIP代表语言图像对比预训练模型,是OpenAI于2021年开发的一个深度学习模型。CLIP模型中图像和文本嵌入共享相同的潜在特征空间,从而能够在两种模式之间直接进行对比学习。这是通过训练模型使相关的图像和文本更紧密地结合在一起,同时将不相关的图像在特征空间距离分开来实现的。

闲话少说,我们直接开始吧!

2. 相关应用

关于CLIP模型的一些应用总结如下:

  • 图像分类和检索:CLIP可以通过将图像与自然语言文本描述关联起来进而可用于图像分类任务。它允许更通用和灵活的图像检索系统,用户可以使用文本查询来在数据库里搜索图像。

  • 内容调节:CLIP可用于通过分析图像和附带文本来识别和过滤不适当或有害的内容,从而调节在线平台上的展示内容。

3. 核心思想

CLIP模型旨在预测一个batchN×N个潜在(img,text)配对具体哪些是实际匹配的。为了实现这一点,CLIP通过图像编码器和文本编码器的联合训练建立了一个多模态嵌入空间。CLIP的损失函数旨在最大化批处理中N个真实配对的图像和文本嵌入之间的余弦相似性,同时最小化N²−N个错误配对的余弦相似度。以下是伪代码(取自原始论文),概述了CLIP的核心实现。
在这里插入图片描述
接着我们将伪代码中每一行的逐步描述,将其转化为使用PyTorch来实现。

4. 网络结构

在进行代码实现之前,我们先来简单回顾下clip模型具体的网络结构:
在这里插入图片描述

ClIP模型使用两种独立的网络结构来作为图像编码和文本编码的主干,其中:

  • image_encoder:负责编码图像的神经网络主干(eg,ResNetVision Transformer等)。
  • text_encoder:表示负责编码文本信息的神经网络架构(eg,CBOWBERT等)。

原始CLIP模型是从零开始训练的,而没有使用预训练的权重来初始化图像编码器和文本编码器,因为它们用于训练其CLIP模型的数据集体量很大(4亿个图像-文本对)。在这篇博客文章的例子中,我们将采取一些不同的做法。我们将从resnet(用于图像)和distilbert(用于文本)模型的预训练权重开始初始化这些部分。

5. 数据输入

该模型每个批次以n个图像和文本对作为输入,其中:

  • I[n,h,w,c]:表示对齐的图像的小批次输入,其中n是batch大小,h是图像高度,w是图像宽度,c是通道数。
  • T[n,l]:表示对齐文本的小批次输入,其中n是batch大小,l是文本序列的长度。

我们的实现中,我们默认batch的大小为128,如下所示:
在这里插入图片描述

6. 特征提取

关于文本和图像的特征提取,这里使用resnet34distilbert来分别提取图像和文本的特征,如下:

  • I_f = image_encoder(I) : 从图像编码器中获取的图像特征表示I_fI_f的大小为[n,d_I],其中d_I是图像特征的维度。
  • T_f=text_encoder(T):从文本编码器中获取的文本特征表示T_fT_f的大小为[n,d_T],其中d_T是文本特征的维度。

在本文实现中,相应的代码如下:

# for encoding images
I_f = models.resnet34(pretrained=True)      
# for encoding captions
T_f= AutoModel.from_pretrained("distilbert-base-multilingual-cased") 

7. 特征映射

接着,我们将相应的文本和图像特征,映射到同一嵌入特征空间,如下:

  • W_i[d_i,d_e]:表示用于将图像特征i_f映射到嵌入特征空间i_e的投影矩阵。W_i的形状大小是[d_i,d_e],其中d_e表示的是联合嵌入特征空间的维度。
  • W_t[d_t,d_e]:表示用于将文本特征t_f映射到相同嵌入空间t_e的投影矩阵。W_t的形状大小是[d_t,d_e]

投影操作可以使用具有两个线性层的神经网络进行编码,其权重是学习的投影矩阵。在大多数情况下,投影权重是唯一可以在新数据集上需要训练的权重。此外,投影层在对齐图像和文本嵌入的尺寸方面发挥着至关重要的作用,确保它们具有相同的维度。

相应的代码实现如下:
在这里插入图片描述

8. 组合

在上一节中,我们将文本和图像特征分别统一到相同的维度,接着我们将上述相关组件进行整合:

  • I_e = l2_normalize(np.dot(I_f, W_i), axis=1) :在联合嵌入空间I_e中嵌入并归一化图像特征
  • T_e = l2_normalize(np.dot(T_f, W_t), axis=1) :在联合嵌入空间T_e中嵌入并归一化文本特征

接着我们使用以下Pytorch代码来描述图像和文本数据的处理次序。首先,相应的数据通过基本编码器进行处理,然后通过投影层进行处理。最后,为两种模态特征进行嵌入归一化化并返回。如下:

在这里插入图片描述

9. 余弦相似度

接着在嵌入空间,我们来计算文本图像嵌入特征的相似度:

  • logits = np.dot(I_e, T_e.T) * np.exp(t):用以计算图像和文本对在联合嵌入空间的特征余弦相似度,通过可学习的参数t进行缩放。

在我们的例子中,我们考虑暂不使用参数t,代码如下:

logits = T_e @ T_e.T

10. 损失函数

CLIP使用对比损失用以将相关图像和文本在嵌入特征空间拉近,同时将不相关的图像和文本距离拉远。

  • labels = np.arange(n): 用以生成表示batch索引的真值标签。
  • loss_i = cross_entropy_loss(logits, labels, axis=0):用以计算图像特征和真值标签的损失
  • loss_t = cross_entropy_loss(logits, labels, axis=1):用以计算文本特征和真值标签的损失
  • loss = (loss_i + loss_t)/2:计算图像和文本损失的加权平均值。

代码实现如下:

在这里插入图片描述

11. 构建完整模型

将所有不同的部件组合在一起,最终的自定义CLIP模型如下所示:

在这里插入图片描述

12. 构建数据集

我们的自定义CLIP模型将使用flickr30k数据集进行训练。该数据集包括31000多张图像,每张图像至少有5个独立的人工生成文本描述。在这个例子中,我们将为每个图像使用两个标题,总共有62000个图像和文本对用于训练。 代码实现如下:
在这里插入图片描述
上述模型关键常数包括用于学习表示特征空间的维度embed_dim, 用于transformer特征维度的transformer_embed_dim和用于文本输入长度的max_len。所选的text_modeldistilbert base multilanguage-cased。用以训练的模型的epoch为3,同时batch_size的大小为128,这些常数将用于模型构建和训练。如下所示:
在这里插入图片描述

13. 数据集测试用例

DataLoader是为训练期间的高效迭代而设置的,提供图像文本对的迭代访问。调用代码如下:

# Create the DataLoader
clip_dataloader = DataLoader(flickr30k_custom_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

以下是数据集中一个批次中的图像文本对的示例:

import numpy as np
import matplotlib.pyplot as plt
# Create an iterator from the dataloader
data_iter = iter(clip_dataloader)

# Get one batch
batch = next(data_iter)

image = batch["image"][0]  # get one image from the batch
caption = batch["caption"][0]  # get one text from the batch

# Convert the image tensor to a NumPy array and permute dimensions
image_np = np.transpose(image.numpy(), (1, 2, 0))

# Display the image and caption
plt.imshow(image_np)
plt.title(f"Caption: {caption}")
plt.show()

运行结果如下:
在这里插入图片描述

14. 优化器选择

此外,我们还需要指定在整个训练过程中需要优化的参数。上文中我们已经固定了文本和图像编码器的特征提取层,那么只有与投影层相关的参数才会在新的数据集上进行训练。

# Create an instance of your model
model = CustomModel().to(device)

# Define optimizer
optimizer = torch.optim.Adam([
    {'params': model.vision_encoder.parameters()},
    {'params': model.caption_encoder.parameters()}
], lr=model.lr)

15. 模型训练

我们使用Tesla T4的GPU机器进行3个epoch的训练,相应的训练代码如下:
在这里插入图片描述

执行上述训练代码,可以得到训练过程如下:
在这里插入图片描述

16. 总结

总之,这篇博客文章探讨了CLIP模型,揭示了其广泛应用的潜力。随着我们对CLIP应用的了解,很明显,它的影响远远超出了最初的预期,为不同领域的创新解决方案铺平了道路。

您学废了嘛?

完整代码:戳我

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/299678.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于原子搜索算法优化的Elman神经网络数据预测 - 附代码

基于原子搜索算法优化的Elman神经网络数据预测 - 附代码 文章目录 基于原子搜索算法优化的Elman神经网络数据预测 - 附代码1.Elman 神经网络结构2.Elman 神经用络学习过程3.电力负荷预测概述3.1 模型建立 4.基于原子搜索优化的Elman网络5.测试结果6.参考文献7.Matlab代码 摘要&…

Linux:nginx设置网站https

http和https的区别 http: 80 https: 443 这种协议比http协议要安全,因为传输数据是经过加密的 HTTPS简介 HTTPS其实是有两部分组成:HTTP SSL / TLS,也就是在HTTP上又加了一层处理加密信息的模块。服务端和客户端的信息传输都会通过…

在IDEA中使用git分支进行开发然后合并到Master分支,2022.1.x版本

在实际开发过程中,为了避免因为在开发中出现的问题以及方便发布版本,如果是多版本发布的情况相下,我们通常需要采用分支进行开发,这个时候,我们就需要了解git分支的相关知识点了,本篇博客也是博主在实际公司…

【SpringCloud】之配置中心(进阶使用)

🎉🎉欢迎来到我的CSDN主页!🎉🎉 🏅我是君易--鑨,一个在CSDN分享笔记的博主。📚📚 🌟推荐给大家我的博客专栏《SpringCloud开发之远程消费》。🎯&a…

MMFF-NET:多层次多尺度特征融合的弱光图像增强网络

这是我去年的工作,我录用的第一篇SCI,很拉,3区。今年中科院新版分区,变成4区了。很遗憾。后面会持续给大家更新我的文章以及我的内容。硕士阶段的东西几乎创新点都很差。 但是对于初学者我希望它有一定的参考价值。 文章链接&am…

利用Type类来获得字段名称(Unity C#中的反射)

使用Type类以前需要引用反射的命名空间: using System.Reflection; 以下是完整代码: public class ReflectionDemo : MonoBehaviour {void Start(){A a new A();B b new B();A[] abArraynew A[] { a, b };foreach(A v in abArray){Type t v.GetTyp…

不带控制器打包exe,转pdf文件时失败的原因

加了注释的两条代码后,控制器会显示一个docx转pdf的进度条。这个进度条需要控制器的实现,如果转exe不带控制器的话,当点击转换为pdf的按钮就会导致程序出错和闪退。 __init__.py文件的入口

分布式事务理论及Seata实践

分布式事务简介 事务是指作为单个逻辑工作单元执行的一系列操作,要么完全地执行,要么完全地不执行。 事务处理可以确保除非事务性单元内的所有操作都成功完成,否则不会永久更新面向数据的资源。事务的四个特征(ACID) …

FineBI实战项目一(3):Kettle实现ETL到数据仓库

目前,finebi_shop_bi 中是没有任何数据的,是一个空的数据库。而后续我们的所有数据分析都将在该数据库中进行。我们第一件事情就是要将 「finebi_shop」数据库中的所有表抽取到「finebi_shop_bi」数据库中。要抽取并装载数据到「finebi_shop_bi」中&…

超维空间M1无人机使用说明书——51、ROS无人机使用AR二维码识别与定位

引言:二维码识别与定位是指ROS通过创建AR标签并且对AR标签进行识别,标签可以由自己任意创建,具体方法会在文中给出,摄像头可以通过识别AR标签大小和姿态获取到标签对应的ID和位置等信息,实现识别与定位 注意&#xff…

Qt/QML编程学习之心得:Linux下Thread线程创建(26)

GUI设计中经常为了不将界面卡死,会用到线程Thread,而作为GUI设计工具,Qt也提供了一个这样的类,即QThread。 QThread对象管理程序中的一个控制线程。线程QThread开始在run()中执行。默认情况下,run()通过调用exec()启动事件循环,并在线程内运行Qt事件循环。 也可以通过…

Camtasia2024苹果Mac电脑版(屏幕录制剪辑软件)

Camtasia Mac2024免费版是一款由TechSmith公司官方进行汉化推出的最新版本,借助Camtasia,您可以轻松记录屏幕并创建优美,专业的视频。记录所有内容-您的整个屏幕或只是一个窗口。或者,添加您已经拥有的视频,图像&#…

python 文件

open """ def open(file: FileDescriptorOrPath, //路径mode: OpenTextMode "r", //设置打开文件的模式 r 以只读方式打开文件。文件的指针将会放在文件的开头。这是默认模式。 w 打开一个文件只用写入。如果该文件已存在则打开文件&#…

一文讲透Python数据分析可视化之直方图(柱状图)

直方图(Histogram)又称柱状图,是一种统计报告图,由一系列高度不等的纵向条纹或线段表示数据分布的情况。一般用横轴表示数据类型,纵轴表示分布情况。通过绘制直方图可以较为直观地传递有关数据的变化信息,使…

【Python从入门到进阶】46、58同城Scrapy项目案例介绍

接上篇《45、Scrapy框架核心组件介绍》 上一篇我们学习了Scrapy框架的核心组件的使用。本篇我们进入实战第一篇,以58同城的Scrapy项目案例,结合实际再次巩固一下项目结构以及代码逻辑的用法。 一、案例网站介绍 58同城是一个生活服务类平台&#xff0c…

msckf_vio在ubuntu20.04中的编译

1.新建catkin workspace文件夹,并在其中新建src文件夹,并将源码clone至src内。 源码地址:https://github.com/KumarRobotics/msckf_vio 目录层级示意如下,build和devel不必新建,后续指令会自动新建。 2. 在编译之前…

java CAS

CAS 在高并发场景,可以使用加锁 或者CAS来保证原子性,但是加锁是很重量级的操作,CAS类似于乐观锁CAS ( Compare and swap )比较并交换,是实现并发算法时常用到的技术,包含三个操作数&#xff1…

LVGL的List控件的触摸按键和实体按键的处理

在LVGL的List控件使用过程中,虽然通过触摸按键选择item,但是有些场景需要实体按键选取item,但是LVGL 的V8.3中没有像Emwin那样有函数选择list item的函数。LVGL中List引入了Group的概念,把列表项都添加到同一个group中。然后通过更…

Linux Capabilities 基础概念与基本使用

目录 1. Linux capabilities 是什么? 2. capabilities 的赋予和继承 线程的 capabilities Permitted* 允许 Effective* 有效 Inheritable* 遗传 Bounding(集合) Ambient 文件的 capabilities Permitted Inheritable Effective 3…

2.4 DEVICE GLOBAL MEMORY AND DATA TRANSFER

在当前的CUDA系统中,设备通常是带有自己的动态随机存取存储器(DRAM)的硬件卡。例如,NVIDIA GTX1080具有高达8 GB的DRAM,称为全局内存。我们将互换使用全局内存和设备内存这两个术语。为了在设备上执行内核,…