用ResNet50+Qwen2-VL-2B-Instruct+LoRA模仿Diffusion-VLA的论文思路,在3090显卡上训练和测试成功

想一步步的实现Diffusion VLA论文的思路,不过论文的图像的输入用DINOv2进行特征提取的,我先把这个部分换成ResNet50。

老铁们,直接上代码:

from PIL import Image
import torch
import torchvision.models as models
from torch import nn
from datasets import Dataset
from modelscope import snapshot_download, AutoTokenizer
from swanlab.integration.transformers import SwanLabCallback
from qwen_vl_utils import process_vision_info
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
from transformers import (
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
    Qwen2VLForConditionalGeneration,
    AutoProcessor,
)
import swanlab
import json
from torchvision import transforms
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.models as models

class CustomResNet(nn.Module):
    def __init__(self, output_size=(256, 1176)):
        super(CustomResNet, self).__init__()
        
        # 预训练的 ResNet 模型
        resnet = models.resnet50(pretrained=True)
        
        # 去掉 ResNet 的最后全连接层和池化层
        self.features = nn.Sequential(*list(resnet.children())[:-2])  # 去掉最后的FC层和AvgPool层
        
        # 自定义的卷积层,调整步幅和padding来控制尺寸
        self.conv1 = nn.Conv2d(2048, 2048, kernel_size=3, stride=1, padding=1)  # 保持大小
        self.conv2 = nn.Conv2d(2048, 2048, kernel_size=3, stride=1, padding=1)  # 保持大小
        self.conv3 = nn.Conv2d(2048, 2048, kernel_size=3, stride=1, padding=1)  # 保持大小
        
        # 上采样层,用于增加特征图的尺寸
        self.upconv1 = nn.ConvTranspose2d(2048, 2048, kernel_size=4, stride=4, padding=0)  # 上采样
        self.upconv2 = nn.ConvTranspose2d(2048, 2048, kernel_size=4, stride=4, padding=0)  # 上采样
        
        # 最终卷积层将特征图变为单通道输出(灰度图)
        self.final_conv = nn.Conv2d(2048, 1, kernel_size=1)  # 输出单通道

    def forward(self, x):
        # 获取ResNet的特征图
        x = self.features(x)
        
        # 经过卷积层
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        
        # 上采样阶段:增加特征图的尺寸
        x = self.upconv1(x)  # 上采样1
        x = self.upconv2(x)  # 上采样2
        
        # 使用插值进行微调输出尺寸
        x = F.interpolate(x, size=(256, 1176), mode='bilinear', align_corners=False)
        
        # 通过最后的卷积层输出(单通道)
        x = self.final_conv(x)  # 通过最后的卷积层输出
        
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

# 创建模型并移动到设备上
model_ResNet = CustomResNet(output_size=(256, 1176)).to(device)

# 定义图像预处理过程
image_transform = transforms.Compose([
    transforms.Resize((800, 800)),  # 确保图像大小一致(通常为224x224)
    transforms.ToTensor(),  # 转换为Tensor并标准化
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])

def extract_resnet_features(image_path):
    """
    使用ResNet提取图像特征
    """
    image = Image.open(image_path).convert("RGB")  # 加载图像并转换为RGB
    image_tensor = image_transform(image).unsqueeze(0).to('cuda')  # 添加batch维度并转换为cuda Tensor
    # features = resnet_extractor(image_tensor)  # 从ResNet提取特征    
    features = model_ResNet(image_tensor)

    return features

def process_func(example):
    """
    将数据集进行预处理,加入ResNet特征提取
    """
    MAX_LENGTH = 8192
    input_ids, attention_mask, labels = [], [], []
    conversation = example["conversations"]
    input_content = conversation[0]["value"]
    output_content = conversation[1]["value"]
    file_path = input_content.split("<|vision_start|>")[1].split("<|vision_end|>")[0]  # 获取图像路径
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": f"{file_path}",
                    "resized_height": 224,  # 确保图像尺寸为224x224
                    "resized_width": 224,
                },
                {"type": "text", "text": "COCO Yes:"},
            ],
        }
    ]
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )  # 获取文本
    image_inputs, video_inputs = process_vision_info(messages)  # 获取数据数据(预处理过)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )

    # print("inputs['pixel_values'] shape: ", inputs['pixel_values'].shape)

    # 提取图像特征
    image_tensor = extract_resnet_features(file_path)  # 从图像路径提取特征
    # print("image_tensor shape: ", image_tensor.shape)
    inputs['pixel_values'] = image_tensor[0,0,:,:]  # 替换图像特征为ResNet特征

    inputs = {key: value.tolist() for key, value in inputs.items()}  # tensor -> list,为了方便拼接
    instruction = inputs

    response = tokenizer(f"{output_content}", add_special_tokens=False)


    input_ids = (
            instruction["input_ids"][0] + response["input_ids"] + [tokenizer.pad_token_id]
    )

    attention_mask = instruction["attention_mask"][0] + response["attention_mask"] + [1]
    labels = (
            [-100] * len(instruction["input_ids"][0])
            + response["input_ids"]
            + [tokenizer.pad_token_id]
    )
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]

    input_ids = torch.tensor(input_ids)
    attention_mask = torch.tensor(attention_mask)
    labels = torch.tensor(labels)
    inputs['pixel_values'] = torch.tensor(inputs['pixel_values'])
    inputs['image_grid_thw'] = torch.tensor(inputs['image_grid_thw']).squeeze(0)  # 由(1,h,w)变换为(h,w)
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels,
            "pixel_values": inputs['pixel_values'], "image_grid_thw": inputs['image_grid_thw']}


def predict(messages, model):
    # 准备推理
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    # 生成输出
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    
    return output_text[0]


# 在modelscope上下载Qwen2-VL模型到本地目录下
model_dir = snapshot_download("Qwen/Qwen2-VL-2B-Instruct", cache_dir="./", revision="master")

# 使用Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct/", use_fast=False, trust_remote_code=True)
processor = AutoProcessor.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct")

# 加载模型
model = Qwen2VLForConditionalGeneration.from_pretrained("./Qwen/Qwen2-VL-2B-Instruct/", device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True,)
model.enable_input_require_grads()  # 开启梯度检查点时,要执行该方法
model.config.use_cache = False

# 处理数据集:读取json文件
# 拆分成训练集和测试集,保存为data_vl_train.json和data_vl_test.json
train_json_path = "data_vl.json"
with open(train_json_path, 'r') as f:
    data = json.load(f)
    train_data = data[:-4]
    test_data = data[-4:]

with open("data_vl_train.json", "w") as f:
    json.dump(train_data, f)

with open("data_vl_test.json", "w") as f:
    json.dump(test_data, f)

train_ds = Dataset.from_json("data_vl_train.json")
train_dataset = train_ds.map(process_func)

# 配置LoRA
config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    inference_mode=False,  # 训练模式
    r=4, #64,  # Lora 秩
    lora_alpha= 1, #16,  # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.05,  # Dropout 比例
    bias="none",
)

# 获取LoRA模型
peft_model = get_peft_model(model, config)

# 配置训练参数
args = TrainingArguments(
    output_dir="./output/Qwen2-VL-2B",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    logging_steps=10,
    logging_first_step=5,
    num_train_epochs=2,
    save_steps=100,
    learning_rate=1e-4,
    save_on_each_node=True,
    gradient_checkpointing=True,
    report_to="none",
)

# 配置Trainer
trainer = Trainer(
    model=peft_model,
    args=args,
    train_dataset=train_dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)

# 开启模型训练
trainer.train()


# ====================测试模式===================
# 配置测试参数
val_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    inference_mode=True,  # 训练模式
    r=4,#64,  # Lora 秩
    lora_alpha=1,#16,  # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.05,  # Dropout 比例
    bias="none",
)

# 获取测试模型
val_peft_model = PeftModel.from_pretrained(model, model_id="./output/Qwen2-VL-2B/checkpoint-992", config=val_config)

# 读取测试数据
with open("data_vl_test.json", "r") as f:
    test_dataset = json.load(f)

test_image_list = []
for item in test_dataset:
    input_image_prompt = item["conversations"][0]["value"]
    # 去掉前后的<|vision_start|>和<|vision_end|>
    origin_image_path = input_image_prompt.split("<|vision_start|>")[1].split("<|vision_end|>")[0]
    
    messages = [{
        "role": "user", 
        "content": [
            {
            "type": "image", 
            "image": origin_image_path
            },
            {
            "type": "text",
            "text": "COCO Yes:"
            }
        ]}]
    
    response = predict(messages, val_peft_model)
    messages.append({"role": "assistant", "content": f"{response}"})
    print(messages[-1])

    test_image_list.append(swanlab.Image(origin_image_path, caption=response))

我在3090显卡(24G显存)运行的结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/949878.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Spring Boot 项目自定义加解密实现配置文件的加密

在Spring Boot项目中&#xff0c; 可以结合Jasypt 快速实现对配置文件中的部分属性进行加密。 完整的介绍参照&#xff1a; Spring Boot Jasypt 实现application.yml 属性加密的快速示例 但是作为一个技术强迫症&#xff0c;总是想着从底层开始实现属性的加解密&#xff0c;…

A/B实验之置信检验(一):如何避免误判 (I类) 和漏报 (II类)

假设检验的依据&#xff1a;如何避免误判和漏报 A/B实验系列相关文章&#xff08;置顶&#xff09; 1.A/B实验之置信检验&#xff08;一&#xff09;&#xff1a;如何避免误判和漏报 2.A/B实验之置信检验&#xff08;二&#xff09;&#xff1a;置信检验精要 引言 在数据驱动…

每日一题:链表中环的入口结点

文章目录 判断链表环的入口节点描述数据范围&#xff1a;复杂度要求&#xff1a;输入输出 示例代码实现思路解析注意事项&#xff1a; 判断链表环的入口节点 描述 给定一个链表&#xff0c;判断该链表是否存在环。如果存在环&#xff0c;返回环的入口节点&#xff1b;如果不存…

深度学习blog-Meanshift均值漂移算法-最大熵模型

均值漂移&#xff08;Mean Shift&#xff09;是一种无监督的聚类算法&#xff0c;广泛应用于数据挖掘和计算机视觉任务。它通过移动样本点到其近邻的均值位置来寻找数据的高密度区域&#xff0c;最终形成聚类。 均值漂移算法原理 均值漂移算法的核心思想是通过滑动窗口&#…

51c自动驾驶~合集45

我自己的原文哦~ https://blog.51cto.com/whaosoft/13020031 #运动控制和规划控制需要掌握的技术栈~ 各大垃圾家电造车厂又要开始了~~~​ 1、ROS的通信方式 李是Lyapunov的李&#xff1a;谈谈ROS的通信机制 话题通信和服务通信&#xff0c;其中话题通信是通过发布和订阅…

Python基于jieba和wordcloud绘制词云图

【Cesium】自定义材质&#xff0c;添加带有方向的滚动路线 &#x1f356; 前言&#x1f3b6;一、实现过程✨二、代码展示&#x1f3c0;三、运行结果&#x1f3c6;四、知识点提示 &#x1f356; 前言 Python基于jieba和wordcloud绘制词云图 &#x1f3b6;一、实现过程 读取文本…

计算机网络与服务器

目录 架构体系及相关知识 三层架构&#xff1a; 四层架构&#xff1a; 常见的应用的模式&#xff1a; OSI模型 分层 数据链路层 TCP/IP模型 TCP和UDP都是传输层的协议 TCP三次握手、四次次分手 URL&HTTP协议详解 网址URL 结构化 报文行 报文头 空行 报文体…

Cursor实现go项目配置并实现仓库Gin项目运行

✅作者简介&#xff1a;大家好&#xff0c;我是 Meteors., 向往着更加简洁高效的代码写法与编程方式&#xff0c;持续分享Java技术内容。 &#x1f34e;个人主页&#xff1a;Meteors.的博客 &#x1f49e;当前专栏&#xff1a;知识备份 ✨特色专栏&#xff1a;知识分享 &#x…

141.环形链表 142.环形链表II

141.环形链表 & 142.环形链表II 141.环形链表 思路&#xff1a;快慢指针 or 哈希表 快慢指针代码&#xff1a; class Solution { public:bool hasCycle(ListNode *head) {if(headnullptr||head->nextnullptr)return false;ListNode *fasthead->next; //不能设置成…

信用租赁系统助力企业实现免押金租赁新模式

内容概要 在现代商业环境中&#xff0c;信用租赁正在迅速崛起。通过结合大数据与区块链技术&#xff0c;信用租赁系统彻底改变了传统的租赁流程。什么是信用租赁呢&#xff1f;简单说&#xff0c;就是不需要押金&#xff0c;你也能够租到你想要的物品&#xff0c;这对企业和消…

el-select下拉框在弹框里面错位

问题出现 Element Plus 是一个基于 Vue 3 的组件库&#xff0c;el-select 是其中一个用于选择器的组件。在 el-select 组件中&#xff0c;teleported 属性用于控制下拉菜单的渲染位置。 解决方法 teleported 属性「element-plus」 popper-append-to-body属性「element」 ‌…

IO进程day1

一、思维导图

力扣-21-合并两个有序链表

思路&#xff1a; 因为是升序的两个链表&#xff0c;我们可以进行数据域比大小&#xff0c;然后把p3(自己创建的)的指针域指向小的那个 注&#xff1a;一定要先判断两个指针为0的情况

人工智能的发展领域之GPU加速计算的应用概述、架构介绍与教学过程

文章目录 一、架构介绍GPU算力平台概述优势与特点 二、注册与登录账号注册流程GPU服务器类型配置选择指南内存和存储容量网络带宽CPU配置 三、创建实例实例创建步骤镜像选择与设置 四、连接实例SSH连接方法远程桌面配置 一、架构介绍 GPU算力平台概述 一个专注于GPU加速计算的…

QT实现 端口扫描暂停和继续功能 3

上篇QT给端口扫描工程增加线程2-CSDN博客 为按钮pushButton_Stop添加clicked事件&#xff0c;功能为暂停扫描&#xff0c;并在暂停后显示继续按钮&#xff0c;点击继续按钮之后继续扫描 1.更新UI 添加继续按钮 点击转到槽则会自动声明 2. 更新 MainWindow.h 需要新增的部分…

汽车微处理器安全机制以及测试介绍

本文介绍了三类汽车微处理器安全机制&#xff1a;硬件类、软件类和混合类&#xff0c;旨在提高系统的可靠性和安全性。硬件类安全机制包括逻辑内建自测试&#xff08;Logic-BIST&#xff09;、三重模块冗余&#xff08;TMR&#xff09;、内存内建自测试&#xff08;Memory-BIST…

【Azure Redis 缓存】Azure Redis 遇见的连接不上问题和数据丢失的情况解答

问题描述 PHP应用再连接Azure Redis服务时&#xff0c;出现Connection Timed out。当通过升级提高Azure Redis的性能时候&#xff0c;发现之前的数据丢失了。 image.png 问题解答 当Redis服务出现Timeout的情况时&#xff0c;可以从Redis服务的指标(Metrics)开始查看&#xff0…

python学习笔记—15—数据容器之列表

1. 数据容器 列表(list)、元组(tuple)、字符串(str)、集合(set)、字典(dict) 2. 列表 (1) 定义 tmp_list ["super", "carry", "doinb"] print(f"tmp_list {tmp_list}, tmp_list type is {type(tmp_list)}") tmp_list1 ["doi…

记录一次面试中被问到的问题 (HR面)

文章目录 一、你对公司的了解多少二、为什么对这个岗位感兴趣三、不能说的离职原因四、离职原因高情商回复五、你的核心优势是什么六、你认为你比其他面试候选人的优势是什么七、不要提及情感 一、你对公司的了解多少 准备要点&#xff1a; 在面试前&#xff0c;对公司进行充分…

VLMs之Agent之CogAgent:《CogAgent: A Visual Language Model for GUI Agents》翻译与解读

VLMs之Agent之CogAgent&#xff1a;《CogAgent: A Visual Language Model for GUI Agents》翻译与解读 导读&#xff1a;这篇论文介绍了CogAgent&#xff0c;一个专注于图形用户界面 (GUI) 理解和导航的视觉语言模型 (VLM)。这篇论文提出了一种新的视觉语言模型 CogAgent&#…