【动手学深度学习Pytorch】1. 线性回归代码

零实现

        导入所需要的包:

# %matplotlib inline
import random
import torch
from d2l import torch as d2l
import matplotlib.pyplot as plt
import matplotlib
import os

        构造人造数据集:假设w=[2, -3.4],b=4.2,存在随机噪音(均值为0,方差为0.001的正态分布噪声),函数拟合为y = w^{T}X + b + n。在构造数据集的过程中,首先X为正态分布(均值为0,方差为1,样本数/行数为num_examples,列数为len(w))

torch.normal(mean, std, *, generator=None, out=None):生成指定输出尺寸的正态分布随机数张量

torch.mv():矩阵和向量的乘积,此处X为矩阵,w为向量

def synthetic_data(w, b, num_examples):
    X = torch.normal(0, 1, (num_examples, len(w))) #均值为0方差为1的随机数,样本数,列数
    y = torch.mv(X, w) + b #y关于x的公式
    y += torch.normal(0, 0.001, y.shape) # 加入噪声项
    return X, y.reshape((-1,1)) #做成列向量返回
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

        查看数据集样本分布:

matplotlib.pyplot.scatter(x, y, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, *, edgecolors=None, plotnonfinite=False, data=None, **kwargs):

        x,y:长度相同的数组,也就是我们即将绘制散点图的数据点,输入数据。

        s:点的大小,默认 20,也可以是个数组,数组每个参数为对应点的大小。

        c:点的颜色,默认蓝色 'b',也可以是个 RGB 或 RGBA 二维行数组。

        marker:点的样式,默认小圆圈 'o'。

        cmap:Colormap,默认 None,标量或者是一个 colormap 的名字,只有 c 是一个浮点数数组的时才使用。如果没有申明就是 image.cmap。

        norm:Normalize,默认 None,数据亮度在 0-1 之间,只有 c 是一个浮点数的数组的时才使用。

        vmin,vmax:亮度设置,在 norm 参数存在时会忽略。

        alpha:透明度设置,0-1 之间,默认 None,即不透明。

        linewidths:标记点的长度。

        edgecolors:颜色或颜色序列,默认为 'face',可选值有 'face', 'none', None。

        plotnonfinite:布尔值,设置是否使用非限定的 c ( inf, -inf 或 nan) 绘制点。

        **kwargs:其他参数。

detach():允许我们从计算图中分离出张量。当对一个张量调用detach()方法时,它会创建一个新的张量,这个新张量与原始张量共享数据,但它不再参与计算图的任何操作,对分离后的张量进行的任何操作都不会影响原始张量,也不会在计算图中留下任何痕迹。

plt.scatter(features[:,(1)].detach().numpy(),labels.detach().numpy(),1);
plt.show()

        遍历数据集,输出数据集内容:

len(): 返回对象(字符、列表、元组等)长度或项目个数(此处是张量的行数)

list(): 将元组转换为列表

range():创建一个整数列表

shuffle(): 随机打乱列表

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples)) #生成样本索引
    random.shuffle(indices) #样本随机读取没有特定顺序
    # 进行batch划分
    for i in range(0, num_examples, batch_size): #从i开始到i+batchsize
        batch_indices =  torch.tensor(indices[i:min(i + batch_size, num_examples)])
        # 截取切片:开始位置为i,结束位置为min函数的返回值
        # 返回值为i+batch_size和num_examples的值比较小的那个
        yield features[batch_indices], labels[batch_indices] #产生随机顺序的特征&标号

batch_size = 10

for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    break

 

      定义参数、模型、损失函数以及优化算法:

torch.mutual():矩阵相乘

with torch.no_grad():所有计算得出的tensor的requires_grad都自动设置为False,不会进行自动求导

grad.zero_():将梯度置零(不然会发生累计的情况)

# 定义初始化模型参数
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
# 定义模型
def linreg(X, w, b):
    return torch.matmul(X, w) + b
# 定义损失函数
def squared_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape))**2/2
# 定义优化算法
def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

        定义训练过程:

# 训练过程
lr = 0.01
num_epochs = 10
net = linreg
loss = squared_loss

for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)
        l.sum().backward()
        sgd([w,b], lr, batch_size)
    with torch.no_grad():
        train_1= loss(net(features, w, b), labels)
        print(f'epoch{epoch + 1}, loss{float(train_1.mean()):f}')

简介实现

        导入所需要的包:

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
import matplotlib.pyplot as plt

        创建人造数据集:

data.TensorDataset():将数据进行封装

data.DataLoader():将数据分批次处理

iter():获取列表的迭代器

next():获取下一个值

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b,1000)
def load_array(data_arrays, batch_size, is_train=True):
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

batch_size = 10
data_iter = load_array((features, labels), batch_size)

next(iter(data_iter))

初始化模型、模型参数、loss: 

nn.Sequential():实现模型层结构的简单排序

torch.optim.SGD():定义优化算法

torch.optim.SGD().step():进行模型的更新

# 使用框架的预定义好的层
from torch import nn
net = nn.Sequential(nn.Linear(2,1))
# 初始化模型参数
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)
# 计算均方误差使用的是MSELoss类
loss = nn.MSELoss()
trainer = torch.optim.SGD(net.parameters(),lr=0.01)

        定义训练过程:

num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        l = loss(net(X), y)
        trainer.zero_grad()
        l.backward()
        trainer.step()
    l = loss(net(features), labels)
    print(f'epoch{epoch + 1}, loss{1:f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/917791.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【数据结构】树——顺序存储二叉树

写在前面 在学习数据结构前,我们早就听说大名鼎鼎的树,例如什么什么手撕红黑树大佬呀,那这篇笔记不才就深入浅出的介绍二叉树。 文章目录 写在前面一、树的概念及结构1.1、数的相关概念1.2、数的表示1.3 树在实际中的运用(表示文…

Linux常用命令,持续更新钟

在Linux系统中,你可以使用多种命令来拷贝和移动文件及目录。以下是常用的几个命令及其用法: 一、拷贝文件或目录 cp 命令 cp 命令用于拷贝文件或目录。 拷贝文件: cp source_file destination_file 例如: cp file1.txt /hom…

计算机视觉中的双边滤波:经典案例与Python代码解析

🌟 计算机视觉中的双边滤波:经典案例与Python代码解析 🚀 Hey小伙伴们!今天我们要聊的是计算机视觉中的一个重要技术——双边滤波。双边滤波是一种非线性滤波方法,主要用于图像去噪和平滑,同时保留图像的边…

Ubuntu 22.04 上快速搭建 Samba 文件共享服务器

Samba 简介 Samba 是一个开源软件,它扮演着不同操作系统间沟通的桥梁。通过实现 SMB(Server Message Block)协议,Samba 让文件和打印服务在 Windows、Linux 和 macOS 之间自由流动。 以下是 Samba 的特点: 跨平台兼…

在MATLAB中实现自适应滤波算法

自适应滤波算法是一种根据信号特性自动调整滤波参数的数字信号处理方法,其可以有效处理噪声干扰和信号畸变问题。在许多实时数据处理系统中,自适应滤波算法得到了广泛应用。在MATLAB中,可以使用多种方法实现自适应滤波算法。本文将介绍自适应…

AWTK-WIDGET-WEB-VIEW 发布

awtk-widget-web-view 是通过 webview 提供的接口,实现的 AWTK 自定义控件,使得 AWTK 可以方便的显示 web 页面。 项目网址: https://gitee.com/zlgopen/awtk-widget-web-view webview 提供了一个跨平台的 webview 接口,是一个非…

使用Web Workers提升JavaScript的并行处理能力

💓 博客主页:瑕疵的CSDN主页 📝 Gitee主页:瑕疵的gitee主页 ⏩ 文章专栏:《热点资讯》 使用Web Workers提升JavaScript的并行处理能力 使用Web Workers提升JavaScript的并行处理能力 使用Web Workers提升JavaScript的…

shell编程--传参与数学运算

探讨一下如何向shell脚本传递参数。 脚本传参 首先用vim创建一个脚本。 vim 脚本.sh 可以理解为其他编程语言的标准输出,例如C语言的输出%d,表标准输出数字。 用echo 执行文件名称是:$0 echo 第一个参数是:$1 echo 传递参数作为字符串显…

头歌-本关任务:使用GmSSL命令行,生成SM2私钥并对文件进行签名验证(第二关)。

第一关在网上找到了,但第二关没找到,在这里做一下补充:) 如果想认真学的话可以看看文档 国密SM2椭圆曲线密码标准http://gmssl.org/docs/sm2.html 内容为 GuetPython 的明文文件msg.txt 私钥sm2.pem 公钥sm2Pub.pem 使用sm2utl对msg.txt进行签名&…

elementui el-table中给表头 el-table-column 加一个鼠标移入提示说明

前言 在使用el-table 表格中有些表格的表头需要加入一些提示&#xff0c;鼠标移入则出现提示&#xff0c;非常实用&#xff0c;我是通过el-table中的el-tooltip实现的&#xff0c;以下的效果预览 代码实现 <el-table ref"multipleTable" :data"data"…

在kile 5中一个新工程的创建

这两天博主学习到了在kile5中创建一个工程&#xff0c;当然博主不会忘了小伙伴们的&#xff0c;这就和你们分享。 本次创建以STM32F103C8为例 创建过程&#xff1a; 1首先创建文件 名字随意&#xff0c;但也不要太随意&#xff0c;因为是外国软件&#xff0c;所以多少对中文…

PortSwigger WEB缓存欺骗

一、Web缓存概述 Web 缓存是位于源服务器和用户之间的系统。当客户端请求静态资源时&#xff0c;请求首先被定向到缓存。如果缓存不包含资源的副本&#xff08;称为缓存未命中&#xff09;&#xff0c;则请求将转发到源服务器&#xff0c;该服务器将处理并响应请求。然后&#…

nginx源码安装配置ssl域名

nginx源码安装 下载 wget http://nginx.org/download/nginx-1.24.0.tar.gz 解压 tar -zxvf nginx-1.24.0.tar.gz 下载openssl apt install openssl 安装nginx cd nginx-1.24.0 sudo apt-get install libpcre3 libpcre3-dev ./configure --prefix=/home/nginx24 --with-http_ss…

如何用WordPress和Shopify提升SEO表现?

选择合适的建站程序对于SEO优化非常重要。目前&#xff0c;WordPress和Shopify是两种备受推崇的建站平台&#xff0c;各有优势。 WordPress最大的优点是灵活性。它支持大量SEO插件&#xff0c;帮助你调整元标签、生成站点地图、优化内容结构等。这些功能让你能够轻松地提升网站…

RHCE的学习(20)

变量5种赋值方式 shell中变量赋值5种方式&#xff0c;其中采用name10的方法称A 直接赋值 nameB read命令 read v1C 使用命令行参数 &#xff08;$1 $2 $3 ..&#xff09; name$1D 使用命令的输入 username$(whoami)E 从文件读取 #cut -d : -f1 /etc/passwd > /user.listfor…

llama factory lora 微调 qwen2.5 7B Instruct模型

项目背景 甲方提供一台三卡4080显卡 需要进行qwen2.5 7b Instruct模型进行微调。以下为整体设计。 要使用 LLaMA-Factory 对 Qwen2.5 7B Instruct模型 进行 LoRA&#xff08;Low-Rank Adapters&#xff09;微调&#xff0c;流程与之前提到的 Qwen2 7B Instruct 模型类似。LoRA …

Python酷库之旅-第三方库Pandas(221)

目录 一、用法精讲 1036、pandas.DatetimeIndex.to_pydatetime方法 1036-1、语法 1036-2、参数 1036-3、功能 1036-4、返回值 1036-5、说明 1036-6、用法 1036-6-1、数据准备 1036-6-2、代码示例 1036-6-3、结果输出 1037、pandas.DatetimeIndex.to_series方法 10…

108. UE5 GAS RPG 实现地图名称更新和加载关卡

在这一篇里&#xff0c;我们将实现对存档的删除功能&#xff0c;在删除时会有弹框确认。接着实现获取玩家的等级和地图名称和存档位置&#xff0c;我们可以通过存档进入游戏&#xff0c;玩家在游戏中可以在存档点存储存档。 实现删除存档 删除存档需要一个弹框确认&#xff0…

后台管理系统(开箱即用)

很久没有更新博客了&#xff0c;给大家带上一波福利吧,大佬勿扰 现在市面上流行的后台管理模板很多,若依,芋道等,可是这些框架对我们来说可能会有点重,所以我自己从0到1写了一个后台管理模板,你们使用时候可扩展性也会更高 项目主要功能: 成员管理&#xff0c;部门管理&#…

【ubuntu】Geogebra

Geogebra 几何作图工具 是一款跨平台的几何作图工具软件&#xff0c; 目前已经覆盖了&#xff0c; windows&#xff0c;android&#xff0c; mac, linux 等操作系统。 Ubuntu 现状 Ubuntu 自带应用市场 Ubuntu 自带应用市场目前只有 Geogebra 4.0 版本&#xff0c; 不能画立…