【PyTorch】softmax回归

文章目录

  • 1. 模型与代码实现
    • 1.1. 模型
    • 1.2. 代码实现
  • 2. Q&A
    • 2.1. 运行过程中出现以下警告:
    • 2.2. 定义的神经网络中的nn.Flatten()的作用是什么?
    • 2.3. num_workers有什么作用?它的值怎么确定?

1. 模型与代码实现

1.1. 模型

  • 背景
    在分类问题中,模型的输出层是全连接层,每个类别对应一个输出。我们希望模型的输出 y ^ j \hat{y}_j y^j可以视为属于类 j j j的概率,然后选择具有最大输出值的类别作为我们的预测。
    softmax函数能够将未规范化的输出变换为非负数并且总和为1,同时让模型保持可导的性质,而且不会改变未规范化的输出之间的大小次序。
  • softmax函数
    y ^ = s o f t m a x ( o ) \mathbf{\hat{y}}=\mathrm{softmax}(\mathbf{o}) y^=softmax(o)其中 y ^ j = e x p ( o j ) ∑ k e x p ( o k ) \hat{y}_j=\frac{\mathrm{exp}({o_j})}{\sum_{k}\mathrm{exp}({o_k})} y^j=kexp(ok)exp(oj)
  • softmax是一个非线性函数,但softmax回归的输出仍然由输入特征的仿射变换决定,因此,softmax回归是一个线性模型
  • 为了避免将softmax的输出直接送入交叉熵损失造成的数值稳定性问题,需要将softmax和交叉熵损失结合在一起,具体做法是:不将softmax概率传递到损失函数中, 而是在交叉熵损失函数中传递未规范化的输出,并同时计算softmax及其对数。因此定义交叉熵损失函数时也进行了softmax运算

1.2. 代码实现

import torch
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
from tensorboardX import SummaryWriter

def load_dataset(batch_size, num_workers):
    """加载数据集"""
    root = "./dataset"
    transform = transforms.Compose([transforms.ToTensor()])
    mnist_train = FashionMNIST(
        root=root, 
        train=True, 
        transform=transform, 
        download=True
    )
    mnist_test = FashionMNIST(
        root=root, 
        train=False, 
        transform=transform, 
        download=True
    )
    dataloader_train = DataLoader(
        mnist_train, 
        batch_size, 
        shuffle=True, 
        num_workers=num_workers
    )
    dataloader_test = DataLoader(
        mnist_test, 
        batch_size, 
        shuffle=False,
        num_workers=num_workers
    )
    return dataloader_train, dataloader_test

def init_network(net):
    """初始化模型参数"""
    def init_weights(m):
        if type(m) == nn.Linear:
            nn.init.normal_(m.weight, mean=0, std=0.01)
            nn.init.constant_(m.bias, val=0)
    if isinstance(net, nn.Module):
        net.apply(init_weights)

class Accumulator:
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
def accuracy(y_hat, y):
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

def train(net, dataloader_train, criterion, optimizer, device):
    """训练模型"""
    if isinstance(net, nn.Module):
        net.train()
    train_metrics = Accumulator(3)  # 训练损失总和、训练准确度总和、样本数
    for X, y in dataloader_train:
        X, y = X.to(device), y.to(device)
        y_hat = net(X)
        loss = criterion(y_hat, y)
        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()
        train_metrics.add(float(loss.sum()), accuracy(y_hat, y), y.numel())
    train_loss = train_metrics[0] / train_metrics[2]
    train_acc = train_metrics[1] / train_metrics[2]
    return train_loss, train_acc

def test(net, dataloader_test, device):
    """测试模型"""
    if isinstance(net, nn.Module):
        net.eval()
    with torch.no_grad():    
        test_metrics = Accumulator(2)   # 测试准确度总和、样本数
        for X, y in dataloader_test:
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            test_metrics.add(accuracy(y_hat, y), y.numel())
        test_acc = test_metrics[0] / test_metrics[1]
        return test_acc
    

if __name__ == "__main__":
    # 全局参数设置
    batch_size = 256
    num_workers = 4
    num_epochs = 20
    learning_rate = 0.1
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 创建记录器
    writer = SummaryWriter()

    # 加载数据集
    dataloader_train, dataloader_test = load_dataset(batch_size, num_workers)

    # 定义神经网络
    net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10)).to(device)

    # 初始化神经网络
    init_network(net)

    # 定义损失函数
    criterion = nn.CrossEntropyLoss(reduction='none')

    # 定义优化器
    optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        train_loss, train_acc = train(net, dataloader_train, criterion, optimizer, device)
        test_acc = test(net, dataloader_test, device)
        writer.add_scalars("metrics", {
            'train_loss': train_loss, 
            'train_acc': train_acc, 
            'test_acc': test_acc
            }, 
            epoch)
    writer.close()   

输出结果:
softmax回归

2. Q&A

2.1. 运行过程中出现以下警告:

UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at …\torch\csrc\utils\tensor_numpy.cpp:180.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

该警告的大致意思是给定的NumPy数组不可写,并且PyTorch不支持不可写的张量。这意味着你可以使用张量写入底层(假定不可写)NumPy数组。在将数组转换为张量之前,可能需要复制数组以保护其数据或使其可写。在本程序的其余部分,此类警告将被抑制。因此需要修改C:\Users\%UserName%\anaconda3\envs\%conda_env_name%\lib\site-packages\torchvision\datasets\mnist.py的第498行,将return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)中的False改成True

2.2. 定义的神经网络中的nn.Flatten()的作用是什么?

net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10)).to(device)

nn.Flatten()的作用是将图像数据张量展成一维,方便输入后续的全连接层。

2.3. num_workers有什么作用?它的值怎么确定?

num_workers表示加载batch数据的进程数,num_workers=0时只有主进程去加载batch数据。要实现多进程加载数据,加载函数一定要位于if __name__ == "__main__"下。一般开始是将num_workers设置为等于计算机上的CPU内核数量,在此基础上,尝试增加或减少num_workers的值,选择训练速度高时的值。查看CPU内核数量的方法:“任务管理器 > 性能 > CPU”。
CPU内核数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/225478.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

用C语言实现链栈的基本操作

#include <stdio.h> #include <malloc.h> #define ElemType char//相当于ElemType等同于char类型 //链式结构 数据域指针域 typedef struct LinkStackNode//定义一个链栈的结构体类型 {ElemType data;//ElemType是链栈的元素类型&#xff0c;代表数据域struct Lin…

Let‘s Encrypt安全证书的步骤及使用-基于centos9, 包括工具certbot安装及使用,获取apache、nginx、iis等服务器安全证书

Lets Encrypt 介绍 Lets Encrypt 是一个免费、自动化的证书颁发机构&#xff08;CA&#xff0c;Certificate Authority&#xff09;&#xff0c;致力于为网站提供免费的SSL/TLS证书。以下是关于Lets Encrypt安全证书的详细介绍&#xff1a; 特点和背景&#xff1a; 免费&#…

zedbox 实现配置 yolov5

Stereolabs 的 ZED Box 是一款由 NVIDIA Jetson™ 提供支持的紧凑型网关&#xff0c;用于在移动和现场情况下挑战 AIoT&#xff08;物联网&#xff09;。它旨在将空间感知和理解带到边缘&#xff0c;并聚合、处理和分析来自 3D 传感器和设备的数据。 相关内容 1.win10下 cud…

超声波清洗机有用吗,超声波清洗机怎么选?热门超声波清洗机测评

用过超声波清洗机的朋友都知道&#xff0c;超声波清洗机比自己手动清洗眼镜的好处在于&#xff0c;它能够更加快速、高效的清洗干净眼镜等物件&#xff01;通过超声波振频的工作原理对物体的污渍进行一个脱离剥落&#xff0c;不接触物件的清洗方式也保证了物件清洗的安全性&…

leetcode刷题:611.有效三角形的个数(双指针实现)

题目地址&#xff1a;有效三角形的个数 解决此题时&#xff0c;首先需要知道的是如何判断三个数字是否能够构成三角形。 我们知道&#xff0c;三角形任意两边之和都大于第三边。所以判断三个数字是否能构成三角形需要进行三次比较&#xff08;最基础的思路&#xff09; 方法一…

LangChain的函数,工具和代理(五):Tools Routing

关于langchain的函数、工具、代理系列的博客我之前已经写了四篇&#xff0c;还没有看过的朋友请先看一下&#xff0c;这样便于对后续博客内容的理解&#xff1a; LangChain的函数&#xff0c;工具和代理(一)&#xff1a;OpenAI的函数调用 LangChain的函数&#xff0c;工具和代…

【改进YOLOv8】融合Gold-YOLO的车辆未礼让行人检测系统

1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 研究背景与意义 随着交通工具的普及和道路交通量的增加&#xff0c;交通安全问题日益凸显。尤其是车辆未礼让行人的情况频繁发生&#xff0c;给行人的生命安全带来了严重威胁。因…

从零开始:同城O2O外卖APP的技术开发指南

随着互联网的迅速发展&#xff0c;O2O&#xff08;OnlinetoOffline&#xff09;模式在各个行业都取得了巨大成功&#xff0c;而同城外卖APP更是成为人们生活中不可或缺的一部分。本文将从零开始&#xff0c;为您提供一份同城O2O外卖APP的技术开发指南&#xff0c;让您能够深入了…

OpenCV交叉编译

1.下载代码解压 tar -zxvf opencv-4.8.1.tar.gz cd cd opencv-4.8.1 sudo mkdir chmod 777 build cd build 2.配置交叉编译工具 根据自己的板子进行修改 -D CMAKE_C_COMPILERaarch64-mix210-linux-gcc -D CMAKE_CXX_COMPILERaarch64-mix210-linux-g 3.cmake生成makefi…

1.4 场景设计精要

一、场景主题确定 设计游戏场景首先明确游戏发生的时间地点等时代背景。通过对玩家动线的设计&#xff0c;功能模型的合理布局构建出场景的基本骨架。利用光影效果和色彩变化烘托场景氛围。 市场上常见的主题场景&#xff1a;剑侠、科幻、废墟、魔幻等 二、场景风格确定 大类分…

C51汇编程序

目录 一.C51的数据类型和存储类型 1.数据类型&#xff1a; 2.C51的扩展数据类型&#xff1a; 3.数据存储类型 4.数据存储模式 二.特殊功能寄存器及其位变量定义 1.特殊功能寄存器的C51定义 2.位变量的C51定义 三.C51语言的绝对地址访问 1.绝对宏 2._at_关键字 一.C5…

Linux CentOS本地部署SQL Server数据库结合cpolar内网穿透实现公网访问

&#x1f308;个人主页&#xff1a;聆风吟 &#x1f525;系列专栏&#xff1a;数据结构、Cpolar杂谈 &#x1f516;少年有梦不应止于心动&#xff0c;更要付诸行动。 文章目录 &#x1f4cb;前言一. 安装sql server二. 局域网测试连接三. 安装cpolar内网穿透四. 将sqlserver映射…

创业6个月裤衩都赔掉了;2023生成式AI年度大事表;AI工程师的自我修养;LLM开发者成长计划;OpenAI LLM入门课程 | ShowMeAI日报

&#x1f440;日报&周刊合集 | &#x1f3a1;生产力工具与行业应用大全 | &#x1f9e1; 点赞关注评论拜托啦&#xff01; &#x1f440; 黄家驹AI演唱「直到世界尽头」&#xff0c;是科技前进也是青春回望~ https://www.bilibili.com/video/BV1CG411i7MV 最近几天&#xf…

2022 RedisDays 内容揭秘

上个月&#xff0c;Redis举办了3场线上会议&#xff0c;分别介绍了即将正式发布的Redis 7中包括的重要更新的内容&#xff0c;还有Redis完全重写的RedisJSON 2.0模块&#xff0c;和新发布的Redis Stack模块。除此之外&#xff0c;在此次线上会议中还介绍了现代化的软件架构与Re…

pyside6详细笔记

文章目录 主要模组简介绍安装与环境配置安装配置QtDesignerPyUICPyRCC基础了解元对象系统对象模型重要的类Qt 对象:身份?值?对象树与所有状态概述QObjects 的构造/销毁顺序继承关系图Qt 命名空间模块简介QWidget窗口的创建在PyQt中使用qrc/rcc资源系统Qt 资源系统简介qrc 文件…

从Java8升级到Java17,特色优化点

从Java8升级到Java17&#xff0c;特色优化点 一、局部变量类型推断二、switch表达式三、文本块四、Records五、模式匹配instanceof六、密封类七、NullPointerException 从Java 8 到 Java 20&#xff0c;Java 已经走过了漫长的道路&#xff0c;自 Java 8 以来&#xff0c;Java 生…

预赛->省赛->国赛 我的全国软件测试大赛之旅

学习推荐 Web 功能测试&#xff1a;Javaselenium3 web自动化测试实战 性能测试&#xff1a;看慕测官方的视频&#xff0c;这里会用就行&#xff0c;不用学太多 自己根据视频写的&#xff1a;Web自动测试常用代码(Java版) Web没啥难的&#xff0c;主要拼手速&#xff0c;其…

出错:I/O文件读取JAVA

I/O文件读取 /** author:xiaowang* date:2023/12/6* demand:读取java1班的数据* * */ package homework;import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException;public class FileReadTest {public static void main(String[] args) …

Windows系统上如何搭建Linux操作系统

一、准备工作 1&#xff0c;VMware安装包 2&#xff0c;Centos IOS镜像 3&#xff0c;finalshell安装包 阿里云盘下载地址&#xff1a; https://www.alipan.com/s/uSQsWn15E3W 二&#xff0c;VMware安装 1&#xff0c;新建虚拟机 2&#xff0c;选择下一步 3&#xff0c;…

小航助学题库白名单竞赛考级蓝桥杯等考scratch(14级)(含题库教师学生账号)

需要在线模拟训练的题库账号请点击 小航助学编程在线模拟试卷系统&#xff08;含题库答题软件账号&#xff09; 需要在线模拟训练的题库账号请点击 小航助学编程在线模拟试卷系统&#xff08;含题库答题软件账号&#xff09;