【深度学习】线性回归的简洁实现

线性回归的简洁实现

在过去的几年里,出于对深度学习强烈的兴趣,许多公司、学者和业余爱好者开发了各种成熟的开源框架。
这些框架可以自动化基于梯度的学习算法中重复性的工作。
目前,我们只会运用:
(1)通过张量来进行数据存储和线性代数;
(2)通过自动微分来计算梯度。
实际上,由于数据迭代器、损失函数、优化器和神经网络层很常用,现代深度学习库也为我们实现了这些组件。

本节将介绍如何(通过使用深度学习框架来简洁地实现线性回归模型)。

生成数据集

我们首先[生成数据集]。

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
true_w = torch.tensor([2, -3.4])#用于创建张量(Tensor)
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

读取数据集

我们可以[调用框架中现有的API来读取数据]。
我们将featureslabels作为API的参数传递,并通过数据迭代器指定batch_size
此外,布尔值is_train表示是否希望数据迭代器对象在每个迭代周期内打乱数据。

def load_array(data_arrays, batch_size, is_train=True):  #@save
    """构造一个PyTorch数据迭代器"""
    dataset = data.TensorDataset(*data_arrays)
    #TensorDataset可以将多个张量组合成一个数据集对象
    return data.DataLoader(dataset, batch_size, shuffle=is_train)#用于批量加载和处理数据集
batch_size = 10
data_iter = load_array((features, labels), batch_size)

为了验证是否正常工作,让我们读取并打印第一个小批量样本。这里我们使用iter构造Python迭代器,并使用next从迭代器中获取第一项。

next(iter(data_iter))
iter(data_iter)

iter() 是 Python 的内置函数,它的作用是将一个可迭代对象(如列表、元组、DataLoader 等)转换为迭代器对象。迭代器是一种特殊的对象,它实现了 iter() 和 next() 方法,允许我们逐个访问可迭代对象中的元素。
在 PyTorch 里,DataLoader 是一个可迭代对象,它用于批量加载数据。通过iter(DataLoader) 就可以将 DataLoader 转换为迭代器,以便后续使用 next() 函数逐个获取批次数据。

next(iter(data_iter))

next() 也是 Python 的内置函数,它用于从迭代器中获取下一个元素。当调用 next(迭代器对象) 时,迭代器会返回其下一个元素,如果没有更多元素,会抛出 StopIteration 异常。

在这里插入图片描述

定义模型

当我们在实现线性回归时,我们明确定义了模型参数变量,并编写了计算的代码,这样通过基本的线性代数运算得到输出。
但是,如果模型变得更加复杂,且当我们几乎每天都需要实现模型时,自然会想简化这个过程。
这种情况类似于为自己的博客从零开始编写网页。做一两次是有益的,但如果每个新博客就需要工程师花一个月的时间重新开始编写网页,那并不高效。

对于标准深度学习模型,我们可以[使用框架的预定义好的层]。这使我们只需关注使用哪些层来构造模型,而不必关注层的实现细节。

我们首先定义一个模型变量net,它是一个Sequential类的实例。Sequential类将多个层串联在一起。
当给定输入数据时,Sequential实例将数据传入到第一层,然后将第一层的输出作为第二层的输入,以此类推。
在下面的例子中,我们的模型只包含一个层,因此实际上不需要Sequential。但是由于以后几乎所有的模型都是多层的,在这里使用Sequential会让你熟悉“标准的流水线”。

单层网络架构,这一单层被称为全连接层(fully-connected layer),因为它的每一个输入都通过矩阵-向量乘法得到它的每个输出。

# nn是神经网络的缩写
from torch import nn

net = nn.Sequential(nn.Linear(2, 1))

nn.Sequential 是 PyTorch 中的一个容器类,它可以按顺序将多个神经网络层组合在一起,形成一个完整的神经网络模型。使用 nn.Sequential 可以方便地定义一个简单的前馈神经网络,模型会按照添加层的顺序依次对输入数据进行处理。

nn.Linear 是 PyTorch 中的全连接层(也称为线性层),它实现了一个线性变换,其公式为: y = x A T + b y = xA^T + b y=xAT+b,其中 x x x 是输入数据, A A A 是权重矩阵, b b b 是偏置向量, y y y 是输出数据。
nn.Linear 类的构造函数为 nn.Linear(in_features, out_features, bias=True),其中:

  • in_features:输入特征的数量,即输入数据的维度。
  • out_features:输出特征的数量,即输出数据的维度。
  • bias:是否使用偏置项,默认为 True

(初始化模型参数)

在使用net之前,我们需要初始化模型参数。如在线性回归模型中的权重和偏置,深度学习框架通常有预定义的方法来初始化参数。
在这里,我们指定每个权重参数应该从均值为0、标准差为0.01的正态分布中随机采样,偏置参数将初始化为零。

正如我们在构造nn.Linear时指定输入和输出尺寸一样,现在我们能直接访问参数以设定它们的初始值。
我们通过net[0]选择网络中的第一个图层,然后使用weight.databias.data方法访问参数。
我们还可以使用替换方法normal_fill_来重写参数值。

net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)

weight全连接层的权重参数,它是一个 torch.Tensor 对象,代表了线性变换中的权重矩阵。

bias:全连接层的偏置参数,它也是一个 torch.Tensor 对象,代表了线性变换中的偏置向量。

data 属性返回的是其底层的普通张量(不包含梯度信息)。我们直接操作 data 可以避免在初始化过程中触发不必要的梯度计算。

normal_(0, 0.01):这是 PyTorch 张量的一个原地操作方法(方法名末尾带 _ 表示原地操作,会直接修改调用该方法的张量),作用是将张量中的元素用均值为 0、标准差为 0.01 的正态分布。也就是说,这行代码把全连接层的权重矩阵的所有元素初始化为从该正态分布中采样得到的值。
fill_(0):这也是一个原地操作方法,它会把偏置张量中的所有元素都填充为 0,也就是将全连接层的偏置向量初始化为零向量。

定义损失函数

[计算均方误差使用的是MSELoss类,也称为平方 L 2 L_2 L2范数]。默认情况下,它返回所有样本损失的平均值。

loss = nn.MSELoss()
向量的平方 L2 范数

对于一个 n n n 维向量 x = [ x 1 , x 2 , ⋯   , x n ] T \mathbf{x} = [x_1, x_2, \cdots, x_n]^T x=[x1,x2,,xn]T,其 L2 范数(也称为欧几里得范数)定义为向量各元素平方和的平方根,数学表达式为:
∥ x ∥ 2 = ∑ i = 1 n x i 2 \|\mathbf{x}\|_2 = \sqrt{\sum_{i=1}^{n} x_i^2} x2=i=1nxi2
而向量的平方 L2 范数则是 L2 范数的平方,即:
∥ x ∥ 2 2 = ∑ i = 1 n x i 2 \|\mathbf{x}\|_2^2 = \sum_{i=1}^{n} x_i^2 x22=i=1nxi2

矩阵的平方 L2 范数

对于一个 m × n m \times n m×n 的矩阵 A = [ a i j ] \mathbf{A} = [a_{ij}] A=[aij],其 Frobenius 范数(可以看作是矩阵的一种 L2 范数)定义为矩阵所有元素平方和的平方根,表达式为:
∥ A ∥ F = ∑ i = 1 m ∑ j = 1 n a i j 2 \|\mathbf{A}\|_F = \sqrt{\sum_{i=1}^{m} \sum_{j=1}^{n} a_{ij}^2} AF=i=1mj=1naij2
矩阵的平方 L2 范数(即平方 Frobenius 范数)为:
∥ A ∥ F 2 = ∑ i = 1 m ∑ j = 1 n a i j 2 \|\mathbf{A}\|_F^2 = \sum_{i=1}^{m} \sum_{j=1}^{n} a_{ij}^2 AF2=i=1mj=1naij2

定义优化算法

小批量随机梯度下降算法是一种优化神经网络的标准工具,PyTorch在optim模块中实现了该算法的许多变种。
当我们(实例化一个SGD实例)时,我们要指定优化的参数(可通过net.parameters()从我们的模型中获得)以及优化算法所需的超参数字典。小批量随机梯度下降只需要设置lr值,这里设置为0.03。

trainer = torch.optim.SGD(net.parameters(), lr=0.03)

net.parameters() 是 PyTorch 中 nn.Module 类(nn.Sequential 继承自 nn.Module)的一个方法,它会返回一个生成器(Python 中的迭代器对象),这个生成器会逐个产生模型 net 里所有可训练的参数。

训练

通过深度学习框架的高级API来实现我们的模型只需要相对较少的代码。
我们不必单独分配参数、不必定义我们的损失函数,也不必手动实现小批量随机梯度下降。
当我们需要更复杂的模型时,高级API的优势将大大增加。
当我们有了所有的基本组件,[训练过程代码与我们从零开始实现时所做的非常相似]。

回顾一下:在每个迭代周期里,我们将完整遍历一次数据集(train_data),
不停地从中获取一个小批量的输入和相应的标签。
对于每一个小批量,我们会进行以下步骤:

  • 通过调用net(X)生成预测并计算损失l(前向传播)。
  • 通过进行反向传播来计算梯度。
  • 通过调用优化器来更新模型参数。

为了更好的衡量训练效果,我们计算每个迭代周期后的损失,并打印它来监控训练过程。

num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        l = loss(net(X) ,y)
        trainer.zero_grad()
        l.backward()
        trainer.step()
    l = loss(net(features), labels)
    print(f'epoch {epoch + 1}, loss {l:f}')

在这里插入图片描述
下面我们[比较生成数据集的真实参数和通过有限数据训练获得的模型参数]。
要访问参数,我们首先从net访问所需的层,然后读取该层的权重和偏置。正如在从零开始实现中一样,我们估计得到的参数与生成数据的真实参数非常接近。

w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/960382.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java 网络原理 ②-IP协议

这里是Themberfue 经过五节课的传输层协议的讲解,接下来我们将进入网络层协议——IP协议的讲解了~~~ IP协议 IP 相信大家在日常生活中或多或少都听过,你的IP地址是什么?192.168.0.1 ......✨IP 其实是个网络层协议,即互联网协议&…

PETSc源码分析:Nonlinear Solvers

本文结合PETSc源代码,总结PETSc中的非线性方程组求解器。 注1:限于研究水平,分析难免不当,欢迎批评指正。 注2:文章内容会不定期更新。 参考文献 Balay S. PETSc/TAO Users Manual, Revision 3.22. Argonne National …

嵌入式C语言:结构体的多态性之结构体中的void*万能指针

目录 一、void*指针在结构体中的应用 二、实现方式 2.1. 定义通用结构体 2.2. 定义具体结构体 2.3. 初始化和使用 三、应用场景 3.1. 内存管理函数 3.2. 泛型数据结构(链表) 3.3. 回调函数和函数指针 3.4. 跨语言调用或API接口(模拟…

反向代理模块。。

1 概念 1.1 反向代理概念 反向代理是指以代理服务器来接收客户端的请求,然后将请求转发给内部网络上的服务器,将从服务器上得到的结果返回给客户端,此时代理服务器对外表现为一个反向代理服务器。 对于客户端来说,反向代理就相当于…

构建旧系统:打造可维护系统的艺术

作者:来自 Elastic Saman Nourkhalaj 软件开发人员有很多不同的任务,但我们每个人都必须审查旧代码。无论是检查以前的版本还是查看过去某人如何解决问题,遗留代码都是工作的一部分。但是你是否曾经审查过以前的版本并感到沮丧并问 “谁编写了…

PAT (Basic Level) Practice 乙级1031-1040

制作不易,大家的点赞和关注就是我更新的动力! 由于这些题全是大一寒假刷的,部分还是c语言,部分的解题方法比较复杂,希望大家体谅。有问题欢迎大家在评论区讨论,有不足也请大家指出,谢谢大家&am…

BUUCTF 蜘蛛侠呀 1

BUUCTF:https://buuoj.cn/challenges 文章目录 题目描述:密文:解题思路:flag: 相关阅读 CTF Wiki Hello CTF NewStar CTF buuctf-蜘蛛侠呀 BUUCTF:蜘蛛侠呀 MISC(时间隐写)蜘蛛侠呀 题目描述&am…

面向长文本的多模型协作摘要架构:多LLM文本摘要方法

多LLM摘要框架在每轮对话中包含两个基本步骤:生成和评估。这些步骤在多LLM分散式摘要和集中式摘要中有所不同。在两种策略中,k个不同的LLM都会生成多样化的文本摘要。然而在评估阶段,多LLM集中式摘要方法使用单个LLM来评估摘要并选择最佳摘要,而分散式多LLM摘要则使用k个LLM进行…

c语言版贪吃蛇(Pro Max版)附源代码

1 背景 贪吃蛇是一款经典的电子游戏,最早出现在20世纪70年代的街机游戏中。游戏的核心玩法是玩家控制一条蛇在有限的空间内移动,通过吃食物来增长身体长度,同时避免撞到墙壁、障碍物或自身。随着蛇的长度增加,游戏难度逐渐提升。 …

AI软件外包需要注意什么 外包开发AI软件的关键因素是什么 如何选择AI外包开发语言

1. 定义目标与需求 首先,要明确你希望AI智能体做什么。是自动化任务、数据分析、自然语言处理,还是其他功能?明确目标可以帮助你选择合适的技术和方法。 2. 选择开发平台与工具 开发AI智能体的软件时,你需要选择适合的编程语言、…

分布式理解

分布式 如何理解分布式 狭义的分布是指,指多台PC在地理位置上分布在不同的地方。 分布式系统 分布式系**统:**多个能独立运行的计算机(称为结点)组成。各个结点利用计算机网络进行信息传递,从而实现共同的“目标或者任…

python学opencv|读取图像(四十七)使用cv2.bitwise_not()函数实现图像按位取反运算

【0】基础定义 按位与运算:两个等长度二进制数上下对齐,全1取1,其余取0。按位或运算:两个等长度二进制数上下对齐,有1取1,其余取0。 按位取反运算:一个二进制数,0变1,1变0。 【1】…

CVE-2023-38831 漏洞复现:win10 压缩包挂马攻击剖析

目录 前言 漏洞介绍 漏洞原理 产生条件 影响范围 防御措施 复现步骤 环境准备 具体操作 前言 在网络安全这片没有硝烟的战场上,新型漏洞如同隐匿的暗箭,时刻威胁着我们的数字生活。其中,CVE - 2023 - 38831 这个关联 Win10 压缩包挂…

链表排序--(奇数位是升序,偶数位是降序)

题目描述 对一个单链表进行排序,但这个链表有一个特殊的结构: 奇数位是升序:链表中位于奇数位置的节点是按升序排列的。例如,如果链表的第1个节点的值是1,第3个节点的值是3,第5个节点的值是5,那…

在无sudo权限Linux上安装 Ollama 并使用 DeepSeek-R1 模型

本教程将指导你如何在 Linux 系统上安装 Ollama(一个本地运行大型语言模型的工具),并加载 DeepSeek-R1 模型。DeepSeek-R1 是一个高性能的开源语言模型,适用于多种自然语言处理任务。 DeepSeek-R1 简介 DeepSeek-R1 是 DeepSeek …

arduino学习

一、log日志 只看自己 看指定 看错误日志 二、布局 重要:新建activity时需要的配置 若一个工程中有多个activity,需要修改开启activity属性、总容器标签、debug启动activity。下面流程内截图activity不一致,根据自己新建的activity配置&am…

obsidian插件——Metadata Hider

原本是要找导出图片时显示属性的插件,奈何还没找到,反而找到了可以隐藏属性的插件。唉,人生不如意,十之八九。 说一下功能: 这个插件可以把obsidian的文档属性放在右侧显示,或者决定只显示具体几项属性&a…

SimpleFOC STM32教程10|基于STM32F103+CubeMX,速度闭环控制(有电流环)

导言 SimpleFOC STM32教程09|基于STM32F103CubeMX,ADC采样相电流 如上图所示, 增加了电流环. 效果如下: 20250123-200906 RTT 如上图所示,三相占空比依然是马鞍波。当我用手去给电机施加阻力时,PID要维持目标转速&am…

Qt 5.14.2 学习记录 —— 이십일 Qt网络和音频

文章目录 1、UDP带有界面的Udp服务器(回显服务器) 2、TCP回显服务器 3、HTTP客户端4、音频 和Linux的网络一样,Qt封装了Linux的网络API,即Socket API。网络编程是在应用层写,需要传输层支持,传输层有UDP和T…

【C语言基础】编译并运行第一个C程序

博主未授权任何人或组织机构转载博主任何原创文章,感谢各位对原创的支持! 博主链接 博客内容主要围绕: 5G/6G协议讲解 高级C语言讲解 Rust语言讲解 文章目录 编译并运行第一个C程序一、编译上面的程序二、运行上面的程序…