Llama改进之——SwiGLU激活函数

引言

今天介绍LLAMA模型引入的关于激活函数的改进——SwiGLU1,该激活函数取得了不错的效果,得到了广泛地应用。

SwiGLU是GLU的一种变体,其中包含了GLU和Swish激活函数。

GLU

GLU(Gated Linear Units,门控线性单元)2引入了两个不同的线性层,其中一个首先经过sigmoid函数,其结果将和另一个线性层的输出进行逐元素相乘作为最终的输出:
GLU ( x , W , V , b , c ) = σ ( x W + b ) ⊗ ( x V + c ) (1) \text{GLU}(x,W,V,b,c) = \sigma(xW+b) \otimes (xV+c) \tag 1 GLU(x,W,V,b,c)=σ(xW+b)(xV+c)(1)
这里 W , V W,V W,V以及 b , c b,c b,c分别是这两个线性层的参数; σ ( x W + b ) \sigma(xW+b) σ(xW+b)作为门控,控制 x V + c xV+c xV+c的输出。

这里使用 σ \sigma σ作为激活函数,修改改激活函数得到的变体通常能带来更好的性能表现,比如SwiGLU修改激活函数为Swish。我们来看下Swish激活函数。

Swish

Swish3激活函数的形式为:
Swish β ( x ) = x σ ( β x ) (2) \text{Swish}_\beta(x) = x \sigma(\beta x) \tag 2 Swishβ(x)=xσ(βx)(2)
其中 σ ( x ) \sigma(x) σ(x)是Sigmoid函数; β \beta β是一个可学习的参数。

可以通过下面的代码画出Swish激活函数在不同参数 β \beta β下的图像:

import numpy as np
import matplotlib.pyplot as plt

def swish(x, beta):
  return x / (1 + np.exp(-beta*x))

x = np.linspace(-10, 10, 100)
betas = [0.1, 1.0, 10.0]

plt.figure(figsize=(10, 6))

for beta in betas:
    y = swish(x, beta)
    plt.plot(x, y, label=f'beta={beta}')

plt.legend()
plt.title('Swish Activation Function')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.grid(True)
plt.show()

image-20240428224729925

可以看到3,当 β \beta β趋近于 0 0 0时,Swish函数趋近于线性函数 y = x 2 y=x^2 y=x2;当 β \beta β趋近于无穷大时,Swish函数趋近于ReLU函数;当 β \beta β取值为 1 1 1时,Swish函数是光滑且非单调的,等价于参考4中介绍的SiLU。

Swish与ReLU之间最显著的区别是当 x < 0 x < 0 x<0时Swish的非单调“凸起”3

SwiGLU

如前文所述,将公式(1)中GLU的激活函数改为Swish即变成了所谓的SwiGLU激活函数1
SwiGLU ( x , W , V ) = Swish β ( x W ) ⊗ ( x V ) (3) \text{SwiGLU}(x,W,V) = \text{Swish}_\beta(xW) \otimes (xV) \tag{3} SwiGLU(x,W,V)=Swishβ(xW)(xV)(3)
这里省略了偏置项。

代码实现

参考LLaMA,全连接层使用带有SwiGLU激活函数的FFN(Position-wise Feed-Forward Network)的公式如下1
FFN SwiGLU ( x , W , V , W 2 ) = ( Swish 1 ( x W ) ⊗ x V ) W 2 (4) \text{FFN}_{\text{SwiGLU}}(\pmb x,W,V,W_2) = (\text{Swish}_1(\pmb xW) \otimes \pmb xV)W_2 \tag 4 FFNSwiGLU(x,W,V,W2)=(Swish1(xW)xV)W2(4)
这里的Swish函数可以被SiLU函数替代:
SiLU ( x ) = x σ ( x ) \text{SiLU}(\pmb x) = \pmb x \sigma(\pmb x) SiLU(x)=xσ(x)
即:
FFN SwiGLU ( x , W , V , W 2 ) = ( SiLU ( x W ) ⊗ x V ) W 2 (5) \text{FFN}_{\text{SwiGLU}}(\pmb x,W,V,W_2) = (\text{SiLU}(\pmb xW) \otimes \pmb xV)W_2 \tag 5 FFNSwiGLU(x,W,V,W2)=(SiLU(xW)xV)W2(5)

import torch
from torch import nn
import torch.nn.functional as F

class FeedForward(nn.Module):
    def __init__(self, hidden_size: int, intermediate_size: int) -> None:
       	super().__init__()

        self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, seq_len, hidden_size)
        # w1(x) -> (batch_size, seq_len, intermediate_size)
        # w1(x) -> (batch_size, seq_len, intermediate_size)
        # w2(*) -> (batch_size, seq_len, hidden_size)
    	return self.w2(F.silu(self.w1(x)) * self.w3(x))
            

这里w1,w2,w3分别对应公式(5)中的 W , W 2 , V W,W_2,V W,W2,V

注意维度,其中w1,w3x转换到维度intermediate_size,然后w2转换回hidden_size

参考


  1. [论文翻译]GLU Variants Improve Transformer ↩︎ ↩︎ ↩︎

  2. [论文笔记]Language Modeling with Gated Convolutional Networks ↩︎

  3. [论文笔记]SEARCHING FOR ACTIVATION FUNCTIONS ↩︎ ↩︎ ↩︎

  4. [论文笔记]GAUSSIAN ERROR LINEAR UNITS (GELUS) ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/592675.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux(openEuler、CentOS8)常用的IP修改方式(文本配置工具nmtui+配置文件+nmcli命令)

----本实验环境为openEuler系统<以server方式安装>&#xff08;CentOS类似&#xff0c;可参考本文&#xff09;---- 一、知识点 &#xff08;一&#xff09;文本配置工具nmtui(openEuler已预装) nmtui&#xff08;NetworkManager Text User Interface&#xff09;是一…

ZooKeeper以及DolphinScheduler的用法

目录 一、ZooKeeper的介绍 数据模型 ​编辑 操作使用 ①登录客户端 ​编辑 ②可以查看下面节点有哪些 ③创建新的节点&#xff0c;并指定数据 ④查看节点内的数据 ⑤、删除节点及数据 特殊点&#xff1a; 运行机制&#xff1a; 二、DolphinScheduler的介绍 架构&#…

计算机毕业设计Python+Spark知识图谱高考志愿推荐系统 高考数据分析 高考可视化 高考大数据 大数据毕业设计

毕业设计&#xff08;论文&#xff09;任务书 毕业设计&#xff08;论文&#xff09;题目&#xff1a; 基于大数据的高考志愿推荐系统 设计&#xff08;论文&#xff09;的主要内容与要求&#xff1a; 主要内容&#xff1a; 高…

贝叶斯回归

1. 贝叶斯推断的定义 简单来说&#xff0c;贝叶斯推断 (Bayesian inference) 就是结合“经验 (先验)”和“实践 (样本)”&#xff0c;得出“结论 (后 验)”。 2. 什么是先验&#xff1f; 贝叶斯推断把模型参数看作随机变量。在得到样本之前&#xff0c;根据主观经验和既有知…

巧记英语单词

页面 在输入框中填写英语单词的谐音 这样的话就进行了一次英语单词的记忆练习。 页面代码 <% layout(/layouts/default.html, {title: 英语单词管理, libs: [dataGrid]}){ %> <div class"main-content"><div class"box box-main">&l…

anaconda、cuda、tensorflow、pycharm环境安装

anaconda、cuda、tensorflow、pycharm环境安装 anaconda安装 anaconda官方下载地址 本文使用的是基于python3.9的anaconda 接下来跟着步骤安装&#xff1a; 检验conda是否成功安装 安装CUDA和cuDNN 提醒&#xff0c;CUDA和cuDNN两者必须版本对应&#xff0c;否者将会出错…

my-room-in-3d中的电脑,电视,桌面光带发光原理

1. my-room-in-3d中的电脑&#xff0c;电视&#xff0c;桌面光带发光原理 最近在github中&#xff0c;看到了这样的一个项目&#xff1b; 项目地址 我看到的时候&#xff0c;蛮好奇他这个光带时怎么做的。 最后发现&#xff0c;他是通过&#xff0c;加载一个 lightMap.jpg这个…

大型语言模型的新挑战:AMR语义表示的神秘力量

DeepVisionary 每日深度学习前沿科技推送&顶会论文&数学建模与科技信息前沿资讯分享&#xff0c;与你一起了解前沿科技知识&#xff01; 引言&#xff1a;AMR在大型语言模型中的作用 在自然语言处理&#xff08;NLP&#xff09;的领域中&#xff0c;抽象意义表示&…

查找算法与排序算法

查找算法 二分查找 (要求熟练) // C// 二分查找法&#xff08;递归实现&#xff09; int binarySearch(int *nums, int target, int left, int right) // left代表左边界&#xff0c;right代表右边界 {if (left > right) return -1; // 如果左边大于右边&#xff0c;那么…

esp8266与uno使用软串口通信

esp8266的d6和d5分别与uno的5和6管脚连接&#xff1a; uno程序&#xff1a; //uno #include <SoftwareSerial.h> SoftwareSerial s(5,6);//(RX,TX)void setup(){s.begin(9600);Serial.begin(9600); }void loop(){int data50;if (s.available() > 0) {char c s.read(…

【错题集-编程题】比那名居的桃子(滑动窗口 / 前缀和)

牛客对应题目链接&#xff1a;比那名居的桃子 (nowcoder.com) 一、分析题目 1、滑动窗口 由题意得&#xff0c;我们是要枚举所有大小为 k 的子数组&#xff0c;并且求出这段⼦数组中快乐值和羞耻度之和。因此&#xff0c;可以利用滑动窗口的思想&#xff0c;用两个变量维护大小…

【区块链】共识算法简介

共识算法简介 区块链三要素&#xff1a; 去中心化共识算法智能合约 共识算法作为区块链三大核心技术之一&#xff0c;其重要性不言而喻。今天就来简单介绍共识算法的基本知识。 最简单的解释&#xff0c;共识算法就是要让所有节点达成共识&#xff0c;保证少数服从多数&#x…

从零开始学AI绘画,万字Stable Diffusion终极教程(六)

【第6期】知识补充 欢迎来到SD的终极教程&#xff0c;这是我们的第六节课&#xff0c;也是最后一节课 这套课程分为六节课&#xff0c;会系统性的介绍sd的全部功能&#xff0c;让你打下坚实牢靠的基础 1.SD入门 2.关键词 3.Lora模型 4.图生图 5.controlnet 6.知识补充 …

初识C语言——第九天

ASCII定义 在 C 语言中&#xff0c;每个字符都对应一个 ASCII 码。ASCII 码是一个字符集&#xff0c;它定义了许多常用的字符对应的数字编码。这些编码可以表示为整数&#xff0c;也可以表示为字符类型。在 C 语言中&#xff0c;字符类型被定义为一个整数类型&#xff0c;它占…

C/C++开发,opencv-ml库学习,K近邻(KNN)应用

目录 一、k近邻算法 1.1 算法简介 1.2 opencv-k近邻算法 二、cv::ml::KNearest应用 2.1 数据集样本准备 2.2 KNearest应用 2.3 程序编译 2.4 main.cpp全代码 一、k近邻算法 1.1 算法简介 K近邻算法&#xff08;K-Nearest Neighbor&#xff0c;KNN&#xff09;基本原理是…

Vue按照顺序实现多级弹窗(附Demo)

目录 前言1. 单个弹窗2. 多级弹窗 前言 强化各个知识点&#xff0c;以实战融合&#xff0c;以下两个Demo从实战提取 1. 单个弹窗 部署按钮框以及确定的方法即可 截图如下所示&#xff1a; 以下Demo整体逻辑如下&#xff1a; 点击“生成周月计划”按钮会触发showWeekPlanDia…

FLIR LEPTON3.5 热像仪wifi 科研实验测温采集仪

点击查看详情!点击查看详情点击查看详情点击查看详情点击查看详情点击查看详情点击查看详情点击查看详情点击查看详情点击查看详情点击查看详情点击查看详情点击查看详情点击查看详情点击查看详情 1、描述 这是一款桌面科研实验测温热成像多功能热像记录仪&#xff0c;小巧轻便…

STM32微秒级别延时--F407--TIM1

基本配置&#xff1a; TIM1挂载在APB2总线上&#xff0c;150MHz经过15分频&#xff0c;得到10MHz计数频率&#xff0c;由于disable了自动重装载&#xff0c;所以只需要看下一次计数值是多少即可。 void TIM1_Delay_us(uint16_t us) //使用阻塞方式进行延时&#xff0c;ARR值不…

记录vue报错问题 in ./node_modules/axios/lib/platform/index.js

今天这个问题困扰了我许久 报错内容如下&#xff1a; 最初一直以为是我没装axios&#xff0c;又重新装了一次&#xff0c;后面才发现是axios版本原因&#xff0c;真的总是被版本的原因困住真的很烦 解决方法如下&#xff1a; 将axios的版本改为1.5.0 1、打开项目的文件夹“…

Linux命令--查找占磁盘空间最大的文件

原文网址&#xff1a;Linux命令--查找占磁盘空间最大的文件-CSDN博客 简介 本文介绍Linux怎样查找占磁盘空间最大的文件。 1.找到占空间最大的分区 命令 df -h 结果 2.查找分区里最大的文件 法1&#xff1a;直接查找最大的文件 sudo find my_folder -type f -exec du -…