当CV遇上transformer(三)Clip模型及源码分析

当CV遇上transformer(三)Clip模型及源码分析

  • 2020年10月,Dosovitskiy首次将纯Transformer的网络结构应用于图像分类任务中(ViT),并取得了当时最优的分类效果,其研究成果是Transformer完全替代标准卷积的首次尝试。随着谷歌提出ViT之后,一大批的vision transformer的工作席卷计算机视觉任务。
  • Open AI在2021年1月份发布DALL-E和CLIP,其中DALL-E是基于文本来生成图像的模型,而CLIP是用文本作为监督信号来训练可迁移的视觉模型,这两个工作也像ViT一样带动了一波新的研究高潮。
    • OpenAI的CLIP模型显著地改变了研究人员处理多模态数据的方式,但是CLIP只是一个开始。
    • 从预训练数据到训练方法和对比损失函数的细节,CLIP家族在过去几年中取得了令人难以置信的进步。
      • ALIGN缩放噪声文本(https://arxiv.org/abs/2102.05918)
      • K-LITE增强外部知识(https://arxiv.org/abs/2204.09222)
      • OpenCLIP研究缩放定律(https://arxiv.org/abs/2212.07143)
      • MetaCLIP优化数据管理(https://arxiv.org/abs/2309.16671)
      • DFN增强数据质量(https://arxiv.org/abs/2309.17425)
  • 虽然原始论文中只对用CLIP进行zero-shot分类做了实验,但其实CLIP的应用价值远不止此。依据李沐团队的CLIP改进工作串讲总结如下(CLIP 改进工作串讲(上)【论文精读·42】):
    • 语义分割
      • LSeg与 CLIP 实现 zero-shot 的方式类似,它通过类别 prompt 作为文本输入,然后计算相似度,也实现了zero-shot 语义分割(https://arxiv.org/pdf/2201.03546)。
      • GroupViT与CLIP类似,利用图像文本对进行无监督的训练,从而让模型进行简单的分割任务(https://arxiv.org/pdf/2202.11094)。
    • 目标检测
      • ViLD把CLIP模型当作teacher去蒸馏作者设计的网络,从而达到利用zero-shot去做目标检测的目的(https://arxiv.org/pdf/2104.13921)。
      • GLIP创造性地将目标检测任务转换为短语定位任务。即对待任意一张训练图片,把标签用句号隔开,拼接成一句话。再结合伪标价签的技术来扩增数据(https://arxiv.org/pdf/2112.03857)。
      • GLIPv2大概思想与GLIP差不多,框架上也很相似,只不过带入了更多的任务和数据(https://arxiv.org/pdf/2206.05836)。
    • 图像生成
      • CLIPasso就是给模型一张真实的照片,模型就能还给他一张最简形式的简笔画。利用CLIP的强大能力,从速写和图像中提炼语义概念 ,将速写定义为一组贝兹曲线(贝塞尔曲线)。然后用一个可微调光栅化器直接针对基于CLIP的感知损失,优化曲线参数(https://arxiv.org/pdf/2202.05822)。
      • CLIP4Clip利用CLIP模型去做视频里的video-text retrival(视频-文本跨模态检索)。因为CLIP天生就很适合做检索工作,它就是在算图像和文本之间的相似性,根据相似性来做ranking、mathcing、retrieve各种类似的任务(https://arxiv.org/pdf/2104.08860)。
      • Action CLIIP利用CLIP研究视频理解动作识别(https://arxiv.org/pdf/2109.08472)。
  • 今天,我们来了解下能够实现zero-shot分类的CLIP模型。
  • 论文链接:Learning Transferable Visual Models From Natural Language Supervision
  • 官方源码:GitHub - openai/CLIP: CLIP

1 CLIP模型的整体架构

1.1 CLIP简述

  • 在计算机视觉领域,最常采用的迁移学习方式就是先在一个较大规模的数据集如ImageNet上预训练,然后在具体的下游任务上再进行微调,这里的预训练是基于有监督训练的,需要大量的数据标注,因此成本较高
  • 在NLP领域,自监督预训练使用十分广泛。在BERT中,以一定比例 mask 掉输入文本中的一些部分,让模型去预测这批被 mask 掉的内容。这样,利用数据本身就可以作为监督(模型要预测的目标来源于数据本身,并非人工构造),无需复杂的人工标注。
  • 不过,我们之前讲过何凯明大神提出的MAE模型,以一定比例随机 mask 掉图片中的一些图像块(patch),然后重建这些部分的像素值,实现了CV领域的自监督预训练。虽然MAE模型提出的时间比CLIP模型晚,但之前也有MoCo和SimCLR等方法实现了CV领域的自监督预训练。不过,对于自监督模型,代理任务往往是辅助来进行表征学习,在迁移到其它数据集时也需要加上新的分类器来进行有监督训练。
  • CLIP模型是Open AI提出的,Open AI喜欢将一切都GPT化(就是做生成式模型)。OpenAI的GPT系列模型,相较于BERT模型,可以直接zero-shot迁移到下游任务。因此CLIP模型,不仅仅可以进行自监督预训练,更重要的是还能实现zero-shot分类
  • **与CV中常用的先预训练然后微调不同,CLIP可以直接实现zero-shot的图像分类,即不需要任何训练数据,就能在某个具体下游任务上实现分类,**这也是CLIP亮点和强大之处。

1.2 CLIP模型架构

1.2.1 CLIP模型架构详解

那么,CLIP模型是如何实现zero-shot分类的呢?

  • CLIP是一种基于对比学习的多模态模型,CLIP的训练数据是文本-图像对:一张图像和它对应的文本描述,通过对比学习,模型能够学习到文本-图像对的匹配关系。

  • 如下图所示,CLIP包括两个模型:Text EncoderImage Encoder,其中Text Encoder用来提取文本的特征,可以采用NLP中常用的text transformer模型;而Image Encoder用来提取图像的特征,可以采用常用CNN模型或者vision transformer。

  • 预训练过程(下图左半部分)

    • 对于一个包含N个文本-图像对的训练batch,将N个文本特征和N个图像特征两两组合,CLIP模型会预测出 N 2 N^2 N2个可能的文本-图像对的相似度,这里的相似度直接计算文本特征和图像特征的余弦相似性(cosine similarity),即下图所示的矩阵;
    • 这里共有N个正样本,即真正属于一对的文本和图像(矩阵中的对角线元素),而剩余的 N 2 − N N^2−N N2N个文本-图像对为负样本;
    • 那么CLIP的训练目标就是最大N个正样本的相似度,同时最小化 N 2 − N N^2−N N2N个负样本的相似度。
  • 推理过程(下图右半部分)

    • 训练后的CLIP其实是两个模型,除了视觉模型外,还有一个文本模型
    • 用CLIP实现zero-shot分类很简单:
      • 根据任务的分类标签构建每个类别的描述文本:A photo of {object},然后将这些文本送入Text Encoder得到对应的文本特征,如果类别数目为N,那么将得到N个文本特征;
      • 将要预测的图像送入Image Encoder得到图像特征,然后与N个文本特征计算缩放的余弦相似度,然后选择相似度最大的文本对应的类别作为图像分类预测结果。
      • 进一步地,我们也可以对得到的余弦相似度计算softmax,得到每个预测类别的概率值。

在这里插入图片描述

1.2.2 利用CLIP模型进行图片zero-shot分类

CLIP模型已经开源,HuggingFace中transformers库中也集成了这个模型,我们可以先利用transformers库,进行简单的zero-shot图片分类测试:

我们先用一个常见的物体-键盘进行分类:
在这里插入图片描述

from PIL import Image
from transformers import CLIPProcessor, CLIPModel

# 0、准备一张测试图像
image = Image.open('./keyboard.png')
print('image', image)

# 1、加载预训练模型
model_path = '/root/autodl-fs/models/clip-vit-base-patch32'

model = CLIPModel.from_pretrained(model_path)
processor = CLIPProcessor.from_pretrained(model_path)

# 2、相关词选项
text = ["a photo of a computer", "a photo of a mouse", "a photo of a keyboard", "a photo of a cellphone"]

# 3、模型预测
inputs = processor(text=text, images=image, return_tensors='pt', padding=True)
outputs = model(**inputs)

# 4、预测结果归一化
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)

# 5、打印结果
probs = probs.detach().numpy().tolist()

for i in range(len(text)):
    print(text[i], ':', probs[0][i])
a photo of a computer  : 0.009659518487751484
a photo of a mouse     : 0.000540732522495091
a photo of a keyboard  : 0.9897673726081848    # 可以看到键盘的概率最高
a photo of a cellphone : 3.2318232115358114e-05

我们换一张动漫天使图片,并改变prompt进行分类

在这里插入图片描述

# 2、更改相关词选项
# text = ["a photo of a computer", "a photo of a mouse", "a photo of a keyboard", "a photo of a cellphone"]
text = ["a photo of a angle", "a photo of a ghost", "a photo of a cat", "a photo of a airplane"]
# 可以看到angle概率最大,而传统的ResNet在ImageNet预训练后只能分1000类
a photo of a angle    : 0.9529269933700562   
a photo of a ghost    : 0.029617968946695328
a photo of a cat      : 0.00526232598349452
a photo of a airplane : 0.012192689813673496

1.3 CLIP模型的训练、prompt

CLIP模型的训练

  • 其实之前已经有研究用文本来作为监督信号来训练视觉模型,但这些方法难以实现较高的性能,作者认为主要原因就是数据集规模太小。因此为了训练CLIP,OpenAI从互联网收集了4个亿的文本-图像对,论文称之为WebImageText,实现了大力出奇迹。

  • CLIP虽然是多模态模型,但它主要是用来训练可迁移的视觉模型。论文中Text Encoder固定选择一个包含63M参数的text transformer模型

  • 而Image Encoder采用了两种的不同的架构

    • 一是常用的CNN架构ResNet。其中ResNet包含5个不同大小的模型:ResNet50ResNet101RN50x4RN50x16RNx64,后面三个模型是按照EfficientNet缩放规则对ResNet分别增大4x,16x和64x得到。
    • 二是基于transformer的ViT。ViT选择3个不同大小的模型:ViT-B/32ViT-B/16ViT-L/14
    • 所有的模型都训练32个epochs,训练过程采用了一个很大的batch size:32768
      • 由于数据量较大,最大的ResNet模型RN50x64需要在592个V100卡上训练18天
      • 而最大ViT模型ViT-L/14需要在256张V100卡上训练12天,可见训练CLIP需要耗费很大的资源,不过OpenAI是不愁计算资源的公司。
      • 对于ViT-L/14,还在336的分辨率下额外finetune了一个epoch来增强性能,论文发现这个模型效果最好,记为ViT-L/14@336,论文中进行对比实验的CLIP模型也采用这个。
  • CLIP模型实现的伪代码如下:

    在这里插入图片描述

prompt

  • prompt engineering是最近NLP领域比较火的一个研究,核心是通过构建合适prompt(提示)来使预训练模型能够直接应用到下游任务,这和之前的预训练+微调不同。
  • 前面的例子我们采用A photo of {label},但其实也有其它选择,比如我们也可以直接用类别标签。但是如果直接采用类别标签作为文本描述,那么很多文本就是一个单词,缺少具体的上下文,而且也和CLIP的训练数据不太一致,效果上会不如采用A photo of {label}

1.4 部分实验结果

CLIP论文仅正文就长达27页,在30多个数据集上进行了大量实验,这里选择几个进行展示。

1.4.1 和ResNet50对比

  • 下图对比了zero-shot CLIP和ResNet50 linear probing(ImageNet数据上预训练,在加上线性分类层进行finetune)在27个数据集上表现。
  • 其中在16个数据集上CLIP可以超过ResNet50。但是在一些特别的,复杂的或者抽象的数据集上CLIP表现较差,比如卫星图像分类,淋巴结转移检测,在合成场景中计数等。
  • 下图CLIP表现较差的竟然还有MNIST数据集,分类准确度只有88%。通过对CLIP训练数据进行分析,作者发现4亿的训练数据中基本上没有和MNIST比较相似的数据,所以这对CLIP来说就属于域外数据了,表现较差就比较容易理解了。
  • CLIP的zero-shot性能虽然和有监督的ResNet50相当,但是ResNet50距离SOTA效果还很远,作者估计要达到SOTA的效果,CLIP还需要增加1000x的计算量。

在这里插入图片描述

1.4.2 CLIP的few-shot性能

  • 可以看到CLIP的zero-shot和最好的模型(BiT-M)在16-shot下的性能相当;
  • CLIP在16-shot下效果有进一步的提升;
  • 虽然CLIP在few-shot实验中随着样本量增加性能有提升,但是1-shot和2-shot性能比zero-shot还差,这个作者认为主要是CLIP的训练和常规的有监督训练存在一定的差异造成的。

在这里插入图片描述

1.4.3 CLIP在自然分布漂移上表现更鲁棒

  • 论文发现CLIP在自然分布漂移上表现更鲁棒
  • 比如在ImageNet-A数据集上,ResNet101性能只有2.7%,而CLIP能达到77.1%。

在这里插入图片描述

2 CLIP模型代码分析

  • 官方源码:https://github.com/openai/CLIP
  • 值得注意的是,OpenAI并没有开源CLIP的训练代码,因此我们仅使用OpenAI官方的代码来解析推理过程。
  • 如果对训练部分感兴趣,可以参考非官方实现版本:mlfoundations/open_clip: An open source implementation of CLIP

2.1 CLIP推理代码

在这里插入图片描述

  • 我们使用下面代码,进行CLIP前向推理,clip.load会自动下载相关模型权重。
import clip
import torch
from PIL import Image

if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('loading model ...')
    model, preprocess = clip.load("ViT-B/32", device=device, download_root='/root/autodl-fs/models/clip_vit')

    # 图像预处理,input是clip的架构图,预处理后shape=(1, 3, 224, 224)
    image = preprocess(Image.open("./CLIP.png")).unsqueeze(0).to(device)
    # 文本预处理,预处理后shape=(3, 77)
    text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

    with torch.no_grad():
        # logits_per_image shape = (1, 3)
        logits_per_image, logits_per_text = model(image, text)  # 输入模型,执行前向推理
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()  # softmax归一化

    print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

2.1.1 图像预处理

  • clip.load会返回preprocess,其实质是下面_transform代码所示
  • 图像预处理有:Resize、裁剪、转为RGB、ToTensor以及Normalize。
# clip/clip.py
def _transform(n_px):
    return Compose([
        Resize(n_px, interpolation=BICUBIC),
        CenterCrop(n_px),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

2.1.2 文本预处理

  • 每一个句子会padding到77的长度,超过的会截断
# clip/clip.py
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:

    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]

    result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            if truncate:
                tokens = tokens[:context_length]
                tokens[-1] = eot_token
            else:
                raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
        result[i, :len(tokens)] = torch.tensor(tokens)

    return result

2.2 CLIP模型部分

2.2.1 模型初始化

# clip/model.py
class CLIP(nn.Module):
    def __init__(self,
                 embed_dim: int, # 图像被编码的维度与文本相同
                 # vision
                 image_resolution: int,
                 vision_layers: Union[Tuple[int, int, int, int], int],
                 vision_width: int,
                 vision_patch_size: int,
                 # text
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int
                 ):
        super().__init__()

        self.context_length = context_length

        if isinstance(vision_layers, (tuple, list)):
            # 图像编码方式一:使用ResNet结构
            vision_heads = vision_width * 32 // 64
            self.visual = ModifiedResNet(
                layers=vision_layers,
                output_dim=embed_dim,
                heads=vision_heads,
                input_resolution=image_resolution,
                width=vision_width
            )
        else:
            # 图像编码方式一:使用ViT结构
            vision_heads = vision_width // 64
            self.visual = VisionTransformer(
                input_resolution=image_resolution,
                patch_size=vision_patch_size,
                width=vision_width,
                layers=vision_layers,
                heads=vision_heads,
                output_dim=embed_dim
            )
        # 文本编码使用Transformer
        self.transformer = Transformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
            attn_mask=self.build_attention_mask()
        )

        self.vocab_size = vocab_size
        # 对单词进行embedding
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        # 可学习的位置编码
        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
        # LayerNorm
        self.ln_final = LayerNorm(transformer_width)
        # text映射
        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
        # 可学习的logit缩放因子
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        
        # 权重初始化
        self.initialize_parameters()

2.2.2 forward函数

  • 首先对输入图像和文本,分别编码
  • 然后进行特征归一化,最后计算图像与文本的相似度
    # clip/model.py
    
    def forward(self, image, text):
        # 1、对输入图像和文本,分别编码
        image_features = self.encode_image(image)  # 最终的image_features:(1,512)
        text_features = self.encode_text(text)     # 最终的text_features :(3,512)

        # 2、特征归一化
        # normalized features
        image_features = image_features / image_features.norm(dim=1, keepdim=True) 
        text_features = text_features / text_features.norm(dim=1, keepdim=True)

        # 3、求解图像与文本的相似度,获得图像与文本匹配结果
        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()   # 对logit进行缩放
        # 特征相乘获得相似度 [batch_img, batch_text]
        # logits_per_image shape = [1, 3]
        logits_per_image = logit_scale * image_features @ text_features.t() 
        logits_per_text = logits_per_image.t()

        # shape = [global_batch_size, global_batch_size]
        return logits_per_image, logits_per_text

2.2.3 图像编码

  • 这里讲解利用VisionTransformer进行图像编码
  • 核心就是利用卷积将一张图像转变为一个序列,可以参考:当CV遇上transformer(一)ViT模型
    # clip/model.py
    def encode_image(self, image):
        return self.visual(image.type(self.dtype))  # 转成fp16

2.2.4 文本编码

    # clip/model.py
    def encode_text(self, text):
        # 1、对tokenized text进行编码
        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
        # 2、加上位置编码
        x = x + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        # 3、多层ResidualAttentionBlock
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        # 4、layer_norm
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # x.shape = [3, 77, 512]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        # [batch_size=3, transformer.width=512] @ [512,512] = [3,512]
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
        
        return x

3 利用CLIP预训练模型实现以图搜图

  • 可以使用ResNet或者clip实现对图片的特征抽取
  • 首先,需要利用模型对图片库中图像进行特征提取,并将结果保存起来
  • 当想要搜寻一张图片的相似图片时,需要计算这张图片和图片向量库中的相似度,获取topk个相似图片。
import glob
import os
import numpy as np
from PIL import Image

import torch
import argparse
import timm
import torchvision
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel
"""图像搜索引擎"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_mean_std_of_dataset(dataset_dir):
    """统计数据集的均值和标准差"""
    train_files = glob.glob(os.path.join(dataset_dir, "*.jpg"))

    print(f'total {len(train_files)} files for training')
    result = []
    # 遍历所有的图片
    for file in train_files:
        img = Image.open(file).convert('RGB')
        img = np.array(img).astype(np.uint8)
        # 像素缩放到0-1
        img = img / 255.
        result.append(img)

    # result shape = [BS, H, W, C]
    # 对每个通道求均值和标准差
    mean = np.mean(result, axis=(0, 1, 2))
    std = np.std(result, axis=(0, 1, 2))
    print(f'mean = {mean}, std = {std}')
    return mean, std


def get_args():
    parser = argparse.ArgumentParser(description='Image Search Task')
    parser.add_argument('--input_size', type=int, default=128, help='images input size')
    parser.add_argument('--dataset_dir', default="/root/autodl-fs/data/fruit20/dataset/train", help='images path')
    parser.add_argument('--test_image_dir', default="/root/autodl-fs/data/fruit20/dataset/val", help='test images path')
    parser.add_argument('--save_dir', default="output_dir", help='path to save')
    parser.add_argument('--model_name', default="clip", help='model name: renet50 or renet152 or clip')
    parser.add_argument('--feature_dict_file', default="corpus_feature_dict.npy",
                        help='filename where to save image representations')
    parser.add_argument('--topk', type=int, default=7, help='k most similar images')
    parser.add_argument('--mode', default="extract", help='extract or predict')
    args = parser.parse_args()
    return args

def extract_feature_by_clip(model, preprocess, image_file_path):
    # 1、读取图像并进行预处理
    image = Image.open(image_file_path)
    inputs = preprocess(images=image, return_tensors='pt')
    # 2、将输入传递给模型获取图像的特征
    with torch.no_grad():
        features = model.get_image_features(**inputs)
        vec_features = features.squeeze().cpu().numpy()

        return vec_features

def extract_feature_single(args, model, image_file_path):
    image_rgb = Image.open(image_file_path).convert('RGB')
    image = image_rgb.resize(args.input_size, args.input_size)
    image = torchvision.transforms.ToTensor()(image)
    image = torchvision.transforms.Normalize(mean=[0.47, 0.43, 0.32], std=[0.37, 0.36, 0.34])(image).unsqueeze(0)

    with torch.no_grad():
        features = model.forward_features(image)
        vec_features = model.global_pool(features)
        vec_features = vec_features.squeeze().cpu().numpy()

        return vec_features


def extract_features(args, model, img_path, preprocess):
    all_vectors = {}
    train_files_path = glob.glob(os.path.join(img_path, "*.jpg"))
    train_files_path += glob.glob(os.path.join(img_path, "*.png"))
    for image_file_path in tqdm(train_files_path):
        if args.model_name == "clip":
            # 1、通过clip提取特征
            all_vectors[image_file_path] = extract_feature_by_clip(model, preprocess, image_file_path)
        else:
            # 2、通过ResNet提取特征
            all_vectors[image_file_path] = extract_feature_single(args, model, image_file_path)
    # 将提取出的图像特征保存起来
    os.makedirs(f"./{args.save_dir}/{args.model_name}", exist_ok=True)
    np.save(f"{args.save_dir}/{args.model_name}/{args.feature_dict_file}", all_vectors)
    return all_vectors


def get_similar_matrix(vectors_dict):
    """计算给定向量字典中各个向量之间的相似度,其中相似度的计算采用了向量之间的余弦相似度"""
    # 1、每行代表一个向量
    v = np.array(list(vectors_dict.values())) # [NUM, H]
    # 2、计算相似度矩阵的分子部分
    numerator = np.matmul(v, v.T) # [NUM, NUM]
    # 3、计算相似度矩阵的分母部分(计算每对向量之间的范数乘积)
    denominator = np.matmul(
        np.linalg.norm(v, axis=1, keepdims=True),
        np.linalg.norm(v, axis=1, keepdims=True).T
    ) # [NUM, NUM]
    # 4、到相似度矩阵 sim,其中 sim[i, j] 表示向量 i 和向量 j 之间的相似度
    sim = numerator / denominator
    keys = list(vectors_dict.keys())
    return sim, keys


if __name__ == '__main__':
    args = get_args()

    model = None
    processor = None
    if args.model_name != "clip":
        # 利用renet50 or renet152作为特征抽取器
        model = timm.create_model(args.model_name, pretrained=True)
        model.eval()
    else:
        # 加载openai clip预训练模型
        model_path = '/root/autodl-fs/models/clip-vit-base-patch32'
        model = CLIPModel.from_pretrained(model_path)
        processor = CLIPProcessor.from_pretrained(model_path)

    if args.mode == "extract":
        # 1、利用预训练模型抽取特征,并保存下来
        print(f'use pretrained model {args.model_name} to extract features')
        extract_features(args, model, img_path=args.dataset_dir, preprocess=processor)
    else:
        # 2、以图搜图
        print(f'use pretrained model {args.model_name} to search {args.topk} similar images from corpus')
        test_images = glob.glob(os.path.join(args.test_image_dir, "*.jpg"))
        test_images += glob.glob(os.path.join(args.test_image_dir, "*.png"))

        # 2-1加载图像向量
        all_vectors = np.load(f"./{args.save_dir}/{args.model_name}/{args.feature_dict_file}", allow_pickle=True)
        all_vectors = all_vectors.item()

        # 2-2 提取搜索图像的图像特征
        for image_file_path in tqdm(test_images):
            print(f'reading {image_file_path} ......')
            if args.model_name == "clip":
                all_vectors[image_file_path] = extract_feature_by_clip(model, processor, image_file_path)
            else:
                all_vectors[image_file_path] = extract_feature_single(args, model, image_file_path)
        # 2-3 获取相似度矩阵及相似图片路径
        sims, keys = get_similar_matrix(all_vectors)

        # 2-4 获取topk个相似图片
        result = {}
        for image_file in tqdm(test_images):
            index = keys.index(image_file)
            sim_vec = sims[index]
            indexs = np.argsort(sim_vec)[::-1][1:args.topk]
            sim_imgs, sim_socres = [], []
            for ind in indexs:
                sim_imgs.append(keys[ind])
                sim_socres.append(sim_vec[ind])

            result[image_file] = (sim_imgs, sim_socres)
        print(result)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/632017.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

FebHost:为什么企业需要注册保加利亚.BG域名?

在当今全球化的商业环境中&#xff0c;对于与保加利亚市场息息相关的企业而言&#xff0c;选择合适的域名至关重要。.BG域名作为企业在线身份的重要组成部分&#xff0c;提供了多重利好&#xff0c;成为业内不容忽视的战略资源。 首先&#xff0c;地域标识性强是.BG域名的一大…

ClassificationPrimitive 内部原理

ClassificationPrimitive 内部原理 发明 ClassificationPrimitive的真是个天才。其原理是利用 webgl 的模板缓冲区实现。 渲染两次, 首先是绘制模板, 然后绘制真正的内容。 示意图: function createClass() {const { program, uniforms } WebGLProgram.buildPrograms(gl, …

CST电磁仿真软件什么是Schematic?三维模型和电路协同仿真【小白必学教程】

什么是Schematic? 使用CST Design Studio进行的各种分析&#xff01; Schematic 进行三维仿真时&#xff0c;有时需要将3D模型和电路图放在一起进行仿真分析。比如需要天线和匹配电路协同仿真&#xff0c;两者构成完整的电路图可以系统地分析In/0ut特性。按下3D工作界面下方…

Spring Security实现用户认证一:简单示例

Spring Security实现用户认证一&#xff1a;简单示例 1 原理1.1 用户认证怎么进行和保存的&#xff1f;认证流程SecurityContext保存 2 创建简单的登录认证示例2.1 pom.xml依赖添加2.2 application.yaml配置2.3 创建WebSecurityConfig配置类2.4 测试 1 原理 Spring Security是…

React 第三十八章 React 中的位运算

位运算是一种计算机编程中常用的操作&#xff0c;它直接对二进制位进行操作。二进制&#xff0c;指的就是以二为底的一种计数方式&#xff0c;常见的还有八进制、十进制、十六进制。 十进制0123456789101112131415二进制0000000100100011010001010110011110001001101010111100…

【面试干货】 两个有序数组的合并排序

【面试干货】 两个有序数组的合并排序 1、实现思想2、代码实现 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收藏不迷路&#x1f496; 1、实现思想 使用两个指针分别指向两个数组的起始位置&#xff0c;然后逐个比较两个指针所指向的元素&#xff0c;将较小的元素…

【全开源】场地预定小程序支持微信小程序+微信公众号+H5

XYvenue是基于FastAdminUniApp开发的多场馆场地预定小程序&#xff0c;提供运动场馆运营解决方案&#xff0c;适用于体育馆、羽毛球馆、兵乒球馆、篮球馆、网球馆等场馆。 功能特性 1、场馆管理 可添加多个预约场馆&#xff0c;小程序端切换场馆显示。 2、场地管理 可添加多…

C语言如何删除表中指定位置的结点?

一、问题 如何删除链表中指定位置的结点&#xff1f; 二、解答 删除链表中指定的结点&#xff0c;就像是排好队的⼩朋友⼿牵着⼿&#xff0c;将其中⼀个⼩朋友从队伍中分出来&#xff0c;只需将这个⼩朋友的双⼿从两边松开。 删除结点有两种情况&#xff1a; &#xff08;1&am…

怎么删除pdf中的某一页?五种高效删除方法

怎么删除pdf中的某一页&#xff1f;PDF文件是我们在工作中经常需要处理的一类文件&#xff0c;它的格式很稳定&#xff0c;不易修改。但是&#xff0c;有时候我们可能需要对PDF文件进行编辑&#xff0c;比如删除其中的某一页。本文将为你介绍五种高效的方法&#xff0c;帮助你轻…

python 脚本压缩文件linux 正常,windows 文件夹/文件名称 被加上了上级文件夹名

场景&#xff1a; php 在调用python 脚本&#xff0c;进行文件压缩&#xff08;因为php的压缩大文件总是超时&#xff09;&#xff0c;linux/mac 环境文件/文件夹名压缩前后一致&#xff0c;windows 压缩后 文件/文件夹名被改变为 上级 文件夹原名 原因&#xff1a; window…

短视频批量剪辑,智能素材文案生成,多账号授权私信回复与矩阵发布素材功能合集系统,短视频矩阵助手源码搭建部署源码开源部署方案。

目录 一、短视频矩阵助手系统是什么&#xff1f; 二、短视频矩阵助手系统可以为企业解决什么问题&#xff1f; 短视频矩阵助手可以解决哪些问题&#xff1f; 三、短视频矩阵助手系统功能有哪些&#xff1f; 四、总结 一、短视频矩阵助手系统是什么&#xff1f; 短视频矩阵…

提升LED显示屏散热效能的五大策略

在现代生活中&#xff0c;LED显示屏已成为不可或缺的信息展示工具&#xff0c;其广泛应用于商业广告、公共信息发布、舞台表演等多个领域。然而&#xff0c;随着LED显示屏的长时间运行&#xff0c;散热问题逐渐凸显&#xff0c;不仅影响设备的稳定性和寿命&#xff0c;还可能导…

Python实战开发及案例分析(25)—— 爬山算法

爬山算法&#xff08;Hill Climbing&#xff09;是一种启发式搜索算法&#xff0c;常用于解决优化问题。它的核心思想是从一个初始解开始&#xff0c;不断朝着增益最大的方向移动&#xff0c;直到达到局部最优解。 实现步骤 从初始解开始。在当前解的邻域中找到一个更好的解。…

Java入门基础学习笔记26——break,continue

跳转关键字&#xff1a; break&#xff1a; 跳出并结束当前所在循环的执行。 continue&#xff1a; 用于跳出当前循环中的当次执行&#xff0c;直接进入循环中的下一次执行。 package cn.ensource.loop;public class BreakContinueDemo8 {public static void main(String[] a…

AI大语言模型在公共服务中的应用实例

随着计算机技术的飞速发展&#xff0c;人工智能已经成为了当今科技领域的热门话题。从早期的图灵测试到现在的深度学习和神经网络&#xff0c;人工智能已经取得了令人瞩目的成就。特别是近年来&#xff0c;大数据、云计算、高性能计算等技术的发展为人工智能的研究提供了更加广…

怎么做微信预约链接_微信预约新风尚

在快节奏的现代生活中&#xff0c;我们都渴望找到一种既方便又高效的方式来处理日常事务。无论是预约看病、预约美容&#xff0c;还是预约一场心仪的讲座或活动&#xff0c;我们都希望能够一键搞定&#xff0c;省时省力。今天&#xff0c;就让我来为大家揭秘如何制作一个微信预…

Facebook海外企业户/海外企业三不限户稳定性怎么样?

Facebook是做跨境电商卖家最有效的营销工具之一&#xff0c;不过相对的在Facebook上的广告竞争也会越来越激烈。目前外贸行业发展迅速。Facebook作为每天拥有30亿人口的活跃网络平台&#xff0c;约占全球网络用户的30%。平均来说&#xff0c;它的用户愿意每天花60分钟在平台上浏…

美港通正规股票交易市场人民币突然拉升,市场开启“大风车”模式?

查查配今天上午,市场又开启了“大风车”模式,多个热点轮番拉升。 一则关于地产行业利好的小作文流出,地产产业链上午爆发,租售同权、房地产服务、房地产开发等板块大涨,光大嘉宝、天地源等个股涨停。万科A涨超4%。 美港通证券以其专业的服务和较低的管理费用在市场中受到不少…

【上海生物发酵展精选展商】三门峡市高瑞生物技术有限公司

三门峡市高瑞生物技术有限公司注册成立于2017年2月23日&#xff0c;经营范围是微生物培养基原材料制造、销售。2017年度因场地搬迁、异地重建&#xff0c;公司由“三门峡市高山生物制品有限公司”更名为“三门峡市高瑞生物技术有限公司”。 该公司具有20余年丰富经验的微生物培…