【深度学习】GELU激活函数是什么?

torch.nn.GELU 模块在 PyTorch 中实现了高斯误差线性单元(GELU)激活函数。GELU 被用于许多深度学习模型中,包括Transformer,因为它相比传统的 ReLU(整流线性单元)函数能够更好地近似神经元的真实激活行为。

概述

  • 功能: 应用 GELU 激活函数。
  • 公式: GELU 激活函数可以表示为:
    GELU ( x ) = x ⋅ Φ ( x ) \text{GELU}(x) = x \cdot \Phi(x) GELU(x)=xΦ(x)
    其中 Φ ( x ) \Phi(x) Φ(x) 是标准正态分布的累积分布函数。

使用方法

要在神经网络中使用 torch.nn.GELU 模块,你可以简单地导入它并将其添加到模型的层中。以下是一个示例:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layer1 = nn.Linear(10, 20)
        self.gelu = nn.GELU()
        self.layer2 = nn.Linear(20, 10)

    def forward(self, x):
        x = self.layer1(x)
        x = self.gelu(x)
        x = self.layer2(x)
        return x

# 创建模型实例
model = MyModel()

# 创建一个随机输入张量
input_tensor = torch.randn(5, 10)

# 前向传播
output_tensor = model(input_tensor)
print(output_tensor)

解释

  • nn.Linear(10, 20): 一个线性层,输入大小为 10,输出大小为 20。
  • nn.GELU(): GELU 激活函数,应用于第一个线性层的输出。
  • nn.Linear(20, 10): 另一个线性层,输入大小为 20,输出大小为 10。

GELU 激活函数用于在模型中引入非线性,这有助于模型学习更复杂的模式。

GELU 的优点

  • 平滑近似: GELU 提供了一种比 ReLU 更平滑的神经元激活近似,这可以帮助训练的稳定性和收敛性。
  • 概率解释: 通过结合累积分布函数,GELU 以更有原则的方式考虑激活的概率,这可能在某些模型(尤其是自然语言处理 (NLP) 和计算机视觉 (CV) 中)带来更好的性能。

通过在你的 PyTorch 模型中使用 torch.nn.GELU,你可以利用这些优点来提高神经网络的性能和训练动态。

GELU(Gaussian Error Linear Unit)激活函数是在论文《Gaussian Error Linear Units (GELUs)》中提出的。这篇论文由 Dan Hendrycks 和 Kevin Gimpel 于 2016 年发表。

以下是使用 Python 和 Matplotlib 绘制 GELU 激活函数的函数曲线的代码:

import numpy as np
import matplotlib.pyplot as plt
from scipy.special import erf

# 定义 GELU 激活函数
def gelu(x):
    return 0.5 * x * (1 + erf(x / np.sqrt(2)))

# 生成 x 轴数据
x = np.linspace(-3, 3, 400)
# 计算 y 轴数据
y = gelu(x)

# 绘制 GELU 激活函数曲线
plt.figure(figsize=(8, 6))
plt.plot(x, y, label='GELU', color='blue')
plt.title('GELU Activation Function')
plt.xlabel('Input')
plt.ylabel('Output')
plt.legend()
plt.grid(True)
plt.show()

运行上述代码将生成一个展示 GELU 激活函数的曲线图:

在这里插入图片描述
优点:

  • 平滑的近似:

GELU 提供了比 ReLU 更平滑的激活函数,这有助于神经网络更稳定地训练并提高收敛性。ReLU 在负数区间完全关闭,而 GELU 会根据输入值的大小逐渐激活神经元。

  • 概率解释:

GELU 将标准正态分布的累积分布函数(CDF)结合到激活函数中,以一种更有原则的方式处理激活的概率。这种方法考虑了输入值的分布,使得神经网络可以更有效地处理不同范围的输入。

  • 更好的性能:

由于 GELU 函数的平滑性和概率解释,它在处理某些任务时(尤其是在自然语言处理 (NLP) 和计算机视觉 (CV) 任务中)表现出色。在这些任务中,GELU 激活函数可以提高模型的性能。

  • 渐进式变化:

相对于 ReLU 的硬边界(即大于零输出本身,小于零输出零),GELU 提供了一种更加渐进式的激活方式,使得小负值输入仍然能够产生一定的激活效果,这在某些情况下可以提高模型的表现。

GELU 反向传播的公式

GELU 激活函数的公式

GELU 激活函数定义为:
GELU ( x ) = x ⋅ Φ ( x ) \text{GELU}(x) = x \cdot \Phi(x) GELU(x)=xΦ(x)
其中 Φ ( x ) \Phi(x) Φ(x) 是标准正态分布的累积分布函数。 Φ ( x ) \Phi(x) Φ(x) 的表达式为:
Φ ( x ) = 1 2 ( 1 + erf ( x 2 ) ) \Phi(x) = \frac{1}{2} \left( 1 + \text{erf}\left( \frac{x}{\sqrt{2}} \right) \right) Φ(x)=21(1+erf(2 x))

GELU 的梯度公式

为了求 GELU 的梯度,我们需要对其进行求导。这里 erf ( x ) \text{erf}(x) erf(x) 是误差函数,定义为:
erf ( x ) = 2 π ∫ 0 x e − t 2   d t \text{erf}(x) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} \, dt erf(x)=π 20xet2dt

求导过程

GELU 的导数可以表示为:
d d x GELU ( x ) = d d x ( x ⋅ Φ ( x ) ) \frac{d}{dx} \text{GELU}(x) = \frac{d}{dx} \left( x \cdot \Phi(x) \right) dxdGELU(x)=dxd(xΦ(x))

根据乘积法则:
d d x ( x ⋅ Φ ( x ) ) = Φ ( x ) + x ⋅ d d x Φ ( x ) \frac{d}{dx} \left( x \cdot \Phi(x) \right) = \Phi(x) + x \cdot \frac{d}{dx} \Phi(x) dxd(xΦ(x))=Φ(x)+xdxdΦ(x)

我们需要对 Φ ( x ) \Phi(x) Φ(x) 进行求导:
d d x Φ ( x ) = d d x ( 1 2 ( 1 + erf ( x 2 ) ) ) \frac{d}{dx} \Phi(x) = \frac{d}{dx} \left( \frac{1}{2} \left( 1 + \text{erf}\left( \frac{x}{\sqrt{2}} \right) \right) \right) dxdΦ(x)=dxd(21(1+erf(2 x)))

由于常数部分导数为零,我们仅对 erf ( x 2 ) \text{erf}\left( \frac{x}{\sqrt{2}} \right) erf(2 x) 进行求导:
d d x erf ( x 2 ) = 2 π e − ( x 2 ) 2 ⋅ 1 2 = e − x 2 / 2 2 π \frac{d}{dx} \text{erf}\left( \frac{x}{\sqrt{2}} \right) = \frac{2}{\sqrt{\pi}} e^{-\left( \frac{x}{\sqrt{2}} \right)^2} \cdot \frac{1}{\sqrt{2}} = \frac{e^{-x^2/2}}{\sqrt{2\pi}} dxderf(2 x)=π 2e(2 x)22 1=2π ex2/2

所以:
d d x Φ ( x ) = 1 2 π e − x 2 / 2 \frac{d}{dx} \Phi(x) = \frac{1}{\sqrt{2\pi}} e^{-x^2/2} dxdΦ(x)=2π 1ex2/2

将其代入前面的公式,我们得到:
d d x GELU ( x ) = Φ ( x ) + x ⋅ 1 2 π e − x 2 / 2 \frac{d}{dx} \text{GELU}(x) = \Phi(x) + x \cdot \frac{1}{\sqrt{2\pi}} e^{-x^2/2} dxdGELU(x)=Φ(x)+x2π 1ex2/2

因此,GELU 的梯度为:
d d x GELU ( x ) = 1 2 ( 1 + erf ( x 2 ) ) + x ⋅ 1 2 π e − x 2 / 2 \frac{d}{dx} \text{GELU}(x) = \frac{1}{2} \left( 1 + \text{erf}\left( \frac{x}{\sqrt{2}} \right) \right) + x \cdot \frac{1}{\sqrt{2\pi}} e^{-x^2/2} dxdGELU(x)=21(1+erf(2 x))+x2π 1ex2/2

Python 代码绘制 GELU 梯度曲线

import numpy as np
import matplotlib.pyplot as plt
from scipy.special import erf, erfc

# 定义 GELU 激活函数
def gelu(x):
    return 0.5 * x * (1 + erf(x / np.sqrt(2)))

# 定义 GELU 激活函数的导数
def gelu_derivative(x):
    return 0.5 * (1 + erf(x / np.sqrt(2))) + (x * np.exp(-x**2 / 2)) / np.sqrt(2 * np.pi)

# 生成 x 轴数据
x = np.linspace(-3, 3, 400)
# 计算 y 轴数据
y = gelu(x)
# 计算 y' 轴数据
dy = gelu_derivative(x)

# 绘制 GELU 激活函数和梯度曲线
plt.figure(figsize=(8, 6))
plt.plot(x, y, label='GELU', color='blue')
plt.plot(x, dy, label='GELU Derivative', color='red', linestyle='dashed')
plt.title('GELU Activation Function and Its Derivative')
plt.xlabel('Input')
plt.ylabel('Output')
plt.legend()
plt.grid(True)
plt.show()

运行这段代码将生成一个展示 GELU 激活函数及其梯度的曲线图,有助于直观地理解 GELU 在反向传播中的行为:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/727568.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ARM64汇编0B - 函数调用约定

建议先看《CSAPP》的3.7节,讲的很细。我们这里就直接看例子来分析了。 例子 static int func(int a, int b, int c, int d, int e, int f, int g, int h, int i) {printf("%s\n", "add all");int x a b;return a b c d e f g h i; …

Faiss:选择合适的索引Index

向量相似性搜索彻底改变了搜索领域。它允许我们高效地检索从GIF到文章等各种媒体,即使在处理十亿级别数据集时,也能在亚秒级时间内提供令人印象深刻的准确性。 然而,这种灵活性也带来了一个问题:如何知道哪种索引大小最适合我们的…

2-11 基于matlab的BP-Adaboost的强分类器分类预测

基于matlab的BP-Adaboost的强分类器分类预测,Adaboost是一种迭代分类算法,其在同一训练集采用不同方法训练不同分类器(弱分类器),并根据弱分类器的误差分配不同权重,然后将这些弱分类器组合成一个更强的最终…

check python checking for Python executable “python2“ in the PATH

背景: mac电脑升级后重新拉取老项目后安装node_module 和启动项目报错 gyp info using node-gyp3.8.0 gyp info using node14.18.0 | darwin | x64 gyp verb command rebuild [] gyp verb command clean [] gyp verb clean removing "build" directory …

Python基础教程(二十六):对接MongoDB

💝💝💝首先,欢迎各位来到我的博客,很高兴能够在这里和您见面!希望您在这里不仅可以有所收获,同时也能感受到一份轻松欢乐的氛围,祝你生活愉快! 💝&#x1f49…

视听分割相关论文阅读

1. End-to-End Referring Video Object Segmentation with Multimodal Transformers RVOS(视频中的参考对象分割)比RIS(图像中的参考对象分割)要困难得多,因为指代动作的文本表达通常无法从单个静态帧中正确推断出来。…

Web3 学习

之前学习 web3,走了不少弯路,最近看到了 hackquest,重新刷了一遍以太坊基础,感觉非常nice,而且完全免费,有需要的可以试试,链接hackquest.io。

劲爆!Kimi月之暗面可以接入微信,智能升级, 打造个性多Agent(二)

前言 在当今这个快速发展的AI时代,抖音推出了一个名为“扣子Coze”的工具,帮助用户快速、低门槛地搭建属于自己的AI机器人。本文将详细介绍如何使用扣子Coze配置自己的AI Agent,并展示其在多个平台上的应用。 如何使用多个Agent 搭建更加智…

【DKN: Deep Knowledge-Aware Network for News Recommendation】

DKN: Deep Knowledge-Aware Network for News Recommendation 摘要 在线新闻推荐系统旨在解决新闻信息爆炸的问题,为用户进行个性化推荐。 总体而言,新闻语言高度凝练,充满知识实体和常识。 然而,现有的方法并没有意识到这些外部…

[Kubernetes] etcd 单机和集群部署

文章目录 1.etcd基本概念2.etcd的基本知识3.etcd优势4.etcd单机部署4.1 linux部署4.2 windows部署4.3 docker安装etcd 5.etcd集群部署 1.etcd基本概念 etcd是一个高可用的分布式键值存储系统,是CoreOS(现在隶属于Red Hat)公司开发的一个开源…

Redis缓存的一些概念性问题

目录 缓存模型和思路 缓存更新策略 数据库和缓存不一致 缓存与数据库双写一致 缓存穿透 缓存雪崩 缓存击穿 速度快,好用,内存的读写性能远高于磁盘,缓存可以大大降低用户访问并发量带来的服务器读写压力 缓存模型和思路 标准的操作方式就是查询数据库之前先…

用户态协议栈02-arp reply实现

在上一节DODK的UDP收发中发送udp包的时候,需要向物理机的arp表中添加一个静态的arp记录。这在生产环境中显然是不可以的。在内核的协议栈中,会将自己的ip和mac在局域网中进行广播,并且记录其他电脑的ip和mac。在需要发送数据包的时候&#xf…

AD使用快捷键

1、如何实现元器件旋转45放置 在Preferences >> PCB Editor >> General中将Rotation Step(旋转的步进值)由90改为45,这样以后每次按空格键旋转器件时旋转角度为45。 2、显示网络、隐藏网络 N 3、对齐 2、设置DRC检查选项&#xf…

clickhouse学习笔记(四)库、表、分区相关DDL操作

目录 一、数据库操作 1、创建数据库 2、查询及选择数据库 3、删除数据库 二、数据表操作 1、创建表 2、删除表 3、基本操作 ①追加新字段 ②修改字段类型或默认值 ③修改字段注释 ④删除已有字段 ⑤移动数据表(重命名) ⑥清空表 三、默认值…

【前端技巧】css篇

利用counter实现计数器 counter-reset&#xff1a;为计数器设置名称&#xff0c;语法如下&#xff1a; counter-rese: <idntifier><integer>第一个参数为变量名称&#xff0c;第二个参数为初始值&#xff0c;默认为0 counter-increment&#xff1a;设置计数器增…

OpenGL绘制Bezier曲面

Bezier的定义 贝塞尔曲面是贝塞尔曲线在二维上的扩展。它由一组控制点定义,通过这些控制点生成光滑的曲面。贝塞尔曲面通常用两个参数 u u u和 v v v来表示,这两个参数的取值范围都在 [0, 1] 之间。 数学表示 P ( u , v ) = ∑ i = 0 n ∑ j = 0 m p i j ⋅ B i , n ( u ) ⋅…

YOLOv10改进 | Conv篇 |YOLOv10引入SPD-Conv卷积

1. SPD-Conv介绍 1.1 摘要:卷积神经网络(CNN)在图像分类和目标检测等许多计算机视觉任务中取得了巨大的成功。 然而,在图像分辨率较低或物体较小的更艰巨的任务中,它们的性能会迅速下降。 在本文中,我们指出,这源于现有 CNN 架构中一个有缺陷但常见的设计,即使用跨步卷…

MicroPython+ESP32 C3+ST7735S LCD屏 WIFI联网显示实时时间

案例地址&#xff1a;https://gitee.com/whltaoin_admin/MP_ESP32_ST7735S- 展示效果 ESP32LCD屏 WIFI联网并显示实时时间 TFT LCD模块参数介绍 名称&#xff1a;1.8 128*160 RGB_TFT驱动芯片&#xff1a;ST7735S ESP32 C3 参数介绍&#xff08;经典款&#xff09; 外观及…

Kafka基础教程

Kafka基础教程 资料来源&#xff1a;Apache Kafka - Introduction (tutorialspoint.com) Apache Kafka起源于LinkedIn&#xff0c;后来在2011年成为一个开源Apache项目&#xff0c;然后在2012年成为一流的Apache项目。Kafka是用Scala和Java编写的。Apache Kafka是基于发布-订…

leetcode:557. 反转字符串中的单词 III(python3解法)

难度&#xff1a;简单 给定一个字符串 s &#xff0c;你需要反转字符串中每个单词的字符顺序&#xff0c;同时仍保留空格和单词的初始顺序。 示例 1&#xff1a; 输入&#xff1a;s "Lets take LeetCode contest" 输出&#xff1a;"steL ekat edoCteeL tsetnoc…