rwkv模型lora微调之accelerate和deepspeed训练加速

       

目录

一、rwkv模型简介

二、lora原理简介

三、rwkv-lora微调

1、数据整理

2、环境搭建

a、Dockerfile编写

b、制造镜像

c、容器启动

3、训练代码修改

四、模型推理

1、模型推理

2、lora权重合并

3、推理web服务

五、总结


        由于业务采用的ChatGLM模型推理成本太大了,希望降低模型推理成本。因此对rwkv_1.5B模型进行了预研和业务领域的验证。为了快速验证,采用了lora+accelerate+deepspeed的训练方式。微调的过程中对rwkv模型认识更加深刻,同时对于docker训练环境搭建也更加熟悉了。这篇博客就分享一下这次微调中的一些实践,主要是关于训练流程拉通和rwkv模型在业务领域的一些结论。

一、rwkv模型简介

                rwkv模型是国人研发的一个非常优秀的模型,采用RNN架构代码目前主流的attention机制的transformer架构,在时间复杂度和空间复杂度都减少比较多的情况下,还能取得非常不错的效果,在各个榜单都有上榜。

       ​​

      上图是rwkv模型语言建模的架构,可以看到舍弃了attention机制,采用time mix 和channel mix模块。 

二、lora原理简介

      论文LoRA: Low-Rank Adaptation of Large Language Models 开发了一种方法,专为微调大模型减小显存。如下图:

       

   

对于一个参数,在微调的时候不直接微调W,而是把W通过低秩分解为两个小矩阵B和A的乘积,然后学习更新B和A,从而达到减少参数量和梯度等,同时保证模型lora微调后的效果和全参数微调效果相当。实现的时候会在BAx乘以一个系数,一般是lora_alpha/lora_rank的比值,注意lora_rank越大可学习的参数越多,显存占用就越多。

实践一般采用peft来实现对模型的linear层进行weight分解,使用方法如下:

model初始化
......
peft_config = LoraConfig(
        peft_type="LORA",
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=args.lora_rank,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        target_modules=args.target_modules.split(","),
    )
model = get_peft_model(model, peft_config)
......
model训练和保存
model_state_dict = lora.lora_state_dict(model)
torch.save(path,model_state_dict )

三、rwkv-lora微调

        rwkv的微调主要的重点内容在于数据的整理(整理成模型可训练的格式)、训练环境的搭建、训练代码的修改和最后的模型效果评估,其中至于怎么样微调才能获得比较好的效果,本文不予讨论。由于rwkv支持2中数据格式,一种是question+answer拼接,另外一种是instruction+input+response拼接;目前1.5B,rwkv开源了v4和v5版本的权重,因此这里会做4次实验,用相同的业务数据构成训练集和测试集,使用不用的权重和数据指令拼接方式进行实验。

1、数据整理

qa指令拼接——适合做问答类

{"text": "Question: 问题\n\nAnswer: 答案"}

iir指令拼接——适合做阅读理解问答

{"text": "Instruction:基于专业背景的知识问题\n\nInput:专业领域的资料背景知识内容\n\nResponse:基于上述的专业回答"}

其中Instruction 是指示,Input 是需要操作的数据(注意Input可以为空),Response是答案

我们的业务数据

{"context": "姓名:未知\n服务时间:晚上23点\n联系方式:未知\n地址:广东省深圳市龙岗区南湾街道康桥花园\n空调品牌:卡萨帝\n空调样式:挂机\n是否5匹:10匹\n故障类型:异味\n\n坐席:空调发生什么故障了,不制热、不制冷、不开机还是其他问题?\n客户:其他不故障现象\n\n以上海尔导航场景收集的要素信息以及坐席和客户的一轮对话,你是要素抽取的专家,请根据坐席和客户的对话,更新上述要素结果,对话中未提及到的要素,保持原样结果输出,“空调品牌”取值范围是“卡萨帝”、“海尔”、“统帅”、“小超人”,“空调样式”取值范围是“柜机”、“挂机”、“嵌入机”、“中央空调”,“是否5匹”取值范围是“5匹以上”、“5匹以下”,“故障类型”取值范围是“不制冷”、“不制热”、“机器制热效果差”、“机器制冷效果差”、“机器着火”、“遥控器故障”、“无法关机”、“噪音大”、“温度不能调整”、“外观伤”、“频繁开停机”、“显示屏乱码跳屏”、“机器报故障”、“室内机漏水”、“连接管未包扎好”、“送风强度”、“异味”、“漏电”、“不通电”、“不启动”、“按键失灵”、“出风异常”、“显示屏异常”、“不停机”、“不除霜”、“排水管问题”、“空调漏气/漏氟”、“购买配件”、“自动开/关机”\n请给出要素抽取结果", "target": "姓名:未知\n\n服务时间:晚上23点\n\n联系方式:未知\n\n地址:广东省深圳市龙岗区南湾街道康桥花园\n\n空调品牌:卡萨帝\n\n空调样式:挂机\n\n是否5匹:10匹\n\n故障类型:其它故障"}

qa拼接后的形式:

{"text": "Question:姓名:未知\n服务时间:晚上23点\n联系方式:未知\n地址:广东省深圳市龙岗区南湾街道康桥花园\n空调品牌:卡萨帝\n空调样式:挂机\n是否5匹:10匹\n故障类型:异味\n\n坐席:空调发生什么故障了,不制热、不制冷、不开机还是其他问题?\n客户:其他不故障现象\n\n以上是海尔导航场景收集的要素信息以及坐席和客户的一轮对话,你是要素抽取的专家,请根据坐席和客户的对话,更新上述要素结果,对话中未提及到的要素,无需输出,若所有要素在对话中均未提到,请直接输出“无效对话”,空调品牌取值范围是“卡萨帝”、“海尔”、“统帅”、“小超人”;空调样式取值范围是“柜机”、“挂机”、“嵌入机”、“中央空调”;是否5匹取值范围是“5匹”、“5匹以上”、“5匹以下”、“10匹”;故障类型取值范围是“不制冷”、“不制热”、“机器制热效果差”、“机器制冷效果差”、“机器着火”、“遥控器故障”、“无法关机”、“噪音大”、“温度不能调整”、“外观伤”、“频繁开停机”、“显示屏乱码跳屏”、“机器报故障”、“室内机漏水”、“连接管未包扎好”、“送风强度”、“异味”、“漏电”、“不通电”、“不启动”、“按键失灵”、“显示屏异常”、“不停机”、“不除霜”、“空调漏气/漏氟”、“购买配件”、“自动开/关机”、“出风异常”、“排水管问题”、“其他故障”\n请给出要素抽取结果\n\nAnswer:故障类型:其它故障"}

iir拼接后的形式:

{"text": "Instruction:以上是海尔导航场景收集的要素信息以及坐席和客户的一轮对话,你是要素抽取的专家,请根据坐席和客户的对话,更新上述要素结果,对话中未提及到的要素,无需输出,若所有要素在对话中均未提到,请直接输出“无效对话”,空调品牌取值范围是“卡萨帝”、“海尔”、“统帅”、“小超人”;空调样式取值范围是“柜机”、“挂机”、“嵌入机”、“中央空调”;是否5匹取值范围是“5匹”、“5匹以上”、“5匹以下”、“10匹”;故障类型取值范围是“不制冷”、“不制热”、“机器制热效果差”、“机器制冷效果差”、“机器着火”、“遥控器故障”、“无法关机”、“噪音大”、“温度不能调整”、“外观伤”、“频繁开停机”、“显示屏乱码跳屏”、“机器报故障”、“室内机漏水”、“连接管未包扎好”、“送风强度”、“异味”、“漏电”、“不通电”、“不启动”、“按键失灵”、“显示屏异常”、“不停机”、“不除霜”、“空调漏气/漏氟”、“购买配件”、“自动开/关机”、“出风异常”、“排水管问题”、“其他故障”\n请给出要素抽取结果\n\nInput:姓名:未知\n服务时间:晚上23点\n联系方式:未知\n地址:广东省深圳市龙岗区南湾街道康桥花园\n空调品牌:卡萨帝\n空调样式:挂机\n是否5匹:10匹\n故障类型:异味\n\n坐席:空调发生什么故障了,不制热、不制冷、不开机还是其他问题?\n客户:其他不故障现象\n\nResponse:故障类型:其它故障"}

2、环境搭建

        官方代码库指定的环境直接安装就好了,不过安装的过程中要注意机器的显卡驱动一定要比安装的cuda版本要高,并且cuda版本的算力不能低于显卡的算力(大多数情况下,显卡是支持一定的低版本的cuda和torch的);torch的版本要和cuda的版本一致,比如4090显卡安装了12.0的显卡驱动,安装了cuda11.8,那么torch也要安装cuda11.8的版本 torch2.0_cu118。rwkv有自己实现的cuda算子需要python调用C++和nvcc来编译作为torch的扩展,所以要严格匹配版本,不然会报显卡算力过高和torch版本不匹配,cuda和torch版本不匹配等错误。C++编译的时候还需要完整的libso库文件,由于本人使用的机器多人使用,不好升级libso库文件——错误操作可能会导致linux系统出错。稳妥起见直接使用docker来搭建训练环境,并且在docker中训练。物理机器上安装docker,编写dockerfile后,制作镜像,启动容器然后训练就OK了。

a、Dockerfile编写
##build 镜像
#docker build -t  images_name(images_name:tag) -f ./Dockerfile .
##运行容器  --gpus all 宿主机上的显卡可用  --ipc host  代表与宿主机器共享命名空间,即让Docker容器和宿主机器使用同一个进程ID命名空间和信号命名空间,从而实现进程间通信的能力
## --network host docker 使用本机的IP和端口
#docker run -d -it --name my_container --gpus all --network host --ipc host  images_name(id)

#cuda toolkit共享的库,涵盖了运行环境的最小集合如动态库等,但没有cuda的编译工具nvcc
#FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04

#基于runtime,添加了编译工具链、调试工具、头文件、静态库,用于从源码编译cuda应用,是有nvcc的
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04

WORKDIR /rwkv
# Set up time zone.
ENV TZ=Asia/Shanghai
RUN  ln -snf /usr/share/zoneinfo/$TZ /etc/localtime

ENV STAGE_DIR=/tmp
RUN mkdir -p ${STAGE_DIR}


RUN  apt-get update && \
        apt-get install -y --no-install-recommends \
         software-properties-common build-essential autotools-dev \
        nfs-common pdsh \
        cmake g++ gcc \
        curl wget vim tmux emacs less unzip \
        htop iftop iotop ca-certificates openssh-client openssh-server \
        rsync iputils-ping net-tools

RUN  apt-get update && \
         apt-get install -y --no-install-recommends \
        libsndfile-dev \
        libcupti-dev \
        libjpeg-dev \
        libpng-dev \
        screen \
        libaio-dev


#从源码安装python
RUN apt install unzip wget build-essential zlib1g-dev libncurses5-dev libgdbm-dev libnss3-dev libssl-dev libsqlite3-dev libreadline-dev libffi-dev curl libbz2-dev pkg-config make -y
RUN apt-get install liblzma-dev -y
#RUN wget https://www.python.org/ftp/python/3.10.10/Python-3.10.10.tar.xz
COPY Python-3.10.10.tar.xz ./
RUN tar xf Python-3.10.10.tar.xz
RUN cd Python-3.10.10 && ./configure --enable-optimizations && make altinstall && cd .. && rm -fr *
RUN python3.10 -m pip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cu118


WORKDIR /rwkv
COPY requirements.txt ./
#RUN python3.10 -m pip install -r requirements.txt
#RUN python3.10 -m pip install --upgrade pip && python3.10 -m pip install -i  https://mirrors.aliyun.com/pypi/simple -r requirements.txt
RUN  python3.10 -m pip install -i  https://mirrors.aliyun.com/pypi/simple -r requirements.txt
# 拷贝所有nue文件
COPY . ./

        注意python可以提前现在源码,然后上传到服务器再制作镜像;cuda docker 一定要拉取devel版本,runtime版本会精简,不安装nvcc等编译工具,python安装一些第三方库会依赖nvcc编译工具的。其他的都没有什么了,一切正常编写即可。

b、制造镜像
docker build -t  images_name(images_name:tag) -f ./Dockerfile .

这个耗时比较久,一个是镜像、已经库文件安装,还有数据、代码等copy。

c、容器启动
docker run -d -it --name my_container --gpus all --network host --ipc host  images_name(id)

        关注的地方是--gpus 一定要是all,这样容器才能使用物理机上的所有显卡;--network host保证docker使用物理机的ip和端口,可以通过改ip访问docker内的服务;--ipc host让Docker容器和宿主机器使用同一个进程ID命名空间和信号命名空间,从而实现进程间通信的能力——跑分布式训练必须选项,因为多进程中的子进程要和主进程进行通信,传输梯度等信息。

3、训练代码修改

        原始的训练代码是不支持lora和accelerate的,这里我们修改为支持lora以及accelerate的形式。同时由于采用分布式训练,目前可以使用deepspeed来做,而accelerate也支持deepspeed的插件形式(和直接使用deepspeed来做分布式训练稍有不同,直接使用deepspeed对系统的各种库libso要求的比较严格,之前使用deepspeed一直没有成功过)。代码主体结构如下:

from accelerate import Accelerator, DeepSpeedPlugin
from peft import get_peft_model, LoraConfig, TaskType
import loralib as lora

#初始化分布式环境
accumulate_step = 4
mixed_precision = 'bf16'
deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=accumulate_step)
accelerator = Accelerator(mixed_precision=mixed_precision, gradient_accumulation_steps=accumulate_step, deepspeed_plugin=deepspeed_plugin)
device = accelerator.device

......
......
model = RWKV(args)

#lora设置,设置模型的那些参数使用lora以及其他的一些参数。
peft_config = LoraConfig(
        peft_type="LORA",
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=args.lora_rank,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        target_modules=args.target_modules.split(","),
    )
model = get_peft_model(model, peft_config)
......
#模型、优化器、数据加载器等用accelerate包装一下。
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer,train_dataloader)
......
for epoch in range(int(args.epoch_count)):
    for step, batch in enumerate(t := tqdm(train_dataloader, ncols=100)):
         model(batch)
         ......
         accelerator.backward(loss)
         optimizer.step()
         lr_scheduler.step()
         optimizer.zero_grad()

分布式环境的初始化以及lora参数的设置,针对rwkv模型lora设置如下:

lora_rank=16
lora_alpha=32
lora_dropout=0.1
target_modules=emb,key,value,receptance,output,head

完整的训练代码如下(其他的部分自行完成,代码修改自rwkv_LM中的rwkv-v4neo):

########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os, warnings, math, sys, time
import numpy as np
import torch
from torch.utils.data import DataLoader
import logging
from transformers import get_linear_schedule_with_warmup
from argparse import ArgumentParser
logging.basicConfig(level=logging.INFO)
import os
import sys
sys.path.append(os.getcwd())
def script_method(fn, _rcb=None):
    return fn

def script(obj, optimize=True, _frames_up=0, _rcb=None):
    return obj

import torch.jit

script_method1 = torch.jit.script_method
script1 = torch.jit.script
torch.jit.script_method = script_method
torch.jit.script = script

from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nn

from torch.utils.data import DataLoader
import gc

import psutil
import traceback
from tqdm import tqdm
import numpy as np

from accelerate import Accelerator, DeepSpeedPlugin
from torch.utils.data import Dataset, IterableDataset
import random
import json
from collections import defaultdict

import threading
from tokenizer import build_tokenizer
from datetime import datetime
from peft import get_peft_model, LoraConfig, TaskType
import loralib as lora

accumulate_step = 4
mixed_precision = 'bf16'
deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=accumulate_step)
accelerator = Accelerator(mixed_precision=mixed_precision, gradient_accumulation_steps=accumulate_step, deepspeed_plugin=deepspeed_plugin)
device = accelerator.device

def b2mb(x):
    return int(x / 2 ** 20)

class TorchTracemalloc:
    def __enter__(self):
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
        self.begin = torch.cuda.memory_allocated()
        self.process = psutil.Process()

        self.cpu_begin = self.cpu_mem_used()
        self.peak_monitoring = True
        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
        peak_monitor_thread.daemon = True
        peak_monitor_thread.start()
        return self

    def cpu_mem_used(self):
        """get resident set size memory for the current process"""
        return self.process.memory_info().rss

    def peak_monitor_func(self):
        self.cpu_peak = -1

        while True:
            self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)

            # can't sleep or will not catch the peak right (this comment is here on purpose)
            # time.sleep(0.001) # 1msec

            if not self.peak_monitoring:
                break

    def __exit__(self, *exc):
        self.peak_monitoring = False

        gc.collect()
        torch.cuda.empty_cache()
        self.end = torch.cuda.memory_allocated()
        self.peak = torch.cuda.max_memory_allocated()
        self.used = b2mb(self.end - self.begin)
        self.peaked = b2mb(self.peak - self.begin)

        self.cpu_end = self.cpu_mem_used()
        self.cpu_used = b2mb(self.cpu_end - self.cpu_begin)
        self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin)
        # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")

def collate_fn(batch):
    tokens, labels, domains = zip(*batch)
    input_ids = torch.nn.utils.rnn.pad_sequence(tokens,batch_first=True,padding_value=0)
    labels = torch.nn.utils.rnn.pad_sequence(labels,batch_first=True,padding_value=-100)
    domains = torch.stack(domains)
    return {"input_ids": input_ids, "labels": labels, "domains":domains}

idx2domain = {}
domain2idx = {}
# 所有数据全部加载 batch内采样
class DataReader(Dataset):
    def __init__(self,tokenizer, file_list, sample_ratios, domain_names, max_token, args):
        self.args = args
        self.tokenizer = tokenizer
        file_list = file_list.split(",")
        sample_ratios = list(map(float, sample_ratios.split(",")))
        domain_names = domain_names.split(",")
        assert len(file_list) == len(sample_ratios) and len(file_list) == len(domain_names)
        self.file_list = file_list
        self.domain_names = domain_names
        self.max_token = max_token
        self.sample_ratios = sample_ratios
        self.sum_ratio = sum(sample_ratios)
        print("self.sum_ratio: ",self.sum_ratio)
        assert self.sum_ratio <= 1.0
        self.cum_ratios = [sum(sample_ratios[:i + 1]) for i in range(len(sample_ratios))]
        print("file_list: {}, sample_ratios: {} cum_ratios:{}".format(file_list, sample_ratios, self.cum_ratios))
        self.domain2num = defaultdict(int)
        self.common_datas = {}
        for i in range(len(file_list)):
            domain2idx[domain_names[i]] = i
            idx2domain[i] = domain_names[i]
            self.common_datas[domain_names[i]] = self.loaddata_convert_token_to_ids(domain_names[i], file_list[i])
            print(file_list[i], len(self.common_datas[domain_names[i]]))
        print("domain2num:{}".format(self.domain2num))
        self.train_data = []
        self.index = 0
        self.epoch = 0
        self.train_length = 4000
        self.train_step = 1000

    def loaddata_convert_token_to_ids(self, domain_name, file_path):
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        domain_idx = domain2idx[domain_name]
        all_datas = []
        for line in tqdm(lines[0:], desc=f"read{file_path}",ncols=100):
            text = json.loads(line)["text"]

            text = text.split('\n\n')
            q = '\n\n'.join(text[0:3]) + "Answer:"
            a = '\n\n'.join(text[3:])
            a = a.replace('Answer:',"")

            q_ids = self.tokenizer.tokenize(q)
            a_ids = self.tokenizer.tokenize(a)
            ids = q_ids + a_ids
            ids.append(self.tokenizer.eod)
            if len(ids) > 2:
                if len(ids) > self.max_token:
                    # 大于最大长度的数据丢弃掉
                    continue
                else:
                    labels = [-100] * len(q_ids) + a_ids + [self.tokenizer.eod]
                    assert len(ids) == len(labels), " len(ids) != len(labels)"
                    input_ids = torch.as_tensor(ids[:-1], dtype=torch.long)
                    labels = torch.as_tensor(labels[1:], dtype=torch.long)
                    domain_idx = torch.as_tensor(domain_idx, dtype=torch.long)
                    all_datas.append((input_ids, labels, domain_idx))
        print(f"{file_path}--{len(all_datas)}")
        self.domain2num[domain_name] += 1

        return all_datas


    def __getitem__(self, item):
        if len(self.train_data) == 0:
            time_str = datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
            print("=============={}==============".format(time_str))
            for k, v in self.common_datas.items():
                if k in ['friso','kongtiao','qa','other']:
                    self.train_data.extend(v)
                else:
                    split_count = len(v)//20
                    epoch = self.epoch % 20
                    temp = v[epoch*split_count:(epoch+1)*split_count]
                    # temp = random.choices(v, k=split_count)
                    self.train_data.extend(temp)
            print(f"len(self.train_data) {len(self.train_data)} epoch {self.epoch}")

        if self.index < self.train_step:
            self.index += 1
            if item >= len(self.train_data):
                item = random.randint(0,len(self.train_data)-1)
            input_ids, labels, domain_idx = self.train_data[item]
            return input_ids, labels, domain_idx
        else:
            self.epoch += 1
            self.index = 0
            self.train_data = []
            for k, v in self.common_datas.items():
                if k in ['friso','kongtiao','qa','other']:
                    self.train_data.extend(v)
                else:
                    split_count = len(v)//20
                    epoch = self.epoch % 20
                    temp = v[epoch*split_count:(epoch+1)*split_count]
                    # temp = random.choices(v, k=split_count)
                    self.train_data.extend(temp)
            print(f"len(self.train_data) {len(self.train_data)} epoch {self.epoch}")
            self.index += 1
            if item >= len(self.train_data):
                item = random.randint(0, len(self.train_data) - 1)
            input_ids, labels, domain_idx = self.train_data[item]
            return input_ids, labels, domain_idx

    def __len__(self):
        # return 910000
        return self.train_length

if __name__ == "__main__":
    parser = ArgumentParser()

    parser.add_argument("--file_list", default="", type=str)
    parser.add_argument("--sample_ratios", default="utf-8", type=str)
    parser.add_argument("--domain_names", default="", type=str)
    parser.add_argument("--use_owndatareader", default="1", type=str)
    parser.add_argument("--logdir", default="", type=str)
    parser.add_argument("--datadir", default="", type=str)
    parser.add_argument("--save_step",default=50000,type=int)

    # lora
    parser.add_argument("--lora_rank", default=16, type=int)
    parser.add_argument("--lora_alpha", default=32, type=int)
    parser.add_argument("--lora_dropout", default=0.1, type=float)
    parser.add_argument("--target_modules", default="emb,key,value,receptance,output,head", type=str)

    parser.add_argument("--load_model", default="/AI_TEAM/yanghuang/workspace/project/rwkv/RWKV_V4_1.5B/RWKV-4-World-CHNtuned-1.5B-v1-20230620-ctx4096.pth", type=str)  # full path, with .pth
    parser.add_argument("--wandb", default="", type=str)  # wandb project name. if "" then don't use wandb
    parser.add_argument("--proj_dir", default="out", type=str)
    parser.add_argument("--random_seed", default="-1", type=int)

    parser.add_argument("--data_file", default="", type=str)
    parser.add_argument("--data_type", default="utf-8", type=str)
    parser.add_argument("--vocab_size", default=65536, type=int)  # vocab_size = 0 means auto (for char-level LM and .txt data)

    parser.add_argument("--ctx_len", default=2560, type=int)
    parser.add_argument("--epoch_steps", default=1000, type=int)  # a mini "epoch" has [epoch_steps] steps
    parser.add_argument("--epoch_count", default=500, type=int)  # train for this many "epochs". will continue afterwards with lr = lr_final
    parser.add_argument("--epoch_begin", default=0, type=int)  # if you load a model trained for x "epochs", set epoch_begin = x
    parser.add_argument("--epoch_save", default=5, type=int)  # save the model every [epoch_save] "epochs"

    parser.add_argument("--micro_bsz", default=12, type=int)  # micro batch size (batch size per GPU)
    parser.add_argument("--n_layer", default=24, type=int)
    parser.add_argument("--n_embd", default=2048, type=int)
    parser.add_argument("--dim_att", default=0, type=int)
    parser.add_argument("--dim_ffn", default=0, type=int)
    parser.add_argument("--pre_ffn", default=0, type=int)  # replace first att layer by ffn (sometimes better)
    parser.add_argument("--head_qk", default=0, type=int)  # my headQK trick
    parser.add_argument("--tiny_att_dim", default=0, type=int)  # tiny attention dim
    parser.add_argument("--tiny_att_layer", default=-999, type=int)  # tiny attention @ which layer

    parser.add_argument("--lr_init", default=6e-4, type=float)  # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
    parser.add_argument("--lr_final", default=1e-5, type=float)
    parser.add_argument("--warmup_steps", default=-1, type=int)  # try 50 if you load a model
    parser.add_argument("--beta1", default=0.9, type=float)
    parser.add_argument("--beta2", default=0.99, type=float)  # use 0.999 when your model is close to convergence
    parser.add_argument("--adam_eps", default=1e-8, type=float)
    parser.add_argument("--grad_cp", default=0, type=int)  # gradient checkpt: saves VRAM, but slower
    parser.add_argument("--dropout", default=0, type=float) # try 0.01 / 0.02 / 0.05 / 0.1
    parser.add_argument("--weight_decay", default=0, type=float) # try 0.1 / 0.01 / 0.001
    parser.add_argument("--weight_decay_final", default=-1, type=float)

    parser.add_argument("--my_pile_version", default=1, type=int)  # my special pile version
    parser.add_argument("--my_pile_stage", default=0, type=int)  # my special pile mode
    parser.add_argument("--my_pile_shift", default=-1, type=int)  # my special pile mode - text shift
    parser.add_argument("--my_pile_edecay", default=0, type=int)
    parser.add_argument("--layerwise_lr", default=1, type=int)  # layerwise lr for faster convergence (but slower it/s)
    parser.add_argument("--ds_bucket_mb", default=200, type=int)  # deepspeed bucket size in MB. 200 seems enough
    # parser.add_argument("--cuda_cleanup", default=0, type=int)  # extra cuda cleanup (sometimes helpful)

    parser.add_argument("--my_img_version", default=0, type=str)
    parser.add_argument("--my_img_size", default=0, type=int)
    parser.add_argument("--my_img_bit", default=0, type=int)
    parser.add_argument("--my_img_clip", default='x', type=str)
    parser.add_argument("--my_img_clip_scale", default=1, type=float)
    parser.add_argument("--my_img_l1_scale", default=0, type=float)
    parser.add_argument("--my_img_encoder", default='x', type=str)
    # parser.add_argument("--my_img_noise_scale", default=0, type=float)
    parser.add_argument("--my_sample_len", default=0, type=int)
    parser.add_argument("--my_ffn_shift", default=1, type=int)
    parser.add_argument("--my_att_shift", default=1, type=int)
    parser.add_argument("--head_size_a", default=64, type=int) # can try larger values for larger models
    parser.add_argument("--head_size_divisor", default=8, type=int)
    parser.add_argument("--my_pos_emb", default=0, type=int)
    parser.add_argument("--load_partial", default=0, type=int)
    parser.add_argument("--magic_prime", default=0, type=int)
    parser.add_argument("--my_qa_mask", default=0, type=int)
    parser.add_argument("--my_random_steps", default=0, type=int)
    parser.add_argument("--my_testing", default='', type=str)
    parser.add_argument("--my_exit", default=99999999, type=int)
    parser.add_argument("--my_exit_tokens", default=0, type=int)

    args = parser.parse_args()
    summary_writer = SummaryWriter(args.logdir)
    print(args)
    ########################################################################################################

    np.set_printoptions(precision=4, suppress=True, linewidth=200)
    warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")
    warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")
    # os.environ["WDS_SHOW_SEED"] = "1"

    args.my_timestamp = datetime.today().strftime("%Y-%m-%d-%H-%M-%S")
    args.enable_checkpointing = False
    args.replace_sampler_ddp = False
    args.logger = False
    args.gradient_clip_val = 1.0
    args.num_sanity_val_steps = 0
    args.check_val_every_n_epoch = int(1e20)
    args.log_every_n_steps = int(1e20)
    args.max_epochs = -1  # continue forever
    args.betas = (args.beta1, args.beta2)
    args.real_bsz = args.micro_bsz
    os.environ["RWKV_T_MAX"] = str(args.ctx_len)
    os.environ["RWKV_MY_TESTING"] = args.my_testing
    os.environ["RWKV_HEAD_SIZE_A"] = str(args.head_size_a)
    if args.dim_att <= 0:
        args.dim_att = args.n_embd
    if args.dim_ffn <= 0:
        if 'r3' in args.my_testing:
            args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)
        else:
            args.dim_ffn = args.n_embd * 4

    if args.data_type == "wds_img":
        args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
        args.proj_dir = f"{args.proj_dir}-{args.run_name}"
    else:
        args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"

    if accelerator.is_main_process and not os.path.exists(args.proj_dir):
        os.makedirs(args.proj_dir)

    if args.my_pile_stage > 0:
        magic_prime_bak = args.magic_prime

        if args.my_pile_version == 1:
            if args.ctx_len == 1024:
                args.magic_prime = 324331313
            elif args.ctx_len == 2048:
                args.magic_prime = 162165671
            elif args.ctx_len == 4096:
                args.magic_prime = 81082817
            elif args.ctx_len == 8192:
                args.magic_prime = 40541399
        else:
            if args.ctx_len == 1024:
                args.magic_prime = 1670239709
            elif args.ctx_len == 2048:
                args.magic_prime = 835119767
            elif args.ctx_len == 4096:
                args.magic_prime = 417559889
            elif args.ctx_len == 6144:
                args.magic_prime = 278373239
            elif args.ctx_len == 8192:
                args.magic_prime = 208779911
        if args.my_pile_shift < 0:
            args.my_pile_shift = 0

        if magic_prime_bak > 0:
            args.magic_prime = magic_prime_bak
        if args.my_qa_mask == 2:
            args.epoch_count = 2 * args.magic_prime // 40320
        else:
            args.epoch_count = args.magic_prime // 40320

        args.epoch_steps = 40320 // args.real_bsz
        assert args.epoch_steps * args.real_bsz == 40320
        # if args.my_pile_stage == 2:
        #     assert args.lr_final == args.lr_init
        if args.my_pile_stage >= 2:  # find latest saved model
            list_p = []
            for p in os.listdir(args.proj_dir):
                if p.startswith("rwkv") and p.endswith(".pth"):
                    p = ((p.split("-"))[1].split("."))[0]
                    if p != "final":
                        if p == "init":
                            p = -1
                        else:
                            p = int(p)
                        list_p += [p]
            list_p.sort()
            max_p = list_p[-1]
            if len(list_p) > 1:
                args.my_pile_prev_p = list_p[-2]  # in case max_p is corrupted
            if max_p == -1:
                args.load_model = f"{args.proj_dir}/rwkv-init.pth"
            else:
                args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
                if args.warmup_steps < 0:
                    if args.my_pile_stage == 2:
                        args.warmup_steps = 10
                    else:
                        args.warmup_steps = 30
            args.epoch_begin = max_p + 1

    samples_per_epoch = args.epoch_steps * args.real_bsz
    tokens_per_epoch = samples_per_epoch * args.ctx_len


    assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]

    args.precision = "bf16"
    assert args.precision in ["fp32", "tf32", "fp16", "bf16"]
    os.environ["RWKV_FLOAT_MODE"] = args.precision
    # os.environ["RWKV_JIT_ON"] = "1"
    os.environ["RWKV_JIT_ON"] = "0"

    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    if args.precision == "fp32":
        torch.backends.cudnn.allow_tf32 = False
        torch.backends.cuda.matmul.allow_tf32 = False
    else:
        torch.backends.cudnn.allow_tf32 = True
        torch.backends.cuda.matmul.allow_tf32 = True

    args.precision = "bf16"

    if args.data_type == 'wds_img':
        from src.model_img import RWKV_IMG
        model = RWKV_IMG(args)
    else:
        from src.model import RWKV
        model = RWKV(args)

    try:
        load_dict = torch.load(args.load_model, map_location="cpu")
        load_keys = list(load_dict.keys())
        for k in load_keys:
            if k.startswith('_forward_module.'):
                load_dict[k.replace('_forward_module.','')] = load_dict[k]
                del load_dict[k]
    except:
        if args.my_pile_stage >= 2:  # try again using another checkpoint
            max_p = args.my_pile_prev_p
            if max_p == -1:
                args.load_model = f"{args.proj_dir}/rwkv-init.pth"
            else:
                args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
            args.epoch_begin = max_p + 1
            load_dict = torch.load(args.load_model, map_location="cpu")

    model.load_state_dict(load_dict)

    peft_config = LoraConfig(
        peft_type="LORA",
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=args.lora_rank,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        target_modules=args.target_modules.split(","),
    )
    model = get_peft_model(model, peft_config)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr_init)

    tokenizer_type = "RWKVTokenizer"
    vocab_file = "./json2binidx/rwkv_vocab_v20230424.txt"
    tokenizer = build_tokenizer(tokenizer_type, vocab_file)
    train_data = DataReader(tokenizer, args.file_list, args.sample_ratios, args.domain_names, args.ctx_len, args)
    # train_data = DataReader( tokenizer, args.ctx_len, args.datadir, read_file_count=2)

    train_dataloader = DataLoader(dataset=train_data, collate_fn=collate_fn, shuffle=True, batch_size=args.micro_bsz)
    print(f"已经加载完了数据:{len(train_dataloader)}条")

    warm_up_ratio = 0.1
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=int(len(train_dataloader) / accumulate_step * warm_up_ratio),
        num_training_steps=(int(len(train_dataloader) / accumulate_step) * args.epoch_count),
    )
    model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
    print(f"已经加载完了数据:{len(train_dataloader)}条")

    loss_fct = nn.CrossEntropyLoss()
    global_step = 0

    domain2globalstep = {k: 0 for k in domain2idx}

    for epoch in range(int(args.epoch_count)):
        name2loss = {k: 0 for k in domain2idx}
        domain2step = {k: 0 for k in domain2idx}
        print("name2loss",name2loss)
        total_loss = 0
        mean_loss = 0
        domain2num = {k: 0 for k in domain2idx}
        with TorchTracemalloc() as tracemalloc:
            model.to(device).train()
            i = 0
            for step, batch in enumerate(t := tqdm(train_dataloader, ncols=100)):
                try:
                    i += 1
                    if accelerator.is_main_process and i % args.save_step == 0:
                        model_state_dict = lora.lora_state_dict(accelerator.unwrap_model(model))
                        save_path = os.path.join(args.proj_dir, f"rwkv-epoch{epoch}_step{i}_lora.pt")
                        accelerator.save(model_state_dict, save_path)

                    labels = batch['labels']
                    domains = batch['domains']
                    input_ids = batch['input_ids']
                    lm_logits = model(input_ids)

                    shift_logits = lm_logits.contiguous()
                    shift_labels = labels.contiguous()

                    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

                    accelerator.backward(loss)
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()
                    if i % 50 == 0:
                        torch.cuda.empty_cache()
                    loss_detach = loss.detach().cpu().float()

                    total_loss += loss_detach
                    time_str = datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
                    des_train = f"{time_str} shape:{input_ids.shape[1]} loss: {loss_detach}"
                    for domian_name, domian_idx in domain2idx.items():
                        select_idx = domains == domian_idx
                        select_shift_logits = shift_logits[select_idx]
                        select_shift_labels = shift_labels[select_idx]
                        loss_domain = 0
                        if len(select_shift_labels) > 0:
                            domain2num[domian_name] += len(select_shift_labels)
                            loss_domain = loss_fct(select_shift_logits.view(-1, select_shift_logits.size(-1)),
                                                   select_shift_labels.view(-1)).detach().cpu().float()
                            domain2globalstep[domian_name] += 1
                            domain2step[domian_name] += 1
                            name2loss[domian_name] += loss_domain
                            summary_writer.add_scalar(f"train_step/{domian_name}", loss_domain, domain2globalstep[domian_name])
                        des_train += f" {domian_name}: {loss_domain}"
                        # domain2loss_detach[domian_name] = loss_domain
                    t.set_description(des_train)
                    # t.set_postfix(des_train)
                    if accelerator.is_main_process:
                        summary_writer.add_scalar(f"train_step/total_loss", loss_detach, global_step)
                    global_step += 1
                except Exception as e:
                    print(str(e))
                    print(traceback.format_exc())
                    print("oom", batch['input_ids'].shape)
                    optimizer.zero_grad()
                    torch.cuda.empty_cache()

        mean_loss = total_loss / (step + 1)
        for k in name2loss:
            name2loss[k] = name2loss[k] / (domain2step[k] + 1)
            if accelerator.is_main_process:
                summary_writer.add_scalar(f"train/{k}", name2loss[k], epoch)


        s = ""
        s_num = ""
        for k, v in name2loss.items():
            s += f" {k}_loss={v}"
            s_num += f" {k}_num={domain2num[k]}"

        train_epoch_loss = total_loss
        train_mean_epoch_loss = mean_loss
        train_ppl = torch.exp(train_epoch_loss)
        time_str = datetime.now().strftime("%Y-%m-%d-%H_%M_%S")
        accelerator.print(
            f"{time_str}  epoch={epoch}: train_ppl={train_ppl} train_epoch_loss={train_epoch_loss} train_mean_epoch_loss={train_mean_epoch_loss}")
        accelerator.print(s)
        accelerator.print(s_num)
        accelerator.wait_for_everyone()

accelerate联合deepspeed启动的时候需要配置文件:

compute_environment: LOCAL_MACHINE
deepspeed_config:
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: false
  zero3_save_16bit_model: false
  zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'yes'
dynamo_backend: 'yes'
fsdp_config: {}
machine_rank: 0
main_training_function: main
megatron_lm_config: {}
mixed_precision: fp16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
use_cpu: true
main_process_port: 20667

主要关注num_processes,要和使用的显卡数量一致。

训练启动脚本,使用CUDA_VISIBLE_DEVICES指定机器上使用的显卡;nohup后台启动;accelerate launch 启动accelerate;--config_file 配置文件设置以及deepspeed的配置等

CUDA_VISIBLE_DEVICES=1,2,4,5 nohup  accelerate launch --config_file accelerate_ds_zero3_cpu_offload_config.yaml  train_accelerator_deepspeed_lora_v1.py \
--load_model /AI_TEAM/yanghuang/workspace/project/rwkv/RWKV_V4_1.5B/RWKV-4-World-CHNtuned-1.5B-v1-20230620-ctx4096.pth
......
......

采用lora以及2张4090来训练,只需要几分钟就可以训练好一个epoch,显存占用也非常友好:

四、模型推理

1、模型推理

模型推理使用rwkv第三方库来实现,核心逻辑如下:

from rwkv.model import RWKV
from rwkv.utils import PIPELINE
model = RWKV(model='./rwkv.pth', strategy='cuda bf16')
model.eval()
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")

out_tokens = []
out_last = 0
out_str = ''
occurrence = {}
state = None
token = None
for i in range(max_length):
    tokens = pipeline.encode(ctx) if i == 0 else [token]
    out, state = pipeline.model.forward(tokens, state)
    for n in occurrence:
        out[n] -= (0.4 + occurrence[n] * 0.4)  # repetition penalty

    token = pipeline.sample_logits(out, temperature=1.0, top_p=0.0)
    if token == 0:
        break  # exit when 'endoftext'

    out_tokens += [token]
    occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
    tmp = pipeline.decode(out_tokens[out_last:])

    if ('\ufffd' not in tmp) and (not tmp.endswith('\n')):
        # print(tmp, end='', flush=True)
        out_str += tmp
        out_last = i + 1
return out_str

同时由于采用lora训练因此需要把lora权重合并到原始的权重上,方可使用上述方式进行模型加载和推理

2、lora权重合并

lora权重合并到原始权重,依据公式直接实现,代码如下:

def merge_lora_weights():
    rwkv_path = "RWKV-4-World-CHNtuned-1.5B-v1-20230620-ctx4096.pth"
    lora_path = "./lora.pt"
    print("lora_path: ",lora_path)
    model_weight = torch.load(rwkv_path, map_location='cpu')
    lora_model = torch.load(lora_path,  map_location='cpu')
    for k, v in tqdm(model_weight.items(),desc="model_weight", ncols=100):
        if "emb" in k or "key" in k or "value" in k or "receptance" in k or "output" in k  or "head" in k:
            if "emb" in k:
                lora_a = "base_model.model." + k.replace(".weight", ".lora_embedding_A.default")
                lora_b = "base_model.model." + k.replace(".weight", ".lora_embedding_B.default")
                device = v.device
                w_a = lora_model[lora_a].T
                w_b = lora_model[lora_b].T
                w = torch.mm(w_a, w_b).cpu()
                new_w = v.cpu() + 2 * w
                model_weight[k] = new_w.to(device)
            elif "weight" in k:
                lora_a = "base_model.model." + k.replace(".weight", ".lora_A.default.weight")
                lora_b = "base_model.model." + k.replace(".weight", ".lora_B.default.weight")
                device = v.device
                w_a = lora_model[lora_a]
                w_b = lora_model[lora_b]
                w = torch.mm(w_b, w_a).cpu()
                # w = torch.mm(w_b, w_a)
                new_w = v.cpu() + 2 * w
                model_weight[k] = new_w.to(device)
            else:
                model_weight[k] = v
        else:
            model_weight[k] = v
    rwkv_lora_path = "./rwkv.pth"
    torch.save(model_weight,rwkv_lora_path)
    print("merge_lora_weights finished!")

3、推理web服务

一般都是需要提供web接口,采用aiohttp来做异步web接口,把上述模型推理和lora权重合并功能逻辑集成到web服务程序中:

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import asyncio
import json
import logging.handlers
import os
import socket
import time

import aiohttp

from aiohttp import web

import torch
from argparse import ArgumentParser
from tqdm import tqdm

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True

os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '1'
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS

# logger
log_level = logging.DEBUG

logger = logging.getLogger(__name__)
logger.setLevel(log_level)

formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(filename)s:%(lineno)s %(message)s')

stream_handler = logging.StreamHandler()
stream_handler.setLevel(log_level)
stream_handler.setFormatter(formatter)

os.makedirs('./log', exist_ok=True)
file_handler = logging.handlers.RotatingFileHandler(filename='log/server.log', maxBytes=10 << 20, backupCount=5,encoding='utf8')
file_handler.setLevel(log_level)
file_handler.setFormatter(formatter)

logger.addHandler(stream_handler)
logger.addHandler(file_handler)

#
NODE_NAME = 'general.rwkv.loratest_20231010'
NODE_NAME_2 = 'general.chat.hydiversity_20231010'
print(NODE_NAME)
print(NODE_NAME_2)
NUS = '心跳IP:端口'


async def heart_beat(ip, port):
    data_dic = {
        'method': 'heartbeat',
        'params': {
            'data': [
                {
                    'nodename': NODE_NAME,
                    'addrip': ip + ':' + str(port),
                    'type': 'transparent'
                },
                {
                    'nodename': NODE_NAME_2,
                    'addrip': ip + ':' + str(port),
                    'type': 'transparent'
                }
            ]
        }
    }
    send_data = json.dumps(data_dic)

    client = aiohttp.ClientSession()
    while True:
        try:
            await client.post(f'http://{NUS}/heartbeat', data=send_data)
        except Exception as e:
            logger.error(f'send heartbeat fail: {e}')
        await asyncio.sleep(1)


class TimeMeasure:
    def __init__(self, desc=''):
        self.start = 0
        self.desc = desc

    def __enter__(self):
        self.start = time.time()
        logger.info(f'{self.desc} start')

    def __exit__(self, exc_type, exc_val, exc_tb):
        end = time.time()
        cost_s = end - self.start
        if cost_s > 10:
            cost_s = round(cost_s, 2)
            logger.info(f'{self.desc} end, cost : {cost_s}s')
        else:
            cost_ms = round(cost_s * 1000, 2)
            logger.info(f'{self.desc} end, cost : {cost_ms}ms')


def build_fail_resp(id_: int, code: int, msg: str):
    return web.json_response({
        'id': id_,
        'jsonrpc': '2.0',
        'ret': code,
        'result': {
            "error_info": msg
        }
    })


def build_success_resp(id_, result):
    data = {
        'id': id_,
        'jsonrpc': '2.0',
        'ret': 0,
        'result': {
            'chatInfo': {
                'answer': result,
                'elements':[]
            }
        }
    }
    for ele in result.split('\n\n'):
        ele = ele.split(":")
        try:
            temp = {"tag":ele[0],"value":ele[1]}
            data['result']['chatInfo']['elements'].append(temp)
        except Exception as e:
            print(e)
    send_data = json.dumps(data, ensure_ascii=False)
    return web.json_response(text=send_data)


class Server:
    def __init__(self):
        self.lock = asyncio.Semaphore(20)
        self.model = RWKV(model='./rwkv.pth', strategy='cuda bf16')
        # self.model = RWKV(model='./rwkv.pth', strategy='cuda fp16')
        self.model.eval()
        self.pipeline = PIPELINE(self.model, "rwkv_vocab_v20230424")
        out_str = self.chat("Question:你好呀,你是谁?\n\nAnswer:")
        logger.info(f'out_str——{out_str}')
        logger.info(f'Server __init__ finished!')
    @torch.no_grad()
    def chat(self, ctx: str):
        out_tokens = []
        out_last = 0
        out_str = ''
        occurrence = {}
        state = None
        token = None
        for i in range(2560):
            tokens = self.pipeline.encode(ctx) if i == 0 else [token]
            out, state = self.pipeline.model.forward(tokens, state)
            for n in occurrence:
                out[n] -= (0.4 + occurrence[n] * 0.4)  # repetition penalty

            token = self.pipeline.sample_logits(out, temperature=1.0, top_p=0.0)
            if token == 0:
                break  # exit when 'endoftext'

            out_tokens += [token]
            occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
            tmp = self.pipeline.decode(out_tokens[out_last:])

            if ('\ufffd' not in tmp) and (not tmp.endswith('\n')):
                # print(tmp, end='', flush=True)
                out_str += tmp
                out_last = i + 1
        return out_str

    async def inference(self, request: web.Request):
        req = await request.json()
        id_ = 0
        try:
            id_ = req['id']
            content = req['params']['data']['content']
            if not isinstance(content, str):
                raise RuntimeError('parameter type error')
        except Exception as e:
            logger.exception(f'params error: {e}')
            return build_fail_resp(id_, 8002, 'parameter error')

        logger.info(f'id: {id_}\nreq content:\n{content}')

        prompt = f'Question:{content}\n\nAnswer:'

        # prompt = f"Instruction:这是一通交通事故报警的通话, 你是要素抽取方面的专家,需要提取的要素名为“案发地址”\n请给出要素抽取结果\n\nInput:{content}\n\nResponse:"

        logger.info(f'id: {id_}\nreq prompt:\n{prompt}')

        with TimeMeasure(f'id: {id_} infer'):
            try:
                # result = await asyncio.get_running_loop().run_in_executor(None, self.chat, prompt)
                result = await asyncio.to_thread(self.chat, prompt)

            except Exception as e:
                logger.exception(f'id: {id_} inference fail: {e}')
                return build_fail_resp(id_, 8001, 'internal error')

        logger.info(f'id: {id_}, resp: {result}')
        return build_success_resp(id_, result)




def get_local_ip(ip, port):
    try:
        conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        conn.connect((ip, port))
        ip = conn.getsockname()[0]
    except Exception:
        raise
    conn.close()
    return ip


async def main(ip, port):
    server = Server()
    app = web.Application()
    app.add_routes([
        web.post('/nlp', server.inference)
    ])
    asyncio.create_task(heart_beat(ip, port))
    return app

def merge_lora_weights():
    rwkv_path = "/AI_TEAM/yanghuang/workspace/project/rwkv/RWKV_V4_1.5B/RWKV-4-World-CHNtuned-1.5B-v1-20230620-ctx4096.pth"
    lora_path = "./output/20231016_kongtiao_v1/rwkv-epoch5_step1000_lora.pt"
    print("lora_path: ",lora_path)
    model_weight = torch.load(rwkv_path, map_location='cpu')
    lora_model = torch.load(lora_path,  map_location='cpu')
    for k, v in tqdm(model_weight.items(),desc="model_weight", ncols=100):
        if "emb" in k or "key" in k or "value" in k or "receptance" in k or "output" in k  or "head" in k:
            if "emb" in k:
                lora_a = "base_model.model." + k.replace(".weight", ".lora_embedding_A.default")
                lora_b = "base_model.model." + k.replace(".weight", ".lora_embedding_B.default")
                device = v.device
                w_a = lora_model[lora_a].T
                w_b = lora_model[lora_b].T
                w = torch.mm(w_a, w_b).cpu()
                new_w = v.cpu() + 2 * w
                model_weight[k] = new_w.to(device)
            elif "weight" in k:
                lora_a = "base_model.model." + k.replace(".weight", ".lora_A.default.weight")
                lora_b = "base_model.model." + k.replace(".weight", ".lora_B.default.weight")
                device = v.device
                w_a = lora_model[lora_a]
                w_b = lora_model[lora_b]
                w = torch.mm(w_b, w_a).cpu()
                # w = torch.mm(w_b, w_a)
                new_w = v.cpu() + 2 * w
                model_weight[k] = new_w.to(device)
            else:
                model_weight[k] = v
        else:
            model_weight[k] = v
    rwkv_lora_path = "./rwkv.pth"
    torch.save(model_weight,rwkv_lora_path)
    print("merge_lora_weights finished!")


if __name__ == '__main__':
    merge_lora_weights()
    bind_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0)
    local_ip = get_local_ip('心跳地址', 心跳IP)
    bind_socket.bind(('0.0.0.0', 0))
    web.run_app(main(local_ip, bind_socket.getsockname()[1]), sock=bind_socket)

web服务启动展示

2023-11-02 06:21:12,812 [INFO] rwkv_chat_lora_iir.py:147 out_str——我是一个基于GPT-3.5接口的AI机器人。

Question: 你好呀,你是谁?

Answer: 我是一个基于GPT-3.5接口的AI机器人
2023-11-02 06:21:12,838 [INFO] rwkv_chat_lora_iir.py:148 Server __init__ finished!
======== Running on http://0.0.0.0:45149 ========
(Press CTRL+C to quit)

可以采用心跳地址来请求 也可以直连物理机IP:45149/nlp地址来请求:

五、总结

结果:

1、今天rwkv_v4  集内55%(49 epoch) 集外15% (1191条数据)
2、昨天rwkv_v5 集内最高34%(9 epoch) 集外24%(1191条数据 4epoch)
结论:
a、rwkv_v5  确实要比rwkv_v4 对集外的泛化能力强很多(2,3对比支持该结论)
b、比ChatGLM6B蒸馏到ChatGLM1.5B效果差很多(集外92%)——训练方式完全不同,这个训练成本非常大

        虽然rwkv1.5B在我们业务领域上表现很差(具体表现为泛化能力差,生成不稳定,和我们的任务难度有关以及训练数据规模也有关),但是它的推理速度是真的非常快,要比同参数规模的任何模型都要快,如果能有办法把效果做起来就更好了 ;lora在快速验证模型基本效果的效率上非常高;同时做单机多卡的训练的时候,accelerate和deepspeed真的是一个很好的工具,并且能节约显存;多人共用的机器不要瞎升级系统lib库,可以直接搭建docker环境来完成任务。

参考文章

RWKV语言模型从入门到放弃,保姆级Training、Fine-tuning、Lora入坑教程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/114055.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

软件测试---边界值分析(功能测试)

能对限定边界规则设计测试点---边界值分析 选取正好等于、刚好大于、刚好小于边界的值作为测试数据 上点: 边界上的点 (正好等于)&#xff1b;必选(不考虑区开闭) 内点: 范围内的点 (区间范围内的数据)&#xff1b;必选(建议选择中间范围) 离点: 距离上点最近的点 (刚好…

linux下mysql-8.2.0集群部署(python版本要在2.7以上)

目录 一、三台主机准备工作 1、mysql官方下载地址&#xff1a;https://dev.mysql.com/downloads/ 2、修改/etc/hosts 3、关闭防火墙 二、三台主机安装mysql-8.2.0 1、解压 2、下载相应配置 3、初始化mysql&#xff0c;启动myslq&#xff0c;设置开机自启 4、查看初始密…

代码训练营第59天:动态规划part17|leetcode647回文子串|leetcode516最长回文子序列

leetcode647&#xff1a;回文子串 文章讲解&#xff1a;leetcode647 leetcode516&#xff1a;最长回文子序列 文章讲解&#xff1a;leetcode516 DP总结&#xff1a;动态规划总结 目录 1&#xff0c;leeetcode647 回文子串。 2&#xff0c;leetcode516 最长回文子串&#xff1…

Agent 应用于提示工程

如果Agent模仿了人类在现实世界中的操作方式&#xff0c;那么&#xff0c;能否应用于提示工程即Prompt Engingeering 呢&#xff1f; 从LLM到Prompt Engineering 大型语言模型(LLM)是一种基于Transformer的模型&#xff0c;已经在一个巨大的语料库或文本数据集上进行了训练&…

Docker(1)

文章目录 Docker物理机部署的缺点虚拟机Docker 与虚拟机的区别Docker 的优势 Docker 概念安装 DockerDocker 架构镜像加速Docker 命令进程服务相关命令 镜像相关文件命令容器相关的命令 镜像加载的原理UnionFS(联合文件系统)docker 镜像加载原理 容器的数据卷数据卷概念配置数据…

一座 “数智桥梁”,华为助力“天堑变通途”

《水调歌头游泳》中的一句话&#xff0c;“一桥飞架南北&#xff0c;天堑变通途”&#xff0c;广为人们所熟知&#xff0c;其中展现出的&#xff0c;是中国人对美好出行的无限向往。 天堑变通途从来不易。 中国是当今世界上交通运输最繁忙、最快捷的国家之一&#xff0c;交通行…

2024上海国际人工智能展(CSITF)以“技术,让生活更精彩”为核心理念,以“创新驱动发展,保护知识产权,促进技术贸易”为主题

2024上海国际人工智能展&#xff08;CSITF&#xff09; China&#xff08;Shanghai&#xff09;International Technology Fair 时间:2024年6月12-14日 地点:上海世博展览馆 主办单位 中华人民共和国商务部 中华人民共和国科学技术部 中华人民共和国国家知识产权局 上海市…

C#,数值计算——求解一组m维线性Volterra方程组的计算方法与源程序

1 文本格式 using System; namespace Legalsoft.Truffer { /// <summary> /// 求解一组m维线性Volterra方程组 /// Solves a set of m linear Volterra equations of the second kind using the /// extended trapezoidal rule.On input, t0 is the st…

Git 标签(Tag)实战:打标签和删除标签的步骤指南

目录 前言使用 Git 打本地和远程标签&#xff08;Tag&#xff09;删除本地和远程 Git 标签&#xff08;Tag&#xff09;开源项目标签&#xff08;Tag&#xff09;实战打标签删除标签 结语开源微服务商城项目前后端分离项目 前言 在开源项目中&#xff0c;版本控制是至关重要的…

python脚本监听域名证书过期时间,并将通知消息到钉钉

版本一&#xff1a; 执行脚本带上 --dingtalk-webhook和–domains后指定钉钉token和域名 python3 ssl_spirtime.py --dingtalk-webhook https://oapi.dingtalk.com/robot/send?access_tokenavd345324 --domains www.abc1.com www.abc2.com www.abc3.com脚本如下 #!/usr/bin…

什么是Node.js的流(stream)?它们有什么作用?

聚沙成塔每天进步一点点 ⭐ 专栏简介 前端入门之旅&#xff1a;探索Web开发的奇妙世界 欢迎来到前端入门之旅&#xff01;感兴趣的可以订阅本专栏哦&#xff01;这个专栏是为那些对Web开发感兴趣、刚刚踏入前端领域的朋友们量身打造的。无论你是完全的新手还是有一些基础的开发…

LeetCode算法题解|​ 669. 修剪二叉搜索树​、108. 将有序数组转换为二叉搜索树、​538. 把二叉搜索树转换为累加树​

一、LeetCode 669. 修剪二叉搜索树​ 题目链接&#xff1a;669. 修剪二叉搜索树 题目描述&#xff1a; 给你二叉搜索树的根节点 root &#xff0c;同时给定最小边界low 和最大边界 high。通过修剪二叉搜索树&#xff0c;使得所有节点的值在[low, high]中。修剪树 不应该 改变…

(免费领源码)java#springboot#MYSQL 电影推荐网站30760-计算机毕业设计项目选题推荐

摘 要 随着互联网时代的到来&#xff0c;同时计算机网络技术高速发展&#xff0c;网络管理运用也变得越来越广泛。因此&#xff0c;建立一个B/S结构的电影推荐网站&#xff1b;电影推荐网站的管理工作系统化、规范化&#xff0c;也会提高平台形象&#xff0c;提高管理效率。 本…

《TCP/IP详解 卷一:协议》第5章的IPv4数据报的IHL字段解释

首先说明一下&#xff0c;这里并不解释整个IPv4数据报各个字段的含义&#xff0c;仅仅针对IHL字段作解释。 我们先看下IPv4数据报格式 对于IHL字段&#xff0c; 《TCP/IP详解 卷一&#xff1a;协议》这么解释&#xff1a; IPv4数据报。头部大小可变&#xff0c;4位的IHL字段…

MongoDB系例全教程

一、系列文章目录 一、MongoDB安装教程—官方原版 二、MongoDB 使用教程(配置、管理、监控)_linux mongodb 监控 三、MongoDB 基于角色的访问控制 四、MongoDB用户管理 五、MongoDB基础知识详解 六、MongoDB—Indexs 七、MongoDB事务详解 八、MongoDB分片教程 九、Mo…

分类预测 | Matlab实现SMA-KELM黏菌优化算法优化核极限学习机分类预测

分类预测 | Matlab实现SMA-KELM黏菌优化算法优化核极限学习机分类预测 目录 分类预测 | Matlab实现SMA-KELM黏菌优化算法优化核极限学习机分类预测分类效果基本描述程序设计参考资料 分类效果 基本描述 1.MATLAB实现SMA-KELM黏菌优化算法优化核极限学习机分类预测(完整源码和数…

html用css grid实现自适应四宫格放视频

想同时播放四个本地视频&#xff1a; 四宫格&#xff1b;自式应&#xff0c;即放缩浏览器时&#xff0c;四宫格也跟着放缩&#xff1b;尽量填满页面&#xff08;F11 浏览器全屏时可以填满整个屏幕&#xff09;。 在 html 中放视频用 video 标签&#xff0c;参考 [1]&#xff1…

Nginx配置

localtion规则解释 #表示精确匹配&#xff0c;优先级也是最高的 ^~ #表示uri以某个常规字符串开头,理解为匹配url路径即可 ~ #表示区分大小写的正则匹配 ~* #表示不区分大小写的正则匹配 !~ #表示区分大小写不匹配的正则 !~* #表示不区分大小写不匹配的正则 / #通用匹配&#…

【Linux】Nignx的入门使用负载均衡动静分离(前后端项目部署)---超详细

一&#xff0c;Nignx入门 1.1 Nignx是什么 Nginx是一个高性能的开源Web服务器和反向代理服务器。它使用事件驱动的异步框架&#xff0c;可同时处理大量请求&#xff0c;支持负载均衡、反向代理、HTTP缓存等常见Web服务场景。Nginx可以作为一个前端的Web服务器&#xff0c;也可…

react条件渲染

目录 前言 1. 使用if语句 2. 使用三元表达式 3. 使用逻辑与操作符 列表渲染 最佳实践和注意事项 1. 使用合适的条件判断 2. 提取重复的逻辑 3. 使用适当的key属性 总结 前言 在React中&#xff0c;条件渲染指的是根据某个条件来决定是否渲染特定的组件或元素。这在构…