分类 classificaton

b10049398220446a9618e92493250956.png

1)什么是分类?

在此之前,我们一直使用的都是回归任务进行学习;这里我们将进一步学习什么是分类,我们先从训练模型的角度来看看二者的区别。

82a92f15506b468a9e2e416f68dd3d04.png

 对于回归来说,它所作的是对模型输入相应的特征,然后模型给出相应的输出,需要让模型的输出和实际的标签值越接近越好;而对于分类来说,同样的是将相应的特征输入模型,模型输出相应的类型。

1.1问题一:

分类模型的输出不像回归模型一样输出是一个特定的数值,所以对于分类模型来说我们可以根据将不同的类别使用不同的数值来代替。

例如:

class1 --- 1

class2 --- 2

class3 --- 3

这样当模型认为输入的特征和class1更符合的话就会输出1,class3更符合的话就会输出3。但是又会出现新的问题,采用以上编码方式是否会导致模型认为class1和class2更相似;class1和class3更加的不同呢?因为class1和class2的距离上更近,比如网络的输出值是1.49,其实就表示很大概率是class1或者class2,这样其实就隐含表示class1或者class2更近一点。

1.2问题二:

对于有些分类任务来说,采用以上编码方式是不会具有问题的。比如使用升高和体重来预测小朋友的年级,例如:一年级 --- 1、二年级 --- 2、三年级 --- 3。这样是没问题的,因为一年级和二年级这两个类别来说是相对更近的, 一年级和三年级这两个类别来说是相对更远的。但是对于有的分类任务来说,再编码的时候就会产生这样的问题,于是在编码的时候采用one-hot vector的方式进行编码。

例如:

 eq?class1%3D%5Cleft%20%5B%20%5Cbegin%7Bmatrix%7D%201%20%5C%5C%200%20%5C%5C%200%20%5C%5C%20%5Cend%7Bmatrix%7D%20%5Cright%20%5D

 eq?class2%3D%5Cleft%20%5B%20%5Cbegin%7Bmatrix%7D%200%20%5C%5C%201%20%5C%5C%200%20%5C%5C%20%5Cend%7Bmatrix%7D%20%5Cright%20%5D

 eq?class3%3D%5Cleft%20%5B%20%5Cbegin%7Bmatrix%7D%200%20%5C%5C%200%20%5C%5C%201%20%5C%5C%20%5Cend%7Bmatrix%7D%20%5Cright%20%5D

采用这种编码方式的话,就不会出现以上这种问题了,这样的话,他们之间的距离都是一样的啦。

1.3问题三:

在回归问题的时候,我们构建的神经网络只能输出一个数值,但是对于分类问题来说,要是采用one-hot的编码方式对类别进行编码,那么对于网络的输出就不能只有一个,所以网络的结构也必须改变。

所以如下图所示,只需要多输出两个就行。

58aaed9f9f3b4fd5bf2b507ae315ddf7.png

至此对于一个分类任务的模型我们已经构架完成了,只不过是对于回归问题进行了一些小小的改进,但是其实对于分类问题来说,还有一些不太一样的问题。


 我们来最终对比一下回归任务和分类任务,分类任务最终的输出要和实际的标签纸越接近越好;对于分类来说,其最终的输出也应该与实际的类别标签纸越接近越好。

fbfc23417b5245f09ddfac263253aa7d.png

可是实际在最后一步输出的过程中,即最后网络输出了y之后,对于分类问题来说会再加上一个softmax使其输出 eq?y%5E%7B%27%7D ,然后希望的是  eq?y%5E%7B%27%7D和实际的标签值越接近越好。

2)为什要加softmax?

简单理解即使,其实对于网络的输出的三个值是可以为任何值,但是在最终的标签我们是希望在零到一之间的,所以通过softmax就可以将网络输出的值规格化到零到一之间。

softmax的工作过程:

  • 对其所以的输入的y值(网络最后的输出)取exp,也就是分别计算e^{y_i}
  • e^{y_i}进行求和
  • 用分别计算得到的e^{y_i}比上所有e^{y_i}的和就得到了每个数值最后被规格化后的数值。

例子:

输入softmax的三个数值是3、1、-3。 

  • 取exp。得到e^3=20,e^1=2.7,e^{-3}=0.05
  • 求和。20 + 2.7 + 0.05 = 22.75
  • 归一化。\frac{0.05}{22.75}\approx 0,\frac{2.7}{22.75}\approx 0.12,\frac{20}{22.75}\approx 0.88

其实在实际的分类问题当中,当分类任务是两个类别的时候,我们更常用的是使用sigmoid函数来进行最后的归一化;但是其实也可以使用softmax,他们二者在二分类的使用上无本质的区别。

2.1sigmoid函数:

sigmoid函数原型表达式如下:

                                                              sigmoid = \frac{1}{1+e{-x}}

以输入x_1,x_2为例子。

  • sigmoid:output(x_1)=\frac{1}{1+e{-x_1}}
  • softmax:output(x_1)=\frac{e^{x_1}}{e^{x_1}+e^{x_2}}=\frac{1}{1+e^{x_2-x_1}}=\frac{1}{1+e^{-(x_1-x_2)}}
  • 对于二分类来说可以进一步将softmax写成:softmax =\frac{1}{1+e^{-z_1}}

由此可得对于二分类问题来说,其二者的公式无本质区别,即理论上来说,二者是没有任何区别的。

sigmoid和softmax函数的本质区别: 

sigmoid函数用于多标签分类问题,选取多个标签作为正确答案,它是将任意值归一化为[0-1]之间,并不是不同概率之间的相互关联

Softmax函数用于多分类问题,即从多个分类中选取一个正确答案。Softmax综合了所有输出值的归一化,因此得到的是不同概率之间的相互关联。 

转载来源:深度学习随笔——Softmax函数与Sigmoid函数的区别与联系 - 知乎

Sigmoid函数针对两点分布提出。神经网络的输出经过它的转换,可以将数值压缩到(0,1)之间,得到的结果可以理解成分类成目标类别的概率P,而不分类到该类别的概率是(1 - P),这也是典型的两点分布的形式。

Softmax函数本身针对多项分布提出,当类别数是2时,它退化为二项分布。而它和Sigmoid函数真正的区别就在——二项分布包含两个分类类别(姑且分别称为A和B),而两点分布其实是针对一个类别的概率分布,其对应的那个类别的分布直接由1-P得出。

简单点理解就是,Sigmoid函数,我们可以当作成它是对一个类别的“建模”,将该类别建模完成,另一个相对的类别就直接通过1减去得到。而softmax函数,是对两个类别建模,同样的,得到两个类别的概率之和是1。

3)分类问题的损失函数

分类问题的损失函数同样的根据距离来计算,可以和之前的回归问题一样使用MSE误差来计算损失函数。但是更常用是使用下图中的 Cross- entropy来计算误差

为什么选择Cross- entropy而不是Mean Square error

其实我们实际在使用pytorch进行构建网络实现分类的任务的时候,我们会发现找不到softmax,这是因为,我们再构建网络后,使用Cross-entropy来计算误差的时候,其会自动再网络的最后一层加上softmax,在 pytorch中Cross-entropy和softmax被内置成为了一个整体。

例子:

在某个训练过程中网络输出的三个数值分别是y_1,y_2,y_3,然后再经softmax处理得到了最后的输出结果,我们分别使用Cross- entropy和Mean Square error进行求损失,然后根据损失计算下一步该往哪里走。 

如上图所示得到了Cross- entropy和Mean Square error 的损失图,在图的右下角都是损失最小的地方,即y_1的值变大, y_2的值变小就可以使得损失值变小;在图的左上角都是损失最大的地方,即y_1的值变小, y_2的值变大会使得损失值变大。 

所以当使用Mean Square error 的时候,很有可能会到达梯度不变的点,导致训练不下去(一般的训练过程训练不下去),但是对于 Cross- entropy 来说,确是有梯度了,可以进行下去。 

所以可以得出对于损失函数的设计都会影响最后的一个训练优化过程。

4)pytorch实现分类代码

完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 定义超参数
batch_size = 64
learning_rate = 0.001
num_epochs = 10

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 下载并加载 CIFAR-10 数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 定义 CNN 模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 实例化模型、定义损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:    # 每100个mini-batch打印一次
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}], Loss: {running_loss / 100:.4f}')
            running_loss = 0.0

print('训练完成')

# 测试模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'测试准确率: {100 * correct / total:.2f}%')

接下来将一步一步剖析代码,我们先看网络架构。

定义SimpleCNN类,让其继承prtorch中的nn.Module,并且

# 定义 CNN 模型
class SimpleCNN(nn.Module): #SimpleCNN 继承自 nn.Module
    def __init__(self): #构造函数(初始化方法)
        super(SimpleCNN, self).__init__()  # 这一行代码调用了父类(nn.Module)的构造函数,确保在实例化SimpleCNN类时,也会执行父类的初始化操作。       

第一层网络:

self.conv1 = nn.Conv2d(3, 32, 3, padding=1)

CIFAR-10图像都是3通道(RGB)的,尺寸为32x32像素。因此每个图像的数据形状是3x32x32,所以对于32个卷积核

  • 输入通道数:3,对应于 RGB 图像的三个通道。
  • 输出通道数:32,卷积核的数量,生成32个特征图。
  • 卷积核大小:3x3。
  • Padding:1,保持输出和输入的宽高相同。

CIFAR-10图像为3x32x32,输入通道是3,其实就对应着3个特征图,即RGB三个特征图。经过32个卷积核采样后,每个卷积核都能得到一个特征图,也就是32个特征图,但是在不加padding的时候得到的特征图都是4x4的,加了padding之后,就能保证最后卷积后得到额特征是还是6x6的。

第二层网络:

self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
  • 输入通道数:32,来自前一层的输出。
  • 输出通道数:64。
  • 卷积核大小:3x3。
  • Padding:1。

 第三层网络——池化层

  • 池化大小:2x2。
  • 步幅:2,减少特征图尺寸,通常用于降采样。

self.pool = nn.MaxPool2d(2, 2)

全连接层

self.fc1 = nn.Linear(64 * 8 * 8, 512)

  • 输入大小:64 * 8 * 8,来自池化层展平后的特征数。
  • 输出大小:512,隐藏层的神经元个数。

self.fc2 = nn.Linear(512, 10)

  • 输入大小:512,来自前一层的输出。
  • 输出大小:10,对应于 CIFAR-10 数据集的10个类别。

self.relu = nn.ReLU()

ReLU:一种常用的激活函数,引入非线性,计算简单,能有效缓解梯度消失问题。

前向传播:

def forward(self, x):
    x = self.pool(self.relu(self.conv1(x)))
    x = self.pool(self.relu(self.conv2(x)))
    x = x.view(-1, 64 * 8 * 8)
    x = self.relu(self.fc1(x))
    x = self.fc2(x)
    return x

  • 卷积 + ReLU + 池化
    • 首先对输入应用卷积层和 ReLU 激活函数,然后进行最大池化。
  • 展平
    • 使用 view 方法将特征图展平成一维,便于连接到全连接层。
  • 全连接层 + ReLU
    • 输入到第一个全连接层并应用 ReLU 激活。
  • 输出层
    • 最后通过第二个全连接层得到输出。

这个简单的 CNN 通过卷积、激活、池化和全连接层的组合来提取图像特征并进行分类。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/909420.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

免费且强大的PDF转换工具——PDFgear

前言 PDFgear是一款不可或缺的PDF文件处理工具,凭借其强大的功能和多样的特点,它能帮助用户更快速、高效地编辑和处理PDF文件,显著提升工作效率 通过网盘分享的文件:pdf转换工具 链接: https://pan.baidu.com/s/1ap37H9tP6brqTgf…

sql中判断一个字段是否包含一个数据的方法有哪些?

目录 一、like模糊查询(like关键字) 二、locate(字符串,字段名) 三、 instr(字段名,字符串) 四、regexp_extract(subject, pattern, index) 以下是几种方法,使用hive来举例演示一下: -- 举例:创建一个…

STM32 + CubeMX + 硬件SPI + W5500 +TcpClient

这篇文章记录一下STM32W5500TCP_Client的调试过程,实现TCP客户端数据的接收与发送。 目录 一、W5500模块介绍二、Stm32CubeMx配置三、Keil代码编写1、添加W5500驱动代码到工程(添加方法不赘述,驱动代码可以在官网找)2、在工程中增…

C++笔试题之实现一个定时器

一.定时器(timer)的需求 1.执行定时任务的时,主线程不阻塞,所以timer必须至少持有一个线程用于执行定时任务 2.考虑到timer线程资源的合理利用,一个timer需要能够管理多个定时任务,所以timer要支持增删任务…

Halcon resistor.hedv 使用多个对焦级别提取深度

depth_from_focus * Extract depth using multiple focus levels * 使用多个对焦级别提取深度 Names : [] * 初始化一个空数组,用于存储图像名称 dev_close_window () * 关闭当前打开的图像窗口 for i : 1 to 10 by 1 * 循环开始,从1到10 …

区块链技术与应用-PKU 学习笔记

课程地址 资料: ETH-Security 区块链学习记录_比特币 BTC 密码学原理 比特币,又称加密货币(crypto-currency),它主要利用了密码学中的哈希函数(cryptographic hash function)的抗碰撞特性(collision resistance)和单向散列特性(hiding) …

Spark 的Standalone集群环境安装与测试

目录 一、Standalone 集群环境安装 (一)理解 Standalone 集群架构 (二)Standalone 集群部署 二、打开监控界面 (一)master监控界面 (二)日志服务监控界面 三、集群的测试 &a…

VLAN 高级技术 ——QinQ的配置

QinQ的概述: QinQ技术是一种扩展虚拟局域网(VLAN)数量空间的技术,通过在802.1Q标签报文的基础上再增加一层802.1Q的Tag来实现。以下是对QinQ技术的详细概述: QinQ技术的定义与背景 定义:QinQ&#xff08…

伍光和《自然地理学》电子书(含考研真题、课后习题、章节题库、模拟试题)

《自然地理学》(第4版)由伍光和、王乃昂、胡双熙、田连恕、张建明合著,于2018年11月出版。作为普通高等教育“十一五”国家级规划教材,本书不仅适用于高校地球科学各专业的基础课程,还可供环境、生态等有关科研、教学人…

迅为RK3588开发板Android多屏显示之多屏同显和多屏异显

迅为RK3588开发板是一款低功耗、高性能的处理器,适用于基于arm的PC和Edge计算设备、个人移动互联网设备等数字多媒体应用,RK3588支持8K视频编解码,内置GPU可以完全兼容OpenGLES 1.1、2.0和3.2。RK3588引入了新一代完全基于硬件的最大4800万像…

登录功能设计(php+mysql)

一 登录功能 1. 创建一个登录页面(login.php),包含一个表单,用户输入用户名和密码。 2. 在表单的提交事件中,使用PHP代码处理用户输入的用户名和密码。 3. 首先,连接MySQL数据库。然后&a…

vue--vueCLI

何为CLI ■ CLI是Command-Line Interface,俗称脚手架. ■ 使用Vue.js开发大型应用时,我们需要考虑代码目录结构、项目结构和部署、热加载、代码单元测试等事情。(vue 脚手架的作用), 而通过vue-cli即可:vue-cli 可以…

软件测试工程师面试整理 —— 编程与自动化!

在软件测试领域,编程与自动化是提升测试效率、覆盖率和可靠性的关键因素。掌握编程技术和自动化测试框架,能够帮助测试人员有效地执行大量重复性测试任务,并迅速反馈软件的质量状况。以下是编程与自动化在测试中的主要应用及相关技术介绍&…

宝顶白芽,慢生活的味觉盛宴

在快节奏的生活中,人们愈发向往那种悠然自得、返璞归真的生活方式。白茶,以其独特的韵味和清雅的风格,成为了现代人追求心灵宁静与生活品质的象征。而在众多白茶之中,竹叶青茶业出品的宝顶白芽以其甘甜醇爽的特质,成为…

安卓APP渗透安全测试

1.移动安全测试点分析 1.1主要测试 客户端 数据传输 服务端 l反编译 l二次打包 l组件安全 lWebview漏洞 l数据安全 l界面劫持 l数据备份风险 lDebug调试风险 l安全策略 l数据窃听 l中间人攻击 l信息泄露 l任意修改数据包 lSQL注入 l上传漏洞 l暴力破解 l逻辑漏洞 lXSS…

CentOS 7 安装 ntp,自动校准系统时间

1、安装 ntp yum install ntp 安装好后,ntp 会自动注册成为服务,服务名称为 ntpd 2、查看当前 ntpd 服务的状态 systemctl status ntpd 3、启动 ntpd 服务、查看 ntpd 服务的状态 systemctl start ntpdsystemctl status ntpd 4、设置 ntpd 服务开机启…

ESP-HaloPanel:用 ESP32-C2 打造超低成本智能家居面板

项目简介 在生活品质日益提升的今天,智能家居系统已经走进了千家万户,并逐渐成为现代生活的一部份。与此同时,一款设计精致、体积轻盈、操作简便的全屋智能家居控制面板,已经成为众多家庭的新宠。这种高效、直观的智能化的解决方…

如何用ChatGPT结合Python处理遥感数据

在科技飞速发展的时代,遥感数据的精准分析已经成为推动各行业智能决策的关键工具。从无人机监测农田到卫星数据支持气候研究,空天地遥感数据正以前所未有的方式为科研和商业带来深刻变革。然而,对于许多专业人士而言,如何高效地处…

TCP Analysis Flags 之 TCP Keep-Alive

前言 默认情况下,Wireshark 的 TCP 解析器会跟踪每个 TCP 会话的状态,并在检测到问题或潜在问题时提供额外的信息。在第一次打开捕获文件时,会对每个 TCP 数据包进行一次分析,数据包按照它们在数据包列表中出现的顺序进行处理。可…