动手学习RAG: moka-ai/m3e 模型微调deepspeed与对比学习

  • 动手学习RAG: 向量模型
  • 动手学习RAG: moka-ai/m3e 模型微调deepspeed与对比学习
  • 动手学习RAG:迟交互模型colbert微调实践 bge-m3

1. 环境准备

pip install transformers
pip install open-retrievals
  • 注意安装时是pip install open-retrievals,但调用时只需要import retrievals
  • 欢迎关注最新的更新 https://github.com/LongxingTan/open-retrievals

2. 使用M3E模型

from retrievals import AutoModelForEmbedding

embedder = AutoModelForEmbedding.from_pretrained('moka-ai/m3e-base', pooling_method='mean')
embedder

请添加图片描述

sentences = [
    '* Moka 此文本嵌入模型由 MokaAI 训练并开源,训练脚本使用 uniem',
    '* Massive 此文本嵌入模型通过**千万级**的中文句对数据集进行训练',
    '* Mixed 此文本嵌入模型支持中英双语的同质文本相似度计算,异质文本检索等功能,未来还会支持代码检索,ALL in one'
]

embeddings = embedder.encode(sentences)

for sentence, embedding in zip(sentences, embeddings):
    print("Sentence:", sentence)
    print("Embedding:", embedding)
    print("")

请添加图片描述

3. deepspeed 微调m3e模型

数据仍然采用之前介绍的t2-ranking数据集

  • deepspeed配置保存为 ds_zero2_no_offload.json. 不过虽然设置了zero2,这里我只用了一张卡. 但deepspeed也很容易扩展到多卡,或多机多卡
    • 关于deepspeed的分布式设置,可参考Tranformer分布式特辑
{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 100,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1e-10
    },

    "zero_optimization": {
        "stage": 2,
        "allgather_partitions": true,
        "allgather_bucket_size": 1e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 1e8,
        "contiguous_gradients": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

这里稍微修改了open-retrievals这里的代码,主要是修改了导入为包的导入,而不是相对引用。保存文件为embed.py

"""Embedding fine tune pipeline"""

import logging
import os
import pickle
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed

from retrievals import (
    EncodeCollator,
    EncodeDataset,
    PairCollator,
    RetrievalTrainDataset,
    TripletCollator,
)
from retrievals.losses import AutoLoss, InfoNCE, SimCSE, TripletLoss
from retrievals.models.embedding_auto import AutoModelForEmbedding
from retrievals.trainer import RetrievalTrainer

# os.environ["WANDB_LOG_MODEL"] = "false"
logger = logging.getLogger(__name__)


@dataclass
class ModelArguments:
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    causal_lm: bool = field(default=False, metadata={'help': "Whether the model is a causal lm or not"})
    lora_path: Optional[str] = field(default=None, metadata={'help': "Lora adapter save path"})


@dataclass
class DataArguments:
    data_name_or_path: str = field(default=None, metadata={"help": "Path to train data"})
    train_group_size: int = field(default=2)
    unfold_each_positive: bool = field(default=False)
    query_max_length: int = field(
        default=32,
        metadata={
            "help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    document_max_length: int = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    query_instruction: str = field(default=None, metadata={"help": "instruction for query"})
    document_instruction: str = field(default=None, metadata={"help": "instruction for document"})
    query_key: str = field(default=None)
    positive_key: str = field(default='positive')
    negative_key: str = field(default='negative')
    is_query: bool = field(default=False)
    encoding_save_file: str = field(default='embed.pkl')

    def __post_init__(self):
        # self.data_name_or_path = 'json'
        self.dataset_split = 'train'
        self.dataset_language = 'default'

        if self.data_name_or_path is not None:
            if not os.path.isfile(self.data_name_or_path) and not os.path.isdir(self.data_name_or_path):
                info = self.data_name_or_path.split('/')
                self.dataset_split = info[-1] if len(info) == 3 else 'train'
                self.data_name_or_path = "/".join(info[:-1]) if len(info) == 3 else '/'.join(info)
                self.dataset_language = 'default'
                if ':' in self.data_name_or_path:
                    self.data_name_or_path, self.dataset_language = self.data_name_or_path.split(':')


@dataclass
class RetrieverTrainingArguments(TrainingArguments):
    train_type: str = field(default='pairwise', metadata={'help': "train type of point, pair, or list"})
    negatives_cross_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
    temperature: Optional[float] = field(default=0.02)
    fix_position_embedding: bool = field(
        default=False, metadata={"help": "Freeze the parameters of position embeddings"}
    )
    pooling_method: str = field(default='cls', metadata={"help": "the pooling method, should be cls or mean"})
    normalized: bool = field(default=True)
    loss_fn: str = field(default='infonce')
    use_inbatch_negative: bool = field(default=True, metadata={"help": "use documents in the same batch as negatives"})
    remove_unused_columns: bool = field(default=False)
    use_lora: bool = field(default=False)
    use_bnb_config: bool = field(default=False)
    do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"})
    report_to: Optional[List[str]] = field(
        default="none", metadata={"help": "The list of integrations to report the results and logs to."}
    )


def main():
    parser = HfArgumentParser((ModelArguments, DataArguments, RetrieverTrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    model_args: ModelArguments
    data_args: DataArguments
    training_args: TrainingArguments

    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            "Use --overwrite_output_dir to overcome."
        )

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)
    logger.info("Model parameters %s", model_args)
    logger.info("Data parameters %s", data_args)

    set_seed(training_args.seed)

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=False,
    )
    if training_args.use_bnb_config:
        from transformers import BitsAndBytesConfig

        logger.info('Use quantization bnb config')
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
    else:
        quantization_config = None

    if training_args.do_train:
        model = AutoModelForEmbedding.from_pretrained(
            model_name_or_path=model_args.model_name_or_path,
            pooling_method=training_args.pooling_method,
            use_lora=training_args.use_lora,
            quantization_config=quantization_config,
        )

        loss_fn = AutoLoss(
            loss_name=training_args.loss_fn,
            loss_kwargs={
                'use_inbatch_negative': training_args.use_inbatch_negative,
                'temperature': training_args.temperature,
            },
        )

        model = model.set_train_type(
            "pairwise",
            loss_fn=loss_fn,
        )

        train_dataset = RetrievalTrainDataset(
            args=data_args,
            tokenizer=tokenizer,
            positive_key=data_args.positive_key,
            negative_key=data_args.negative_key,
        )
        logger.info(f"Total training examples: {len(train_dataset)}")

        trainer = RetrievalTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            data_collator=TripletCollator(
                tokenizer,
                query_max_length=data_args.query_max_length,
                document_max_length=data_args.document_max_length,
                positive_key=data_args.positive_key,
                negative_key=data_args.negative_key,
            ),
        )

        Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)

        trainer.train()
        # trainer.save_model(training_args.output_dir)
        model.save_pretrained(training_args.output_dir)

        if trainer.is_world_process_zero():
            tokenizer.save_pretrained(training_args.output_dir)

    if training_args.do_encode:
        model = AutoModelForEmbedding.from_pretrained(
            model_name_or_path=model_args.model_name_or_path,
            pooling_method=training_args.pooling_method,
            use_lora=training_args.use_lora,
            quantization_config=quantization_config,
            lora_path=model_args.lora_path,
        )

        max_length = data_args.query_max_length if data_args.is_query else data_args.document_max_length
        logger.info(f'Encoding will be saved in {training_args.output_dir}')

        encode_dataset = EncodeDataset(args=data_args, tokenizer=tokenizer, max_length=max_length, text_key='text')
        logger.info(f"Number of train samples: {len(encode_dataset)}, max_length: {max_length}")

        encode_loader = DataLoader(
            encode_dataset,
            batch_size=training_args.per_device_eval_batch_size,
            collate_fn=EncodeCollator(tokenizer, max_length=max_length, padding='max_length'),
            shuffle=False,
            drop_last=False,
            num_workers=training_args.dataloader_num_workers,
        )

        embeddings = model.encode(encode_loader, show_progress_bar=True, convert_to_numpy=True)
        lookup_indices = list(range(len(encode_dataset)))

        with open(os.path.join(training_args.output_dir, data_args.encoding_save_file), 'wb') as f:
            pickle.dump((embeddings, lookup_indices), f)


if __name__ == "__main__":
    main()

  • 最终调用文件 shell run.sh
MODEL_NAME="moka-ai/m3e-base"

TRAIN_DATA="/root/kag101/src/open-retrievals/t2/t2_ranking.jsonl"
OUTPUT_DIR="/root/kag101/src/open-retrievals/t2/ft_out"


# loss_fn: infonce, simcse

deepspeed -m --include localhost:0 embed.py \
  --deepspeed ds_zero2_no_offload.json \
  --output_dir $OUTPUT_DIR \
  --overwrite_output_dir \
  --model_name_or_path $MODEL_NAME \
  --do_train \
  --data_name_or_path $TRAIN_DATA \
  --positive_key positive \
  --negative_key negative \
  --pooling_method mean \
  --loss_fn infonce \
  --use_lora False \
  --query_instruction "" \
  --document_instruction "" \
  --learning_rate 3e-5 \
  --fp16 \
  --num_train_epochs 5 \
  --per_device_train_batch_size 32 \
  --dataloader_drop_last True \
  --query_max_length 64 \
  --document_max_length 256 \
  --train_group_size 4 \
  --logging_steps 100 \
  --temperature 0.02 \
  --save_total_limit 1 \
  --use_inbatch_negative false

请添加图片描述

4. 测试

微调前性能 c-mteb t2-ranking score

请添加图片描述

微调后性能

请添加图片描述

采用infoNCE损失函数,没有加in-batch negative,而关注的是困难负样本,经过微调map从0.654提升至0.692,mrr从0.754提升至0.805

对比一下非deepspeed而是直接torchrun的微调

  • map略低,mrr略高。猜测是因为deepspeed中设置的一些auto会和直接跑并不完全一样
    请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/877631.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Mac】系统环境配置

常用工具 Navicat PJ版本:this Host切换器 SwitchHosts termius 一款好用的Linux服务器连接工具: termius 小飞机 dddd:🪜 Git mac安装git有好多种方式,自带的xcode或者通过Homebrew来安装,本文的…

人工智能开发实战matplotlib库应用基础

内容导读 matplotlib简介绘制直方图绘制撒点图 一、matplotlib简介 matplotlib是一个Python 2D绘图库,它以多种硬拷贝格式和跨平台的交互式环境生成高质量的图形。 matplotlib 尝试使容易的事情变得更容易,使困难的事情变得可能。 我们只需几行代码…

Qt ORM模块使用说明

附源码:QxOrm是一个C库资源-CSDN文库 使用说明 把QyOrm文件夹拷贝到自己的工程项目下, 在自己项目里的Pro文件里添加include($$PWD/QyOrm/QyOrm.pri)就能使用了 示例test_qyorm.h写了表的定义,Test_QyOrm_Main.cpp中写了所有支持的功能的例子: 通过自动表单添加…

C++ ——string的模拟实现

目录 前言 浅记 1. reserve(扩容) 2. push_back(尾插) 3. iterator(迭代器) 4. append(尾插一个字符串) 5. insert 5.1 按pos位插入一个字符 5.2 按pos位插入一个字符串 …

CleanClip for Mac 剪切板 粘贴工具 历史记录 安装(保姆级教程,新手小白轻松上手)

CleanClip:革新macOS剪贴板管理体验 目录 功能概览 多格式历史记录保存智能搜索功能快速复制操作拖拽功能 安装指南 前期准备安装步骤 配置与使用 功能概览 多格式历史记录保存 CleanClip支持保存文本、图片、文件等多种格式的复制历史记录,为用户提…

C语言 | Leetcode C语言接雨水II

题目: 题解: typedef struct{int row;int column;int height; } Element;struct Pri_Queue; typedef struct Pri_Queue *P_Pri_Queue; typedef Element Datatype;struct Pri_Queue{int n;Datatype *pri_qu; };/*优先队列插入*/ P_Pri_Queue add_pri_que…

通信工程学习:什么是SNI业务节点接口

SNI:业务节点接口 SNI业务节点接口,全称Service Node Interface,是接入网(AN)和一个业务节点(SN)之间的接口,位于接入网的业务侧。这一接口在通信网络中扮演着重要的角色&#xff0c…

几种手段mfc140u.dll丢失的解决方法,了解mfc140u.dll

在使用Windows操作系统时,许多用户可能会遇到“找不到mfc140u.dll”或“mfc140u.dll未找到”的错误提示。这个错误通常是由于该文件丢失或损坏所致。本文将详细介绍mfc140u.dll文件的作用、丢失的原因及其解决方法,帮助您快速恢复系统的正常运行。 一、m…

2024年华为9月4日秋招笔试真题题解

2024年华为0904秋招笔试真题 二叉树消消乐好友推荐系统维修工力扣上类似的题--K站中转内最便宜的航班 二叉树消消乐 题目描述 给定原始二叉树和参照二叉树(输入的二叉树均为满二叉树,二叉树节点的值范围为[1,1000],二叉树的深度不超过1000)&#xff0c…

Maven 解析:打造高效、可靠的软件工程

Apache Maven【简称maven】 是一个用于 Java 项目的构建自动化工具, 通过提供一组规则来管理项目的构建、依赖关系和文档。 1.Pre-预备知识: 1.1.Maven是什么? [by Maven是什么?有什么作用?Maven的核心内容简述_ma…

简化登录流程,助力应用建立用户体系

随着智能手机和移动应用的普及,用户需要在不同的应用中注册和登录账号,传统的账号注册和登录流程需要用户输入用户名和密码,这不仅繁琐而且容易造成用户流失。 华为账号服务(Account Kit)提供简单、快速、安全的登录功能,让用户快…

Zabbix自定义监控项与触发器

当我们需要获取某台主机上的数据时,直接利用 zabbix 提供的模板可以很方便的获得需要的数据,但是有些特别的数据,利用这些现有的模板或监控项是无法实现的,例如网站状态信息的监控、mysql数据库主从状态等信息。这是就需要自己定义键值和监控…

Java许可政策再变,Oracle JDK 17 免费期将结束!

原文地址:https://www.infoworld.com/article/3478122/get-ready-for-more-java-licensing-changes.html Oracle JDK 17的许可协议将于9月变更回Oracle Technology Network License Agreement,这将迫使用户重新评估他们的使用策略。 有句老话说&#xf…

小程序组件间通信

文章目录 父传子子传父获取组件实例兄弟通信 父传子 知识点: 父组件如果需要向子组件传递指定属性的数据,在 WXML 中需要使用数据绑定的方式 与普通的 WXML 模板类似,使用数据绑定,这样就可以向子组件的属性传递动态数据。 父…

java实际开发——数据库存储金额时用什么数据类型?(MySQL、PostgreSQL)

目录 java开发时金额用的数据类型——BigDecimal MySQL存储金额数据时用的数据类型是——decimal PostgreSQL存储金额数据时用的数据类型是——decimal 或 money java开发时金额用的数据类型——BigDecimal https://blog.csdn.net/Jilit_jilit/article/details/142180903?…

YOLOv5/v8 + 双目相机测距

yolov5/v8双目相机测距的代码,需要相机标定 可以训练自己的模型并检测测距,都是python代码 已多次实验,代码无报错。 非常适合做类似的双目课题! 相机用的是汇博视捷的双目相机,具体型号见下图。 用的yolov5是6.1版本的…

ubuntu2204安装kvm

ubuntu2204安装kvm 前言一、检测硬件是否支持二、安装软件三、创建/管理虚拟机1、创建存储池2、qemu创建镜像3、xml文件运行虚拟机1、范文2、xml文件创建虚机3、创建虚机 4、克隆虚机5、创建快照6、脚本创建VNC连接 四、创建集群1、安装glusterfs2、加入集群删除节点 3、 创建卷…

深度剖析iOS渲染

iOS App 图形图像渲染的基本流程: 1.CPU:完成对象的创建和销毁、对象属性的调整、布局计算、文本的计算和排版、图片的格式转换和解码、图像的绘制。 2.GPU:GPU拿到CPU计算好的显示内容,完成纹理的渲染, 渲染完成后将渲…

基于YOLO深度学习和百度AI接口的手势识别与控制项目

基于YOLO深度学习和百度AI接口的手势识别与控制项目 项目描述 本项目旨在开发一个手势识别与控制系统,该系统能够通过摄像头捕捉用户的手势,并通过YOLO深度学习模型或调用百度AI接口进行手势识别。识别到的手势可以用来控制计算机界面的操作&#xff0…

使用 Elastic 和 LM Studio 的 Herding Llama 3.1

作者:来自 Elastic Charles Davison, Julian Khalifa 最新的 LM Studio 0.3 更新使 Elastic 的安全 AI Assistant 能够更轻松、更快速地与 LM Studio 托管模型一起运行。在这篇博客中,Elastic 和 LM Studio 团队将向你展示如何在几分钟内开始使用。如果你…