pytorch 实现线性回归(深度学习)

一 查看原始函数

        y=2x+4.2

初始化

%matplotlib inline
import random
import torch
from d2l import torch as d2l

1.1 生成原始数据

def synthetic_data(w, b, num_examples):
    x = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(x, w) + b
    print('x:', x)
    print('y:', y)
    y += torch.normal(0, 0.01, y.shape)  # 噪声
    return x, y.reshape((-1 , 1))
true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')

features, labels = synthetic_data(true_w, true_b, 10)

1.2 数据转换

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]

batch_size = 10
for x, y in data_iter(batch_size, features, labels):
    print(f'x: {x}, \ny: {y}')

1.3 初始化权重

随机初始化,w使用 均值0,方差 0.01 的随机值, b 初始化为1

w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b

二 执行训练

查看训练过程中的 参数变化:

print(f'true_w: {true_w}, true_b: {true_b}')

def squared_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

def linreg(x, w, b):
    return torch.matmul(x, w) + b

def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            # print('param:', param, 'param.grad:', param.grad)
            param -= lr * param.grad / batch_size
            param.grad.zero_()

lr = 0.03
num_epochs = 1000
for epoch in range(num_epochs):
    for x, y in data_iter(batch_size, features, labels):
        l = squared_loss(linreg(x, w, b), y)   # 计算总损失
        print('w:', w, 'b:', b)  # l:', l, '\n
        l.sum().backward()
        sgd([w, b], lr, batch_size)

 


三 测试梯度更新

初始化数据

%matplotlib inline
import random
import torch
from d2l import torch as d2l

def synthetic_data(w, b, num_examples):
    x = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(x, w) + b
    print('x:', x)
    print('y:', y)
    y += torch.normal(0, 0.01, y.shape)  # 噪声
    return x, y.reshape((-1 , 1))

true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')

features, labels = synthetic_data(true_w, true_b, 10)

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]

batch_size = 10
for x, y in data_iter(batch_size, features, labels):
    print(f'x: {x}, \ny: {y}')
    
w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b

3.1 测试更新

print(f'true_w: {true_w}, true_b: {true_b}')

def squared_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

def linreg(x, w, b):
    return torch.matmul(x, w) + b

def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            print('param:', param, 'param.grad:', param.grad)
#             param -= lr * param.grad / batch_size
#             param.grad.zero_()

lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):
    for x, y in data_iter(batch_size, features, labels):
        l = squared_loss(linreg(x, w, b), y)   # 计算总损失
        print(f'\nepoch: {epoch},w:', w, 'b:', b)  # l:', l, '\n
        l.sum().backward()  # 计算更新梯度
        sgd([w, b], lr, batch_size)

使用 l.sum().backward()  # 计算更新梯度:

不使用更新时:

print(f'true_w: {true_w}, true_b: {true_b}')

def squared_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

def linreg(x, w, b):
    return torch.matmul(x, w) + b

def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            print('param:', param, 'param.grad:', param.grad)
#             param -= lr * param.grad / batch_size
#             param.grad.zero_()

lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):
    for x, y in data_iter(batch_size, features, labels):
        l = squared_loss(linreg(x, w, b), y)   # 计算总损失
        print(f'\nepoch: {epoch},w:', w, 'b:', b)  # l:', l, '\n
        # l.sum().backward()  # 计算更新梯度
        sgd([w, b], lr, batch_size)
        
#     break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/392304.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

JavaWeb-JDBC-API详解

一、JDBC介绍 二、JDBC 快速入门 package com.itheima.jdbc;import java.sql.Connection; import java.sql.DriverManager; import java.sql.Statement;public class JDCBDemo {public static void main(String[] args) throws Exception {//1、注册驱动Class.forName("co…

django中事务和锁

目录 一:事务(Transactions) 二:锁 在Django中,事务和锁是数据库操作中的两个重要概念,它们用于确保数据的完整性和一致性。下面我将分别解释这两个概念在Django中的应用。 一:事务&#xff…

Code Composer Studio (CCS) - Breakpoint (断点)

Code Composer Studio [CCS] - Breakpoint [断点] 1. BreakpointReferences 1. Breakpoint 选中断点右键 -> Breakpoint Properties… Skip Count:跳过断点总数,在断点执行之前设置总数 Current Count:当前跳过断电累计值 References […

Ubuntu学习笔记-Ubuntu搭建禅道开源版及基本使用

文章目录 概述一、Ubuntu中安装1.1 复制下载安装包路径1.2 将安装包解压到ubuntu中1.3 启动服务1.4 设置开机自启动 二、禅道服务基本操作2.1 启动,停止,重启,查看服务状态2.2 开放端口2.3 访问和登录禅道 卜相机关 卜三命、相万生&#xff0…

第13章 网络 Page738~741 13.8.3 TCP/UDP简述

libcurl是C语言写成的网络编程工具库,asio是C写的网络编程的基础类型库 libcurl只用于客户端,asio既可以写客户端,也可以写服务端 libcurl实现了HTTP\FTP等应用层协议,但asio却只实现了传输层TCP/UDP等协议。 在学习http时介绍…

CSS概述 | CSS的引入方式 | 选择器

文章目录 1.CSS概述2.CSS的引入方式2.1.内部样式表2.2.行内样式表2.3.外部样式表 3.选择器 1.CSS概述 CSS,全称Cascading Style Sheets(层叠样式表),是一种用来设置HTML(或XML等)文档样式的语言。CSS的主要…

Code Composer Studio (CCS) - Current and Local Revision

Code Composer Studio [CCS] - Current and Local Revision References 鼠标放在文件内的任意位置,鼠标右键 -> Compare With -> Local History -> Revision Time. References [1] Yongqiang Cheng, https://yongqiang.blog.csdn.net/

vue-路由(六)

阅读文章你可以收获什么? 1 明白什么是单页应用 2 知道vue中的路由是什么 3 知道如何使用vueRouter这个路由插件 4 知道如何如何封装路由组件 5 知道vue中的声明式导航router-link的用法 6 知道vue中的编程式导航的使用 7 知道声明式导航和编程式导航式如何传…

【数据结构】18 二叉搜索树(查找,插入,删除)

定义 二叉搜索树也叫二叉排序树或者二叉查找树。它是一种对排序和查找都很有用的特殊二叉树。 一个二叉搜索树可以为空,如果它不为空,它将满足以下性质: 非空左子树的所有键值小于其根节点的键值非空右子树的所有键值都大于其根结点的键值左…

Rust 学习笔记 - 注释全解

前言 和其他编程语言一样,Rust 也提供了代码注释的功能,注释用于解释代码的作用和目的,帮助开发者理解代码的行为,编译器在编译时会忽略它们。 单行注释 单行注释以两个斜杠 (//) 开始,只影响它们后面直到行末的内容…

Java面向对象三大特征之封装

封装的作用和含义: 程序的设计要追求“高内聚,低耦合”。高内聚就是类的内部数据操作细节自己完成,不允许外部干涉;低耦合是仅暴露少量的方法给外部使用,尽量方便外部调用。 编程中封装的具体优点: 提高代…

Days 33 ElfBoard 固定CPU频率

ELF 1开发板选用的是主频800MHz NXP的i.MX6ULL处理器。根据实际的应用场景,如果需要降低CPU功耗,其中一种方法可以将CPU频率固定为节能模式,下面以这款开发板为例给小伙伴们介绍一下固定CPU频率的方法。 先来介绍一下与CPU频率相关的命令&…

关于umi ui图标未显示问题

使用ant design pro 时,安装了umi ui ,安装命令: yarn add umijs/preset-ui -D但是启动项目后,发现没有显示umi ui的图标 找了许多解决方案,发现 umi的版本问题,由于我使用的ant design pro官网最新版本&a…

Quantitative Analysis: PIM Chip Demands for LLAMA-7B inference

1 Architecture 如果将LLAMA-7B模型参数量化为4bit,则存储模型参数需要3.3GB。那么,至少PIM chip 的存储至少要4GB。 AiM单个bank为32MB,单个die 512MB,至少需要8个die的芯片。8个die集成在一个芯片上。 提供816bank级别的访存带…

解决IDEA的Project无法正常显示的问题

一、问题描述 打开IDEA,结果发现项目结构显示有问题: 二、解决办法 File -> Project Structure… -> Project Settings (选Modules),然后导入Module 结果: 补充: IDEA提示“The imported module settings a…

root MUSIC 算法补充说明

root MUSIC 算法补充说明 多项式求根root MUSIC 算法原理如何从 2 M − 2 2M-2 2M−2 个根中确定 K K K 个根从复数域上观察 2 M − 2 2M-2 2M−2 个根的分布 这篇笔记是上一篇关于 root MUSIC 笔记的补充。 多项式求根 要理解 root MUSIC 算法,需要理解多项式求…

面试题-01

1、JDK 和 JRE 和 JVM 分别是什么,有什么区别? JDK(Java Development Kit,Java 软件开发工具包) JDK(Java Development Kit):JDK 是 Java 开发⼯具包,包含了编写、编译…

社区居家养老新选择,全视通智慧方案让长者生活更安心

随着人口老龄化趋势加剧,养老问题已经成为社会各界关注的焦点。我国政府积极采取相关措施,加速推动养老服务业的健康发展。2023年5月,《城市居家适老化改造指导手册》发布,针对城市老年人居家适老化改造需求,提出了47项…

Linux线程(1)--线程的概念 | 线程控制

目录 前置知识 线程的概念 Linux中对线程的理解 重新定义进程与线程 重谈地址空间 线程的优缺点 线程的优点 线程的缺点 线程异常 线程的用途 Linux线程 VS 进程 线程控制 创建线程 线程等待 线程终止 线程ID的深入理解 前置知识 我们知道一个进程有属于自己的P…

python学习24

前言:相信看到这篇文章的小伙伴都或多或少有一些编程基础,懂得一些linux的基本命令了吧,本篇文章将带领大家服务器如何部署一个使用django框架开发的一个网站进行云服务器端的部署。 文章使用到的的工具 Python:一种编程语言&…