动手学深度学习(二)---线性回归

文章目录

    • 1.线性回归从0实现
    • 2.线性回归简洁实现
    • 【相关方法】
      • torch.normal()

1.线性回归从0实现

从0开始实现整个方法,包括数据流水线、模型、损失函数和小批量随机梯度下降优化器
(1)导入需要的包

% matplotlib inline
import random
import torch
from d2l import torch as d2l  #pip install d2l torch torchvision

(2)构造数据
在这里插入图片描述

def synthetic_data(w,b,num_examples):
    """生成y = Xw + b + 噪声"""
    X = torch.normal(0, 1, (num_examples, len(w)))  #生成均值为0,方差为1的随机数X,大小为:num个样本,列数为len(w)
    Y = torch.matmul(X, w) + b
    Y += torch.normal(0, 0.01, Y.shape) # 噪音
    return X, Y.reshape((-1,1))  #将X,Y以列向量的形式返回

true_w = torch.tensor([2,-3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
print('features:',features[0], '\nlabel:',labels[0])

在这里插入图片描述

(3)定义一个data_iter函数,该函数接收批量大小、特征矩阵和标签向量作为输入,生成大小为batch_size的小批量

def data_iter(batch_size, features, labels):
    num_examples = len(features) #样本数量
    indices = list(range(num_examples))  # 每个样本的index,0到num_example-1,存成list
    # print(indices)
    # 随机读取数据
    random.shuffle(indices)  # 打乱indices,即打乱index,以做到随机访问每一个样本
    # print(indices)
    # 从0到num_examples-1,步长为batch_size
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(indices[i:min(i + batch_size,num_examples)])
        # yield:以(x,y)的形式返回
        yield features[batch_indices],labels[batch_indices]

batch_size = 10
for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    break

(4)定义损失函数

def squared_loss(y_hat, y):
    """均方损失"""
    return (y_hat - y.reshape(y_hat.shape))**2 / 2

(5)定义优化算法

def sgd(params, lr, batch_size):
    """小批量随机梯度下降"""
    with torch.no_grad( ):
        for param in params:
            param -= lr * param.grad /batch_size
            param.grad.zero_()

(6)训练过程

lr = 0.03  #学习率
num_epochs = 3
net = linreg
loss = squared_loss

for epoch in range(num_epochs):
    for X,y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)
        l.sum().backward()
        sgd([w,b], lr, batch_size)
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch+1},loss {float(train_l.mean()):f}')
epoch 1,loss 0.037877
epoch 2,loss 0.000149
epoch 3,loss 0.000053

2.线性回归简洁实现

(1)使用深度学习框架实现线性回归模型,生成数据集

import numpy as np
import torch
from torch.utils import data  # 处理数据的相关模块
from d2l import torch as d2l

true_w = torch.tensor([2,-3.4])
true_b = 4.2
# 通过人工数据生成的函数合成所需要的数据
features, labels = d2l.synthetic_data(true_w, true_b, 100)

(2)调用框架中现有的API来读取数据

def load_array(data_arrays, batch_size, is_train=True):
    """构造一个PyTorch数据迭代器"""
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset,batch_size,shuffle=is_train)

batch_size = 10
data_iter = load_array((features, labels), batch_size)

next(iter(data_iter))
[tensor([[-2.6414, -0.0434],
         [-0.4269, -0.6939],
         [-1.0043,  1.0792],
         [-0.7633,  0.5864],
         [-1.0033, -0.0062],
         [ 0.2628, -0.9052],
         [-1.2950, -0.1154],
         [ 1.1647, -1.1407],
         [ 1.4055, -1.1912],
         [ 0.4733,  1.5539]]),
 tensor([[-0.9327],
         [ 5.7303],
         [-1.4759],
         [ 0.6839],
         [ 2.2185],
         [ 7.8047],
         [ 2.0032],
         [10.4156],
         [11.0725],
         [-0.1272]])]

(3)使用框架预定义好的层

from torch import nn

net = nn.Sequential(nn.Linear(2,1))  # 输入维度为2,输出维度为1

(4)初始化模型参数

net[0].weight.data.normal_(0, 0.01) #使用正态分布替换data的值
net[0].bias.data.fill_(0)
tensor([0.])

(5)计算均方误差使用MSELoss类

Loss = nn.MSELoss()

(6)实例化SGD

trainer = torch.optim.SGD(net.parameters(), lr=0.03)

(7)训练

num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        l = Loss(net(X), y)
        trainer.zero_grad()
        l.backward()
        trainer.step()
    l = Loss(net(features), labels)
    print(f'epoch {epoch+1}, loss {l:f}')
epoch 1, loss 0.000105
epoch 2, loss 0.000105
epoch 3, loss 0.000105

【相关方法】

torch.normal()

torch.normal(mean, std, generator=None, out=None)

返回从单独的正态分布中提取的随机数的张量,该正态分布的均值是mean,标准差是std

import torch
x = torch.normal(0,1,(2,2))
print(x)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/174939.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

GCANet

2019、中科大港科、有代码 Chen D, He M, Fan Q, et al. Gated context aggregation network for image dehazing and deraining[C]//2019 IEEE winter conference on applications of computer vision (WACV). IEEE, 2019: 1375-1383. GitHub - cddlyf/GCANet: Implementation…

力扣每日一题-美化数组的最少删除数-2023.11.21

力扣每日一题:美化数组的最少删除数 开篇 今天的力扣每日一题居然写出来了,好开心,迫不及待地把题目分享出来,希望你也能把它狠狠拿下。 题目链接: 2216.美化数组的最少删除数 题目描述 代码思路 创建一个list集合来保存数组&a…

【Python3】【力扣题】338. 比特位计数

【力扣题】题目描述: 题解:从0到n的整数,逐一统计二进制中1的个数,记录在一个新列表中。 【Python3】代码: 1、解题思路:Python函数。 知识点:bin(...):转为二进制字符串&#xff…

Ubuntu环境下基于libxl库文件使用C++实现对表格的操作

功能 表格不存在则创建后再进行操作创建sheet添加新的工作表在sheet中增加数据设置单元格样式 相关配置 下载地址:libxl选择 LibXL for Linux 4.2.0 i386 x64 armhf aarch64 安装配置 1,使用 tar zxvf 文件名.tar.gz 进行文件解压2,创…

【电路笔记】-电源电压

电源电压 文章目录 电源电压1、概述1.1 交流发电机1.2 电池1.3 理想电压源1.4 实际电压源1.5 连接规则 2、相关源2.1 压控电压源 (VCVS)2.2 电流控制电压源 (CCVS) 3、总结 在本文中,我们详细介绍了称为电源电压的重要电子元件的架构、功能和使用。 我们首先提出理想…

达索系统3DEXPERIENCE WORKS 2024 Fabrication新功能

当发现产品的制造环节,以及因产品模型本身的设计而导致制造环节存在不合理性,从而导致加工制造成本增加。 快速判断,轻松协作 在达索系统3DEXPERIENCE WORKS 2024中我们可以快速的判断产品的可制造性,以及快速与前端设计沟通协作…

目标文件(ELF格式)

1.linux中有三类目标文件 **(1)可重定位目标文件(.o或者.a):**包含二进制代码和数据,其形式可以和其他目标文件进行合并,创建一个可执行目标文件。(.a文件是由很多个.o文件的集合&a…

【设备树添加节点】

节点结束位置都需要加分号 of_iomap 完成映射 of_property_read_u32_array of_property_read_string of_fine_node_by_path

C++实战学习笔记

文章目录 erase()uniquevector的insert()std::string::npos erase() (1)erase(pos,n); 删除从pos开始的n个字符,比如erase(0,1)就是删除第一个字符 (2)erase(position);删除position处的一个字符(position是个string类…

使用yum安装jdk,并配置环境变量

写在前面: yum安装的jdk,默认不用配置环境变量就可以java -version显示结果,但是却不能在系统环境变量中查看到JAVA_HOME,由于其他应用需要这个环境变量,比如hadoop,所以这里说明如何进行java的相关环境变量配置 1. y…

6.Gin 路由详解 - GET POST 请求以及参数获取示例

6.Gin 路由详解 - GET POST 请求以及参数获取示例 GET POST 请求以及参数获取示例 Get 请求:获取 Quary 参数 // 获取query参数示例:GET /user?uid20&namejack&page1 r.GET("/user", func(c *gin.Context) {// 获取参数// Query获取参…

Spring-IOC-Spring6和JUnit5集成

1、父工程pom.xml <properties><maven.compiler.source>17</maven.compiler.source><maven.compiler.target>17</maven.compiler.target><project.build.sourceEncoding>UTF-8</project.build.sourceEncoding><spring.version>…

Linux驱动开发笔记(四):设备驱动介绍、熟悉杂项设备驱动和ubuntu开发杂项设备Demo

若该文为原创文章&#xff0c;转载请注明原文出处 本文章博客地址&#xff1a;https://hpzwl.blog.csdn.net/article/details/134533533 红胖子网络科技博文大全&#xff1a;开发技术集合&#xff08;包含Qt实用技术、树莓派、三维、OpenCV、OpenGL、ffmpeg、OSG、单片机、软硬…

LVS+Keepalived 高可用群集

一、一.Keepalived工具介绍 专为LVS和HA设计的一款健康检查工具 • 支持故障自动切换&#xff08;Failover&#xff09; • 支持节点健康状态检查&#xff08;Health Checking&#xff09; • 官方网站&#xff1a;http://www.keepalived.org/ 二、Keepalived工作原理 • …

淘宝商品详情接口,商品属性接口,商品信息查询,商品详细信息接口,h5详情,淘宝APP详情

淘宝商品详情API接口可以使用淘宝开放平台提供的SDK或API来获取。这些接口可以用于获取商品的详细信息&#xff0c;如标题、价格、描述、图片等。 以下是使用淘宝开放平台API获取商品详情的步骤&#xff1a; 注册淘宝开放平台账号&#xff0c;并创建应用&#xff0c;获取应用…

【JavaEE初阶】 JavaScript基础语法——壹

文章目录 &#x1f38b;初识JavaScript&#x1f6a9;JavaScript 是什么&#x1f6a9;JavaScript 和 HTML 和 CSS 之间的关系&#x1f6a9;JavaScript 运行过程&#x1f6a9;JavaScript 的组成 &#x1f38d;前置知识&#x1f6a9;第一个JS程序&#x1f6a9;JavaScript 的书写形…

vue3 uniapp h5 安卓和iOS开发适配踩坑记录

font-size适配屏幕大小及iOS和安卓状态栏及安全距离的处理 App.vue <script setup lang"ts"> import { onLaunch, onShow, onHide } from "dcloudio/uni-app"; import ./main.scss onLaunch(() > {console.log("App Launch");var wid…

leetcode刷题详解——粉刷房子

1. 题目链接&#xff1a;LCR 091. 粉刷房子 2. 题目描述&#xff1a; 假如有一排房子&#xff0c;共 n 个&#xff0c;每个房子可以被粉刷成红色、蓝色或者绿色这三种颜色中的一种&#xff0c;你需要粉刷所有的房子并且使其相邻的两个房子颜色不能相同。 当然&#xff0c;因为…

六大排序详讲(直接插入排序+希尔排序+选择排序+堆排序+冒泡排序+快速排序)

文章目录 排序一、 排序的概念1.排序&#xff1a;2.稳定性&#xff1a;3.内部排序&#xff1a;4.外部排序&#xff1a; 二、插入排序1.直接插入排序2.希尔排序 三、选择排序1.直接选择排序方法一方法二直接插入排序和直接排序的区别 2.堆排序 四、交换排序1.冒泡排序2.快速排序…

Threejs_07 环境、透明度、纹理、ao、光照等贴图的渲染

老陈打码 继续学习老陈threejs 支持&#xff01;&#xff01;&#xff01;&#xff01;&#xff01;&#xff01; 下面用到的所有图片、资源、hdr文件都是老陈打码的原资源 链接&#xff1a;https://pan.baidu.com/s/1WWWHgekCIH7OnjI7S_3ZtQ 提取码&#xff1a;6666 Thre…