【动手学深度学习】softmax回归的简洁实现详情

目录

🌊1. 研究目的

🌊2. 研究准备

🌊3. 研究内容

🌍3.1 softmax回归的简洁实现

🌍3.2 基础练习

🌊4. 研究体会


🌊1. 研究目的

  • 理解softmax回归的原理和基本实现方式;
  • 学习如何从零开始实现softmax回归,并了解其关键步骤;
  • 通过简洁实现softmax回归,掌握使用现有深度学习框架的能力;
  • 探索softmax回归在分类问题中的应用,并评估其性能。

🌊2. 研究准备

  • 根据GPU安装pytorch版本实现GPU运行研究代码;
  • 配置环境用来运行 Python、Jupyter Notebook和相关库等相关库。

🌊3. 研究内容

启动jupyter notebook,使用新增的pytorch环境新建ipynb文件,为了检查环境配置是否合理,输入import torch以及torch.cuda.is_available() ,若返回TRUE则说明研究环境配置正确,若返回False但可以正确导入torch则说明pytorch配置成功,但研究运行是在CPU进行的,结果如下:


🌍3.1 softmax回归的简洁实现

完成softmax回归的简洁实现的研究代码及练习内容如下:

导入必要库及模型:

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

初始化模型参数

# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)
        
net.apply(init_weights)

重新审视Softmax的实现

loss = nn.CrossEntropyLoss(reduction='mean')  # 将reduction设置为'mean'或'sum'

优化算法

trainer = torch.optim.SGD(net.parameters(), lr=0.1)

训练

num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)


🌍3.2 基础练习

1.尝试调整超参数,例如批量大小、迭代周期数和学习率,并查看结果。

在这个示例中,我将批量大小调整为128,迭代周期数调整为20,学习率调整为0.01。

import torch
from torch import nn
from d2l import torch as d2l

# 超参数调整
batch_size = 128  # 调整批量大小
num_epochs = 20  # 调整迭代周期数
learning_rate = 0.01  # 调整学习率

train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))
net.apply(init_weights)

loss = nn.CrossEntropyLoss(reduction='mean')

trainer = torch.optim.SGD(net.parameters(), lr=learning_rate)

d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

2.增加迭代周期的数量。为什么测试精度会在一段时间后降低?我们怎么解决这个问题?

当增加迭代周期的数量时,训练过程会继续进行更多的迭代,模型会有更多的机会学习训练数据中的模式和特征。通常情况下,增加迭代周期数量可以提高模型的训练精度。然而,如果过度训练,测试精度可能会在一段时间后开始降低。

这种情况被称为"过拟合"(overfitting)。过拟合发生时,模型在训练数据上表现得很好,但在新数据(测试数据)上表现较差。过拟合是由于模型过于复杂,过度记住了训练数据中的噪声和细节,而无法泛化到新数据。

为了解决过拟合问题,可以尝试以下几种方法:

  • 提前停止(Early Stopping):在训练过程中,跟踪训练误差和测试误差。一旦测试误差开始上升,就停止训练。这样可以防止模型过度拟合训练数据。
  • 正则化(Regularization):通过向损失函数添加正则化项,限制模型参数的大小,防止过度拟合。常见的正则化方法包括L1正则化和L2正则化。
  • 数据增强(Data Augmentation):通过对训练数据进行随机变换(如旋转、翻转、缩放等),增加训练样本的多样性,有助于提高模型的泛化能力。
  • 减小模型复杂度:减少模型的层数、节点数或参数量,使其更简单。简化模型可以降低过拟合的风险。
  • 使用更多的训练数据:增加训练数据量可以减少过拟合的可能性,因为模型将有更多的样本进行学习。

通过组合使用这些方法,可以有效地解决过拟合问题并提高模型的泛化能力。


🌊4. 研究体会

通过这次研究,我深入学习了softmax回归模型,理解了它的原理和基本实现方式。开始了解softmax回归的背景和用途,它在多类别分类问题中的应用广泛;学习了如何从零开始实现softmax回归,并掌握了其中的关键步骤。

通过简洁实现softmax回归,更加熟悉了深度学习框架的使用。可以通过几行代码完成模型的定义、数据的加载和训练过程。还学会了使用框架提供的工具来评估模型的性能,如计算准确率和绘制混淆矩阵。这使能够更方便地对模型进行调试和优化,以获得更好的分类结果。

最后,通过实验探索了softmax回归在分类问题中的应用,并评估了其性能。使用了一些真实的数据集,如MNIST手写数字数据集,来进行实验。在实验中,将数据集划分为训练集和测试集,用训练集来训练模型,然后用测试集来评估模型的性能。

在从零开始实现的实验中,对模型的性能进行了一些调优,比如调整学习率和迭代次数。观察到随着迭代次数的增加,模型的训练损失逐渐下降,同时在测试集上的准确率也在提升。这证明了的模型在一定程度上学习到了数据的规律,并能够泛化到新的样本。而在简洁实现的实验中,由于深度学习框架的优化算法和自动求导功能,模型的训练速度明显快于从零开始实现。同时,框架提供了更多的网络结构和调优方法,使能够更加灵活地构建和调整模型。在简洁实现中,我还尝试了一些不同的模型结构,比如加入隐藏层或使用更复杂的优化算法,以探索更高效的模型设计。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/676156.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

prometheus+alertmanager+webhook钉钉机器人告警

版本:centos7.9 python3.9.5 alertmanager0.25.0 prometheus2.46.0 安装alertmanager prometheus 配置webhook # 解压: tar -xvf alertmanager-0.25.0.linux-amd64.tar.gz tar -xvf prometheus-2.46.0.linux-amd64.tar.gz mv alertmanager-0.25.0.linu…

分享毕业论文要怎么写以及写论文工具推荐

毕业论文的写作是一个系统且需要深度研究的过程。以下将分步介绍毕业论文的写作方法,并推荐一些实用的写作工具。 毕业论文写作方法 选题: 确定研究方向和目标,选择具体且有一定研究价值的课题。建议选择应用型题目,结合理论和实…

【HarmonyOS】鸿蒙系统中应用权限等级介绍、定义、申请授权讲解

【HarmonyOS】鸿蒙系统中应用权限等级介绍、定义、申请授权讲解 针对权限等级,相对于主体来说,会有不同的细分概念。 一、权限APL等级: 首先在鸿蒙系统中,对于权限本身,分为三个等级:normal,s…

【JAVA WEB实用与优化技巧】如何使用本地.bat/.sh脚本快速将服务发布到测试环境?

文章目录 普通方式的springboot 使用docker打包发布【手动构建镜像模式】1. maven 打包可运行jar包2.手动打包镜像3.运行容器 全自动化本地命令发布到远程服务的方式配置ssh信任公钥获取公钥git 获取公钥方式: 桌面右键 -> open git gui here -> help -> show SSH key…

【数据库】MySQL表的操作

目录 一.创建表 二.查看表 三.修改表 四.删除表 一.创建表 基本语法: CREATE TABLE table_name(field1 datatype,field2 datatype,field3 datatype) character set 字符集 collate 校验规则 engine 储存引擎field表示列名 datatype表示列的类型 charatcer se…

初识C++ · 模拟实现stack和Queue

目录 前言: 1 Stack 1.1 双端队列 2 Queue 前言: 经历了list三个自定义类型的洗礼,来个简单的放松放松,即栈和队列: 文档记录的,栈和队列是一种容器适配器,它们不属于stl,但是它…

什么是云渲染?怎么使用呢?手把手教你

云渲染是一种利用云计算技术进行图形渲染的服务。它允许用户将渲染任务提交到云端服务器,由远程的计算机集群资源进行渲染操作,完成后再将渲染结果返回给用户。 云渲染技术的优势在于它可以提高渲染效率和质量,支持多任务同时加速渲染&#…

一个被无数人忽视的好项目

这个项目相信大家都在各大短视频平台见过,之前被我忽视了,当时的我自以为它是一时的热度,很快就会凉凉。但现在却生生被打脸了,这其实是非常好变现且流量也很大的一个好项目。 到底是什么好项目呢,没错,就…

[MYSQL]合作过至少三次的演员和导演

ActorDirector 表: ---------------------- | Column Name | Type | ---------------------- | actor_id | int | | director_id | int | | timestamp | int | ---------------------- timestamp 是这张表的主键(具有唯一值的列).编写解决方案…

黑马程序员——Spring框架——day04——SpringMVC基础

目录: SpringMVC简介 背景SpringMVC概述技术体系定位快速入门 目的需求步骤代码实操测试工具 PostMan简介PostMan安装PostMan使用知识点总结请求与参数处理 请求路径 环境准备问题分析解决方式请求方式 环境准备技术分析参数 基本数据类型POJO嵌套POJO数组集合&…

基于卷积神经网络(CNN)的深度迁移学习在声发射(AE)监测螺栓连接状况的应用

螺栓结构在工业中用于组装部件,它们在多种机械系统中扮演着关键角色。确保这些连接结构的健康状态对于航空航天、汽车和建筑等各个行业至关重要,因为螺栓连接的故障可能导致重大的安全风险、经济损失、性能下降和监管合规问题。 在早期阶段检测到螺栓松动…

四、利用启发式算法进行特定数据集的残差网络结构搜索【框架+源码】

背景:工作之后干的事情跟算法关联甚少,整理下读书期间的负责和参与的work,再熟悉学习下。 边熟悉边整理喽~ CV Tradictional workCV AI based work机械臂视觉抓取项目机器学习全流程 Pipeline训练平台OCR生产线喷码识别三维重建(SfM)ROS机器人…

springboot项目通过jar包部署到服务器

1. 将springboot项目打包成jar包 方式一:IDEA为例 出现 BUILD SUCCESS 证明打包成功,自动生成了 target 目录, jar 包就在目录里边 方式二:命令行(得配置好maven环境变量) 切换到项目目录下,…

springboot管理的各依赖版本查看

找一个springboot相关的依赖,比如这里我找mybatis 鼠标点击artifactId名称,图中蓝色字段,跳转到springboot依赖(鼠标悬停在上面变成蓝色表示可点击跳转), 点击spring-boot-dependencites,跳转到…

基于FPGA的SystemVerilog练习

文章目录 一、认识SystemVerilogSystemVerilog的语言特性SystemVerilog的应用领域SystemVerilog的优势SystemVerilog的未来发展方向 二、流水灯代码流水灯部分testbench仿真文件 三、用systemVerilog实现超声波测距计时器测距部分led部分数码管部分采样部分顶层文件引脚绑定效果…

QT入门知识回顾

1 QT简介 1.1 Qt模块: Qt Core模块: 是QT类库的核心,所有其他模块都依赖这个模块 Qt Gui模块: 提供GUI程序的基本功能 Qt Network模块:提供跨平台的网络功能 Qt Widgets模块:提供创建用户界面的功能 1.2Qt的signal/slot机制 任何一个类只要类体前部书写 Q_OBJ…

TH方程学习 (6)

一、内容介绍 本节旨在使用优化算法的方法,旨在利用最小的燃耗实现目标的交会,变量为目标的转移时间。整个问题描述为: 本节拟采取粒子群优化的算法,matlab中自带的粒子群函数为particleswarm,其用法不详细介绍&#…

LeetCode:环形链表II

文章收录于LeetCode专栏 LeetCode地址 环形链表II 题目 给定一个链表,返回链表开始入环的第一个节点。如果链表无环,则返回null。   为了表示给定链表中的环,我们使用整数pos来表示链表尾连接到链表中的位置(索引从0开始&#…

C++青少年简明教程:数组

C青少年简明教程:数组 C数组是一种存储固定大小连续元素的数据结构。数组中的每个元素都有一个索引,通过索引可以访问或修改数组中的元素。 在C中,数组中的元素数据类型必须一致。数组是一个连续的内存区域,用于存储相同类型的元…

std::shared_ptr,reset()函数

感慨&#xff1a;不深入阅读源代码&#xff0c;真的心虚&#xff0c;也用不好。 上代码&#xff1a; class A01 { public://std::weak_ptr<B0> b_ptr;int data{ 1234 };~A01() {std::cout << "A01 deleted\n";}void Print() { std::cout << &quo…