使用deepspeed小记

1. 减少显存占用的历程忠告

医学图像经常很大,所以训练模型有时候会有难度,但是现在找到了很多减少显存的方法。
不知道为什么,使用transformers的trainer库确确实实会减少显存的占用,即使没有使用deepspeed,占用的显存也会减少。

别自己造轮子
我之前也使用过 LoRA,自己也设计过,非常非常建议千万不要自己去写LoRA,很浪费时间,设计很费时间,同时检验模型LoRA的有效性也很浪费时间,权重的融合也很浪费时间,尽量使用其他已经写好的LoRA。

我推荐使用transformers集成模型和训练集,只需要写一个dataset和collate_fn,最多再多写一个Trainer的computer_loss,模型就可以自然而然的搞定。效率最高最有效。

2. Deepspeed方便快捷

在这里插入图片描述
使用 deepspeed 的流程是最短的

2.1 如果warning,需要加载一些库

moudle ava
moudle load compiler/gcc/7.3.1
moudle load cuda/7/11.8

由于deepspeed进行编译实际上是将GPU的一些指令重新编译,让CPU执行,同时还要符合CUDA的计算结构,能和GPU交互,所以GCC编译,CUDA编译都要符合版本要求

2.2 编写Trainer的python文件

建议使用transformers的trainer函数,这样很多json文件可以直接设置auto,同时还方便指定json配置文件。
同时要注意,这里可能会要求你加入 args,设置一个 local_rank 全局管控。
TrainingArguments 指定 ds_config.json 文件

import argparse
import sys
def parse_agrs():
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=-1, help="Local rank. Necessary for using the torch.distributed.launch utility.")
    return args

args = parse_agrs()

training_args = TrainingArguments(
    output_dir='./checkpoint/Eff_R2GenCMN_base',
    num_train_epochs=1000,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./checkpoint/Eff_R2GenCMN_base/output_logs',
    logging_steps=10,
    save_strategy='steps',  # 添加保存策略为每一定步骤保存一次
    save_steps=100,  # 每100步保存一次模型
    save_total_limit=5,  # 最多保存5个模型
    report_to="none",
    fp16=True,  # 启用混合精度训练
    deepspeed='./ds_config.json',
)

tokenizer = Tokenizer()
args = parse_agrs()
model = R2GenCMN(args, tokenizer)
dataset_train = Dataset(xlsx_file="./dataset/train_dataset.xlsx")
dataset_test = Dataset(xlsx_file="./dataset/test_dataset.xlsx")
trainer = MyTrainer(
    model=model,  # 使用的模型
    args=training_args,  # 训练参数
    train_dataset=dataset_train,  # 训练数据集
    eval_dataset=dataset_test,  # 验证数据集
    data_collator=collate_fn,
    # 可能需要定义compute_metrics函数来计算评估指标
)

2.3 编写ds_config文件

编写ds_config文件的目的就是简介python文件,同时更改参数方便,减少大脑记忆负担,便于使用。
ds_config.json 文件脚本通常是 通用的, batch如果写auto,deepspeed会根据显卡给你 自动设置batch 大小
这里只是设置了

stage2的

{
    "bfloat16": {
        "enabled": false
    },
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },
    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto"
        }
    },
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true
    },
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "steps_per_print": 1e5
}

或者使用stage3

{
    "bfloat16": {
        "enabled": false
    },
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },
    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto"
        }
    },
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_fp16_weights_on_model_save": true
    },
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 1e5,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

2.4 运行程序

最终deepspeed运行就可以了
这里的warning实际上没有影响模型的运行,是重新编译。

deepspeed train.py
[2024-04-02 12:04:43,112] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-04-02 12:05:48,493] [WARNING] [runner.py:196:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only.
[2024-04-02 12:05:48,493] [INFO] [runner.py:555:main] cmd = /public/home/v-yumy/anaconda3/envs/llava2/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMF19 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None transformer_train.py
[2024-04-02 12:05:51,627] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-04-02 12:05:55,944] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [0]}
[2024-04-02 12:05:55,944] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=1, node_rank=0
[2024-04-02 12:05:55,944] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0]})
[2024-04-02 12:05:55,944] [INFO] [launch.py:163:main] dist_world_size=1
[2024-04-02 12:05:55,944] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=0
[2024-04-02 12:06:29,136] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-04-02 12:06:31,519] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2024-04-02 12:06:31,519] [INFO] [comm.py:594:init_distributed] cdb=None
[2024-04-02 12:06:31,519] [INFO] [comm.py:625:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.742 seconds.
Prefix dict has been built successfully.
EfficientNet: replace first conv
EncoderDecoder 的Transformer 是 base
EncoderDecoder 是 base
视觉特征,不进行预训练
 [WARNING]  cpu_adam cuda is missing or is incompatible with installed torch, only cpu ops can be compiled!
Using /public/home/v-yumy/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...
Emitting ninja build file /public/home/v-yumy/.cache/torch_extensions/py310_cu117/cpu_adam/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module cpu_adam...
Time to load cpu_adam op: 0.7046074867248535 seconds
Rank: 0 partition count [1] and sizes[(42770360, False)] 
{'loss': 6.7285, 'learning_rate': 1.6730270909663467e-05, 'epoch': 0.02}                                                                                                                                   
{'loss': 6.0535, 'learning_rate': 2.3254658315702903e-05, 'epoch': 0.05}                                                                                                                                   
{'loss': 5.598, 'learning_rate': 2.6809450068309278e-05, 'epoch': 0.07}                                                                                                                                    
{'loss': 5.2824, 'learning_rate': 2.9266416338062584e-05, 'epoch': 0.1}                                                                                                                                    
{'loss': 5.0738, 'learning_rate': 3.114597855245884e-05, 'epoch': 0.12}                                                                                                                                    
{'loss': 4.8191, 'learning_rate': 3.266853634404809e-05, 'epoch': 0.15}                                                                                                                                    
{'loss': 4.5336, 'learning_rate': 3.3948300828875964e-05, 'epoch': 0.17}       

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/512918.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【QT学习】4.浮动窗口

结果&#xff1a; 代码&#xff1a; //制作核心控件&#xff1a;文本编辑框QTextEdit* pTextEditnew QTextEdit;//制作浮动控件connect(pMenu1,&QMenu::triggered,[](QAction* pAction){qDebug()<<pAction->text()<<endl;if(pAction->text()"浮动…

若依框架时间比较的坑(DATE_FORMAT)

背景 - 想做生日的比较 若依自带的比较 <if test"params.beginTime ! null and params.beginTime ! "><!-- 开始时间检索 -->AND date_format(u.create_time,%y%m%d) > date_format(#{params.beginTime},%y%m%d)</if><if test"params…

XXLJob中GLUE模式实现在线编写java/shell/python/php/nodejs/powerShell---SpringCloud工作笔记202

1.起因: 之前就一直想实现类似的功能,今天总于找到有可以参考的东西了,这个思路可以帮助实现这种功能. 2.获得灵感 就是:我想实现通过在线编写代码,来扩展我们平台的能力,这样随着业务的扩展,不用我们每次都修改了代码,再去部署,这样就比较麻烦,今天偶尔发现,对于xxljob来说.有…

QA测试开发工程师面试题满分问答8: mysql数据库的索引定义、用途和使用场景

MySQL数据库索引是一种数据结构&#xff0c;用于提高数据库的查询效率。索引是基于表中的一个或多个列构建的&#xff0c;它们允许数据库系统快速定位和访问表中的特定数据&#xff0c;而无需扫描整个表。 索引的定义 在MySQL中&#xff0c;可以使用CREATE INDEX语句定义索引…

计算机网络|谢希仁版|数据链路层

数据链路层 数据链路层研究的是什么&#xff1f;数据链路层的几个共同问题数据链路与链路帧通信规程 三个基本问题封装成帧透明传输差错检测可靠传输 点对点协议PPPPPP协议应满足的需求PPP协议的组成PPP协议帧的格式各字段的意义字节填充零比特填充PPP协议的工作状态 使用广播信…

TSINGSEE青犀推出河道/河湖/水域治理视频AI智能解决方案

一、方案背景 “十四五”时期&#xff0c;在面源污染防治等方面实现突破&#xff0c;实现主要水污染排放总量持续减少&#xff0c;水生态环境持续改善等任务艰巨。进一步完善流域综合治理体系&#xff0c;提升流域水环境综合治理能力和水平&#xff0c;更好适应新阶段发展需求…

【Unity灶台】食品加工系统模型搭建

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;元宇宙-秩沅 &#x1f468;‍&#x1f4bb; hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍&#x1f4bb; 本文由 秩沅 原创 &#x1f468;‍&#x1f4bb; 收录于专栏&#xff1a;uni…

去班味的尽头是风险管理

运维工程师的“班味”是从风险管理就加重的。 什么是班味呢&#xff1f;指的是打工人身上特有的疲惫气质&#xff0c;面色憔悴、双目无神和腰酸背痛都是“班味”的显著表现。习惯性回复“收到&#xff0c;马上来”、不自觉唉声叹气、下班也提不起精神等症状&#xff0c;则说明…

CTF下加载CTFtraining题库以管理员身份导入 [HCTF 2018]WarmUp,之后以参赛者身份完成解题全过程

-------------------搭建CTFd------------------------------ 给大家介绍一个本地搭建比较好用的CTF比赛平台&#xff1a;CTFD。 CTFd是一个Capture The Flag框架&#xff0c;侧重于易用性和可定制性。它提供了运行CTF所需的一切&#xff0c;并且可以使用插件和主题轻松进行自…

手写SpringBoot(二)之动态切换Servlet容器

系列文章目录 手写SpringBoot&#xff08;一&#xff09;之简易版SpringBoot 手写SpringBoot&#xff08;二&#xff09;之动态切换Servlet容器 手写SpringBoot&#xff08;三&#xff09;之自动配置 手写SpringBoot&#xff08;四&#xff09;之bean动态加载 手写SpringBoot&…

如何使用VNC+Cpolar实现Windows电脑公网远程控制Ubuntu系统桌面

文章目录 前言1. VisualSVN安装与配置2. VisualSVN Server管理界面配置3. 安装cpolar内网穿透3.1 注册账号3.2 下载cpolar客户端3.3 登录cpolar web ui管理界面3.4 创建公网地址 4. 固定公网地址访问 前言 SVN 是 subversion 的缩写&#xff0c;是一个开放源代码的版本控制系统…

生存分析笔记

生存分析&#xff08;英语&#xff1a;Survival analysis&#xff09;是指根据试验或调查得到的数据对生物或人的生存时间进行分析和推断&#xff0c;研究生存时间和结局与众多影响因素间关系及其程度大小的方法&#xff0c;也称生存率分析或存活率分析&#xff0c;例如生物有机…

2023年第十四届蓝桥杯 - 省赛 - C/C++大学A组 - B.有奖问答

Idea 一共 30 道题&#xff0c;得分情况为 0 ~ 100 分。 创建一个 30 行 100 列的 dp 数组&#xff0c;dp[i][j] 表示做完第 i 题&#xff0c;得分为 j 的方案数。 Code Python dp [[0 for _ in range(100)] for _ in range(31)] # dp[i][j] 表示做完第 i 题得分为 j 的…

Scala第十八章节(Iterable集合、Seq集合、Set集合、Map集合以及统计字符个数案例)

Scala第十八章节 章节目标 掌握Iterable集合相关内容.掌握Seq集合相关内容.掌握Set集合相关内容.掌握Map集合相关内容.掌握统计字符个数案例. 1. Iterable 1.1 概述 Iterable代表一个可以迭代的集合, 它继承了Traversable特质, 同时也是其他集合的父特质. 最重要的是, 它定…

UE4_破碎插件的蓝图节点_Apply Radius Damage

一、知识点 Apply Radius Damage:破碎组件所带的蓝图节点。 二、使用方法&#xff1a; 1、设置——插件&#xff0c;搜索destruction&#xff0c;找到 Apex Destruction&#xff0c;勾选已启用。重启虚幻编辑器。 2、这样右键操作就有创建可破坏的网格体菜单&#xff0c;将do…

基于51单片机多路温度检测—8路、4路

基于51单片机多路温度检测 &#xff08;仿真&#xff0b;程序&#xff0b;原理图&#xff09; 功能介绍 具体功能&#xff1a; 1.通过对多路DS18B20温度传感器的数据采集&#xff1b; 2.温度采集并将数值显示在LCD显示屏上&#xff1b; 3.可通过按键设置温度报警值&#xff…

如果查看iPhone的GPU

摘要 了解你的显卡对于在电脑上玩现代图形要求高的游戏非常重要。本文介绍了如何轻松查看你的显卡型号以及为什么显卡在玩电脑游戏时如此关键。 引言 随着电脑游戏的发展&#xff0c;现代游戏对硬件性能的要求越来越高。十年前发布的显卡已经无法满足当前游戏的需求。因此&…

《书生·浦语大模型全链路开源开放体系》学习笔记

书生浦语大模型全链路开源开放体系-学习笔记 大模型成为发展通用人工智能的重要途径专用模型通用大模型 书生大模型开源历程InternLM2回归语言建模的本质主要亮点性能全方位提升强大的内生计算能力 从模型到应用典型流程全链条开源开放体系数据数据集获取预训练微调XTuner 评测…

函数参数缺省和内联函数【C++】

文章目录 函数参数缺省函数参数缺省的条件和要求 内联函数内联函数的工作原理内联函数的定义方法内联函数的要求解决方法&#xff1a;直接在.h中定义内联函数的函数体 内联函数再Debug模式下默认是不展开的 函数参数缺省 顾名思义&#xff1a;可以少传一个/多个参数给函数&…

基于springboot+vue+Mysql的企业客户信息反馈平台

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…