Meta Llama 3 前馈层

Meta Llama 3 前馈层

flyfish

在这里插入图片描述
图片来自论文 http://arxiv.org/pdf/2304.13712

因为树根是Transformer,所以这里会将 Llama 3 与Transformer比较下

Transformer的前馈层

在Transformer模型中,每个编码器和解码器层中都包含一个前馈神经网络(Feed-Forward Neural Network, FFN)。前馈神经网络的作用是对经过自注意力机制处理后的输出进行进一步的非线性变换和特征提取。

前馈神经网络的结构

每个前馈神经网络由两个线性变换和一个激活函数组成。具体结构如下:
FFN ( x ) = Linear 2 ( ReLU ( Linear 1 ( x ) ) ) \text{FFN}(x) = \text{Linear}_2(\text{ReLU}(\text{Linear}_1(x))) FFN(x)=Linear2(ReLU(Linear1(x)))

前馈神经网络的激活函数

Transformer中的前馈神经网络使用的激活函数是ReLU(Rectified Linear Unit)。ReLU的定义如下: ReLU ( x ) = max ⁡ ( 0 , x ) \text{ReLU}(x) = \max(0, x) ReLU(x)=max(0,x)

前馈神经网络的优缺点

优点
非线性特征提取:
ReLU激活函数引入非线性,使得前馈神经网络能够提取和表示输入数据中的复杂特征。

计算效率:
ReLU激活函数的计算非常简单,只需比较输入是否大于零,因此计算效率很高。
缓解梯度消失问题:

相较于传统的激活函数(如Sigmoid或Tanh),ReLU可以缓解梯度消失问题,特别是在深层神经网络中。

缺点
ReLU的死亡问题(Dead ReLU Problem):
当输入为负时,ReLU的输出恒为零。如果大量的神经元在训练过程中输出恒为零,它们将不会对模型的学习做出贡献。

参数选择:
前馈神经网络的隐藏层维度 d_ff 的选择需要经验和实验调整,过大或过小都会影响模型的性能和计算效率。

代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        # 前馈神经网络的计算过程
        out = F.relu(self.linear1(x))
        out = self.linear2(self.dropout(out))
        
        # 残差连接和层归一化
        out = self.norm(x + out)
        return out

# 定义模型参数
d_model = 512
d_ff = 2048
dropout = 0.1

# 创建前馈神经网络层
ffn = FeedForward(d_model, d_ff, dropout)

# 创建示例输入张量 (batch_size, seq_length, d_model)
batch_size = 32
seq_length = 10
input_tensor = torch.randn(batch_size, seq_length, d_model)

# 执行前向传播
output = ffn(input_tensor)

print("Output shape:", output.shape)

输出

Output shape: torch.Size([32, 10, 512])

self.linear1:第一个线性层,将输入从 d_model 维度变换到 d_ff 维度。
self.dropout:Dropout 层,用于在训练过程中随机丢弃一些神经元,防止过拟合。
self.linear2:第二个线性层,将隐藏层的输出从 d_ff 维度变换回 d_model 维度。
self.norm:层归一化,用于规范化输入,增加模型的稳定性。

LLama3的前馈神经网络实现

源码是

class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )
        self.w2 = RowParallelLinear(
            hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
        )
        self.w3 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

改造

不使用FairScale

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional

class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float] = None,
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

# 示例用法
dim = 512
hidden_dim = 2048
multiple_of = 64
ffn_dim_multiplier = 1.0

ffn = FeedForward(dim, hidden_dim, multiple_of, ffn_dim_multiplier)

# 创建示例输入张量 (batch_size, seq_length, dim)
batch_size = 32
seq_length = 10
input_tensor = torch.randn(batch_size, seq_length, dim)

# 执行前向传播
output = ffn(input_tensor)

print("Output shape:", output.shape)

输出

Output shape: torch.Size([32, 10, 512])

类定义:
FeedForward 类继承自 nn.Module,这是PyTorch中的基本模块类。

构造函数:
init 方法初始化前馈神经网络层的各个参数:
dim:输入和输出的特征维度。
hidden_dim:隐藏层的维度。
multiple_of:隐藏层维度的倍数,用于确保隐藏层的维度是某个值的整数倍。
ffn_dim_multiplier:一个可选的乘数,用于调整隐藏层的维度。

隐藏层维度计算:
hidden_dim 初始设定为 2/3 的原始隐藏层维度。
如果提供了 ffn_dim_multiplier,则乘以该值调整 hidden_dim。
使用 multiple_of 来确保 hidden_dim 是其整数倍

hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

线性层定义:
self.w1、self.w2 和 self.w3 都是 nn.Linear 层,分别用于不同的线性变换:
self.w1:将输入从 dim 维度映射到 hidden_dim 维度。
self.w2:将中间层从 hidden_dim 维度映射回 dim 维度。
self.w3:将输入从 dim 维度再次映射到 hidden_dim 维度。

前向传播:
在 forward 方法中,定义了前向传播的计算过程

return self.w2(F.silu(self.w1(x)) * self.w3(x))

self.w1(x):对输入 x 进行第一次线性变换。
F.silu(self.w1(x)):对线性变换的结果应用SiLU激活函数。
self.w3(x):对输入 x 进行第二次线性变换。
F.silu(self.w1(x)) * self.w3(x):将激活后的结果与 self.w3(x) 的结果相乘。
self.w2(…):将乘积结果通过 self.w2 线性变换映射回原始维度 dim。

LLama3和Transformer中的FFN的比较

在这里插入图片描述

相同点

基本结构:
都包含两层线性变换和一个激活函数。
都使用残差连接和归一化来增强模型的稳定性。

激活函数:
虽然实现有所不同,但都通过激活函数(如ReLU或SiLU)引入非线性。

不同点
激活函数:
Transformer使用ReLU激活函数,而LLama3使用SiLU(Swish Linear Unit),定义为:
SiLU ( x ) = x ⋅ sigmoid ( x ) \text{SiLU}(x) = x \cdot \text{sigmoid}(x) SiLU(x)=xsigmoid(x)

线性层组合:
Transformer的FFN是两个顺序的线性层:

out = self.linear2(F.relu(self.linear1(x)))

LLama3的FFN则是两个线性层的组合,包含一个乘积操作:

return self.w2(F.silu(self.w1(x)) * self.w3(x))

维度调整:
LLama3中加入了一个可选的维度调整因子和倍数约束,以确保隐藏层维度符合某些特定的需求。

ReLU:
简单且高效,适合大多数应用。
但在负输入时,输出恒为零,可能导致部分神经元在训练过程中“死亡”。
SiLU:
平滑的非线性转换,梯度在输入的正负范围内都能够有效传播。
相比ReLU,更能捕获输入的细微变化,但计算复杂度略高。

ReLU和SiLU 可视化比较

import matplotlib.pyplot as plt
import numpy as np

# 定义ReLU和SiLU函数
def relu(x):
    return np.maximum(0, x)

def silu(x):
    return x / (1 + np.exp(-x))

# 创建输入数据
x = np.linspace(-10, 10, 400)

# 计算ReLU和SiLU的输出
y_relu = relu(x)
y_silu = silu(x)

# 绘图
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.plot(x, y_relu, label='ReLU')
plt.title('ReLU Activation Function')
plt.xlabel('Input')
plt.ylabel('Output')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(x, y_silu, label='SiLU', color='orange')
plt.title('SiLU Activation Function')
plt.xlabel('Input')
plt.ylabel('Output')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

在这里插入图片描述

官方的llama 3 的代码使用FairScale ,上面的代码为了分析流程,没有使用FairScale。

FairScale 简单介绍下
FairScale是一个由Facebook AI Research(FAIR)团队开发的用于PyTorch的开源库,旨在简化大规模深度学习模型的训练和推理。FairScale提供了多种优化工具和模块,帮助研究人员和工程师更高效地进行分布式训练和模型并行化。

FairScale的主要功能
分布式数据并行(Distributed Data Parallel, DDP):
提供增强的分布式数据并行功能,相比于PyTorch自带的DDP模块,FairScale的DDP具有更高的灵活性和性能优化。

分布式模型并行(Distributed Model Parallel, DMP):
允许将模型的不同部分分布到多个设备上,从而使得超大规模模型的训练成为可能。

梯度检查点(Gradient Checkpointing):
通过在反向传播过程中保存和重用部分计算结果,减少显存占用,从而训练更大规模的模型。

优化器状态并行(Optimizer State Sharding):
将优化器的状态分片到多个设备上,从而降低单个设备的显存需求。

张量并行(Tensor Parallelism):
支持在多个设备间并行执行张量计算,进一步提升大规模模型的训练效率

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/703467.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MySQL-子查询(DQL 结束)

054-where后面使用子查询 什么是子查询 select语句中嵌套select语句就叫做子查询。select语句可以嵌套在哪里? where后面、from后面、select后面都是可以的。 select ..(select).. from ..(select).. where ..(select)..where后面使用子查询 案例:找…

C++中stack和queue

前言 在 C 中,stack(栈)和 queue(队列)是两种常用的容器适配器,分别用于管理数据的后进先出(LIFO)和先进先出(FIFO)访问模式。本文将详细介绍这两种数据结构的…

C#开源软件:OneNote组件oneMore轻松打造自己的公众号编辑器

OneMore是一款为Microsoft OneNote设计的插件,它提供了许多扩展功能来增强OneNote的使用体验。 插件功能概述: OneMore插件拥有多达一百多个扩展功能,这些功能覆盖了笔记编辑、搜索、导出等多个方面,为用户提供了更加便捷和高效的…

【项目实战】--云备份系统

1、云备份认识 自动将本地计算机上指定文件夹中需要备份的文件上传备份到服务器中。并且能够随时通过浏览器进行查看并且下载,其中下载过程支持断点续传功能,而服务器也会对上传为文件进行热点管理,将非热点文件进行压缩存储,节省…

openGauss 6.0.0 一主二备集群安装及使用zcbus实现Oracle到openGauss的数据同步

一、前言 openGauss 6.0.0-RC1是openGauss 2024年3月发布的创新版本,该版本生命周期为0.5年。根据openGauss官网介绍,6.0.0-RC1与之前的版本特性功能保持兼容,另外,在和之前版本兼容的基础上增加了很多新功能,比如分区表性能优化…

Java集合自测题

文章目录 一、说说 List , Set , Map 三者的区别?二、List , Set , Map 在 Java 中分别由哪些对应的实现类?底层的数据结构?三、有哪些集合是线程不安全的?怎么解决呢?四、HashMap 查询,删除的时间复杂度五…

autosleep框架设计与实现

在低功耗系统中,autosleep是一个较小的模块,是低功耗主流程的入口。在Linux内核中,autosleep是休眠流程的触发点和入口点,PM Core的休眠流程入口pm_suspend()就是被autosleep的睡眠工作队列调用而进入休眠的。 该功能的支持受宏…

MyBatis 参数上的处理的细节内容

1. MyBatis 参数上的处理的细节内容 文章目录 1. MyBatis 参数上的处理的细节内容2. MyBatis 参数上的处理3. 准备工作4. 单个(一个)参数4.1 单个(一个)简单类型作为参数4.2 单个(一个) Map集合 作为参数4.3 单个(一个) 实体类POJO作为参数 5. 多个参数5.1 Param注解(命名参数)…

计算机相关专业的探讨

目录 一、计算机相关专业是否仍是“万金油”选择 二、计算机行业的未来发展态势 三、从专业与个人的匹配度判断选择计算机相关专业 四、对于高考生的建议 一、计算机相关专业是否仍是“万金油”选择 计算机相关专业在过去很长一段时间内确实被视为“万金油”专业&#xff0…

算法训练营day06--242.有效的字母异位词+349. 两个数组的交集+202. 快乐数+1. 两数之和

一、242.有效的字母异位词 题目链接:https://leetcode.cn/problems/valid-anagram/description/ 文章讲解:https://programmercarl.com/0242.%E6%9C%89%E6%95%88%E7%9A%84%E5%AD%97%E6%AF%8D%E5%BC%82%E4%BD%8D%E8%AF%8D.html 视频讲解:http…

电视剧推荐

1、《春色寄情人》 2、《唐朝诡事录》 3、《南来北往》 4、《与凤行》 5、《利剑玫瑰》 6、《承欢记》

【教程】使用立创EDA打开JSON格式的PCB及原理图

这里写目录标题 一、将PCB和原理图放同一文件夹二、打开嘉立创EDA并导入.zip文件三、选择.zip文件并选择 “导入文件并提取库” 一、将PCB和原理图放同一文件夹 并打包成.zip文件 二、打开嘉立创EDA并导入.zip文件 嘉立创 我这里用的网页端,客户端下载页面拉到…

html的网页制作代码分享

<!-- prj_8_2.html --> <!DOCTYPE html> <html lang "EN"><head><meta charset"utf-8" /><title>页面布局设计</title><style type "text/css">*{padding: 0px;margin:0px;}#header{back…

Spring中IOC容器

IoC IOC容器 IoC是一种设计思想&#xff0c;面向对象编程 Spring通过IoC管理所有Java对象的实例化和初始化&#xff0c;控制对象之间依赖关系 将IoC容器管理的Java对象称为Spring Bean&#xff0c;与new创建的对象没有区别 控制反转&#xff08;IoC Inversion of Controle&a…

世优科技AI数字人多模态交互系统“世优波塔”正式发布

2024年6月6日&#xff0c;世优科技“波塔发布会”在北京举办&#xff0c;本次发布会上&#xff0c;世优科技以全新的“波塔”产品诠释了更高效、更智能、更全面的AI数字人产品及软硬件全场景解决方案&#xff0c;实现了世优品牌、产品和价值的全面跃迁。来自行业协会、数字产业…

大众点评全国丽人POI采集225万家-2024年5月底

大众点评全国丽人POI采集225万家-2024年5月底 店铺POI点位示例&#xff1a; 店铺id Hav6zIYtzhyyopIZ 店铺名称 防屏蔽 十分制服务评分 8.9 十分制环境评分 8.9 十分制划算评分 8.9 人均价格 210 评价数量 19935 店铺地址 建北一支路观音桥步行街红鼎国际A座9-9 店铺…

英伟达GPU对比分析:A100、A800、H100与H800

在当今技术迅速发展的时代&#xff0c;英伟达的GPU产品线提供了多种高性能选项&#xff0c;以满足不同类型的工作负载需求。本文将对英伟达的四种GPU型号——A100、A800、H100和H800进行深入对比分析&#xff0c;探讨它们在性能、架构、应用场景等方面的差异&#xff0c;以帮助…

LIN 入门(1)

1、概述 LIN 是什么 LIN 是 Local Interconnect Network 的缩写&#xff0c;是基于 UART/SCI(Universal Asynchronous Receiver-Transmitter / Serial Communication Interface&#xff0c;通用异步收发器/串行通信接口)的低成本串行通信协议。可用于汽车、家电、办 公设备等…

代码随想录-二叉树 | 111 二叉树的最小深度

代码随想录-二叉树 | 111 二叉树的最小深度 LeetCode 111 二叉树的最小深度解题思路代码难点总结 LeetCode 111 二叉树的最小深度 题目链接 代码随想录 题目描述 给定一个二叉树&#xff0c;找出其最小深度。 最小深度是从根节点到最近叶子节点的最短路径上的节点数量。 说…

讯飞星火模型-语音转文字实现

目录 项目结构 准备音频 接口Demo 准备代码&#xff08;完整修改后&#xff09; 测试提取中文文字代码 结果 下载链接&#xff1a; 这是上周打算试试&#xff0c;提取视频文字之后&#xff0c;制作视频字幕&#xff0c;从而想用大模型来实现&#xff0c;基本的demo可以在…