当CV遇上transformer(三)Clip模型及源码分析
- 2020年10月,Dosovitskiy首次将纯Transformer的网络结构应用于图像分类任务中(ViT),并取得了当时最优的分类效果,其研究成果是Transformer完全替代标准卷积的首次尝试。随着谷歌提出ViT之后,一大批的vision transformer的工作席卷计算机视觉任务。
- Open AI在2021年1月份发布DALL-E和CLIP,其中DALL-E是基于文本来生成图像的模型,而CLIP是用文本作为监督信号来训练可迁移的视觉模型,这两个工作也像ViT一样带动了一波新的研究高潮。
- OpenAI的CLIP模型显著地改变了研究人员处理多模态数据的方式,但是CLIP只是一个开始。
- 从预训练数据到训练方法和对比损失函数的细节,CLIP家族在过去几年中取得了令人难以置信的进步。
- ALIGN缩放噪声文本(https://arxiv.org/abs/2102.05918)
- K-LITE增强外部知识(https://arxiv.org/abs/2204.09222)
- OpenCLIP研究缩放定律(https://arxiv.org/abs/2212.07143)
- MetaCLIP优化数据管理(https://arxiv.org/abs/2309.16671)
- DFN增强数据质量(https://arxiv.org/abs/2309.17425)
- 虽然原始论文中只对用CLIP进行zero-shot分类做了实验,但其实CLIP的应用价值远不止此。依据李沐团队的CLIP改进工作串讲总结如下(CLIP 改进工作串讲(上)【论文精读·42】):
- 语义分割
- LSeg与 CLIP 实现 zero-shot 的方式类似,它通过类别 prompt 作为文本输入,然后计算相似度,也实现了zero-shot 语义分割(https://arxiv.org/pdf/2201.03546)。
- GroupViT与CLIP类似,利用图像文本对进行无监督的训练,从而让模型进行简单的分割任务(https://arxiv.org/pdf/2202.11094)。
- 目标检测
- ViLD把CLIP模型当作teacher去蒸馏作者设计的网络,从而达到利用zero-shot去做目标检测的目的(https://arxiv.org/pdf/2104.13921)。
- GLIP创造性地将目标检测任务转换为短语定位任务。即对待任意一张训练图片,把标签用句号隔开,拼接成一句话。再结合伪标价签的技术来扩增数据(https://arxiv.org/pdf/2112.03857)。
- GLIPv2大概思想与GLIP差不多,框架上也很相似,只不过带入了更多的任务和数据(https://arxiv.org/pdf/2206.05836)。
- 图像生成
- CLIPasso就是给模型一张真实的照片,模型就能还给他一张最简形式的简笔画。利用CLIP的强大能力,从速写和图像中提炼语义概念 ,将速写定义为一组贝兹曲线(贝塞尔曲线)。然后用一个可微调光栅化器直接针对基于CLIP的感知损失,优化曲线参数(https://arxiv.org/pdf/2202.05822)。
- CLIP4Clip利用CLIP模型去做视频里的video-text retrival(视频-文本跨模态检索)。因为CLIP天生就很适合做检索工作,它就是在算图像和文本之间的相似性,根据相似性来做ranking、mathcing、retrieve各种类似的任务(https://arxiv.org/pdf/2104.08860)。
- Action CLIIP利用CLIP研究视频理解动作识别(https://arxiv.org/pdf/2109.08472)。
- 语义分割
- 今天,我们来了解下能够实现zero-shot分类的CLIP模型。
- 论文链接:Learning Transferable Visual Models From Natural Language Supervision
- 官方源码:GitHub - openai/CLIP: CLIP
1 CLIP模型的整体架构
1.1 CLIP简述
- 在计算机视觉领域,最常采用的迁移学习方式就是先在一个较大规模的数据集如ImageNet上预训练,然后在具体的下游任务上再进行微调,这里的预训练是基于
有监督训练的,需要大量的数据标注,因此成本较高
。 - 在NLP领域,自监督预训练使用十分广泛。在BERT中,以一定比例 mask 掉输入文本中的一些部分,让模型去预测这批被 mask 掉的内容。这样,利用数据本身就可以作为监督(
模型要预测的目标来源于数据本身,并非人工构造
),无需复杂的人工标注。 - 不过,我们之前讲过何凯明大神提出的MAE模型,以一定比例随机 mask 掉图片中的一些图像块(patch),然后重建这些部分的像素值,实现了CV领域的自监督预训练。虽然MAE模型提出的时间比CLIP模型晚,但之前也有MoCo和SimCLR等方法实现了CV领域的自监督预训练。不过,对于自监督模型,代理任务往往是辅助来进行表征学习,在迁移到其它数据集时也需要加上新的分类器来进行有监督训练。
- CLIP模型是Open AI提出的,Open AI喜欢将一切都GPT化(
就是做生成式模型
)。OpenAI的GPT系列模型,相较于BERT模型,可以直接zero-shot迁移到下游任务。因此CLIP模型,不仅仅可以进行自监督预训练,更重要的是还能实现zero-shot分类
。 - **与CV中常用的先预训练然后微调不同,CLIP可以直接实现zero-shot的图像分类,即不需要任何训练数据,就能在某个具体下游任务上实现分类,**这也是CLIP亮点和强大之处。
1.2 CLIP模型架构
1.2.1 CLIP模型架构详解
那么,CLIP模型是如何实现zero-shot分类的呢?
-
CLIP是一种基于对比学习的多模态模型,CLIP的训练数据是文本-图像对:
一张图像和它对应的文本描述
,通过对比学习,模型能够学习到文本-图像对的匹配关系。 -
如下图所示,CLIP包括两个模型:Text Encoder和Image Encoder,其中Text Encoder用来提取文本的特征,可以采用NLP中常用的text transformer模型;而Image Encoder用来提取图像的特征,可以采用常用CNN模型或者vision transformer。
-
预训练过程(下图左半部分)
- 对于一个包含N个文本-图像对的训练batch,将N个文本特征和N个图像特征两两组合,CLIP模型会预测出 N 2 N^2 N2个可能的文本-图像对的相似度,这里的相似度直接计算文本特征和图像特征的余弦相似性(cosine similarity),即下图所示的矩阵;
- 这里共有N个正样本,即真正属于一对的文本和图像(
矩阵中的对角线元素
),而剩余的 N 2 − N N^2−N N2−N个文本-图像对为负样本; - 那么CLIP的训练目标就是最大N个正样本的相似度,同时最小化 N 2 − N N^2−N N2−N个负样本的相似度。
-
推理过程(下图右半部分)
- 训练后的CLIP其实是两个模型,除了视觉模型外,还有一个文本模型
- 用CLIP实现zero-shot分类很简单:
- 根据任务的分类标签构建每个类别的描述文本:
A photo of {object}
,然后将这些文本送入Text Encoder得到对应的文本特征,如果类别数目为N,那么将得到N个文本特征; - 将要预测的图像送入Image Encoder得到图像特征,然后与N个文本特征计算缩放的余弦相似度,然后选择相似度最大的文本对应的类别作为图像分类预测结果。
- 进一步地,我们也可以对得到的余弦相似度计算softmax,得到每个预测类别的概率值。
- 根据任务的分类标签构建每个类别的描述文本:
1.2.2 利用CLIP模型进行图片zero-shot分类
CLIP模型已经开源,HuggingFace中transformers库中也集成了这个模型,我们可以先利用transformers库,进行简单的zero-shot图片分类测试:
我们先用一个常见的物体-键盘进行分类:
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
# 0、准备一张测试图像
image = Image.open('./keyboard.png')
print('image', image)
# 1、加载预训练模型
model_path = '/root/autodl-fs/models/clip-vit-base-patch32'
model = CLIPModel.from_pretrained(model_path)
processor = CLIPProcessor.from_pretrained(model_path)
# 2、相关词选项
text = ["a photo of a computer", "a photo of a mouse", "a photo of a keyboard", "a photo of a cellphone"]
# 3、模型预测
inputs = processor(text=text, images=image, return_tensors='pt', padding=True)
outputs = model(**inputs)
# 4、预测结果归一化
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
# 5、打印结果
probs = probs.detach().numpy().tolist()
for i in range(len(text)):
print(text[i], ':', probs[0][i])
a photo of a computer : 0.009659518487751484
a photo of a mouse : 0.000540732522495091
a photo of a keyboard : 0.9897673726081848 # 可以看到键盘的概率最高
a photo of a cellphone : 3.2318232115358114e-05
我们换一张动漫天使图片,并改变prompt进行分类
:
# 2、更改相关词选项
# text = ["a photo of a computer", "a photo of a mouse", "a photo of a keyboard", "a photo of a cellphone"]
text = ["a photo of a angle", "a photo of a ghost", "a photo of a cat", "a photo of a airplane"]
# 可以看到angle概率最大,而传统的ResNet在ImageNet预训练后只能分1000类
a photo of a angle : 0.9529269933700562
a photo of a ghost : 0.029617968946695328
a photo of a cat : 0.00526232598349452
a photo of a airplane : 0.012192689813673496
1.3 CLIP模型的训练、prompt
CLIP模型的训练
-
其实之前已经有研究用文本来作为监督信号来训练视觉模型,但这些方法难以实现较高的性能,作者认为主要原因就是数据集规模太小。因此为了训练CLIP,OpenAI从互联网收集了
4个亿的文本-图像对
,论文称之为WebImageText,实现了大力出奇迹。 -
CLIP虽然是多模态模型,但它主要是用来训练可迁移的视觉模型。论文中Text Encoder固定选择一个包含63M参数的text transformer模型
-
而Image Encoder采用了两种的不同的架构
- 一是常用的CNN架构ResNet。其中ResNet包含5个不同大小的模型:ResNet50,ResNet101,RN50x4,RN50x16和RNx64,后面三个模型是按照EfficientNet缩放规则对ResNet分别增大4x,16x和64x得到。
- 二是基于transformer的ViT。ViT选择3个不同大小的模型:ViT-B/32,ViT-B/16和ViT-L/14。
- 所有的模型都训练32个epochs,训练过程采用了一个很大的batch size:32768。
- 由于数据量较大,最大的ResNet模型RN50x64需要在592个V100卡上训练18天
- 而最大ViT模型ViT-L/14需要在256张V100卡上训练12天,可见训练CLIP需要耗费很大的资源,不过OpenAI是不愁计算资源的公司。
- 对于ViT-L/14,还在336的分辨率下额外finetune了一个epoch来增强性能,论文发现这个模型效果最好,记为ViT-L/14@336,论文中进行对比实验的CLIP模型也采用这个。
-
CLIP模型实现的伪代码如下:
prompt
prompt engineering
是最近NLP领域比较火的一个研究,核心是通过构建合适prompt(提示)来使预训练模型能够直接应用到下游任务,这和之前的预训练+微调
不同。- 前面的例子我们采用
A photo of {label}
,但其实也有其它选择,比如我们也可以直接用类别标签。但是如果直接采用类别标签作为文本描述,那么很多文本就是一个单词,缺少具体的上下文,而且也和CLIP的训练数据不太一致,效果上会不如采用A photo of {label}
。
1.4 部分实验结果
CLIP论文仅正文就长达27页,在30多个数据集上进行了大量实验,这里选择几个进行展示。
1.4.1 和ResNet50对比
- 下图对比了zero-shot CLIP和ResNet50 linear probing(ImageNet数据上预训练,在加上线性分类层进行finetune)在27个数据集上表现。
- 其中在16个数据集上CLIP可以超过ResNet50。但是在一些特别的,复杂的或者抽象的数据集上CLIP表现较差,比如卫星图像分类,淋巴结转移检测,在合成场景中计数等。
- 下图CLIP表现较差的竟然还有MNIST数据集,分类准确度只有88%。通过对CLIP训练数据进行分析,作者发现4亿的训练数据中基本上没有和MNIST比较相似的数据,所以这对CLIP来说就属于域外数据了,表现较差就比较容易理解了。
- CLIP的zero-shot性能虽然和有监督的ResNet50相当,但是ResNet50距离SOTA效果还很远,作者估计要达到SOTA的效果,CLIP还需要增加1000x的计算量。
1.4.2 CLIP的few-shot性能
- 可以看到CLIP的zero-shot和最好的模型(BiT-M)在16-shot下的性能相当;
- CLIP在16-shot下效果有进一步的提升;
- 虽然CLIP在few-shot实验中随着样本量增加性能有提升,但是1-shot和2-shot性能比zero-shot还差,这个作者认为主要是CLIP的训练和常规的有监督训练存在一定的差异造成的。
1.4.3 CLIP在自然分布漂移上表现更鲁棒
- 论文发现CLIP在自然分布漂移上表现更鲁棒
- 比如在ImageNet-A数据集上,ResNet101性能只有2.7%,而CLIP能达到77.1%。
2 CLIP模型代码分析
- 官方源码:https://github.com/openai/CLIP
- 值得注意的是,OpenAI并没有开源CLIP的训练代码,因此我们仅使用OpenAI官方的代码来解析推理过程。
- 如果对训练部分感兴趣,可以参考非官方实现版本:mlfoundations/open_clip: An open source implementation of CLIP
2.1 CLIP推理代码
- 我们使用下面代码,进行CLIP前向推理,clip.load会自动下载相关模型权重。
import clip
import torch
from PIL import Image
if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('loading model ...')
model, preprocess = clip.load("ViT-B/32", device=device, download_root='/root/autodl-fs/models/clip_vit')
# 图像预处理,input是clip的架构图,预处理后shape=(1, 3, 224, 224)
image = preprocess(Image.open("./CLIP.png")).unsqueeze(0).to(device)
# 文本预处理,预处理后shape=(3, 77)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
with torch.no_grad():
# logits_per_image shape = (1, 3)
logits_per_image, logits_per_text = model(image, text) # 输入模型,执行前向推理
probs = logits_per_image.softmax(dim=-1).cpu().numpy() # softmax归一化
print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]]
2.1.1 图像预处理
- clip.load会返回preprocess,其实质是下面_transform代码所示
- 图像预处理有:Resize、裁剪、转为RGB、ToTensor以及Normalize。
# clip/clip.py
def _transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
2.1.2 文本预处理
- 每一个句子会padding到77的长度,超过的会截断
# clip/clip.py
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)
return result
2.2 CLIP模型部分
2.2.1 模型初始化
# clip/model.py
class CLIP(nn.Module):
def __init__(self,
embed_dim: int, # 图像被编码的维度与文本相同
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int
):
super().__init__()
self.context_length = context_length
if isinstance(vision_layers, (tuple, list)):
# 图像编码方式一:使用ResNet结构
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
else:
# 图像编码方式一:使用ViT结构
vision_heads = vision_width // 64
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
# 文本编码使用Transformer
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask()
)
self.vocab_size = vocab_size
# 对单词进行embedding
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
# 可学习的位置编码
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
# LayerNorm
self.ln_final = LayerNorm(transformer_width)
# text映射
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
# 可学习的logit缩放因子
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
# 权重初始化
self.initialize_parameters()
2.2.2 forward函数
- 首先对输入图像和文本,分别编码
- 然后进行特征归一化,最后计算图像与文本的相似度
# clip/model.py
def forward(self, image, text):
# 1、对输入图像和文本,分别编码
image_features = self.encode_image(image) # 最终的image_features:(1,512)
text_features = self.encode_text(text) # 最终的text_features :(3,512)
# 2、特征归一化
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# 3、求解图像与文本的相似度,获得图像与文本匹配结果
# cosine similarity as logits
logit_scale = self.logit_scale.exp() # 对logit进行缩放
# 特征相乘获得相似度 [batch_img, batch_text]
# logits_per_image shape = [1, 3]
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
2.2.3 图像编码
- 这里讲解利用VisionTransformer进行图像编码
- 核心就是
利用卷积将一张图像转变为一个序列
,可以参考:当CV遇上transformer(一)ViT模型
# clip/model.py
def encode_image(self, image):
return self.visual(image.type(self.dtype)) # 转成fp16
2.2.4 文本编码
# clip/model.py
def encode_text(self, text):
# 1、对tokenized text进行编码
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
# 2、加上位置编码
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
# 3、多层ResidualAttentionBlock
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
# 4、layer_norm
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# x.shape = [3, 77, 512]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# [batch_size=3, transformer.width=512] @ [512,512] = [3,512]
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
3 利用CLIP预训练模型实现以图搜图
- 可以使用ResNet或者clip实现对图片的特征抽取
- 首先,需要利用模型对图片库中图像进行特征提取,并将结果保存起来
- 当想要搜寻一张图片的相似图片时,需要计算这张图片和图片向量库中的相似度,获取topk个相似图片。
import glob
import os
import numpy as np
from PIL import Image
import torch
import argparse
import timm
import torchvision
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel
"""图像搜索引擎"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_mean_std_of_dataset(dataset_dir):
"""统计数据集的均值和标准差"""
train_files = glob.glob(os.path.join(dataset_dir, "*.jpg"))
print(f'total {len(train_files)} files for training')
result = []
# 遍历所有的图片
for file in train_files:
img = Image.open(file).convert('RGB')
img = np.array(img).astype(np.uint8)
# 像素缩放到0-1
img = img / 255.
result.append(img)
# result shape = [BS, H, W, C]
# 对每个通道求均值和标准差
mean = np.mean(result, axis=(0, 1, 2))
std = np.std(result, axis=(0, 1, 2))
print(f'mean = {mean}, std = {std}')
return mean, std
def get_args():
parser = argparse.ArgumentParser(description='Image Search Task')
parser.add_argument('--input_size', type=int, default=128, help='images input size')
parser.add_argument('--dataset_dir', default="/root/autodl-fs/data/fruit20/dataset/train", help='images path')
parser.add_argument('--test_image_dir', default="/root/autodl-fs/data/fruit20/dataset/val", help='test images path')
parser.add_argument('--save_dir', default="output_dir", help='path to save')
parser.add_argument('--model_name', default="clip", help='model name: renet50 or renet152 or clip')
parser.add_argument('--feature_dict_file', default="corpus_feature_dict.npy",
help='filename where to save image representations')
parser.add_argument('--topk', type=int, default=7, help='k most similar images')
parser.add_argument('--mode', default="extract", help='extract or predict')
args = parser.parse_args()
return args
def extract_feature_by_clip(model, preprocess, image_file_path):
# 1、读取图像并进行预处理
image = Image.open(image_file_path)
inputs = preprocess(images=image, return_tensors='pt')
# 2、将输入传递给模型获取图像的特征
with torch.no_grad():
features = model.get_image_features(**inputs)
vec_features = features.squeeze().cpu().numpy()
return vec_features
def extract_feature_single(args, model, image_file_path):
image_rgb = Image.open(image_file_path).convert('RGB')
image = image_rgb.resize(args.input_size, args.input_size)
image = torchvision.transforms.ToTensor()(image)
image = torchvision.transforms.Normalize(mean=[0.47, 0.43, 0.32], std=[0.37, 0.36, 0.34])(image).unsqueeze(0)
with torch.no_grad():
features = model.forward_features(image)
vec_features = model.global_pool(features)
vec_features = vec_features.squeeze().cpu().numpy()
return vec_features
def extract_features(args, model, img_path, preprocess):
all_vectors = {}
train_files_path = glob.glob(os.path.join(img_path, "*.jpg"))
train_files_path += glob.glob(os.path.join(img_path, "*.png"))
for image_file_path in tqdm(train_files_path):
if args.model_name == "clip":
# 1、通过clip提取特征
all_vectors[image_file_path] = extract_feature_by_clip(model, preprocess, image_file_path)
else:
# 2、通过ResNet提取特征
all_vectors[image_file_path] = extract_feature_single(args, model, image_file_path)
# 将提取出的图像特征保存起来
os.makedirs(f"./{args.save_dir}/{args.model_name}", exist_ok=True)
np.save(f"{args.save_dir}/{args.model_name}/{args.feature_dict_file}", all_vectors)
return all_vectors
def get_similar_matrix(vectors_dict):
"""计算给定向量字典中各个向量之间的相似度,其中相似度的计算采用了向量之间的余弦相似度"""
# 1、每行代表一个向量
v = np.array(list(vectors_dict.values())) # [NUM, H]
# 2、计算相似度矩阵的分子部分
numerator = np.matmul(v, v.T) # [NUM, NUM]
# 3、计算相似度矩阵的分母部分(计算每对向量之间的范数乘积)
denominator = np.matmul(
np.linalg.norm(v, axis=1, keepdims=True),
np.linalg.norm(v, axis=1, keepdims=True).T
) # [NUM, NUM]
# 4、到相似度矩阵 sim,其中 sim[i, j] 表示向量 i 和向量 j 之间的相似度
sim = numerator / denominator
keys = list(vectors_dict.keys())
return sim, keys
if __name__ == '__main__':
args = get_args()
model = None
processor = None
if args.model_name != "clip":
# 利用renet50 or renet152作为特征抽取器
model = timm.create_model(args.model_name, pretrained=True)
model.eval()
else:
# 加载openai clip预训练模型
model_path = '/root/autodl-fs/models/clip-vit-base-patch32'
model = CLIPModel.from_pretrained(model_path)
processor = CLIPProcessor.from_pretrained(model_path)
if args.mode == "extract":
# 1、利用预训练模型抽取特征,并保存下来
print(f'use pretrained model {args.model_name} to extract features')
extract_features(args, model, img_path=args.dataset_dir, preprocess=processor)
else:
# 2、以图搜图
print(f'use pretrained model {args.model_name} to search {args.topk} similar images from corpus')
test_images = glob.glob(os.path.join(args.test_image_dir, "*.jpg"))
test_images += glob.glob(os.path.join(args.test_image_dir, "*.png"))
# 2-1加载图像向量
all_vectors = np.load(f"./{args.save_dir}/{args.model_name}/{args.feature_dict_file}", allow_pickle=True)
all_vectors = all_vectors.item()
# 2-2 提取搜索图像的图像特征
for image_file_path in tqdm(test_images):
print(f'reading {image_file_path} ......')
if args.model_name == "clip":
all_vectors[image_file_path] = extract_feature_by_clip(model, processor, image_file_path)
else:
all_vectors[image_file_path] = extract_feature_single(args, model, image_file_path)
# 2-3 获取相似度矩阵及相似图片路径
sims, keys = get_similar_matrix(all_vectors)
# 2-4 获取topk个相似图片
result = {}
for image_file in tqdm(test_images):
index = keys.index(image_file)
sim_vec = sims[index]
indexs = np.argsort(sim_vec)[::-1][1:args.topk]
sim_imgs, sim_socres = [], []
for ind in indexs:
sim_imgs.append(keys[ind])
sim_socres.append(sim_vec[ind])
result[image_file] = (sim_imgs, sim_socres)
print(result)