TensorFlow深度学习实战(7)——分类任务详解

TensorFlow深度学习实战(7)——分类任务详解

    • 0. 前言
    • 1. 分类任务
      • 1.1 分类任务简介
      • 1.2 分类与回归的区别
    • 2. 逻辑回归
    • 3. 使用 TensorFlow 实现逻辑回归
    • 小结
    • 系列链接

0. 前言

分类任务 (Classification Task) 是机器学习中的一种监督学习问题,其目的是将输入数据(特征向量)映射到离散的类别标签。广泛应用于如文本分类、图像识别、垃圾邮件检测、医学诊断等多种领域。

1. 分类任务

1.1 分类任务简介

分类任务的目标是通过训练数据学习一个模型,使得对于新的输入数据能够预测其所属的类别。输入数据是模型的自变量,通常是特征向量 x = [ x 1 , x 2 , … , x n ] ) {x} = [x_1, x_2, \dots, x_n]) x=[x1,x2,,xn]),其中 n n n 是特征的维度,每个特征可能是连续值(如温度、年龄)或离散值(如颜色、性别)。输出是一个类别标签,表示每个输入数据点的所属类别,对于二分类任务,输出标签通常为 01;而对于多分类任务,标签的数量可以是多个类别,例如 0123 等。
根据类别的数量不同,可以将分类任务归为不同类型:

  • 二分类 (Binary Classification):输出只有两个类别,例如“是”与“否”
  • 多分类 (Multiclass Classification):输出包含多个类别标签,适用于每个样本属于多个可能类别中的一个的任务,例如“猫”、“狗”、“狮子”、“大象”等
  • 多标签分类 (Multilabel Classification):与传统的单一类别分类不同,每个样本可以同时属于多个类别

1.2 分类与回归的区别

回归和分类任务之间的区别:

  • 在分类任务中,数据被分成不同的类别,而在回归中,目标是根据给定的数据得到一个连续值。例如,识别手写数字的任务属于分类任务,所有的手写数字都属于 09 之间的某个数字;而根据不同的输入变量预测房屋价格则属于回归任务
  • 在分类任务中,模型的目标是找到分隔不同类别的决策边界;而在回归任务中,模型的目标是逼近一个适合输入输出关系的函数。

分类和回归任务的不同之处如下图所示。在分类中,我们需要找到分隔类别的线(或平面,或超平面)。在回归中,目标是找到一条(或一个平面,或一个超平面)拟合给定输入与输出关系的线。

分类与回归

2. 逻辑回归

逻辑回归 (Logistic regression) 用于确定事件发生的概率。通常,事件表示为分类的因变量。事件发生的概率使用 sigmoid (或logit )函数表示:
P ^ ( Y ^ = 1 ∣ X = x ) = 1 1 + e − ( b + W T x ) \hat P(\hat Y=1|X=x)=\frac 1{1+e^{-(b+W^Tx)}} P^(Y^=1∣X=x)=1+e(b+WTx)1
目标是估计权重 W = { w 1 , w 2 , . . . , w n } W=\{w_1,w_2,...,w_n\} W={w1,w2,...,wn} 和偏置项 b b b。在逻辑回归中,系数可以使用最大似然估计或随机梯度下降来估计。如果 p p p 是输入数据样本的总数,损失通常定义为交叉熵项:
l o s s = ∑ i = 1 p Y i l o g ( Y ^ i ) + ( 1 − Y i ) l o g ( 1 − Y ^ i ) loss=\sum_{i=1}^pY_ilog(\hat Y_i)+(1-Y_i)log(1-\hat Y_i) loss=i=1pYilog(Y^i)+(1Yi)log(1Y^i)
逻辑回归用于分类问题。例如,在分析医疗数据时,我们可以使用逻辑回归来分类一个人是否患有癌症。如果输出的分类变量具有两个或多个,可以使用多分类逻辑回归。对于多分类逻辑回归,交叉熵损失函数可以改写为:
l o s s = ∑ i = 1 p ∑ j = 1 k Y i j l o g Y ^ i j loss=\sum_{i=1}^p\sum_{j=1}^kY_{ij}log\hat Y_{ij} loss=i=1pj=1kYijlogY^ij
其中 k k k 是类别总数。了解了逻辑回归的原理后,接下来,将其应用于具体实践中。

3. 使用 TensorFlow 实现逻辑回归

接下来,使用 TensorFlow 实现逻辑回归,对 MNIST 手写数字进行分类。MNIST 数据集包含手写数字的图像,每个图像都有一个标签值(介于 09 之间)标注图像中的数字值。因此,属于多类别分类问题。
为了实现逻辑回归,构建一个仅包含一个全连接层的模型。输出中的每个类别由一个神经元表示,由于我们有 10 个类别,输出层的神经元数为 10。逻辑回归中使用的概率函数类似于 sigmoid 激活函数,因此,模型使用 sigmoid 激活。接下来,构建模型。

(1) 首先,导入所需库。由于全连接层接收的输入为一维数据,因此使用 Flatten 层,用于将 MNIST 数据集中的 28 x 28 二维输入图像调整为一个包含 784 个元素的一维数组:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow.keras as K
from tensorflow.keras.layers import Dense, Flatten

(2)tensorflow.keras 数据集中获取 MNIST 输入数据:

((train_data, train_labels),(test_data, test_labels)) = tf.keras.datasets.mnist.load_data()

(3) 对数据进行预处理。对图像进行归一化,MNIST 数据集的图像是灰度图像,每个像素的强度值介于 0255 之间,将其除以 255,使数值范围在 01 之间:

train_data = train_data/np.float32(255)
train_labels = train_labels.astype(np.int32)  
test_data = test_data/np.float32(255)
test_labels = test_labels.astype(np.int32)

(4) 定义模型,模型只有一个具有 10 个神经元的 Dense 层,输入大小为 784,从模型摘要的输出中可以看到,只有 Dense 层具有可训练的参数:

model = K.Sequential([
                      # Dense(64,  activation='relu'),
                      # Dense(32,  activation='relu'),
                      Flatten(input_shape=(28, 28)),
                      Dense(10, activation='sigmoid')
])
print(model.summary())

模型架构

(5) 因为标签是整数值,因此使用 SparseCategoricalCrossentropy 损失函数,设置 logits 参数为 True。选择 Adam 优化器,此外,定义准确率作为在训练过程中需要记录的指标。模型训练 50epochs,使用 80:20 的比例拆分训练-验证集:

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
history = model.fit(x=train_data,y=train_labels, epochs=50, verbose=1, validation_split=0.2)

(6) 绘制损失曲线观察模型性能表现。可以看到随着 epoch 的增加,训练损失降低的同时,验证损失逐渐增加,因此模型出现过拟合,可以通过添加隐藏层来改善模型性能:

plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

训练过程监测

(7) 为了更好地理解结果,构建两个实用函数,用于可视化手写数字以及模型输出的 10 个神经元的概率:

predictions = model.predict(test_data)

def plot_image(i, predictions_array, true_label, img):
    true_label, img = true_label[i], img[i]
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])

    plt.imshow(img, cmap=plt.cm.binary)

    predicted_label = np.argmax(predictions_array)
    if predicted_label == true_label:
        color = 'blue'
    else:
        color = 'red'

    plt.xlabel("Pred {} Conf: {:2.0f}% True ({})".format(predicted_label,
                                    100*np.max(predictions_array),
                                    true_label),
                                    color=color)

def plot_value_array(i, predictions_array, true_label):
    true_label = true_label[i]
    plt.grid(False)
    plt.xticks(range(10))
    plt.yticks([])
    thisplot = plt.bar(range(10), predictions_array, color="#777777")
    plt.ylim([0, 1])
    predicted_label = np.argmax(predictions_array)

    thisplot[predicted_label].set_color('red')
    thisplot[true_label].set_color('blue')

(8) 绘制预测结果:

i = 56
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plot_image(i, predictions[i], test_labels, test_data)
plt.subplot(1,2,2)
plot_value_array(i, predictions[i],  test_labels)
plt.show()

左侧的图像是手写数字图像,图像下方显示了预测的标签、预测的置信度以及真实标签。右侧的图像显示了 10 个神经元输出的概率(逻辑输出),可以看到代表数字 4 的神经元具有最高的概率:

训练结果

(9) 为了保持逻辑回归的特性,以上代码仅使用了一个包含 sigmoid 激活函数的 Dense 层。为了获得更好的性能,可以添加 Dense 层并使用 softmax 作为最终的激活函数,以下模型在验证数据集上能够达到 97% 的准确率:

better_model = K.Sequential([
                      Flatten(input_shape=(28, 28)),
                      Dense(128,  activation='relu'),
                      #Dense(64,  activation='relu'),
                      Dense(10, activation='softmax')
])
better_model.summary()

better_model.compile(optimizer='adam',
                     loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                     metrics=['accuracy'])

history = better_model.fit(x=train_data,y=train_labels, epochs=10, verbose=1, validation_split=0.2)

plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

predictions = better_model.predict(test_data)
i = 0
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plot_image(i, predictions[i], test_labels, test_data)
plt.subplot(1,2,2)
plot_value_array(i, predictions[i],  test_labels)
plt.show()

模型预测

我们可以尝试添加更多隐藏层,或者修改隐藏层中神经元的数量,或者修改优化器,以更好地理解这些参数对模型性能的影响。

小结

分类任务是机器学习中最常见的任务之一,广泛应用于各个领域。成功的分类任务不仅需要选择合适的算法,还需要对数据进行深入的预处理和特征工程。在本节中,我们首先介绍了分类任务及其与回归任务的区别,然后介绍了用于分类任务的逻辑回归技术,并使用 TensorFlow 实现了逻辑回归模型。

系列链接

TensorFlow深度学习实战(1)——神经网络与模型训练过程详解
TensorFlow深度学习实战(2)——使用TensorFlow构建神经网络
TensorFlow深度学习实战(3)——深度学习中常用激活函数详解
TensorFlow深度学习实战(4)——正则化技术详解
TensorFlow深度学习实战(5)——神经网络性能优化技术详解
TensorFlow深度学习实战(6)——回归分析详解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/966278.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

国产编辑器EverEdit - 查找功能详解

1 查找功能详解 1.1 应用场景 查找关键词应该是整个文本编辑/阅读活动中,操作频度非常高的一项,用好查找功能,不仅可以可以搜索到关键字,还可以帮助用户高效完成一些特定操作。 1.2 基础功能 1.2.1 基础查找功能 选择主菜单查…

5分钟了解回归测试

1. 什么是回归测试(Regression Testing) 回归测试是一个系统的质量控制过程,用于验证最近对软件的更改或更新是否无意中引入了新错误或对以前的功能方面产生了负面影响(比如你在家中安装了新的空调系统,发现虽然新的空…

【AI】卷积神经网络CNN

不定期更新,建议关注收藏点赞。 目录 零碎小组件经验总结早期的CNN 最重要的模型架构无非是cnn 或 transformer 零碎小组件 全连接神经网络 目前已经被替代。 每个神经元都有参与,但由于数据中的特征点变化大,全连接神经网络把所有数据特征都…

企业FTP替代升级,实现传输大文件提升100倍!

随着信息技术的飞速发展,网络安全环境也变得越来越复杂。在这种背景下,传统的FTP(文件传输协议)已经很难满足现代企业对文件传输的需求了。FTP虽然用起来简单,但它的局限性和安全漏洞让它在面对高效、安全的数据交换时…

LabVIEW铅酸蓄电池测试系统

本文介绍了基于LabVIEW的通用飞机铅酸蓄电池测试系统的设计与实现。系统通过模块化设计,利用多点传感器采集与高效的数据处理技术,显著提高了蓄电池测试的准确性和效率。 ​ 项目背景 随着通用航空的快速发展,对飞机铅酸蓄电池的测试需求也…

Lecture8 | LPV VXGI SSAO SSDO

Review: Lecture 7 | Lecture 8 LPV (Light Propagation Volumes) Light Propagation Volumes(LPV)-孤岛惊魂CryEngine引进的技术 LPV做GI快|好 大体步骤: Step1.Generation of Radiance Point Set Scene Representation 生成辐射点集的场景表示:辐射…

0012—数组

存取一组数据,使用数组。 数组是一组相同类型元素的集合。 要存储1-10的数字,怎么存储? C语言中给了数组的定义:一组相同类型元素的集合。 创建一个空间创建一组数: 一、数组的定义 int arr[10] {1,2,3,4,5,6,7,8,…

AI绘画社区:解锁艺术共创的无限可能(9/10)

AI 绘画:不只是技术,更是社交新潮流 在科技飞速发展的今天,AI 绘画早已不再仅仅是一项孤立的技术,它正以惊人的速度融入我们的社交生活,成为艺术爱好者们交流互动的全新方式,构建起一个充满活力与创意的社…

让office集成deepseek,支持office和WPS办公软件!(体验感受)

导读 AIGC:AIGC是一种新的人工智能技术,它的全称是Artificial Intelligence Generative Content,即人工智能生成内容。 它是一种基于机器学习和自然语言处理的技术,能够自动产生文本、图像、音频等多种类型的内容。这些内容可以是新闻文章、…

c++ template-3

第 7 章 按值传递还是按引用传递 从一开始,C就提供了按值传递(call-by-value)和按引用传递(call-by-reference)两种参数传递方式,但是具体该怎么选择,有时并不容易确定:通常对复杂类…

使用springAI实现图片相识度搜索

类似的功能:淘宝拍照识别商品。图片相识度匹配 实现方式:其实很简单,用springai 将图片转换为向量数据,然后搜索就是先把需要搜索的图片转位向量再用向量数据去向量数据库搜索 但是springai现在不支持多模态嵌入数据库。做了一些…

私有化部署DeepSeek并SpringBoot集成使用(附UI界面使用教程-支持语音、图片)

私有化部署DeepSeek并SpringBoot集成使用(附UI界面使用教程-支持语音、图片) windows部署ollama Ollama 是一个开源框架,专为在本地机器上便捷部署和运行大型语言模型(LLM)而设计 下载ollama 下载地址(…

半导体制造工艺讲解

目录 一、半导体制造工艺的概述 二、单晶硅片的制造 1.单晶硅的制造 2.晶棒的切割、研磨 3.晶棒的切片、倒角和打磨 4.晶圆的检测和清洗 三、晶圆制造 1.氧化与涂胶 2.光刻与显影 3.刻蚀与脱胶 4.掺杂与退火 5.薄膜沉积、金属化和晶圆减薄 6.MOSFET在晶圆表面的形…

正则表达式的简单介绍 + regex_match使用

正则表达式 正则表达式(Regular Expression,简称 regex)是一种用于匹配字符串的模式。它由一系列字符和特殊符号组成,用于描述、匹配一系列符合某个句法规则的字符串。正则表达式广泛应用于文本搜索、替换、验证等场景。 它的主…

AnythingLLM开发者接口API测试

《Win10OllamaAnythingLLMDeepSeek构建本地多人访问知识库》见上一篇文章,本文在上篇基础上进行。 1.生成本地API 密钥 2.打开API测试页面(http://localhost:3001/api/docs/) 就可以在页面测试API了 2.测试获取用户接口(/v1/admin/users) 3…

TypeScript 中的类:面向对象编程的基础

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

二级C语言题解:矩阵主、反对角线元素之和,二分法求方程根,处理字符串中 * 号

目录 一、程序填空📝 --- 矩阵主、反对角线元素之和 题目📃 分析🧐 二、程序修改🛠️ --- 二分法求方程根 题目📃 分析🧐 三、程序设计💻 --- 处理字符串中 * 号 题目&#x1f…

Qt 支持的动画格式对比,Lottie/APNG/GIF/WEBP

Qt版本:6.7.2 , QML 一,Lottie 在qml中使用LottieAnimation即可,但有三个问题: 1.动画加载中报错: 如果图片(.json)本身存在不支持的effect 或shape type等,效果并不好&#xff1a…

SpringCloud - Nacos注册/配置中心

前言 该博客为Nacos学习笔记,主要目的是为了帮助后期快速复习使用 学习视频:7小快速通关SpringCloud 辅助文档:SpringCloud快速通关 一、简介 Nacos官网:https://nacos.io/docs/next/quickstart/quick-start/ Nacos /nɑ:kəʊ…

老游戏回顾:TL2

TL2是一部ARPG游戏,是TL的续作游戏,由位于美国西雅图的Runic Games开发,游戏于2012年9月20日上市,简体中文版于2013年4月10日在国内上市。 2有非常独特的艺术风格,这些在1中就已经形成,经过升级将使这款游…