LLM: AI Mathematical Olympiad (下)

文章目录

  • 一、SC-TIR策略(工具整合推理)
  • 二、SC-TIR原理
  • 三、避免过拟合
  • 四、代码分析
    • 1、Main函数
    • 2、SC-TIR control flow
    • 3、Extract answer
    • 4、Execute completion
  • 总结


本文较长分成两个部分分析 | ू•ૅω•́)ᵎᵎᵎ
第一部分:预备知识介绍和数据准备
第二部分:推理策略与代码分析

第二部分来了 。:.゚ヽ(。◕‿◕。)ノ゚.:。+゚, 这部分主要介绍作者的推理策略,个人觉得LLM在推理上的策略设计可以让模型有着质的改变

一、SC-TIR策略(工具整合推理)

保证模型输出和评测的稳定,作者设计自一致性工具整合推理 (SC-TIR) 来抑制高波动,流程如下图:
请添加图片描述
具体流程如下


1)将每道题复制 N 次以生成 vLLM 的一个 batch。N 可以看成多数投票时的候选数量。
2)对这 N 个输入进行采样解码,直至生成完整的 Python 代码块。
3)执行每个 Python 代码块并将其输出串接在代码后面,包括栈回溯 (如有)。
4)重复 M 次以生成 N 个、深度为 M 的生成,允许模型使用栈回溯自纠正代码错误。如果某个样本无法生成合理的输出 (如,生成了不完整的代码块),就删除之。
5)对候选解答进行后处理,并使用多数投票来选择最终答案


二、SC-TIR原理

TIR伪代码如下
请添加图片描述
1)首先初始化问题文本
2)循环k次进行答案生成zi
3)如果带有标记的answer在zi中,直接解析并返回答案
4)如果带有python标记的代码块不在zi中,继续下一次生成
5)如果带有python标记的代码块在zi中,使用python解释器执行,将上一次的问题ci-1, 当前生成文本zi, python执行结果ri进行拼接,执行下一次生成


从伪代码可以发现,作者没有直接计算pθ (y | x),TIR会从一个辅助潜在变量、生成的CoT计划的跟踪和Python源代码中联合得出一个答案和一系列样本。该序列记为z。因此,从TIR中得出的生成器对应于pθ (y, z | x)的样本。因此,为了有效地计算pθ (y | x),需要将生成的迹线边缘化,这可以通过以下求和来实现:

请添加图片描述

在潜在变量建模的背景下,这个方程通常被称为边际似然,当Z很大时,通常在计算上是不可行的,就像这里的情况一样。这种情况在实践中经常发生,存在各种近似策略。最值得注意的是,在边缘化LLM推理轨迹的背景下,Wang等人[2023]提出了自一致性(SC),它通过从pθ (y, z | x)中提取有限数量的n个样本y,然后应用多数投票程序,近似于最大后验(MAP)决策规则的边缘化和应用
请添加图片描述

在SC-TIR的情况下,从TIR中生成n个样本,然后应用一个过滤器F,它去除Y支持之外的病态响应,最后应用自一致多数投票

请添加图片描述


作者使用的 N=48,M=4 。因为增加任一参数的数值并不会提高性能,所以我们就选择了这两个最小值以保证满足时间限制。实际上,该算法通过工具整合推理增强了 CoT 的自一致性 (如下所示)SC-TIR 算法产生了更稳健的结果。请添加图片描述

三、避免过拟合

为了指导模型选择,作者使用了四个内部验证集来衡量模型在不同难度的数学题上的性能。为了避免基础模型中潜在的数据污染,从 AMC12 (2022、2023) 和 AIME (2022、2023、2024) 中选择题目以创建两个内部验证数据集:

AMC (83 道题): 我们选择了 AMC12 22、AMC12 23 的所有题目,并保留了那些结果为整数的题目。最终生成的数据集包含 83 道题。该验证集旨在模拟 Kaggle 上的私有测试集,因为我们从竞赛描述中知道题目难度大于等于这个级别。我们发现我们的模型可以解答大约 60-65% 的题目。为了测量波动,每次评估时,我们使用 5-10 个不同的种子,使用我们的 SC-TIR 算法通常会看到大约 1-3% 的波动。

AIME (90 道题): 我们选择了 AIME 22、AIME 23 以及 AIME 24 的所有题目来度量我们模型解决难题的表现如何,并观测最常见的错误模式。同上,每次评估,我们使用 5-10 个种子进行以测量波动。

由于 AMC/AIME 验证集规模较小,与公开排行榜类似,这些数据集上的模型性能容易受噪声的影响。为了更好地评估模型的性能,我们还使用 MATH 测试集的子集 (含 5,000 道题) 对其进行了评估。我们仅保留答案为整数的题目,以简化多数投票并模拟奥赛评估。因此,我们又多了两个验证集: MATH 4 级 (754 道题) ,MATH 5 级 (721 道题)

通过使用这四个验证集,我们能够在不同的训练阶段选择最有潜力的模型,并缩小超参的选择范围。我们发现,对本 AIMO 赛程而言,将小型但具代表性的验证集与较大的验证集相结合是有用的,因为每个提交都受到抽样随机性的影响。

最终模型评测结果
请添加图片描述

四、代码分析

https://www.kaggle.com/code/lewtun/numina-1st-place-solution/notebook#Python-REPL-and-code-execution-utilities

1、Main函数

基本流程如下:
1)循环每个问题并进行tokenizer化处理
2)对相同问题采样num_samples次(变成一个batch),并构建一个数据集格式用于后续处理
3)循环n次生成,对输入的文本先后进行 generate_batched(LLM生成处理),process_code(python处理)
4)在LLM生成函数中,每个问题会被赋予多个属性,对于无法获取答案的问题,其中prune属性会变为true,用于过滤
5)过滤输出,解析答案,投票获取最终答案

核心部分

    for test, submission in tqdm(iter_test, desc="Solving problems"): 
    	
    	# 处理问题格式,从apply_template函数来看,并没有做什么特殊的处理
        problem = apply_template({"prompt": test.problem.values[0]}, tokenizer=vllm.get_tokenizer(), prompt="{}")
        
        print(f"=== INPUT FOR PROBLEM ID {test.id.values[0]} ===\n{problem}\n")
		
		# Dataset.from_list 从given list中创建一个dataset
		# 将一个problem循环 num_samples, 表示一个采样里面有n个相同的问题
        samples = Dataset.from_list([
            {
                "text": problem["text"],
                "gen_texts": problem["text"],
                "should_prune": False,
                "model_answers": "-1",
                "has_code": True,
            }
            for _ in range(config.num_samples)
        ])
        completed = []
		
		# 循环n次生成
        for step in range(config.num_generations):
        	
        	#samples是个dataset对象,分别执行generated_batched 和 process_code函数
        	# process_code 函数就是SC-TIR
			# SC-TIR会修改sample里面的属性,判断是否should prune,has_code, model_answers
            samples = samples.map(
                generate_batched,
                batch_size=128,
                batched=True,
                fn_kwargs={"vllm": vllm, "sampling_params": sampling_params},
                load_from_cache_file=False,
            )
            samples = samples.map(
                process_code,
                num_proc=num_procs,
                load_from_cache_file=False,
                fn_kwargs={"restart_on_fail": config.restart_on_fail, "last_step": step == (config.num_generations - 1)},
            )
            done = samples.filter(lambda x: x["should_prune"] is True, load_from_cache_file=False)
            if len(done):
                completed.append(done)
			# 不断迭代,直到should_prune 为True或者 完成for循环
            samples = samples.filter(lambda x: x["should_prune"] is False, load_from_cache_file=False)  

		
        completed.append(samples)
        samples = concatenate_datasets(completed)
        candidates = samples["model_answers"]
        print(f"=== CANDIDATE ANSWERS ({len(candidates)}) ===\n{candidates}\n")
		#拿到所有正常答案
        filtered = filter_answers(candidates)
        print(f"=== FILTERED ANSWERS ({len(filtered)}) ===\n{filtered}\n")
        # 投票
        majority = get_majority_vote(filtered)
        print(f"=== MAJORITY ANSWER (mod 1000) ===\n{majority}\n")
        submission["answer"] = majority
        env.predict(submission)
        test["model_answer"] = majority
        final_answers.append(test)
    if not config.is_submission:
        answers = env.df.merge(pd.concat(final_answers))
        answers["correct"] = answers["ground_truth"].astype(int) == answers["model_answer"].astype(int)
        print("Accuracy", answers["correct"].astype(int).mean())

2、SC-TIR control flow

**基本流程如下:
1)首先基于正则匹配找到python代码
2)判断是否存在code block(不存在进行prune,存在执行block), 是否重启(重启保持原文本不变)
3)判断是否存在answer(存在直接解析,不存在执行python代码拿到结果和原始文本拼接,更新属性)
**


def process_code(sample, restart_on_fail, last_step, check_last_n_chars=100):
    gen_text = sample["gen_texts"]
	
	# 正则匹配 找到 ```python -----   ```
    num_python_blocks = len(re.findall(r"```python(.*?)```", gen_text, re.DOTALL))
    region_to_check = gen_text[-check_last_n_chars:]
    if num_python_blocks == 0:
        if restart_on_fail:
            print("no code has ever been generated, RESTARTING")
            sample["gen_texts"] = sample["text"]
        else:
            print("no code has ever been generated, STOP")
            sample["should_prune"] = True
            sample["has_code"] = False
        return sample
     # 没有output标志 ,但是有answer标志和 boxed标志
    if not gen_text.endswith("```output\n") and ("answer is" in region_to_check or "\\boxed" in region_to_check):
        num_output_blocks = len(re.findall(r"```output(.*?)```", gen_text, re.DOTALL))
        if num_output_blocks == 0:
            print("The model hallucinated the code answer")
            sample["should_prune"] = True
            return sample
        if "boxed" in region_to_check:
            try:
                answer = normalize_answer(extract_boxed_answer(region_to_check))
            except Exception:
                answer = "-1"
        else:
            answer = normalize_answer(region_to_check)
        sample["model_answers"] = answer
        return sample
    if last_step:
        return sample
	
	# gen_text 不存在output mark
    if not gen_text.endswith("```output\n"):
        print("warning: output block not found: ", gen_text[-40:])
        if restart_on_fail:
            sample["gen_texts"] = sample["text"]
        else:
            sample["should_prune"] = True
        return sample
      
      ### gen_text 存在 output标记,且存在python block, 执行python
    code_result, _ = postprocess_completion(gen_text, return_status=True, last_code_block=True)
    truncation_limit = 200
    if len(code_result) > truncation_limit:
        code_result = code_result[:truncation_limit] + " ... (output truncated)"
    ### 这里应该就是COT技术了
    sample["gen_texts"] = gen_text + f"{code_result}\n```"
    return sample

3、Extract answer

这部分代码就是提取答案,数据在定义时候会用一个\boxed{} 或者 \fbox{} ,这个函数就是在找到{}里面的内容。


def extract_boxed_answer(text):
    def last_boxed_only_string(text):
        idx = text.rfind("\\boxed")
        if idx < 0:
            idx = text.rfind("\\fbox")
            if idx < 0:
                return None
        i = idx
        right_brace_idx = None
        num_left_braces_open = 0
        while i < len(text):
            if text[i] == "{":
                num_left_braces_open += 1
            if text[i] == "}":
                num_left_braces_open -= 1
                if num_left_braces_open == 0:
                    right_brace_idx = i
                    break
            i += 1
        if right_brace_idx is None:
            return None
        return text[idx : right_brace_idx + 1]

4、Execute completion


def execute_completion(executor, completion, return_status, last_code_block):
    executions = re.findall(r"```python(.*?)```", completion, re.DOTALL)
    if len(executions) == 0:
        return completion, False if return_status else completion
    if last_code_block:
        executions = [executions[-1]]
    outputs = []
    successes = []
    for code in executions:
        success = False
        for lib in ("subprocess", "venv"):
            if lib in code:
                output = f"{lib} is not allowed"
                outputs.append(output)
                successes.append(success)
                continue
        try:
            success, output = executor(code)
        except TimeoutError as e:
            print("Code timed out")
            output = e
        if not success and not return_status:
            output = ""
        outputs.append(output)
        successes.append(success)
    output = str(outputs[-1]).strip()
    success = successes[-1]
    if return_status:
        return output, success
    return output

总结

整个AI Mathematical Olympiad 项目大概就这样,其实真正核心的point就是数据的丰富多样性,在这个数据驱动的时代,丰富的数据比起模型设计更加重要。作者收集数据的想法和推理的策略都很值得学习。这个推理策略还是很受启发的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/921304.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

06、Spring AOP

在我们接下来聊Spring AOP之前我们先了解一下设计模式中的代理模式。 一、代理模式 代理模式是23种设计模式中的一种,它属于结构型设计模式。 对于代理模式的理解: 程序中对象A与对象B无法直接交互,如:有人要找某个公司的老总得先打前台登记传达程序中某个功能需要在原基…

前端vue调试样式方法

1.选中要修改的下拉框&#xff0c;找到对应的标签的class样式 2.在浏览器中添加width宽度样式覆盖原有的样式&#xff0c;如果生效后说明class对了&#xff0c;则到vue页面的strye中添加覆盖样式 <style> :deep(.el-select){width: 180px; } </style>3.寻找自定义…

刷写树莓派系统

一. 树莓派做smb文件服务器参考视频 click here 二. 在官网上下载适合自己树莓派型号的镜像文件 三. 使用官方的riprpi-imager刷写软件 可以自动将TF卡(micro sd卡)格式化为适合树莓派系统运行的文件格式Fat32。就无需自己手动格式化进行刷写。 四. 出现文件验证失败 先把…

Python中的XGBOOST算法实现详解

文章目录 Python中的XGBOOST算法实现详解一、引言二、XGBoost算法原理与Python实现1、XGBoost算法原理1.1、目标函数的二阶泰勒展开 2、Python实现XGBoost2.1、环境准备2.2、导入库2.3、数据准备2.4、数据拆分2.5、模型训练2.6、模型预测与评估2.7、特征重要性可视化 三、总结 …

android 使用MediaPlayer实现音乐播放--权限请求

在Android应用中&#xff0c;获取本地音乐文件的权限是实现音乐扫描功能的关键步骤之一。随着Android版本的不断更新&#xff0c;从Android 6.0&#xff08;API级别23&#xff09;开始&#xff0c;应用需要动态请求权限&#xff0c;而到了android 13以上需要的权限又做了进一步…

Docker 容器化开发 应用

Docker 常用命令 存储 - 目录挂载 存储 卷映射 自定义网络 Docker Compose语法 Dockerfile - 制作镜像 镜像分层机制 完结

Python爬虫案例八:抓取597招聘网信息并用xlutils进行excel数据的保存

excel保存数据的三种方式&#xff1a; 1、pandas保存excel数据&#xff0c;后缀名为xlsx; 举例&#xff1a; import pandas as pddic {姓名: [张三, 李四, 王五, 赵六],年龄: [18, 19, 20, 21],住址: [广州, 青岛, 南京, 重庆] } dic_file pd.DataFrame(dic) dic_file…

【Unity How】Unity中如何实现物体的匀速往返移动

直接上代码 using UnityEngine;public class CubeBouncePingPong : MonoBehaviour {[Header("移动参数")][Tooltip("移动速度")]public float moveSpeed 2f; // 控制移动的速度[Tooltip("最大移动距离")]public float maxDistance 5f; // 最大…

面向对象-接口的使用

1. 接口的概述 为什么有接口&#xff1f; 借口是一种规则&#xff0c;对于继承而言&#xff0c;部分子类之间有共同的方法&#xff0c;为了约束方法的使用&#xff0c;使用接口。 接口的应用&#xff1a; 接口不是一类事物&#xff0c;它是对行为的抽象。 2. 接口的定义和使…

理论结合实践:用Umami构建网站分析系统

个人博客地址&#xff08;欢迎大家访问&#xff09;&#xff1a;理论结合实践&#xff1a;用Umami构建网站分析系统 1. 引言 网站统计分析是一种通过收集、处理和分析网站数据来评估网站性能、用户行为和流量来源的综合方法。通过分析用户访问模式、页面浏览量、访问时长、用户…

【AI最前线】DP双像素sensor相关的AI算法全集:深度估计、图像去模糊去雨去雾恢复、图像重建、自动对焦

Dual Pixel 简介 双像素是成像系统的感光元器件中单帧同时生成的图像&#xff1a;通过双像素可以实现&#xff1a;深度估计、图像去模糊去雨去雾恢复、图像重建 成像原理来源如上&#xff0c;也有遮罩等方式的pd生成&#xff0c;如图双像素视图可以看到光圈的不同一半&#x…

sysbench压测DM的高可用切换测试

一、配置集群 1. 配置svc.conf [rootlocalhost dm]# cat /etc/dm_svc.conf TIME_ZONE(480) LANGUAGE(CN)DM(192.168.112.139:5236,192.168.112.140:5236) [DM] LOGIN_MODE(1) SWITCH_TIME(300) SWITCH_INTERVAL(200)二、编译sysbench 2.1 配置环境变量 [dmdba~]# vi ~/.bas…

高性能linux服务器运维实战小结 性能调优工具

性能指标 进程指标 进程关系 父进程创子进程时&#xff0c;调fork系统调用。调用时&#xff0c;父给子获取一个进程描述符&#xff0c;并设置新的pid&#xff0c;同事复制父进程的进程描述符给子进程&#xff0c;此时不会复制父进程地址空间&#xff0c;而是父子用相同地址空…

pcb元器件选型与焊接测试时的一些个人经验

元件选型 在嘉立创生成bom表&#xff0c;对照bom表买 1、买电容时有50V或者100V是它的耐压值&#xff0c;注意耐压值 2、在买1117等降压芯片时注意它降压后的固定输出&#xff0c;有那种可调降压比如如下&#xff0c;别买错了 贴片元件焊接 我建议先薄薄的在引脚上涂上锡膏…

【zookeeper03】消息队列与微服务之zookeeper集群部署

ZooKeeper 集群部署 1.ZooKeeper 集群介绍 ZooKeeper集群用于解决单点和单机性能及数据高可用等问题。 集群结构 Zookeeper集群基于Master/Slave的模型 处于主要地位负责处理写操作)的主机称为Leader节点&#xff0c;处于次要地位主要负责处理读操作的主机称为 follower 节点…

C 语言复习总结记录三

C 语言复习总结记录三 一 函数的定义 维基百科中对函数的定义&#xff1a;子程序 在计算机科学中&#xff0c;子程序&#xff08;英语&#xff1a;Subroutine, procedure, function, routine, method, subprogram, callable unit&#xff09;&#xff0c;是一个大型程序中的…

MYSQL——多表设计以及数据库中三种关系模型

大致介绍数据库中三种关系模型 一对多&#xff08;1:N&#xff09; 定义&#xff1a; 一个实体可以与另一个实体的多个实例相关联&#xff0c;而后者只能与前者的一个实例相关联。 例子&#xff1a; 学生和课程的关系。 学生&#xff08;1&#xff09;&#xff1a;每个学生…

OpenCV和Qt坐标系不一致问题

“ OpenCV和QT坐标系导致绘图精度下降问题。” OpenCV和Qt常用的坐标系都是笛卡尔坐标系&#xff0c;但是细微处有些不同。 01 — OpenCV坐标系 OpenCV是图像处理库&#xff0c;是以图像像素为一个坐标位置&#xff0c;即一个像素对应一个坐标&#xff0c;所以其坐标系也叫图像…

实验四:构建园区网(OSPF 动态路由)

目录 一、实验简介 二、实验目的 三、实验需求 四、实验拓扑 五、实验步骤 1、在 eNSP 中部署网络 2、设计全网 IP 地址 3、配置二层交换机 4、配置路由交换机并测试通信 5、配置路由接口地址 6、配置 OSPF 动态路由&#xff0c;实现全网互通 一、实验简介 使用路由…

《剖析 Spring 原理:深入源码的旅程(二)》

六、Spring 的 Bean 注入与装配 Spring 的 Bean 注入与装配的方式有很多种&#xff0c;可以通过 xml、get set 方式、构造函数或者注解等。简单易用的方式就是使用 Spring 的注解&#xff0c;Spring 提供了大量的注解方式&#xff0c;如 Autowired、Qualifier 等。Spring 还支持…