在向量数据中存储多模态数据,通过文字搜索图片,Chroma 支持文字和图片,通过 OpenClip 模型对文字以及图片做 Embedding。本文通过 Chroma 实现一个文字搜索图片的功能。
OpenClip
CLIP(Contrastive Language-Image Pretraining,对比语言-图像预训练)是由OpenAI开发的一种模型,它结合了自然语言处理(NLP)和计算机视觉(CV)来理解和关联文本和视觉数据。CLIP旨在从大量的互联网数据中学习,并能够执行各种任务,例如零样本图像分类、图像到文本搜索和文本到图像搜索,而无需特定任务的数据集。CLIP 有以下特性
-
对比学习:CLIP使用对比学习方法,模型通过区分匹配和不匹配的图像和文本对进行训练。这意味着它学习将图像与其对应的文本描述对齐,并区分不相关的对。
-
双分支架构:CLIP包含两个分支:一个用于处理图像,另一个用于处理文本。这些分支通常基于深度学习架构,例如用于图像的Vision Transformers(ViT)或ResNet,用于文本的基于Transformer的模型(如GPT)。
-
联合嵌入空间:模型将图像和文本投影到共享的嵌入空间中。在训练过程中,它最大化匹配图像-文本对的嵌入相似性,最小化不匹配对的嵌入相似性。
-
零样本学习:CLIP的一个重要优势是其零样本学习能力。这意味着它可以通过利用类别的文本描述来对在训练中未见过的类别的图像进行分类。
OpenClip 是 Open AI CLIP 的开源实现。
数据准备
本文使用魔搭的数据集 tany0699/dailytags
import os
from datasets import load_dataset
from matplotlib import pyplot as plt
import matplotlib.image as mpimg
from PIL import Image
IMAGE_FOLDER = "images"
N_IMAGES = 20
# For plotting
plot_cols = 5
plot_rows = N_IMAGES // plot_cols
fig, axes = plt.subplots(plot_rows, plot_cols, figsize=(plot_rows*2, plot_cols*2))
axes = axes.flatten()
# Write the images to a folder
dataset_iter = iter(dataset)
os.makedirs(IMAGE_FOLDER, exist_ok=True)
for i in range(N_IMAGES):
image = Image.open(next(dataset_iter)['image:FILE'])
axes[i].imshow(image)
axes[i].axis("off")
image.save(f"images/{i}.jpg")
plt.tight_layout()
plt.show()
安装依赖
安装 Chroma 和 OpenClip
!pip install chromadb
!pip install open_clip_torch
搜索图片
- 启动 Chroma
import chromadb
client = chromadb.Client()
- 初始化 Embedding Model
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
from chromadb.utils.data_loaders import ImageLoader
embedding_function = OpenCLIPEmbeddingFunction()
image_loader = ImageLoader()
- 创建 Chroma 集合
collection = client.create_collection(
name='multimodal_collection',
embedding_function=embedding_function,
data_loader=image_loader)
- 初始化数据
# Get the uris to the images
image_uris = sorted([os.path.join(IMAGE_FOLDER, image_name) for image_name in os.listdir(IMAGE_FOLDER)])
ids = [str(i) for i in range(len(image_uris))]
collection.add(ids=ids, uris=image_uris)
- 查询
retrieved = collection.query(query_texts=["bird"], include=['data', 'distances'], n_results=3)
for img in retrieved['data'][0]:
print(retrieved['distances'])
plt.imshow(img)
plt.axis("off")
plt.show()
总结
Chroma 多模态测试下来,效果还是不错,但是目前只支持英文。