Softmax回归(多类分类模型)

目录

  • 1.对真实值类别编码:
  • 2.预测值:
  • 3.目标函数要求:
  • 4.使用Softmax模型将输出置信度Oi计算转换为输出匹配概率y^i:
  • 5.使用交叉熵作为损失函数:
  • 6.代码实现:

1.对真实值类别编码:

在这里插入图片描述

  • y为真实值,有且仅有一个位置值为1,该位置即为该元素真实类别

2.预测值:

在这里插入图片描述

  • Oi为该元素与类别i匹配的置信度

3.目标函数要求:

在这里插入图片描述

  • 对于正确类y的置信度Oy要远远大于其他非正确类的置信度Oi,才能使识别到的正确类与错误类具有更明显的差距

4.使用Softmax模型将输出置信度Oi计算转换为输出匹配概率y^i:

在这里插入图片描述

  • y^为n维向量,每个元素非负且和为1
  • y^i为元素与类别i匹配的概率

5.使用交叉熵作为损失函数:

在这里插入图片描述

  • L为真实概率y与预测概率y^的差距
  • 分类问题不关心非正确类的预测值,只关心正确类的预测值有多大

6.代码实现:

import sys
import os
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import transforms
from torch.utils import data
from d2l import torch as d2l


os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

## 读取小批量数据
batch_size = 256
trans = transforms.ToTensor()
#train_iter, test_iter = common.load_fashion_mnist(batch_size) #无法翻墙的,可以参考这种方法取下载数据集
mnist_train  = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True) # 需要网络翻墙,这里数据集会自动下载到项目跟目录的/data目录下
mnist_test  = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True) # 需要网络翻墙,这里数据集会自动下载到项目跟目录的/data目录下
print(len(mnist_train))  # train_iter的长度是235;说明数据被分成了234组大小为256的数据加上最后一组大小不足256的数据
print('11111111')


## 展示部分数据
def get_fashion_mnist_labels(labels):  # @save
    """返回Fashion-MNIST数据集的文本标签。"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]


def show_fashion_mnist(images, labels):
    d2l.use_svg_display()
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()


train_data, train_targets = next(iter(data.DataLoader(mnist_train, batch_size=18)))
#展示部分训练数据
show_fashion_mnist(train_data[0:10], train_targets[0:10])

# 初始化模型参数
num_inputs = 784
num_outputs = 10

W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)


# 定义模型
def softmax(X):
    X_exp = X.exp()
    partition = X_exp.sum(dim=1, keepdim=True)
    return X_exp / partition  # 这里应用了广播机制


def net(X):
    return softmax(torch.matmul(X.reshape(-1, num_inputs), W) + b)


# 定义损失函数
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))


def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1)))


# 计算分类准确率
def accuracy(y_hat, y):
    return (y_hat.argmax(dim=1) == y).float().mean().item()


# 计算这个训练集的准确率
def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
        n += y.shape[0]
    return acc_sum / n


num_epochs, lr = 10, 0.1


# 本函数已保存在d2lzh包中方便以后使用
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
              params=None, lr=None, optimizer=None):
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X, y in train_iter:
            y_hat = net(X)
            l = loss(y_hat, y).sum()

            # 梯度清零
            if params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()

            l.backward()
            # 执行优化方法
            if optimizer is not None:
                optimizer.step()
            else:
                d2l.sgd(params, lr, batch_size)

            train_l_sum += l.item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
              % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))


# 训练模型
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)

# 预测模型
for X, y in test_iter:
    break
true_labels = get_fashion_mnist_labels(y.numpy())
pred_labels = get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]
show_fashion_mnist(X[0:9], titles[0:9])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/321823.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

实战指南:如何在Spring Boot中无缝整合Dubbo【四】

欢迎来到我的博客,代码的世界里,每一行都是一个故事 实战指南:如何在Spring Boot中无缝整合Dubbo【四】 前言项目结构主项目(作为主pom)接口服务提供者properties文件实现类 服务消费者properties接口层 实现效果图 前言 微服务架构已经成为…

前端开发必备:掌握正则表达式,轻松应对复杂的表单验证

前言 在前端开发中,经常需要处理 URL 地址、校验手机号合法性、提取域名等。正则表达式是一种常用的工具。通过使用正则表达式,我们可以对用户输入进行有效的验证,确保数据的合法性和完整性。本文将介绍一些常见的正则表达式,帮助…

Meproc:简单高效的跨平台进程/任务管理工具

最近使用 Melang 语言写了一个 supervisor 相似服务Meproc来管理进程。 Meproc 有如下特性: 使用 HTTP API 管理控制 Meproc 来管理进程跨平台,支持 UNIX/Linux 、Mac 、Windows 等平台支持 cron 类定时调度任务支持简单的任务间依赖关系支持原生的协…

如何打赢稳定性之战?

文章目录 前言为什么总会出现问题呢?如何证明你的稳定性做的有效果?既是持久战,也是防御战1. 提前建筑好防御工事2. 以攻为守3. 找外部支援和配合 前言 随着23年年末期间,各大厂争先恐后的出现的各种线上故障,一时间“…

highlight.js 实现搜索关键词高亮效果 ,显示匹配数量及切换显示功能

先看效果&#xff1a; 更新&#xff1a;增加切换显示 折腾了老半天&#xff0c;记录一下 注意事项都写注释了 代码&#xff1a; <template><div class"absolute-lt wh-full overflow-hidden p-10"><div style"width: 200px"><el-…

学网络必懂的华为CSS堆叠技术

知识改变命运&#xff0c;技术就是要分享&#xff0c;有问题随时联系&#xff0c;免费答疑&#xff0c;欢迎联系&#xff01; 厦门微思网络​​​​​​https://www.xmws.cn 华为认证\华为HCIA-Datacom\华为HCIP-Datacom\华为HCIE-Datacom Linux\RHCE\RHCE 9.0\RHCA\ Oracle OC…

【Python数据分析系列】实现txt文件与列表(list)相互读写转换(源码+案例)

这是Python数据分析系列原创文章&#xff0c;我的第199篇原创文章。 一、问题 平时在做数据分析或者程序开发的时候&#xff0c;需要将中间的一些结果或最后的处理结果保存下来&#xff0c;比如保存为txt格式的文本文件&#xff0c;这就涉及列表与txt之间的一种读取和写入操作…

【LV13 DAY16 轮询与中断】

轮询实现按键实验 #include "exynos_4412.h"int main() {//GPX1_1设置为输入模式//GPX1.CONGPX1.CON & (~ (0XF<<4));while(1){if(!(GPX1.DAT&(1<<1))){printf("key pressed\n");while(!(GPX1.DAT&(1<<1)));}else{}}return…

1127: 矩阵乘积

题目描述 计算两个矩阵A和B的乘积。 输入 第一行三个正整数m、p和n&#xff0c;0<m,n,p<10&#xff0c;表示矩阵A是m行p列&#xff0c;矩阵B是p行n列&#xff1b; 接下来的m行是矩阵A的内容&#xff0c;每行p个整数&#xff0c;用空格隔开&#xff1b; 最后的p行是矩…

【期末不挂科-单片机考前速过系列P11】(第十一章:15题速过串行口的工作原理和应用)经典例题盘点(带图解析)

前言 大家好吖&#xff0c;欢迎来到 YY 滴单片机速过系列 &#xff0c;热烈欢迎&#xff01; 本章主要内容面向接触过单片机的老铁 主要内容含&#xff1a; 欢迎订阅 YY滴C专栏&#xff01;更多干货持续更新&#xff01;以下是传送门&#xff01; YY的《C》专栏YY的《C11》专栏…

【模型评估 02】ROC曲线

二值分类器&#xff08;Binary Classifier&#xff09;是机器学习领域中最常见也是应用最广泛的分类器。评价二值分类器的指标很多&#xff0c;比如precision、recall、F1 score、P-R曲线等。相比而言&#xff0c;ROC曲线有很多优点&#xff0c;经常作为评估而知分类器最重要的…

增强FAQ搜索引擎:发挥Elasticsearch中KNN的威力

英文原文地址&#xff1a;https://medium.com/nerd-for-tech/enhancing-faq-search-engines-harnessing-the-power-of-knn-in-elasticsearch-76076f670580 增强FAQ搜索引擎&#xff1a;发挥Elasticsearch中KNN的威力 2023 年 10 月 21 日 在一个快速准确的信息检索至关重要的…

Servlet JSP-实现简单的登录功能

本篇文章讲解如何使用Servlet-JSP-实现简单的登录功能。我们在进行Servlet和JSP实现简单登录功能的过程中&#xff0c;首先通过Eclipse创建了一个Maven项目&#xff0c;这为我们的Web应用提供了良好的项目管理和依赖管理。接下来&#xff0c;我们解决了新建项目时可能出现的报错…

CDN内容分发网络

1、CDN的含义 1.1 什么是CDN&#xff1f; CDN是内容分发网络&#xff08;Content Delivery Network&#xff09;的缩写。它是一种通过将内容部署到全球各地的服务器节点&#xff0c;使用户能够快速访问和下载内容的网络架构。 简单来说&#xff0c;CDN通过将内容分发到离用户更…

鸿蒙应用开发学习:改进小鱼动画实现按键一直按下时控制小鱼移动和限制小鱼移出屏幕

一、前言 近期我在学习鸿蒙应用开发&#xff0c;跟着B站UP主黑马程序员的视频教程做了一个小鱼动画应用&#xff0c;UP主提供的小鱼动画源代码仅仅实现了移动组件的功能&#xff0c;还存在一些问题&#xff0c;如默认进入页面是竖屏而页面适合横屏显示&#xff1b;真机测试发现…

【纯CSS特效源码】(二)精美的立体字

1.漂浮感立体 关键处&#xff1a; text-shadow:2px -2px white, -6px 6px gray;给字体添加了两层shadow&#xff0c;右上角白色提亮&#xff0c;左下角灰色阴影。 参数解释&#xff1a;例子中2px -2px white&#xff0c;代表右上角白色 第一个参数2px&#xff1a;正数表示从左…

C++ 完成Client分页显示log

分页显示t_log 1、获取用户的输入 1.1、写一个Input成员函数&#xff0c;处理输入进来的语句 std::string XClient::Input() {//清空缓冲//cin.ignore(4096, \n);string input "";for (;;){char a getchar();if (a < 0 || a \n || a \r)break;cout <<…

基于鸿蒙HarmonyOS 元服务开发一款公司运营应用(ArkTS API 9)

前言 最近基于Harmony OS最新版本开发了一个作品&#xff0c;本文来详细讲解一下&#xff0c;如何我是如何开发这个作品的。以及如何使用OpenHarmony&#xff0c;基于ArkTS&#xff0c;API 9来开发一个属于自己的元服务。 废话不多说&#xff0c;我的作品名称叫做Company Oper…

Spring Boot 中实现文件上传、下载、删除功能

&#x1f3c6;作者简介&#xff0c;普修罗双战士&#xff0c;一直追求不断学习和成长&#xff0c;在技术的道路上持续探索和实践。 &#x1f3c6;多年互联网行业从业经验&#xff0c;历任核心研发工程师&#xff0c;项目技术负责人。 &#x1f389;欢迎 &#x1f44d;点赞✍评论…

【RPC】序列化:对象怎么在网络中传输?

今天来聊下RPC框架中的序列化。在不同的场景下合理地选择序列化方式&#xff0c;对提升RPC框架整体的稳定性和性能是至关重要的。 一、为什么需要序列化&#xff1f; 首先&#xff0c;我们得知道什么是序列化与反序列化。 网络传输的数据必须是二进制数据&#xff0c;但调用…