Chapter 4.4:Adding shortcut connections

4 Implementing a GPT model from Scratch To Generate Text

4.4 Adding shortcut connections

  • 接下来,让我们讨论 shortcut connections(快捷连接)背后的概念,也称为 skip connections(跳跃连接)或 residual connections(残差连接)。最初,在残差网络(ResNet)中提出,用于缓解梯度消失问题。梯度消失问题指的是梯度(在训练过程中指导权重更新)在向后传播通过各层时逐渐变小,导致难以有效训练较早的层,如下图所示

    对比一个由 5 层组成的深度神经网络,左侧没有快捷连接,右侧带有快捷连接。快捷连接涉及将某一层的输入与其输出相加,从而有效地创建一条绕过某些层的替代路径

    工作原理

    1. 创建更短的梯度路径,跳过中间层。
    2. 通过将某一层的输出与后面某一层的输出相加实现。
  • 前向方法中添加快捷连接

    class ExampleDeepNeuralNetwork(nn.Module):
        def __init__(self, layer_sizes, use_shortcut):
            super().__init__()
            self.use_shortcut = use_shortcut
            self.layers = nn.ModuleList([
                nn.Sequential(nn.Linear(layer_sizes[0], layer_sizes[1]), GELU()),
                nn.Sequential(nn.Linear(layer_sizes[1], layer_sizes[2]), GELU()),
                nn.Sequential(nn.Linear(layer_sizes[2], layer_sizes[3]), GELU()),
                nn.Sequential(nn.Linear(layer_sizes[3], layer_sizes[4]), GELU()),
                nn.Sequential(nn.Linear(layer_sizes[4], layer_sizes[5]), GELU())
            ])
    
        def forward(self, x):
            for layer in self.layers:
                # Compute the output of the current layer
                layer_output = layer(x)
                # Check if shortcut can be applied
                if self.use_shortcut and x.shape == layer_output.shape:
                    x = x + layer_output
                else:
                    x = layer_output
            return x
    

    该代码实现了一个包含 5 层的深度神经网络,每层由一个 Linear layer(线性层)和一个 GELU activation function(GELU 激活函数)组成。在前向传播过程中,我们迭代地将输入传递到各层,如果 self.use_shortcut 属性设置为 True,则可以选择性地添加上面图中的 shortcut connections(快捷连接)。

    def print_gradients(model, x):
        # Forward pass
        output = model(x)
        target = torch.tensor([[0.]])
    
        # Calculate loss based on how close the target
        # and output are
        loss = nn.MSELoss()
        loss = loss(output, target)
        
        # Backward pass to calculate the gradients
        loss.backward()
    
        for name, param in model.named_parameters():
            if 'weight' in name:
                # Print the mean absolute gradient of the weights
                print(f"{name} has gradient mean of {param.grad.abs().mean().item()}")
    

    接下来,我们实现一个计算模型向后传递中的梯度的函数:

    def print_gradients(model, x):
        # Forward pass
        output = model(x)
        target = torch.tensor([[0.]])
    
        # Calculate loss based on how close the target
        # and output are
        loss = nn.MSELoss()
        loss = loss(output, target)
        
        # Backward pass to calculate the gradients
        loss.backward()
    
        for name, param in model.named_parameters():
            if 'weight' in name:
                # Print the mean absolute gradient of the weights
                print(f"{name} has gradient mean of {param.grad.abs().mean().item()}")
    

    在前面的代码中,我们定义了一个损失函数来计算模型输出与目标值(如 0)的接近程度,并通过调用 loss.backward() 自动计算每一层的损失梯度。使用 model.named_parameters() 可以遍历权重参数,例如对于 3×3 的权重矩阵,计算其 3×3 梯度值的平均绝对梯度,从而得到每一层的单一梯度值,便于比较各层梯度。.backward() 方法的优势在于自动完成梯度计算,无需手动实现数学过程,极大地简化了深度神经网络的训练和使用。

    接下来我们首先打印没有使用shortcut的梯度

    # 未使用shortcut
    layer_sizes = [3, 3, 3, 3, 3, 1]  
    sample_input = torch.tensor([[1., 0., -1.]])
    
    torch.manual_seed(123)
    model_without_shortcut = ExampleDeepNeuralNetwork(
        layer_sizes, use_shortcut=False
    )
    print_gradients(model_without_shortcut, sample_input)
    
    """输出"""
    layers.0.0.weight has gradient mean of 0.00020173587836325169
    layers.1.0.weight has gradient mean of 0.0001201116101583466
    layers.2.0.weight has gradient mean of 0.0007152041653171182
    layers.3.0.weight has gradient mean of 0.001398873864673078
    layers.4.0.weight has gradient mean of 0.005049646366387606
    

    接着打印使用shortcut的梯度

    # 使用shortcut
    torch.manual_seed(123)
    model_with_shortcut = ExampleDeepNeuralNetwork(
        layer_sizes, use_shortcut=True
    )
    print_gradients(model_with_shortcut, sample_input)
    
    """输出"""
    layers.0.0.weight has gradient mean of 0.22169792652130127
    layers.1.0.weight has gradient mean of 0.20694106817245483
    layers.2.0.weight has gradient mean of 0.32896995544433594
    layers.3.0.weight has gradient mean of 0.2665732502937317
    layers.4.0.weight has gradient mean of 1.3258541822433472
    

    根据上面的输出结果可以看出,shortcut connections(快捷连接)防止了梯度在早期层(如 layer.0)中消失,确保梯度的有效传播。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/950049.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Web渗透测试之XSS跨站脚本 原理 出现的原因 出现的位置 测试的方法 危害 防御手段 面试题 一篇文章给你说的明明白白

目录 XSS介绍的原理和说明 Cross Site Scripting 钓鱼 XSS攻击原理 XSS漏洞出现的原因: XSS产生的原因分析 XSS出现位置: XSS测试方法 XSS的危害 防御手段: 其它防御 面试题: 备注: XSS介绍的原理和说明 嵌入在客户…

热门数据手套对比,应用方向有何不同?

AI与人形机器人是目前市场中大热的两个新行业。在人形机器人或拟人仿真机器人制造与开发中动作捕捉技术的融入是必不可少的,通过将动捕数据与先进的AI大数据训练技术相结合,不仅能够省去枯燥乏味的动作编程过程大幅减少训练时间,还可以使训练…

dbt Semantic Layer 详细教程-1 :总体概述

dbt 语义模型提供语言描述方式快速定义业务指标。本文介绍语义模型作用和意义,以及语义模型的组成部分,后面会继续介绍如何定义语义模型,基于语义模型定义指标,如何通过MetricFlow(语义层框架)能够构建用于…

JAVA:探讨 CopyOnWriteArrayList 的详细指南

1、简述 在 Java 的并发编程中,CopyOnWriteArrayList 是一种特殊的线程安全的集合类。它位于 java.util.concurrent 包中,主要用于在并发读写场景下提供稳定的性能。与传统的 ArrayList 不同,CopyOnWriteArrayList 通过在每次修改时创建一个…

简单编程实现QT程序黑色主题显示

代码如下 int main(int argc, char *argv[]) {QApplication a(argc, argv);//QSurfaceFormat::setDefaultFormat(QVTKOpenGLStereoWidget::defaultFormat());QPalette darkpalette;a.setStyle(QStyleFactory::create("Fusion"));darkpalette.setColor(QPalette::Wind…

沁恒CH32V208GBU6外设PWM:注意分辨时钟使能函数RCC_APB2PeriphClockCmd;PWM模式1和模式2的区别;PWM动态开启和关闭

从事嵌入式单片机的工作算是符合我个人兴趣爱好的,当面对一个新的芯片我即想把芯片尽快搞懂完成项目赚钱,也想着能够把自己遇到的坑和注意事项记录下来,即方便自己后面查阅也可以分享给大家,这是一种冲动,但是这个或许并不是原厂希望的,尽管这样有可能会牺牲一些时间也有哪天原…

飞书企业消息实践

一、飞书自带的消息机器人限制 频控策略 - 服务端 API - 飞书开放平台 自定义机器人的频率控制和普通应用不同,为单租户单机器人 100 次/分钟,5 次/秒。建议发送消息尽量避开诸如 10:00、17:30 等整点及半点时间,否则可能出现因系统压力导致…

0107作业

思维导图 练习: 要求在堆区连续申请5个int的大小空间用于存储5名学生的成绩&#xff0c;分别完成空间的申请、成绩的录入、升序 排序、 成绩输出函数以及空间释放函数&#xff0c;并在主程序中完成测试 要求使用new和delete完成 #include <iostream>using namespace std…

以C++为基础快速了解C#

using System: - using 关键字用于在程序中包含 System 命名空间。 一个程序一般有多个 using 语句, 相当于C的 using namespace std; C# 是大小写敏感的。 所有的语句和表达式必须以分号&#xff08;;&#xff09;结尾。 程序的执行从 Main 方法开始。 与 Java 不同的是&#…

面试题:并发与并行的区别?

并发&#xff08;Concurrency&#xff09;和并行&#xff08;Parallelism&#xff09;是计算机科学中两个相关但不同的概念&#xff0c;它们都涉及到同时处理多个任务&#xff0c;但在实现方式和效果上有显著的区别。理解这两者的区别对于编写高效的多任务程序非常重要。 并发&…

面向对象分析和设计OOA/D,UML,GRASP

目录 什么是分析和设计&#xff1f; 什么是面向对象的分析和设计&#xff1f; 迭代开发 UML 用例图 交互图 基于职责驱动设计 GRASP 常见设计原则 什么是分析和设计&#xff1f; 分析&#xff0c;强调是对问题和需求的调查研究&#xff0c;不是解决方案。例如&#x…

MySQL使用navicat新增触发器

找到要新增触发器的表&#xff0c;然后点击设计&#xff0c;找到触发器标签。 根据实际需要&#xff0c;填写相关内容&#xff0c;操作完毕&#xff0c;点击保存按钮。 在右侧的预览界面&#xff0c;可以看到新生成的触发器脚本

Anthropic 的人工智能 Claude 表现优于 ChatGPT

在人工智能领域&#xff0c;竞争一直激烈&#xff0c;尤其是在自然语言处理&#xff08;NLP&#xff09;技术的发展中&#xff0c;多个公司都在争夺市场的主导地位。OpenAI的ChatGPT和Anthropic的Claude是目前最具影响力的两款对话型AI产品&#xff0c;它们都能够理解并生成自然…

《罪恶装备-奋战》官方中文学习版

《罪恶装备 -奋战-》是Arc System Works开发的格斗游戏《罪恶装备》系列的第二十五部作品 [1]&#xff0c;男主角索尔历时25年的故事就此画上句号&#xff0c;而罪恶装备的故事却并未结束。 《罪恶装备-奋战》官方中文版 https://pan.xunlei.com/s/VODWAm1Dv-ZWVvvmUMflgbbxA1…

期末概率论总结提纲(仅适用于本校,看文中说明)

文章目录 说明A选择题1.硬币2.两个事件的关系 与或非3.概率和为14.概率密度 均匀分布5.联合分布率求未知参数6.联合分布率求未知参数7.什么是统计量&#xff08;记忆即可&#xff09;8.矩估计量9.117页12题10.显著水平阿尔法&#xff08;背公式就完了&#xff09; 判断题11.事件…

流程图(四)利用python绘制漏斗图

流程图&#xff08;四&#xff09;利用python绘制漏斗图 漏斗图&#xff08;Funnel Chart&#xff09;简介 漏斗图经常用于展示生产经营各环节的关键数值变化&#xff0c;以较高的头部开始&#xff0c;较低的底部结束&#xff0c;可视化呈现各环节的转化效率与变动大小。一般重…

继承(5)

大家好&#xff0c;今天我们继续来学习继承的相关知识&#xff0c;来看看子类构造方法&#xff08;也叫做构造器&#xff09;是如何做的。 1.6 子类构造方法 父子父子,先有父再有子,即:子类对象构选时,需要先调用基类构造方法,然后执行子类的构造方法 ★此时虽然执行了父类的…

Vue框架主要用来做什么?Vue框架的好处和特性.

在快速发展的互联网时代&#xff0c;前端开发技术的变革日新月异&#xff0c;为开发者带来了前所未有的机遇与挑战。Vue.js&#xff0c;作为前端开发领域的一颗璀璨新星&#xff0c;以其轻量级、高效灵活的特性&#xff0c;赢得了广大开发者的青睐。本文将深入探讨Vue框架的主要…

搭建Hadoop分布式集群

软件和操作系统版本 Hadoop框架是采用Java语言编写&#xff0c;需要java环境&#xff08;jvm&#xff09; JDK版本&#xff1a;JDK8版本 &#xff0c;本次使用的是 Java: jdk-8u431-linux-x64.tar.gz Hadoop: hadoop-3.3.6.tar.gz 三台Linux虚拟节点: CentOS-7-x86_64-DVD-2…

电子应用设计方案87:智能AI收纳箱系统设计

智能 AI 收纳箱系统设计 一、引言 智能 AI 收纳箱系统旨在为用户提供更高效、便捷和智能的物品收纳与管理解决方案&#xff0c;通过融合人工智能技术和创新设计&#xff0c;提升用户的生活品质和物品整理效率。 二、系统概述 1. 系统目标 - 实现物品的自动分类和整理&#xf…