深度学习(五)softmax 回归之:分类算法介绍,如何加载 Fashion-MINIST 数据集

Softmax 回归

基本原理

回归和分类,是两种深度学习常用方法。回归是对连续的预测(比如我预测根据过去开奖列表下次双色球号),分类是预测离散的类别(手写语音识别,图片识别)。

1699720169075

现在我们已经对回归的处理有一定的理解了,如何过渡到分类呢?

假设我们有 n 类,首先我们要编码这些类让他们变成数据。所有类变成一个列向量。

y = [ y 1 , y 2 , . . . y n ] T y=[y_1,y_2,...y_n]^T y=[y1,y2,...yn]T

有一个数据属于第 i 类,那么他的列向量就是:

y = [ 0 , 0 , . . . , 1 , . . . , 0 , 0 ] T y=[0,0,...,1,...,0,0]^T y=[0,0,...,1,...,0,0]T

也就是只有他所在的那个类的元素=1.

可以用均方损失训练,通过概率判断最终选用哪一个。

Softmax 回归就是一种分类方式(回归问题在多分类上的推广)。首先确定输入特征数和输出类别数。比如上图中我们有4个特征和3个可能的类别,那么计算各自概率的公式包括3个线性回归:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

可以看出 Softmax 是全连接的单层神经网络。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们让所有输出结果归一化后,从中选择出最大可能的,置信度最高的分类结果。

image-20231112100423488

采用 e 的指数可以让值全变为非负。

用真实的概率向量-我们预测得到的概率向量就是损失。真实值就是只有一个1的列向量。

交叉熵损失:

image-20231112101259670

可见**分类问题,我们不关心对非正确的预测值,只关心正确预测值是否足够大。**因为正确值是只有一个元素为1的列向量。

常用的损失函数

L2 Loss:均方损失。

image-20231112101555142

L1 Loss:绝对值损失。

image-20231112101829868

L2 梯度是一条倾斜直线,对于梯度下降算法等更为合适;L1 是一个跳变,梯度要么 -1 要么 1. 如图是 L1 L2 的梯度。

image-20231112102551104

我们可以结合两者,得到一个新的损失函数(鲁棒损失 Huber Robust):

KaTeX parse error: {equation} can be used only in display mode.

image-20231112102721527

图像分类数据集

MINIST 是一个常用图像分类数据集,但是过于简单。后来的 upgrade 版叫 Fashion-MINIST(服装分类).

首先,我们研究研究怎么加载训练数据集,以便后面测试算法用。

# 导包
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()

d2l.use_svg_display()

# 下载数据集并读取到内存
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)		# 训练数据集
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)	# 测试数据集用于评估性能

# 定义函数用于返回对应索引的标签
def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

# 图像可视化,让结果看着更直观,比如下面那个绿色图的样子
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """绘制图像列表"""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

# 我们先读一点数据集看看啥样的
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

1699980345931

# 通过内置数据加载器读取一批量数据,自动随机打乱读取,不需要我们自己定义
batch_size = 256

def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

测量以上用时基本2-3s。

总结整合以上数据读取过程,代码如下:

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

加载图像还可以调整其大小。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/158368.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

redis运维(九)字符串(二)字符串过期时间

一 字符串过期时间 细节点: 注意命令的入参和返回值 ① 再谈过期时间 说明: 设置key的同时并且设置过期时间,是一个原子操作 ② ttl 检查过期时间 ③ persist 删除过期时间 ④ redis 删除过期key的机制 ⑤ 惰性删除 惰性理解:让过期…

自动驾驶-BEV感知综述

BEV感知综述 随着自动驾驶传感器配置多模态化、多源化,将多源信息在unified View下表达变得更加关键。BEV视角下构建的local map对于多源信息融合及理解更加直观简洁,同时对于后续规划控制模块任务的开展也更为方便。BEV感知的核心问题是: …

[Linux版本Debian系统]安装cuda 和对应的cudnn以cuda 12.0为例

写在前面 先检查自己有没有安装使用wget的命令,没有的话输入下面命令安装: apt-get install wget -y查看gcc的安装 sudo apt install gcc #安装gcc gcc --version #查看gcc是否安装成功 #若上述命令不成功使用下面的命令尝试之后再执行上面…

长短期记忆(LSTM)与RNN的比较:突破性的序列训练技术

长短期记忆(Long short-term memory, LSTM)是一种特殊的RNN,主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。 Why LSTM提出的动机是为了解…

【powershell】入门和示例

▒ 目录 ▒ 🛫 导读开发环境 1️⃣ 简介用途IDE解决此系统上禁止运行脚本 2️⃣ 语法3️⃣ 实战数据库备份执行循环拷贝文件夹 🛬 文章小结📖 参考资料 🛫 导读 开发环境 版本号描述文章日期2023-11-17操作系统Win10 - 22H21904…

Java JVM虚拟机

加载字节码文件.class 1字节一般为8位 字节码结构: 第一部分 4字节 cafebaby 第二部分 版本号 00 00 00 32, 第三部分 常量数量 count 第四部分常量池 常量类型表示: 继承关系改变 1.1以后 后面是属性方法 等参数 通过javap 反编译class ,javap xx.class javap -c xxx.…

【Redis】springboot整合redis(模拟短信注册)

要保证redis的服务器处于打开状态 上一篇: 基于session的模拟短信注册 https://blog.csdn.net/m0_67930426/article/details/134420531 整个流程是,前端点击获取验证码这个按钮,后端拿到这个请求,通过RandomUtil 工具类的方法生…

.nc格式文件的显示及特殊裁剪方式

最近我们遇到一个nc格式的文件,需要将它做成报告插图,bing搜索一番以后,了解到nc的全名为NetCDF(network Common Data Form),是一种网络通用数据格式,广泛用于大气科学、水文、海洋学、环境模拟、地球物理等诸多领域。…

【超好用的工具库】hutool-all工具库的基本使用

简介(可不看): hutool-all是一个Java工具库,提供了许多实用的工具类和方法,用于简化Java开发过程中的常见任务。它包含了各种模块,涵盖了字符串操作、日期时间处理、加密解密、文件操作、网络通信、图片处…

指针传2(续集)

近期的天气是真的冷啊,老铁们一定要照顾好自己呀,注意防寒保暖,没有你们我怎么活啊! 上次的指针2的末尾,给大家分享了两个有趣的代码,今天就先来讲一讲那两个代码: 两个有趣的代码:…

Logrotate日志切割工具的应用与配置

Logrotate日志切割工具的应用与配置,以下是公司生产环境亲测,跳了不少的坑,最后已经部署到生产了,可放心使用 简介 Logrotate是一个在Unix和类Unix系统(如Linux)上用于管理日志文件的实用程序。它可以帮助…

官宣定档 | 3大主题论坛重磅行业颁奖,CGT Asia 2024第五届亚洲细胞与基因治疗创新峰会特色亮点抢先看

细胞与基因治疗代表着未来医学发展的趋势,随着技术的不断更新与发展与支持政策的持续推出,细胞与基因治疗产业的希望被无限扩大,自第一批细胞治疗与基因治疗产品上市到如今,行业已经进入快车道,步入高速发展期&#xf…

如何确保消息不会丢失

本篇文章大家还可以通过浏览我的博客阅读。如何确保消息不会丢失 - 胤凯 (oyto.github.io)很多人刚开始接触消息队列的时候&#xff0c;最经常遇到的一个问题就是丢消息了。<!--more-->对于大部分业务来说&#xff0c;丢消息意味着丢数据&#xff0c;是完全无法接受的。 …

C语言--给定一行字符串,获取其中最长单词【图文详解】

一.问题描述 给定一行字符串,获取其中最长单词。 比如&#xff1a;给定一行字符串&#xff1a; hello wo shi xiao xiao su 输出&#xff1a;hello 二.题目分析 “打擂台算法”&#xff0c;具体内容小伙伴们可以参考前面的内容。 三.代码实现 char* MaxWord(const char* str)…

CMakeLists.txt基础指令与cmake-gui生成VS项目的步骤

简介 本博客主要介绍cmake的基本指令&#xff0c;同时&#xff0c;很多使用Visual Studio小白从Gitbub下载项目源码后&#xff0c;看到CMakeLists.txt&#xff0c;不知道如何使用Visual Studio编译源码&#xff1b;针对以上问题&#xff0c;做一下简单操作与解释&#xff0c;方…

c语言-数据结构-堆

目录 一、二叉树 1、二叉树的概念 2、完全二叉树和满二叉树 3、完全二叉树的顺序存储 二、堆 2、堆的概念与结构 3、堆的创建及初始化 4、堆的插入&#xff08;小堆&#xff09; 5、堆的删除 6、显示堆顶元素 7、显示堆里的元素个数 8、测试堆的各个功能 9、 实现堆…

零代码编程:用ChatGPT批量转换多个视频文件夹到音频并自动移动文件夹

有很多个视频文件夹&#xff1a; 要全部转成音频&#xff0c;然后复制到另一个文件夹。 在ChatGPT中输入如下提示词&#xff1a; 你是一个Python编程专家&#xff0c;要完成一个批量将Mp4视频转为Mp3音频的任务&#xff0c;具体步骤如下&#xff1a; 打开文件夹&#xff1a;…

机器学习 天气识别

>- **&#x1f368; 本文为[&#x1f517;365天深度学习训练营](https://mp.weixin.qq.com/s/Nb93582M_5usednAKp_Jtw) 中的学习记录博客** >- **&#x1f356; 原作者&#xff1a;[K同学啊 | 接辅导、项目定制](https://mtyjkh.blog.csdn.net/)** >- **&#x1f680;…

matlab层次分析法模型及相关语言基础

发现更多计算机知识&#xff0c;欢迎访问Cr不是铬的个人网站 代码放在最后面! 这篇文章是学习层次分析法模型的笔记。 1.什么时候用层次分析法 层次分析法是建模比赛中最基础的模型之一&#xff0c;其主要用于解决评价类问题&#xff08;例如&#xff1a;选择哪种方案最好、…

Mysql数据库 16.SQL语言 数据库事务

一、数据库事务 数据库事务介绍——要么全部成功要么全部失败 我们把完成特定的业务的多个数据库DML操作步骤称之为一个事务 事务——就是完成同一个业务的多个DML操作 例&#xff1a; 数据库事务四大特性 原子性&#xff08;A&#xff09;&#xff1a;一个事务中的多个D…