pytorch常用的模块函数汇总(1)

目录

torch:核心库,包含张量操作、数学函数等基本功能

torch.nn:神经网络模块,包括各种层、损失函数和优化器等

torch.optim:优化算法模块,提供了各种优化器,如随机梯度下降 (SGD)、Adam、RMSprop 等。

torch.autograd:自动求导模块,用于计算张量的梯度


torch:核心库,包含张量操作、数学函数等基本功能

  1. torch.tensor(): 创建张量
  2. torch.zeros()torch.ones(): 创建全零或全一张量
  3. torch.rand(): 创建随机张量
  4. torch.from_numpy(): 从 NumPy 数组创建张量
  5. torch.add()torch.sub()torch.mul()torch.div(): 加法、减法、乘法、除法

  6. torch.mm()torch.matmul(): 矩阵乘法

  7. torch.exp()torch.log()torch.sin()torch.cos(): 指数、对数、正弦、余弦等数学函数

  8. torch.index_select(): 按索引选取张量的子集

  9. torch.masked_select(): 根据掩码选取张量的子集

    切片操作:类似 Python 中的列表切片操作,如 tensor[2:5]
  10. torch.view()torch.reshape(): 改变张量的形状

  11. torch.squeeze()torch.unsqueeze(): 压缩或扩展张量的维度

  12. torch.mean()torch.sum()torch.max()torch.min(): 计算张量均值、和、最大值、最小值等

  13. torch.broadcast_tensors(): 对张量进行广播操作

  14. torch.cat(): 拼接张量

  15. torch.stack(): 堆叠张量

  16. torch.split(): 分割张量

torch.nn:神经网络模块,包括各种层、损失函数和优化器等

  • 神经网络层

    • torch.nn.Linear(in_features, out_features): 全连接层,进行线性变换。
    • torch.nn.Conv2d(in_channels, out_channels, kernel_size): 2D卷积层。
    • torch.nn.MaxPool2d(kernel_size): 2D 最大池化层。
    • torch.nn.ReLU(): ReLU 激活函数。
    • torch.nn.Sigmoid(): Sigmoid 激活函数。
    • torch.nn.Dropout(p): Dropout 层,用于防止过拟合。

备注:Sigmoid 激活函数是一种常用的非线性激活函数,其作用可以总结如下:

将输入映射到 (0, 1) 范围内:输出范围在 0 到 1 之间,可以将任意实数输入映射到 0 到 1 之间。这种特性在某些情况下很有用,比如对于二分类任务,Sigmoid 函数的输出可以被解释为样本属于正类的概率。

引入非线性变换: Sigmoid 函数是一种非线性函数,可以引入神经网络的非线性变换能力,使得神经网络可以学习更加复杂的模式和关系。在深度神经网络中,非线性激活函数的使用可以帮助神经网络学习非线性模式,提高网络的表达能力。

输出平滑且连续: Sigmoid 函数具有平滑的 S 形曲线,在定义域内都是可导的,这使得在反向传播算法中计算梯度变得相对容易。这一点对于神经网络的训练至关重要。

  • 损失函数

torch.nn.CrossEntropyLoss(): 交叉熵损失函数,常用于多分类问题。

交叉熵损失函数用于衡量两个概率分布之间的差异,通常用于多分类任务中。在神经网络的多分类任务中,输入模型的输出是一个概率分布,表示每个类别的预测概率,而交叉熵损失函数则用于比较这个预测概率分布与实际标签的分布之间的差异。

torch.nn.CrossEntropyLoss() 来计算交叉熵损失函数,它会自动将模型的输出通过 Softmax 函数转换为概率分布,并计算交叉熵损失。

torch.nn.MSELoss(): 均方误差损失函数,常用于回归问题。

均方误差损失函数用于衡量模型输出与实际目标之间的差异,通常在回归任务中使用。该损失函数计算预测值与真实值之间的平方差,并将所有样本的平方差求平均得到最终的损失值。

  • 优化器

    • torch.optim.SGD(model.parameters(), lr=learning_rate): 随机梯度下降优化器。
    • torch.optim.Adam(model.parameters(), lr=learning_rate): Adam 优化器。
  • 模型定义相关

    • torch.nn.Module: 所有神经网络模型的基类,需要继承这个类。
    • model.forward(input_tensor): 定义前向传播。
  • 数据处理相关

    • torch.utils.data.Dataset: PyTorch 数据集的基类,需要自定义数据集时使用。
    • torch.utils.data.DataLoader(dataset, batch_size, shuffle): 数据加载器,用于批量加载数据。
  • torch.optim:优化算法模块,提供了各种优化器,如随机梯度下降 (SGD)、Adam、RMSprop 等。

  • 优化器(Optimizer)类

    • torch.optim.SGD(params, lr=0.01, momentum=0, weight_decay=0):随机梯度下降优化器,实现了带动量的随机梯度下降
    • torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):Adam 优化器,结合了动量方法和 RMSProp 方法,通常在深度学习中表现良好。
    • torch.optim.Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0):Adagrad 优化器,自适应地为参数分配学习率。它根据参数的历史梯度信息对每个参数的学习率进行调整。这意味着对于不同的参数,Adagrad可以为其分配不同的学习率,从而更好地适应参数的更新需求
    • torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):RMSprop 优化器,有效地解决了 Adagrad 学习率下降较快的问题(RMSprop对梯度平方项进行指数加权平均)。

备注:

1.在优化算法中,momentum(动量)是一种用于加速模型训练的技巧。动量项的引入旨在解决标准随机梯度下降在训练过程中可能遇到的震荡和收敛速度慢的问题。

动量项的引入可以帮助优化算法在参数更新时更好地利用之前的更新方向,从而在一定程度上减少参数更新的波动,加快收敛速度,并有助于跳出局部极小值。具体来说,动量项在参数更新时会考虑之前的更新方向,并对当前的更新方向进行一定程度的调整。

在 PyTorch 的 torch.optim.SGD 中,动量可以通过设置 momentum 参数来控制。通常情况下,动量的取值范围在 0 到 1 之间,常见的默认取值为 0.9。当动量设为0.9时,每次迭代,都会保留上一次速度的 90%,并使用当前梯度微调最终的更新方向。

总结来说,动量项的引入可以提高随机梯度下降的稳定性和收敛速度,有助于在训练神经网络时更快地找到较优的解。

2. 

在优化算法中,weight decay(权重衰减)是一种用于控制模型参数更新的正则化技术。权重衰减通过在优化过程中对参数进行惩罚,防止其取值过大,从而有助于降低过拟合的风险。

具体来说,在SGD中的weight_decay参数是对模型的权重进行L2正则化,即在计算梯度时额外增加一个关于参数的惩罚项。这个惩罚项会使得优化算法更倾向于选择较小的权重值,从而降低模型的复杂度,减少过拟合的风险。

在PyTorch中,torch.optim.SGD中的weight_decay参数用于控制权重衰减的程度。通常情况下,weight_decay的取值为一个小的正数,比如 0.001 或 0.0001。设置了weight_decay之后,在计算梯度时会额外考虑到权重的惩罚项,从而影响参数的更新方式。

总结来说,权重衰减是一种正则化技术,通过对模型参数的惩罚来控制模型的复杂度,减少过拟合的风险,提高模型的泛化能力。

3.

  • L1正则化:L1正则化会给模型的损失函数添加一个关于权重绝对值的惩罚项,即L1范数(权重的绝对值之和)。在梯度下降过程中,L1正则化会导致部分权重直接变为0,因此可以实现稀疏性,有特征选择的效果。L1正则化倾向于产生稀疏的权重矩阵,可以用于特征选择和降维。

L1正则化的惩罚项是模型权重的L1范数,即权重的绝对值之和。在优化过程中,为了最小化损失函数并减少正则化项的影响,优化算法会尝试将权重调整到较小的值。由于L1正则化的几何形状在坐标轴上拐角处就会与坐标轴相交,这就导致了在坐标轴上许多点都是对称的,因此在这些点上的梯度不唯一。这意味着在这些对称点上,优化算法更有可能将权重调整为0,从而导致稀疏性。

  • L2正则化:L2正则化会给模型的损失函数添加一个关于权重平方的惩罚项,即L2范数的平方(权重的平方和)。在梯度下降过程中,L2正则化会使得权重都变得比较小,但不会直接导致稀疏性。L2正则化对异常值比较敏感,因为它会平方每个权重,使得异常值对损失函数的影响更大。

  • 总的来说,L1正则化和L2正则化都是常用的正则化技术,它们在模型训练过程中都有助于控制模型的复杂度,减少过拟合的风险。选择使用哪种正则化方法通常取决于具体的问题和数据特点,以及对模型稀疏性的需求。在实际应用中,有时也会将L1和L2正则化结合起来,形成弹性网络正则化(Elastic Net regularization),以兼顾两种正则化方法的优势。

4. 

betas 是 Adam 算法中的两个超参数之一,它控制了梯度的一阶矩估计和二阶矩估计的指数衰减率。betas 是一个长度为2的元组,通常形式为 (beta1, beta2)。在 Adam 算法中,beta1 控制了一阶矩估计(梯度的均值)的衰减率,beta2 控制了二阶矩估计(梯度的平方的均值)的衰减率。

通常情况下,beta1 的默认值为 0.9,beta2 的默认值为 0.999。这意味着在每次迭代中,一阶矩估计将保留当前梯度的 90%,而二阶矩估计将保留当前梯度的平方的 99.9%。这些衰减率的选择使得 Adam 算法能够在训练过程中自适应地调整学习率,并对梯度的变化做出快速或缓慢的响应,从而更有效地更新模型参数。

总之,betas 参数在 Adam 算法中起着调节梯度一阶和二阶矩估计衰减率的作用,通过合理设置 betas 可以影响算法的收敛性和稳定性。

  • 调整学习率的函数

    • torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)根据给定的函数 lr_lambda 调整学习率。
    • torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1):每个 step_size 个 epoch 将学习率降低为原来的 gamma 倍。
    • torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1):在指定的里程碑上将学习率降低为原来的 gamma 倍。
  • 其他常用函数

    • zero_grad():用于将模型参数的梯度清零,通常在每个 batch 后调用。
    • step(closure):用于执行单步优化器的更新,需要传入一个闭包函数 closure
    • state_dict() 和 load_state_dict():用于保存和加载优化器的状态字典,方便恢复训练。
  • torch.autograd:自动求导模块,用于计算张量的梯度

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/493924.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Maven的pom.xml中resources标签的用法

spring-boot-starter-parent-2.4.1.pom文件中resources标签内容如下&#xff1a; <build><resources><resource><directory>${basedir}/src/main/resources</directory><filtering>true</filtering><includes><include>…

考研数学|张宇《1000题》太难了,根本刷不动?怎么破!

即使一直在看张宇的课程&#xff0c;但在做1000题时仍然感到困难。这其实是许多考生会出现的问题&#xff0c;所以不用担心&#xff0c;希望看完这篇文章能对你有帮助。 首先是理论与实践的差距。听课时&#xff0c;你可能是在接受知识&#xff0c;而做题则需要将这些知识应用…

鸿蒙开发之ArkUI组件常用组件文本输入

TextInput、TextArea是输入框组件&#xff0c;通常用于响应用户的输入操作&#xff0c;比如评论区的输入、聊天框的输入、表格的输入等&#xff0c;也可以结合其它组件构建功能页面&#xff0c;例如登录注册页面。 TextInput为单行输入框、TextArea为多行输入框 TextArea 多行…

阿里云服务器租用价格表(最新CPU/内存/带宽/磁盘收费标准)

阿里云服务器一个月多少钱&#xff1f;最便宜5元1个月。阿里云轻量应用服务器2核2G3M配置61元一年&#xff0c;折合5元一个月&#xff0c;2核4G服务器30元3个月&#xff0c;2核2G3M带宽服务器99元12个月&#xff0c;轻量应用服务器2核4G4M带宽165元12个月&#xff0c;4核16G服务…

如何用Python操作xlsx文件并绘制折线图!

​大家好&#xff0c;数据分析在现代社会越来越重要&#xff0c;而Excel作为数据分析的利器&#xff0c;几乎人手一份。但是&#xff0c;Excel的操作有时候略显繁琐&#xff0c;更是感觉无从下手。 你知道吗&#xff1f;Python这个神奇的工具不仅能帮你处理海量的数据&#xf…

【GPU系列】选择最适合的 CUDA 版本以提高系统性能

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

c语言中动态内存管理

说到内存&#xff0c;大家一定都知道。但是有一种函数可以实现动态内存管理&#xff0c;下面大家一起学习。 文章目录 一、为什么要有动态内存管理&#xff1f;二、malloc 和 free1.malloc2.free 三、calloc 和 realloc1.calloc2.realloc3.常见的动态内存的错误3.1对NULL指针的…

【SpringBoot框架篇】37.使用gRPC实现远程服务调用

文章目录 RPC简介gPRC简介protobuf1.文件编写规范2.字段类型3.定义服务(Services) 在Spring Boot中使用grpc1.父工程pom配置2.grpc-api模块2.1.pom配置2.2.proto文件编写2.3.把proto文件编译成class文件 3.grpc-server模块3.1.pom文件和application.yaml3.2.实现grpc-api模块的…

Linux——信号概念与信号产生方式

目录 一、概念 二、前台进程与后台进程 1.ctrlc 2.ctrlz 三、信号的产生方式 1.键盘输入产生信号 2.系统调用发送信号 2.1 kill()函数 2.2 raise()函数 2.3 abort()函数 3.异常导致信号产生 3.1 除0异常 3.2 段错误异常 4.软件条件产生信号 4.1 管道 4.2 闹钟…

最新可用免费VPS云服务器整理汇总

随着云计算技术的不断发展&#xff0c;越来越多的个人和企业开始关注和使用VPS云服务器。VPS云服务器以其高度的灵活性、可定制性和安全性&#xff0c;成为了一种受欢迎的服务器解决方案。然而&#xff0c;对于初学者或者预算有限的用户来说&#xff0c;如何选择合适的免费VPS云…

ZYNQ学习之Ubuntu系统的简单设置与文本编辑

基本都是摘抄正点原子的文章&#xff1a;<领航者 ZYNQ 之嵌入式Linux 开发指南 V3.2.pdf&#xff0c;因初次学习&#xff0c;仅作学习摘录之用&#xff0c;有不懂之处后续会继续更新~ 一、Ubuntu的简单操作 1.1 切换拼音输入法 Ubuntu 自带的拼音输入法&#xff0c;有两种…

ADAS多传感器后融合算法解析-下篇

ADAS多传感器后融合算法解析-下篇 在ADAS多传感器后融合(上)中我们介绍了后融合的接口、策略。本文将主要介绍后融合的实现流程、难点及注意事项。 附赠自动驾驶学习资料和量产经验&#xff1a;链接 二、后融合处理流程 如下图为基本RC后融合系统流程图&#xff0c;接下来将…

day 36 贪心算法 part05● 435. 无重叠区间 ● 763.划分字母区间 ● 56. 合并区间

一遍过。首先把区间按左端点排序&#xff0c;然后右端点有两种情况。 假设是a区间&#xff0c;b区间。。。这样排列的顺序&#xff0c;那么 假设a[1]>b[0],如果a[1]>b[1]&#xff0c;就应该以b[1]为准&#xff0c;否则以a[1]为准。 class Solution { public:static bo…

argocd部署

一、前言 ArgoCD 是一个开源的、持续交付工具&#xff0c;用于自动化部署应用程序到 Kubernetes 集群。它基于 GitOps 理念&#xff0c;通过使用 Git 作为单一的源头来管理应用程序的配置和部署状态&#xff0c;argocd会定时监控git仓库中的yaml配置文件&#xff0c;当git仓库中…

验证码/数组元素的复制.java

1&#xff0c;验证码 题目&#xff1a;定义方法实现随机产生一个5位的验证码&#xff0c;前面四位是大写或小写的英文字母&#xff0c;最后一位是数字 分析&#xff1a;定义一个包含所有大小写字母的数组&#xff0c;然后对数组随机抽取4个索引&#xff0c;将索引对应的字符拼…

JSON Web Token 入门教程

JSON Web Token&#xff08;JWT&#xff09;是一种可以在多方之间安全共享数据的开放标准&#xff0c;JWT 数据经过编码和数字签名生成&#xff0c;可以确保其真实性&#xff0c;也因此 JWT 通常用于身份认证。这篇文章会介绍什么是 JWT&#xff0c;JWT 的应用场景以及组成结构…

46秒AI生成真人视频爆火,遭在线打假「换口型、声音」

ChatGPT狂飙160天&#xff0c;世界已经不是之前的样子。 新建了人工智能中文站https://ai.weoknow.com 每天给大家更新可用的国内可用chatGPT资源 发布在https://it.weoknow.com 更多资源欢迎关注 是炒作还是真正的 AI 视频能力进化&#xff1f; AI 生成视频已经发展到这个程…

rabbitmq集群问题排查

blowcode-test-redis04、blowcode-test-redis05、blowcode-test-redis06 这3个节点搭建的rabbitmq集群&#xff0c;04是主节点。 某次分别观察3个节点的管理页面&#xff0c;先都只能看到自己的节点是正常的绿色状态&#xff0c;猜测节点都各自为政了。 下图是05节点成功加入0…

iOS_convert point or rect 坐标和布局转换+判断

文章目录 1. 坐标转换2. 布局转换3. 包含、相交 如&#xff1a;有3个色块 let view1 UIView(frame: CGRect(x: 100.0, y: 100.0, width: 300.0, height: 300.0)) view1.backgroundColor UIColor.cyan self.view.addSubview(view1)let view2 UIView(frame: CGRect(x: 50.0, …

Redis面试题-缓存穿透,缓存击穿,缓存雪崩

1、穿透: 两边都不存在&#xff08;皇帝的新装&#xff09; &#xff08;黑名单&#xff09; &#xff08;布隆过滤器&#xff09; 解释&#xff1a;请求的数据既不在Redis中也不在数据库中&#xff0c;这时我们创建一个黑名单来存储该数据&#xff0c;下次再有类似的请求进来…