在深度学习中常见的初始化操作

目录

截断正态分布来初始化张量

逐行代码解释

相关理论解释

截断正态分布函数

截断正态分布的定义

截断正态分布的作用

计算截断点的作用

具体步骤

正态分布的累积分布函数(CDF)

 正态分布的累积分布函数与误差函数的关系

示例计算

误差函数

应用:

定义:

误差函数的性质

Python 中的误差函数

总结


截断正态分布来初始化张量

import math
import warnings
import torch

def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    def norm_cdf(x):
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)
    with torch.no_grad():
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
        tensor.uniform_(2 * l - 1, 2 * u - 1)
        tensor.erfinv_()
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)
        tensor.clamp_(min=a, max=b)
        return tensor

逐行代码解释

1、正态分布的累积分布函数(CDF)norm_cdf 函数计算标准正态分布的累积分布函数。

def norm_cdf(x):
    return (1. + math.erf(x / math.sqrt(2.))) / 2.

2、警告:检查均值是否在截断边界 [a, b] 的2个标准差范围内,如果不在,则发出警告。

if (mean < a - 2 * std) or (mean > b + 2 * std):
    warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                  "The distribution of values may be incorrect.",
                  stacklevel=2)

3、不跟踪梯度:以下代码块确保初始化时不跟踪梯度,这对于设置神经网络的初始权重很有用。

with torch.no_grad():
    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)
    tensor.uniform_(2 * l - 1, 2 * u - 1)
    tensor.erfinv_()
    tensor.mul_(std * math.sqrt(2.))
    tensor.add_(mean)
    tensor.clamp_(min=a, max=b)
    return tensor
  • lu 是截断点 ab 处的累积分布函数值。
  • tensor.uniform_(2 * l - 1, 2 * u - 1) 用从指定范围的均匀分布生成的值初始化张量。
  • tensor.erfinv_() 对张量应用误差函数的逆函数。
  • tensor.mul_(std * math.sqrt(2.)) 将张量的值缩放到期望的标准差。
  • tensor.add_(mean) 将张量的值平移到期望的均值。
  • tensor.clamp_(min=a, max=b) 确保张量中的所有值都在指定的截断范围 [a, b] 之内。

相关理论解释

截断正态分布函数

截断正态分布的定义

        给定一个均值为 μ、标准差为 σ 的正态分布 N(μ,σ2),截断正态分布在区间 [a,b] 上的定义如下:

         其中,ϕ(x) 是正态分布的概率密度函数(PDF),Φ(x)是正态分布的累积分布函数(CDF)。

截断正态分布的作用
  1. 限制范围:确保生成的随机变量值在某个指定范围内,这对于物理约束或特定应用场景非常重要。
  2. 防止异常值:避免生成不合实际或有害的极端值,例如在神经网络权重初始化时防止极端值导致的训练不稳定。
计算截断点的作用

在实现截断正态分布时,我们需要计算截断点 ab 对应的累积分布函数值 l 和 u,以便生成满足截断条件的随机数。

具体步骤
  1. 标准化:将截断点 ab 标准化为标准正态分布中的值。

  2. 计算标准正态分布的 CDF:计算标准正态分布在标准化后的截断点 lu 处的累积分布函数值。注意:此处有一个性质,就是随机变量Φ(l)和Φ(u)是满足[0,1]的均匀分布。

  3. 转换为均匀分布:生成的均匀分布随机数在 [2Φ(l)−1,2Φ(u)−1] 区间内。

  4. 逆误差函数:将均匀分布的值通过逆误差函数转换为标准正态分布的值。

    tensor.erfinv()

  5. 缩放和平移:将标准正态分布的值缩放到所需的标准差,并平移到所需的均值。

  6. 截断:确保所有值都在 [a,b] 区间内。

正态分布的累积分布函数(CDF)

定义:用于计算正态分布从负无穷大到给定值 x的概率。具体而言,对于标准正态分布 N(0,1),CDF 表示为:

 正态分布的累积分布函数与误差函数的关系

在代码中,我们通过误差函数(erf)来计算标准正态分布的 CDF。误差函数与标准正态分布的 CDF 之间有如下关系:

 代码中的 norm_cdf 函数:

def norm_cdf(x):
    return (1. + math.erf(x / math.sqrt(2.))) / 2.

norm_cdf 函数的实现如下:

  1. 输入:函数接收一个参数 x,它是需要计算 CDF 的点。
  2. 计算误差函数math.erf(x / math.sqrt(2.)) 计算 \frac{x}{\sqrt{2}} 的误差函数值。
  3. 调整误差函数值:将误差函数的结果加 1,然后除以 2,得到标准正态分布在 x 点的 CDF 值。

以下是函数的具体步骤:

  1. math.erf(x / math.sqrt(2.))计算误差函数
  2. 1. + math.erf(x / math.sqrt(2.)):将误差函数的结果加 1。
  3. (1. + math.erf(x / math.sqrt(2.))) / 2.:结果除以 2 得到最终的 CDF 值。
示例计算

假设我们需要计算标准正态分布在 x=1处的 CDF 值:

import math

def norm_cdf(x):
    return (1. + math.erf(x / math.sqrt(2.))) / 2.

x = 1
cdf_value = norm_cdf(x)
print("CDF value at x = 1:", cdf_value)

 运行以上代码,会输出 x=1处的 CDF 值,即:

CDF value at x = 1: 0.8413447460685429

这意味着在标准正态分布中,小于等于 1 的值的概率大约为 0.8413。

误差函数

应用:

        数学上用于处理正态分布和概率问题的重要函数。误差函数用于计算某个值在标准正态分布中的概率,并且在统计学、概率论和许多应用数学领域中都有广泛应用。

定义:

         这个积分没有解析解,因此通常通过数值方法进行计算。

误差函数的性质
  • 对称性:误差函数是奇函数,即erf⁡(−x)=−erf⁡(x) 。
  • 值域:误差函数的值域在 −1 到 1 之间,即 −1≤erf⁡(x)≤1。
  • 边界值:当 x→∞ 时,erf⁡(x)→1;当 x→−∞时,erf⁡(−x)→−1。
Python 中的误差函数

在 Python 中,可以使用 math 模块中的 erf 函数来计算误差函数值。以下是一个示例:

import math

x = 1.0
erf_value = math.erf(x)
print("erf(1.0) =", erf_value)

运行结果是:

erf(1.0) = 0.8427007929497149

这意味着当x=1.0 时,erf(1.0)的值大约为 0.8427。

总结

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/637150.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

切换分支报错:Untracked Files Prevent Checkout

切换分支报错&#xff1a;Untracked Files Prevent Checkout 分支切换 Untracked Files Prevent Checkout 新起的项目在切换master分支到工作分支时&#xff0c;出现下图的问题&#xff1a; Untracked Files Prevent Checkout Move or commit them before checkout 网上的解决…

pip(包管理器) for Python

pip是什么 pip是Python的包安装程序&#xff0c;即python包管理器。您可以使用 pip 从Python包索引和其他索引安装包。 1. pip 安装 python 包 pip install 包名 例如&#xff1a;pip install pymssql &#xff1a; 使用pip安装数据库驱动包 pymssql 2.pip 卸载 python 包 pi…

数据结构2(初):顺序表和链表

目录 1、线性表 2、顺序表 2.1、概念及结构 2.2、顺序表的实现 2.3、顺序表的问题及思考 3、链表 3.1、链表的概念及结构 3.2、链表的分类 3.3、无头单向非循环链表的实现 3.4、带头双向循环链表的实现 4、顺序表和链表的区别和联系 1、线性表 线性表是n个具有相同特…

200+有趣的HTML前端游戏项目合集(5月17日更新,持续更新中)

&#x1f482; 个人网站:【 摸鱼游戏】【神级代码资源网站】【工具大全】&#x1f91f; 一站式轻松构建小程序、Web网站、移动应用&#xff1a;&#x1f449;注册地址&#x1f91f; 基于Web端打造的&#xff1a;&#x1f449;轻量化工具创作平台&#x1f485; 想寻找共同学习交…

工作干到抑郁了,要不要辞职?

在知乎上看到以为网友提问&#xff1a;工作干到抑郁&#xff0c;该不该辞职&#xff1f; 今天和大家聊聊这个话题&#xff0c;如果你也有类似的情况&#xff0c;希望这篇文章能帮到你。 熟悉瑶琴的朋友&#xff0c;都知道瑶琴在去年有一次裸辞的经历。离职前&#xff0c;严重的…

多台Centos快速区分,让Centos开机自动显示它的IP地址!

背景说明&#xff1a;当公司拥有多台Centos服务器&#xff0c;管理员很容易弄混淆导致不好区分&#xff0c;在这样的情况下我们可以写个简单脚本来实现开机自动显示它的IP地址&#xff0c;从而达到区分开来的结果&#xff01; 首先我们来开下效果&#xff0c;登录之前的 下面是…

【加密与解密(第四版)】第十八章笔记

第十八章 反跟踪技术 18.1 由BeginDebugged引发的蝴蝶效应 IsDebuggerPresent()函数读取当前进程PEB中的BeginDebugged标志 CheckRemoteDebuggerPresent() 反调试总结&#xff1a;https://bbs.kanxue.com/thread-225740.htm https://www.freebuf.com/articles/others-articl…

细胞冻存——让你的细胞“长生不老”

《星际穿越》电影中提到漫长的太空旅程中&#xff0c;宇航员可以进入休眠水床休眠&#xff0c;并自行设定唤醒时间。在《异形》《深空失忆》《三体》等科幻作品中&#xff0c;都出现此类技术。《三体》中&#xff0c;休眠后来成为人类最普遍的一项技术。技术上的人类低温休眠&a…

JavaEE-网络初识

文章目录 一、网络背景1.1 起源1.2 国内网络的发展 二、关键概念2.1 网络2.2 设备2.3 ip地址与端口号 三、协议3.1 协议分层3.2 OSI七层模型3.3 TCP/IP五层模型3.4 数据传输过程的简单叙述 一、网络背景 1.1 起源 在国外大概时上世纪70年代左右&#xff0c;网络就出现了&…

项目集成SkyWalking,基于k8s搭建

一、搭建SkyWalking 官方文档&#xff08;英文&#xff09;&#xff1a;skywalking/docs at master apache/skywalking 中文可以使用&#xff1a;GitHub - SkyAPM/document-cn-translation-of-skywalking: [已过期,请使用官网AI文档] The CN translation version of Apache…

【LeetCode:496. 下一个更大元素 I + 单调栈】

&#x1f680; 算法题 &#x1f680; &#x1f332; 算法刷题专栏 | 面试必备算法 | 面试高频算法 &#x1f340; &#x1f332; 越难的东西,越要努力坚持&#xff0c;因为它具有很高的价值&#xff0c;算法就是这样✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;…

深度学习——图像分类(CNN)—训练模型

训练模型 1.导入必要的库2.定义超参数3.读取训练和测试标签CSV文件4.确保标签是字符串类型5.显示两个数据框的前几行以了解它们的结构6.定义图像处理参数7.创建图像数据生成器8.设置目录路径9.创建训练和验证数据生成器10.构建模型11.编译模型12.训练模型并收集历史13.绘制损失…

【AD21】PCB板尺寸与层名称标注

PCB绘制完成后&#xff0c;需要给上级或生产制造商发送输出文件&#xff0c;输出文件中包含板尺寸标识和层标识可以方便工作的交接。 1. 板尺寸标识 首先板尺寸标识所在的层要在与板框不同的机械层&#xff0c;这里我选择机械5层。 点击放置->尺寸->线性尺寸 这里板尺…

微信小程序uniapp+django洗脚按摩足浴城消费系统springboot

原生wxml开发对Node、预编译器、webpack支持不好&#xff0c;影响开发效率和工程构建。所以都会用uniapp框架开发 前后端分离&#xff0c;后端给接口和API文档&#xff0c;注重前端,接近原生系统 使用Navicat或者其它工具&#xff0c;在mysql中创建对应名称的数据库&#xff0…

利用大模型构造数据集,并微调大模型

一、前言 目前大模型的微调方法有很多&#xff0c;而且大多可以在消费级显卡上进行&#xff0c;每个人都可以在自己的电脑上微调自己的大模型。 但是在微调时我们时常面对一个问题&#xff0c;就是数据集问题。网络上有许多开源数据集&#xff0c;但是很多时候我们并不想用这…

Gerchberg-Saxton (GS) 和混合输入输出(Hybrid Input-Output, HIO)算法

文章目录 1. 简介2. 算法描述3. 混合输入输出&#xff08;Hybrid Input-Output, HIO&#xff09;算法3.1 HIO算法步骤3.2 HIO算法的优势3.3 算法描述 4. 算法实现与对比5. 总结参考文献 1. 简介 Gerchberg-Saxton (GS) 算法是一种常用于相位恢复和光学成像的迭代算法。该算法最…

【抽代复习笔记】18-置换练习题(2)及两个重要定理

最近一直忙于学校的事情&#xff0c;好久没更新了&#xff0c;实在抱歉。接下来几期大概也会更得慢一些&#xff0c;望见谅。 练习4&#xff1a;写出4次对称群S4中所有置换。 解&#xff1a;由上一篇笔记结尾的定理我们知道&#xff0c;4次对称群的阶&#xff08;也就是所含元…

JSON的序列化与反序列化以及VSCode执行Run Code 报错

JSON JSON: JavaScript Object Notation JS对象简谱 , 是一种轻量级的数据交换格式。 JSON格式 { "name":"金苹果", "info":"种苹果" } 一个对象&#xff1a;由一个大括号表示.括号中通过键值对来描述对象的属性 (可以理解为, 大…

2024年 电工杯 (A题)大学生数学建模挑战赛 | 园区微电网风光储协调优化配置 | 数学建模完整代码解析

DeepVisionary 每日深度学习前沿科技推送&顶会论文&数学建模与科技信息前沿资讯分享&#xff0c;与你一起了解前沿科技知识&#xff01; 本次DeepVisionary带来的是电工杯的详细解读&#xff1a; 完整内容可以在文章末尾全文免费领取&阅读&#xff01; 问题重述…

MVS net笔记和理解

文章目录 传统的方法有什么缺陷吗&#xff1f;MVSnet深度的预估 传统的方法有什么缺陷吗&#xff1f; 传统的mvs算法它对图像的光照要求相对较高&#xff0c;但是在实际中要保证照片的光照效果很好是很难的。所以传统算法对镜面反射&#xff0c;白墙这种的重建效果就比较差。 …