softmax回实战

1.数据集

MNIST数据集 (LeCun et al., 1998) 是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。 我们将使用类似但更复杂的Fashion-MNIST数据集 (Xiao et al., 2017)。

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()

读取数据集

我们可以通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

读取小批量

为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。 回顾一下,在每次迭代中,数据加载器每次都会读取一小批量数据,大小为batch_size。 通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。

batch_size = 256

def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

整合所有组件

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

我们现在已经准备好使用Fashion-MNIST数据集,便于下面的章节调用来评估各种分类算法。

  • 数据迭代器是获得更高性能的关键组件。依靠实现良好的数据迭代器,利用高性能计算来避免减慢训练过程。

2.初始化模型参数

在softmax回归中,我们的输出与类别一样多。 因为我们的数据集有10个类别,所以网络输出维度为10。 因此,权重将构成一个784×10的矩阵, 偏置将构成一个1×10的行向量。

num_inputs = 784
num_outputs = 10

W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

3.定义softmax操作

def softmax(X):
    X_exp = torch.exp(X)
    partition = X_exp.sum(1, keepdim=True)
    return X_exp / partition  # 这里应用了广播机制

正如上述代码,对于任何随机输入,我们将每个元素变成一个非负数。 此外,依据概率原理,每行总和为1。

X = torch.normal(0, 1, (2, 5))
X_prob = softmax(X)
X_prob, X_prob.sum(1)

(tensor([[0.1686, 0.4055, 0.0849, 0.1064, 0.2347],
         [0.0217, 0.2652, 0.6354, 0.0457, 0.0321]]),
 tensor([1.0000, 1.0000]))

4.定义模型

定义softmax操作后,我们可以实现softmax回归模型。 下面的代码定义了输入如何通过网络映射到输出。 注意,将数据传递到模型之前,我们使用reshape函数将每张原始图像展平为向量。

def net(X):
    return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

5.定义损失函数

接下来,我们引入交叉熵损失函数。 这可能是深度学习中最常见的损失函数,因为目前分类问题的数量远远超过回归问题的数量。

def cross_entropy(y_hat, y):
    return - torch.log(y_hat[range(len(y_hat)), y])

小批量随机梯度下降来优化模型的损失函数,设置学习率为0.1

lr = 0.1

def updater(batch_size):
    return d2l.sgd([W, b], lr, batch_size)

6.分类精度

为了计算精度,我们执行以下操作。 首先,如果y_hat是矩阵,那么假定第二个维度存储每个类的预测分数。 我们使用argmax获得每行中最大元素的索引来获得预测类别。 然后我们将预测类别与真实y元素进行比较。 由于等式运算符“==”对数据类型很敏感, 因此我们将y_hat的数据类型转换为与y的数据类型一致。 结果是一个包含0(错)和1(对)的张量。 最后,我们求和会得到正确预测的数量。

def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

7.训练

在这里,我们重构训练过程的实现以使其可重复使用。 首先,我们定义一个函数来训练一个迭代周期。 请注意,updater是更新模型参数的常用函数,它接受批量大小作为参数。

def train_epoch_ch3(net, train_iter, loss, updater):  #@save
    """训练模型一个迭代周期(定义见第3章)"""
    # 将模型设置为训练模式
    if isinstance(net, torch.nn.Module):
        net.train()
    # 训练损失总和、训练准确度总和、样本数
    metric = Accumulator(3)
    for X, y in train_iter:
        # 计算梯度并更新参数
        y_hat = net(X)
        l = loss(y_hat, y)
        if isinstance(updater, torch.optim.Optimizer):
            # 使用PyTorch内置的优化器和损失函数
            updater.zero_grad()
            l.mean().backward()
            updater.step()
        else:
            # 使用定制的优化器和损失函数
            l.sum().backward()
            updater(X.shape[0])
        metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    # 返回训练损失和训练精度
    return metric[0] / metric[2], metric[1] / metric[2]

在展示训练函数的实现之前,我们定义一个在动画中绘制数据的实用程序类Animator, 它能够简化本书其余部分的代码。

class Animator:  #@save
    """在动画中绘制数据"""
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5)):
        # 增量地绘制多条线
        if legend is None:
            legend = []
        d2l.use_svg_display()
        self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        # 使用lambda函数捕获参数
        self.config_axes = lambda: d2l.set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts

    def add(self, x, y):
        # 向图表中添加多个数据点
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)

接下来我们实现一个训练函数, 它会在train_iter访问到的训练数据集上训练一个模型net。 该训练函数将会运行多个迭代周期(由num_epochs指定)。 在每个迭代周期结束时,利用test_iter访问到的测试数据集对模型进行评估。 我们将利用Animator类来可视化训练进度。

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save
    """训练模型(定义见第3章)"""
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],
                        legend=['train loss', 'train acc', 'test acc'])
    for epoch in range(num_epochs):
        train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
        test_acc = evaluate_accuracy(net, test_iter)
        animator.add(epoch + 1, train_metrics + (test_acc,))
    train_loss, train_acc = train_metrics
    assert train_loss < 0.5, train_loss
    assert train_acc <= 1 and train_acc > 0.7, train_acc
    assert test_acc <= 1 and test_acc > 0.7, test_acc

训练:

num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

8.预测

现在训练已经完成,我们的模型已经准备好对图像进行分类预测。 给定一系列图像,我们将比较它们的实际标签(文本输出的第一行)和模型预测(文本输出的第二行)。

def predict_ch3(net, test_iter, n=6):  #@save
    """预测标签(定义见第3章)"""
    for X, y in test_iter:
        break
    trues = d2l.get_fashion_mnist_labels(y)
    preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [true +'\n' + pred for true, pred in zip(trues, preds)]
    d2l.show_images(
        X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])

predict_ch3(net, test_iter)

9.总结:

  • 借助softmax回归,我们可以训练多分类的模型。
  • 训练softmax回归循环模型与训练线性回归模型非常相似:先读取数据,再定义模型和损失函数,然后使用优化算法训练模型。大多数常见的深度学习模型都有类似的训练过程。

10.简洁实现:

10.1读取数据

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

10.2定义模型,初始化参数

跟线性规划模型是一样的,只不过有十个输出

# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

10.3定义损失函数

loss = nn.CrossEntropyLoss(reduction='none')

10.4定义优化器

我们使用学习率为0.1的小批量随机梯度下降作为优化算法

updater = torch.optim.SGD(net.parameters(), lr=0.1)

10.5训练

调用上面写好的函数,只不过这次的模型,损失函数,优化器都是pytorch内置的

num_epochs = 10
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

10.6预测:

predict_ch3(net, test_iter)

10.7小结

  • 使用深度学习框架的高级API,我们可以更简洁地实现softmax回归。
  • 从计算的角度来看,实现softmax回归比较复杂。在许多情况下,深度学习框架在这些著名的技巧之外采取了额外的预防措施,来确保数值的稳定性。这使我们避免了在实践中从零开始编写模型时可能遇到的陷阱。
    数,优化器**都是pytorch内置的
num_epochs = 10
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

10.6预测:

predict_ch3(net, test_iter)

10.7小结

  • 使用深度学习框架的高级API,我们可以更简洁地实现softmax回归。
  • 从计算的角度来看,实现softmax回归比较复杂。在许多情况下,深度学习框架在这些著名的技巧之外采取了额外的预防措施,来确保数值的稳定性。这使我们避免了在实践中从零开始编写模型时可能遇到的陷阱。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/336592.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

STM32标准库开发—软件I2C读取MPU6050

软件模拟I2C时序 初始化I2C引脚以及时钟 void MyI2C_Init(void) { RCC_APB2PeriphClockCmd(RCC_APB2Periph_GPIOB,ENABLE);GPIO_InitTypeDef GPIO_InitStruct;GPIO_InitStruct.GPIO_ModeGPIO_Mode_Out_OD;GPIO_InitStruct.GPIO_PinGPIO_Pin_10|GPIO_Pin_11;GPIO_InitStruct.G…

pearcmd文件包含漏洞

1.什么是pearcmd.php pecl是PHP中用于管理扩展而使用的命令行工具&#xff0c;而pear是pecl依赖的类库。在7.3及以前&#xff0c;pecl/pear是默认安装的&#xff1b;在7.4及以后&#xff0c;需要我们在编译PHP的时候指定--with-pear才会安装 不过&#xff0c;在Docker任意版本…

(菜鸟自学)Metasploit漏洞利用——ms08-067

&#xff08;菜鸟自学&#xff09;漏洞利用——ms08-067 漏洞简介利用nmapMSF软件对XP sp3系统进行渗透攻击设置exploit模块参数RHOSTRPORTSMBPIPEExploit Target 设置有效载荷查找可兼容的有效载荷 渗透测试VNC 漏洞简介 MS08-067 是指微软于2008年发布的一个安全漏洞&#x…

重学Java 10 面向对象

正是风雨欲来的时候&#xff0c;火却越烧越旺了 ——24.1.20 重点 1.为何使用面向对象思想编程 2.如何使用面向对象思想编程 3.何时使用面向对象思想编程 4.利用代码去描述世间万物的分类 5.在一个类中访问另外一个类中的成员 -> new对象 6.成员变量和局部变量的区别 一…

利用HTML+CSS+JS打造炫酷时钟网页的完整指南

引言 在现代Web开发中&#xff0c;制作一个引人注目的时钟网页是一种常见而令人愉悦的体验。本文将介绍如何使用HTML、CSS和JavaScript来创建一个炫酷的时钟网页&#xff0c;通过这个项目&#xff0c;你将学到如何结合这三种前端技术&#xff0c;制作一个动态且美观的时钟效果…

接口测试 02 -- JMeter入门到实战

前言 JM eter毕竟是做压测的工具&#xff0c;自动化这块还是有缺陷。 如果公司做一些简单的接口自动化&#xff0c;可以考虑使用JMeter快速完成&#xff0c;如果想做完善的接口自动化体系&#xff0c;建议还是基于Python来做。 为什么学习接口测试要先从JMeter开始&#xff1f;…

C语言数据结构——顺序表

&#xff08;图片由AI生成&#xff09; 0.前言 在程序设计的世界里&#xff0c;数据结构是非常重要的基础概念。本文将专注于C语言中的一种基本数据结构——顺序表。我们将从数据结构的基本概念讲起&#xff0c;逐步深入到顺序表的内部结构、分类&#xff0c;最后通过一个实…

Unity常用的优化技巧集锦

Unity性能优化是面试的时候经常被问道的一些内容&#xff0c;今天给大家分享一些常用的Unity的优化技巧和思路&#xff0c;方便大家遇到问题时候参考与学习。 对啦&#xff01;这里有个游戏开发交流小组里面聚集了一帮热爱学习游戏的零基础小白&#xff0c;也有一些正在从事游…

电脑pdf如何转换成word格式?用它实现pdf文件一键转换

pdf转word格式可以用于提取和重用pdf文档中的内容&#xff0c;有时候&#xff0c;我们可能需要引用或引用pdf文档中的一些段落、表格或数据&#xff0c;通过将pdf转换为可编辑的Word文档&#xff0c;可以轻松地复制和粘贴所需内容&#xff0c;节省我们的时间&#xff0c;那么如…

【守护工地安全】YOLOv8实现安全帽检测

学习《OpenCV应用开发&#xff1a;入门、进阶与工程化实践》一书 做真正的OpenCV开发者&#xff0c;从入门到入职&#xff0c;一步到位&#xff01; 数据集 该图像数据集包含8000张图像&#xff0c;两个类别分别是安全帽与人、以其中200多张图像为验证集&#xff0c;其余为训…

谁说知识库都是英文的 今天就来一个中文版的

1.安装 1.1创建目录 mkdir -p /opt/trilium-cn cd /opt/trilium-cn 1.2.编写docker-compose.yml文件 version: 3 services:trilium-cn:image: nriver/trilium-cnrestart: alwaysports:- "10012:8080"volumes:# 把同文件夹下的 trilium-data 目录映射到容器内- /opt…

基于JavaWeb+SSM+Vue基于微信小程序生鲜云订单零售系统的设计和实现

基于JavaWebSSMVue基于微信小程序生鲜云订单零售系统的设计和实现 滑到文末获取源码Lun文目录前言主要技术系统设计功能截图订阅经典源码专栏Java项目精品实战案例《500套》 源码获取 滑到文末获取源码 Lun文目录 目录 1系统概述 1 1.1 研究背景 1 1.2研究目的 1 1.3系统设计…

6.4.4释放音频

6.4.4释放音频 许多Flash动画里的音乐或歌曲非常好听&#xff0c;能不能在没有源文件的情况下把里面的声音文件取出来呢&#xff1f;利用Swf2VideoConverter2可以轻松做到这一点。 1&#xff0e;单击“添加”按钮&#xff0c;在弹出的下拉菜单中选择“添加文件”&#xff0c;…

java小项目:简单的收入明细记事本,超级简单(不涉及数据库,通过字符串来记录)

一、效果 二、代码 2.1 Acount类 package com.demo1;public class Acount {public static void main(String[] args) {String details "收支\t账户金额\t收支金额\t说 明\n"; //通过字符串来记录收入明细int balance 10000;boolean loopFlag true;//控制循…

iOS中的视频播放

引言 在数字时代&#xff0c;视频已成为我们日常沟通和娱乐不可或缺的一部分。从简短的社交媒体剪辑到全长电影&#xff0c;我们对流畅、高质量视频播放的需求从未如此强烈。对于开发者来说&#xff0c;为iOS用户提供无瑕疵的视频体验是一项挑战&#xff0c;也是一种艺术。本篇…

数据结构之list类

前言 list是列表类。从list 类开始&#xff0c;我们就要接触独属于 Python 的数据类型了。Python 简单、易用&#xff0c;很大一部分原因就是它对基础数据类型的设计各具特色又相辅相成。 话不多说&#xff0c;让我们开始学习第一个 Python 数据类型一list。 1. list的赋值 输…

手把手教你使用 VS Code 运行和调试 Python 程序

本文以 Ubuntu 系统为例&#xff0c;介绍如何在 VS Code 上配置 Python 的编程环境&#xff0c;并把 Python 程序运行、调试起来。由于 Python 是解释型语言&#xff0c;并且 VS Code 中提供了内置的调试器可用于调试 Python 代码&#xff0c;因此配置和操作流程比调试 C/C 代码…

SpringSecurity+JWT前后端分离架构登录认证

目录 1. 数据库设计 2. 代码设计 登录认证过滤器 认证成功处理器AuthenticationSuccessHandler 认证失败处理器AuthenticationFailureHandler AuthenticationEntryPoint配置 AccessDeniedHandler配置 UserDetailsService配置 Token校验过滤器 登录认证过滤器接口配置…

C++将信息输入到文件内

第一步检查文件是否打开&#xff0c;用到头文件&#xff1a; #include <fstream> #include <sstream> 文件打开的函数为 file.isopen() 信息输入到文件应该为 file << "" << value; 注意是file<< 如图 定义file ofstream f…

计算机组成原理 第一弹

ps&#xff1a;本文章的图片来源都是来自于湖科大教书匠高老师的视频&#xff0c;声明&#xff1a;仅供自己复习&#xff0c;里面加上了自己的理解 这里附上视频链接地址&#xff1a;1-2 计算机的发展_哔哩哔哩_bilibili ​​ 目录 &#x1f680;计算机系统 &#x1f680;计…