【动手学深度学习Pytorch】2. Softmax回归代码

零实现

        导入所需要的包:

import torch
from IPython import display
from d2l import torch as d2l

        定义数据集参数、模型参数:

batch_size = 256 # 每次随机读取256张图片
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
# 将展平每个图片将其视为长度为784的向量,数据集存在10个类别
num_inputs = 784
num_outputs = 10
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

        实现Softmax操作:

# 实现Softmax
def softmax(X):
    X_exp = torch.exp(X)
    partition = X_exp.sum(1, keepdim=True) #列数为特征数,行数为样本数
    return X_exp / partition #广播机制

# 尝试进行Softmax操作
X = torch.normal(0, 1, (2,5))
X_prob = softmax(X)
X_prob, X_prob.sum(1)

# 实现Softmax回归模型
def net(X):
    return softmax(torch.matmul(X.reshape(-1,W.shape[0]),W)+b)

        定义交叉熵函数:

# 创建一个数据y_hat,其中包含2个样本在3个类别的预测概率,使用y作为y_hat中概率的索引
y = torch.tensor([0,2])
y_hat = torch.tensor([[0.1, 0.3, 0.6],[0.3, 0.2, 0.5]])
y_hat[[0, 1], y]
# 交叉熵函数
def cross_entropy(y_hat, y):
    return -torch.log(y_hat[range(len(y_hat)),y])
cross_entropy(y_hat, y)

        将预测类别于真实元素进行比较:

torch.argmax(input, dim=None, keepdim=False):用于返回指定维度中最大值的索引。通常用于分类任务中从预测输出中找到概率最大的类别

.dtype:.dtype 是张量的属性,用于返回该张量的 数据类型 (data type)。每个张量都有一个数据类型,用于定义其中存储元素的类型,例如浮点数、整数或布尔值。

tensor.type(dtype=None):不传入参数时,返回一个字符串,表示张量的类型;传入参数时,返回一个新的张量,该张量的类型与指定类型匹配。

x = torch.tensor([1.0, 2.0, 3.0])  # 默认 float32 类型
print(x.type())  # 输出: torch.FloatTensor

x_int = x.type(torch.int64)
print(x_int)         # 输出: tensor([1, 2, 3])
print(x_int.type())  # 输出: torch.LongTensor (int64 的别名)

net.eval():设置为评估模式。

def accuracy(y_hat, y):#计算预测争取的数量
    # 判断 y_hat 是否为多维张量(例如二维)
    if len(y_hat.shape)>1 and y_hat.shape[1] > 1:
        # 如果是多类别分类(第二维大于 1),通过argmax获取每行中概率或分数最大的类别索引
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype)==y  # 比较预测结果和真实标签是否相等
    return float(cmp.type(y.dtype).sum()) # 返回预测正确的总数量

accuracy(y_hat, y) / len(y)

def evaluate_accuracy(net, data_iter):#计算在指定数据集上的模型精度
    # 如果是 PyTorch 模型,设置为评估模式
    if isinstance(net, torch.nn.Module):
        net.eval() 
    metric = Accumulator(2)  # 初始化累加器,存储 [正确预测数, 总样本数]
    for X, y in data_iter:
        metric.add(accuracy(net(X), y), y.numel()) # 累加每批数据的预测结果
    return metric[0] / metric[1]  # 返回精度:正确预测数 / 总样本数

        Accumulator实例:

class Accumulator: #在n个变量上累加
    def __init__(self, n):
        self.data = [0.0] * n
    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]
    def reset(self):
        self.data = [0.0] * len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

evaluate_accuracy(net, test_iter)

        定义训练过程: 

net.train():设置为训练模式。

torch.optim.Optimizer.step():用于执行模型参数更新基于之前计算好的梯度(通过反向传播获得),按照优化算法的规则调整模型参数的值,以最小化损失函数。

def train_epoch_ch3(net, train_iter, loss, updater):
    if isinstance(net, torch.nn.Module):
        net.train()
    metric = Accumulator(3)
    for X, y in train_iter:
        y_hat = net(X)
        l = loss(y_hat, y) #计算损失
        if isinstance(updater, torch.optim.Optimizer):
            updater.zero_grad() # 清除梯度
            l.backward() # 反向传播计算梯度
            updater.step() # 根据梯度更新模型参数
            metric.add(
                float(l) * len(y),  # 累加当前批次的损失
                accuracy(y_hat, y),  # 累加当前批次的正确预测数
                y.size().numel())  # 累加当前批次的样本数
        else: # 如果是自定义优化器
            l.sum().backward()
            updater(X.shape[0]) # 自定义的更新函数,可能需要批次大小作为参数
            metric.add(float(l.sum()), 
                       accuracy(y_hat),
                       y.numel())
    return metric[0] / metric[2], metric[1] / metric[2]

        定义一个在动画中绘制数据的实用程序类:

class Animator: #实时观看在训练过程中的变化
    # 初始化绘图环境,包括图表的设置、标签、坐标轴范围、曲线样式等。
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-','m--','g-','r:'),nrows=1,ncols=1,
                 figsize=(3.5, 2,5)):
        if legend is None:
            legend = []
        d2l.use_svg_display()
        self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols ==1:
            self.axes = [self.axes,]
        self.config_axes = lambda:d2l.set_axes(self.axes[0],
                                              xlabel, ylabel,
                                              xlim, ylim,
                                              xscale, yscale,
                                              legend)
        self.X, self.Y, self.fmt = None, None, fmts

    def add(self, x, y):
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)

        训练函数: 

# 训练函数
def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):
        # 进行可视化
        animator = Aminator(xlabel='epoch', xlim=[1, num_epochs], 
                            ylim=[0.3,],
                           legend=['train loss','train acc','test acc'])
        for epoch in range(num_epochs):
            train_metrics = train_epoch_ch2(net, train_iter, loss, updater)
            test_acc = evaluate_accuracy(net, test_iter)
            animator.add(epoch+1, train_metrics+(test_acc,))
        train_loss, train_acc = train_metrics

# 小批量随机梯度下降来优化训练算法
lr = 0.1
def updater(batch_size):
    return d2l.sgd([W,b],lr,batch_size)

num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater(10))

 简洁实现

        导入所需要的包:

import torch
from IPython import display
from d2l import torch as d2l

        初始化数据集、模型参数、损失函数以及训练优化算法:网络加入高斯噪声,增强泛化性。

torch.nn.init.normal_(tensor, mean=0.0, std=1.0):正态分布(高斯分布)随机初始化张量的值

nn.Sequential(*modules):用于将多个模块(如线性层、激活函数等)按顺序组合成一个模型。适合简单的前向计算场景。

nn.Flatten(start_dim=1, end_dim=-1):将输入张量展平成二维张量,适用于线性层输入。

nn.Linear(in_features, out_features, bias=True):实现一个线性层(全连接层)

nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean'):计算分类任务中的交叉熵损失(适用于多分类问题)。
torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False):实现随机梯度下降(SGD)优化算法,用于更新模型参数。

net.parameters():返回模型的可训练参数的迭代器。

batch_size = 256 # 每次随机读取256张图片
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

net = nn.Sequential(nn.Flatten(),nn.Linear(784, 100))
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

loss = nn.CrossEntropyLoss()

trainer = torch.optim.SGD(net.parameters(),lr=0.1)

        用之前定义的训练函数训练模型:

num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater(10))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/917916.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

51单片机基础05 实时时钟-思路及代码参考2、3

目录 一、思路二 1、原理图 2、代码 二、思路三 1、原理图 2、代码 一、思路二 所有设定功能相关的操作均在矩阵键盘进行实现&#xff0c;并在定时器中扫描、计数等 1、原理图 2、代码 #include <AT89X52.h> //调用51单片机的头文件 //------------------…

Notepad++的完美替代

由于Notepad的作者曾发表过可能在开发者代码中植入恶意软件的言论&#xff0c;他备受指责。在此&#xff0c;我向大家推荐一个Notepad的完美替代品——NotepadNext和Notepad--。 1、NotepadNext NotepadNext的特点&#xff1a; 1、跨平台兼容性 NotepadNext基于Electron或Qt…

Python | Leetcode Python题解之第564题数组嵌套

题目&#xff1a; 题解&#xff1a; class Solution:def arrayNesting(self, nums: List[int]) -> int:ans, n 0, len(nums)for i in range(n):cnt 0while nums[i] < n:num nums[i]nums[i] ni numcnt 1ans max(ans, cnt)return ans

面试经典 150 题:20、2、228、122

20. 有效的括号 参考代码 #include <stack>class Solution { public:bool isValid(string s) {if(s.size() < 2){ //特判&#xff1a;空字符串和一个字符的情况return false;}bool flag true;stack<char> st; //栈for(int i0; i<s.size(); i){if(s[i] ( |…

使用vscode+expo+Android夜神模拟器运行react-native项目

1.进入夜神模拟器安装路径下的bin目录 2.输入命令&#xff0c;连接Android Studio 启动夜神模拟器后&#xff0c; 打开安装目录的bin文件夹执行下面的命令&#xff0c;只需执行一次&#xff09; nox_adb.exe connect 127.0.0.1:62001adb connect 127.0.0.1:62001 3.运行项目…

【STM32】USB 简要驱动软件架构图

STM32 USB 软件架构比较复杂&#xff0c;建议去看 UM 1734 或者 st wiki STM32 USB call graph STM32 USB Device Library files organization Reference [1]: https://wiki.stmicroelectronics.cn/stm32mcu/wiki/Introduction_to_USB_with_STM32 [2]: UM1734

鸿蒙中如何实现图片拉伸效果

2024年10月22日&#xff0c;华为发布会上&#xff0c;推出鸿蒙5.0。现在加入恰逢时机&#xff0c;你&#xff0c;我皆是鸿蒙时代合伙人。无论为了学习技术&#xff0c;还是为了谋福利&#xff0c;在鸿蒙的浩瀚海洋中分到一杯羹。现在学习鸿蒙正当时。 一文了解鸿蒙中图片拉伸的…

VUE+SPRINGBOOT实现邮箱注册、重置密码、登录功能

随着互联网的发展&#xff0c;网站用户的管理、触达、消息通知成为一个网站设计是否合理的重要标志。目前主流互联网公司都支持手机验证码注册、登录。但是手机短信作为服务端网站是需要付出运营商通信成本的&#xff0c;而邮箱的注册、登录、重置密码&#xff0c;无疑成为了这…

网络基础(4)传输层

既然是传输层首先就要明确实在层状结构的哪里,除开物理层之外分成了四层协议: 到这里上层(应用层)的使用已经没有问题&#xff0c;之前使用的套接字都是在应用层的。 再说端口号 到一个主机收到一个报文的时候&#xff0c;这个报文中一定存在这个报文需要到的主机的ip号。如果…

web——sqliabs靶场——第六关——报错注入和布尔盲注

这一关还是使用报错注入和布尔盲注 一. 判断是否有sql注入 二. 判断注入的类型 是双引号的注入类型。 3.报错注入的检测 可以使用sql报错注入 4.查看库名 5. 查看表名 6.查看字段名 7. 查具体字段的内容 结束 布尔盲注 结束

网络基础 - 网段划分篇

我们知道&#xff0c;IP 地址(IPv4 地址)由 “网络标识(网络地址)” 和 “主机标识(主机地址)” 两部分组成&#xff0c;例如 192.168.128.10/24&#xff0c;其中的 “/24” 表示从第 1 位开始到多少位属于网络标识&#xff0c;那么&#xff0c;剩余位就属于主机标识了&#xf…

【AI图像生成网站Golang】JWT认证与令牌桶算法

AI图像生成网站 目录 一、项目介绍 二、雪花算法 三、JWT认证与令牌桶算法 四、项目架构 五、图床上传与图像生成API搭建 六、项目测试与调试(等待更新) 三、JWT认证与令牌桶算法 在现代后端开发中&#xff0c;用户认证和接口限流是确保系统安全性和性能的两大关键要素…

TR3:Pytorch复现Transformer

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 一、实验目的 从整体上把握Transformer模型&#xff0c;明白它是个什么东西&#xff0c;可以干嘛读懂Transformer的复现代码 二、实验环境 语言环境&#xff1…

数据分布之指数分布(sample database classicmodels _No.10)

数据分布之指数分布&#xff08;sample database classicmodels _No.10&#xff09; 准备工作&#xff0c;可以去下载 classicmodels 数据库具体如下 点击&#xff1a;classicmodels 也可以去 下面我的博客资源下载 https://download.csdn.net/download/tomxjc/88685970 文章…

无人机动力系统测试-实测数据与CFD模拟仿真数据关联对比分析

我们经常被问到这样的问题&#xff1a;“我们计划运行 CFD 仿真&#xff0c;我们还需要对电机和螺旋桨进行实验测试吗&#xff1f;我们可能有偏见&#xff0c;但我们的答案始终是肯定的&#xff0c;而且有充分的理由。我们自己执行了大量的 CFD 仿真&#xff0c;但我们承认&…

MinIO 的 S3 over RDMA 计划: 为高速人工智能数据基础设施设定对象存储新标准

随着 AI 和机器学习的需求不断加速&#xff0c;数据中心网络正在迅速发展以跟上步伐。对于许多企业来说&#xff0c;400GbE 甚至 800GbE 正在成为标准选择&#xff0c;因为数据密集型和时间敏感型 AI 工作负载需要高速、低延迟的数据传输。用于大型语言处理、实时分析和计算机视…

游戏引擎学习第13天

视频参考:https://www.bilibili.com/video/BV1QQUaYMEEz/ 改代码的地方尽量一张图说清楚吧,懒得浪费时间 game.h #pragma once #include <cmath> #include <cstdint> #include <malloc.h>#define internal static // 用于定义内翻译单元内部函数 #…

十分钟学会html超文本标记语言

前言 本次学习的是在b站up主泷羽sec课程有感而发&#xff0c;如涉及侵权马上删除文章。 笔记的只是方便各位师傅学习知识&#xff0c;以下网站只涉及学习内容&#xff0c;其他的都与本人无关&#xff0c;切莫逾越法律红线&#xff0c;否则后果自负。 &#xff01;&#xff01;…

【Linux系统编程】第四十七弹---深入探索:POSIX信号量与基于环形队列的生产消费模型实现

✨个人主页&#xff1a; 熬夜学编程的小林 &#x1f497;系列专栏&#xff1a; 【C语言详解】 【数据结构详解】【C详解】【Linux系统编程】 目录 1、POSIX信号量 2、基于环形队列的生产消费模型 2.1、代码实现 2.1.1、RingQueue基本结构 2.1.2、PV操作 2.1.3、构造析构…

除了 TON, 哪些公链在争夺 Telegram 用户?数据表现如何?

作者&#xff1a;Stella L (stellafootprint.network) 在 2024 年&#xff0c;区块链游戏大规模采用迎来了一个意想不到的催化剂&#xff1a;Telegram。随着各大公链争相布局这个拥有海量用户基础的即时通讯平台&#xff0c;一个核心问题浮出水面&#xff1a;这种用户获取策略…