003 FeedForward前馈层

一、环境

本文使用环境为:

  • Windows10
  • Python 3.9.17
  • torch 1.13.1+cu117
  • torchvision 0.14.1+cu117

二、前馈层原理

Transformer模型中的前馈层(Feed Forward Layer)是其关键组件之一,对于模型的性能起着重要作用。下面将用900字对Transformer的前馈层进行详细讲解。前馈层在网络中位置,如下图:

2.1 前馈层的基本概念

前馈层,又称为全连接层(Fully Connected Layer)或密集层(Dense Layer),是神经网络中最基础的组件之一。在Transformer模型中,前馈层通常出现在每一个Transformer编码器(Encoder)和解码器(Decoder)中的每一个自注意力(Self-Attention)层之后。

2.2 前馈层的结构

  1. 线性变换:前馈层的输入首先经过一个线性变换,将输入映射到一个高维空间。这个线性变换通常由一个权重矩阵和一个偏置向量实现。
  2. 激活函数:经过线性变换后,输入会通过一个激活函数,增加模型的非线性表达能力。常用的激活函数包括ReLU、Sigmoid、Tanh等。
  3. 线性变换与输出:经过激活函数后,再进行一次线性变换,将高维空间的特征映射回原始空间,得到前馈层的输出。

2.3 前馈层的作用

  1. 特征提取:前馈层通过线性变换和激活函数,可以提取输入数据的复杂特征,有助于模型更好地理解数据。
  2. 增加模型复杂度:通过增加前馈层的隐藏单元数量,可以增加模型的复杂度,提高模型的表达能力。
  3. 缓解梯度消失问题:在深度神经网络中,梯度消失是一个常见的问题。前馈层中的激活函数可以缓解这个问题,使得模型能够更好地训练。

三、完整代码

前馈层接受自注意力子层的输出作为输入,并通过一个带有 Relu 激活函数的两层全连接网络对输入进行更加复杂的非线性变换。实验证明,这一非线性变换会对模型最终的性能产生十分重要的影响。

其中 W1, b1,W2, b2 表示前馈子层的参数。实验结果表明,增大前馈子层隐状态的维度有利于提
升最终翻译结果的质量,因此,前馈子层隐状态的维度一般比自注意力子层要大。使用 Pytorch 实现的前馈层参考代码如下:

import torch.nn as nn
import torch.nn.functional as F
import torch

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout = 0.1):
        super().__init__()
        # d_ff 默认设置为 2048
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
    def forward(self, x):
        x = self.dropout(F.relu(self.linear_1(x)))
        x = self.linear_2(x)
        return x



batch_size = 32
seq_len = 64
d_model = 512
ff = FeedForward(d_model=d_model)
x = torch.randn(batch_size, seq_len, d_model) # 32 x 64 x 512
ff.eval()
print(ff(x).shape) # torch.Size([32, 64, 512])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/241672.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

cpp:1:10: fatal error: opencv2/core.hpp: 没有那个文件或目录

前言&#xff1a; 我按照官网方法安装了opencv&#xff0c;运行的也是官网的测试代码&#xff1a; #include <opencv2/core.hpp> #include <opencv2/highgui.hpp> using namespace cv; int main() {printf("hello world")return 0; }

boost1.55 安装使用教程 windows

第一步 &#xff1a;首先在boost官网上下载库压缩包 添加链接描述 选择自己需要的版本进行下载 解压后执行booststrap.bat 用来生成创建b2.exe 和bjam.exe 拓展&#xff1a;.\b2 --help 了解一下有哪些参数可以配置 默认b2.exe编译后&#xff0c;链接到项目如果出现如下错误…

VLAN基本原理

目录 一、VLAN概念及优势 &#xff08;一&#xff09;基本理念 &#xff08;二&#xff09;VLAN的特点 二、VLAN ID 种类、范围及用途 &#xff08;一&#xff09;静态VLAN &#xff08;二&#xff09;动态VLAN &#xff08;三&#xff09;VLAN三种端口类型 &#xff0…

深入理解Java虚拟机---类加载机制

类加载机制 什么是类加载机制类加载的时机类加载的过程加载验证文件格式验证元数据验证字节码验证符号引用验证 准备解析初始化 类加载器双亲委派模型 什么是类加载机制 虚拟机把描述类的数据从 Class 文件加载到内存&#xff0c;并对数据进行校验、转换解析和初始化&#xff…

centOS安装bochsXshell连接centos启动可视化界面

centOS安装bochs 参考&#xff1a;https://blog.csdn.net/muzi_since/article/details/102559187 首先安装依赖环境&#xff1a; yum install gtk2 gtk2-devel yum install libXt libXt-devel yum install libXpm libXpm-devel yum install SDL SDL-devel yum install libXr…

已解决:No goals have been specified for this build. You must specify a vali

[ERROR] No goals have been specified for this build. You must specify a valiTOC 完整报错 No goals have been specified for this build. You must specify a valid lifecycle phase or a goal in the format : or :[:]:. Available lifecycle phases are: pre-clean, c…

6. Service详解

6. Service详解 文章目录 6. Service详解6.1 Service介绍6.2 Service类型6.3 Service使用6.3.1 实验环境准备6.3.2 ClusterIP类型的Service6.3.3 HeadLess类型的Service6.3.3.1 deployment和statefulset区别6.3.3.2 statefulset deployment 区别 6.3.4 NodePort类型的Service6.…

Trace 在多线程异步体系下传递

JAVA 线程异步常见的实现方式有&#xff1a; new ThreadExecutorService 当然还有其他的&#xff0c;比如fork-join&#xff0c;这些下文会有提及&#xff0c;下面主要针对这两种场景结合 DDTrace 和 Springboot 下进行实践。 引入 DDTrace sdk <properties><java.…

湖农大邀请赛shell_rce漏洞复现

湖农大邀请赛 shell_rce 复现 在 2023 年湖南农业大学邀请赛的线上初赛中&#xff0c;有一道 shell_rce 题&#xff0c;本文将复现该题。 题目内容&#xff0c;打开即是代码&#xff1a; <?phpclass shell{public $exp;public function __destruct(){$str preg_replace…

Shopify怎么避免被封店?封店原因有哪些?

市场研究的一份报告显示&#xff0c;全球跨境电子商务市场预计到2028年将达到30422亿美元&#xff0c;其中&#xff0c;亚太地区是最大的跨境电商市场&#xff0c;据海关统计数据&#xff0c;近五年来&#xff0c;我国跨境电商进出口增长近10倍。跨境电商业务新的增长风口已经到…

图像去噪——PMRID训练自己数据集及推理测试(详细图文教程)

目录 一、源码包准备二、数据集准备2.1 提取数据集名称2.2 .txt报错问题2.2.1 正确格式2.2.2 错误格式 三、修改配置参数四、训练及保存模型权重4.1 训练4.2 保存模型权重文件 五、模型推理测试5.1 导入测试集5.2 测试5.3 测试结果5.3.1 测试场景15.3.2 测试场景2 5.4 推理速度…

jsp+servlet+图书交流平台 有filter过滤器

在线图书推荐与交流平台 随着数字化的进展和人们对持续学习的追求&#xff0c;在线资源变得越来越受欢迎。对于众多读者来说&#xff0c;找到合适的书籍和与其他读者交流阅读体验是非常有价值的。为了满足这一需求&#xff0c;我们提出了一个在线图书推荐与交流平台的设计。此…

千梦网创:赚钱就是服侍好双爹

“爹啊&#xff0c;我没钱用啦&#xff0c;给我啃一下。” 想赚钱&#xff0c;最快的方式就是啃爹。 不管你做什么项目&#xff0c;同行永远都是我们的爹。 跟着爹走&#xff0c;有吃有喝不用愁。 跟着老爹走&#xff0c;蛋花汤里加骨头。 小时候父亲总是把我们高高的举过…

查询mysql服务器当前时区设置、session当前时区设置

使用命令SELECT global.time_zone;可以查询mysql服务器的当前时区设置&#xff0c;例如&#xff1a; 使用命令SELECT session.time_zone;可以查询session的当前时区设置&#xff0c;例如&#xff1a;

Vue 3 开发中遇到的问题及解决方案(fix bug)

开发环境&#xff1a;mac系统&#xff0c;node版本&#xff1a; 16.15.0 版本兼容问题 vite v3.2.4 building for development... hasInjectionContext is not exported by node_modules/pinia/node_modules/vue-demi/lib/index.mjs, imported by node_modules/pinia/dist/pini…

【算法题】冠亚军排名,奖牌榜排名(js)

解法&#xff1a; function solution(lines) {const list [];for (let i 0; i < lines.length; i) {const line lines[i];const [country, gold, silver, bronze] line.split(" ");list.push({country,gold: gold - 0,silver: silver - 0,bronze: bronze - 0…

Java 数据结构篇-用数组、堆实现优先级队列

&#x1f525;博客主页&#xff1a; 【小扳_-CSDN博客】 ❤感谢大家点赞&#x1f44d;收藏⭐评论✍ 文章目录 1.0 优先级队列说明 2.0 用数组实现优先级队列 3.0 无序数组实现优先级队列 3.1 无序数组实现优先级队列 - 入队列 offer(E value) 3.2 无序数组实现优先级队列 - 出…

RocketMQ可视化工具 打包遇到的yarn intall 问题

文章目录 RocketMQ可视化工具1.github上下载2.修改参数3.运行4.打包5.出错6.解决7.重试8.再解决9.很奇怪运行没错&#xff0c;但是测试错啦10.不想深究&#xff0c;直接跳过测试11.展示成功 RocketMQ可视化工具 1.github上下载 下载地址 https://github.com/apache/rocketmq-…

深度学习的目标检测算法综述

信息记录材料 2022年10月 第23卷第10期 【摘要】目标检测是深度学习的一个重要应用&#xff0c;目前在智能驾驶、工业检测相关领域都获得应用&#xff0c;具有重要的现实意义。本文对基于深度学习目标检测算法原理和应用情况进行简述&#xff0c;首先介绍结合区域提取和卷积神经…

Corona最新渲染器Corona11详解,附送下载地址

近日&#xff0c;Corona进行了大版本更新&#xff0c;发布了最新的Corona11。这次更新&#xff0c;包含众多新功能和新修复&#xff0c;借助 Corona 11 用户可将作品提升到更高的创作水准&#xff0c;更真实可感的视觉水平。 那么更新了那些呢&#xff1f;一起来看看吧&#x…