Pytorch线性回归

使用pytorch来重现线性模型的过程,构造神经网络module,构造损失函数loss,构造随机梯度下降的优化器sgd。

一 revise

首先确定我们的模型,我们希望完成的目标就是得到较小的loss,所以我们就需要一个标量值的loss。

那其实在上一部分的内容就提到了tensor,loss,backward的使用,其实这个就是我们利用pytorch给我们的功能了。

二 pytorch fashion

pytorch写神经网络的第一步就是要准备数据集(有构造数据集的工具),设计一个模型计算y_head,构造损失函数和优化器,写训练的周期(前馈反馈更新)。

2.1准备数据

我们之前准备数据,就是直接使用列表。

x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]

In pytorch , the computational graph is in mini-batch fashion,so X and Y are 3x1 Tensor.

但是现在我们希望它不在是一个列表或者一个向量,我们希望它可以成为一个3X1的矩阵。当然公式依然还是使用之前的y_head = w*x +b 。图中的y_pred 就是我们说的y_head。

大家可能注意到我们上面说计算图使用小批量的方式(mini batch),这其实就是把x的三个数据同时放在一起,同时进行计算。

现在我们有了3x1矩阵的x了,那我们就可以得到3x1矩阵的y。因为存在广播机制,看图中w和b,其实也会变成3x1的矩阵。

当然loss函数的公式也是和之前不变为(y_head - y)**2。此时loss和y也会变成3x1的矩阵。

说了这么多,无非就是要把数据处理成矩阵的形式那使用tensor的方式就是:

import torch
x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[2.0],[4.0],[6.0]])

总结来说,使用mini batch构造数据集的时候,我们就是需要x和y都是矩阵的形式。

2.2 构造计算图

最开始计算梯度手工计算,后来我们构建计算图可以自动把梯度求出来,之后就可以进行优化了。  

因此我们在准备好数据的基础上,下一步就是构造出计算图。我还是使用y_head = x*w + b这样一个函数(放仿射模型)。我们通过这个构建出计算图的一部分叫线性单元。

那其实我们现在进行计算的都是矩阵,既然计算矩阵就需要确定w的大小和b的大小。就需要从z和x的维度来确定w和b的维度。

接下来我们需要把y_head放到loss函数中进行计算,上面我们也说到,loss必须是一个标量(只有是标量的情况下才可以backward)。现在y_head是一个矩阵形式,且y也是,得到的loss应该也是矩阵形式,此时我们就需要对loss进行适当的修改,通常会对loss内的数进行求和,使其变成一个标量。

class LinearModel(torch.nn.Module):
    def __init__(self): #初始化
        super(LinearModel,self).__init__()
        self.linear = torch.nn.Linear(1,1)
    
    def forward(self,x): #前馈
        y_pred = self.linear(x)
        return y_pred

model = LinearModel()

大家可以看见,我们在类中没有写反馈的函数,这是由于model构造出的对象会自动根据计算图去实现backward。

(1)构造函数的super不用变。

torch.nn.Linear 使pytorch中的一个类,这个类里面的对象包括了权重w和偏置b,就可以直接完成下面的整个运算。Linear也是源于Model,所以他也可以进行自动的反向传播。

class torch.nn.linea(in_features,out_features,bias=True) 所做的计算就是y=ax+b。其中in_features表示的就是输入的x是几维的,out_features表示的就是输出的是几维的。那我们在mini batch中矩阵的行表示的使各个样本的值,那此时不难猜出矩阵的列表示的就是feature。bias为True表示需要偏置量,默认为True。

下面两个计算公式都可以,注意w在矩阵乘法的位置。

(2)forward(self,x)

 y_pred = self.linear(x) 这一步其实就是在计算我们的y_head。其实看上面的定义也可以看出来。

(3)最后将模型进行实例化,供我们后面使用。

2.3构造损失函数和优化器

(1)损失函数

我们还是使用MSE损失函数,此时也还是需要构建计算图,整体过程就是在拿到y_head的前提下,对y进行计算,得到loss的值,此时loss还是一个矩阵的形式,最后还需要对其进行变成标量。

class torch.nn.MSELoss(size_average=True,reduce=True) 

其中size_average是是否求均值。reduce是是否进行降维,也就是是否进行求和。

现在这个size_average被废除了,大家可以看下列代码实现求均值的True or False。

#criterion = torch.nn.MSELoss(size_average=True)
#改为:
criterion = torch.nn.MSELoss(reduction='mean')

#criterion = torch.nn.MSELoss(size_average=False)
#改为:
criterion = torch.nn.BCELoss(reduction='sum')

因此对于criterion这个对象需要的参数是(y_head,y)

(2)优化器

optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

优化器是不会建立计算图的。

model.parameters() 这个指的是权重。简单来说就是model中其实没有定义权重,model里存在我们上面定义的成员linear,linear中包含两个权重w和b。现在就需要告诉优化器哪些tensor是需要优化的,哪些是用于梯度下降的。因此在使用SGD这个模型时,想找权重,直接model.parameters。可以把model中所有的参数全部找到。

lr时学习率,一般会给一个固定的值。可以在不同的部分使用不同的学习率。

2.4训练过程

for epoch in range(100):
    y_pred = model(x_data)  #先计算出y_head
    loss = criterion(y_pred,y_data) #再计算出loss
    print(epoch,loss.item()) 
    
    optimizer.zero_grad()#在反馈前将梯度清0
    loss.backward()#反馈
    optimizer.step()#更新

整个过程其实就先算y_head,再计算loss,随后backward,更新。

最后就是打印一下权重信息。

# w b
print('w=',model.linear.weight.item())
print('b=',model.linear.weight.item())

#Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred=',y_test.data)

对于我们给出的x_data和y_data,其实最终最好的情况是w=2 b=0。

分别是100次epoch和1000次epoch的结果,看得出来1000次更接近于我们想要的值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/668835.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux入门攻坚——24、BIND编译安装、Telnet和OpenSSH

BIND编译安装 对于没有rpm包,需要源代码编译安装。 1、下载源代码:bind-9.12.2-P1.tar.gz,解压:tar -xf bind-9.12.2-P1.tar.gz 2、完善环境: 1)增加用户组named:groupadd -g 53 named 2&…

Multipass虚拟机磁盘扩容

Multipass 是一个用于轻松创建和管理 Ubuntu 虚拟机的工具,特别适合开发环境。要使用 Multipass 扩大虚拟机的磁盘容量,你需要经历几个步骤,因为 Multipass 自身并不直接提供图形界面来调整磁盘大小。不过,你可以通过结合 Multipa…

程序员上岸指南

如果你还在996,大小周,感觉身体被掏空,那么你可以看看下面这篇文章,我特意搜集了一些苦逼程序员的上岸教程。 人生真的就是做几道选择题,选错了,忙也是瞎忙。选对了,躺着都能赢。总的来说&#…

MQTT之使用mosquitto

1、下载并安装mosquitto 参考:04 Windows下mosquitto安装_mosquitto-1.6.9-install-windows-x64 windowsserver系-CSDN博客 2、启动 2.1添加用户 .\mosquitto_passwd -c pwfile.example user1 报错信息如下: Error: Unable to open file C:\Program…

Go-Admin后台管理系统源码(GO+VUE)编译与部署

1.克隆源码: # Get backend code git clone https://github.com/go-admin-team/go-admin.git# Get the front-end code git clone https://github.com/go-admin-team/go-admin-ui.git3.下载并安装GO开发环境: 3.编译管理后台后端 # Enter the go-admin backend project cd ./…

深入解析智慧互联网医院系统源码:医院小程序开发的架构到实现

本篇文章,小编将深入解析智慧互联网医院系统的源码,重点探讨医院小程序开发的架构和实现,旨在为相关开发人员提供指导和参考。 一、架构设计 智慧互联网医院系统的架构设计是整个开发过程的核心,直接影响到系统的性能、扩展性和维…

IO流(1)

定义:存取和读取数据的解决方案 作用:用于读写数据(本地文件、网络) 分类: 一种是:输出流和输入流。 一种是:字节流和字符流。 字节流 字节流——FileOutputStream(字节输出流&…

MoeCTF 2022 usb

直接找 URB的第一个输入协议 我们需要提取的数据 HID Data 提取过滤器 tshark -r usb.pcapng -Y "usb.src\"2.2.1\"" -T json >1.json 拿 usbhid.data 字段 tshark -r usb.pcapng -Y "usb.src\"2.2.1\"" -T json -e usbhid.data …

【记录】打印|用浏览器生成证件照打印PDF,打印在任意尺寸的纸上(简单无损!)

以前我打印证件照的时候,我总是在网上找在线证件照转换或者别的什么。但是我今天突然就琢磨了一下,用 PDF 打印应该也可以直接打印出来,然后就琢磨出来了,这么一条路大家可以参考一下。我觉得比在线转换成一张 a4 纸要方便的多&am…

Git常用命令1

1、设置用户签名 ①基本语法: git config --global user.name 用户名 git config --global user.email 邮箱 ②实际操作 ③查询是否设置成功 cat ~/.gitconfig 注:签名的作用是区分不同操作者身份。用户的签名信息在每一个版本的提交…

2.1 OpenCV随手简记(二)

为后续项目学习做准备,我们需要了解LinuxOpenCV、Mediapipe、ROS、QT等知识。 一、图像显示与保存 1、基本原理 1.1 图像像素存储形式 首先得了解下图像在计算机中存储形式:(为了方便画图,每列像素值都写一样了)。对于只有黑白颜色的灰度…

OSM历史10年(2014-2024)全国数据下载(路网、建筑物、POI、水系、地表覆盖利用······)

点击下方全系列课程学习 点击学习—>ArcGIS全系列实战视频教程——9个单一课程组合系列直播回放 零、前沿 这次向大家介绍一下OSM(OpenStreetMap)十年历史数据(2014—2014)的下载方法。当然我们也下载好分享给大家&#xff…

如何上传模型素材创建3D漫游作品?

一、进入3D空间漫游互动工具编辑器 进入720云官网-点击“开始创作”-选择3D空间漫游-进入到作品创建页面。 二、上传模型及素材,创建生成3D空间漫游模型 1.创建3D空间作品:您可以选择新建空白作品或使用720云提供的预设空间模板,本篇主要介绍…

顶级商圈词汇,你听过哪些?

能把复杂的事情说到简单才是真正的高手,堆砌时髦名词的人,大多是被别人影响后的盲从之辈,高认知的人不扯这些虚头巴脑的东西。 顶级商圈词汇,你听过哪些? 生命周期,价值转化,完善逻辑业模式&a…

Elasticsearch 认证模拟题 - 4

一、题目 生成快照,或快照生命周期 1.1 考点 快照生命周期(最好通过界面化配置)创建仓库创建快照 (因为这个需要部署共享文件,所以这个我就在虚拟机上简单操作一下) 注: 部署共享文件系统可…

Java Web基础知识(Servlet、Cookie、Session、Filter、Listener)

文章目录 Servlet什么是Servlet?Servlet的生命周期ServletConfig对象ServletContext对象Servlet请求转发和重定向Servlet请求转发(forward和include)Servlet重定向(redirect)重定向和转发的区别? 读取文件、下载文件 …

攻防对抗少丢分,爱加密帮您筑起第二防线

应用程序通常处理和存储大量的敏感数据,如用户个人信息、财务信息、商业数据、国家数据等,用户量越大的应用程序,其需要存储和保护的用户数据越多。因此应用层长期是攻击方的核心目标,传统应用安全依靠防火墙(FireWall)、入侵检测…

一周学会Django5 Python Web开发 - Django5内置Admin系统二次开发

锋哥原创的Python Web开发 Django5视频教程: 2024版 Django5 Python web开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili2024版 Django5 Python web开发 视频教程(无废话版) 玩命更新中~共计56条视频,包括:2024版 Django5 Python we…

规则引擎Drools,基于mysql实现动态加载部署

文章目录 一、使用1、参考资料2、引包3、创建规则实体类4、实现drools动态规则5、模拟数据库,实现规则的CRUD6、创建控制层7、测试规则的动态添加(1)添加规则(2)修改规则(3)删除规则 8、模拟2个…

01PCB设计概述

PCB设计概述 EDA electronic design automatic 电子设计自动化(利用计算机来实现电子设计) 分为 : 微电子(芯片设计)、硬件板卡(PCB设计) 画原理图、画PCB布线 要会绘制原理图库、和封装图库 元…