用Pytorch实现线性回归(Linear Regression with Pytorch)

使用pytorch写神经网络的第一步就是需要准备好数据集,设计模型(用于计算y_hat(y的预测值)),构造损失函数和优化器(使用PyTorch API),写训练周期(前馈(算loss)+反馈(算梯度)+更新(更新权重))

一:准备数据

现在使用mini-batch的方式,X和Y为3x1(可以变,但是x和y要相同)的矩阵形式。

从代码中也可以看出来,x和y都是3x1的矩阵。

二:设计模型(构造计算图)

此处使用了一个仿射模型(在pytorch中叫做线性单元)

在我们设计的例子中,我们需要设置权重w的数值,和偏置量b。

那w和b的形状(几x几的矩阵),是由y_hat和x来共同确定。

之后将y_hat和y放入loss函数中进行计算,得出loss的值(一定是一个标量)。

看下模型设计的代码:

#需要继承自module ,因为module中有很多方法我们需要使用
class LinearModel(torch.nn.Module):
    def __init__(self): #构造函数 在初始化对象时默认调用的函数
        super(LinearModel,self).__init__() #super调用父类的构造
        self.linear = torch.nn.Linear(1,1) #构造一个对象 linear Unit中的w和b(linear来自父类,可以自动反向传播)
    
    def forward(self,x): #前馈需要进行的计算 发现没有backword模块,因为Module中自动根据计算图实现backword过程
        y_pred = self.linear(x)
        return y_pred

model = LinearModel() #实例化 在之后既可以使用model(x)将x传入forword中的x,求得y_pred

其中torch.nn.Linear 的使用方法如下 

三:构造loss和optimizer

此处我们使用MSEloss,需要的参事时y_hat和y,就可以求出loss。

代码如下:

criterion = torch.nn.MSELoss(size_average=False)

我们使用SGD优化器(不会构建计算图),代码如下

optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

四:训练过程

for epoch in range(100):
    y_pred = model(x_data)  #先计算出y_hat
    loss = criterion(y_pred,y_data) #再计算出loss
    print(epoch,loss.item()) 
    
    optimizer.zero_grad()#在反馈前将梯度清0
    loss.backward()#反馈
    optimizer.step()#更新

最后打印一些相关内容

# w b
print('w=',model.linear.weight.item())
print('b=',model.linear.weight.item())

#Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred=',y_test.data)

发现当range为1000时,已经达到了我们的预期。

五:整体流程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/802671.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

FPGA资源容量

Kintex™ 7 https://www.amd.com/zh-tw/products/adaptive-socs-and-fpgas/fpga/kintex-7.html#product-table AMD Zynq™ 7000 SoC https://www.amd.com/en/products/adaptive-socs-and-fpgas/soc/zynq-7000.html#product-table AMD Zynq™ UltraScale™ RFSoC 第一代 AMD Z…

浅说区间dp(下)

文章目录 环形区间dp例题[NOI1995] 石子合并题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 提示思路 [NOIP2006 提高组] 能量项链题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 提示思路 [NOIP2001 提高组] 数的划分题目描述输入格式输出格式样例 #1样例输…

AI大模型加持的新一代网络舆情系统——“速途观澜”舆情感知引擎发布上线

**近日,AI和数据驱动的新媒体生态服务商速途网络,发布上线企业声誉管理智能服务平台“速途观澜”。**该平台融合了速途最新研发升级的“观澜舆情感知引擎”,一款以大数据和AI为底座的网络舆情态势感知系统,这是速途在产品创新研发…

orcad导出pdf 缺少title block

在OrCAD中导出PDF时没有Title Block 最后确认问题在这里: 要勾选上Title Block Visible下面的print

共建特色基地 协同互促育人

作为芯片和集成电路、人工智能、智能网联车等临港重点产业布局的知识密集型相关企业,核心技术人才和技术骨干是公司参与全球竞争的重要核心竞争力之一。 知从科技通过不断的创新和规范,在深化产教融合、校企合作、“双师型”、联合办学协同育人、产业人…

Kafka Producer发送消息流程之分区器和数据收集器

文章目录 1. Partitioner分区器2. 自定义分区器3. RecordAccumulator数据收集器 1. Partitioner分区器 clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java,中doSend方法,记录了生产者将消息发送的流程,其中有一步…

【XSS】

文章目录 0x01 简介0x02 XSS Payload用法XSS攻击平台及调试JavaScript 0x03 XSS构造技巧XSS漏洞防御策略 跨站脚本攻击,Cross Site Script。(重点在于脚本script) 分类 反射型、存储型DOM型 漏洞原理:通过插入script篡改“HTML”…

字节码编程bytebuddy之通过Advice动态修改方法参数值

写在前面 本文看下如何通过bytebuddy的advice切面技术来动态修改方法入参值。 1:程序 首先定义premain: package com.dahuyou.change.method.param;//import net.bytebuddy.agent.builder.AgentBuilder; import net.bytebuddy.agent.builder.AgentBu…

Java web从入门到精通 (第 2版)中文电子版

前言 《Java Web从入门到精通(第2版)》共分21章,包括Java Web应用开发概述、HTML与CSS网页开发基础、JavaScript脚本语言、搭建开发环境、JavaBean技术、Servlet技术、过滤器和监听器、Hibernate高级应用、Java Web的数据库操作、EL&#xf…

Linux 上 TTY 的起源

注:机翻,未校对。 What is a TTY on Linux? (and How to Use the tty Command) What does the tty command do? It prints the name of the terminal you’re using. TTY stands for “teletypewriter.” What’s the story behind the name of the co…

每日一题,力扣leetcode Hot100之49. 字母异位词分组

该题用哈希表解答,具有统一特征的作为哈希表的键名,然后满足要求的作为值 解法一: 我们将每个字符串进行排序,如果排序后的结果相同,则可以认为是字母异位词,我们将排序后的结果作为哈希表的key&#xff…

智能听诊器:宠物健康监测的革新者

宠物健康护理领域迎来了一项激动人心的技术革新——智能听诊器。这款创新设备以其卓越的精确度和用户友好的操作,为宠物主人提供了一种全新的健康监测方法。 使用智能听诊器时,只需将其放置在宠物身上,它便能立即捕捉到宠物胸腔的微小振动。…

S274多功能可编程RTU在智慧水务远程水质检测系统中的应用案例

钡铼第四代RTU S274作为一款多功能可编程的无线工业物联网数据监测采集控制短信报警终端,为智慧水务领域提供了强大的技术支持和解决方案。 技术概述与特点 钡铼S274基于UCOSII嵌入式实时操作系统,支持多种通信协议包括短信和MQTT,能够接入…

2024嘶吼网络安全产业图谱(高清完整版)

在数字化和智能化浪潮的推动下,网络安全产业正处于一个快速变革的时期。从传统的防御手段和被动的威胁应对,到如今主动预防和智能检测技术的普及,网络安全领域的焦点和需求正不断演进。为了更好的理解当前网络安全产业现状和未来发展方向&…

jeecgboot项目不知道什么原因启动出来8080端口后就不下去了,要等上10多分钟才出来接口地址等正常情况

因为这个项目license问题无法开源,更多技术支持与服务请加入我的知识星球。 1、项目中途不知道什么原因,就出现下面情况 具体如下: 2024-07-15 15:08:15.767 [main] [34mINFO [0;39m [36mliquibase.changelog:30[0;39m - Reading from jeec…

【LeetCode】十七、并查集

文章目录 1、并查集Union Find2、并查集find的优化:路径压缩 Quick find3、并查集union的优化:权重标记 1、并查集Union Find 并查集,一种树形的数据结构,处理不相交的两个集合的合并与查询问题。 【参考:&#x1f4…

优化 Java 数据结构选择与使用,提升程序性能与可维护性

优化 Java 数据结构选择与使用,提升程序性能与可维护性 引言 在软件开发中,数据结构的选择是影响程序性能、内存使用以及代码可维护性的关键因素之一。Java 作为一门广泛使用的编程语言,提供了丰富的内置数据结构,如数组、链表、…

python用selenium网页模拟时xpath无法定位元素解决方法2

有时我们在使用python selenium xpath时,无法定位元素,红字显示no such element。上一篇文章写了1种情况,是包含iframe的,详见https://blog.csdn.net/Sixth5/article/details/140342929。 本篇写第2种情况,就是xpath定…

嵌入式人工智能(9-基于树莓派4B的PWM-LED呼吸灯)

1、PWM简介 (1)、什么是PWM 脉冲宽度调制(PWM),是英文“Pulse Width Modulation”的缩写,简称脉宽调制,是在具有惯性的系统中利用微处理器的数字输出来对模拟电路进行控制的一种非常有效的技术,广泛应用在从测量、通信到功率控制…