动手学深度学习-线性神经网络-7softmax回归的简洁实现

目录

 初始化模型参数

重新审视Softmax的实现 

优化算法 

训练 

小结


在 线性回归的实现中, 我们发现通过深度学习框架的高级API能够使实现

线性回归变得更加容易。 同样,通过深度学习框架的高级API也能更方便地实现softmax回归模型。 本节如在上一节中一样, 继续使用Fashion-MNIST数据集,并保持批量大小为256。

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

 初始化模型参数

如我们在 softmax回归所述, softmax回归的输出层是一个全连接层。 因此,为了实现我们的模型, 我们只需在Sequential中添加一个带有10个输出的全连接层。 同样,在这里Sequential并不是必要的, 但它是实现深度模型的基础。 我们仍然以均值0和标准差0.01随机初始化权重。

# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

重新审视Softmax的实现 

在前面 一节的例子中, 我们计算了模型的输出,然后将此输出送入交叉熵损失。 从数学上讲,这是一件完全合理的事情。 然而,从计算角度来看,指数可能会造成数值稳定性问题。

回想一下,softmax函数y^j=exp⁡(oj)∑kexp⁡(ok), 其中y^j是预测的概率分布。 oj是未规范化的预测o的第j个元素。 如果ok中的一些数值非常大, 那么exp⁡(ok)可能大于数据类型容许的最大数字,即上溢(overflow)。 这将使分母或分子变为inf(无穷大), 最后得到的是0、infnan(不是数字)的y^j。 在这些情况下,我们无法得到一个明确定义的交叉熵值。

解决这个问题的一个技巧是: 在继续softmax计算之前,先从所有ok中减去max(ok)。 这里可以看到每个ok按常数进行的移动不会改变softmax的返回值:

 

在减法和规范化步骤之后,可能有些oj−max(ok)具有较大的负值。 由于精度受限,exp⁡(oj−max(ok))将有接近零的值,即下溢(underflow)。 这些值可能会四舍五入为零,使y^j为零, 并且使得log⁡(y^j)的值为-inf。 反向传播几步后,我们可能会发现自己面对一屏幕可怕的nan结果。

尽管我们要计算指数函数,但我们最终在计算交叉熵损失时会取它们的对数。 通过将softmax和交叉熵结合在一起,可以避免反向传播过程中可能会困扰我们的数值稳定性问题。 如下面的等式所示,我们避免计算exp⁡(oj−max(ok)), 而可以直接使用oj−max(ok),因为log⁡(exp⁡(⋅))被抵消了。

我们也希望保留传统的softmax函数,以备我们需要评估通过模型输出的概率。 但是,我们没有将softmax概率传递到损失函数中, 而是在交叉熵损失函数中传递未规范化的预测,并同时计算softmax及其对数, 这是一种类似“LogSumExp技巧”的聪明方式。 

loss = nn.CrossEntropyLoss(reduction='none')

优化算法 

在这里,我们使用学习率为0.1的小批量随机梯度下降作为优化算法。 这与我们在线性回归例子中的相同,这说明了优化器的普适性。 

trainer = torch.optim.SGD(net.parameters(), lr=0.1)

训练 

接下来我们调用上一节中 定义的训练函数来训练模型。 

num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

 

和以前一样,这个算法使结果收敛到一个相当高的精度,而且这次的代码比之前更精简了。 

小结

  • 使用深度学习框架的高级API,我们可以更简洁地实现softmax回归。

  • 从计算的角度来看,实现softmax回归比较复杂。在许多情况下,深度学习框架在这些著名的技巧之外采取了额外的预防措施,来确保数值的稳定性。这使我们避免了在实践中从零开始编写模型时可能遇到的陷阱。

 

 

 

 

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/936899.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

人工智能原理实验四:智能算法与机器学习

一、实验目的 本实验课程是计算机、智能、物联网等专业学生的一门专业课程,通过实验,帮助学生更好地掌握人工智能相关概念、技术、原理、应用等;通过实验提高学生编写实验报告、总结实验结果的能力;使学生对智能程序、智能算法等…

【新界面】基于卷积神经网络的垃圾分类(Matlab)

基于CNN的垃圾识别与分类GUI【新界面】 有需要可直接联系我,基本都在在线,能秒回!可加我看演示视频,不懂可以远程教学 1.此项目设计包括两份完整的源代码,有GUI界面的代码和无GUI界面系统的代码。 (以下部…

网站访问的基础-HTTP超文本传输协议

BS架构 浏览器Browser⬅➡服务器Server 浏览器和服务器之间通过 IP 地址进行通信,实现数据的请求和传输。 例如,当用户在浏览器中访问一个网站时,浏览器会根据用户输入的网址(通过 DNS 解析得到服务器 IP 地址)向服…

【C++】递归填充矩阵的理论解析与实现

博客主页: [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: C 文章目录 💯前言💯问题描述💯递归实现💯参数解析函数参数详解填充顺序分析递归终止条件 💯示例解析第一层递归第二层递归第三层递归最终输出 &#x1f4af…

Git 仓库托管教程

git远程仓库 常用的远程仓库-->托管服务:github、码云、gitlab等 github需要魔法上网,速度较慢因为在国外且仅仅支持Git,如果不是Git项目是不支持的;码云--gitee国内的代码托管平台,服务器在国内速度快一些&#…

[创业之路-190]:《华为战略管理法-DSTE实战体系》-2-华为DSTE战略管理体系概要

目录 一、DSTE战略管理体系与BLM的关系 1、DSTE战略管理体系概述 2、BLM模型概述 3、DSTE与BLM的关系 二、重新认识流程 1. 流程就是业务本身,流程是业务过程的可视化: 2. 流程是业务最佳路径的经验教训总结: 3. 流程是战略知识资产、…

多智能体架构 Insight-V:针对长链视觉推理瓶颈

多智能体架构 Insight-V:针对长链视觉推理瓶颈 https://arxiv.org/abs/2411.14432 推理智能体与总结智能体协作完成任务,实现复杂视觉任务中的高效推理与总结。其中写了一小段,用迭代 DPO 算法,在每一轮训练中,模型会…

ASP.NET |日常开发中连接Oracle数据库详解

ASP.NET |日常开发中连接Oracle数据库详解 前言一、安装和配置 Oracle 数据访问组件1.1 安装ODP.NET(Oracle Data Provider for.NET):1.2 引用相关程序集: 二、配置连接字符串2.1 连接字符串的基本组成部分&#xff1a…

生成树协议STP工作步骤

第一步:选择根桥 优先级比较:首先比较优先级,优先级值越小的是根桥MAC地址比较:如果优先级相同,则比较MAC地址。MAC地址小的是根桥。 MAC地址比较的时候从左往右,一位一位去比 第二步:所有非根…

Redis是什么?Redis和MongoDB的区别在那里?

Redis介绍 Redis(Remote Dictionary Server)是一个开源的、基于内存的数据结构存储系统,它可以用作数据库、缓存和消息中间件。以下是关于Redis的详细介绍: 一、数据结构支持 字符串(String) 这是Redis最…

minio 分布式文件管理

一、minio 是什么? MinIO构建分布式文件系统,MinIO 是一个非常轻量的服务,可以很简单的和其他应用的结合使用,它兼容亚马逊 S3 云存储服务接口,非常适合于存储大容量非结构化的数据,例如图片、视频、日志文件、备份数…

【射频IC学习笔记】4 D类功率放大器PA电路设计/loadpull仿真/输出功率及效率PAE计算

一、功率放大器设计指标及电路结构 1. 设计指标 功率放大器的指标要求如下图所示采用D类的开关类型功率放大器,理论上开关类型的PA能够做到100%的效率,但实际上会有一些偏差。像D类功放并不适合高功率射频信号的输出,因为其在射频功率上面的…

【数据结构——查找】二叉排序树(头歌实践教学平台习题)【合集】

目录😋 任务描述 相关知识 测试说明 我的通关代码: 测试结果: 任务描述 本关任务:实现二叉排序树的基本算法。 相关知识 为了完成本关任务,你需要掌握:二叉树的创建、查找和删除算法。具体如下: (1)由…

Unity UGUI图片循环列表插件

效果展示: 下载链接:https://gf.bilibili.com/item/detail/1111843026 概述: LoopListView2 是一个与 UGUI ScrollRect 相同的游戏对象的组件。它可以帮助 UGUI ScrollRect 以高效率和节省内存的方式支持任意数量的项目。 对于具有10,000个…

5G学习笔记之SNPN系列之ID和广播消息

目录 1. 概述 2. SNPN ID 3. SNPN广播消息 1. 概述 SNPN:Stand-alone Non-Public Network,独立的非公共网络,由NPN独立运营,不依赖与PLMN网络。 SNPN不支持的5GS特性: 与EPS交互 emergency services when the UE acce…

(后序遍历 简单)leetcode 101翻转二叉树

将根结点的左右结点看作 两个树的根结点,后序遍历(从叶子结点从下往上遍历) 两个树边遍历边比较。 左节点就左右根的后序遍历 右根结点就右左根的后序遍历来写 后序遍历(从叶子结点从下往上遍历) /*** Definition …

通过ajax的jsonp方式实现跨域访问,并处理响应

一、场景描述 现有一个项目A,需要请求项目B的某个接口,并根据B接口响应结果A处理后续逻辑。 二、具体实现 1、前端 前端项目A发送请求,这里通过jsonp的方式实现跨域访问。 $.ajax({ url:http://10.10.2.256:8280/ssoCheck, //请求的u…

Goby AI 2.0 自动化编写 EXP | Mitel MiCollab 企业协作平台 npm-pwg 任意文件读取漏洞(CVE-2024-41713)

漏洞名称:Mitel MiCollab 企业协作平台 npm-pwg 任意文件读取漏洞(CVE-2024-41713) English Name:Mitel MiCollab /npm-pwg File Read Vulnerability (CVE-2024-41713) CVSS core: 6.8 漏洞描述: Mitel MiCollab 是加拿大 Mitel 公司推出…

现代密码学总结(上篇)

现代密码学总结 (v.1.0.0版本)之后会更新内容 基本说明: ∙ \bullet ∙如果 A A A是随机算法, y ← A ( x ) y\leftarrow A(x) y←A(x)表示输入为 x x x ,通过均匀选择 的随机带运行 A A A,并且将输出赋给 y y y。 ∙ \bullet …

VMware:CentOS 7.* 连不上网络

1、修改网络适配 2、修改网卡配置参数 cd /etc/sysconfig/network-scripts/ vi ifcfg-e33# 修改 ONBOOTyes 3、重启网卡 service network restart 直接虚拟机中【ping 宿主机】,能PING通说明centOS和宿主机网络通了,只要宿主机有网,则 Ce…