MM-LLM:使用Llava类构建图文多模态大模型实践

在这里插入图片描述
多模态大模型的结构如上,llava是用两层MLP作为连接器。该模式也是后续很多工作的基础。

本文主要参考了https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/train_llava的工作,最初是在b站看到的,讲解的很细致。

基础模型

大语言模型:Qwen2-1.5B-Instruct
视觉模型:clip-vit-large-patch14-336
连接器:MLP
框架:llava模型

1.LLM的处理

下载模型权重到本地后,修改Qwen2-1.5B-Instruct/tokenizer_config.json的added_tokens_decoder的值,添加

"151646": {
      "content": "<image>",
      "lstrip": false,
      "normalized": false,
      "rstrip": false,
      "single_word": false,
      "special": true
    }

additional_special_tokens添加 "<image>"

2.初始化llava模型

# 模型权重路径
modify_qwen_tokenizer_dir = "autodl-tmp/Qwen2-1.5B-Instruct"
clip_model_name_or_path = (
    "autodl-tmp/clip-vit-large-patch14-336"
)

# 加载qwen2
qwen_tokenizer = AutoTokenizer.from_pretrained(modify_qwen_tokenizer_dir)
qwen_model = AutoModelForCausalLM.from_pretrained(
                                            modify_qwen_tokenizer_dir, 
                                            device_map='cuda:0', 
                                            torch_dtype=torch.bfloat16
                                            )


# 加载clip
clip_model = AutoModel.from_pretrained(clip_model_name_or_path, device_map="cuda:0")
processor = AutoProcessor.from_pretrained(clip_model_name_or_path)

# 将clip模型和llm_model模型的config拿出来,初始化一个llava model
# Initializing a CLIP-vision config
vision_config = clip_model.vision_model.config
# Initializing a Llama config
text_config = qwen_model.config
# Initializing a Llava llava-1.5-7b style configuration
configuration = LlavaConfig(vision_config, text_config)
# Initializing a model from the llava-1.5-7b style configuration
model = LlavaForConditionalGeneration(configuration)

输出:

LlavaForConditionalGeneration(
  (vision_tower): CLIPVisionModel(
    (vision_model): CLIPVisionTransformer(
      (embeddings): CLIPVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (position_embedding): Embedding(577, 1024)
      )
      (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-23): 24 x CLIPEncoderLayer(
            (self_attn): CLIPAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): QuickGELUActivation()
              (fc1): Linear(in_features=1024, out_features=4096, bias=True)
              (fc2): Linear(in_features=4096, out_features=1024, bias=True)
            )
            (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          )
        )
      )
      (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
  )
  (multi_modal_projector): LlavaMultiModalProjector(
    (linear_1): Linear(in_features=1024, out_features=1536, bias=True)
    (act): GELUActivation()
    (linear_2): Linear(in_features=1536, out_features=1536, bias=True)
  )
  (language_model): Qwen2ForCausalLM(
    (model): Qwen2Model(
      (embed_tokens): Embedding(151936, 1536)
      (layers): ModuleList(
        (0-27): 28 x Qwen2DecoderLayer(
          (self_attn): Qwen2SdpaAttention(
            (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
            (k_proj): Linear(in_features=1536, out_features=256, bias=True)
            (v_proj): Linear(in_features=1536, out_features=256, bias=True)
            (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
            (rotary_emb): Qwen2RotaryEmbedding()
          )
          (mlp): Qwen2MLP(
            (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
            (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
            (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): Qwen2RMSNorm()
          (post_attention_layernorm): Qwen2RMSNorm()
        )
      )
      (norm): Qwen2RMSNorm()
    )
    (lm_head): Linear(in_features=1536, out_features=151936, bias=False)
  )
)

这样得到了llava模型的结构,但是旧有的权重参数还没迁移过来,要将其移动到新model里。

# 权重复制
model.vision_tower.vision_model = clip_model.vision_model
model.language_model = qwen_model

然后保存到本地,注意要将autodl-tmp/processor的preprocessor_config.json复制到autodl-tmp/vlm_1

# 保存模型
model.save_pretrained("autodl-tmp/vlm_1")
qwen_tokenizer.save_pretrained("autodl-tmp/vlm_1")
processor.save_pretrained("autodl-tmp/processor")

3.数据集加载代码

采用该数据集:https://huggingface.co/datasets/OpenGVLab/ShareGPT-4o

主要代码:

class LlavaDataset(Dataset):
    def __init__(self, dataset_dir: str) -> None:
        super().__init__()

        self.chat_data, self.image_dir = self.build_dataset(dataset_dir)

    def build_dataset(self, data_dir: str) -> Tuple[List[Dict], Path]:
        # 得到对话文件和图像文件的路径
        data_dir = Path(data_dir) # 父文件夹路径
        chat_file = data_dir.joinpath("final_data.jsonl") # 对话文件
        image_dir = data_dir.joinpath("image") # 图像文件夹
        # 读取为记录,转为dict
        chat_data = pd.read_json(chat_file, lines=True).to_dict(orient="records")

        return chat_data, image_dir

    def __len__(self):
        return len(self.chat_data)

    def __getitem__(self, index) -> Tuple[str, str, Path]:
        # 根据索引定位到记录
        cur_data = self.chat_data[index] # 定位

        conversations = cur_data.get("conversations") # 字典格式获取到对话记录

        human_input = conversations[0].get("value") # 查询
        chatbot_output = conversations[1].get("value") # 回复
        image_path = self.image_dir.joinpath(cur_data.get("image")) # 图片的路径,由图片文件夹+图片名构成

        return human_input, chatbot_output, image_path

4.训练

使用deepseed训练,主要代码

def train():

    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    model, processor = load_model_processor(model_args)
    data_collator = TrainLLavaModelCollator(processor, -100)
    train_dataset = load_dataset(data_args)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=None,
        data_collator=data_collator,
    )
    trainer.train()
    trainer.save_state()
    trainer.save_model(output_dir=training_args.output_dir)

5.推理

没有训练的模型进行推理的结果:

很抱歉,我无法看到或描述图片,因为我是一个文本生成模型,无法处理图像。如果您需要帮助,可以提供文字描述,我会尽力帮助您。

训练后的模型推理:

The image depicts a scene of a person sitting on a chair with their
legs crossed. The person is wearing a white shirt and dark blue jeans.
The person’s hair is styled in a messy, tousled manner, which adds to
the casual and relaxed atmosphere of the image. The person’s eyes are
closed, and they appear to be in a state of deep thought or
contemplation.

In the background, there is a small, white, rectangular object that
appears to be a piece of paper or a piece of writing. The object is
positioned in a manner that suggests it might be part of a document or
a note. The background is a light beige color, which contrasts with
the person’s clothing and the white object.

The chair is a wooden chair with a simple design, featuring a single
armrest and a backrest. The chair is positioned on a dark wooden
floor, which adds to the overall casual and comfortable feel of the
scene. The floor is also light beige, which complements the background
and the person’s clothing.

The lighting in the image is soft and diffused, giving the scene a
warm and inviting atmosphere. The person’s posture suggests they are
in a relaxed position, possibly after a long day or a moment of
reflection.

In summary, the image captures a person sitting on a chair with their
legs crossed, wearing casual clothing, and in a relaxed position. The
background includes a small white object, and the lighting is soft and
diffused, creating a warm and inviting atmosphere.

我仅仅训练了三轮,使用了不到300条数据。虽然结果不是很好,但是可以看出来是有成效的。
在这里插入图片描述

在我查找的多模态大模型实现中性价比是最高的,不用重写LLM的forward函数什么的。

相关代码放在https://github.com/stay-leave/enhance_llm。

参考:
https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/train_llava
https://github.com/OpenGVLab/InternVL/blob/main/internvl_chat
https://github.com/AviSoori1x/seemore
https://github.com/alexander-moore/vlm
https://github.com/WatchTower-Liu/VLM-learning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/765938.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

夏日编码狂欢:GitCode x DevUI挑战赛,点燃你的创造力

在这个创新驱动的时代&#xff0c;开源技术已成为推动全球软件开发进步的强大引擎&#xff0c;技术的边界正被全球开发者的集体智慧不断拓展。 在这个充满活力的夏日&#xff0c;开源社区迎来了一场全新的活动——由GitCode携手华为云DevUI精心打造的开源共创挑战赛。这不仅是…

【瑞吉外卖 | day01】项目介绍+后台登录退出功能

文章目录 瑞吉外卖 — day011. 所需知识2. 软件开发整体介绍2.1 软件开发流程2.2 角色分工2.3 软件环境 3. 瑞吉外卖项目介绍3.1 项目介绍3.2 产品原型展示3.3 技术选型3.4 功能架构3.5 角色 4. 开发环境搭建4.1 数据库环境搭建4.2 Maven项目构建 5. 后台系统登录功能5.1 创建需…

The Sandbox 通过创作者挑战推动社区参与

游戏开发者并不是每天都有机会让自己的作品赢得大奖。但在 The Sandbox&#xff0c;这已经成为一种趋势&#xff01;首届 "创作者挑战 "让顶尖创作者将 150 万 SAND 捧回家。现在&#xff0c;我们要带着另一个巨额奖池回来了&#xff01; 关于首届创作者挑战&#xf…

代理IP的10大误区:区分事实与虚构

在当今的数字时代&#xff0c;代理已成为在线环境不可或缺的一部分。它们的用途广泛&#xff0c;从增强在线隐私到绕过地理限制。然而&#xff0c;尽管代理无处不在&#xff0c;但仍存在许多围绕代理的误解。在本博客中&#xff0c;我们将探讨和消除一些最常见的代理误解&#…

昇思25天学习打卡营第7天|函数式自动微分

函数式自动微分 概念函数与计算图微分函数与梯度计算自定义神经网络梯度计算参考 概念 神经网络的训练主要使用反向传播算法&#xff0c;模型预测值&#xff08;logits&#xff09;与正确标签&#xff08;label&#xff09;送入损失函数&#xff08;loss function&#xff09;…

这几类热销品被Ozon限制销售,ozon还有什么产品好卖?

OZON是俄罗斯最大的B2C电商平台&#xff0c;占据俄罗斯电商市场份额的62%&#xff0c;日均订单量高达37万单&#xff0c;拥有超过1600万的活跃用户。ozon平台对中国卖家招商的产品品类涵盖了多个领域&#xff0c;但近日Ozon官方发布将对这三大类目实行销售限制&#xff0c;一起…

DNS访问百度

DNS&#xff0c;英文全称是 domain name system&#xff0c;域名解析系统&#xff0c;它的作用也很明确&#xff0c;就是域名和 IP 相互映射。 假设你要查询 baidu.com 的 IP 地址: 首先会查找浏览器的缓存,看看是否能找到 baidu.com 对应的IP地址&#xff0c;找到就直接返回&…

【热门会议|见刊快】2024年管理创新与教育国际会议 (ICMIE 2024)

2024年管理创新与教育国际会议 (ICMIE 2024) 2024 International Conference on Management Innovation and Education 【重要信息】 大会地点&#xff1a;洛阳 大会官网&#xff1a;http://www.icicmie.com 投稿邮箱&#xff1a;icicpsssub-conf.com 【注意&#xff1a;稿将稿…

工厂方法模式:概念与应用

目录 工厂方法模式工厂方法模式结构工厂方法适合的应用场景工厂方法模式的优缺点练手题目题目描述输入描述输出描述**提示信息**解题&#xff1a; 工厂方法模式 工厂方法模式是一种创建型设计模式&#xff0c; 其在父类中提供一个创建对象的方法&#xff0c; 允许子类决定实例…

苹果电脑废纸篓数据被清空了,有什么方法可以恢复吗?

使用电脑的用户都知道&#xff0c;被删除的文件一般都会经过回收站&#xff0c;想要恢复它直接点击“还原”就可以恢复到原始位置。mac电脑同理也是这样&#xff0c;但是“回收站”在mac电脑显示为“废纸篓”。 苹果电脑废纸篓数据被清空了&#xff0c;有什么方法可以恢复吗&am…

页面速度是如何影响SEO的?

搜索引擎使用复杂的算法来衡量您网站的重要方面&#xff0c;以决定是否向您发送流量。 搜索引擎使用您网站的小元素来确定您网站的质量和真实性&#xff0c;然后此操作将转化为您的网页在搜索引擎结果页面 中出现的位置。提高您在 SERP 中的排名的过程称为搜索引擎优化 (SEO)。…

在 Mac 上使用 本地 LLM 文本终结

我们可使用本地大型语言模型&#xff0c;如Mistral、Llama等&#xff0c;来给文本做总结&#xff0c;相比在线的 Kimi &#xff0c;ChatGPT&#xff0c; 我们不用担心数据泄露&#xff0c;因为整个操作都是在本地电脑完成的。 我们用 ollama 举例 首先安装 ollama https://ol…

从零搭建Prometheus到Grafana告警推送

目录 一、Prometheus源码安装和动态更新配置 二、Prometheus操作面板和常见配置 三、Prometheus常用监控组件exporter配置 3.1 exporter是什么 3.2 有哪些exporter 3.3 exporter怎么用 3.4 实战 node_exporter ​3.5 其它exporter都怎么用 四、Promethus整合新版Sprin…

数据结构常见图算法

深度优先搜索 时间复杂度 领接矩阵表示 O( n2) 领接表表示 O(n+e) 空间复杂度 O(e) DFS与回溯法类似,一条路径走到底后需要返回上一步,搜索第二条路径。在树的遍历中,首先一直访问到最深的节点,然后回溯到它的父节点,遍历另一条路径,直到遍历完所有节点…

怎样在《语文世界》期刊上发表论文?

怎样在《语文世界》期刊上发表论文&#xff1f; 《语文世界》知网国家级 1.5-2版 2500字符左右 正常收25年4-6月版面 可加急24年内&#xff08;初中&#xff0c;高中&#xff0c;中职&#xff0c;高职&#xff0c;大学均可&#xff0c;操作周期2个月左右&#xff09; 《语文世…

【CH32V305FBP6】USBD HS 虚拟串口分析

文章目录 前言分析端点 0USBHS_UIS_TOKEN_OUT 端点 2USBHS_UIS_TOKEN_OUTUSBHS_UIS_TOKEN_IN 前言 虚拟串口&#xff0c;端口 3 单向上报&#xff0c;端口 2 双向收发。 分析 端点 0 USBHS_UIS_TOKEN_OUT 设置串口参数&#xff1a; 判断 USBHS_SetupReqCode CDC_SET_LIN…

解锁应用商店新玩法:Xinstall渠道包,让你的App推广效率飙升

在移动应用竞争日益激烈的今天&#xff0c;如何在众多应用商店中脱颖而出&#xff0c;实现精准推广与高效获客&#xff0c;成为每位App开发者与广告主的共同追求。幸运的是&#xff0c;Xinstall作为一款一站式App全渠道统计服务商&#xff0c;以其专业的渠道包解决方案&#xf…

Yi-1.5 9B Chat 上线Amazon SageMaker JumpStart

你是否对简单的API调用大模型感到不满足&#xff1f;是否因为无法亲自部署属于自己的大模型而烦恼&#xff1f; 好消息来了&#xff0c;Amazon SageMaker JumpStart 初体验 CloudLab实验上线啦&#xff01; 本实验将以零一万物最新发布的中文基础模型 Yi-1.5 9B Chat 为例&am…

如何指定Microsoft Print To PDF的输出路径

在上一篇文章中&#xff0c;介绍了三种将文件转换为PDF的方式。默认情况下&#xff0c;在Microsoft Print To PDF的首选项里&#xff0c;是看不到输出路径的设置的。 需要一点小小的手段。 运行输入 control 打开控制面板&#xff0c;选择硬件和声音下的查看设备和打印机 找到…

在卷积神经网络(CNN)中为什么可以使用多个较小的卷积核替代一个较大的卷积核,以达到相同的感受野

在卷积神经网络&#xff08;CNN&#xff09;中为什么可以使用多个较小的卷积核替代一个较大的卷积核&#xff0c;以达到相同的感受野 flyfish 在卷积神经网络&#xff08;CNN&#xff09;中&#xff0c;可以使用多个较小的卷积核替代一个较大的卷积核&#xff0c;以达到相同的…