softmax回归实战-分类

1.数据集

MNIST数据集 (LeCun et al., 1998) 是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。 我们将使用类似但更复杂的Fashion-MNIST数据集 (Xiao et al., 2017)。

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()

读取数据集

我们可以通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

读取小批量

为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。 回顾一下,在每次迭代中,数据加载器每次都会读取一小批量数据,大小为batch_size。 通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。

batch_size = 256

def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

整合所有组件

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

我们现在已经准备好使用Fashion-MNIST数据集,便于下面的章节调用来评估各种分类算法。

  • 数据迭代器是获得更高性能的关键组件。依靠实现良好的数据迭代器,利用高性能计算来避免减慢训练过程。

2.初始化模型参数

在softmax回归中,我们的输出与类别一样多。 因为我们的数据集有10个类别,所以网络输出维度为10。 因此,权重将构成一个784×10的矩阵, 偏置将构成一个1×10的行向量。

num_inputs = 784
num_outputs = 10

W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

3.定义softmax操作

def softmax(X):
    X_exp = torch.exp(X)
    partition = X_exp.sum(1, keepdim=True)
    return X_exp / partition  # 这里应用了广播机制

正如上述代码,对于任何随机输入,我们将每个元素变成一个非负数。 此外,依据概率原理,每行总和为1。

X = torch.normal(0, 1, (2, 5))
X_prob = softmax(X)
X_prob, X_prob.sum(1)

(tensor([[0.1686, 0.4055, 0.0849, 0.1064, 0.2347],
         [0.0217, 0.2652, 0.6354, 0.0457, 0.0321]]),
 tensor([1.0000, 1.0000]))

4.定义模型

定义softmax操作后,我们可以实现softmax回归模型。 下面的代码定义了输入如何通过网络映射到输出。 注意,将数据传递到模型之前,我们使用reshape函数将每张原始图像展平为向量。

def net(X):
    return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

5.定义损失函数

接下来,我们引入交叉熵损失函数。 这可能是深度学习中最常见的损失函数,因为目前分类问题的数量远远超过回归问题的数量。

def cross_entropy(y_hat, y):
    return - torch.log(y_hat[range(len(y_hat)), y])

小批量随机梯度下降来优化模型的损失函数,设置学习率为0.1

lr = 0.1

def updater(batch_size):
    return d2l.sgd([W, b], lr, batch_size)

6.分类精度

为了计算精度,我们执行以下操作。 首先,如果y_hat是矩阵,那么假定第二个维度存储每个类的预测分数。 我们使用argmax获得每行中最大元素的索引来获得预测类别。 然后我们将预测类别与真实y元素进行比较。 由于等式运算符“==”对数据类型很敏感, 因此我们将y_hat的数据类型转换为与y的数据类型一致。 结果是一个包含0(错)和1(对)的张量。 最后,我们求和会得到正确预测的数量。

def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

7.训练

在这里,我们重构训练过程的实现以使其可重复使用。 首先,我们定义一个函数来训练一个迭代周期。 请注意,updater是更新模型参数的常用函数,它接受批量大小作为参数。

def train_epoch_ch3(net, train_iter, loss, updater):  #@save
    """训练模型一个迭代周期(定义见第3章)"""
    # 将模型设置为训练模式
    if isinstance(net, torch.nn.Module):
        net.train()
    # 训练损失总和、训练准确度总和、样本数
    metric = Accumulator(3)
    for X, y in train_iter:
        # 计算梯度并更新参数
        y_hat = net(X)
        l = loss(y_hat, y)
        if isinstance(updater, torch.optim.Optimizer):
            # 使用PyTorch内置的优化器和损失函数
            updater.zero_grad()
            l.mean().backward()
            updater.step()
        else:
            # 使用定制的优化器和损失函数
            l.sum().backward()
            updater(X.shape[0])
        metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    # 返回训练损失和训练精度
    return metric[0] / metric[2], metric[1] / metric[2]

在展示训练函数的实现之前,我们定义一个在动画中绘制数据的实用程序类Animator, 它能够简化本书其余部分的代码。

class Animator:  #@save
    """在动画中绘制数据"""
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5)):
        # 增量地绘制多条线
        if legend is None:
            legend = []
        d2l.use_svg_display()
        self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        # 使用lambda函数捕获参数
        self.config_axes = lambda: d2l.set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts

    def add(self, x, y):
        # 向图表中添加多个数据点
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)

接下来我们实现一个训练函数, 它会在train_iter访问到的训练数据集上训练一个模型net。 该训练函数将会运行多个迭代周期(由num_epochs指定)。 在每个迭代周期结束时,利用test_iter访问到的测试数据集对模型进行评估。 我们将利用Animator类来可视化训练进度。

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save
    """训练模型(定义见第3章)"""
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],
                        legend=['train loss', 'train acc', 'test acc'])
    for epoch in range(num_epochs):
        train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
        test_acc = evaluate_accuracy(net, test_iter)
        animator.add(epoch + 1, train_metrics + (test_acc,))
    train_loss, train_acc = train_metrics
    assert train_loss < 0.5, train_loss
    assert train_acc <= 1 and train_acc > 0.7, train_acc
    assert test_acc <= 1 and test_acc > 0.7, test_acc

训练:

num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

8.预测

现在训练已经完成,我们的模型已经准备好对图像进行分类预测。 给定一系列图像,我们将比较它们的实际标签(文本输出的第一行)和模型预测(文本输出的第二行)。

def predict_ch3(net, test_iter, n=6):  #@save
    """预测标签(定义见第3章)"""
    for X, y in test_iter:
        break
    trues = d2l.get_fashion_mnist_labels(y)
    preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [true +'\n' + pred for true, pred in zip(trues, preds)]
    d2l.show_images(
        X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])

predict_ch3(net, test_iter)

9.总结:

  • 借助softmax回归,我们可以训练多分类的模型。
  • 训练softmax回归循环模型与训练线性回归模型非常相似:先读取数据,再定义模型和损失函数,然后使用优化算法训练模型。大多数常见的深度学习模型都有类似的训练过程。

10.简洁实现:

10.1读取数据

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

10.2定义模型,初始化参数

跟线性规划模型是一样的,只不过有十个输出

# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

10.3定义损失函数

loss = nn.CrossEntropyLoss(reduction='none')

10.4定义优化器

我们使用学习率为0.1的小批量随机梯度下降作为优化算法

updater = torch.optim.SGD(net.parameters(), lr=0.1)

10.5训练

调用上面写好的函数,只不过这次的模型,损失函数,优化器都是pytorch内置的

num_epochs = 10
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

10.6预测:

predict_ch3(net, test_iter)

10.7小结

  • 使用深度学习框架的高级API,我们可以更简洁地实现softmax回归。
  • 从计算的角度来看,实现softmax回归比较复杂。在许多情况下,深度学习框架在这些著名的技巧之外采取了额外的预防措施,来确保数值的稳定性。这使我们避免了在实践中从零开始编写模型时可能遇到的陷阱。
    数,优化器**都是pytorch内置的
num_epochs = 10
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

10.6预测:

predict_ch3(net, test_iter)

10.7小结

  • 使用深度学习框架的高级API,我们可以更简洁地实现softmax回归。
  • 从计算的角度来看,实现softmax回归比较复杂。在许多情况下,深度学习框架在这些著名的技巧之外采取了额外的预防措施,来确保数值的稳定性。这使我们避免了在实践中从零开始编写模型时可能遇到的陷阱。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/344411.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

javaspringbootmysql开放实验室管理系统03361-计算机毕业设计项目选题推荐(附源码)

摘 要 随着社会的发展&#xff0c;社会的方方面面都在利用信息化时代的优势。互联网的优势和普及使得各种系统的开发成为必需。 本文以实际运用为开发背景&#xff0c;运用软件工程原理和开发方法&#xff0c;它主要是使用动态网页开发技术java作为系统的开发语言&#xff0c;M…

如何把openwrt的ipk软件包安装到ubuntu上

前提&#xff1a;都是arm64的架构的软件包。 下载openwrt的ipk软件包 1. 从https://pkgs.org/ 查找下载软件包&#xff1a; 本文以swconfig软件包为例&#xff0c;下载swconfig和相关的依赖软件包&#xff1a; swconfig_12_aarch64_cortex-a72.ipk libuci20130104_2021-10-2…

eNSP学习——交换机基础

目录 原理概述 实验目的 实验步骤 实验内容 实验拓扑 实验步骤 基础配置 配置交换机双工模式 配置接口速率 思考题 原理概述 交换机之间通过以太网电接口对接时需要协商一些接口参数&#xff0c;比如速率、双工模式等。   接口速率&#xff1a;指的是交换机接口每秒钟传…

【操作系统】实验五 添加内核模块

&#x1f57a;作者&#xff1a; 主页 我的专栏C语言从0到1探秘C数据结构从0到1探秘Linux &#x1f618;欢迎关注&#xff1a;&#x1f44d;点赞&#x1f64c;收藏✍️留言 &#x1f3c7;码字不易&#xff0c;你的&#x1f44d;点赞&#x1f64c;收藏❤️关注对我真的很重要&…

apkpure下载Google Play中APP的APK安装包

比如Google Play上有个应用叫iSmart DV&#xff0c;想下载APK&#xff0c;怎么做呢 打开apkpure(https://apkpure.net/)&#xff0c;输入对应的app名称即可下载

Making Large Language Models Perform Better in Knowledge Graph Completion论文阅读

文章目录 摘要1.问题的提出引出当前研究的不足与问题KGC方法LLM幻觉现象解决方案 2.数据集和模型构建数据集模型方法基线方法任务模型方法基于LLM的KGC的知识前缀适配器知识前缀适配器 与其他结构信息引入方法对比 3.实验结果与分析结果分析&#xff1a;可移植性实验&#xff1…

Kafka-服务端-KafkaController

Broker能够处理来自KafkaController的LeaderAndIsrRequest、StopReplicaRequest、UpdateMetadataRequest等请求。 在Kafka集群的多个Broker中&#xff0c;有一个Broker会被选举为Controller Leader,负责管理整个集群中所有的分区和副本的状态。 例如&#xff1a;当某分区的Le…

第92讲:MySQL主从复制集群故障排查思路汇总

文章目录 1.从库I/O线程处于Connecting状态2.从库I/O线程处于No状态3.从库SQL线程处于No状态 1.从库I/O线程处于Connecting状态 从库的I/O线程处于Connection连接中的状态&#xff0c;一般都是连接不上主库导致&#xff1a; 可能由于网络不通&#xff0c;防火墙的干扰导致从库…

MongoDB系列之一文总结索引

概述 分类 索引的分类&#xff1a; 按照索引包含的字段数量&#xff0c;可分为单键索引&#xff08;单字段索引&#xff09;和组合索引&#xff08;联合索引、复合索引&#xff09;按照索引字段的类型&#xff0c;可以分为主键索引和非主键索引按照索引节点与物理记录的对应…

2024免费mathtype7.4.4安装注册步骤教程

数学建模中对公式的编辑有很高的要求&#xff0c;mathtype是一款专业的数学公式编辑工具&#xff0c;能够帮助用户在各种文档中插入复杂的数学公式和符号。 一 Mathtype 的下载安装 1.1 安装前须知 解压和安装前&#xff0c;需要将电脑的杀毒软件或者防火墙关掉&#xff0c;如…

python222网站实战(SpringBoot+SpringSecurity+MybatisPlus+thymeleaf+layui)-后台管理主页面实现

锋哥原创的SpringbootLayui python222网站实战&#xff1a; python222网站实战课程视频教程&#xff08;SpringBootPython爬虫实战&#xff09; ( 火爆连载更新中... )_哔哩哔哩_bilibilipython222网站实战课程视频教程&#xff08;SpringBootPython爬虫实战&#xff09; ( 火…

vue3跨域请求及一些常用配置

在使用vue3开发的时候&#xff0c;总免不了做一些基础的配置。比如跨域配置&#xff0c;一些常用函数的封装等等。接下来&#xff0c;我就做一些自己在在开发中所运用到一些常用配置。 一、跨域配置 其实&#xff0c;对于跨域配置&#xff0c;我之前的博文中也有说过&#xff0…

Linux的常见指令和基本操作演绎【复习篇章一】

文章目录 前言下载安装 XShellXShell 下的复制粘贴热键操作01.ls指令tree 02.cd指令03.touch指令04.mkdir指令&#xff08;重要&#xff09;&#xff1a;05.rmdir指令 && rm 指令&#xff08;重要&#xff09;06.组合07.man指令&#xff08;重要&#xff09;&#xff1…

Packet Tracer - VLAN 间路由练习

地址分配表 设备 接口 IP 地址 子网掩码 默认网关 R1 G0/0 172.17.25.2 255.255.255.252 不适用 R1 G0/1.10 172.17.10.1 255.255.255.0 不适用 R1 G0/1.20 172.17.20.1 255.255.255.0 不适用 R1 G0/1.30 172.17.30.1 255.255.255.0 不适用 R1 G0/1.…

前景贴纸类特效SDK,面向企业的技术解决方案

随着数字媒体技术的快速发展&#xff0c;视频内容在社交媒体、广告、教育等领域的应用越来越广泛。为了增加视频的吸引力和趣味性&#xff0c;许多企业开始寻求在视频中添加特效和贴纸。美摄科技的前景贴纸类特效SDK为企业提供了一种高效、灵活的解决方案&#xff0c;满足不同的…

R语言VRPM包绘制多种模型的彩色列线图

列线图&#xff0c;又称诺莫图&#xff08;Nomogram&#xff09;&#xff0c;它是建立在回归分析的基础上&#xff0c;使用多个临床指标或者生物属性&#xff0c;然后采用带有分数高低的线段&#xff0c;从而达到设置的目的&#xff1a;基于多个变量的值预测一定的临床结局或者…

生命在于折腾——WeChat机器人的研究和探索

一、前言 2022年&#xff0c;我玩过原神&#xff0c;当时看到了云崽的QQ机器人&#xff0c;很是感兴趣&#xff0c;支持各种插件&#xff0c;查询游戏内角色相关信息&#xff0c;当时我也自己写了几个插件&#xff0c;也看到很多大佬编写的好玩的插件&#xff0c;后来因为QQ不…

微信聊天记录生成词云

目录 前置准备一、获取微信聊天记录&#xff08;一&#xff09;配置MuMu模拟器&#xff08;二&#xff09;微信数据备份与恢复&#xff08;三&#xff09;获取微信聊天记录文件至电脑&#xff08;四&#xff09;获取EnMicroMsg.db的密钥&#xff08;五&#xff09;使用SQLciphe…

详解线性分组码(linear code)

目录 一. 介绍 二. 线性分组码 三. 生成矩阵 四. 对偶编码 五. 校验矩阵 六. 陪集编码 七. 小结 一. 介绍 Low-density parity-check&#xff0c;简称LDPC码&#xff0c;翻译为低密度奇偶校验码。 我们所熟悉的LDPC码就是一个典型的线性分组码&#xff08;linear bloc…

2023年度AI盘点 AIGC|AGI|ChatGPT|人工智能大模型

前言 「作者主页」&#xff1a;雪碧有白泡泡 「个人网站」&#xff1a;雪碧的个人网站 2023年是人工智能大语言模型大爆发的一年&#xff0c;一些概念和英文缩写也在这一年里集中出现&#xff0c;很容易混淆&#xff0c;甚至把人搞懵。 文章目录 前言01 《ChatGPT 驱动软件开…