通俗易懂之线性回归时序预测PyTorch实践

线性回归(Linear Regression)是机器学习中最基本且广泛应用的算法之一。它不仅作为入门学习的经典案例,也是许多复杂模型的基础。本文将全面介绍线性回归的原理、应用,并通过一段PyTorch代码进行实践演示,帮助读者深入理解这一重要概念。

线性回归概述

线性回归是一种用于预测因变量(目标变量)与一个或多个自变量(特征变量)之间关系的统计方法。其目标是在数据点之间找到一条最佳拟合直线,使得预测值与实际值之间的误差最小。

基本形式

  • 简单线性回归:只有一个自变量。
  • 多元线性回归:包含多个自变量。

本文将聚焦于简单线性回归,即仅考虑一个自变量的情况。

线性回归的数学原理

模型表达式

简单线性回归的模型表达式为:

y = w x + b y = wx + b y=wx+b

其中:

  • y y y 是预测值。
  • x x x 是输入特征。
  • w w w 是权重(斜率)。
  • b b b 是偏置(截距)。

损失函数

为了衡量模型预测值与实际值之间的差异,通常使用均方误差(Mean Squared Error, MSE)作为损失函数:

Loss = 1 2 ∑ i = 1 N ( y i pred − y i ) 2 \text{Loss} = \frac{1}{2} \sum_{i=1}^{N} (y_i^{\text{pred}} - y_i)^2 Loss=21i=1N(yipredyi)2

优化算法

线性回归常用的优化算法是梯度下降(Gradient Descent)。通过计算损失函数关于参数 w w w b b b 的梯度,迭代更新参数以最小化损失。

更新规则如下:

w : = w − η ∂ Loss ∂ w w := w - \eta \frac{\partial \text{Loss}}{\partial w} w:=wηwLoss

b : = b − η ∂ Loss ∂ b b := b - \eta \frac{\partial \text{Loss}}{\partial b} b:=bηbLoss

其中 η \eta η 是学习率。

应用场景

线性回归在多个领域有广泛应用,包括但不限于:

  • 经济学:预测经济指标,如GDP、通货膨胀率等。
  • 工程学:估计物理量之间的关系,如材料强度与应力。
  • 医疗:预测疾病发展趋势,如体重增长与健康指标。
  • 金融:股价预测、风险评估等。

PyTorch实现线性回归

接下来,我们将通过一段PyTorch代码实践线性回归,从数据生成、模型训练到可视化展示,全面演示线性回归的实现过程。代码参考《深度学习框架PyTorch入门与实践》一书的实现,为了感受线性回归的计算过程,代码并未直接调用python中已有的线性回归库。

代码解析

首先,我们导入必要的库并设置随机种子以确保结果可复现。

import torch as t
import matplotlib.pyplot as plt
from IPython import display

t.manual_seed(1000)
数据生成函数

定义一个函数 get_fake_data 来生成假数据,这些数据遵循线性关系 y = 2 x + 3 y = 2x + 3 y=2x+3 并添加了一定的噪声。

def get_fake_data(batch_size=8):
    x = t.randn(batch_size, 1, dtype=float) * 20  # 随机生成x,范围扩大到[-20, 20]
    y = x * 2 + (1 + t.randn(batch_size, 1, dtype=float)) * 3  # y = 2x + 3 + 噪声
    return x, y

调用该函数生成一批数据并进行可视化。

x, y = get_fake_data()

plt.figure()
plt.scatter(x, y)
plt.show()
参数初始化

随机初始化权重 w w w 和偏置 b b b,并设置学习率 l r lr lr

# 随机初始化参数
w = t.rand(1, 1, requires_grad=True, dtype=float)
b = t.zeros(1, 1, requires_grad=True, dtype=float)

lr = 0.00001
训练过程

通过1000次迭代,使用梯度下降法优化参数 w w w b b b

for i in range(1000):
    x, y = get_fake_data()

    y_pred = x.mm(w) + b.expand_as(y)  # 预测值

    loss = 0.5 * (y_pred - y) ** 2  # 均方误差
    loss = loss.sum()

    loss.backward()  # 反向传播计算梯度

    # 更新参数
    w.data.sub_(lr * w.grad.data)
    b.data.sub_(lr * b.grad.data)

    # 梯度清零
    w.grad.data.zero_()
    b.grad.data.zero_()

    # 每100次迭代可视化一次结果
    if i % 100 == 0:
        display.clear_output(wait=True)
        x_plot = t.arange(0, 20, dtype=float).view(-1, 1)
        y_plot = x_plot.mm(w) + b.expand_as(x_plot)
        plt.plot(x_plot.data, y_plot.data, label='Fitting Line')

        x2, y2 = get_fake_data(batch_size=20)
        plt.scatter(x2, y2, color='red', label='Data Points')

        plt.xlim(0, 20)
        plt.ylim(0, 41)
        plt.legend()
        plt.show()
        plt.pause(0.5)

可视化与训练过程

训练过程中,每隔100次迭代,会清除之前的输出,绘制当前拟合的直线与新生成的数据点。随着训练的进行,拟合线将逐渐接近真实的线性关系 y = 2 x + 3 y = 2x + 3 y=2x+3

以下是训练过程中的可视化效果示例:

在这里插入图片描述

注:实际运行代码时,图像会动态更新,展示拟合过程。

代码关键点解析

  1. 数据生成

    • 使用 torch.randn 生成标准正态分布的随机数,并通过线性变换获取 xy
    • 添加噪声使模型更贴近真实场景。
  2. 参数初始化

    • w 随机初始化,b 初始化为零。
    • requires_grad=True 表示在反向传播时需要计算梯度。
  3. 前向传播

    • 计算预测值 y_pred = x.mm(w) + b.expand_as(y)
    • 使用矩阵乘法 mm 实现线性变换。
  4. 损失计算

    • 采用均方误差损失函数。
    • loss.backward() 计算损失函数相对于参数的梯度。
  5. 参数更新

    • 使用学习率 lr 按梯度方向更新参数。
    • data.sub_ 进行原地更新,避免梯度计算图的干扰。
  6. 梯度清零

    • 每次参数更新后,需要清零梯度 w.grad.data.zero_()b.grad.data.zero_(),以防止梯度累积。
  7. 可视化

    • 使用 matplotlib 绘制拟合线和数据点。
    • display.clear_output(wait=True) 清除之前的图像,避免图形堆积。
    • plt.pause(0.5) 控制图像更新速度。

总结

本文从线性回归的基本概念出发,详细介绍了其数学原理和应用场景,并通过一段PyTorch代码演示了线性回归模型的实现过程。从数据生成、参数初始化、模型训练到结果可视化,全面展示了线性回归的实际应用。通过这种实例讲解,读者不仅能够理解线性回归的理论基础,还能掌握其在深度学习框架中的具体实现方法。

线性回归作为机器学习的基础模型,虽然简单,但其思想却深刻影响着更加复杂的算法和模型。在实际应用中,理解并掌握线性回归对于进一步学习和开发更加复杂的机器学习模型具有重要意义。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/950407.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

分布式主键ID生成方式-snowflake雪花算法

这里写自定义目录标题 一、业务场景二、技术选型1、UUID方案2、Leaf方案-美团(基于数据库自增id)3、Snowflake雪花算法方案 总结 一、业务场景 大量的业务数据需要保存到数据库中,原来的单库单表的方式扛不住大数据量、高并发,需…

在 C# 中显示动画 GIF 并在运行时更改它们

您可以通过将按钮、图片框、标签或其他控件的Image属性设置为 GIF 文件 来显示动画 GIF 。(如果您在窗体的BackgroundImage属性中显示一个,则不会获得动画。) 有几种方法可以在运行时更改 GIF。 首先,您可以将 GIF 添加为资源。…

【技术支持】安卓无线adb调试连接方式

Android 10 及更低版本,需要借助 USB 手机和电脑需连接在同一 WiFi 下;手机开启开发者选项和 USB 调试模式,并通过 USB 连接电脑(即adb devices可以查看到手机);设置手机的监听adb tcpip 5555;拔掉 USB 线…

【网络】计算机网络的分类 局域网 (LAN) 广域网 (WAN) 城域网 (MAN)个域网(PAN)

局域网是通过路由器接入广域网的 分布范围 局域网Local Area Network:小范围覆盖,速度高,延迟低(办公室,家庭,校园,网络) 广域网Wide Area Network 大范围覆盖,速度相对低,延迟高…

scanf:数据之舟的摆渡人,静卧输入港湾的诗意守候

大家好啊,我是小象٩(๑ω๑)۶ 我的博客:Xiao Xiangζั͡ޓއއ 很高兴见到大家,希望能够和大家一起交流学习,共同进步。* 这一节我们主要来学习scanf的基本用法,了解scanf返回值,懂得scanf占位符和赋值…

win10 gt520+p106双卡测试

安装391.35驱动失败,虽然gpuz和设备管理器显示正常但没有nvidia控制面板 重启进安全模式,ddu卸载,再次重启到安全模式,安装391.01驱动,显示3dvision安装失败,重启再看已经有nvidia控制面板了 修改p106注册表 AdapterType 1 EnableMsHybrid 1 计算机\HKEY_LOCAL_MACHINE\SYSTE…

C# OpenCV机器视觉:霍夫变换

在一个阳光灿烂得近乎放肆的午后,阿强的实验室就像被施了魔法的科学城堡,到处闪耀着神秘的科技光芒。阿强呢,像个即将踏上惊险征程的探险家,一屁股坐在那堆满奇奇怪怪设备的桌前,眼神中透露出按捺不住的兴奋劲儿&#…

【深度学习基础】线性神经网络 | 线性回归的简洁实现

【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈PyTorch深度学习 ⌋ ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重…

工业级手持地面站(支持Android和IOS)技术详解!

一、硬件平台的选择 无人机遥控器为了支持Android和iOS系统,通常会选择高性能的处理器和操作系统作为硬件基础。例如,一些高端遥控器可能采用基于ARM架构的高性能处理器,这些处理器能够高效地运行Android或iOS操作系统,并提供足够…

CatLog的使用

一 CatLog的简介 1.1 作用 CAT(Central Application Tracking) 是基于 Java 开发的实时应用监控平台,为美团点评提供了全面的实时监控告警服务。 1.2 组成部分 1.2.1 Transaction 1.Transaction 适合记录跨越系统边界的程序访问行为&a…

vue elementui 大文件进度条下载

下载进度条 <el-card class"box-card" v-if"downloadProgress > 0"><div>正在下载文件...</div><el-progress :text-inside"true" :stroke-width"26" :percentage"downloadProgress" status"…

TensorRT-LLM中的MoE并行推理

2种并行方式&#xff1a; moe_tp_size&#xff1a;按照维度切分&#xff0c;每个GPU拥有所有Expert的一部分权重。 moe_ep_size: 按照Expert切分&#xff0c;每个GPU有用一部分Expert的所有权重。 二者可以搭配一起使用。 限制&#xff1a;二者的乘积&#xff0c;必须等于模…

计算机的错误计算(二百零五)

摘要 基于一位读者的问题&#xff0c;提出题目&#xff1a;能用数值计算证明 吗&#xff1f;请选用不同的点&#xff08;即差别大的数&#xff09;与不同的精度。实验表明&#xff0c;大模型理解了题意。但是&#xff0c;其推理能力值得商榷。 例1. 就摘要中问题&#xff0…

关于TCP/IP五层结构的理解

关于TCP/IP五层结构的理解 TCP/IP五层模型 是目前被广泛采用的一种模型,我们可以将 TCP / IP 模型看作是 OSI 七层模型的精简版本&#xff0c;由以下 5 层组成&#xff1a; 1. 应用层&#xff1a;应用层是体系结构中的最高层&#xff0c;定义了应用进程间通信和交互的规则。本…

Unity3D仿星露谷物语开发19之库存栏丢弃及交互道具

1、目标 从库存栏中把道具拖到游戏场景中&#xff0c;库存栏中道具数相应做减法或者删除道具。同时在库存栏中可以交换两个道具的位置。 2、UIInventorySlot设置Raycast属性 在UIInventorySlot中&#xff0c;我们只希望最外层的UIInventorySlot响应Raycast&#xff0c;他下面…

Sprint Boot教程之五十:Spring Boot JpaRepository 示例

Spring Boot JpaRepository 示例 Spring Boot建立在 Spring 之上&#xff0c;包含 Spring 的所有功能。由于其快速的生产就绪环境&#xff0c;使开发人员能够直接专注于逻辑&#xff0c;而不必费力配置和设置&#xff0c;因此如今它正成为开发人员的最爱。Spring Boot 是一个基…

C++ STL map和set的使用

序列式容器和关联式容器 想必大家已经接触过一些容器如&#xff1a;list&#xff0c;vector&#xff0c;deque&#xff0c;array&#xff0c;forward_list&#xff0c;string等&#xff0c;这些容器统称为系列容器。因为逻辑结构为线性的&#xff0c;两个位置的存储的值一般是…

人工智能及深度学习的一些题目(三)

1、【填空题】 使用RNNCTC模型进行语音识别&#xff0c;在产生预测输出时&#xff0c;对于输入的音频特征序列通过网络预测产生对应的字母序列&#xff0c;可以使用&#xff08; beamsearch &#xff09;算法进行最优路径搜索。 2、【填空题】 逻辑回归模型属于有监督学习中的&…

《C++11》右值引用深度解析:性能优化的秘密武器

C11引入了一个新的概念——右值引用&#xff0c;这是一个相当深奥且重要的概念。为了理解右值引用&#xff0c;我们需要先理解左值和右值的概念&#xff0c;然后再理解左值引用和右值引用。本文将详细解析这些概念&#xff0c;并通过实例进行说明&#xff0c;以揭示右值引用如何…

cp命令详解

&#x1f3dd;️专栏&#xff1a;计算机操作系统 &#x1f305;主页&#xff1a;猫咪-9527主页 “欲穷千里目&#xff0c;更上一层楼。会当凌绝顶&#xff0c;一览众山小。” 目录 1. 基本功能 2. 命令语法 3. 常用选项 4. 常见用法示例 4.1 复制单个文件 4.2 递归复制目录…