Transformers 加速的一些常用技巧

Transformers 是一个强大的架构,但模型因其采用的自注意力机制,虽然能够有效地处理序列数据并捕获长距离依赖关系,但同时也容易导致在训练过程中出现OOM(Out of Memory,内存不足)或者达到GPU的运行时限制。

主要是因为

  1. 参数数量庞大:Transformer模型通常包含大量的参数,尤其是在模型层面进行扩展时(例如,增加层数或头数)。这些参数需要大量的内存来存储权重和梯度。
  2. 自注意力计算:自注意力机制需要对输入序列的每个元素与其他所有元素计算其相互关系,导致计算复杂度和内存需求随着输入长度的增加而显著增加。对于非常长的序列,这一点尤其突出。
  3. 激活和中间状态存储:在训练过程中,需要存储前向传播中的中间激活状态,以便于反向传播时使用。这增加了额外的内存负担。

为了解决这些问题,我们今天来总结以下一些常用的加速策略

固定长度填充

在处理文本数据时,由于文本序列的长度可能各不相同,但许多机器学习模型(尤其是基于Transformer的模型)需要输入数据具有固定的尺寸,因此需要对文本序列进行固定长度填充(padding)。

在使用Transformer模型时,填充部分不应影响到模型的学习。因此通常需要使用注意力掩码(attention mask)来指示模型在自注意力计算时忽略这些填充位置。通过这种固定长度填充和相应的处理方法,可以使得基于Transformer的模型能够有效地处理不同长度的序列数据。在实际应用中,这种方法是处理文本输入的常见策略。

 def fixed_pad_sequences(sequences, max_length, padding_value=0):
     padded_sequences = []
     for sequence in sequences:
         if len(sequence) >= max_length:
             padded_sequence = sequence[:max_length]  # Trim the sequence if it exceeds max_length
         else:
             padding = [padding_value] * (max_length - len(sequence))  # Calculate padding
             padded_sequence = sequence + padding  # Pad the sequence
         padded_sequences.append(padded_sequence)
     return padded_sequences

这种方式会将所有的序列填充成一个长度,这样虽然长度相同了,但是因为序列的实际大小本来就不同,同一批次很可能出现有很多填充的情况,所以就出现了动态填充策略。

动态填充是在每个批处理中动态填充输入序列到最大长度。与固定长度填充不同,在固定长度填充中,所有序列都被填充以匹配整个数据集中最长序列的长度,动态填充根据该批中最长序列的长度单独填充每个批中的序列。

这样虽然每个批次的长度是不同的,但是批次内部的长度是相同的,可以加快处理速度。

 def pad_sequences_dynamic(sequences, padding_value=0):
     max_length = max(len(seq) for seq in sequences)  # Find the maximum length in the sequences
     padded_sequences = []
     for sequence in sequences:
         padding = [padding_value] * (max_length - len(sequence))  # Calculate padding
         padded_sequence = sequence + padding  # Pad the sequence
         padded_sequences.append(padded_sequence)
     return padded_sequences

等长匹配

等长匹配是在训练或推理过程中将长度相近的序列分组成批处理的过程。等长匹配通过基于序列长度将数据集划分为桶,然后从这些桶中采样批次来实现的。

从上图可以看到,通过等长匹配的策略,减少了填充量,这样也可以加速计算

 def uniform_length_batching(sequences, batch_size, padding_value=0):
     # Sort sequences based on their lengths
     sequences.sort(key=len)
     
     # Divide sequences into buckets based on length
     buckets = [sequences[i:i+batch_size] for i in range(0, len(sequences), batch_size)]
     
     # Pad sequences within each bucket to the length of the longest sequence in the bucket
     padded_batches = []
     for bucket in buckets:
         max_length = len(bucket[-1])  # Get the length of the longest sequence in the bucket
         padded_bucket = []
         for sequence in bucket:
             padding = [padding_value] * (max_length - len(sequence))  # Calculate padding
             padded_sequence = sequence + padding  # Pad the sequence
             padded_bucket.append(padded_sequence)
         padded_batches.append(padded_bucket)
     
     return padded_batches

自动混合精度

自动混合精度(AMP)是一种通过使用单精度(float32)和半精度(float16)算法的组合来加速深度学习模型训练的技术。它利用了现代gpu的功能,与float32相比,使用float16数据类型可以更快地执行计算,同时使用更少的内存。

 import torch
 from torch.cuda.amp import autocast, GradScaler
 
 # Define your model
 model = YourModel()
 
 # Define optimizer and loss function
 optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
 criterion = torch.nn.CrossEntropyLoss()
 
 # Create a GradScaler object for gradient scaling
 scaler = GradScaler()
 
 # Inside the training loop
 for inputs, targets in dataloader:
     # Clear previous gradients
     optimizer.zero_grad()
     
     # Cast inputs and targets to the appropriate device
     inputs, targets = inputs.to(device), targets.to(device)
     
     # Enable autocasting for forward pass
     with autocast():
         # Forward pass
         outputs = model(inputs)
         loss = criterion(outputs, targets)
     
     # Backward pass
     # Scale the loss value
     scaler.scale(loss).backward()
     
     # Update model parameters
     scaler.step(optimizer)
     
     # Update the scale for next iteration
     scaler.update()

AMP在训练过程中动态调整计算精度,允许模型在大多数计算中使用float16,同时自动将某些计算提升为float32,以防止下流或溢出等数值不稳定问题。

Fp16 vs Fp32

双精度(FP64)消耗64位。符号值为1位,指数值为11位,有效精度为52位。

单精度(FP32)消耗32位。符号值为1位,指数值为8位,有效精度为23位。

半精度(FP16)消耗16位。符号值为1位,指数值为5位,有效精度为10位。

所以Fp16可以提高内存节省,并可以大大提高模型训练的速度。考虑到Fp16的优势和它在模型使用方面的主导区域,它非常适合推理任务。但是fp16会产生数值精度的损失,导致计算或存储的值不准确,考虑到这些值的精度至关重要。

另外就是这种优化师针对于分类任务的,对于回归这种需要精确数值的任务Fp16的表现并不好。

总结

以上这些方法,可以在一定程度上缓解内存不足和计算资源的限制,但是对于大型的模型我们还是需要一个强大的GPU。

https://avoid.overfit.cn/post/7240bee210cd408a90ca04279830040e

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/619826.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

RS8521XF功能和参数介绍及PDF资料

RS8521XF功能和参数介绍及PDF资料-公司新闻-配芯易-深圳市亚泰盈科电子有限公司 品牌: RUNIC(润石) 封装: SOT-23-5 描述: 零漂移,1uV Vos,350KHz,1通道运放 放大器数: 单路 增益带宽积(GBP): 350kHz 输入偏置电流(Ib): 1pA 压摆率(SR): 0.17V/us RS8521XF 是一款高性能的运算…

OBS插件--图片或视频源随着声音动态缩放

音效动态调整 应用此插件的源可以根据音效声音的高低进行动态的缩放。在本例中,我们在当前的场景里面添加了一个小喇叭的图片源,在这个图片源上引用这款滤镜插件,然后将VLC视频源的音效用于此插件的音效。设置完成后,场景中的小喇…

基于GD32的简易数字示波器(4)- 软件_GD32的keil5环境和串口下载

这期记录的是项目实战,做一个简易的数字示波器。 教程来源于嘉立创,帖子主要做学习记录,方便以后查看。 本期主要介绍GD32的keil5环境和串口下载。详细教程可观看下方链接。 软件-第1讲-工程模板新建_哔哩哔哩_bilibili 2.1 开发环境搭建 …

FebHost:什么是新西兰.NZ域名?

.nz域名是新西兰的顶级域名(TLD)。它是专门分配给新西兰的国家代码顶级域(ccTLD)。以 .nz 结尾的域名用于标识与新西兰或其居民相关的网站、电子邮件地址和其他在线资源。 .nz 域名由新西兰域名委员会管理。负责监督 .nz 域名的注…

[AutoSar]BSW_Diagnostic_005 bootloader跳转流程和配置

目录 关键词平台说明一、调度时序图总览二、P2*Server 和P2Server三、配置四、bootloader 设计 关键词 嵌入式、C语言、autosar、OS、BSW、UDS、diagnostic 平台说明 项目ValueOSautosar OSautosar厂商vector , EB芯片厂商TI 英飞凌编程语言C,C编译器…

2023愚人杯 )————被遗忘的反序列化

<?php# 当前目录中有一个txt文件哦 error_reporting(0); show_source(__FILE__); include("check.php");class EeE{public $text;public $eeee;public function __wakeup(){if ($this->text "aaaa"){echo lcfirst($this->text);}}public functi…

ansible -playbook运维工具、语法、数据结构、命令用法、触发器、角色

目录 配置文件 基本语法规则&#xff1a; YAML支持的数据结构 playbook核心元素 ansible-playbook用法&#xff1a; 触发器 特点&#xff1a; 角色&#xff1a; 习题&#xff1a; 配置文件 playbook配置文件使用yaml语法&#xff0c;YAML 是一门标记性语言,专门用来写配…

目标检测——印度车辆数据集

引言 亲爱的读者们&#xff0c;您是否在寻找某个特定的数据集&#xff0c;用于研究或项目实践&#xff1f;欢迎您在评论区留言&#xff0c;或者通过公众号私信告诉我&#xff0c;您想要的数据集的类型主题。小编会竭尽全力为您寻找&#xff0c;并在找到后第一时间与您分享。 …

C++中的complex

在 C 中&#xff0c;std::complex 是一个模板类&#xff0c;用于表示和操作复数。这个类是标准模板库&#xff08;STL&#xff09;的一部分&#xff0c;包含在 头文件中。std::complex 提供了一套丰富的功能&#xff0c;包括基本的算术运算、比较运算、数学函数等&#xff0c;使…

【微信小程序开发】flex布局在小程序开发项目中的应用详解

✨✨ 欢迎大家来到景天科技苑✨✨ &#x1f388;&#x1f388; 养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; &#x1f3c6; 作者简介&#xff1a;景天科技苑 &#x1f3c6;《头衔》&#xff1a;大厂架构师&#xff0c;华为云开发者社区专家博主&#xff0c;…

Hive JSON数据处理

Hive JSON数据处理 JSON&#xff08;JavaScript Object Notation&#xff09;文件格式是一种轻量级的数据交换格式&#xff0c;用于存储和传输结构化的数据。它基于JavaScript的语法&#xff0c;但是可以被多种编程语言所支持和解析&#xff0c;因此被广泛应用于各种场景。 J…

Spring Security 复盘

1、什么Spring Security&#xff1f; Spring Security 是一种强大的框架&#xff0c;它在 Spring 生态系统中扮演着保护应用安全的关键角色。Spring Security 基于 Spring 框架&#xff0c;提供了一套 Web 应用安全性的完整解决方案。 2、认证 和 授权 1.什么是认证&#xff1…

Python 全栈系列244 nginx upstream 负载均衡 踩坑日记

说明 最初是因为租用算力机(Python 全栈系列242 踩坑记录:租用算力机完成任务)&#xff0c;所以想着做一个负载均衡&#xff0c;然后多开一些服务&#xff0c;把配置写在nginx里面就好了。 一开始租用了一个3080起了一个服务&#xff0c;后来觉得速度不够快&#xff0c;再起了…

如何解决IntelliJ IDEA中pom.xml依赖项引发的安全漏洞黄线警告问题

背景 在开发过程中&#xff0c;当我们在pom.xml文件中添加依赖项时&#xff0c;经常会发现IntelliJ IDEA报出黄色警告线条&#xff0c;提示存在潜在的安全漏洞。警告的具体展现形式如下&#xff1a; 解决方案 首先&#xff0c;打开设置菜单界面&#xff0c;接着选择编辑器选…

【图解计算机网络】http1.1,http2.0,http3.0

http1.1&#xff0c;http2.0&#xff0c;http3.0 http1.1长连接管道传输缺点 http2.0头部压缩二进制格式并发传输服务端推送缺点 http3.0无队头阻塞快速建立连接连接迁移 http1.1 长连接 在http1.0的时候&#xff0c;一次http请求就要建立一次TCP连接&#xff0c;这一次的htt…

Vue3项目Easy云盘(二):文件列表+新建目录+文件重命名+文件上传

一、文件列表 1.封装全局组件Table.vue 因为Main.vue等都会用到文件列表table&#xff0c;所以直接封装成组件。 src/components/Table.vue <template><!-- 表格 --><div><el-tableref"dataTable":data"dataSource.list || []":h…

基于AIoTedge+ThingsKit物联网平台,实现办公室人员进出AI统计

在AIoT时代&#xff0c;智能办公已成为提升企业效率的关键。本期文章将带你了解如何利用AIoTedge结合ThingsKit物联网平台&#xff0c;实现办公室人员进出的智能统计。这不是简单的技术堆砌&#xff0c;而是一场关于AI与IoT融合的实战演示。&#x1f31f; 提示&#xff1a;AIoT…

【Linux】Centos7安装部署unimrcp,搭建MRCP服务器

yum install libtool yum install libtool-ltdl-devel yum install libsofia-sip-ua find / -name libsofia-sip-ua.so.0 2>/dev/null # 设置环境变量&#xff1a;如果库文件存在但不在默认搜索路径中&#xff0c;你可以通过设置 LD_LIBRARY_PATH 环境变量来告诉系统在哪…

Github学习

1.Git与Github 区别: Git是一个分布式版本控制系统&#xff0c;简单的说就是一个软件&#xff0c;用于记录一个或若干个文件内容变化&#xff0c;以便将来查阅特点版本修订情况的软件。 Github是一个为用户提高Git服务的网站&#xff0c;简单说就是一个可以放代码的地方。Gi…

数字化社会的引擎:揭示Facebook的影响力

在当今数字化社会中&#xff0c;社交媒体平台扮演着至关重要的角色&#xff0c;而Facebook作为其中的巨头之一&#xff0c;其影响力不可忽视。本文将深入探讨Facebook的影响力&#xff0c;从多个角度揭示其在数字化社会中的引擎作用。 1. 社交互动的核心平台 Facebook作为社交…