源码分析之blip2的ITC和ITM的具体实现

引言:很久之前读blip2,对ITC和ITM大致有个印象,一个对比学习,一个图文匹配的二分类,咋一听好像没什么难理解的,最近好好看了一下源码,觉得实现上很巧妙,值得与诸君共享

这里小编没有一句一句分析,直接源码+注释,觉得这样看比较方便,因为只分析ITC和ITM,所以这里只放了blip2里面的Blip2Qformer的forward函数内容,如有出入,还请各位小伙伴留言斧正!

Image-text Contrastive

###============== Image-text Contrastive ===================###
    
    """
    因为在多张卡上训练,所以这里需要将所有卡上的图像特征收集起来,维度为[batch_size*num_gpu, num_query_tokens, embed_dim],
    其中,num_query_tokens是视觉tokens数量,embed_dim是维度
    """
    image_feats_all = concat_all_gather(
        image_feats
    )  # [batch_size*num_gpu, num_query_tokens, embed_dim]

    # 文本这一步操作与上述同理
    text_feat_all = concat_all_gather(text_feat)  # [batch_size*num_gpu, embed_dim]

    """
    求图像与所有文本的相似度
    这里image_feats.unsqueeze(1)之后的维度是[batch_size,1, num_query_tokens, embed_dim]
    text_feat_all.unsqueeze(-1)之后的维度是[batch_size*num_gpu, embed_dim,1]
    为了求每个图像跟所有文本的相似度,图像特征[batch_size,1, num_query_tokens, embed_dim]第2个维度会被广播到batch_size*num_gpu变成[batch_size*,batch_size*num_gpu, num_query_tokens, embed_dim]
    然后矩阵乘法会沿着image_feats和text_feat_all最后两个维度进行相乘,embed_dim维度相乘消失,所以得到的结果为[batch_size,batch_size*num_gpu, num_query_tokens,1]
    相乘之后的结果再squeeze()就得到了[batch_size,batch_size*num_gpu, num_query_tokens]
    """
    sim_q2t = torch.matmul(
        image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)
    ).squeeze()  # [batch_size, batch_size*num_gpu, num_query_tokens]

    """
    max(-1)表示在最后一个维度上,寻找最大值
    也就是说,对每个图像到文本的相似度,选取所有num_query_tokens中的最大值,sim_i2t最终的维度为[batch_size, batch_size*num_gpu]
    """
    sim_i2t, _ = sim_q2t.max(-1)

    # 通过温度参数self.temp进行相似度的缩放控制
    sim_i2t = sim_i2t / self.temp

    """
    求文本与所有图像的相似度
    text_feat.unsqueeze(1).unsqueeze(1)之后的维度为[batch_size,1,1,embed_dim]
    image_feats_all.permute(0, 2, 1)交换后面两个维度之后的特征维度为[batch_size*num_gpu, embed_dim, num_query_token]
    同理,文本特征[batch_size,1,1,embed_dim]会广播第2个维度到batch_size*num_gpu,变成[batch_size,batch_size*num_gpu,1,embed_dim]
    然后最后两个维度做矩阵乘法得到[batch_size,batch_size*num_gpu,1,num_query_token]
    squeeze()之后的特征为[batch_size,batch_size*num_gpu,num_query_token]
    """
    sim_t2q = torch.matmul(
        text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)
    ).squeeze()

    # 对每个文本到图像的相似度,选取所有num_query_tokens中的最大值,sim_i2t最终的维度为[batch_size, batch_size*num_gpu]
    sim_t2i, _ = sim_t2q.max(-1)
    sim_t2i = sim_t2i / self.temp

    rank = dist.get_rank()
    bs = image.size(0)

    """
    torch.linspace(start, end, steps, dtype=int)的作用是生成从 start 到 end 之间的 steps 个数值,并返回一个 1D 张量
    这里用来生成多 GPU 训练中的标签(targets)索引,targets维度维[batch_size]
    每个 GPU 进程(或 rank)负责处理自己的 batch,并为它分配唯一的索引序列
    """
    targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(
        image.device
    )

    if "image_id" in samples.keys():  # coco retrieval finetuning
        # 对于包含图像 ID 的样本,使用基于相似度的目标分布计算损失
        image_ids = samples["image_id"].view(-1, 1)
        image_ids_all = concat_all_gather(image_ids)
        pos_idx = torch.eq(image_ids, image_ids_all.t()).float()
        sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)
        sim_targets = 0.9 * sim_targets + 0.1 * torch.ones_like(sim_targets) / sim_targets.size(1)

        loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean()
        loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean()
        loss_itc = (loss_t2i + loss_i2t) / 2
    else:
        """
        否则,使用交叉熵计算损失
        sim_i2t维度为[batch_size, batch_size*num_gpu],targets维度维[batch_size]
        对于sim_i2t每个batch,targets都有唯一一个在 0 到 batch_size * num_gpu - 1 之间真实值,因此可以计算交叉熵,从而达到让正例更接近,负例更远的效果
        """
        loss_itc = (
                           F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)
                           + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
                   ) / 2

Image-text Matching

###============== Image-text Matching ===================###
    # 同上述
    text_input_ids_world = concat_all_gather(text_tokens.input_ids)
    text_attention_mask_world = concat_all_gather(text_tokens.attention_mask)
    image_embeds_world = all_gather_with_grad(image_embeds)

    with torch.no_grad():
        # 当有image_id时,作者把相似度矩阵里面image_ids相匹配的都mask掉了,即在后面计算的时候忽略样本自身的匹配
        if "image_id" in samples.keys():
            mask = torch.eq(image_ids, image_ids_all.t())
            sim_t2i.masked_fill_(mask, -10000)
            sim_i2t.masked_fill_(mask, -10000)
        else:
            # 与上面同理,将当前 GPU 进程处理的样本的索引范围填充为 -10000,即在后面计算的时候忽略样本自身的匹配
            sim_t2i[:, rank * bs: rank * bs + bs].fill_diagonal_(-10000)
            sim_i2t[:, rank * bs: rank * bs + bs].fill_diagonal_(-10000)

        # 被masked的值和被fill_diagonal_(-10000),经过softmax之后都会接近于0
        weights_t2i = F.softmax(sim_t2i, dim=1)
        weights_i2t = F.softmax(sim_i2t, dim=1)

    # 为每个文本选择一个负样本图像
    image_embeds_neg = []
    for b in range(bs):
        """
        对每个batch的数据随机选择一个负样本
        torch.multinomial从给定的概率分布中进行多项式分布抽样
        weights_t2i[b]中值大的数,被采样的概率就大,上述对sim_t2i自身样本进行mask就是为了这里自身样本作为正样本不会被选择
        """
        neg_idx = torch.multinomial(weights_t2i[b], 1).item()
        image_embeds_neg.append(image_embeds_world[neg_idx])
    image_embeds_neg = torch.stack(image_embeds_neg, dim=0)

    # 为每个图像选择一个负样本文本
    text_ids_neg = []
    text_atts_neg = []
    for b in range(bs):
        neg_idx = torch.multinomial(weights_i2t[b], 1).item()
        text_ids_neg.append(text_input_ids_world[neg_idx])
        text_atts_neg.append(text_attention_mask_world[neg_idx])

    text_ids_neg = torch.stack(text_ids_neg, dim=0)
    text_atts_neg = torch.stack(text_atts_neg, dim=0)

    """
    这一步很妙!
    将文本的两个正样本一个负样本进行拼接,为后续二分类做准备
    至于为什么这么拼接,后面你就知道了
    """
    text_ids_all = torch.cat(
        [text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0
    )  # pos, pos, neg
    text_atts_all = torch.cat(
        [text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
        dim=0,
    )

    # 这一步是对query_tokens进行一些处理
    query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
    query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(
        image.device
    )
    attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)

    """
    将图像的两个正样本一个负样本进行拼接,为后续二分类做准备
    注意:文本拼接的顺序是:正样本,正样本,负样本
    图像拼接的顺序是:正样本,负样本,正样本
    它们只有第一个位置都是正样本,也即第一个位置是一对匹配的正例,后面两个位置都是一正一负是不匹配的,这样我们就可以通过判断它们匹不匹配来进行二分类学习,妙哉!
    """
    image_embeds_all = torch.cat(
        [image_embeds, image_embeds_neg, image_embeds], dim=0
    )  # pos, neg, pos
    image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(
        image.device
    )

    # 将拼接后的文本特征,图像特征以及相应的query_tokens输入到bert中进行分类预测
    output_itm = self.Qformer.bert(
        text_ids_all,
        query_embeds=query_tokens_itm,
        attention_mask=attention_mask_all,
        encoder_hidden_states=image_embeds_all,
        encoder_attention_mask=image_atts_all,
        return_dict=True,
    )

    # 取分类预测的结果
    vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]
    vl_output = self.itm_head(vl_embeddings)
    logits = vl_output.mean(dim=1)

    # 生成对应的真实标签,只有第一个batch文本对是匹配的,所以第一个batch的标签设置为1,其他都是0
    itm_labels = torch.cat(
        [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
        dim=0,
    ).to(image.device)

    # 将预测结果和真实值输入到交叉熵损失函数中,进行二分类损失计算
    loss_itm = F.cross_entropy(logits, itm_labels)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/888111.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

在不支持WSL2的Windows环境下安装Redis并添加环境变量的方法

如果系统版本支持 WSL 2 可跳过本教程。使用官网提供的教程即可 官网教程 查看是否支持 WSL 2 如果不支持或者觉得麻烦可以按照下面的方式安装 下载 点击打开下载地址 下载 zip 文件即可 安装 将下载的 zip 文件解压到自己想要解压的地方即可。(注意&#x…

sqli-labs less-17密码重置报错注入

密码重置报错植入 来到首页面我们看到页面提示【password reset】,说明这是更改密码的注入,也就是说我们知道一个账户名,修改他的密码,所以我们可以在passwd处进行注入。 闭合方式 添加单引号 有报错 可以知道闭合方式为单引号…

Leetcode—76. 最小覆盖子串【困难】

2024每日刷题&#xff08;167&#xff09; Leetcode—76. 最小覆盖子串 C实现代码 class Solution { public:string minWindow(string s, string t) {int bestL -1;int l 0, r 0;vector<int> cnt(128);for(const char c: t) {cnt[c];}int require t.length();int m…

OJ在线评测系统 微服务 用分布式消息队列 RabbitMQ 解耦判题服务和题目服务 手搓交换机和队列 实现项目异步化

消息队列解耦 项目异步化 分布式消息队列 分布式消息队列是一种用于异步通信的系统&#xff0c;它允许不同的应用程序或服务之间传递消息。消息队列的核心理念是将消息存储在一个队列中&#xff0c;发送方可以将消息发送到队列&#xff0c;而接收方则可以在适当的时候从队列中…

安卓如何实现双击触摸唤醒点亮屏幕功能-Android framework实战开发

背景 经常有学员朋友在群里问到一个目前市场上常见的功能&#xff1a; 手机待机时候双击屏幕可以唤醒点亮手机屏幕功能 如何实现这个功能&#xff0c;经常有同学在群里求助&#xff0c;今天就刚好来讨论一下这个待机时候双击触摸唤醒点亮屏幕的功能的实现方案。 功能核心方案设…

【微服务】服务注册与发现 - Eureka(day3)

CAP理论 P是分区容错性。简单来说&#xff0c;分区容错性表示分布式服务中一个节点挂掉了&#xff0c;并不影响其他节点对外提供服务。也就是一台服务器出错了&#xff0c;仍然可以对外进行响应&#xff0c;不会因为某一台服务器出错而导致所有的请求都无法响应。综上所述&…

dwceqos网络驱动性能优化

文章介绍 本文会分享一些在QNX系统下对io-pkt-v6-hc驱动模块cpu loading过高问题优化的经验&#xff0c;以及一些调优debug的方法。这些优化措施实施之后可以降低io-pkt-v6-hc在高负载的情况下的cpu loading。本文的调优是基于synopsys公司的dwceqos模块&#xff0c;理论上方法…

【Android 源码分析】Activity生命周期之onPause

忽然有一天&#xff0c;我想要做一件事&#xff1a;去代码中去验证那些曾经被“灌输”的理论。                                                                                  – 服装…

【STM32 HAL库】MPU6050 DMP库移植 与 自检失败的处理

【STM32 HAL库】MPU6050 DMP库移植 与 自检失败的处理 本文参考移植步骤文件配置代码修改inv_mpu.cinv_mpu.hinv_mpu_dmp_motion_driver.c 使用 自检失败怎么处理ret -1改正DEBUG过程 ret -9改正DEBUG过程 本文参考 B站 CSDN 移植步骤 文件配置 新建一个 dmp 文件夹 并将…

【Linux】进程地址空间、环境变量:从理论到实践(三)

&#x1f308; 个人主页&#xff1a;Zfox_ &#x1f525; 系列专栏&#xff1a;Linux 目录 &#x1f680; 前言一&#xff1a;&#x1f525; 环境变量 &#x1f95d; 基本概念&#x1f95d; 常见环境变量&#x1f95d; 查看环境变量方法 二&#xff1a;&#x1f525; 测试 &…

Nat. Commun.:飞秒激光书写受蚂蚁启发的可重构微型机器人集体

背景介绍生物在各种环境中的集体行为十分普遍&#xff0c;它们能够自发有序地完成单个个体难以完成的任务。目前&#xff0c;生物集体的形成主要分为两大类。第一类生物个体之间没有直接接触&#xff0c;如蜜蜂、鱼和鸟类&#xff0c;这导致这些集体不稳定&#xff0c;容易受到…

Linux网络编程 -- 网络基础

本文主要介绍网络的一些基础概念&#xff0c;不涉及具体的操作原理&#xff0c;旨在构建对网络的基础认识。 1、网络的早期发展历程 20世纪50年代 在这一时期&#xff0c;计算机主机非常昂贵&#xff0c;而通信线路和设备相对便宜。为了共享计算机主机资源和进行信息的综合处…

基于图像的3D动物重建与生成

一、背景与目标 3D-Fauna 是一款用于基于图像和视频进行四足动物3D重建与生成的开源方案。自然界展示了复杂的相似性与多样性,该方法通过学习来自网上图片的四足动物的3D形态,能够从单张图片生成可动画化的带有纹理的3D网格模型。其最终目标是通过大量扩展现有的解决方案,实…

数据库(MySQL):使用命令从零开始在Navicat创建一个数据库及其数据表(一).创建基础表

一. 使用工具和命令 1.1 使用的工具 Navicat Premium 17 &#xff1a;“Navicat”是一套可创建多个连接的数据库管理工具。 MySQL版本8.0.39 。 1.2 使用的命令 Navicat中使用的命令 命令命令解释SHOW DATABASES&#xff1b;展示所有的数据库CREATE DATABASE 数据库名称; 创…

基于STM32的数字温度传感器设计与实现

引言 STM32 是由意法半导体&#xff08;STMicroelectronics&#xff09;开发的基于 ARM Cortex-M 内核的微控制器系列&#xff0c;以其强大的处理能力、丰富的外设接口和低功耗著称&#xff0c;广泛应用于嵌入式系统设计中。在这篇文章中&#xff0c;我们将介绍如何基于 STM32…

深度学习:基于MindSpore实现ResNet50中药分拣

ResNet基本介绍 ResNet&#xff08;Residual Network&#xff09;是一种深度神经网络架构&#xff0c;由微软研究院的Kaiming He等人在2015年提出&#xff0c;并且在ILSVRC 2015竞赛中取得了很好的成绩。ResNet主要解决了随着网络深度增加而出现的退化问题&#xff0c;即当网络…

数据结构与算法——动态规划算法简析

1.初步了解动态规划 由于本篇博客属于动态规划的初阶学习&#xff0c;所以大多都是简单的表示&#xff0c;更深层次的学术用语会在之后深度学习动态规划之后出现&#xff0c;本文主要是带各位了解一下动态规划的大致框架 1.1状态表示 通常的我们会开辟一个dp数组来存储需要表示…

015 品牌关联分类

文章目录 后端CategoryBrandEntity.javaCategoryBrandController.javaCategoryBrandServiceImpl.javaCategoryServiceImpl.javaBrandServiceImpl.java删除 npm install pubsub-jsnpm install --save pubsub-js这个错误是由于在尝试安装 pubsub-js 时&#xff0c;npm 发现了项目…

数据结构(栈和队列的实现)

1. 栈&#xff08;Stack&#xff09; 1.1 栈的概念与结构 栈是一种特殊的线性表&#xff0c;其只允许固定的一段插入和删除操作&#xff1b;进行数据插入和删除的一段叫做栈顶&#xff0c;另一端叫栈底&#xff1b;栈中的元素符合后进先出LIFO&#xff08;Last In First Out&…

C++——模拟实现vector

1.查看vector的源代码 2.模拟实现迭代器 #pragma oncenamespace jxy {//模板尽量不要分离编译template <class T>class vector{public:typedef T* iterator;//typedef会受到访问限定符的限制typedef const T* const_iterator;//const迭代器是指向的对象不能修改&#xf…