【Python】科研代码学习:七 TrainingArguments,Trainer

【Python】科研代码学习:七 TrainingArguments,Trainer

  • TrainingArguments
    • 重要的方法
  • Trainer
    • 重要的方法
    • 使用 Trainer 的简单例子

TrainingArguments

  • HF官网API:Training
    众所周知,推理是一个大头,训练是另一个大头
    之前的很多内容,都是为训练这里做了一个小铺垫
    如何快速有效地调用代码,训练大模型,才是重中之重(不然学那么多HF库感觉怪吃苦的)
  • 首先看训练参数,再看训练器吧。
    首先,它的头文件是 transformers.TrainingArguments
    再看它源码的参数,我勒个去,太多了吧。
    ※ 我这里挑重要的讲解,全部请看API去。
  • output_dir (str)设置模型输出预测,或者中继点 (checkpoints) 的输出目录。模型训练到一半,肯定需要有中继点文件的嘛,就相当于游戏存档有很多一样,防止跑一半直接程序炸了,还要从头训练
  • overwrite_output_dir (bool, optional, defaults to False):把这个参数设置成 True,就会覆盖其中 output_dir 中的文档。一般在从中继点继续训练时需要这么用
  • do_train (bool, optional, defaults to False):指明我在做训练集的训练任务
  • do_eval (bool, optional):指明我在做验证集的评估任务
  • do_predict (bool, optional, defaults to False) :指明我在做测试集的预测任务
  • evaluation_strategy :评估策略:训练时不评估 / 每 eval_steps 步评估,或者每 epoch 评估
"no": No evaluation is done during training.
"steps": Evaluation is done (and logged) every eval_steps.
"epoch": Evaluation is done at the end of each epoch.
  • per_device_train_batch_size :训练时每张卡的batch大小,默认为8
    per_device_eval_batch_size :评估时每张卡的batch大小,默认为8
  • learning_rate (float, optional, defaults to 5e-5):学习率,里面使用的是 AdamW optimizer
    其他相应的 AdamW Optimizer 的参数还有:
    weight_decay adam_beta1adam_beta2adam_epsilon
  • num_train_epochs:训练的 epoch 个数,默认为3,可以设置小数。
  • lr_scheduler_type:具体作用要查看 transformers 里的 Scheduler 是干什么用的
  • warmup_ratiowarmup_steps :让一开始的学习率从0逐渐升到 learning_rate 用的
  • logging_dir :设置 logging 输出的文档
    除此之外还有一些和 logging相关的参数:
    logging_strategy ,logging_first_step ,logging_steps ,logging_nan_inf_filter 设置日志的策略
  • 与保存模型中继文件相关的参数:
    save_strategy :不保存中继文件 / 每 epoch 保存 / 每 save_steps 步保存
"no": No save is done during training.
"epoch": Save is done at the end of each epoch.
"steps": Save is done every save_steps.

save_steps :如果是整数,表示多少步保存一次;小数,则是按照总训练步,多少比例之后保存一次
save_total_limit :最多中继文件的保存上限,如果超过上限,会先把最旧的那个中继文件删了再保存新的
save_safetensors :使用 savetensor来存储和加载 tensors,默认为 True
push_to_hub :是否保存到 HF hub

  • use_cpu (bool, optional, defaults to False):是否用 cpu 训练
  • seed (int, optional, defaults to 42) :训练的种子,方便复现和可重复实验
  • data_seed :数据采样的种子
  • 数据精读相关的一些参数:
    FP32、TF32、FP16、BF16、FP8、FP4、NF4、INT8
    bf16 (bool, optional, defaults to False)fp16 (bool, optional, defaults to False)
    tf32 (bool, optional)
  • run_name :展示在 wandb and mlflow logging 中的描述
  • load_best_model_at_end (bool, optional, defaults to False):是否保存效果最好的中继点作为最终模型,与 save_total_limit 有些交互操作
    如果上述设置成 True 的话,考虑 metric_for_best_model ,即如何评估效果最好。默认为 loss 即损失最小
    如果你修改了 metric_for_best_model 的话,考虑 greater_is_better ,即指标越大越好还是越小越好
  • 一些加速相关的参数,貌似都比较麻烦
    fsdp
    fsdp_config
    deepspeed
    accelerator_config
  • optim :设置 optimizer,默认为 adamw_torch
    也可以设置成 adamw_hf, adamw_torch, adamw_torch_fused, adamw_apex_fused, adamw_anyprecision or adafactor.
  • resume_from_checkpoint :传入中继点文件的目录,从中继点继续训练

重要的方法

  • ※ 那我怎么访问或者修改上述参数呢?
    由于这个需要实例化,所以我们需要使用OO的方法修改
    下面讲一下其中重要的方法
  • set_dataloader:设置 dataloader
    在这里插入图片描述
from transformers import TrainingArguments

args = TrainingArguments("working_dir")
args = args.set_dataloader(train_batch_size=16, eval_batch_size=64)
args.per_device_train_batch_size
  • 设置 logging 相关的参数
    在这里插入图片描述
  • 设置 optimizer
    在这里插入图片描述
  • 设置保存策略
    在这里插入图片描述
  • 设置训练策略
    在这里插入图片描述
  • 设置评估策略
    在这里插入图片描述
  • 设置测试策略
    在这里插入图片描述

Trainer

  • 终于到大头了。Trainer 是主要用 pt 训练的,主要支持 GPUs (NVIDIA GPUs, AMD GPUs)/ TPUs
  • 看下源码,它要的东西不少,讲下重要参数:
  • model:要么是 transformers.PretrainedModel 类型的,要么是简单的 torch.nn.Module 类型的
  • argsTrainingArguments 类型的训练参数。如果不提供的话,默认使用 output_dir/tmp_trainer 里面的那个训练参数
  • data_collator DataCollator 类型参数,给训练集或验证集做数据分批和预处理用的,如果没有tokenizer默认使用 default_data_collator,否则默认使用 DataCollatorWithPadding (Will default to default_data_collator() if no tokenizer is provided, an instance of DataCollatorWithPadding otherwise.)
  • train_dataset (torch.utils.data.Dataset or torch.utils.data.IterableDataset, optional) :提供训练的数据集,当然也可以是 Datasets 类型的数据
  • eval_dataset :类似的验证集的数据集
  • tokenizer :提供 tokenizer 分词器
  • compute_metrics :验证集使用时候的计算指标,具体得参考 EvalPrediction 类型
  • optimizers :可以提供 Tuple(optimizer, scheduler)。默认使用 AdamW 以及 get_linear_schedule_with_warmup() controlled by args
    在这里插入图片描述

重要的方法

  • compute_loss:设置如何计算损失
    在这里插入图片描述
  • train:设置训练集训练任务,第一个参数可以设置是否从中继点开始训练
    在这里插入图片描述
  • evaluate:设置验证集评估任务,需要提供验证集
    在这里插入图片描述
  • predict:设置测试集任务
    在这里插入图片描述
  • save_model:保存模型参数到 output_dir在这里插入图片描述
  • training_step:设置每一个训练的 step,把一个batch的输入经过了何种操作,得到一个 torch.Tensor
    在这里插入图片描述

使用 Trainer 的简单例子

  • 主要就是加载一些参数,传进去即可
    模型、训练参数、训练集、验证集、计算指标
    调用训练方法 .train()
    最后保存模型即可 .save_model()
from transformers import (
    Trainer,
    )
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.save_model(outputdir="./xxx")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/447569.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【PyTorch实战演练】深入剖析MTCNN(多任务级联卷积神经网络)并使用30行代码实现人脸识别

文章目录 0. 前言1. 级联神经网络介绍2. MTCNN介绍2.1 MTCNN提出背景2.2 MTCNN结构 3. MTCNN PyTorch实战3.1 facenet_pytorch库中的MTCNN3.2 识别图像数据3.3 人脸识别3.4 关键点定位 0. 前言 按照国际惯例,首先声明:本文只是我自己学习的理解&#xff…

【Qt学习笔记】(二)--第一个程序“Hello World”(学习Qt中程序的运行、发布、编译过程)

声明:本人水平有限,博客可能存在部分错误的地方,请广大读者谅解并向本人反馈错误。    因为我个人对Qt也是有一些需求,所以开设本专栏进行学习,希望大家可以一起学习,共同进步。   这篇博客将从一个 He…

算法刷题Day1 | 704.二分查找、27.移除元素

目录 0 引言1 二分查找1.1 我的解题1.2 修改后1.3 总结 2 移除元素2.1 暴力求解2.2 双指针法(快慢指针) 🙋‍♂️ 作者:海码007📜 专栏:算法专栏💥 标题:代码随想录算法训练营第一天…

Vue.js+SpringBoot开发大病保险管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 系统配置维护2.2 系统参保管理2.3 大病保险管理2.4 大病登记管理2.5 保险审核管理 三、系统详细设计3.1 系统整体配置功能设计3.2 大病人员模块设计3.3 大病保险模块设计3.4 大病登记模块设计3.5 保险审核模块设计 四、…

MySQL三种日志

一、undo log(回滚日志) 1.作用: (1)保证了事物的原子性 (2)通过read view和undo log实现mvcc多版本并发控制 2.在事务提交前,记录更新前的数据到undo log里,回滚的时候读…

数据可视化助力林业智能管理

数据可视化是当下科技发展中的一项重要工具,它在各行各业都展现了强大的应用价值。在智慧林业领域,数据可视化更是发挥了独特的作用,为林业管理和生态保护提供了有效的支持和解决方案。下面我就以可视化从业者的角度,来简单聊聊这…

四节点/八节点四边形单元悬臂梁Matlab有限元编程 | 平面单元 | Matlab源码 | 理论文本

专栏导读 作者简介:工学博士,高级工程师,专注于工业软件算法研究本文已收录于专栏:《有限元编程从入门到精通》本专栏旨在提供 1.以案例的形式讲解各类有限元问题的程序实现,并提供所有案例完整源码;2.单元…

关于yolov8的DFL模块(pytorch以及tensorrt)

先看代码 class DFL(nn.Module):"""Integral module of Distribution Focal Loss (DFL).Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391"""def __init__(self, c116):"""Initialize a convo…

嵌入式C语言(六)

对齐这个事情在内核中可不是个什么小事,内核中涉及到内存方面的都需要非常的谨慎。 上一篇我们知道了可以通过__attribute__来声明属性,也知道了section这个属性,这篇我们来看看关于内存对齐使用的两个属性–>aligned和packed 地址对齐&…

Altium Designer如何对走线模式进行切换

AD软件提供了比较智能的走线模式切换功能,可以根据个人习惯进行切换,能有效的提高了PCB设计效率。 点击界面右上角系统参数的图标 或者在pcb界面中使用快捷键OP进入到优选项界面,然后选中 PCB Editor-Interactive Routing,在布线…

ubuntu設定QGC獲取pixhawk Mini4(PX4 Mini 4) 的imu信息

ubuntu20.04 QGC使用v4.3.0的版本 飛控pixhawk Mini4 飛控上只使用一條micro USB連接電腦,沒有其他線 安裝命令 sudo apt-get remove modemmanager -y sudo apt install gstreamer1.0-plugins-bad gstreamer1.0-libav gstreamer1.0-gl -y sudo apt install libf…

Vue:纯前端实现文件拖拽上传

先看一下拖拽相关的事件:dragover、dragenter drop和dragleave 。 dragover事件:当被拖动的元素在一个可放置目标上方时,该事件会被触发。 通常,我们会使用event.preventDefault()方法来取消浏览器默认的拖放行为,以便…

amv是什么文件格式?如何播放amv视频?

AMV文件格式源自于中国公司Actions Semiconductor,最初作为其MP4播放器中使用的专有视频格式。产生于数码媒体发展的需求下,AMV格式为小屏幕便携设备提供了一种高度压缩的视频存储方案。 AMV文件格式的主要特性与使用场景 AMV格式以其独特的特性在小尺寸…

复合式统计图绘制方法(7)

复合式统计图绘制方法(7) 常用的统计图有条形图、柱形图、折线图、曲线图、饼图、环形图、扇形图。 前几类图比较容易绘制,饼图环形图绘制较难。 在统计图的应用方面,有时候有两个关联的统计学的样本值要用统计图来表达&#xff0…

运动想象 (MI) 迁移学习系列 (5) : SSMT

运动想象迁移学习系列:SSMT 0. 引言1. 主要贡献2. 网络结构3. 算法4. 补充4.1 为什么设置一种新的适配器?4.2 动态加权融合机制究竟是干啥的? 5. 实验结果6. 总结欢迎来稿 论文地址:https://link.springer.com/article/10.1007/s11517-024-0…

天府锋巢直播产业基地:直播带岗,成都直播基地奔向产业化

天府锋巢直播产业基地位于成都市天府新区科学城板块,是一座集直播带岗、电商孵化、产业培训、供应链整合等多功能于一体的现代化全域直播产业基地。近年来,随着成都直播产业的蓬勃发展,成都积极响应市场需求,致力于打造出西部地区…

linux进程间通信-共享内存

一、共享内存是什么 在Linux系统中,共享内存是一种IPC(进程间通信)方式,它可以让多个进程在物理内存中共享一段内存区域。 这种共享内存区域被映射到多个进程的虚拟地址空间中,使得多个进程可以直接访问同一段物理内存…

【Python可视化系列】一文教你绘制雷达图(源码)

这是我的第234篇原创文章。 一、引言 雷达图是以从同一点开始的轴上表示的三个或更多个定量变量的二维图表的形式显示多变量数据的图形方法,也称为蜘蛛图或星形图。雷达图通常用于综合分析多个指标,具有完整,清晰和直观的优点。通常由多个等…

Constrained Iterative LQR 自动驾驶中使用的经典控制算法

Motion planning 运动规划在自动驾驶领域是一个比较有挑战的部分。它既要接受来自上层的行为理解和决策的输出,也要考虑一个包含道路结构和感知所检测到的所有障碍物状态的动态世界模型。最终生成一个满足安全性和可行性约束并且具有理想驾驶体验的轨迹。 通常,motion plann…

遥感影像植被波谱特征总结

植被跟太阳辐射的相互关系有别于其他物质&#xff0c;如裸土、水体等&#xff0c;比如植被的“红边”现象&#xff0c;即在<700nm附近强吸收&#xff0c;>700nm高反射。很多因素影响植被对太阳辐射的吸收和反射&#xff0c;包括波长、水分含量、色素、养分、碳等。 研究…