ToolLLM: Facilitating Large Language Models to Master 16000+ Real-world APIs

ToolLLM: Facilitating Large Language Models to Master 16000+ Real-world APIsimage.png

一、动机

虽然现如今大模型展现出无与伦比的表现,但是其在工具理解和使用方面依然存在不足,即根据用户的指令和意图来使用外部API。这是因为现有的指令微调任务大多数是基于语言任务,而忽略了工具的调用指令。

Although open-source LLMs, e.g., LLaMA (Touvron et al., 2023a), have achieved versatile capabilities through instruction tuning (Taori et al., 2023; Chiang et al., 2023), they still lack the sophistication in performing higher-level tasks, such as appro- priately interacting with tools (APIs) to fulfill complex human instruction.

因此,本文希望为开源的大模型探索一个可以使用工具的模型,并提出TooLLM。
目前已有一些通过指令微调来提升大模型工具使用能力的方法,但是依然存在一些局限性:

  • API数量有限:缺乏一些真实场景数量丰富的API;
  • 任务场景单一:大多数都是单一API工具的调用,而真实场景下可能需要多个工具协同使用;
  • 规划和推理能力较差:现有研究采用 CoT(Wei 等,2023)或 ReACT(Yao 等,2022)进行模型推理,无法充分发挥 LLM 中存储的功能,因此无法处理复杂指令。此外,有些研究甚至不执行 API 来获取真实响应(Patil 等,2023;Tang 等,2023),而真实响应是后续模型规划的重要信息。

二、方法

整体的流程如下图所示:
image.png
数据构造与模型训练:

  • 首先从RapidAPI网站收集大量的API,并进行采样和处理;
  • 根据采样的API组合,调用ChatGPT进行指令生成,由此构造好相应的(指令,相关API)样本。这些样本可以用于训练API Retriever;
  • 基于(指令,相关API),继续让ChatGPT标注这些API的组合和调用顺序,形成(指令,solution path)样本,最终形成ToolBench。这些数据用于训练ToolLLaMA模型;

推理:

  • 给定一个指令,首先使用API Retriever检索所有可能的API;
  • 基于这些API与RapidAPI进行多轮交互获得API的调用结果,最后获得答案。

2.1 数据构建

首先提出ToolBench,其包含了从16464个Rapid API网站上爬取处理过的API。其数量统计与其他Benchmark对比如下表所示:
image.png
TooBench的构建有三个步骤:

  • API收集
  • 指令构建
  • Tool Path标注

(1)API收集
API主要来自于RapidAPI网站,其提供了49个粗粒度API的类别(Categories):
image.png
以及超过500个细粒度的类别(Collections)
image.png
每个tool会包含若干个API,每个tool将获取其名称、描述、URL以及所有涉及的API。对于每个API,也同时获得其名称、描述、HTTP、调用所需参数、request body信息、API的执行代码片段以及response样例等。这些元信息将有助于大模型来理解并使用API。
RapidAPI的层次结构如下图所示:
image.png
(2)指令构造
ToolBench的构造是为了后续让开源大模型能够更好地进行指令微调和工具的理解,因此,需要考虑乳腺癌两个场景:

  • diversity:需要提供多样的工具使用场景,使得大模型具有工具泛化能力;
  • multi-tool usage:需要符合真实场景,即一个问题可能需要多个工具协同完成。

首先对于所有的API集合中,随机采样一些API。其次设计prompt让ChatGPT理解这些API,并生成出如下两个内容:

  • instruction:生成一个指令,使得这个指令可以涉及到所采样的这些API;
  • relevant APIs:从被采样的API中,生成与instruction完全相关的API;

ChatGPT prompt用于提示ChatGPT的prompt包含三个部分:

  • description:指令生成任务的任务描述;
  • document:每个被采样的API的工具描述;
  • ICL exemplar:人工编写的多个包含single-tool和multi-tool场景下的工具调用样例,并采样3个作为exemplar;

ChatGPT生成指令可以形式化为如下公式:
image.png

工具采样策略single-tool:每个tool包含若干API,所以这些API直接作为一个组合,记作single-tool instruc- tions (I1)。
multi-tool:不能随便采样组合tool,因为有一些tool之间完全没有联系。因此这里借助RapidAPI提供的分层归类结构。同属于同一个Category或者Collection的2~5个tool可以随机组合起来,且每个tool最多挑选3个API。

  • 采样空间如果是同属于Category,其对应生成的instruction被称为 intra-category multi-tool instructions (I2);
  • 采样空间如果同属于Collection,其对应生成的instruction被称为intra-collection multi-tool instructions (I3)

通过不同的采样策略,ChatGPT生成的指令可以具备一定的多样性。
最终生成了200k(指令,相关API)的组合样本。87413, 84815, and 25251 instances for I1, I2, and I3。
由于这些(指令,相关API)的组合样本都是由ChatGPT经过工具理解来生成的,所以这些样本可以用于训练API Retriever。

given an instruction, the API retriever recommends a set of relevant APIs, which are sent to ToolLLaMA for multi-round decision making to derive the final answer. Despite sifting through a large pool of APIs, the retriever exhibits remarkable retrieval precision, returning APIs closely aligned with the ground truth.

(3)Tool Path标注
光有instruction和relevant API还是远远不够的,因为尤其是multi-tool usage场景,这些API的调用组合和先后顺序也是至关重要的。因此,这里依然借助ChatGPT来完成API调用的组合与顺序的标注。

给定一个instruction和若干相关的API(API的名称、描述、调用方法等元信息),ChatGPT需要能够给出一个正确的动作序列。以多轮对话的形式进行。
对于每一轮,此时ChatGPT的动作需要根据上一轮的API的结果,选择下一个API。选择的API之后需要进行思考(thought),思考的内容形式为“Thought: · · · , API Name: · · · , Parameters: · · · ”。
每个动作包含两个函数:

  • Finish with Final Answer:API给出的实质的结果,该动作表明API调用成功;
  • Finish by Giving Up:多个API尝试调用后依然失败,或者无法完成指令所需要的内容。

基于上述的设定,提出一种基于深度优先搜索的决策树算法(DFSDT)。如下图所示:
image.png
相比宽度优先搜索,采用深度优先搜索可以较快地找到一条满足条件的路径,节省ChatGPT的调用次数。满足条件的路径即能够从instruction出发,最终完成所有API的调用。

We perform DFSDT for all the generated instructions and only retain those passed solution paths.

为此,最终获得了126486个(指令、solution path)用于训练LLaMA。

2.2 ToolEval

本文也提出了一种针对工具使用的评估体系。包含两个评价指标:

  • Pass Rate:它计算在有限预算内成功完成指令的比例。该指标衡量了 LLM 指令的可执行性,可以看作是理想工具使用的基本要求;
  • Win Rate:每个指令设计了两个solution path,来让ChatGPT判断哪一个正确。通常情况下,一个path是待测baseline大模型生成的tool使用路径,另一个path可以是人类编写的path,或者chatgpt所编写的ground truth。

Through rigorous testing (details in appendix A.5), we find that ToolEval demonstrates a high agreement of 87.1% in pass rate and 80.3% in win rate with human annotators. This shows that ToolEval can reflect and represent human evaluation to a large extent.

2.3 ToolLLaMA

将构建好的126486个(指令、solution path)用于训练LLaMA2-7B模型。评估的时候,根据l1、l2、l3来评估模型的工具泛化性能。
实验结果如下所示:
image.png

三、复现

假设已经有一个ToolLLaMA模型,下面通过代码来描述让大模型使用工具的过程。
给定一个指令:

[
    {
        "query": "I'm planning a surprise party for my best friend, and I want to include meaningful quotes in the decorations. Can you provide me with random love, success, and motivation quotes? It would be great to have quotes that can celebrate love, success, and inspire everyone at the party. Thank you so much for your help!",
        "query_id": 82217
    }
]

执行下面代码完成工具使用:

'''
Open-domain QA Pipeline
'''
import argparse
from toolbench.inference.Downstream_tasks.rapidapi import pipeline_runner


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--corpus_tsv_path', type=str, default="your_retrival_corpus_path/", required=False, help='')
    parser.add_argument('--retrieval_model_path', type=str, default="your_model_path/", required=False, help='')
    parser.add_argument('--retrieved_api_nums', type=int, default=5, required=False, help='')
    parser.add_argument('--backbone_model', type=str, default="toolllama", required=False, help='chatgpt_function or davinci or toolllama')
    parser.add_argument('--openai_key', type=str, default="", required=False, help='openai key for chatgpt_function or davinci model')
    parser.add_argument('--model_path', type=str, default="your_model_path/", required=False, help='')
    parser.add_argument('--tool_root_dir', type=str, default="your_tools_path/", required=True, help='')
    parser.add_argument("--lora", action="store_true", help="Load lora model or not.")
    parser.add_argument('--lora_path', type=str, default="your_lora_path if lora", required=False, help='')
    parser.add_argument('--max_observation_length', type=int, default=1024, required=False, help='maximum observation length')
    parser.add_argument('--max_source_sequence_length', type=int, default=4096, required=False, help='original maximum model sequence length')
    parser.add_argument('--max_sequence_length', type=int, default=8192, required=False, help='maximum model sequence length')
    parser.add_argument('--observ_compress_method', type=str, default="truncate", choices=["truncate", "filter", "random"], required=False, help='maximum observation length')
    parser.add_argument('--method', type=str, default="CoT@1", required=False, help='method for answer generation: CoT@n,Reflexion@n,BFS,DFS,UCT_vote')
    parser.add_argument('--input_query_file', type=str, default="", required=False, help='input path')
    parser.add_argument('--output_answer_file', type=str, default="",required=False, help='output path')
    parser.add_argument('--toolbench_key', type=str, default="",required=False, help='your toolbench key to request rapidapi service')
    parser.add_argument('--rapidapi_key', type=str, default="",required=False, help='your rapidapi key to request rapidapi service')
    parser.add_argument('--use_rapidapi_key', action="store_true", help="To use customized rapidapi service or not.")
    parser.add_argument('--api_customization', action="store_true", help="To use customized api or not. NOT SUPPORTED currently under open domain setting.")
    
    args = parser.parse_args()

    # 执行函数,由于只有指令,所以需要先设置add_retrieval=True
    pipeline_runner = pipeline_runner(args, add_retrieval=True)
    pipeline_runner.run()

  1. 读取当前的指令数据,构造任务数据
def generate_task_list(self):
    args = self.args
    query_dir = args.input_query_file # 待执行指令所属的文件
    answer_dir = args.output_answer_file
    if not os.path.exists(answer_dir):
        os.mkdir(answer_dir)
    method = args.method # 例如DFS_woFilter_w2,即采用DFS来寻找API call solution path
    backbone_model = self.get_backbone_model() # 例如训练好的ToolLLaMA模型
    white_list = get_white_list(args.tool_root_dir) # 获得一些已有的tool
    task_list = []
    querys = json.load(open(query_dir, "r")) # 读取所有待测的指令
    for query_id, data_dict in enumerate(querys):
        # 遍历每一个待测指令
        if "query_id" in data_dict:
            query_id = data_dict["query_id"]
        if "api_list" in data_dict:
            origin_tool_names = [standardize(cont["tool_name"]) for cont in data_dict["api_list"]]
            tool_des = contain(origin_tool_names,white_list)
            if tool_des == False:
                continue
            tool_des = [[cont["standard_tool_name"], cont["description"]] for cont in tool_des]
        else:
            tool_des = None
        # 最终,每个指令都将对应一个任务,任务包含了推理的方法、backbone等信息
        task_list.append((method, backbone_model, query_id, data_dict, args, answer_dir, tool_des))
    return task_list
  1. 推理
def run(self):
    # 每个指令对应一个任务
    task_list = self.task_list
    random.seed(42)
    random.shuffle(task_list)
    print(f"total tasks: {len(task_list)}")
    new_task_list = []
    for task in task_list:
        out_dir_path = task[-2]
        query_id = task[2]
        output_file_path = os.path.join(out_dir_path,f"{query_id}_{self.args.method}.json")
        if not os.path.exists(output_file_path):
            new_task_list.append(task)
    task_list = new_task_list
    print(f"undo tasks: {len(task_list)}")
    if self.add_retrieval:
        # 如果只有指令,则需要调用retriever来获取可能的API
        retriever = self.get_retriever()
    else:
        retriever = None
    for k, task in enumerate(task_list):
        print(f"process[{self.process_id}] doing task {k}/{len(task_list)}: real_task_id_{task[2]}")
        # 执行API调用推理
        result = self.run_single_task(*task, retriever=retriever, process_id=self.process_id)
  1. Tool Retriever:采用向量检索的办法,根据指令获取topK个工具。
import time
import pandas as pd
from sentence_transformers import SentenceTransformer, util
import json
import re
from toolbench.utils import standardize, standardize_category, change_name, process_retrieval_ducoment


class ToolRetriever:
    def __init__(self, corpus_tsv_path = "", model_path=""):
        self.corpus_tsv_path = corpus_tsv_path
        self.model_path = model_path
        self.corpus, self.corpus2tool = self.build_retrieval_corpus()
        self.embedder = self.build_retrieval_embedder()
        self.corpus_embeddings = self.build_corpus_embeddings()
        
    def build_retrieval_corpus(self):
        print("Building corpus...")
        documents_df = pd.read_csv(self.corpus_tsv_path, sep='\t')
        corpus, corpus2tool = process_retrieval_ducoment(documents_df)
        corpus_ids = list(corpus.keys())
        corpus = [corpus[cid] for cid in corpus_ids]
        return corpus, corpus2tool

    def build_retrieval_embedder(self):
        print("Building embedder...")
        embedder = SentenceTransformer(self.model_path)
        return embedder
    
    def build_corpus_embeddings(self):
        print("Building corpus embeddings with embedder...")
        corpus_embeddings = self.embedder.encode(self.corpus, convert_to_tensor=True)
        return corpus_embeddings

    def retrieving(self, query, top_k=5, excluded_tools={}):
        print("Retrieving...")
        start = time.time()
        query_embedding = self.embedder.encode(query, convert_to_tensor=True)
        hits = util.semantic_search(query_embedding, self.corpus_embeddings, top_k=10*top_k, score_function=util.cos_sim)
        retrieved_tools = []
        for rank, hit in enumerate(hits[0]):
            category, tool_name, api_name = self.corpus2tool[self.corpus[hit['corpus_id']]].split('\t') 
            category = standardize_category(category)
            tool_name = standardize(tool_name) # standardizing
            api_name = change_name(standardize(api_name)) # standardizing
            if category in excluded_tools:
                if tool_name in excluded_tools[category]:
                    top_k += 1
                    continue
            tmp_dict = {
                "category": category,
                "tool_name": tool_name,
                "api_name": api_name
            }
            retrieved_tools.append(tmp_dict)
        return retrieved_tools

调用ToolRetriever进行检索:

def retrieve_rapidapi_tools(self, query, top_k, jsons_path):
    retrieved_tools = self.retriever.retrieving(query, top_k=top_k)
    query_json = {"api_list":[]}
    for tool_dict in retrieved_tools:
        if len(query_json["api_list"]) == top_k:
            break
        category = tool_dict["category"]
        tool_name = tool_dict["tool_name"]
        api_name = tool_dict["api_name"]
        if os.path.exists(jsons_path):
            if os.path.exists(os.path.join(jsons_path, category)):
                if os.path.exists(os.path.join(jsons_path, category, tool_name+".json")):
                    query_json["api_list"].append({
                        "category_name": category,
                        "tool_name": tool_name,
                        "api_name": api_name
                    })
    return query_json
  1. DFS算法,调用ToolLLaMA完成每一轮的推理,并搜索一条能够执行成功的路径
pattern = r".+_w(\d+)"
re_result = re.match(pattern,method)
assert re_result != None
width = int(re_result.group(1))
with_filter = True
if "woFilter" in method:
    with_filter = False
# 调用DFS算法
chain = DFS_tree_search(llm=llm_forward, io_func=env,process_id=process_id, callbacks=callbacks)
# DFS算法开始结合大模型进行推理,寻找一条路径
result = chain.start(
    single_chain_max_step=single_chain_max_step,
    tree_beam_size = width,
    max_query_count = max_query_count,
    answer=1,
    with_filter=with_filter)
import re
from Tree.Tree import my_tree, tree_node
from Prompts.ReAct_prompts import FORMAT_INSTRUCTIONS_SYSTEM_FUNCTION, FORMAT_INSTRUCTIONS_USER_FUNCTION
from Prompts.Tree_search_prompts import DIVERSITY_PROMPT
from Algorithms.base_search import base_search_method
from copy import deepcopy
from LLM_rank.rank_candidate import sum_based_rankn, rank2_subfix
import json
import random


class DFS_tree_search(base_search_method):

    def __init__(self, llm, io_func, process_id=0, callbacks=None):
        super(DFS_tree_search, self).__init__(
            llm, io_func, process_id, callbacks)
        """Depth-first search. 
        with_filter=True: Every time a child node is generated, choose the best multiple iterations to go.
        with_filter=False: Do as Preorder traversal.
        """
        self.io_func = io_func
        self.llm = llm
        self.process_id = process_id
        self.restart()

        self.callbacks = callbacks if callbacks is not None else []

    def restart(self):
        self.status = 0
        self.terminal_node = []
        self.give_up_node = []
        self.now_expand_num = 0
        self.query_count = 0
        self.total_tokens = 0

    def send_agent_chain_end(self, depth, agent_block_ids, chain_block_ids):
        for i in range(len(self.callbacks)):
            callback = self.callbacks[i]
            callback.on_chain_end(
                depth=depth,
                block_id=chain_block_ids[i]
            )
            if i < len(agent_block_ids):
                callback.on_agent_end(
                    depth=depth,
                    block_id=agent_block_ids[i]
                )

    def to_json(self, answer=False, process=True):

        if process:
            json_obj = {
                "win": self.status == 1,
                "tree": self.tree.to_json_recursive(),
                "forward_args": self.forward_args,
                "compare_candidates": [],
            }
            for node in self.terminal_node:
                if node.pruned == False:  # has answer
                    json_obj["compare_candidates"].append(
                        node.get_chain_result_from_this_node(use_messages=False))
        else:
            json_obj = {}

        if answer:
            json_obj["answer_generation"] = {
                "valid_data": False,
                "query_count": self.query_count,
                "total_tokens": self.total_tokens,
                "final_answer": "",
                "finish_type": "give_answer",
                "function": self.io_func.functions,
                "chain": [],
            }
            for node in self.terminal_node:
                if node.pruned == False:
                    json_obj["answer_generation"]["valid_data"] = True
                    json_obj["answer_generation"]["finish_type"] = "give_answer"
                    json_obj["answer_generation"]["final_answer"] = node.description
                    json_obj["answer_generation"]["train_messages"] = node.get_train_messages_from_this_node(
                    )
                    break
            # do not have final answer, look for give_up
            if json_obj["answer_generation"]["valid_data"] == False:
                if len(self.give_up_node) > 0:
                    random_pos = random.randint(0, len(self.give_up_node) - 1)
                    choose_give_up_node = self.give_up_node[random_pos]
                    json_obj["answer_generation"]["valid_data"] = True
                    json_obj["answer_generation"]["finish_type"] = "give_up"
                    json_obj["answer_generation"]["final_answer"] = choose_give_up_node.description
                    json_obj["answer_generation"]["train_messages"] = choose_give_up_node.get_train_messages_from_this_node()
        return json_obj

    def start(self, single_chain_max_step, tree_beam_size, max_query_count, answer=1, with_filter=True):
        """ single_chain_max_step: The maximum depth of the tree
            tree_beam_size: How many children nodes for one node are generated per layer
            answer = n means the Algo exits when find n "give_answer" nodes
            max_query_count: the Algo exits when OpenAI-query exists this value
            with_filter: This is the difference between normal DFS(with_filter=True) and DFSDT(with_filter=False). 
        """
        self.forward_args = locals()
        if "self" in self.forward_args.keys():
            self.forward_args.pop("self")
        self.tree = my_tree()
        self.tree.root.node_type = "Action Input"
        self.tree.root.io_state = deepcopy(self.io_func)

        system = FORMAT_INSTRUCTIONS_SYSTEM_FUNCTION
        system = system.replace("{task_description}",
                                self.io_func.task_description)
        self.tree.root.messages.append({"role": "system", "content": system})

        user = FORMAT_INSTRUCTIONS_USER_FUNCTION
        user = user.replace("{input_description}",
                            self.io_func.input_description)
        self.tree.root.messages.append({"role": "user", "content": user})

        return self.DFS(self.tree.root, single_chain_max_step, tree_beam_size, max_query_count, answer, with_filter)

    def DFS(self, now_node, single_chain_max_step, tree_beam_size, max_query_count, answer, with_filter=True):
        """Returns the number of grids to go back. When a child node of a node generates a final answer or give up, it should go back a few more grids
        In a sense, the larger this value is, the more diverse it is, and it is GreedySearch@n when it is enlarged to infinity.
        """

        # this two value declares the rate to go back, Algo degrades to CoT when the value=Inf
        final_answer_back_length = 2
        prune_back_length = 2

        now_node.expand_num = self.now_expand_num
        self.now_expand_num += 1
        if now_node.get_depth() >= single_chain_max_step or now_node.pruned or now_node.is_terminal:
            if now_node.is_terminal:  # final answer
                self.status = 1
                self.terminal_node.append(now_node)
                return final_answer_back_length
            else:
                now_node.pruned = True
                if now_node.observation_code == 4:
                    self.give_up_node.append(now_node)
                    return prune_back_length
                else:
                    return 1

        next_tree_split_nodes = []
        for i in range(tree_beam_size):
            temp_now_node = now_node

            """If a node have children now, We will prompt the model to generate different nodes than all the existing nodes"""
            delete_former_diversity_message = False
            diversity_message = None
            if len(temp_now_node.children) > 0:

                former_candidates_des = ""
                js_list = []
                for k, child in enumerate(temp_now_node.children):
                    temp_node = child
                    while not temp_node.is_terminal and temp_node.node_type != "Action Input" and len(temp_node.children) > 0:
                        temp_node = temp_node.children[0]
                    if temp_node.node_type == "Action Input":
                        obj_dict = {
                            "name": temp_node.father.description,
                            "arguments": temp_node.description,
                            "function_output": temp_node.observation,
                            "mento-carlo-action-value": temp_node.compute_weight(),
                        }
                        js_list.append(obj_dict)

                if len(js_list) > 0:
                    former_candidates_des = former_candidates_des + \
                        f"{json.dumps(js_list,indent=2)}\n"
                    if temp_now_node.observation != "":
                        former_candidates_des = former_candidates_des + \
                            f"again, your former observation: {temp_now_node.observation}\n"
                    diverse_prompt = DIVERSITY_PROMPT
                    diverse_prompt = diverse_prompt.replace(
                        "{previous_candidate}", former_candidates_des)
                    diversity_message = {
                        "role": "user", "content": diverse_prompt}
                    temp_now_node.messages.append(diversity_message)

                    delete_former_diversity_message = True
            # on_chain_start
            now_depth = temp_now_node.get_depth() // 3
            chain_block_ids = [callback.on_chain_start(
                depth=now_depth,
                inputs=temp_now_node.messages
            ) for callback in self.callbacks]
            agent_block_ids = []
            self.llm.change_messages(temp_now_node.messages)
            # on_llm_start
            [callback.on_llm_start(
                depth=now_depth,
                messages=temp_now_node.messages
            ) for callback in self.callbacks]
            new_message, error_code, total_tokens = self.llm.parse(
                self.io_func.functions, process_id=self.process_id)
            # on_llm_end
            [callback.on_llm_end(
                depth=now_depth,
                response=new_message
            ) for callback in self.callbacks]
            self.query_count += 1
            self.total_tokens += total_tokens
            if self.query_count >= max_query_count:  # a big return value will cause the Algo to exit
                return 100000

            # We need to exclude the diversity_message, because it will influence child nodes
            if delete_former_diversity_message:
                temp_now_node.messages[-1]["valid"] = False

            # parse nodes from OpenAI-message like CoT method
            assert new_message["role"] == "assistant"
            if "content" in new_message.keys() and new_message["content"] != None:
                temp_node = tree_node()
                temp_node.node_type = "Thought"
                temp_node.description = new_message["content"]
                child_io_state = deepcopy(temp_now_node.io_state)
                child_io_state.retriever=None

                temp_node.io_state = child_io_state
                temp_node.is_terminal = child_io_state.check_success() != 0
                temp_node.messages = deepcopy(temp_now_node.messages)
                temp_node.father = temp_now_node
                temp_now_node.children.append(temp_node)
                temp_node.print(self.process_id)
                temp_now_node = temp_node

                if error_code != 0:
                    temp_now_node.observation_code = error_code
                    temp_now_node.pruned = True

            if "function_call" in new_message.keys():
                # on_agent_action
                agent_block_ids = [callback.on_agent_action(
                    depth=now_depth,
                    action=new_message["function_call"]["name"],
                    action_input=new_message["function_call"]["arguments"]
                ) for callback in self.callbacks]
                function_name = new_message["function_call"]["name"]
                temp_node = tree_node()
                temp_node.node_type = "Action"
                temp_node.description = function_name
                child_io_state = deepcopy(temp_now_node.io_state)
                child_io_state.retriever=None

                temp_node.io_state = child_io_state
                temp_node.is_terminal = child_io_state.check_success() != 0
                temp_node.messages = deepcopy(temp_now_node.messages)
                temp_node.father = temp_now_node
                temp_now_node.children.append(temp_node)

                temp_node.print(self.process_id)
                temp_now_node = temp_node

                function_input = new_message["function_call"]["arguments"]
                temp_node = tree_node()
                temp_node.node_type = "Action Input"
                temp_node.description = function_input
                child_io_state = deepcopy(temp_now_node.io_state)
                child_io_state.retriever=None
                
                # on_tool_start
                [callback.on_tool_start(
                    depth=now_depth,
                    tool_name=temp_now_node.description,
                    tool_input=function_input
                ) for callback in self.callbacks]
                observation, status = child_io_state.step(
                    action_name=temp_now_node.description, action_input=function_input)
                temp_node.observation = observation
                temp_node.observation_code = status

                temp_node.io_state = child_io_state
                temp_node.is_terminal = child_io_state.check_success() != 0
                temp_node.messages = deepcopy(temp_now_node.messages)
                temp_node.father = temp_now_node
                temp_now_node.children.append(temp_node)
                temp_node.print(self.process_id)
                temp_now_node = temp_node
                # on_tool_end
                [callback.on_tool_end(
                    depth=now_depth,
                    output=observation,
                    status=status
                ) for callback in self.callbacks]
                if status != 0:
                    # return code defination can be seen in Downstream_tasks/rapid_api
                    if status == 4:
                        temp_now_node.pruned = True
                    elif status == 1:  # hallucination api name
                        assert "function_call" in new_message.keys()
                        new_message["function_call"]["name"] = "invalid_hallucination_function_name"
                    elif status == 3:  # final answer
                        temp_now_node.is_terminal = True
                        temp_now_node.make_finish(final_answer_back_length)

            temp_now_node.messages.append(new_message)
            if temp_now_node.node_type == "Action Input":
                temp_now_node.messages.append({
                    "role": "function",
                    "name": new_message["function_call"]["name"],
                    "content": temp_now_node.observation,
                })
            return_value = None
            if not with_filter:  # DFSDT
                result = self.DFS(temp_now_node, single_chain_max_step,
                                  tree_beam_size, max_query_count, answer, with_filter)
                if len(self.terminal_node) >= answer:
                    return_value = 10000
                elif result > 1:
                    return_value = result-1

            else:

                next_tree_split_nodes.append(temp_now_node)
            self.send_agent_chain_end(
                now_depth, agent_block_ids, chain_block_ids)
            if return_value is not None:
                return return_value

        # Sort the generated next_tree_split_nodes nodes when normal DFS
        if len(next_tree_split_nodes) > 1:
            # When using normal DFS, if we have many child nodes, we will refer to LLM to compare and choose the best one to expand first
            # remember, this operator will cost extra OpenAI calls.
            LLM_rank_args = {
                "functions": self.io_func.functions,
                "process_id": self.process_id,
                "task_description": self.io_func.task_description,
                "rank_func": rank2_subfix,
            }
            scores, rank_query_count, total_tokens = sum_based_rankn(
                self.llm, LLM_rank_args=LLM_rank_args, candidates=next_tree_split_nodes)
            self.query_count += rank_query_count
            self.total_tokens += total_tokens
            for score, node in zip(scores, next_tree_split_nodes):
                node.prior_score = score
            zip_value = list(
                zip(next_tree_split_nodes, range(len(next_tree_split_nodes))))
            zip_value.sort(
                key=lambda x: x[0].prior_score, reverse=True)  # 先做score高的
            next_tree_split_nodes, filtered_order = zip(*zip_value)
            # if self.process_id == 0:
            #     print(f"score={scores}, filtered order: {filtered_order}")

        '''
        Choose one to expand
        '''
        for i in range(len(next_tree_split_nodes)):
            result = self.DFS(
                next_tree_split_nodes[i], single_chain_max_step, tree_beam_size, max_query_count, answer)
            if len(self.terminal_node) >= answer:
                return 10000
            elif result > 1:
                now_node.make_finish(2)
                return result - 1

        return 1

由于ToolBench数据构造时也用到了DFS,此时调用的LLM是ChatGPT。在Evaluation推理时则是ToolLLaMA,所以ToolLLaMA推理的格式需要与OpenAI保持一致。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/736383.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

重生奇迹MU 浅析智力妹妹的现状与天赋

早期的重生奇迹MU游戏中&#xff0c;智力系女性角色通常被简称为“奶娘”&#xff0c;因为她们天生就是辅助定位&#xff0c;能够为队友提供很多帮助。那个时代的游戏非常艰难&#xff0c;升级困难&#xff0c;装备和宝石很难获得&#xff0c;使用药品的消耗也非常大。因此&…

深度学习 --- stanford cs231学习笔记五(训练神经网络的几个重要组成部分之二,数据的预处理)

数据的预处理(Data Preprocessing) 2 Data Preprocessing数据的预处理 数据预处理的几种方法 2&#xff0c;1 数据的零点中心化 数据的零点中心化的目的就是为了把数据的整体分布拉回到原点附近&#xff0c;也就是让数据的整体均值变为0。 ​ 2&#xff0c;2 数据的标准化 数据…

Web框架简介

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 如果你要从零开始建立了一些网站&#xff0c;可能会注意到你不得不反复解决一些类似的问题。这样做是令人厌烦的&#xff0c;并且违反了良好编程的核…

AI智能时代:ChatGPT如何在金融市场发挥策略分析与预测能力?

文章目录 一、ChatGPT在金融策略制定中的深度应用客户需求分析与定制化策略市场动态跟踪与策略调整策略分析与优化 二、ChatGPT在算法交易中的深度应用自动交易策略制定交易执行与监控风险管理 三、未来展望《智能量化&#xff1a;ChatGPT在金融策略与算法交易中的实践》亮点内…

pcl::PointXYZRGBA造成点云无法显示

如果pcd文件没有rgba信息&#xff0c;使用pcl::PointXYZRGBA类型打开会提示以下信息&#xff1a; Failed to find match for field rgba另外&#xff0c;显示出来的点云是黑色&#xff0c;如果使用默认背景色为黑色&#xff0c;就无法显示点云了。 如果设置其它背景色&#xf…

OneNote 作为恶意软件分发新渠道持续增长

目前&#xff0c;Office 文件已经默认禁用宏代码&#xff0c;攻击者开始转向利用其他微软的软件产品来进行恶意 Payload 投递。默认情况下&#xff0c;OneNote 应用也包含在 Office 2019 和 Microsoft 365 软件中&#xff0c;所以 OneNote 文件越来越受到攻击者的青睐。如果有人…

上新:NFTScan 正式上线 Bitcoin-brc20 浏览器!

近日&#xff0c;NFTScan 团队正式对外发布了 Bitcoin-brc20 浏览器&#xff0c;将为 Bitcoin 生态的 NFT 开发者和用户提供简洁高效的 NFT 数据搜索查询服务。作为比特币生态中最火热的标准之一&#xff0c;brc20 也吸引着广泛的关注。洞悉其巨大潜力&#xff0c;NFTScan 对 b…

BFS:解决多源最短路问题

文章目录 什么是多源最短路问题&#xff1f;1.矩阵2.飞地的数量3.地图的最高点4.地图分析总结 什么是多源最短路问题&#xff1f; 多源最短路问题&#xff08;Multi-Source Shortest Path Problem&#xff0c;MSSP&#xff09;是图论中的一个经典问题&#xff0c;它的目标是在…

初学者应该掌握的MySQL数据库的基本组成部分及概念

MySQL数据库作为一种开源的关系型数据库管理系统&#xff0c;被广泛应用于Web应用开发和数据存储。它具有高性能、易用性和可靠性等特点&#xff0c;是开发者们的首选之一。在本篇文章中&#xff0c;我们将详细介绍MySQL数据库的核心组成部分&#xff0c;帮助你深入理解这个强大…

C#基于SkiaSharp实现印章管理(1)

最近对着微软的教程学习SkiaSharp的概念及用法&#xff0c;由于之前使用GDI绘制过坐标系、印章等程序&#xff0c;准备使用SkiaSharp、SKControl控件编写简单的印章设计功能&#xff0c;并能用印章对图片盖章。本文实现创建印章背景、序列化及反序列化印章对象等功能。   VS2…

Linux常用环境变量PATH

Linux常用环境变量 一、常用的默认的shell环境变量二、环境变量 PATH三、持久化修改环境变量四、常用的环境变量 一、常用的默认的shell环境变量 1、当我们在shell命令行属于一个命令&#xff0c;shell解释器去解释这个命令的时候&#xff0c;需要先找到这个命令. 找到命令有两…

支付系统的渠道路由架构设计

图解支付系统的渠道路由设计 渠道路由是引导流量路径的关键&#xff0c;其设计至关重要。本文详解渠道路由概念、必要性及形态&#xff0c;并分享一个高效实用的基于规则的渠道路由设计方案。 注&#xff1a;有些公司称渠道为通道&#xff0c;都是一个意思&#xff0c;为方便起…

Monica

在 《long long ago》中&#xff0c;我论述了on是一个刚出生的孩子的脐带连接在其肚子g上的形象&#xff0c;脐带就是long的字母l和字母n&#xff0c;l表脐带很长&#xff0c;n表脐带曲转冗余和连接之性&#xff0c;on表一&#xff0c;是孩子刚诞生的意思&#xff0c;o是身体&a…

代码-功能-python-爬取博客网标题作者发布时间

环境&#xff1a; python 3.8 代码&#xff1a; # 爬取博客园内容 # https://www.cnblogs.com/import re from lxml import etree import requests import json import threading from queue import Queue import pymysql import timeclass HeiMa:def __init__(self):# 请…

36.远程注入到入口点注入

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 如果看不懂、不知道现在做的什么&#xff0c;那就跟着做完看效果&#xff0c;代码看不懂是正常的&#xff0c;只要会抄就行&#xff0c;抄着抄着就能懂了 上一…

压力测试Monkey命令参数和报告分析

目录 常用参数 -p <测试的包名列表> -v 显示日志详细程度 -s 伪随机数生成器的种子值 --throttle < 毫秒> --ignore-crashes 忽略崩溃 --ignore-timeouts 忽略超时 --monitor-native-crashes 监视本地崩溃代码 --ignore-security-exceptions 忽略安全异常 …

Git 中 pull 操作和 rebase 操作的不同

由于在开发过程中&#xff0c;pull 操作和 rebase 操作都是用来合并分支的&#xff0c;所以我就常常分不清这两个操作具体有什么区别&#xff0c;所以才有了这篇博客来做个简单区分&#xff0c;具体细致差别还请移步到官方文档&#xff1a;Git - Reference (git-scm.com) 1&am…

【Python机器学习】k均值聚类——k均值的失败案例

k均值可能不总能找到“正确”的簇个数&#xff0c;每个簇仅由其中心定义&#xff0c;这意味着每个簇都是凸形。因此&#xff0c;k均值只能找到相对简单的形状。k均值还假设所有簇在某种程度上具有相同的“直径”&#xff0c;它总是将簇之间的边界刚好画在簇中心的之间位置。有时…

神经网络与深度学习 - 神经网络基础

1.2 神经网络基础 学习目标 知道逻辑回归的算法计算输出、损失函数知道导数的计算图知道逻辑回归的梯度下降算法知道多样本的向量计算 应用 应用完成向量化运算应用完成一个单神经元神经网络的结构 1.2.1 Logistic回归 逻辑回归是一个主要用于二分分类的算法。给定一个特…

帝国cms批量取消文章审核-把已审核的文章改成未审核的方法

帝国cms很多人采集的时候&#xff0c;把文章弄成了审核过的文章&#xff0c;或者因为其他的原因&#xff0c;文章都是审核通过&#xff0c;为了seo又不能把全部文章放出来&#xff0c;所以需要把文章弄成未审核以下就是解决本问题的办法 首先来修改后台列表文件&#xff0c;自…