神经网络系列---权重初始化方法


文章目录

    • 权重初始化方法
      • Xavier初始化(Xavier initialization)
      • Kaiming初始化,也称为He初始化
      • LeCun 初始化
      • 正态分布与均匀分布
      • Orthogonal Initialization
      • Sparse Initialization
      • n_in和n_out
      • 代码实现


权重初始化方法

Xavier初始化(Xavier initialization)

是一种用于初始化神经网络权重的方法,也称为Glorot初始化。更有效地传播信号并减少梯度消失或梯度爆炸的问题。适用于激活函数为tanh或sigmoid的情况。

Xavier初始化的计算方法如下:

  1. Glorot(或 Xavier)初始化
    • 适用于激活函数如sigmoid和tanh。
    • 初始化公式: σ = 2 n in + n out \sigma = \sqrt{\frac{2}{n_{\text{in}} + n_{\text{out}}}} σ=nin+nout2
      其中, n in n_{\text{in}} nin 是输入单元数, n out n_{\text{out}} nout 是输出单元数。

对于单个神经元的权重w,从均匀分布正态分布中随机采样,具体取决于所选择的激活函数:

在这里插入图片描述

  1. 如果使用tanh激活函数,从均匀分布采样:
    • 采样范围:[-sqrt(6 / (n_in + n_out)), sqrt(6 / (n_in + n_out))]
    • 其中n_in是上一层的输入节点数量,n_out是当前层的输出节点数量。

在这里插入图片描述

  1. 如果使用sigmoid激活函数,从正态分布采样:
    • 均值:0
    • 方差:sqrt(2 / (n_in + n_out))
    • 其中n_in是上一层的输入节点数量,n_out是当前层的输出节点数量。

Kaiming初始化,也称为He初始化

  1. He 初始化
    • 适用于ReLU及其变种(如LeakyReLU)激活函数。
    • 初始化公式: σ = 2 n in \sigma = \sqrt{\frac{2}{n_{\text{in}}}} σ=nin2

这种初始化方法主要用于修正线性单元(Rectified Linear Units,ReLU)激活函数的神经网络。

与Xavier初始化适用于tanh和sigmoid等S型激活函数不同,Kaiming初始化专门针对ReLU激活函数的特性进行优化。ReLU是一个常用的非线性激活函数,它在输入大于零时保持不变,在输入小于等于零时输出为零。

Kaiming初始化的计算方法如下:

对于单个神经元的权重w,从均匀分布或正态分布中随机采样,具体取决于所选择的激活函数:

  1. 如果使用ReLU激活函数,从正态分布采样:

    • 均值:0
    • 方差:sqrt(2 / n_in)
    • 其中n_in是上一层的输入节点数量。
  2. 对于带有ReLU激活的卷积层,可以使用相同的初始化方法,只是需要考虑卷积层的输入通道数量(即n_in)。

LeCun 初始化

  • 适用于Sigmoid激活函数。
  • 初始化公式: σ = 1 n in \sigma = \sqrt{\frac{1}{n_{\text{in}}}} σ=nin1

正态分布与均匀分布

  • 使用较小的标准差(如0.01)从正态分布中采样权重。
  • 使用较小的范围(如-0.01到0.01)从均匀分布中采样权重。

Orthogonal Initialization

  • 使用正交矩阵初始化权重。这种初始化方法对于某些任务和模型架构可能很有益。

Sparse Initialization

  • 将大部分权重初始化为0,只初始化一小部分非零的权重。

n_in和n_out

n_inn_out分别表示神经网络层的输入节点数量和输出节点数量。这些节点也称为神经元,它们是网络的基本组成部分。

  • n_in:代表上一层(前一层)的节点数量,也就是当前层的输入数量。在神经网络中,每个神经元都会接收来自上一层所有节点的输入,这些输入被加权和后传递给当前神经元的激活函数。因此,n_in指的是上一层与当前层之间的连接数量。

  • n_out:代表当前层的节点数量,也就是当前层的输出数量。每个神经元会将经过激活函数处理后的结果传递给下一层所有节点,形成下一层的输入。因此,n_out指的是当前层与下一层之间的连接数量。

代码实现

#include <iostream>
#include <Eigen/Dense>
#include <random>
#include <cmath>

Eigen::MatrixXd glorotInitialize(int rows, int cols);
Eigen::MatrixXd heInitialize(int rows, int cols);
Eigen::MatrixXd lecunInitialize(int rows, int cols);
Eigen::MatrixXd normalDistributionInitialize(int rows, int cols, double std_dev=0.01);
Eigen::MatrixXd uniformDistributionInitialize(int rows, int cols, double limit=0.01);
Eigen::MatrixXd orthogonalInitialize(int rows, int cols);
// Sparse Initialization需要额外参数来确定稀疏度,这里我们使用一个简化版本,指定一个非零的权重数。
Eigen::MatrixXd sparseInitialize(int rows, int cols, int nonZeroCount);



//1. **Glorot (Xavier) Initialization**:

Eigen::MatrixXd glorotInitialize(int rows, int cols) {
    std::random_device rd;
    std::mt19937 gen(rd());
    double limit = sqrt(6.0 / (rows + cols));
    std::uniform_real_distribution<> dis(-limit, limit);

    Eigen::MatrixXd matrix(rows, cols);
    for(int i = 0; i < rows; i++) {
        for(int j = 0; j < cols; j++) {
            matrix(i, j) = dis(gen);
        }
    }
    return matrix;
}


//**He Initialization**:

Eigen::MatrixXd heInitialize(int rows, int cols) {
    std::random_device rd;
    std::mt19937 gen(rd());
    double std_dev = sqrt(2.0 / rows);
    std::normal_distribution<> dis(0, std_dev);

    Eigen::MatrixXd matrix(rows, cols);
    for(int i = 0; i < rows; i++) {
        for(int j = 0; j < cols; j++) {
            matrix(i, j) = dis(gen);
        }
    }
    return matrix;
}


//3. **LeCun Initialization**:

Eigen::MatrixXd lecunInitialize(int rows, int cols) {
    std::random_device rd;
    std::mt19937 gen(rd());
    double std_dev = sqrt(1.0 / rows);
    std::normal_distribution<> dis(0, std_dev);

    Eigen::MatrixXd matrix(rows, cols);
    for(int i = 0; i < rows; i++) {
        for(int j = 0; j < cols; j++) {
            matrix(i, j) = dis(gen);
        }
    }
    return matrix;
}


//4. **Normal Distribution Initialization**:

Eigen::MatrixXd normalDistributionInitialize(int rows, int cols, double std_dev) {
    std::random_device rd;
    std::mt19937 gen(rd());
    std::normal_distribution<> dis(0, std_dev);

    Eigen::MatrixXd matrix(rows, cols);
    for(int i = 0; i < rows; i++) {
        for(int j = 0; j < cols; j++) {
            matrix(i, j) = dis(gen);
        }
    }
    return matrix;
}

//5. **Uniform Distribution Initialization**:

Eigen::MatrixXd uniformDistributionInitialize(int rows, int cols, double limit) {
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<> dis(-limit, limit);

    Eigen::MatrixXd matrix(rows, cols);
    for(int i = 0; i < rows; i++) {
        for(int j = 0; j < cols; j++) {
            matrix(i, j) = dis(gen);
        }
    }
    return matrix;
}


//6. **Orthogonal Initialization**:
Eigen::MatrixXd orthogonalInitialize(int rows, int cols) {
    // 创建一个随机矩阵
    std::random_device rd;
    std::mt19937 gen(rd());
    std::normal_distribution<> dis(0, 1);

    Eigen::MatrixXd randomMatrix(rows, cols);
    for(int i = 0; i < rows; i++) {
        for(int j = 0; j < cols; j++) {
            randomMatrix(i, j) = dis(gen);
        }
    }

    // 使用QR分解获得正交矩阵
    Eigen::HouseholderQR<Eigen::MatrixXd> qr(randomMatrix);
    Eigen::MatrixXd orthogonalMatrix = qr.householderQ();

    // 如果您需要一个具有特定维度的正交矩阵(例如rows != cols),您可以选择一个子矩阵
    return orthogonalMatrix.block(0, 0, rows, cols);
}


//7. **Sparse Initialization**:

Eigen::MatrixXd sparseInitialize(int rows, int cols, int nonZeroCount) {
    Eigen::MatrixXd matrix = Eigen::MatrixXd::Zero(rows, cols);
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<> dis(-1, 1);

    for(int i = 0; i < nonZeroCount; i++) {
        int r = rand() % rows;
        int c = rand() % cols;
        matrix(r, c) = dis(gen);
    }

    return matrix;
}
int main() {
    int rows = 5;
    int cols = 5;

    // Glorot Initialization
    Eigen::MatrixXd weights_glorot = glorotInitialize(rows, cols);
    std::cout << "Glorot Initialized Weights:" << std::endl << weights_glorot << std::endl << std::endl;

    // He Initialization
    Eigen::MatrixXd weights_he = heInitialize(rows, cols);
    std::cout << "He Initialized Weights:" << std::endl << weights_he << std::endl << std::endl;

    // LeCun Initialization
    Eigen::MatrixXd weights_lecun = lecunInitialize(rows, cols);
    std::cout << "LeCun Initialized Weights:" << std::endl << weights_lecun << std::endl << std::endl;

    // Normal Distribution Initialization
    Eigen::MatrixXd weights_normal = normalDistributionInitialize(rows, cols);
    std::cout << "Normal Distribution Initialized Weights:" << std::endl << weights_normal << std::endl << std::endl;

    // Uniform Distribution Initialization
    Eigen::MatrixXd weights_uniform = uniformDistributionInitialize(rows, cols);
    std::cout << "Uniform Distribution Initialized Weights:" << std::endl << weights_uniform << std::endl << std::endl;

    // Sparse Initialization
    int nonZeroCount = 10; // As an example, set 10 weights to non-zero values
    Eigen::MatrixXd weights_sparse = sparseInitialize(rows, cols, nonZeroCount);
    std::cout << "Sparse Initialized Weights with " << nonZeroCount << " non-zero values:" << std::endl << weights_sparse << std::endl;

    return 0;
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/411424.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ETL数据仓库的使用方式

一、ETL的过程 在 ETL 过程中&#xff0c;数据从源系统中抽取&#xff08;Extract&#xff09;&#xff0c;经过各种转换&#xff08;Transform&#xff09;操作&#xff0c;最后加载&#xff08;Load&#xff09;到目标数据仓库中。以下是 ETL 数仓流程的基本步骤&#xff1a…

工作中常见问题总结

工作中常见错误清单 1、springboot实现无数据库启动 问题 springboot往往是作为b/s系统的server端的架子来使用&#xff0c;但是有些时候&#xff0c;是作为静默的server&#xff0c;并没有界面和数据库&#xff0c;但是springboot默认是链接数据库的&#xff0c;如何解决这个…

sql-labs第46关 order by盲注

sql-labs第46关 order by盲注 来到了第46关进入关卡发现让我们输入的参数为sort&#xff0c;我们输入?sort1尝试&#xff1a; 输入?sort2,3,发现表格按照顺序进行排列输出&#xff0c;明显是使用了order by相关的函数。 我们将参数变成1进行尝试&#xff0c;就会报错&…

C++基础学习——哈希表的封装

目录 ​编辑 一&#xff0c;实现一个可封装的哈希表 1&#xff0c;哈希表的节点 2&#xff0c;哈希表的成员 3&#xff0c;哈希表成员方法的实现 4&#xff0c;迭代器的实现 5&#xff0c;在哈希表中加入迭代器 二&#xff0c;封装哈希表 1&#xff0c;unorder_map封装 2…

【hot100】跟着小王一起刷leetcode -- 49. 字母异位词分组

【【hot100】跟着小王一起刷leetcode -- 49. 字母异位词分组 49. 字母异位词分组题目解读解题思路代码实现 总结 49. 字母异位词分组 题目解读 49. 字母异位词分组 ok&#xff0c;兄弟们&#xff0c;咱们来看看这道题&#xff0c;很明显哈&#xff0c;这里的关键词是字母异位…

2024生物发酵展全面进行-飞翔泵业制造

参展企业介绍 江苏飞翔泵业制造有限公司始建于上世纪八十年代&#xff0c;二00一年根据《公司法》组建江苏飞翔泵业制造有限公司。公司集科研、设计、生产、经营、服务为一体&#xff0c;企业性质为有限责任公司。现为中国石化集团公司物资资源一级网络成员厂&#xff0c;中国石…

备考2025年考研数学(一)真题练习和解析——填空题

今天距离2025年考研预计还有10个月的时间&#xff0c;看起来挺长&#xff0c;但是对于备考2025年考研的同学来说&#xff0c;必须用好每一天。为了帮助大家提升考研数学一的成绩&#xff0c;我收集整理了1987-2024年的考研数学一的真题和解析&#xff0c;并把2015-2024年十年的…

抖店是怎么运营做起来的?抖音小店每天都需要做什么?新手必看!

大家好&#xff0c;我是电商花花。 很多人疑惑开抖音小店之后&#xff0c;选好商品上架之后每天都需要做什么&#xff1f; 不少新手在开了抖音小店之后每天除了选品之后就不知道要做些什么了。 今天给大家分享一下我们每天做抖音小店的工作内容有哪些&#xff0c;如果你是新…

matlab生成模拟的通信信号

matlab中rand函数生成均匀随机分布的随机数&#xff0c;randn生成正态分布的随机数&#xff1b; matlab来模拟一个通信信号&#xff1b; 通信信号通过信道时&#xff0c;研究时认为它会被叠加上服从正态分布的噪声&#xff1b; 先生成随机信号模拟要传输的信号&#xff0c;s…

【一】【SQL】表的增删查改(部分)

表之“增”操作 建表的操作 mysql> create table students(-> id int unsigned primary key auto_increment,-> sn int unsigned unique key,-> name varchar(20) not null,-> qq varchar(32) unique key-> ); Query OK, 0 rows affected (0.03 sec)mysql&g…

试一下newb,还是有错误呀

解题&#xff1a;原式&#xff1d; 2. 在递增的等比数列 ( a n ) (a_n) (an​)中&#xff0c;若 ( a 3 − a 1 5 2 ) (a_3 - a_1 \frac{5}{2}) (a3​−a1​25​), ( a 2 3 ) (a_2 3) (a2​3), 则公比 (q) A. ( 4 3 ) ( \frac{4}{3} ) (34​) B. ( 3 2 ) ( \frac{3}{2} …

创建vue3项目(基础)

首先打开自己的目录文件输入指令cmd 出现命令行工具 输入指令vue create 项目名称 按回车 选择第三个自己配置 根据需求选择 回车 选择自己需要的版本 出现这个 一直按回车大约5下或者6下 创建完毕 结束 感谢观看

Flutter(二):Row、Column 布局

MaterialApp 对于 MaterialApp&#xff0c;组件提供了一些默认的属性&#xff0c;如AppBar、标题、背景颜色等&#xff0c;你可以默认使用它们 import package:flutter/material.dart;void main() {runApp(const App()); }class App extends StatelessWidget {const App({super…

《名作欣赏》期刊投稿方向投稿邮箱

《名作欣赏》杂志是由国家新闻出版总署批准的正规学术期刊。文学是用语言塑造形象反映社会生活的一种语言艺术&#xff0c;是自觉、独立而又面向整个社会的艺术&#xff0c;是文化中极具强烈感染力的重要组成部分&#xff0c;亦是对美的体现。在现代社会&#xff0c;物欲横流&a…

用matlab实现交通分布预测方法——增长系数法

前言 对于我的第一篇文章&#xff1a;Matlab实现交通分布预测方法 —— 增长系数法 | 平均增长率法、底特律法、福莱特法&#xff0c;有不少同学私信我询问关于如何在 matlab 中调用函数、拆分代码以及需要大量调用的问题。于是我便想着把内容做一些优化&#xff0c;放在这篇文…

【MySQL】MySQL从0到0.9 - 持续更新ing

MySQL SQL 基础 DDL 语句 show databases; show tables; desc table_name; # 获取表信息 show create table 表名; // 查询指定表的建表语句 数据类型 char(10) 不满10个会用空格填充&#xff0c;性能好一点 varchar(10) 变长字符串&#xff0c;性能差一点 CREATE TABLE tabl…

nginx设置缓存时间

一、设置缓存时间 当网页数据返回给客户端后&#xff0c;可针对静态网页设置缓存时间&#xff0c;在配置文件内的http段内server段添加location&#xff0c;更改字段expires 1d来实现&#xff1a;避免重复请求&#xff0c;加快访问速度 第一步&#xff1a;修改主配置文件 #修…

[NOIP2011 普及组] 数字反转

AC代码&#xff1a; #include<iostream>using namespace std;int main() {long long n;cin >> n;long long temp n;long long sum 0;while(temp ! 0){int c temp % 10;sum sum * 10 c;temp temp / 10;}printf("%lld",sum);return 0; }

Spring 中 ApplicationContext 和 BeanFactory 的区别有哪些

先看一张类图&#xff1a; 区别&#xff1a; 1&#xff1a;包目录不同&#xff1a; spring-beans.jar 中 org.springframework.beans.factory.BeanFactory spring-context.jar 中 org.springframework.context.ApplicationContext 2&#xff1a;国际化&#xff1a; BeanFacto…

APP被针对攻击了,要怎么解决

随着APP行业的兴起&#xff0c;游戏公司异军突起&#xff0c;不管是在控证还是攻击方面都是属于最复杂的一个场面&#xff0c;游戏APP逐渐成为DDOS流量攻击的“重灾区”。没有提前做好了解就盲目进军游戏APP行业&#xff0c;一旦被攻击就会让公司束手无策。那么&#xff0c;刚上…