2-5 softmax 回归的简洁实现

我们发现通过深度学习框架的高级API能够使实现线性回归变得更加容易。 同样,通过深度学习框架的高级API也能更方便地实现softmax回归模型。 本节如在上节中一样, 继续使用Fashion-MNIST数据集,并保持批量大小为256。

import torch
from torch import nn  # 通过pytorch的nn的module
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

初始化模型参数

softmax回归的输出层是一个全连接层。 因此,为了实现我们的模型, 我们只需在Sequential中添加一个带有10个输出的全连接层。 同样,在这里Sequential并不是必要的, 但它是实现深度模型的基础。 我们仍然以均值0和标准差0.01随机初始化权重。

# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10)) # 这里使用了PyTorch中的nn.Sequential来构建一个顺序容器,将层按顺序添加到网络中
# nn.Flatten():这一层的作用是将输入的数据展平成一维。假设输入的数据是一个28x28的二维图像,展平后将变成一个784(28*28)长度的一维向量。
# nn.Linear(784, 10):这是一个全连接层(线性层),输入大小为784(展平后的图像向量),输出大小为10(假设有10个类别)。

def init_weights(m):  # 这里定义了一个函数init_weights,用于初始化网络中的权重。
    if type(m) == nn.Linear: # 这行代码检查传入的层是否为nn.Linear类型,即全连接层。
        nn.init.normal_(m.weight, std=0.01) # 如果该层是全连接层,则使用nn.init.normal_方法将该层的权重初始化为均值为0,标准差为0.01的正态分布随机值。

net.apply(init_weights);  # net.apply方法会遍历网络中的每一层,并将init_weights函数应用到每一层上,完成权重的初始化。

在交叉熵损失函数中传递未归一化的预测,并同时softmax及其对数

loss = nn.CrossEntropyLoss(reduction='none')

在这里,我们使用学习率为0.1的小批量随机梯度下降作为优化算法。 这与我们在线性回归例子中的相同,这说明了优化器的普适性。

trainer = torch.optim.SGD(net.parameters(), lr=0.1)

训练

接下来我们调用之前定义的训练函数来训练模型。

num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/781285.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

论文复现-基于决策树算法构建银行贷款审批预测模型(金融风控场景)

作者Toby,来源公众号:Python风控模型,基于决策树算法构建银行贷款审批预测模型 目录 1.金融风控论文复现 2.项目背景介绍 3.决策树介绍 4.数据集介绍 5.合规风险提醒 6.技术工具 7.实验过程 7.1导入数据 7.2数据预处理 7.3数据可…

隔离级别-隔离级别中的锁协议、隔离级别类型、隔离级别的设置、隔离级别应用

一、引言 1、DBMS除了采用严格的两阶段封锁协议来保证并发事务的可串行化,实现事务的隔离性,也可允许用户选择一个可以保证应用程序正确执行并且能够使并发度最大的隔离性等级 2、通常用隔离级别来描述隔离性等级,以下将主要介绍ANSI 92标准…

LaTeX教程(014)-LaTeX文档结构(14)

LaTeX教程(014)- LaTeX \LaTeX LATE​X文档结构(14) 2.3.3 multitoc - 将目录设置为多栏 multitoc包的使用方法相当简单,只需要调用这个包,并将要设置为多栏(默认是双栏)的目录指定到包选项中即可。如\usepackage[toc]{multitoc},设置的就是…

GIT 使用相关技巧记录

目录 1、commit 用户信息变更 全局用户信息(没有特殊配置的情况下默认直接用全局信息) 特定仓库用户信息(只针对于当前项目) 方法一:修改config文件 方法二:命令方式 2、idea同一代码推向多个远端仓库…

如何在应用运行时定期监控内存使用情况

如何在应用运行时定期监控内存使用情况 在 iOS 应用开发中,实时监控内存使用情况对于优化性能和排查内存泄漏等问题非常重要。本文将介绍如何在应用运行时定期监控内存使用情况,使用 Swift 编写代码并结合必要的工具和库。 1. 创建桥接头文件 首先&…

k8s 部署 springboot 项目内存持续增长问题分析解决

写在前面 工作中遇到,请教公司前辈解决,简单整理记忆博文内容涉及一次 GC 问题的分析以及解决理解不足小伙伴帮忙指正 😃,生活加油 99%的焦虑都来自于虚度时间和没有好好做事,所以唯一的解决办法就是行动起来,认真做完…

STM32-USART

本内容基于江协科技STM32视频学习之后整理而得。 文章目录 1. 串口通信协议1.1 通信接口1.2 串口通信1.3 硬件电路1.4 电平标准1.5 串口参数及时序1.6 串口时序 2. USART串口通信2.1 USART简介2.2 USART框图2.3 USART基本结构2.4 数据帧2.5 数据帧-配置停止位2.6 起始位侦测2.…

大连外贸建站公司wordpress主题模板

Robonaut萝卜纳特WP外贸站模板 适合用于工业机器人公司出口做外贸搭建公司官方网站使用的WordPress模板。 https://www.jianzhanpress.com/?p7091 优衣裳WordPress外贸建站模板 简洁的wordpress外贸独立站模板,适合服装、衣服、制衣外贸公司搭建公司官方网站使用…

ByteTrack论文阅读笔记

目录 ByteTrack: Multi-Object Tracking by Associating Every Detection Box摘要INTRODUCTION — 简介BYTE算法BYTE算法用Python代码实现实验评测指标轻量模型的跟踪性能 总结SORT算法简介ByteTrack算法和SORT算法的区别 ByteTrack: Multi-Object Tracking by Associating Eve…

location匹配和rewrite重定向

目录 location 匹配 location匹配的分类和优先级 优先级细分 实际网站中的使用规则 1.用精确匹配来实现网站的首页 访问网站的首页 ( /) 2.用正则匹配来实现静态请求的页面和图片 匹配静态页面 访问图片或者指定的后缀名 3.用一般匹配转发.php…

【qt】TCP的监听 (设置服务器IP地址和端口号)

TCP监听是在自己的IP地址上进行的。 当一个TCP服务器程序启动时,它会绑定到一个特定的IP地址和一个端口号上,以便可以接收来自该IP地址和端口号的传入连接请求. 所以我们要先来获取主机的IP地址和设置端口号. 注意: 服务器程序无法任意设置IP地址&…

数据结构学生信息顺序表

主程序 #include "fun.h" int main(int argc, const char *argv[]) { seq_p Screate_seq(); stu data; printf("请问要输入几个学生的数据:"); int n; scanf("%d",&n); while(n--) { prin…

cloudflare tunnels tcp

这里是官网的说明Cloudflare Tunnel Cloudflare Zero Trust docs 根据实际情况安装环境 tunnels除了http,https协议是直接暴露公网,tcp是类似ssh端口转发。 在需要内网穿透的局域网找一条机子部署代理 我这边是window cloudflared tunnel login #生成一个身份校…

防火墙概述

1、防火墙 防火墙顾名思义就是防止火灾发生时,火势烧到其它区域,使用由防火材料砌的墙。在网络安全中,防火墙的作用就是保护本地网络不受到外部网络或恶意程序的伤害。 防火墙的核心任务是控制和防护,即通过安全策略识别流量并做…

【周末闲谈】AI“抢饭碗”?绝对不是危言耸听

AI是在帮助开发者还是取代他们? 在软件开发领域,生成式人工智能(AIGC)正在改变开发者的工作方式。无论是代码生成、错误检测还是自动化测试,AI工具正在成为开发者的得力助手。然而,这也引发了对开发者职业前景和技能需求变化的讨论。AI究竟是在帮助开发者还是取代他们?…

【论文阅读】-- Visual Analytics for Model Selection in Time Series Analysis

时间序列分析中模型选择的可视化分析 摘要1 引言2 相关工作3 问题表征3.1 Box-Jenkins 方法论3.2 ARIMA 和季节性 ARIMA 模型3.3 模型规范3.4 模型拟合3.5 模型诊断 4 需求分析5 VA 用于时间序列分析中的模型选择5.1 VA选型流程说明5.2 TiMoVA 原型5.2.1 实施选择5.2.2 图形用户…

【JavaSE复习】数据结构、集合

JavaSE 复习 1.数据结构1.1 查找1.1.1 基本查找1.1.2 二分查找1.1.3 插值查找1.1.4 斐波那契查找1.1.5 分块查找1.1.6 分块查找的扩展(无规律数据) 1.2 排序1.2.1 冒泡排序1.2.2 选择排序1.2.3 插入排序1.2.4 快速排序 2. 集合2.1 基础集合2.1.1 集合和数…

MyBatis中二级缓存的配置与实现原理

大家好,我是王有志,一个分享硬核 Java 技术的金融摸鱼侠,欢迎大家加入 Java 人自己的交流群“共同富裕的 Java 人”。 上一篇文章《MyBatis中一级缓存的配置与实现原理》中,我们已经掌握了 MyBatis 一级缓存的配置(虽然…

使用AOP思想实现开闭原则下的流水日志输出

主要实现思想: 通过实现Convert接口来抽取公共组件,获取想要的标准模型。 现在有两个订单场景,一个保存订单,一个为更新订单。构造如下的服务类: import org.springframework.stereotype.Service;Service public clas…

pwm 呼吸灯(如果灯一直亮或者一直灭)

(这个文章收藏在我的csdn keil文件夹下面) 如果这样设置预分频和计数周期,那么算出来的pwm频率如下 人眼看起来就只能是一直亮或者灭,因为pwm的频率太高了,但是必须是频率够高,才能实现呼吸灯的缓慢亮缓慢…