动手学深度学习-3.2 线性回归的从0开始

以下是代码的逐段解析及其实际作用:


1. 环境设置与库导入

%matplotlib inline
import random
import torch
from d2l import torch as d2l
  • 作用
    • %matplotlib inline:在 Jupyter Notebook 中内嵌显示 matplotlib 图形。
    • random:生成随机索引用于数据打乱。
    • torch:PyTorch 深度学习框架。
    • d2l:《动手学深度学习》提供的工具函数库(如绘图工具)。

2. 生成合成数据

假设真实权重向量为 w true ∈ R n \mathbf{w}_{\text{true}} \in \mathbb{R}^n wtrueRn,偏置为 b true b_{\text{true}} btrue,噪声为高斯分布 ϵ ∼ N ( 0 , σ 2 ) \epsilon \sim \mathcal{N}(0, \sigma^2) ϵN(0,σ2),则合成数据生成公式为:
y = X w true + b true + ϵ \mathbf{y} = \mathbf{X} \mathbf{w}_{\text{true}} + b_{\text{true}} + \epsilon y=Xwtrue+btrue+ϵ
其中:

  • X ∈ R m × n \mathbf{X} \in \mathbb{R}^{m \times n} XRm×n:输入特征矩阵( m m m 个样本, n n n 个特征)。
  • w true ∈ R n \mathbf{w}_{\text{true}} \in \mathbb{R}^n wtrueRn:真实权重向量。
  • ϵ ∈ R m \epsilon \in \mathbb{R}^m ϵRm:噪声向量。
def synthetic_data(w, b, num_examples):  #@save
    """生成y=Xw+b+噪声"""
    X = torch.normal(0, 1, (num_examples, len(w)))  # 生成标准正态分布的输入特征 num_examples行,len(w)列
    y = torch.matmul(X, w) + b                      # 计算线性输出 y = Xw + b
    y += torch.normal(0, 0.01, y.shape)             # 添加高斯噪声
    return X, y.reshape((-1, 1))                    # y行数不定(值为-1,列数为1)

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

生成的函数是一个二维线性回归模型,其数学表达式为:

y = w 1 x 1 + w 2 x 2 + b + ϵ y = w_1 x_1 + w_2 x_2 + b + \epsilon y=w1x1+w2x2+b+ϵ

其中:

  • 权重 w = [ w 1 , w 2 ] = [ 2 , − 3.4 ] \mathbf{w} = [w_1, w_2] = [2, -3.4] w=[w1,w2]=[2,3.4],由 true_w 定义。
  • 偏置 b = 4.2 b = 4.2 b=4.2,由 true_b 定义。
  • 噪声 ϵ ∼ N ( 0 , 0.0 1 2 ) \epsilon \sim \mathcal{N}(0, 0.01^2) ϵN(0,0.012),即均值为 0、标准差为 0.01 的高斯噪声。

展开为标量形式:
y i = 2 ⋅ x i 1 − 3.4 ⋅ x i 2 + 4.2 + ϵ i ( i = 1 , 2 , … , 1000 ) y_i = 2 \cdot x_{i1} - 3.4 \cdot x_{i2} + 4.2 + \epsilon_i \quad (i = 1, 2, \dots, 1000) yi=2xi13.4xi2+4.2+ϵi(i=1,2,,1000)


3. 数据可视化

d2l.set_figsize()
d2l.plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1);
  • 绘制第二个特征(features[:,1] => n行第1列)与标签 labels 的散点图。

4. 定义数据迭代器

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)  # 打乱索引顺序
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]  # 生成小批量数据
  • 作用
    • 将数据集按 batch_size 划分为小批量,并随机打乱顺序。
    • 使用生成器 (yield) 逐批返回数据,避免一次性加载全部数据到内存。

5. 初始化模型参数

w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
  • 初始化w和b的值
    • w:从均值为 0、标准差为 0.01 的正态分布中初始化权重,启用梯度追踪。
    • b:初始化为 0 的偏置,启用梯度追踪。
    • 参数需梯度追踪以支持反向传播。

6. 定义模型、损失函数和优化器

def linreg(X, w, b):  #@save
    """线性回归模型"""
    return torch.matmul(X, w) + b

def squared_loss(y_hat, y):  #@save
    """均方损失"""
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2  # 除以2便于梯度计算

def sgd(params, lr, batch_size):  #@save
    """小批量随机梯度下降"""
    with torch.no_grad():  # 禁用梯度计算
        for param in params:
            param -= lr * param.grad / batch_size  # 参数更新
            param.grad.zero_()                     # 梯度清零
  • linreg:模型预测值 y ^ \hat{\mathbf{y}} y^ 的矩阵形式为:
    y ^ = X w + b \hat{\mathbf{y}} = \mathbf{X} \mathbf{w} + b y^=Xw+b
    其中:

    • w ∈ R n \mathbf{w} \in \mathbb{R}^n wRn:待学习的权重向量。
    • b ∈ R b \in \mathbb{R} bR:待学习的偏置。
  • squared_loss:损失函数的矩阵形式为:
    L = 1 2 ∥ y ^ − y ∥ 2 L = \frac{1}{2} \| \hat{\mathbf{y}} - \mathbf{y} \|^2 L=21y^y2

    L ( w , b ) = 1 2 m ∥ X w + b − y ∥ 2 L(\mathbf{w}, b) = \frac{1}{2m} \| \mathbf{X} \mathbf{w} + b - \mathbf{y} \|^2 L(w,b)=2m1Xw+by2
    展开后:
    L ( w , b ) = 1 2 m ( X w + b 1 − y ) ⊤ ( X w + b 1 − y ) L(\mathbf{w}, b) = \frac{1}{2m} (\mathbf{X} \mathbf{w} + b \mathbf{1} - \mathbf{y})^\top (\mathbf{X} \mathbf{w} + b \mathbf{1} - \mathbf{y}) L(w,b)=2m1(Xw+b1y)(Xw+b1y)

  • sgd:小批量随机梯度下降优化器,

    • 对权重 w \mathbf{w} w 的梯度
      ∇ w L = 1 m X ⊤ ( X w + b 1 − y ) \nabla_{\mathbf{w}} L = \frac{1}{m} \mathbf{X}^\top (\mathbf{X} \mathbf{w} + b \mathbf{1} - \mathbf{y}) wL=m1X(Xw+b1y)

    • 对偏置 b b b 的梯度
      ∇ b L = 1 m 1 ⊤ ( X w + b 1 − y ) , 1 为单位列向量 \nabla_{b} L = \frac{1}{m} \mathbf{1}^\top (\mathbf{X} \mathbf{w} + b \mathbf{1} - \mathbf{y}),\mathbf{1} 为单位列向量 bL=m11(Xw+b1y)1为单位列向量

    • 使用学习率 η \eta η,参数更新公式为:
      w ← w − η ∇ w L b ← b − η ∇ b L \mathbf{w} \leftarrow \mathbf{w} - \eta \nabla_{\mathbf{w}} L\\ b \leftarrow b - \eta \nabla_{b} L wwηwLbbηbL


7. 训练循环

lr = 0.03
num_epochs = 3
batch_size = 10  # 需补充定义(原代码未显式定义)

for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)  # 计算小批量损失
        l.sum().backward()         # 反向传播计算梯度
        sgd([w, b], lr, batch_size) # 更新参数
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
  • 作用

    • 外层循环:遍历训练轮次 (num_epochs)。
    • 内层循环:按小批量遍历数据,计算损失并反向传播。
    • l.sum().backward():将小批量损失求和后反向传播,计算梯度。
    • sgd:根据梯度更新参数,梯度需除以 batch_size 以保持学习率一致性。
    • 每个 epoch 结束后,计算并打印整体训练损失。
    • mean()函数计算平均值
  • 梯度下降

  l.sum().backward()  # 反向传播计算梯度
  sgd([w, b], lr, batch_size)  # 更新参数
  • 小批量梯度计算公式:
    ∇ w L batch = 1 batch_size X batch ⊤ ( X batch w + b − y batch ) \nabla_{\mathbf{w}} L_{\text{batch}} = \frac{1}{\text{batch\_size}} \mathbf{X}_{\text{batch}}^\top (\mathbf{X}_{\text{batch}} \mathbf{w} + b - \mathbf{y}_{\text{batch}}) wLbatch=batch_size1Xbatch(Xbatchw+bybatch)
    ∇ b L batch = 1 batch_size 1 ⊤ ( X batch w + b − y batch ) \nabla_{b} L_{\text{batch}} = \frac{1}{\text{batch\_size}} \mathbf{1}^\top (\mathbf{X}_{\text{batch}} \mathbf{w} + b - \mathbf{y}_{\text{batch}}) bLbatch=batch_size11(Xbatchw+bybatch)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/964086.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【汽车电子软件架构】AutoSAR从放弃到入门专栏导读

本文是汽车电子软件架构:AutoSAR从放弃到入门专栏的导读篇。文章延续专栏文章的一贯作风,从概念与定义入手,希望读者能对AutoSAR架构有一个整体的认识,然后对专栏涉及的文章进行分类与链接。本文首先从AutoSAR汽车软件架构的概念&…

八、Spring Boot 日志详解

目录 一、日志的用途 二、日志使用 2.1 打印日志 2.1.1 在程序中获取日志对象 2.1.2 使用日志对象打印日志 2.2、日志框架介绍 2.2.1 门面模式(外观模式) 2.2.2 门面模式的实现 2.2.3 SLF4J 框架介绍 2.3 日志格式的说明 2.4 日志级别 2.4.1 日志级别的分类 2.4.2…

【Linux】24.进程信号(1)

文章目录 1. 信号入门1.1 进程与信号的相关知识1.2 技术应用角度的信号1.3 注意1.4 信号概念1.5 信号处理常见方式概览 2. 产生信号2.1 通过终端按键产生信号2.2 调用系统函数向进程发信号2.3 由软件条件产生信号2.4 硬件异常产生信号2.5 信号保存 3. 阻塞信号3.1 信号其他相关…

[Proteus仿真]基于51单片机的智能温控系统

[Proteus仿真]基于51单片机的智能温控系统 基于51单片机的智能温控系统:DS18B20精准测温LCD1602双屏显示三键设置上下限声光报警,支持温度校准、抗干扰设计、阈值记忆。 一.仿真原理图 ​​ 二.模块介绍 温度采集模块(DS18B20&#xff0…

Windows下怎么安装FFFmpeg呢?

在Windows下使用Open-webui报错,说Couldnt find ffmpeg or avconv,解决open-webui报错Couldn‘t find ffmpeg or avconv-CSDN博客于是尝试解决问题,那么Windows下怎么安装FFFmpeg呢? 尝试了两种方法。 第一种方法pip安装(失败&…

C基础寒假练习(2)

一、输出3-100以内的完美数&#xff0c;(完美数&#xff1a;因子和(因子不包含自身)数本身 #include <stdio.h>// 函数声明 int isPerfectNumber(int num);int main() {printf("3-100以内的完美数有:\n");for (int i 3; i < 100; i){if (isPerfectNumber…

【智力测试——二分、前缀和、乘法逆元、组合计数】

题目 代码 #include <bits/stdc.h> using namespace std; using ll long long; const int mod 1e9 7; const int N 1e5 10; int r[N], c[N], f[2 * N]; int nr[N], nc[N], nn, nm; int cntr[N], cntc[N]; int n, m, t;void init(int n) {f[0] f[1] 1;for (int i …

Vue-el挂载点

目录 一、Vue中的el挂载点是什么&#xff1f;二、Vue实例的作用范围是什么呢&#xff1f;三、Vue中的el是否可以挂载哪些选择器&#xff1f;四、el是否可以设置其他的dom元素呢&#xff1f; 一、Vue中的el挂载点是什么&#xff1f; el是用来设置Vue实例挂载&#xff08;管理&a…

c语言练习【实现终端功能、dup2实现文件拷贝、read write文件加载到链表】

练习1&#xff1a;实现终端功能 请实现一个终端的功能&#xff0c;注意需要带有cd功能 #include <stdio.h> #include <stdlib.h> #include <string.h> #include <unistd.h> #include <sys/types.h> #include <sys/wait.h>#define MAX_CM…

MySQL数据库环境搭建

下载MySQL 官网&#xff1a;https://downloads.mysql.com/archives/installer/ 下载社区版就行了。 安装流程 看b站大佬的视频吧&#xff1a;https://www.bilibili.com/video/BV12q4y1477i/?spm_id_from333.337.search-card.all.click&vd_source37dfd298d2133f3e1f3e3c…

1.2 基于深度学习的底层视觉技术

文章目录 高层视觉任务与底层视觉任务深度神经网络相对于传统方法的优势 高层视觉任务与底层视觉任务 计算机视觉中的任务包含高层视觉任务&#xff0c;底层视觉任务。高层视觉任务是处理语义级别相关的任务&#xff0c;例如图像分类、目标检测、图像分割等。底层视觉任务处理与…

YOLOV11-1:YoloV11-安装和CLI方式训练模型

YoloV11-安装和CLI方式训练模型 1.安装和运行1.1安装的基础环境1.2安装yolo相关组件1.3命令行方式使用1.3.1 训练1.3.2 预测 本文介绍yoloV11的安装和命令行接口 1.安装和运行 1.1安装的基础环境 GPU环境&#xff0c;其中CUDA是12.4版本 1.2安装yolo相关组件 # 克隆github…

后盾人JS -- 原型

没有原型的对象 也有没有原型的对象 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document<…

【NEXT】网络编程——上传文件(不限于jpg/png/pdf/txt/doc等),或请求参数值是file类型时,调用在线服务接口

最近在使用华为AI平台ModelArts训练自己的图像识别模型&#xff0c;并部署了在线服务接口。供给客户端&#xff08;如&#xff1a;鸿蒙APP/元服务&#xff09;调用。 import核心能力&#xff1a; import { http } from kit.NetworkKit; import { fileIo } from kit.CoreFileK…

游戏引擎 Unity - Unity 下载与安装

Unity Unity 首次发布于 2005 年&#xff0c;属于 Unity Technologies Unity 使用的开发技术有&#xff1a;C# Unity 的适用平台&#xff1a;PC、主机、移动设备、VR / AR、Web 等 Unity 的适用领域&#xff1a;开发中等画质中小型项目 Unity 适合初学者或需要快速上手的开…

结构体DMA串口接收比特错位

发送&#xff1a; 显示&#xff1a; uint16_t接收时候会比特错位。

在线知识库的构建策略提升组织信息管理效率与决策能力

内容概要 在线知识库作为现代企业信息管理的重要组成部分&#xff0c;具有显著的定义与重要性。它不仅为组织提供了一个集中存储与管理知识的平台&#xff0c;还能够有效提升信息检索的效率&#xff0c;促进知识的创新和利用。通过这样的知识库&#xff0c;企业可以更好地应对…

e2studio开发RA2E1(4)----GPIO输出

e2studio开发RA2E1.4--GPIO输出 概述视频教学样品申请硬件准备参考程序源码下载新建工程工程模板保存工程路径芯片配置工程模板选择时钟设置GPIO口配置R_IOPORT_PinWrite()函数原型R_IOPORT_PortWrite()函数原型代码 概述 本篇文章主要介绍如何使用e2studio对瑞萨单片机进行GP…

【基于SprintBoot+Mybatis+Mysql】电脑商城项目之用户登录

&#x1f9f8;安清h&#xff1a;个人主页 &#x1f3a5;个人专栏&#xff1a;【Spring篇】【计算机网络】【Mybatis篇】 &#x1f6a6;作者简介&#xff1a;一个有趣爱睡觉的intp&#xff0c;期待和更多人分享自己所学知识的真诚大学生。 目录 &#x1f3af;1.登录-持久层 &…

51单片机看门狗系统

在 STC89C52 单片机中&#xff0c;看门狗控制寄存器的固定地址为 0xE1。此地址由芯片厂商在硬件设计时确定&#xff0c;但是它在头文件中并未给出&#xff0c;因此在使用看门狗系统时需要声明下这个特殊功能寄存器 sfr WDT_CONTR 0xE1; 本案将用一个小灯的工作状况来展示看门…