AGI|教你用一部电影的时间训练个人专属Agent

目录

一、Agent如何工作?

二、Function Call 原理

三、开源模型工具调用微调

//Chat模型微调

//训练过程日志

//测试结果

//测试Tools

四、预训练模型微调

五、总结


Agent是一个超越简单文本生成的人工智能系统。它使用大型语言模型(LLM)作为其中央计算引擎,使其能够进行对话、执行任务、推理并显示一定程度的自主权。

一、Agent如何工作?

1、当用户给出一个任务task之后可以从memory中查询记录(可选),查询出的结果(如果有)给AgentLLM进行判断是否可复用,这里指的复用是针对时效性没那么高的任务,例如对过去时的数据“中国19-22年的出生及死亡人口数据”,但如果查询股票数据,天气这种对时效性有很高要求的任务则不适合复用。


2、Agent对任务实现的方式有很多,可以拆解任务、使用lCOT或REACT框架、SOP(Standard Operating Procedure)标准作业规程等等。其目的都是将一个复杂的任务分成n个可在one step内即可完成的子任务。


3、对于子任务,是否需要调用工具,如果无需调用工具则只需要进行一次推理即可;对于需要调用工具的子任务AgentLLM会根据任务描述调用一个或多个工具,根据工具返回结果判断是否可以更改任务状态。待所有的子任务都完成状态变更之后AgentLLM会对结果进行评估反思,判断当前任务是否已经完成。如果某些子任务因为种种原因无法完成,AgentLLM会采取别的方法完成此任务,重复以上步骤直到可以给出结果为止,当然这里的Loop需要设置最大重试次数避免死循环。


4、当AgentLLM判断可以完成任务后可以进行历史任务存储(可选)。长期记忆是将数据存储在数据库中,以便下次查询,短期记忆则保存在内存或缓存中,程序结束时释放。

二、Function Call 原理

在一些任务中我们希望LLM返回我们格式化的数据如json、xml等,function call则需要LLM返回特定的json格式,以OpenAI为例,需要提供工具的描述信息。

from openai import OpenAI
import json

client = OpenAI()



def get_current_weather(location, unit="fahrenheit"):
    """Get the current weather in a given location"""
    if "tokyo" in location.lower():
        return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit})
    elif "san francisco" in location.lower():
        return json.dumps({"location": "San Francisco", "temperature": "72", "unit": unit})
    elif "paris" in location.lower():
        return json.dumps({"location": "Paris", "temperature": "22", "unit": unit})
    else:
        return json.dumps({"location": location, "temperature": "unknown"})

def run_conversation():
    
    messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}]
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "location": {
                            "type": "string",
                            "description": "The city and state, e.g. San Francisco, CA",
                        },
                        "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
                    },
                    "required": ["location"],
                },
            },
        }
    ]
    response = client.chat.completions.create(
        model="gpt-3.5-turbo-1106",
        messages=messages,
        tools=tools,
        tool_choice="auto",
    )
    response_message = response.choices[0].message
    tool_calls = response_message.tool_calls
    
    if tool_calls:
        
        
        available_functions = {
            "get_current_weather": get_current_weather,
        }
        messages.append(response_message)
        
        for tool_call in tool_calls:
            function_name = tool_call.function.name
            function_to_call = available_functions[function_name]
            function_args = json.loads(tool_call.function.arguments)
            function_response = function_to_call(
                location=function_args.get("location"),
                unit=function_args.get("unit"),
            )
            messages.append(
                {
                    "tool_call_id": tool_call.id,
                    "role": "tool",
                    "name": function_name,
                    "content": function_response,
                }
            )
        second_response = client.chat.completions.create(
            model="gpt-3.5-turbo-1106",
            messages=messages,
        )
        return second_response
print(run_conversation())

在推理结果中可以拿到类似{"name": "get_current_weather", "params": {"location": "北京", "unit": "celsius"}}这样的json数据,这里有需要调用的工具名称以及参数信息,接下来只需要编写代码实现工具调用,将工具返回的结果构造成message加入到与LLM对话的上下文中即可实现工具调用。这里的难点在于对一个开源模型来说,如何根据任务以及提供的工具描述给出正确的工具名称以及正确的参数。

三、开源模型工具调用微调

  • 开源项目地址:LLaMA-Factory
  • 作者知乎最佳实践地址:单卡 3 小时训练专属大模型 Agent:基于 LLaMA Factory 实战

以下为复现实验数据过程记录

//Chat模型微调

模型Yi-6B-Chat硬件信息NVIDIA A100-SXM4-80GB
sft超参

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage sft \
    --do_train \
    --model_name_or_path /mnt/models/Yi-6B-Chat \
    --dataset glaive_toolcall \
    --template yi \
    --finetuning_type lora \
    --lora_target q_proj,v_proj \
    --output_dir yi_agent_checkopint \
    --lora_target all \
    --overwrite_cache \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 5e-4 \
    --num_train_epochs 3 \
    --plot_loss \
    --fp16

export model

python src/export_model.py \
    --model_name_or_path /mnt/models/Yi-6B-Chat \
    --adapter_name_or_path yi_agent_checkopint \
    --template yi \
    --finetuning_type lora \
    --export_dir Yi-Agent-6b-Chat \
    --export_size 2 \
    --export_legacy_format False

web demo


python src/web_demo.py --model_name_or_path Yi-Agent-6b-Chat --template yi

//训练过程日志

{'train_runtime': 7735.6787, 'train_samples_per_second': 3.878, 'train_steps_per_second': 0.242, 'train_loss': 0.3381453339894613, 'epoch': 3.0}
100%|████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [2:08:55<00:00, 4.13s/it]
[INFO|trainer.py:2889] 2024-01-25 13:39:49,599 >> Saving model checkpoint to yi_agent_checkopint
[INFO|tokenization_utils_base.py:2432] 2024-01-25 13:39:49,709 >> tokenizer config file saved in yi_agent_checkopint/tokenizer_config.json
[INFO|tokenization_utils_base.py:2441] 2024-01-25 13:39:49,709 >> Special tokens file saved in yi_agent_checkopint/special_tokens_map.json
***** train metrics *****
  epoch = 3.0
  train_loss = 0.3381
  train_runtime = 2:08:55.67
  train_samples_per_second = 3.878
  train_steps_per_second = 0.242
Figure saved: yi_agent_checkopint/training_loss.png
01/25/2024 13:39:49 - WARNING - llmtuner.extras.ploting - No metric eval_loss to plot.
[INFO|modelcard.py:452] 2024-01-25 13:39:49,848 >> Dropping the following result as it does not have all the necessary fields:
{'task': {'name': 'Causal Language Modeling', 'type': 'text-generation'}}

//测试结果

//测试Tools

[
    {
        "name": "get_province_list",
        "description": "获取省份ID",
        "parameters": {
            "type": "object",
            "properties": {}
        }
    },
    {
        "name": "get_cities_list",
        "description": "根据省份ID查询城市地区ID",
        "parameters": {
            "type": "object",
            "properties": {
                "province_id": {
                    "type": "string",
                    "description": "省份ID,可以通过调用get_province_list获取省份ID"
                }
            },
            "required": [
                "province_id"
            ]
        }
    },
    {
        "name": "get_history_weather",
        "description": "根据城市ID和日期查询历史天气信息,日期支持从2011-01-01开始。注:个别地区个别日期数据记录可能会不存在",
        "parameters": {
            "type": "object",
            "properties": {
                "city_id": {
                    "type": "string",
                    "description": "城市地区ID,可以通过调用get_cities_list获取城市地区ID"
                },
                "weather_date": {
                    "type": "string",
                    "description": "日期,格式:2017-07-15,日期不能大于等于今日日期"
                }
            },
            "required": [
                "city_id",
                "weather_date"
            ]
        }
    },
    {
        "name": "get_river_environment",
        "description": "查询地表水水质",
        "parameters": {
            "type": "object",
            "properties": {
                "page": {
                    "type": "integer",
                    "description": "第几页(默认1)"
                },
                "province": {
                    "type": "string",
                    "description": "省份,例:江苏省"
                },
                "river": {
                    "type": "string",
                    "description": "流域,例:海河流域"
                },
                "section": {
                    "type": "string",
                    "description": "断面名称,例:鼓楼外大街"
                }
            },
            "required": []
        }
    },
    {
        "name": "get_environment_air_pm",
        "description": "查询的城市PM2.5数据",
        "parameters": {
            "type": "object",
            "properties": {
                "city": {
                    "type": "string",
                    "description": "城市名称的中文名称或拼音,如:上海 或 shanghai"
                }
            },
            "required": [
                "city"
            ]
        }
    },
    {
        "name": "get_toutiao_news",
        "description": "新闻列表查询",
        "parameters": {
            "type": "object",
            "properties": {
                "type": {
                    "type": "string",
                    "description": "支持类型 top(推荐,默认) guonei(国内) guoji(国际) yule(娱乐) tiyu(体育) junshi(军事) keji(科技) caijing(财经) youxi(游戏) qiche(汽车) jiankang(健康)"
                },
                "page": {
                    "type": "string",
                    "description": "当前页数, 默认1, 最大50"
                },
                "page_size": {
                    "type": "string",
                    "description": "每页返回条数, 默认30 , 最大30"
                },
                "is_filter": {
                    "type": "string",
                    "description": "是否只返回有内容详情的新闻, 1:是, 默认0"
                }
            },
            "required": []
        }
    },
    {
        "name": "chejian_query",
        "description": "根据车辆注册日期及类型,计算车辆的下次上线检验时间。本计算结果仅供参考。",
        "parameters": {
            "type": "object",
            "properties": {
                "type": {
                    "type": "string",
                    "description": "车辆类型, 3:9座(含)以下非营运小微型载客汽车(面包车除外) 4:摩托车 7:非营运大型轿车 1:营运车辆 2:货车、大中型客车 6:面包车 5:其他机动车"
                },
                "reg_date": {
                    "type": "string",
                    "description": "注册登记日期,格式:2022-11-02"
                },
                "iis_sg": {
                    "type": "integer",
                    "description": "事故情况(是否发生过致人伤亡事故或存在非法改装被依法处罚的交通违法),如是传1"
                }
            },
            "required": [
                "type",
                "reg_date"
            ]
        }
    },
    {
        "name": "loan_calc_query",
        "description": "公积金贷款计算器用于计算用户在申请公积金贷款时,选择等额本金和等额本息两种不同的还款方式后,每一期需偿还公积金贷款的月供,以及利息总额和还款总额。",
        "parameters": {
            "type": "object",
            "properties": {
                "money": {
                    "type": "integer",
                    "description": "贷款金额(0 < money <= 500),单位(万),如70表示70万;"
                },
                "year": {
                    "type": "integer",
                    "description": "贷款年限,单位(年),仅限输入 5、10、15、20、25、30"
                },
                "active": {
                    "type": "string",
                    "description": "贷款利率,默认3.25"
                }
            },
            "required": [
                "money",
                "year"
            ]
        }
    },
    {
        "name": "icp_query",
        "description": "网站icp备案查询",
        "parameters": {
            "type": "object",
            "properties": {
                "domainName": {
                    "type": "string",
                    "description": "获取的域名,如:juhe.cn"
                }
            },
            "required": [
                "domainName"
            ]
        }
    },
    {
        "name": "airport_query",
        "description": "获取全球机场三字码",
        "parameters": {
            "type": "object",
            "properties": {
                "airport": {
                    "type": "string",
                    "description": "关键词(可匹配城市机场的中英文名称、机场三字码)"
                },
                "page": {
                    "type": "integer",
                    "description": "页码(默认为1)"
                },
                "per_page": {
                    "type": "integer",
                    "description": "每页显示数量(默认为20,最大为100)"
                }
            },
            "required": [
                "airport"
            ]
        }
    },
    {
        "name": "aptabnormal_query",
        "description": "根据机场三字码查询国内机场不正常航班列表",
        "parameters": {
            "type": "object",
            "properties": {
                "airport": {
                    "type": "string",
                    "description": "机场三字码,字母大写(如:PEK),可通过airport_query获取三字码"
                }
            },
            "required": [
                "airport"
            ]
        }
    }
]

测试问题

请参考工具调用能力测试中的场景列(https://www.yuque.com/mrbun/sgr5h5/hsnz17g1a1wr6k2t#KmgD)

四、预训练模型微调

硬件信息NVIDIA-4090 24G 单卡
sft超参

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage sft \
    --do_train \
    --model_name_or_path /data/models/Yi-6B \
    --dataset glaive_toolcall,alpaca_gpt4_en,alpaca_gpt4_zh,oaast_sft_zh \
    --max_samples 8000 \
    --template default \
    --finetuning_type lora \
    --lora_target q_proj,v_proj \
    --output_dir yi_agent_checkopint \
    --lora_target all \
    --overwrite_cache \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 5e-5 \
    --num_train_epochs 2 \
    --plot_loss \
    --fp16 \
    --flash_attn

export model

python src/export_model.py \
    --model_name_or_path /data/models/Yi-6B \
    --adapter_name_or_path /data/projects/LLaMA-Factory/yi_agent_checkopint \
    --template default \
    --finetuning_type lora \
    --export_dir Yi-Agent-6B-Chat \
    --export_size 2 \
    --export_legacy_format False

web demo

python src/web_demo.py \
    --model_name_or_path Yi-Agent-6B-Chat \
    --template default

测试结果不再赘述。

五、总结

通过SFT微调后可以让原本不具备工具调用能力的模型实现工具调用。通过测试结果可以看出对于复杂场景的效果不是很好,单工具的场景正确率很高,测试的场景是中文场景,训练集中是英文,泛化效果也很不错,我正在准备以下类型数据集,如果有类似的数据集可以在下面贴出连接。

  • API参数描述中需要调用另外一个接口拿到的场景,例如天气查询中的城市id需要调用获取城市idAPI拿到。
  • 对于问题中参数信息不完整,主动抛出问题获取更详细参数信息的场景。
  • 多工具场景。

模型已发布到modelscope Yi-Agent-6B-Chat

作者:徐辉| 后端开发工程师

更多AI小知识欢迎关注“神州数码云基地”公众号,回复“AI与数字化转型”进入社群交流

版权声明:文章由神州数码武汉云基地团队实践整理输出,转载请注明出处。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/412602.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

CS_上线三层跨网段机器(完整过程还原)

以前讲过用cs_smb_beacon上线不出网机器&#xff0c;但是真实的网络拓扑肯定不止这么一层的网络&#xff01; 所以我就来搭建一个复杂一点的网络环境&#xff01;&#xff01; 当然了&#xff0c;这三台电脑之间都是不同的网段&#xff0c;&#xff08;但是同属于一个域环境&a…

MySQL数据库进阶第五篇(锁)

文章目录 一、锁的概述二、全局锁三、表级锁四、元数据锁&#xff08;meta data lock, MDL&#xff09;五、意向锁六、行级锁七、行锁&#xff08;Record Lock&#xff09;八、间隙锁&#xff08;Gap Lock&#xff09;九、临键锁&#xff08;Next-Key Lock&#xff09;十、锁总…

遇见数字孪生城市美好模样 国产GIS加速创新竞逐“新赛道”

在未来数字孪生城市中&#xff0c;每一座建筑、每一条道路、每一盏路灯、每一辆车&#xff0c;甚至每一个人&#xff0c;都有其对应的数字孪生体。这些数字孪生体在虚拟世界中形成了城市的细致模型&#xff0c;如同镜子一样&#xff0c;反映出城市的实时状态&#xff0c;同时也…

Datawhale-Sora技术原理分享

目录 Sora能力边界探索 Sora模型训练流程 Sora关键技术拆解 物理引擎的数据进行训练 个人思考与总结 参考 https://datawhaler.feishu.cn/file/KntHbV3QGoEPruxEql2c9lrsnOb

HTB pwn Dragon Army

逆向分析 程序使用了alloca函数扩大了栈区 此处可以泄露libc的地址 程序主要功能在下面 while ( 1 ){while ( 1 ){fflush(stdin);fflush(_bss_start);fprintf(_bss_start, "\n%sDragons: [%d/%d]%s\n\n", "\x1B[1;34m", v5, 13LL, "\x1B[1;37m"…

山海鲸可视化:数据分析师的明智之选

作为一名资深的数据分析师&#xff0c;我深知数据可视化在数据分析和决策制定中的重要性。在众多的数据可视化工具中&#xff0c;我最终选择了山海鲸可视化软件。本文就简单介绍几点我选择它的主要原因&#xff1a; 首先&#xff0c;山海鲸可视化软件提供了强大的数据整合能力…

MySQL-事务,properties文件解析,连接池

1.事务机制管理 1.1 Transaction事务机制管理 默认情况下是执行一条sql语句就保存一次&#xff0c;那么比如我们需要三条数据同时成功或同时失败就需要开启事务机制了。开启事务机制后执行过程中发生问题就会回滚到操作之前&#xff0c;相当于没有执行操作。 1.2 事务的特征 事…

sql-labs25-28a

一、环境 网上都有不过多阐述 二、sql-labs第25关 它说你的OR和and属于它,那就是过滤了OR和and 注入尝试 不用or和and进行爆破注入,很明显是有注入点的 ?id-1 union select 1,2,3-- 查看数据库 ok&#xff0c;此道题算是解了但是如果我们用了and了呢 ?id-1 and updatex…

消息中间件篇之Kafka-高性能设计

一、高性能设计 消息分区&#xff1a;不受单台服务器的限制&#xff0c;可以不受限的处理更多的数据。 顺序读写&#xff1a;磁盘顺序读写&#xff0c;提升读写效率。 页缓存&#xff1a;把磁盘中的数据缓存到内存中&#xff0c;把对磁盘的访问变为对内存的访问。 零拷贝&a…

C# 水排序 微信小游戏

来只 水排序谜题启发式搜索方法_水排序解法小程序-CSDN博客 大神的C语言转换成C# 语言&#xff0c;更多的请看原作者&#xff0c;这里直接贴C#代码 using System; using System.Collections.Generic; using System.Linq; using System.Text;namespace ConsoleApp2 {class Pro…

SImpAl

output matrix M&#xff0c;Curriculum using w ( x t , f , h ) w(x_t, f, h) w(xt​,f,h) 辅助信息 作者未提供代码

想露一手Linux命令,掌握这20个够用了!

中午好&#xff0c;我的网工朋友。 想做好网工&#xff0c;除了路由交换&#xff0c;掌握Linux操作系统也是很重要的。 因为很多服务器上都是用的Linux系统&#xff0c;要和服务器机器交互&#xff0c;就要通过Linux的相关命令。 正在找工作的朋友&#xff0c;如果面试时&am…

thinkphp6定时任务

这里主要是教没有用过定时任务没有头绪的朋友, 定时任务可以处理一些定时备份数据库等一系列操作, 具体根据自己的业务逻辑进行更改 直接上代码 首先, 是先在 tp 中的 command 方法中声明, 如果没有就自己新建一个, 代码如下 然后就是写你的业务逻辑 执行定时任务 方法写好了…

leetcode 2867. 统计树中的合法路径数目【筛质数+贡献法】

原题链接&#xff1a;2867. 统计树中的合法路径数目 题目描述&#xff1a; 给你一棵 n 个节点的无向树&#xff0c;节点编号为 1 到 n 。给你一个整数 n 和一个长度为 n - 1 的二维整数数组 edges &#xff0c;其中 edges[i] [ui, vi] 表示节点 ui 和 vi 在树中有一条边。 …

openai sora 只能根据文本生成视频?不,TA 是通用物理世界模拟器

视频生成模型作为世界模拟器 我们探索了在视频数据上进行大规模生成模型的训练。 具体来说&#xff0c;我们联合在可变持续时间、分辨率和长宽比的视频和图像上训练文本条件扩散模型。 我们利用了一个在视频和图像潜在编码的时空补丁上操作的变压器架构。 我们最大的模型So…

请求包的大小会影响Redis每秒处理请求数量

文章目录 &#x1f50a;博主介绍&#x1f964;本文内容压测规划客户端长连接数量对性能的影响请求包大小的影响Pipleline模式对Redis的影响 &#x1f4e2;文章总结&#x1f4e5;博主目标 &#x1f50a;博主介绍 &#x1f31f;我是廖志伟&#xff0c;一名Java开发工程师、Java领…

CG-0A 电子水尺可实现对水位数据的连续自动监测

CG-0A 电子水尺可实现对水位数据的连续自动监测产品概述 本产品是一种采用微处理器芯片为控制器&#xff0c;内置通讯电路的数字式水位传感器&#xff0c;具备高的可靠性及抗干扰性能。适用于江、河、湖、水库及蓄水池、水渠等处的水位测量使用。 本产品采用了生产工艺技术&…

springboot集成kafka快速入门demo

一、kafka介绍 Kafka是一种基于分布式发布-订阅消息系统的开源软件。 其目标是提供高吞吐量、低延迟、可扩展性和容错能力。 Kafka中将消息存储在可配置数量的分区中&#xff0c;以便实现横向扩展&#xff0c;并且支持多个生产者和消费者&#xff0c;具有良好的可靠性保证机制。…

【精选】Java面向对象进阶——静态内部类和局部内部类

&#x1f36c; 博主介绍&#x1f468;‍&#x1f393; 博主介绍&#xff1a;大家好&#xff0c;我是 hacker-routing &#xff0c;很高兴认识大家~ ✨主攻领域&#xff1a;【渗透领域】【应急响应】 【Java】 【VulnHub靶场复现】【面试分析】 &#x1f389;点赞➕评论➕收藏 …

SpringCloud有哪些组件

什么是SpringCloud&#xff1f; Spring Cloud是基于Spring Boot的分布式系统开发工具&#xff0c;它提供了一系列开箱即用的、针对分布式系统开发的特性和组件&#xff0c;用于帮助开发人员快速构建和管理云原生应用程序。 Spring Cloud的主要目标是解决分布式系统中的常见问题…