深度学习与神经网络Pytorch版 3.2 线性回归从零开始实现 1.生成数据集

3.2 线性回归从零开始实现

目录

3.2 线性回归从零开始实现

一 ,简介

1. 原理

2. 步骤

3. 优缺点

4. 应用场景

二 ,代码展现

1. 生成数据集(完整代码)

2. 各个函数解析

2.1 torch.normal()函数

2.2 torch.matmul()函数

2.3 d2l.plt.scatter()函数

三 ,总结


一 ,简介

1. 原理

深度学习线性回归的原理是基于神经网络和线性回归的结合。它使用神经网络来构建一个复杂的非线性模型,同时保持线性回归的简单性和可解释性。

在深度学习线性回归中,通常使用全连接神经网络(Fully Connected Neural Network)作为基础结构。输入数据经过一系列的线性变换和非线性激活函数,最终输出预测结果。与传统的线性回归不同,深度学习线性回归可以自动学习特征之间的复杂交互和组合,而不需要手动选择或设计特征。

深度学习线性回归的训练过程与传统神经网络的训练过程类似,使用梯度下降算法优化模型的参数,以最小化预测误差(如均方误差)。在训练过程中,通过反向传播算法计算梯度,并使用优化器(如Adam、SGD等)更新权重和偏置项。

深度学习线性回归的优点是可以处理高维、复杂的非线性数据,并且具有自动特征选择和组合的能力。然而,与传统的线性回归相比,深度学习线性回归需要更多的参数和计算资源,并且可能更容易过拟合。因此,在选择是否使用深度学习线性回归时,需要根据具体问题和数据集的特点进行权衡。

2. 步骤

线性回归从零开始实现步骤包括以下内容:

  1. 导入必要的库:在Python中,需要导入numpy库来处理数据和计算,以及matplotlib库来绘制数据和结果。
  2. 生成数据集:根据实际问题,可以使用随机数生成器生成一组训练数据集,包括输入特征和对应的标签。也可以使用真实数据集进行训练和测试。
  3. 初始化模型参数:为模型权重和偏置项设置初始值,这些初始值可以是随机数或基于先验知识的值。
  4. 定义模型:根据线性回归模型的公式,可以使用numpy的矩阵运算来构建模型。模型可以表示为y = w * x + b,其中x是输入特征,y是对应的标签,w是权重,b是偏置项。
  5. 计算损失函数:损失函数用于衡量模型的预测值与真实值之间的差距。对于线性回归问题,常用的损失函数是均方误差(MSE)。
  6. 执行梯度下降算法:梯度下降算法用于更新模型的参数以最小化损失函数。在每一步迭代中,根据梯度下降公式计算参数的更新方向和步长,并更新参数的值。
  7. 训练:重复执行步骤5和6,直到达到预设的迭代次数或损失函数达到可接受的值。
  8. 评估模型:使用测试数据集评估模型的性能,计算模型的预测值与真实值之间的误差或准确率等指标。
  9. 优化和调整:根据评估结果对模型进行调整,例如调整参数、增加特征或使用正则化等方法来提高模型的性能。
  10. 应用模型进行预测:将新数据输入到模型中进行预测,得到预测结果。

以上是线性回归从零开始实现的基本步骤,具体实现细节可能会根据问题和数据集的不同而有所差异。

3. 优缺点

线性回归的优点:

  1. 简单易行:线性回归模型简单易懂,实现起来也相对容易。
  2. 计算效率高:由于模型简单,计算复杂度较低,因此在线性回归中,无论是训练还是预测,计算速度都比较快。
  3. 可解释性强:线性回归模型可以给出每个特征的权重,这有助于理解特征对目标变量的影响程度。
  4. 适合处理线性关系:线性回归适合处理因变量和自变量之间存在线性关系的情况。
  5. 模型稳定性好:线性回归模型相对稳定,对异常值和噪声的鲁棒性较好。

然而,线性回归也存在一些缺点:

  1. 假设限制:线性回归基于一些假设,如误差项的独立性、同方差性、无序列相关性和常数方差等。在实际应用中,这些假设可能不成立,导致模型误判。
  2. 欠拟合与过拟合:如果线性模型过于简单(即过于欠拟合),它可能无法捕获数据的复杂模式;而如果模型过于复杂(即过拟合),它可能会捕获到数据中的噪声和无关紧要的信息。
  3. 无法处理非线性关系:对于非线性关系的数据,线性回归可能无法给出很好的预测。
  4. 对异常值敏感:如果数据集中存在异常值,线性回归模型的预测结果可能会受到影响。
  5. 特征选择困难:对于特征之间的交互和特征选择,线性回归模型可能会遇到困难。

4. 应用场景

线性回归的应用场景包括但不限于:

  1. 预测:当因变量是连续变量,并且与其影响因素有线性关系时,可以用线性回归进行建模。例如,预测信用卡用户的生命周期价值,可以基于用户所在小区的平均收入、年龄、学历、收入等因素进行线性回归建模。
  2. 模型解释:当需要理解自变量与因变量之间的关系时,可以通过建立线性回归模型,例如决策树、线性回归等模型,以自变量作为输入变量,以因变量作为目标变量进行建模,以此了解黑盒模型的运作机制,并对其作出解释。
  3. 全量实验效果评估:全量实验评估是指当在时间点时,对全量用户加入干预策略,然后评估策略所带来的影响。进行评估时,核心是要剥离其他因素,对实验效果进行评估,线性回归就能解决这个问题。
  4. AB实验:在AB实验中,假定有两组无差异的用户群体和,以作为实验组对其施加策略干预,作为对照组不采取施加任何策略,来评估实验对观测变量的影响。可以通过t或z检验来得到结果,当然也可以建立线性回归模型 ,为是否为实验组的哑变量(当策略变多时,也可为分类变量),通过检验参数的显著性即可得到策略的效果。
  5. 预测疾病发生概率:医院可以根据患者的病历数据(如体检指标、药物复用情况、平时的饮食习惯等)来预测某种疾病发生的概率。
  6. 预测用户支付转化率:网站可以根据访问的历史数据(包括新用户的注册量、老用户的活跃度、网站内容的更新频率等)来预测用户的支付转化率。

以上只是部分应用场景,线性回归模型的应用非常广泛,具体应用取决于数据的特征和业务需求。

二 ,代码展现

1. 生成数据集(完整代码)

# 线性回归从零开始实现
# 生成数据集

# 导入必要的库
import matplotlib.pyplot as plt
import random
import torch
from d2l import torch as d2l


# 定义一个生成合成数据的函数
def synthetic_data(w, b, num_examples):    # 函数参数包括权重w、偏置b和数据点数量num_examples
    # 生成y=Xw+b+噪声满足线性关系y=Xw+b的数据,并添加噪声
    X = torch.normal(0, 1, (num_examples, len(w)))  # 创建一个形状为(num_examples, len(w))的张量X,元素值为从标准正态分布中抽取的随机数
    y = torch.matmul(X, w) + b  # 使用矩阵乘法计算y的值,y = X * w + b
    y += torch.normal(0, 0.01, y.shape)  # 在y的值上添加从标准正态分布中抽取的随机噪声,噪声的标准差为0.01
    return X, y.reshape((-1, 1))  # 返回X和y。y被重新整形为(-1, 1)的形状,这是因为matplotlib在绘图时需要这样的形状


# 定义真实的权重和偏置值
true_w = torch.tensor([2, -3.4])  # 真实的权重w为[2, -3.4]的张量
true_b = 4.2  # 真实的偏置b为4.2的标量

# 使用上面定义的函数生成数据集
features, labels = synthetic_data(true_w, true_b, 1000)  # 生成1000个数据点作为训练或测试样本,特征为X,标签为y(即labels)

print('features:', features[0],'\nlabel:', labels[0])
d2l.set_figsize()
d2l.plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1) 
# 这行代码也是从d2l库中调用的。它使用散点图来可视化特征和标签。#
# features[:, (1)].detach().numpy()选取了所有数据点的第二个特征(索引为1,因为索引是从0开始的)并转换为NumPy数组。
# #.detach()是PyTorch中的方法,用于从计算图中分离张量,这样张量就不会追踪其历史计算,这在进行绘图等操作时是很有用的。
# labels.detach().numpy()将标签转换为NumPy数组。这里的1表示散点的大小。
plt.show()

2. 各个函数解析

2.1 torch.normal()函数
normal(mean, std, *, generator=None, out=None)

参数说明

  • mean (Tensor): 每个输出元素的均值。它是一个张量,其中包含各个分布的均值。
  • std (Tensor): 每个输出元素的标准差。它也是一个张量,其中包含各个分布的标准差。
  • *: 表示后面的参数是关键字参数。
  • generator: 可选参数,一个伪随机数生成器。
  • out: 可选参数,输出张量。

注意事项

  1. meanstd的形状不必匹配,但它们的元素总数必须相同。如果形状不匹配,将使用mean的形状作为返回输出张量的形状。
  2. 如果std是一个CUDA tensor,该函数将同步其设备与CPU。
2.2 torch.matmul()函数
matmul(input, other, *, out=None) -> Tensor

参数说明:

  • input (Tensor): 输入张量。
  • other (Tensor): 另一个张量。
  • *: 表示后面的参数是关键字参数。
  • out (Tensor, optional): 可选参数,输出张量。

行为取决于张量的维度如下:

  • 如果两个张量都是一维的,返回点积(标量)。
  • 如果两个参数都是二维的,返回矩阵-矩阵乘积。
  • 如果第一个参数是一维的,而第二个参数是二维的,为了矩阵乘法,向其维度添加一个1。矩阵乘法之后,添加的维度被移除。
  • 如果第一个参数是二维的,而第二个参数是一维的,返回矩阵-向量乘积。
  • 如果两个参数都至少是一维的,并且至少有一个参数是N维的(其中N>2),则返回批处理矩阵乘法。如果第一个参数是一维的,为了批处理矩阵乘法,向其维度添加一个1,然后在批处理矩阵乘法之后移除它。如果第二个参数是一维的,为了批处理矩阵乘法,向其维度添加一个1,然后在批处理矩阵乘法之后移除它。非矩阵(即批处理)维度是广播的(因此必须可广播)。例如,如果input是一个(j × 1 × n × n)张量,而other是一个(k × n × n)张量,则out将是一个(j × k × n × n)张量。
2.3 d2l.plt.scatter()函数
scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, *, edgecolors=None, plotnonfinite=False, data=None, **kwargs)

参数说明:

  • x, y:这些是您要在散点图中表示的数据点的x和y坐标。

  • s:散点的面积,以像素为单位。这通常用于根据数据点的值进行大小调整。

  • c:用于颜色映射的单个值或数组,通常表示颜色或数据点的值。

  • marker:散点的形状。例如,'o'表示圆形,'.'表示点,','表示像素等。

  • cmap:颜色映射对象或名称。这决定了如何根据c参数的值映射颜色。

  • norm:用于映射到给定范围的归一化对象。这通常与cmap一起使用,以控制颜色映射的范围。

  • vmin, vmax:这些参数指定了归一化对象的下限和上限。它们与norm一起使用来控制颜色映射的范围。

  • alpha:散点的透明度。值范围从0(完全透明)到1(完全不透明)。

  • linewidths:用于绘制边框线的宽度。当不为None时,这会使散点变为带边框的圆圈。

  • edgecolors:用于边框线的颜色。这可以是单一的颜色或颜色数组,与数据点一一对应。

  • plotnonfinite:如果为True,则非有限数值的数据点将被绘制。默认为False。

  • data:提供给所有数据的原始数据的字典。这通常在传递给函数的数据不是直接参数时使用。

  • kwargs:其他关键字参数将传递给collections.PathCollection的构造函数,允许您自定义散点图的其他方面。例如,您可以指定label来在图例中标识这些点等。

三 ,总结

 这段代码的主要目的是生成数据集,并使用散点图可视化其特征和标签。通过这种方式,可以直观地观察到数据分布和特征之间的关系。此外,代码还演示了如何使用PyTorch进行矩阵运算和NumPy数组转换,以及如何使用d2l库中的函数进行绘图操作。

        之后我会更新,线性回归的读取数据集,初始化模型参数,定义模型,定义模型,定义损失函数,定义优化算法,训练等步骤。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/362712.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【教学类-44-54】20240201 德彪钢笔行书(实线字体)制作的数字描字帖

作品展示 背景需求: 找到了两款适合做数字描字贴的字体 【教学类-44-03】20240111阿拉伯数字字帖的字体(三)——德彪钢笔行书(实线字体)和print dashed(虚线字体)-CSDN博客文章浏览阅读1.1k次…

【HarmonyOS】鸿蒙开发之HTTP网络请求——第5章

HTTP网络请求封装 network/request.ets import { configInterface } from ./type import http from ohos.net.http import { getToken } from ../utils/storage//网络请求封装 export const request (config:configInterface)>{let httpRequest:http.HttpRequest http.c…

༺༽༾ཊ—Unity之-01-工厂方法模式—ཏ༿༼༻

首先创建一个项目, 在这个初始界面我们需要做一些准备工作, 建基础通用文件夹, 创建一个Plane 重置后 缩放100倍 加一个颜色, 任务:使用工厂方法模式 创建 飞船模型, 首先资源商店下载飞船模型&#xff0c…

二进制安全虚拟机Protostar靶场(5)堆的简单介绍以及实战 heap0

前言 这是一个系列文章,之前已经介绍过一些二进制安全的基础知识,这里就不过多重复提及,不熟悉的同学可以去看看我之前写的文章 什么是堆 堆是动态内存分配的区域,程序在运行时用来分配内存。它与栈不同,栈用于静态…

【Vue3+Vite】Vue3视图渲染技术 快速学习 第二期

文章目录 一、模版语法1.1 插值表达式和文本渲染1.1.1 插值表达式 语法1.1.2 文本渲染 语法 1.2 Attribute属性渲染1.3 事件的绑定 二、响应式基础2.1 响应式需求案例2.2 响应式实现关键字ref2.3 响应式实现关键字reactive2.4 扩展响应式关键字toRefs 和 toRef 三、条件和列表渲…

农业植保无人机行业研究:预计2025年市场规模可达115亿元

农业植保无人机行业市场投资前景现状如何?农业植保无人机市场,包括无人机自身技术、性能标准和植保标准。农业植保无人机应用植保机喷洒农药对我国而言,不仅具有很大的经济价值,还具有社会价值:农业植保机作业不仅有超高的工作效…

并网逆变器学习笔记8---平衡桥(独立中线模块)控制

参考文献:《带独立中线模块的三相四线制逆变器中线电压脉动抑制方法》---赵文心 一、独立中线模块的三相四线拓扑 独立中线模块是控制中线电压恒为母线一半,同时为零序电流ineu提供通路。不平衡负载的零序电流会导致中线电压脉动,因此需要控制…

【Android 字节码插桩】Gradle插件基础 Transform API的使用

前言 啪~我给大家开个会(手机扔桌子上) 什么叫做 客户无感的数据脱敏!? 师爷给翻译翻译什么叫做客户无感的数据脱敏? 什么特么的叫做客户无感数据脱敏? 举个栗子~ 客户端Sdk新升级了一个版本,增…

UnityShader(九)Unity中的基础光照(下)

目录 标准光照模型 自发光 高光反射 (1)Phong模型 (2)Blinn模型 漫反射 环境光 逐顶点还是逐像素 逐像素光照 逐顶点光照 总结 标准光照模型 光照模型有许多种,但在早期游戏引擎中,往往只使用一…

linux -- 并发 -- 并发来源与简单的解决并发的手段

互斥与同步 当多个执行路径并发执行时,确保对共享资源的访问安全是驱动程序员不得不面对的问题 互斥:对资源的排他性访问 同步:对进程执行的先后顺序做出妥善的安排 一些概念: 临界区:对共享的资源进行访问的代码片段…

1、缓存击穿背后的问题

当面试官问:你知道什么是缓存击穿吗,你们是如何解决的? 首先我们要了解什么是缓存击穿?以及缓存击穿会引发什么问题? 缓存击穿就是redis中的热点数据过期,缓存失效,导致大量的请求直接打到数据…

【免费分享】数据可视化-银行动态实时大屏监管系统,含源码

一、动态效果展示 1. 动态实时更新数据效果图 ​ 2. 鼠标右键切换主题 二、确定需求方案 1. 屏幕分辨率 这个案例的分辨率是16:9,最常用的的宽屏比。 根据电脑分辨率屏幕自适应显示,F11全屏查看; 2. 部署方式 B/S方式:支持…

使用了不受支持的协议。 ERR_SSL_VERSION_OR_CIPHER_MISMATCH的问题解决办法

windwos 2008 R2 使用IIS部署的项目申请使用https协议的时候,通过安全加密协议访问网站提示不受支持的协议 错误原因分析 这种错误通常表示客户端和服务器之间存在协议版本或加密套件不兼容导致在SSL(Secure Socket Layer) 1.协议版本不兼容&…

壹[1],Xamarin开发环境配置

1,环境 VS2022 注: 1,本来计划使用AndroidStudio,但是也是一堆莫名的配置让人搞得很神伤,还是回归C#。 2,MAUI操作类似,但是很多错误解来解去,且调试起来很卡。 3,最…

哪个牌子的头戴式耳机好?推荐性价比高的头戴式耳机品牌

随着科技的不断发展,耳机市场也呈现出百花齐放的态势,从高端的奢侈品牌到亲民的平价品牌,各种款式、功能的耳机层出不穷,而头戴式耳机作为其中的一员,凭借其优秀的音质和降噪功能,受到了广大用户的喜爱&…

C++文件操作(2)

文件操作(2) 1.二进制模式读取文本文件2.使用二进制读写其他类型内容3.fstream类4.文件的随机存取文件指针的获取文件指针的移动 1.二进制模式读取文本文件 用二进制方式打开文本存储的文件时,也可以读取其中的内容,因为文本文件…

Flask 入门3:Flask 请求上下文与请求

1. 前言 Flask 在处理请求与响应的过程: 首先我们从浏览器发送一个请求到服务端,由 Flask 接收了这个请求以后,这个请求将会由路由系统接收。然后在路由系统中,还可以挂入一些 “勾子”,在进入我们的 viewFunction …

【C++】开源:Windows图形库EasyX配置与使用

😏★,:.☆( ̄▽ ̄)/$:.★ 😏 这篇文章主要介绍Windows图形库EasyX配置与使用。 无专精则不能成,无涉猎则不能通。——梁启超 欢迎来到我的博客,一起学习,共同进步。 喜欢的朋友可以关注一下&#…

✅Redis 常见数据类型和应用场景(详解)

Redis 提供了丰富的数据类型,常见的有五种:String(字符串),Hash(哈希),List(列表),Set(集合)、Zset(有序集合&…

揭开时间序列的神秘面纱:特征工程的力量

目录 写在开头1. 什么是特征工程?1.1 特征工程的定义和基本概念1.2 特征工程在传统机器学习中的应用1.3 时间序列领域中特征工程的独特挑战和需求3. 时间序列数据的特征工程技术2.1 数据清洗和预处理2.1.1 缺失值处理2.1.2 异常值检测与处理2.2 时间特征的提取2.2.1 时间戳解析…