Pytorch线性回归教程

import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt

生成测试数据

# 长期趋势
def trend(time, slope=0):
    return slope * time

# 季节趋势
def seasonal_pattern(season_time):
    return np.where(season_time < 0.4,
                    np.cos(season_time * 2 * np.pi),
                    1 / np.exp(3 * season_time))
def seasonality(time, period, amplitude=1, phase=0):
    season_time = ((time + phase) % period) / period
    return amplitude * seasonal_pattern(season_time)

# 噪声
def noise(time, noise_level=1):
    return np.random.randn(len(time)) * noise_level
X = torch.arange(1, 1001)
# Y = 0.7 * X + 100 + torch.randn(X.size())
Y = trend(X, 0.3) + seasonality(X, period=365, amplitude=30) + noise(X, 15) + 200
X.shape, Y.shape
(torch.Size([1000]), torch.Size([1000]))
plt.plot(X.numpy(), Y.numpy());

对测试数据进行处理

# 模型的数据的类型需要是32位浮点型
X = X.type(torch.float32)
Y = Y.type(torch.float32)
X.dtype, Y.dtype
(torch.float32, torch.float32)
# 模型的数据需要进行归一化或者标准化,下面是归一化
X = (X - X.min()) / (X.max() - X.min())
Y = (Y - Y.min()) / (Y.max() - Y.min())
plt.plot(X.numpy(), Y.numpy());

定义模型和模型参数

# 线性模型只有两个参数斜率k,和偏置b
# 线性模型的方程为y = k * x + b
k = nn.Parameter(torch.rand(1, dtype=torch.float32))
b = nn.Parameter(torch.rand(1, dtype=torch.float32))
# 下面输出中的requires_grad=True 表示该参数需要计算梯度
# 梯度用于在反向传播中对参数进行优化,优化方法即梯度下降
k, b 
(Parameter containing:
 tensor([0.6231], requires_grad=True),
 Parameter containing:
 tensor([0.0044], requires_grad=True))
def linear_model(x):
    return k * x + b

梯度下降优化参数

# 可以通过改变学习率lr和epoch_num学习各自的用途
# 定义优化器,用于更新模型的参数,即传入的k和b
optimizer = torch.optim.SGD([k, b], lr=0.01)
# 损失函数,模型优化的目的
loss_func = nn.MSELoss()

# 每个epoch表示把全部的数据过一遍
epoch_num = 2000
for epoch in range(epoch_num):
    # 获取模型预测结果
    y_pred = linear_model(X)
    # 计算损失值
    loss = loss_func(y_pred, Y)
    # 将梯度设为0
    optimizer.zero_grad()
    # 反向传播,计算梯度
    loss.backward()
    # 执行梯度下降,优化参数
    optimizer.step()
k, b
(Parameter containing:
 tensor([0.8825], requires_grad=True),
 Parameter containing:
 tensor([0.0419], requires_grad=True))
# detach()函数用于将参数设置为不需要梯度
k2 = k.detach().numpy()[0]
b2 = b.detach().numpy()[0]

plt.plot(X, Y);
plt.plot(X, k2 * X + b2);

优化模型

class LinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.k = nn.Parameter(torch.rand(1, dtype=torch.float32))
        self.b = nn.Parameter(torch.rand(1, dtype=torch.float32))

    def forward(self, x):
        return self.k * x + self.b
model = LinearModel()
# 定义优化器,用于更新模型的参数,即传入的k和b
optimizer = torch.optim.SGD([k, b], lr=0.01)
# 损失函数,模型优化的目的
loss_func = nn.MSELoss()

epoch_num = 2000
for epoch in range(epoch_num):
    y_pred = model(X)
    loss = loss_func(y_pred, Y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
k, b
(Parameter containing:
 tensor([0.8825], requires_grad=True),
 Parameter containing:
 tensor([0.0419], requires_grad=True))
k2 = k.detach().numpy()[0]
b2 = b.detach().numpy()[0]

plt.plot(X, Y);
plt.plot(X, k2 * X + b2);

随机梯度下降

# 前面执行梯度下降时,我们是一次将全部的数据都传入模型
# 但在实际应用中,可能会由于数据太大,没法全部传入模型
# 因此,可以一次传入一部分数据,这便是随机梯度下降
# 随机梯度下降的核心是,梯度是期望。期望可使用小规模的样本近似估计。
model = LinearModel()
# 定义优化器,用于更新模型的参数,即传入的k和b
optimizer = torch.optim.SGD([k, b], lr=0.01)
# 损失函数,模型优化的目的
loss_func = nn.MSELoss()

# 每个epoch表示把全部的数据过一遍
epoch_num = 2000 
# iter_step表示在一个epoch内抽取几个小规模样本
iter_step = 10
# batch_size表示小规模样本的大小
batch_size = 100
for epoch in range(epoch_num):
    for i in range(iter_step):
        random_samples = torch.randint(X.size()[0], (batch_size, ))
        X_i, Y_i = X[random_samples], Y[random_samples]
        y_pred = model(X_i)
        loss = loss_func(y_pred, Y_i)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
k, b
(Parameter containing:
 tensor([0.8825], requires_grad=True),
 Parameter containing:
 tensor([0.0419], requires_grad=True))
k2 = k.detach().numpy()[0]
b2 = b.detach().numpy()[0]

plt.plot(X, Y);
plt.plot(X, k2 * X + b2);

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/225608.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

视频监控管理平台/智能监测/检测系统EasyCVR智能地铁监控方案,助力地铁高效运营

近日&#xff0c;关于全国44座城市开通地铁&#xff0c;却只有5座城市赚钱的新闻冲上热搜。地铁作为城市交通的重要枢纽&#xff0c;是人们出行必不可少的一种方式&#xff0c;但随着此篇新闻的爆出&#xff0c;大家也逐渐了解到城市运营的不易&#xff0c;那么&#xff0c;如何…

高速风筒解决方案,基于高性价比的普冉单片机开发

高速风筒也就是高速吹风机&#xff0c;与传统的吹风机相比&#xff0c;高速吹风机具有更强大的风力和更快的干燥速度&#xff0c;可以更快地干燥头发或其他物体表面的水分。它通常由一个电动机驱动&#xff0c;并通过旋转的叶片来产生气流。高速风筒广泛应用于个人护理、美容、…

Unity链接MySql数据库

一、连接准备 1. MySql.Data插件 Visual Studio中下载打开Visual Studio_项目_管理NuGet程序包在浏览中搜索MySql.Data并下载 2.MySql官网下载插件 前提已经安装mysql&#xff0c;然后到官网下载以下三个东西&#xff08;最好不要使用最新版本&#xff09; MySQL Connector…

如何从Git上拉取项目

1.Git的概念 Git是一个开源的分布式版本控制系统&#xff0c;可以有效、高速的处理从很小到非常大的项目版本管理。它实现多人协作的机制是利用clone命令将项目从远程库拉取到本地库&#xff0c;做完相应的操作后再利用push命令从本地库将项目提交至远程库。 2.Git的工作流程 …

【Java系列】详解多线程(一)

个人主页&#xff1a;兜里有颗棉花糖 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 兜里有颗棉花糖 原创 收录于专栏【Java系列专栏】 本专栏旨在分享学习Java的一点学习心得&#xff0c;欢迎大家在评论区交流讨论&#x1f48c; 目录 一、背景引入二、线程…

echarts绘制一个饼图

其他echarts&#xff1a; qecharts绘制一个柱状图&#xff0c;柱状折线图 效果图&#xff1a; 代码&#xff1a; <template><div class"wrapper"><div ref"pieChart1" id"pieChart1"></div><div ref"pieCha…

CCF编程能力等级认证GESP—C++1级—20230611

CCF编程能力等级认证GESP—C1级—20230611 单选题&#xff08;每题 2 分&#xff0c;共 30 分&#xff09;判断题&#xff08;每题 2 分&#xff0c;共 20 分&#xff09;编程题 (每题 25 分&#xff0c;共 50 分)时间规划累计相加 答案及解析单选题判断题编程题1编程题2 单选题…

docker安装配置prometheus+node_export+grafana

简介 Prometheus是一套开源的监控预警时间序列数据库的组合&#xff0c;Prometheus本身不具备收集监控数据功能&#xff0c;通过获取不同的export收集的数据&#xff0c;存储到时序数据库中。Grafana是一个跨平台的开源的分析和可视化工具&#xff0c;将采集过来的数据实现可视…

2023年6月21日 Go生态洞察:Go 1.21版发行候选版的深入分析

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

基于SpringBoot+Thymeleaf+Mybatis实现大学生创新创业管理系统(源码+数据库+项目运行指导文档)

一、项目简介 本项目是一套基于SpringBoot实现大学生创新创业管理系统&#xff0c;主要针对计算机相关专业的正在做bishe的学生和需要项目实战练习的Java学习者。 包含&#xff1a;项目源码、数据库脚本等&#xff0c;该项目可以直接作为bishe使用。 项目都经过严格调试&#…

局域网共享打印机设置,解决709、11B等一切共享问题

Win7、Win10、Win11添加共享打印机经常出现局域网共享打印机出错&#xff0c;常规的添加方式是SMB共享&#xff0c;一更新就出问题&#xff0c;报错0x00000709、0x0000011b等。 网上有许多方法&#xff0c;如卸载更新补丁&#xff0c;替换“win32spl.dll”文件&#xff0c;修改…

排查200M宽带下行速度低于100M问题(七十七)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…

sql注入 [GXYCTF2019]BabySQli1

打开题目 多次尝试以后我们发现存在一个admin的账号&#xff0c;但是密码我们不知道 我们尝试一下万能密码 admin or 11 -- q 报错 我们尝试bp抓一下包看看 看着很像编码 先去base32解码 再base64解码 得到 我们从这个sql语句中得到注入点为name 根据报错信息我们知道是…

Python语言基础知识(一)

文章目录 1、Python内置对象介绍2、标识符与变量3、数据类型—数字4、数据类型—字符串与字节串5、数据类型—列表、元组、字典、集合6、运算符和表达式7、运算符和表达式—算术运算符8、运算符和表达式—关系运算符9.1、运算符和表达式— 成员测试运算符in9.2、运算符和表达式…

vue+echarts实现桑吉图的效果

前言&#xff1a; 在我们项目使用图形的情况下&#xff0c;桑吉图算是冷门的图形了&#xff0c;但是它可以实现我们对多级数据之间数据流向更好的展示的需求&#xff0c;比如&#xff0c;我们实际数据流向中&#xff0c;具有1对多&#xff0c;多对多的情况下&#xff0c;如果用…

uni-app实现安卓原生态调用身份证阅读器读卡库读身份证和社保卡、银行卡、IC卡等功能

DONSEE系列多功能读写器Android Uniapp API接口规范V1.0.0 本项目Uniapp调用了身份证读卡器的库文件&#xff1a;DonseeDeviceLib-debug.aar&#xff0c;该库放到nativeplugins\donsee-card\android&#xff0c;然后会自动加载。SDK会自动检查是否拥有USB设备权限&#xff0c;…

关于物联网仪表ADW300 远传电表的详细介绍-安科瑞 蒋静

1概述 ADW300无线计量仪表主要用于计量低压网络的三相有功电能&#xff0c;具有体积小、精度高、功能丰富等优点&#xff0c;并且可选通讯方式多&#xff0c;可支持 RS485 通讯和 Lora、NB、4G、wifi 等无线通讯方式&#xff0c;增加了外置互感器的电流采样模式&#xff0c;从而…

数据结构与算法(五)回溯算法(Java)

目录 一、简介1.1 定义1.2 特性1.3 结点知识补充1.4 剪枝函数1.5 使用场景1.6 解空间1.7 实现模板 二、经典示例2.1 0-1 背包问题2.2 N皇后问题 一、简介 1.1 定义 回溯法&#xff08;back tracking&#xff09;是一种选优搜索法&#xff0c;又称为试探法&#xff0c;按选优条…

为什么 AWS 数据库不讲 HTAP

在 AWS re:Invent 2023 掌门人 Adam Selipsky 的 Keynote 上&#xff0c;数据库方面最重磅的主题是 Zero-ETL&#xff0c;从 TP 数据库 (RDS, Aurora, DynamoDB) 同步数据到 AP 数据库 (Redshift)。 Zero-ETL 是 AWS 在去年 re:invent 2022 上推出的概念&#xff0c;今年则继…

网络层之IP组播

学习的最大理由是想摆脱平庸&#xff0c;早一天就多一份人生的精彩&#xff1b;迟一天就多一天平庸的困扰。各位小伙伴&#xff0c;如果您&#xff1a; 想系统/深入学习某技术知识点… 一个人摸索学习很难坚持&#xff0c;想组团高效学习… 想写博客但无从下手&#xff0c;急需…