线形回归与小批量梯度下降实例

1、准备数据集

import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

#########################################################################
#################准备若干个随机的x和y#####################################
#########################################################################
np.random.seed(100)     #使用random.seed,设置一个固定的随机种子,
data_size = 150         # 数据集大小
x_range = 5             # x的范围
iteration_count = 100   # 迭代次数

# np.random.rand 是 NumPy 库中的一个函数,用于生成一个给定形状的数组,
# 数组中的元素是从一个均匀分布的样本中抽取的,这个均匀分布是在半开区间 [0, 1) 上。
# 这意味着产生的随机数将大于等于0且小于1。
                                                                                 
#随机生成data_size个横坐标x,范围在0到x_range之间
x=x_range * np.random.rand(data_size,1)

#生成带有噪音y数据,基本分布在y=2x+6的附近
y=2*x + 6 + np.random.randn(data_size,1)*0.3

plt.scatter(x,y,marker='x',color='green')

#########################################################################
#################将训练数据转为张量#######################################
#########################################################################
#将训练数据转为张量
tensorX = torch.from_numpy(x).float()
tensorY = torch.from_numpy(y).float()

#使用TensorDataset,将tensorX和tensorY组成训练集
dataset = TensorDataset(tensorX,tensorY)

#使用DataLoader,构造随机的小批量数据
dataloader=DataLoader(dataset,
                      batch_size = 20, #每一个小批量的数据规模是20
                      shuffle =True )  #随机打乱数据的顺序
print("dataloader len =%d" %(len(dataloader)))

for index,(data,label) in enumerate(dataloader):
    print("index=%d num = %d"%(index,len(data)))

2、线性回归模型的训练思路

2.1 初始化参数

设置是模型参数:权重w和偏置b

初始化为随机值

设置 requires_grad=True,PyTorch 将记录这些张量的操作历史,用于后续的自动求导

2.2 循环训练(epoch

epoch 变量用于控制整个训练过程的迭代轮数

在机器学习和深度学习中,“epoch” 是一个常用的术语,指的是在整个数据集上完整地运行一次(即正向传播和反向传播)训练算法的过程。

定义:一个 epoch 是指训练过程中,训练集中每个样本都被使用过一次来更新模型的权重。
训练过程:在训练一个模型时,通常会将数据集分成多个批次(batches)。每个批次包含一定数量的样本。一个 epoch 完成意味着所有批次都已经过模型处理。
迭代与epoch:在一个 epoch 内,模型可能会多次迭代,每次迭代处理一个批次的数据。因此,一个 epoch 包含多个迭代(iterations)。
目的:通过多个 epochs 的训练,模型可以逐渐学习数据集中的模式,从而提高其性能。
数量:训练一个模型所需的 epochs 数量取决于多种因素,包括数据集的大小、模型的复杂度以及问题的难度。有时可能只需要几个 epochs,而有时可能需要数百甚至数千个 epochs。
监控:在训练过程中,通常会监控每个 epoch 的性能指标(如损失函数的值或准确率),以评估模型的学习进度。
过拟合与欠拟合:如果训练过多的 epochs,模型可能会过拟合(即模型学习到了数据中的噪声而非潜在的模式),而训练不足的 epochs 则可能导致欠拟合(即模型未能捕捉到数据中的关键模式)。

2.3 数据加载

内层循环通过 dataloader 遍历训练数据集的小批量数据。dataloader 是一个数据加载器,通常由 DataLoader 类创建,用于批量加载数据。

2.4 前向传播

假设 tensorX 是当前批次的数据

tensorY 是对应的真实标签

使用当前参数 w 和 b 计算预测值 h=w*tensorX +b。

2.5 计算损失

计算预测值 h 和真实值 tensorY 之间的均方误差(MSE),并保存到 loss

loss = torch.mean((h - tensorY) ** 2)

2.6 反向传播

调用 loss.backward() 进行反向传播,计算损失关于参数 w 和 b 的梯度

设置了 requires_grad=True,PyTorch 将记录这些张量的操作历史并自动求导

2.7 更新参数

使用梯度下降算法更新参数 w 和 b。学习率设置为0.01

w.data -= 0.01 * w.grad.data

b.data -= 0.01 * b.grad.data

沿着当前小批量计算的得到的梯度(导数)更新w和b

如果导数为0,则w、b保存不变

2.8 梯度清零

在每次迭代后,需要清空参数的梯度信息,以便下一次迭代计算

3、线性回归模型的实现

# 待送代的参数为w和b
w = torch.randn(1,requires_grad=True)
b = torch.randn(1,requires_grad=True)

#进入模型的循环迭代
for epoch in range(1,iteration_count):#代表了整个训练数据集的迭代轮数
    # 在一个迭代轮次中,以小批量的方式,使用dataloader对数据
    # batch_index表示当前遍历的批次
    # data和label表示这个批次的训练数据和标记
    for batch_index,(data, label)in enumerate(dataloader):
        h = tensorX * w + b #计算当前直线的预测值,保存到h
        
        #计算预测值h和真实值y之间的均方误差,保存到loss中
        loss=torch.mean((h-tensorY)**2)
        #计算代价1oss关于参数w和b的偏导数,设置了 requires_grad=True,PyTorch 将记录这些张量的操作历史并自动求导
        loss.backward()
        
        #进行梯度下降,沿着梯度的反方向,更新w和b的值
        #沿着当前小批量计算的得到的梯度(导数)更新w和b
        #如果导数为0(Δw,Δb为0),则w、b保存不变
        w.data -=0.01 * w.grad.data
        b.data -=0.01 * b.grad.data
        
        print("epoch(%d) batch(%d) lossΔw,Δb,w,b, = %.3lf,%.3lf,%.3lf,%.3lf,%.3lf," %(epoch,batch_index,loss.item(),w.grad.data,b.grad.data,w.data,b.data))
        
        #清空张量w和b中的梯度信息,为下一次迭代做准备
        w.grad.zero_()
        b.grad.zero_()
        
        #每次迭代,都打印当前迭代的轮数epoch
        #数据的批次batch idx和loss损失值
        
        

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/952738.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

消息队列使用中防止消息丢失的实战指南

消息队列使用中防止消息丢失的实战指南 在分布式系统架构里,消息队列起着举足轻重的作用,它异步解耦各个业务模块,提升系统整体的吞吐量与响应速度。但消息丢失问题,犹如一颗不定时炸弹,随时可能破坏系统的数据一致性…

【优选算法篇】:深入浅出位运算--性能优化的利器

✨感谢您阅读本篇文章,文章内容是个人学习笔记的整理,如果哪里有误的话还请您指正噢✨ ✨ 个人主页:余辉zmh–CSDN博客 ✨ 文章所属专栏:优选算法篇–CSDN博客 文章目录 一.位运算一.位运算概述二.常见的位运算操作符三.常见的位运…

创业AI Agents系统深度解析

Agents 近日,AI领域的知名公司Anthropic发布了一份题为《构建高效的智能代理》的报告。该报告基于Anthropic过去一年与多个团队合作构建大语言模型(LLM)智能代理系统的经验,为开发者及对该领域感兴趣的人士提供了宝贵的洞见。本文…

【Spring Boot】Spring 事务探秘:核心机制与应用场景解析

前言 🌟🌟本期讲解关于spring 事务介绍~~~ 🌈感兴趣的小伙伴看一看小编主页:GGBondlctrl-CSDN博客 🔥 你的点赞就是小编不断更新的最大动力 🎆那么废话不多说直…

centos7.6 安装nginx 1.21.3与配置ssl

1 安装依赖 yum -y install gcc zlib zlib-devel pcre-devel openssl openssl-devel2 下载Nginx wget http://nginx.org/download/nginx-1.21.3.tar.gz3 安装目录 mkdir -p /data/apps/nginx4 安装 4.1 创建用户 创建用户nginx使用的nginx用户。 #添加www组 # groupa…

夯实前端基础之HTML篇

知识点概览 HTML部分 1. DOM和BOM有什么区别? DOM(Document Object Model) 当网页被加载时,浏览器会创建页面的对象文档模型,HTML DOM 模型被结构化为对象树 用途: 主要用于网页内容的动态修改和交互&…

Elasticsearch:向量数据库基础设施类别的兴衰

过去几年,我一直在观察嵌入技术如何从大型科技公司的 “秘密武器” 转变为日常开发人员工具。接下来发生的事情 —— 向量数据库淘金热、RAG 炒作周期以及最终的修正 —— 教会了我们关于新技术如何在更广泛的生态系统中找到一席之地的宝贵经验。 更多有关向量搜索…

【华为云开发者学堂】基于华为云 CodeArts CCE 开发微服务电商平台

实验目的 通过完成本实验,在 CodeArts 平台完成基于微服务的应用开发,构建和部署。 ● 理解微服务应用架构和微服务模块组件 ● 掌握 CCE 平台创建基于公共镜像的应用的操作 ● 掌握 CodeArts 平台编译构建微服务应用的操作 ● 掌握 CodeArts 平台部署微…

计科高可用服务器架构实训(防火墙、双机热备,VRRP、MSTP、DHCP、OSPF)

一、项目介绍 需求分析: (1)总部和分部要求网络拓扑简单,方便维护,网络有扩展和冗余性; (2)总部分财务部,人事部,工程部,技术部,提供…

【C++入门】详解合集

目录 💕1.C中main函数内部———变量的访问顺序 💕2.命名空间域 namespace 💕3.命名空间域(代码示例)(不要跳) 💕4.多个命名空间域的内部重名 💕5.命名空间域的展开 …

预编译SQL

预编译SQL 预编译SQL是指在数据库应用程序中,SQL语句在执行之前已经通过某种机制(如预编译器)进行了解析、优化和准备,使得实际执行时可以直接使用优化后的执行计划,而不需要每次都重新解析和编译。这么说可能有一些抽…

qemu搭建虚拟的aarch64环境开发ebpf

一、背景 需求在嵌入式环境下进行交叉编译,学习ebpf相关技术,所以想搭建一个不依赖硬件环境的学习环境。 本文使用的环境版本: 宿主机: Ubuntu24.02 libbpf-bootstrap源码: https://github.com/libbpf/libbpf-boots…

深度学习从入门到实战——卷积神经网络原理解析及其应用

卷积神经网络CNN 卷积神经网络前言卷积神经网络卷积的填充方式卷积原理展示卷积计算量公式卷积核输出的大小计算感受野池化自适应均值化空洞卷积经典卷积神经网络参考 卷积神经网络 前言 为什么要使用卷积神经网络呢? 首先传统的MLP的有什么问题呢? - …

2015年西部数学奥林匹克几何试题

2015/G1 圆 ω 1 \omega_1 ω1​ 与圆 ω 2 \omega_2 ω2​ 内切于点 T T T. M M M, N N N 是圆 ω 1 \omega_1 ω1​ 上不同于 T T T 的不同两点. 圆 ω 2 \omega_2 ω2​ 的两条弦 A B AB AB, C D CD CD 分别过 M M M, N N N. 证明: 若线段 A C AC AC, B D BD …

《Spring Framework实战》14:4.1.4.5.自动装配合作者

欢迎观看《Spring Framework实战》视频教程 自动装配合作者 Spring容器可以自动连接协作bean之间的关系。您可以通过检查ApplicationContext的内容,让Spring自动为您的bean解析协作者(其他bean)。自动装配具有以下优点: 自动装配…

JVM之垃圾回收器概述(续)的详细解析

ParNew(并行) Par 是 Parallel 并行的缩写,New 是只能处理的是新生代 并行垃圾收集器在串行垃圾收集器的基础之上做了改进,采用复制算法,将单线程改为了多线程进行垃圾回收,可以缩短垃圾回收的时间 对于其他的行为(…

有一台服务器可以做哪些很酷的事情

有一台服务器可以做哪些很酷的事情 今天我也来简单分享一下,这几年来,我用云服务器做了哪些有趣的事情。 服务器推荐 1. 个人博客 拥有个人服务器,你可以完全掌控自己的网站或博客。 与使用第三方托管平台相比,你能自由选择网站…

灌区闸门自动化控制系统-精准渠道量测水-灌区现代化建设

项目背景 本项目聚焦于黑龙江某一灌区的现代化改造工程,该灌区覆盖广阔,灌溉面积高达7.5万亩,地域上跨越6个乡镇及涵盖17个村庄。项目核心在于通过全面的信息化建设,强力推动节水灌溉措施的实施,旨在显著提升农业用水的…

3.flask蓝图使用

构建一个目录结构 user_oper.py from flask import Blueprint, request, session, redirect, render_template import functools # 创建蓝图 user Blueprint(xkj, __name__)DATA_DICT {1: {"name": "张三", "age": 22, "gender": …

vue3学习日记1 - Pinia

最近发现职场前端用的框架大多为vue,所以最近也跟着黑马程序员vue3的课程进行学习,以下是我的学习记录 视频网址: Day2-02.Pinia-counter基础使用_哔哩哔哩_bilibili 学习日记: vue3学习日记1 - 环境搭建-CSDN博客 vue3学习日…