动手学深度学习——softmax分类

1. 分类问题

回归与分类的区别:

  • 回归可以用于预测多少的问题, 比如"预测房屋被售出价格",它是个单值输出。
  • softmax可以用来预测分类问题,例如"某个图片中是猫、鸡还是狗?",这是一个多值输出,输出个数等于类别个数,输出的第i个值表示预测为第i类别的概率。

两者的区别在于是问多少还是问哪一个?

分类可以用来描述下面两个问题:

  1. 样本属于哪个类别
  2. 样本属于每个类别的概率

比较经典的分类问题有:

  1. MNIST数据集,手写数字识别,有0-9十个类别。
  2. ImageNet数据集,从一百万张图片中识别自然物体,有1000个类别。
  3. kaggle上的恶意软件类别识别。
  4. 区分淘宝商品的评论是正面还是负面评论。

2. 分类编码

由于自然语言表示的类别不方便运算,所以为了计算的需要,有必要对类别进行编码。

对于分类问题,最常用的编码方式为一位有效编码,也称为独热编码(one-hot encoding)。它可以表示为一个向量,长度等于类别数量,向量中只有一个特征为1,其它特征均为0。

这里我们以一个图像分类问题为例来讨论, 假设要预测一张图片是猫、鸡还是狗,那么我们对这三种类别进行一位有效编码的形式如下:

  • (1,0, 0)对应于“猫”
  • (0,1,0)对应于“鸡”
  • (0,0,1)对应于“狗”
  1. 正确类别对应的分量设置为1,其它所有分量均为0.
  2. 类别数量等于分量数量(这里的分量是指向量在具体一个维度上的值)

分类问题对模型的要求:正确类的置信度要远远大于非正确类的置信度,即Oy > Oi。

相比具体每个类别的预测值大小,我们更关心正确类别的预测值是否远大于其它非正确类别的预测值,只有这样,才能表明模型能真正区分出正确类别。

3. 网络架构

与线性回归一样,softmax回归也是一个单层神经网络。

接着上面的例子,假设每次输入是一个2*2的灰度图像,我们可以用一个标量表示每个像素值,每个图像对应四个特征[x1,x2,x3,x4]。

我们可以定义输出向量y=[o1,o2,o3], 其中o1、o2、o3分别表示输入i是猫、鸡、狗的预测值大小。

由于我们有4个特征和3个可能的输出类别, 我们将需要12个标量来表示权重w, 3个标量来表示偏置b。则每个类别的计算可以表示为:
在这里插入图片描述

由于计算每个输出o1、o2和o3取决于 所有输入x1、x2、x3和x4, 所以softmax回归的输出层也是全连接层。
在这里插入图片描述
如同线性回归一样,可以将计算公式简洁表示,o = Wx + b。这是将所有权重放到一个W矩阵中。 对于给定数据样本的特征x, 我们的输出y是由权重W与输入特征x进行矩阵-向量乘法再加上偏置b得到。

4. 输出概率化

对于分类问题,我们希望模型的输出yj可以视为它属于类别j的概率,然后只需要选择具有最大输出值的类别argmax(xj,yj) 作为我们的预测即可, 这样能同时方便人脑理解和算术运算。

例如,如果为猫、鸡和狗的概率分别为0.1、0.8和0.1, 因为0.8概率最大,所以我们预测的类别是2,在我们的例子中代表“鸡”。

这里之所以要进行标准化概率计算,而不直接将预测o作为输出,其原因在于将线性层的输出o视作概率会存在一些问题:

  1. 线性层输出没有限制输出数字的总和为1,不符合概率分布。
  2. 根据输入的不同,线性层的输出是可以为负值的,会影响我们的计算。

要将输出视为概率,我们必须保证以下两点:

  1. 在任何类别上的输出都是非负
  2. 所有类别的预测值总和为1。

而softmax函数则正好能够将未规范化的预测变换为非负数并且总和为1,同时让模型保持可导的性质。它的作法为:

  • 对每个未规范化的预测求幂(指数),这样可以确保输出非负
  • 让每个求幂后的结果除以它们的总和,就能确保最终输出的概率值总和为1
    在这里插入图片描述通过对输出向量o进行softmax运算后,预测值就是一个概率分布。

而真实的值经过独热编码后也符合这个特征,因为它也符合概率的特性:

  1. 非负数:只有0和1两种值;
  2. 和为1:只有一个值为1,其它均为0;

这样就得到两个概率:预测值概率和真实值概率。接下来,就可以比较两个概率来作为损失。

5. 损失计算

交叉熵损失:用来衡量两个概率分布之间的差异。

对于分类问题,我们不关心非正确类别的预测值,只关心对正确类的预测值置信度有多大。

假设模型对每个类别的预测概率分别是0.7、0.2和0.1,实际该样本属于第一个类别。交叉熵损失会根据模型对第一个类别的预测概率和实际概率来计算一个损失值。用数学表示如下:

H(p, q) = -Σ p(x) * log(q(x))
  • p(x)表示实际的概率分布,q(x)表示模型预测的概率分布。
  • 前面加负号的目的是为了保证交叉熵为正值。log(q(x))的值通常是小于0的(小于1时,对数为负数),p(x)是一个概率值,介于0和1之间。
  • 交叉熵越小,表示两个概率分布越接近,模型的预测效果越好。

可以把交叉熵H(P,Q)想象为“主观概率为Q
的观察者在看到根据概率P生成的数据时的意外程度”。 当P=Q时,这种意外程度降到最低。

训练的目的:最小化交叉熵来优化模型的参数,使得模型的预测结果更接近于实际标签。

由于真实值p(x)是一个独热编码向量,只有一项为1,其它项均为0,所以这里的交叉熵又可以简写成:
在这里插入图片描述
所以,对于分类问题来说,我们不关心非正确类别的预测值,只关心正确类别的预测值有多大。

而梯度则是预测概率与真实概率之间的差异,损失函数对输出o求导为:
在这里插入图片描述

softmax回归模型训练的目标:给出任何样本特征,我们可以预测每个输出类别的概率。 通常我们使用预测概率最高的类别作为输出类别。 如果预测与实际类别(标签)一致,则预测是正确的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/585912.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

用Docker 创建并运行一个MySQL容器

可以在DockerHub官网上荡:mysql - Official Image | Docker Hub 指令是:docker pull mysql; 因为文件比较大可能时间比较长,我是跟着黑马的课走的 课程提供的有文件,我就用已有的资源了。 在tmp目录里放入mysql.tar包 然后cd进去 输入指令:docker lo…

java技术栈快速复习05_基础运维(linux,git)

Linux知识总览 linux可以简单的理解成和window一样的操作系统。 Linux和Windows区别 Linux是严格区分大小写的;Linux中一切皆是文件;Linux中文件是没有后缀的,但是他有一些约定俗成的后缀;Windows下的软件一般是无法直接运行的Li…

管理能力学习笔记八:Will-Skill矩阵“盘“团队

如何把握带教中的“度”,才能在把事情做好的基础上,又能使员工获得成长呢? 需要做到 合理授权 & 适当辅导 如何做到? 通过使用 意愿-技能矩阵(Will-Skill Matrix) 辨别不同带教方法的适用情形,"盘"…

【Vue 2.x】学习vue之一基础部分

文章目录 Vue 一基础部分第一章1、git两个分支主分支子分支 使用方法方式1:采用命令的方式操作分支方式2:在idea中使用git的分支 向git远程仓库提交时忽略文件使用git时的一些冲突注意事项 2、Vue问题1:什么是Vue?问题2&#xff1…

网站升级提示:我用react+go重构了网站并记录了部署项目简要步骤

先贴出来地址,这是我网站的地址易查网 可能有细心的小伙伴们已经看到了,原来我的网站是这样式的 妥妥的phph5 改造 前端react框架 前段时间学习了react,正愁无处练手,就有人说我的网站很low,我感觉这正是一个好的机会&#xff…

纯血鸿蒙APP实战开发——发布图片评论

介绍 本示例将通过发布图片评论场景,介绍如何使用startAbilityForResult接口拉起相机拍照,并获取相机返回的数据。 效果图预览 使用说明 通过startAbilityForResult接口拉起相机,拍照后获取图片地址。 实现思路 创建CommentData类&#…

翻译: 什么是ChatGPT 通过图形化的方式来理解 Transformer 架构 深度学习二

在本章中,我们将深入探讨 网络的开始和 结束阶段发生的情况, 我将花大量时间回顾一些重要的背景知识,这些知识是熟悉Transformer的机器学习工程师的基础知识。 如果你已经熟悉背景知识,迫不及待地想了解更多,你可以跳到下一节,重点将放在Transformer的核心部分——注意…

nacos(docker部署)+springboot集成

文章目录 说明零nacos容器部署初始化配置高级配置部分访问权限控制命名空间设置新建配置文件 springboot配置nacos添加依赖编写测试controller 说明 nacos容器部署采用1Panel运维面板,进行部署操作,简化操作注意提前安装好1Panel和配置完成docker镜像加…

深入剖析Tomcat(五) 剖析Servlet容器并实现一个简易Context与Wrapper容器

上一章介绍了Tomcat的默认连接器,后续程序都会使用默认连接器。前面有讲过Catalina容器的两大块内容就是连接器与Servlet容器。不同于第二章的自定义丐版Servlet容器,这一章就来探讨下Catalina中的真正的Servlet容器究竟长啥样。 四种容器 在Catalina中…

Unity涂鸦纹理实现

文章目录 前言实现过程UV坐标和UI坐标对齐修改像素代码 前言 心血来潮实现下场景中提供一张纹理进行涂鸦的功能。 最终实现效果: 实现过程 UV坐标和UI坐标对齐 这里的纹理使用了UGUI的Canvas进行显示,所以这里使用一张RawImage。 因为Unity的视口坐标是以左下角…

【Excel】excel计算相关性系数R、纳什效率系数NSE、Kling-Gupta系数KGE

对于采用的数据: B2:B10958是观测值的所在范围 C2:C10958是模型计算值的所在范围 一、相关系数R是用来衡量两个变量之间线性关系强度和方向的统计量。在水文学和气象学中,常用的相关系数是皮尔逊相关系数(Pearson correlation coefficient&am…

Baidu Comate:“AI +”让软件研发更高效更安全

4月27日,百度副总裁陈洋出席由全国工商联主办的第64届德胜门大讲堂,并发表了《深化大模型技术创新与应用落地,护航大模型产业平稳健康发展》主题演讲。陈洋表示,“人工智能”成为催生新质生产力的重要引擎,对于企业而言…

线上线下收银一体化,新零售POS系统引领连锁门店数字化转型-亿发

在市场竞争日益激烈的背景下,没有哪个商家能够永远屹立不倒。随着互联网技术的快速发展,传统的线下门店面临着来自电商和新零售的新型挑战。实体零售和传统电商都需要进行变革,都需要实现线上线下的融合。 传统零售在客户消费之后就与商家失…

网络基础(1)网络编程套接字UDP

要完成网络编程首先要理解原IP和目的IP,这在上一节已经说明了。 也就是一台主机要进行通信必须要具有原IP和目的IP地址。 端口号 首先要知道进行网络通信的目的是要将信息从A主机送到B主机吗? 很显然不仅仅是。 例如唐僧要去到西天取真经&#xff0…

ES集群分布式查询原理

集群分布式查询 elasticsearch的查询分成两个阶段: scatter phase:分散阶段,coordinating node会把请求分发到每一个分片gather phase:聚集阶段,coordinating node汇总data node的搜索结果,并处理为最终结…

粘合/粘接/胶合聚酰亚胺PI材料使用UV胶,用的UV LED灯的波长范围及功率怎么选择?(三十九)

UV胶固化设备的UV LED波长范围是多少才能与UV胶匹配? UV胶固化设备的UV LED波长范围与UV胶的匹配性主要取决于所使用的UV胶的固化特性。不同的UV胶可能对UV光的波长有不同的要求。因此,要确定与UV胶匹配的UV LED波长范围,首先需要了解所使用的…

Transformer模型详解

Transformer模型实在论文《Attention Is All You Need》里面提出来的,用来生成文本的上下文编码,传统的上下问编码大多数是由RNN来完成的,不过,RNN存在两个缺点: 一、计算是顺序进行的,无法并行化&#xf…

C语言——每日一题(移除链表元素)

一.前言 今天在leetcode刷到了一道关于单链表的题。想着和大家分享一下。废话不多说,让我们开始今天的知识分享吧。 二.正文 1.1题目要求 1.2思路剖析 我们可以创建一个新的单链表,然后通过对原单链表的遍历,将数据不等于val的节点移到新…

【补充】图神经网络前传——图论

本文作为对图神经网络的补充。主要内容是看书。 仅包含Introduction to Graph Theory前五章以及其他相关书籍的相关内容(如果后续在实践中发现前五章不够,会补上剩余内容) 引入 什么是图? 如上图所示的路线图和电路图都可以使用…

Flink checkpoint 源码分析- Checkpoint barrier 传递源码分析

背景 在上一篇的博客里,大致介绍了flink checkpoint中的触发的大体流程,现在介绍一下触发之后下游的算子是如何做snapshot。 上一篇的文章: Flink checkpoint 源码分析- Flink Checkpoint 触发流程分析-CSDN博客 代码分析 1. 在SubtaskCheckpointCoo…