pytorch06:权重初始化

在这里插入图片描述

目录

  • 一、梯度消失和梯度爆炸
    • 1.1相关概念
    • 1.2 代码实现
    • 1.3 实验结果
    • 1.4 方差计算
    • 1.5 标准差计算
    • 1.6 控制网络层输出标准差为1
    • 1.7 带有激活函数的权重初始化
  • 二、Xavier方法与Kaiming方法
    • 2.1 Xavier初始化
    • 2.2 Kaiming初始化
    • 2.3 常见的初始化方法
  • 三、nn.init.calculate_gain

一、梯度消失和梯度爆炸

1.1相关概念

一个简易三层全连接神经网络图和神经元计算如下:
在这里插入图片描述
观察第二个隐藏层的权值的梯度是如何求取的,根据链式法则,可以得到如下计算公式,会发现w2的梯度依赖上一层的输出值H1;
在这里插入图片描述
当H1趋近于0的时候,W2的梯度也趋近于0;—>梯度消失
当H1趋近于无穷的时候,W2的梯度也趋近于无穷;—>梯度爆炸
在这里插入图片描述
一旦出现梯度消失或者梯度爆炸就会导致模型无法训练;

1.2 代码实现

import os
import torch
import random
import numpy as np
import torch.nn as nn
from common_tools import set_seed

set_seed(1)  # 设置随机种子


class MLP(nn.Module):
    def __init__(self, neural_num, layers):
        super(MLP, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
        self.neural_num = neural_num

    def forward(self, x):
        for (i, linear) in enumerate(self.linears):
            x = linear(x)
            # x = torch.relu(x)
            # x = torch.tanh(x)

            print("layer:{}, std:{}".format(i, x.std()))  # 打印当前值的标准差
            if torch.isnan(x.std()):  # 判断是什么时候标准差为nan
                print("output is nan in {} layers".format(i))
                break

        return x

    # 权值初始化函数
    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):  # 判断当前网络层是否是线性层,如果是就进行权值初始化
                nn.init.normal_(m.weight.data)  # normal: mean=0, 控制标准差std在1左右
                # nn.init.normal_(m.weight.data, std=np.sqrt(1 / self.neural_num))

                # =======这段代码的目的是通过均匀分布初始化并结合tanh激活函数的特性,为神经网络的某一层(线性层)初始化合适的权重
                # a = np.sqrt(6 / (self.neural_num + self.neural_num))
                # tanh_gain = nn.init.calculate_gain('tanh')
                # a *= tanh_gain
                # nn.init.uniform_(m.weight.data, -a, a)
                # 将权重矩阵的值初始化为在 [-a, a] 范围内均匀分布的随机数。这个范围是通过之前的计算和调整得到的,目的是使得权重初始化在一个合适的范围内

                # nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)

                # ================凯明初始化方法================
                # nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))  # 适合relu激活函数初始化 凯明初始化手动计算方法
                # nn.init.kaiming_normal_(m.weight.data)


# flag = 0
flag = 1

if flag:
    layer_nums = 100  # 100层线性层
    neural_nums = 256  # 每增加一层网络 标准差扩大根号n倍
    batch_size = 16

    net = MLP(neural_nums, layer_nums)
    print(net)
    net.initialize()

    inputs = torch.randn((batch_size, neural_nums))  # normal: mean=0, std=1

    output = net(inputs)
    print(output)

1.3 实验结果

这里的初始化使用的是标准正态分布normal: mean=0, 控制标准差std在1左右的方法;
在这里插入图片描述
当输出层达到33层后就会出现梯度爆炸,超出了数据精度可以表示的范围。

1.4 方差计算

在这里插入图片描述
1.期望的计算公式
2,3.是方差的计算公式
根据1,2,3,可以得出,x,y的方差计算公式,当x,y的期望值都为0的时候,x,y的方差等于x的方差乘以y的方差。

1.5 标准差计算

在这里插入图片描述
通过计算可以得出每增加一层网络,标准差增加 n \sqrt{n} n ,n也就是神经元的个数;
代码展示:

if flag:
    layer_nums = 100  # 100层线性层
    neural_nums = 256  # 神经元个数 每增加一层网络 标准差扩大根号n倍
    batch_size = 16

执行结果:
可以看出第一层标准差是15.95,第二次标准差在上一层的基础上再乘以 256 \sqrt{256} 256
在这里插入图片描述

1.6 控制网络层输出标准差为1

从1.5可以看出D(H)的大小有三个因素决定,分别是n、D(X)、D(w),所以只要保证这三者乘积为1,就可以保证D(H)的值为1;
在这里插入图片描述
当我们权值的标准差为 1 / n \sqrt{1/n} 1/n ,那么就能保证网络层每一层的输出标准差都为1;

代码实现:
在这里插入图片描述

输出结果:
在这里插入图片描述
通过输出结果可以发现,几乎每一层网络输出的标准差都为1.

1.7 带有激活函数的权重初始化

在forward函数里面添加tanh激活函数
在这里插入图片描述
执行结果:
增加tanh激活函数之后,随着网络层的增加,标准差越来越小,从而会导致梯度消失的现象,下面将说明Xavier方法与Kaiming方法是如何解决该问题。
在这里插入图片描述

二、Xavier方法与Kaiming方法

2.1 Xavier初始化

方差一致性:保持数据尺度维持在恰当范围,通常方差为1
激活函数:饱和函数,如Sigmoid,Tanh
Xavier初始化公式如下:
在这里插入图片描述

代码实现:
手动代码实现
在这里插入图片描述

直接使用pytorch提供的xavier_uniform_函数方法

nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)

执行结果:
在这里插入图片描述
可以看到,每一层的网络输出标准差都在0.6左右

2.2 Kaiming初始化

当我们使用带有权值初始化的relu激活函数时,输出结果如下,会发现标准差随着网络层的增加逐渐减小,Kaiming初始化解决了这一问题。
在这里插入图片描述
在这里插入图片描述

方差一致性:保持数据尺度维持在恰当范围,通常方差为1
激活函数:ReLU及其变种
公式如下:
在这里插入图片描述

代码实现:

# ================凯明初始化方法================
nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))  # 适合relu激活函数初始化 凯明初始化手动计算方法
# nn.init.kaiming_normal_(m.weight.data)  # 使用pytorch自带方法

输出结果:
在这里插入图片描述

2.3 常见的初始化方法

  1. Xavier均匀分布
  2. Xavier正态分布
  3. Kaiming均匀分布
  4. Kaiming正态分布
  5. 均匀分布
  6. 正态分布
  7. 常数分布
  8. 正交矩阵初始化
  9. 单位矩阵初始化
  10. 稀疏矩阵初始化

三、nn.init.calculate_gain

主要功能:计算激活函数的方差变化尺度(也就是输入数据的方差/经过激活函数之后的方差)
主要参数
• nonlinearity: 激活函数名称
• param: 激活函数的参数,如Leaky ReLU的negative_slop

代码实现:

flag = 1

if flag:
    x = torch.randn(10000)
    out = torch.tanh(x)

    gain = x.std() / out.std()  # 手动计算
    print('gain:{}'.format(gain))

    tanh_gain = nn.init.calculate_gain('tanh')  # pytorch自带函数
    print('tanh_gain in PyTorch:', tanh_gain)

输出结果:
在这里插入图片描述
总结:任何数据在经过tanh激活函数之后,方差缩小大约1.6倍。感兴趣的话也可以使用relu进行实验,最后我的到的结果方差尺度大约是1.4左右。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/289803.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

程序媛的mac修炼手册--MacOS系统更新升级史

啊,我这个口罩三年从未感染过新冠的天选免疫王,却被支原体击倒😷大意了,前几天去医院体检,刚检查完出医院就摘口罩了🤦大伙儿还是要注意戴口罩,保重身体啊!身体欠恙,就闲…

Excel 插件:ASAP Utilities Crack

ASAP Utilities是一款功能强大的 Excel 插件,填补了 Excel 的空白。在过去的 20 年里,我们的加载项已经发展成为世界上最受欢迎的 Microsoft Excel 加载项之一。 ASAP Utilities 中的功能数量(300 多个)可能看起来有点令人眼花缭乱…

Navicat 技术干货 | 聚合查询的介绍

基础 SQL 查询可以检索、插入、更新和删除记录,而聚合查询可通过提供求和、平均值或最大/最小值等的大型结果集,将数据库交互提升到一个新的水平。本文中,我们将探索聚合 SQL 查询的基础知识,并研究如何有效的利用他们来分析和汇总…

Unity中URP下的线性雾

文章目录 前言一、线性雾 雾效因子二、MixFog1、ComputeFogIntensity 雾效强度计算2、雾效颜色混合 lerp(fogColor, fragColor, fogIntensity); 前言 在之前的文章中,我们实现了URP下的雾效支持。 Unity中URP下的添加雾效支持 在上一篇文章中,我们解析了 URP 下统…

HCIA-Datacom题库(自己整理分类的)——其他网络协议【完】

(一)单选 下列属于链路状态协议的是? Direct static FTP OSPF 解析: FTP:文件传输协议 OSPF:链路状态路由协议 如下图所示的网络主机A通过Telnet登录到路由器A然后在远程的界面通过FTP获取路由器的配置文件&…

【已解决】Invalid bound statement (not found)

报错讯息 org.apache.ibatis.binding.BindingException: Invalid bound statement (not found): com.casey.mapper.SysRoleMapper.getUserRoleCode at org.apache.ibatis.binding.MapperMethod S q l C o m m a n d . < i n i t > ( M a p p e r M e t h o d . j a v a :…

iec104和iec61850

iec104和iec61850 IEC104 规约详细解读(一) 协议结构 IEC104 规约详细解读(二)交互流程以及协议解析 61850开发知识总结与分享【1】 Get the necesarry projects next to each other in the same directory; $ git clone https://github.com/robidev/iec61850_open_server.g…

【GoLang入门教程】Go语言几种标准库介绍(四)

编程语言的未来&#xff1f; 文章目录 编程语言的未来&#xff1f;前言几种库fmt库 (格式化操作)关键函数&#xff1a;示例 Go库标准库第三方库示例 html库(HTML 转义及模板系统)主要功能&#xff1a;示例 总结专栏集锦写在最后 前言 上一篇&#xff0c;我们介绍了debug、enco…

Vue2 - diff 原理(动图演示)

目录 1&#xff0c;diffdiff 的时间点 2&#xff0c;_update 函数3&#xff0c;_patch 函数&#xff08;进行 diff&#xff09;3.1&#xff0c;根节点比较3.2&#xff0c;子节点比较 4&#xff0c;key的问题举例1举例2 1&#xff0c;diff 解释&#xff1a;对比新旧虚拟DOM树&a…

很实用的ChatGPT网站—在线编程模块增补篇

很实用的ChatGPT网站&#xff08;http://chat-zh.com/&#xff09;——增补篇 今天介绍一个好兄弟开发的ChatGPT网站&#xff0c;网址[http://chat-zh.com/]。这个网站功能模块很多&#xff0c;包含生活、学习、医疗、法律、经济等很多方面。今天跟大家分享一下&#xff0c;新…

探索模块化神经网络在现代人工智能中的功效和应用

一、介绍 在快速发展的人工智能领域&#xff0c;模块化神经网络 (MNN) 已成为一项关键创新。与遵循整体方法的传统神经网络架构不同&#xff0c;MNN 采用分散式结构。本文深入探讨了 MNN 的基础知识、它们的优势、应用以及它们带来的挑战。 evertongomede 在人工智能领域&#…

【Java期末】学生成绩管理系统

诚接计算机专业编程任务(C语言、C、Python、Java、HTML、JavaScript、Vue等)10/15R&#xff0c;如有需要请私信我&#xff0c;或者加我的企鹅号&#xff1a;1404293476 本文资源下载地址&#xff1a;https://download.csdn.net/download/weixin_47040861/88697244 —————…

imgaug库指南(一):从入门到精通的【图像增强】之旅

文章目录 引言imgaug简介安装和导入imgaug代码示例imgaug的强大之处和用途小结结尾 引言 在深度学习和计算机视觉的世界里&#xff0c;数据是模型训练的基石&#xff0c;其质量与数量直接影响着模型的性能。然而&#xff0c;获取大量高质量的标注数据往往需要耗费大量的时间和…

Spark---RDD算子(单值类型Value)

文章目录 1.RDD算子介绍2.转换算子2.1 Value类型2.1.1 map2.1.2 mapPartitions2.1.3 mapPartitionsWithIndex2.1.4 flatMap2.1.5 glom2.1.6 groupBy2.1.7 filter2.1.8 sample2.1.9 distinct2.1.10 coalesce2.1.11 repartition2.1.12 sortBy 1.RDD算子介绍 RDD算子是用于对RDD进…

【数据结构】循环队列(数组实现)

目录 一、循环队列定义 怎么使一个数组在逻辑上呈“环状”呢&#xff1f; 二、循环队列与顺序队列的差异 1、存储方式: 2、操作方式: 3、空间利用率&#xff1a; 4、循环队列判断队空的方式&#xff1a; 5、循环队列判断队满的方式 完整测试代码及注释&#xff1a; 总…

Vue 中的 ref 与 reactive:让你的应用更具响应性(上)

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

用HTML的原生语法实现两个div子元素在同一行中排列

代码如下&#xff1a; <div id"level1" style"display: flex;"><div id"level2-1" style"display: inline-block; padding: 10px; border: 1px solid #ccc; margin: 5px;">这是第一个元素。</div><div id"…

计算机系统基础

C 语言相关内容省略&#xff0c;复习自用&#xff0c;仅供参考~ 概述 冯诺伊曼结构 存储程序工作方式&#xff1a;将事先编好的程序和原始数据送入主存后才能执行程序&#xff0c;程序被启动执行后&#xff0c;计算机能在不需要操作人员干预下自动完成逐条指令取出和执行的任…

解析为什么Go语言要使用[]rune而不是string来表示中文字符

众所周知&#xff0c;Go语言中有以下这些数据类型。但rune32这个go语言特有的数据类型&#xff0c;比较有意思却经常遭到忽视。所以今天探索学习一下这个数据类型的功能、用法。 Go基本数据类型 布尔&#xff1a;bool 字符串&#xff1a;string 整数&#xff1a; int int8 …

NNDL 作业13 优化算法3D可视化 [HBU]

老师作业原博客:【23-24 秋学期】NNDL 作业13 优化算法3D可视化-CSDN博客 NNDL 作业13 优化算法3D可视化-CSDN博客 编程实现优化算法&#xff0c;并3D可视化 1. 函数3D可视化 分别画出 和 的3D图 NNDL实验 优化算法3D轨迹 鱼书例题3D版_优化算法3d展示-CSDN博客 代码&#…