基于 LLamafactory 的异步API高效调用实现与速度对比

文章目录

    • 背景
    • 摘要
    • 简介
    • 代码实现
    • 运行结果
    • 速度对比
      • 异步调用速度
      • 同步调用速度

背景

原先经常调用各家的闭源大模型的API,如果使用同步的方式调用,速度会很慢。为了加快 API 的调用速度,决定使用异步调用 API 的方式。

摘要

通过异步方式调用大语言模型 API的方法,相较于传统同步调用方式,异步调用速度提升了约 9.41 倍。利用 LLamafactory 原生数据集加载和自定义异步工具类 AsyncAPICall 实现批量数据推理,支持调用限速和断点恢复。

简介

本文编写的代码,支持原生的 llamafactory 的数据集导入方式。
推理速度远远快于同步的 API 调用方式。基于 langchain_openai.ChatOpenAI 的 invoke 方法实现异步调用。
下述代码的主要工作介绍如下:

  • 使用 LLamafactory 的原生方法加载 数据集;
  • 封装了异步调用工具类 AsyncAPICall,限制API的调用速度,逐块推理,避免程序崩溃导致所有数据丢失;

代码实现

async_call_api.py

# pip install langchain langchain_openai

import os
import sys
import json
import asyncio


import fire
from tqdm import tqdm
from dataclasses import dataclass
from aiolimiter import AsyncLimiter
from typing import List
import pandas as pd
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv

from llamafactory.hparams import get_train_args
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.data.loader import _get_merged_dataset

load_dotenv()


class AsyncLLM:
    def __init__(
        self,
        model: str = "gpt-3.5-turbo",
        base_url: str = "http://localhost:{}/v1/".format(
            os.environ.get("API_PORT", 8000)
        ),
        api_key: str = "{}".format(os.environ.get("API_KEY", "0")),
        num_per_second: int = 6,
        **kwargs,
    ):
        self.model = model
        self.base_url = base_url
        self.api_key = api_key
        self.num_per_second = num_per_second

        self.limiter = AsyncLimiter(self.num_per_second, 1)

        self.llm = ChatOpenAI(
            model=self.model, base_url=self.base_url, api_key=self.api_key, **kwargs
        )

    async def __call__(self, text):
        # 限速
        async with self.limiter:
            return await self.llm.ainvoke([text])


llm = AsyncLLM(
    base_url="http://localhost:{}/v1/".format(os.environ.get("API_PORT", 8000)),
    api_key="{}".format(os.environ.get("API_KEY", "0")),
    num_per_second=10,
)
llms = [llm]


@dataclass
class AsyncAPICall:
    uid: str = "0"

    @staticmethod
    async def _run_task_with_progress(task, pbar):
        result = await task
        pbar.update(1)
        return result

    @staticmethod
    def async_run(
        llms: List[AsyncLLM],
        data: List[str],
        keyword: str = "",
        output_dir: str = "output",
        chunk_size=500,
    ) -> List[str]:

        async def infer_chunk(llms: List[AsyncLLM], data: List):
            results = [llms[i % len(llms)](text) for i, text in enumerate(data)]
            with tqdm(total=len(results)) as pbar:
                results = await asyncio.gather(
                    *[
                        AsyncAPICall._run_task_with_progress(task, pbar)
                        for task in results
                    ]
                )
            return results

        idx = 0
        all_df = []
        file_exist_skip = False
        user_confirm = False

        while idx < len(data):
            file_path = os.path.join(output_dir, "tmp", f"{idx}.csv.temp")

            if os.path.exists(file_path):
                if not user_confirm:
                    while True:
                        user_response = input(
                            f"Find {file_path} file already exists. Do you want to skip them forever?\ny or Y to skip, n or N to rerun to overwrite: "
                        )
                        if user_response.lower() == "y":
                            user_confirm = True
                            file_exist_skip = True
                            break
                        elif user_response.lower() == "n":
                            user_confirm = True
                            file_exist_skip = False
                            break

                if file_exist_skip:
                    tmp_df = pd.read_csv(file_path)
                    all_df.append(tmp_df)
                    idx += chunk_size
                    continue

            tmp_data = data[idx : idx + chunk_size]
            loop = asyncio.get_event_loop()
            tmp_result = loop.run_until_complete(infer_chunk(llms=llms, data=tmp_data))
            tmp_result = [item.content for item in tmp_result]

            tmp_df = pd.DataFrame({"infer": tmp_result})

            if not os.path.exists(p := os.path.dirname(file_path)):
                os.makedirs(p, exist_ok=True)

            tmp_df.to_csv(file_path, index=False)
            all_df.append(tmp_df)
            idx += chunk_size

        all_df = pd.concat(all_df)
        return all_df["infer"]


def async_api_infer(
    model_name_or_path: str = "",
    eval_dataset: str = "",
    template: str = "",
    dataset_dir: str = "data",
    do_predict: bool = True,
    predict_with_generate: bool = True,
    max_samples: int = None,
    output_dir: str = "output",
    chunk_size=50,
):

    if len(sys.argv) == 1:
        model_args, data_args, training_args, finetuning_args, generating_args = (
            get_train_args(
                dict(
                    model_name_or_path=model_name_or_path,
                    dataset_dir=dataset_dir,
                    eval_dataset=eval_dataset,
                    template=template,
                    output_dir=output_dir,
                    do_predict=True,
                    predict_with_generate=True,
                    max_samples=max_samples,
                )
            )
        )
    else:
        model_args, data_args, training_args, finetuning_args, generating_args = (
            get_train_args()
        )

    dataset = _get_merged_dataset(
        data_args.eval_dataset, model_args, data_args, training_args, "sft"
    )

    labels = [item[0]["content"] for item in dataset["_response"]]
    prompts = [item[0]["content"] for item in dataset["_prompt"]]

    infers = AsyncAPICall.async_run(
        llms,
        prompts,
        chunk_size=chunk_size,
        output_dir=training_args.output_dir,
    )

    if not os.path.exists(training_args.output_dir):
        os.makedirs(training_args.output_dir, exist_ok=True)

    output_prediction_file = os.path.join(
        training_args.output_dir, "generated_predictions.jsonl"
    )

    with open(output_prediction_file, "w", encoding="utf-8") as writer:
        res: List[str] = []
        for text, pred, label in zip(prompts, infers, labels):
            res.append(
                json.dumps(
                    {"prompt": text, "predict": pred, "label": label},
                    ensure_ascii=False,
                )
            )
        writer.write("\n".join(res))


if __name__ == "__main__":
    fire.Fire(async_api_infer)

上述代码支持 yaml 文件:

async.yaml

## model
model_name_or_path: qwen/Qwen2.5-7B-Instruct

### method
do_predict: true

### dataset
dataset_dir: ../data/
eval_dataset: alpaca_zh_demo
template: qwen

### output
output_dir: output
# overwrite_output_dir: true


### eval
predict_with_generate: true
max_samples: 100

运行结果

使用下述代码运行:

python async_call_api.py async.yaml

异步调用的API推理结果:

{"prompt": "识别并解释给定列表中的两个科学理论:细胞理论和日心说。", "predict": "细胞理论和日心说是两个重要的科学理论,分别属于生物学和天文学领域。\n\n1. 细胞理论:细胞理论是生物学中的一个基本理论,它描述了所有生命体的基本结构和功能单位——细胞。细胞理论包含以下三个主要观点:\n\n   a. 所有生物体都是由一个或多个细胞组成的。\n   b. 细胞是生命活动的基本单位,负责维持生物体的生命活动。\n   c. 新细胞只能由已存在的细胞分裂产生。\n\n   细胞理论的提出和发展对于生物学研究具有重要意义,它揭示了生物体的基本结构和功能单位,为生物学研究提供了理论基础。\n\n2. 日心说:日心说是一种关于太阳系结构的理论,它描述了太阳在太阳系中的中心位置,而地球和其他行星则围绕太阳旋转。日心说最初由古希腊天文学家阿里斯塔克提出,但直到16世纪,哥白尼重新提出这一理论,并得到了进一步的发展和完善。日心说的提出颠覆了传统的地心说观念,为现代天文学的发展奠定了基础。\n\n这两个科学理论分别在生物学和天文学领域中具有重要的意义,它们揭示了生命体和宇宙的基本结构和运行规律。", "label": "细胞理论是生物科学的一个理论,它认为所有生命体都是由微小的基本单元——细胞所构成。这是生物学的一个基础理论,认为细胞是所有生物的基本结构和功能单位,所有的生物都是由一个或多个细胞组成,细胞只能通过细胞分裂产生新的细胞。这一理论由薛定谔、施瓦内和雪莱于1839年首次提出。\n\n日心说是指太阳是太阳系的中心,也就是说,行星围绕太阳旋转的理论。这个理论打破了传统的地心说观点,认为地球并不是宇宙的中心。日心说的提出者是尼古拉·哥白尼,他在16世纪初发表了他的著作《天体运行论》,阐述了太阳系行星围绕太阳运行的模型,为天文学的发展做出了巨大贡献。"}
{"prompt": "输入三支篮球队的名称并生成一个适当的口号。\n输入:俄克拉荷马城雷霆队,芝加哥公牛队,布鲁克林网队。", "predict": "\"雷霆激战,公牛荣耀,网破天荒,三队齐心创辉煌!\"", "label": "输出: \"雷霆,公牛和网队:各显神通,角逐群雄!\""}
{"prompt": "比较以下两位著名人物,使用特征列表并以句子形式提供输出。\n人物1:阿尔伯特·爱因斯坦\n人物2:斯蒂芬·霍金\n特征:智力、教育、对世界的影响", "predict": "人物1:阿尔伯特·爱因斯坦\n- 智力:极高,以其相对论理论对物理学产生了深远影响。\n- 教育:获得苏黎世联邦理工学院的学位,并在普林斯顿高等研究院担任教授。\n- 对世界的影响:他的相对论理论改变了我们对时间、空间和重力的理解,对现代科技和宇宙学产生了巨大影响。\n\n人物2:斯蒂芬·霍金\n- 智力:极高,尽管患有肌萎缩性脊髓侧索硬化症,但他在宇宙学和黑洞研究方面做出了重要贡献。\n- 教育:获得牛津大学和剑桥大学的学位,并在剑桥大学担任卢卡斯数学教授。\n- 对世界的影响:他使宇宙学和黑洞研究更加普及,通过《时间简史》等书籍向大众解释复杂的科学概念,激励了无数人对科学的兴趣。", "label": "阿尔伯特·爱因斯坦和斯蒂芬·霍金都是拥有极其出色智力的人物。两人都取得过非常高的教育成就,他们推进了科学发展并在世界范围内产生了深远的影响。爱因斯坦以其相对论和质能关系公式而闻名,而霍金以其关于黑洞和宇宙的发现而著称。两位科学家都以其深厚的学识和非凡的贡献影响了世界。"}

在输出结果中, predict 是大模型的推理结果。方便大家对比 predict 和 label,并评估大模型推理的精度。

请添加图片描述

为了避免大模型中途程序崩溃,把原始数据分块进行推理。这样即使程序中途崩溃,也能基于之前保存的分快数据继续推理,而不用重新开始推理。

速度对比

异步调用速度

下面是两个异步调用的进度条:

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:22<00:00, 2.27it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:22<00:00, 2.22it/s]

上述异步实验总共数据为100条数据,分块大小为50,故有2个进度条。100条速度44秒全部处理完成,平均处理速度 每秒处理2.2条数据。

同步调用速度

同步调用 LLM api 的代码很简单,如下所示:

infers = []
for prompt in tqdm(prompts):
    infers.append(llm.llm.invoke(prompt))

下面是同步调用的进度条:

100%|████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [06:54<00:00, 4.15s/it]

如果使用同步调用,100条数据,总共耗时 6分54秒,平均每条耗时4.15秒。

方法推理100条数据时间
同步6分54秒
异步44秒

对比之下,异步调用比同步调用快了大约 9.41 倍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/928987.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ceph的存储池管理

1 查看存储池信息 查看存储池的名称 [rootceph141ceph]# ceph osd pool ls .mgr查看存储池机器编号 [rootceph141ceph]# ceph osd pool ls 1 .mgr查看存储池的详细信息 [rootceph141ceph]# ceph osd pool ls detail pool 1 .mgr replicated size 3 min_size 2 crush_rule 0 ob…

一些常见网络安全术语

1、黑帽 为非法目的进行黑客攻击的人&#xff0c;通常是为了经济利益。他们进入安全网络以销毁&#xff0c;赎回&#xff0c;修改或窃取数据&#xff0c;或使网络无法用于授权用户。这个名字来源于这样一个事实&#xff1a;老式的黑白西部电影中的恶棍很容易被电影观众识别&…

Vue中使用ECharts图表中的阈值标记(附源码)

在数据处理和可视化领域&#xff0c;我们经常需要对一系列数据点进行分析。本文将介绍如何在给定的数据点中找到对应于特定Y值的X值&#xff0c;并设置标线起始点标记在ECharts图表中&#xff0c;效果图如下&#xff1a; 实现步骤 1、数据准备 let seriesData [// 提供日期…

Windows 11 如何配置node.js

一&#xff0c;官网下载 官网首页 下载最新LTS版本&#xff0c;比较稳定&#xff0c;如果想探索更新的版本去探索新的nodejs功能。 1. 下载完成后&#xff0c;双击运行程序&#xff0c;点击next 2. 勾选接受协议&#xff0c;点击next 3. 选择自己的安装路径&#xff08;默认是…

笔记本电脑usb接口没反应怎么办?原因及解决方法

笔记本电脑的USB接口是我们日常使用中非常频繁的一个功能&#xff0c;无论是数据传输、充电还是外接设备&#xff0c;都离不开它。然而&#xff0c;当USB接口突然没有反应时&#xff0c;这无疑会给我们的工作和学习带来不小的困扰。下面&#xff0c;我们就来探讨一下笔记本USB接…

计算机毕业设计hadoop+spark民宿推荐系统 民宿数据分析可视化大屏 民宿爬虫 民宿大数据 知识图谱 机器学习 大数据毕业设计

温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 作者简介&#xff1a;Java领…

视频监控集中管理方案设计:Liveweb视频汇聚方案技术特点与应用

随着科技的发展&#xff0c;视频监控平台在各个领域的应用越来越广泛。然而&#xff0c;当前的视频监控平台仍存在一些问题&#xff0c;如视频质量不高、监控范围有限、智能化程度不够等。这些问题不仅影响了监控效果&#xff0c;也制约了视频监控平台的发展。 为了解决这些问…

SpringBoot中@Import和@ImportResource和@PropertySource

1. Import Import注解是引入java类&#xff1a; 导入Configuration注解的配置类&#xff08;4.2版本之前只可以导入配置类&#xff0c;4.2版本之后也可以导入普通类&#xff09;导入ImportSelector的实现类导入ImportBeanDefinitionRegistrar的实现类 SpringBootApplication…

css栅格系统与多列

栅格系统 栅格系统是媒体查询的具体实现 栅格系统将页面自动分为12个格子&#xff0c;可以根据不同的类前缀实现不同的页面布局 多列

Unity 利用Button 组件辅助Scroll View 滚动

实现 创建枚举类ScrollDir 以区分滚动方向。每组两个按钮负责同方向上左右/上下滚动。 Update 中实时获取Scroll View 滚动条当前位置。 if (dir.Equals(ScrollDir.vertical)) {posCurrent scroll.verticalNormalizedPosition; } else if (dir.Equals(ScrollDir.horizontal)…

WEB安全 PHP学习

PHP基础 PHP编码显示问题 header ("Content-type: text/html; charsetgb2312"); header("Content-Type: text/html;charsetutf-8"); windows需要使用gbk编码显示 源码是 <?php header ("Content-type: text/html; charsetgb2312"); sys…

4. IO Stream

文章目录 一、相对论理解IO流二、汉语文学理解流三、图解IO流四、俩亲爹InputStream和OutputStream五、FileInputStream字节流读取文件六、FileOutputStream字节流写入文件七、buff缓冲复制文件1. 例一(无buff)2. 例二(有buff) 八、buffered字节缓冲流和装饰设计模式九、FileRe…

图数据库 | 12、图数据库架构设计——高性能计算架构

在传统类型的数据库架构设计中&#xff0c;通常不会单独介绍计算架构&#xff0c;一切都围绕存储引擎展开&#xff0c;毕竟存储架构是基础&#xff0c;尤其是在传统的基于磁盘存储的数据库架构设计中。 类似地&#xff0c;在图数据库架构设计中&#xff0c;项目就围绕存储的方…

YOLOv9改进,YOLOv9引入SAConv可切换空洞卷积,二次创新RepNCSPELAN4结构

摘要 作者提出的技术结合了递归特征金字塔和可切换空洞卷积,通过强化多尺度特征学习和自适应的空洞卷积,显著提升了目标检测的效果。 理论介绍 空洞卷积(Atrous Convolution)是一种可以在卷积操作中插入“空洞”来扩大感受野的技术,更有效地捕捉到图像中的大范围上下文…

【热门主题】000075 探索嵌入式硬件设计的奥秘

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;并提供具体代码帮助大家深入理解&#xff0c;彻底掌握&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 【热…

CanFestival移植到STM32 F4芯片(基于HAL库)

本文讲述如何通过简单操作就可以把CanFestival库移植到STM32 F4芯片上&#xff0c;作为Slave设备。使用启明欣欣的工控板来做实验。 一 硬件连接 观察CAN报文需要专门的设备&#xff0c;本人从某宝上买了一个兼容PCAN的开源小板子&#xff0c;二十几块钱&#xff0c;通过USB接…

洛谷P1827 [USACO3.4] 美国血统 American Heritage(c嘎嘎)

题目链接&#xff1a;P1827 [USACO3.4] 美国血统 American Heritage - 洛谷 | 计算机科学教育新生态 题目难度&#xff1a;普及 首先介绍下二叉树的遍历&#xff1a; 学过数据结构都知道二叉树有三种遍历&#xff1a; 1.前序遍历&#xff1a;根左右 2.中序遍历&#xff1a;左根…

工业—使用Flink处理Kafka中的数据_ProduceRecord2

使用 Flink 消费 Kafka 中 ProduceRecord 主题的数据,统计在已经检验的产品中,各设备每 5 分钟 生产产品总数,将结果存入HBase 中的 gyflinkresult:Produce5minAgg 表, rowkey“

JavaEE-经典多线程样例

文章目录 单例模式设计模式初步引入为何存在单例模式饿汉式单例模式饿汉式缺陷以及是否线程安全懒汉式单例模式基础懒汉式缺陷以及是否线程安全懒汉式单例模式的改进完整代码(变量volatile) 阻塞队列生产者消费者模型生产者消费者模型的案例以及优点请求与响应案例解耦合削峰填…

Web3的技术栈详解:解读区块链、智能合约与分布式存储

随着数字时代的不断发展&#xff0c;Web3作为下一代互联网的核心理念逐渐走进了大众视野。它承载着去中心化、用户主权以及更高效、更安全的网络环境的期望。Web3不再是由少数中心化机构主导的网络&#xff0c;而是通过一系列核心技术的支撑&#xff0c;给每个用户赋予了更多的…