从 0 手撸一个 pytorch

背景介绍

最近抽空看了下 Andrej Karpathy 的视频教程 building micrograd,教程的质量很高。教程不需要任何前置机器学习基础,只需要有高中水平的数学基础即可。整个教程从 0 到 1 手撸了一个类 pytorch 的机器学习库 micrograd,核心代码不到 100 行。虽然为了简化没有实现复杂的矩阵运算,但是对于理解 pytorch 的设计思想有很大帮助。

动手实践

为了验证 micrograd 的可用性,先基于 micrograd 实现了简单的线性回归算法。

首先构造出数据集,我使用随机数作为 x,通过线性回归确定结果后增加必要的噪声,对应的构造方法如下所示:

import numpy as np

def get_train_dataset(num_samples, noise):
    x = np.random.rand(num_samples)

    y = 4 * x + 3 + np.random.normal(0, noise, num_samples)

    return x.tolist(), y.tolist()

可以看到最终期望的结果为 y = 4 * x + 3

接下来实现训练流程,线性回归的模型的初始值都使用随机值,持续跟踪训练过程中损失值与对应参数的变化,实现如下所示:

import numpy as np
from micrograd.engine import Value

def zero_grad(w, b):
    w.grad = 0
    b.grad = 0

def step(w, b, learning_rate):
    w.data -= learning_rate * w.grad
    b.data -=  learning_rate * b.grad

def train_loop():
    dataset_x, dataset_y = get_train_dataset(10, 0.01)
    learning_rate = 0.1
    w = Value(np.random.rand())
    b = Value(np.random.rand())
    epoch = 40

    print(f"Init w {w.data}, b {b.data}")
    for idx in range(epoch):
        loss = 0

        for x, y in zip(dataset_x, dataset_y):
            x_value, y_value = Value(x), Value(y)
            y_pred = x_value * w + b
            current_loss = (y_value - y_pred) ** 2
            loss += current_loss.data

            zero_grad(w, b)
            current_loss.backward()
            step(w, b, learning_rate)

        print(f"Epoch {idx} got loss: {loss}, w {w.data}, b {b.data}")

上面的实现中使用 zero_grad() 方法重置参数的梯度,使用 step() 方法实际更新模型参数,训练流程就实现在 train_loop() 中。最终结果如下所示:
在这里插入图片描述
可以看到经过 40 轮训练后,损失值从最初的 55.69 下降至 0.0016,而参数 w, b 也接近期望的目标。从实践结果来看,micrograd 确实能实现简单模型的训练。

通过上面的实践来看,micrograd 最核心的就是 Value,按照 Andrej Karpathy 的说法,不到 100 行实现的 Value 就已经完成的 pytorch 中的 Tensor 90% 的功能了,除了这部分核心功能之外,pytorch 更多的是做了效率上的优化。

流程梳理

在机器学习中,模型训练都是基于 梯度下降 来更新模型的。模型训练的过程一般分为前向传播和反向传播:

  • 前向传播会根据训练数据确定对应的损失值,对应于上面的实现如下:
x_value, y_value = Value(x), Value(y)
y_pred = x_value * w + b
current_loss = (y_value - y_pred) ** 2

前向传播就是根据模型确定预测值 y_pred, 基于 MSE 确定损失值 (y - y_pred)^2。前向传播相对容易理解。

  • 反向传播就是根据确定的损失值进行模型参数的调整,从而降低损失值,对应的实现就是:
zero_grad(w, b)
current_loss.backward()
step(w, b, learning_rate)

上面最核心的功能就是调用 current_loss.backward() 确定各个参数对应的梯度,然后在 step() 方法中对参数的值进行更新。

参数更新的方案是相对明确,就是减去梯度与学习率之积实现。因此主要关注如何确定参数的梯度。梯度的计算存在如下所示的关注点:

  1. 数学运算各个元素对应的梯度如何计算,这部分就是微积分中导数的计算;
  2. 链式法则;
  3. 复杂模型中包含上亿参数,如何确定参数各自的梯度;

实现细节

micrograd 最核心的实现位于 engine.py,主要关注 Value 类的实现。

初始化过程

关注初始化过程可以看到 Value 中包含的元素,实现如下:

def __init__(self, data, _children=(), _op=''):
    self.data = data
    self.grad = 0
    self._backward = lambda: None
    self._prev = set(_children)
    self._op = _op # the op that produced this node, for graphviz / debugging / etc

初始化阶段可以看到 Value 中最重要的两个参数,data 保存的是元素中的原始数据,grad 保存的是当前元素对应的梯度。

_backward() 方法保存的是反向传播的方法,用于计算反向传播的梯度

_prev 保存的是当前节点前置的节点,比如 y = w * x 中节点 y 对应的 _prev 保存的是 wx。通过不断的获取 _prev 节点,即可还原完整的运算链路。

数学运算支持

Value 中支持了不同的数学运算,首先以加法为例,实现如下所示:

def __add__(self, other):
    other = other if isinstance(other, Value) else Value(other)

    # 加法运算得到结果,同样是 Value 元素

    out = Value(self.data + other.data, (self, other), '+')

    # 加法反向传播函数

    def _backward():
        self.grad += out.grad
        other.grad += out.grad
    out._backward = _backward

    return out

前向传播计算的实现比较简单,直接基于 data 进行计算,通过加法运算生成了结果 out。同时将参与运算的元素 selfother 保存至 self._prev 中,方便还原运算链路。

out 对应的反向传播的方法 _backward() 是基于链式法则实现。举例如下:

c = a + b

那么 ∂l/∂a = ∂l/∂c * ∂c/∂a,而 ∂c/∂a = 1,因此 ∂l/∂a = ∂l/∂c,因此加法中元素的梯度就等于其结果的梯度。

那么为什么实现是 self.grad += out.grad 而不是 self.grad = out.grad 呢,因为单个元素涉及多个运算链路时,梯度是不同链路确定的梯度之和。

这个也带来一个隐患,每次重新计算梯度之前,需要将原有的梯度重置为 0。对应于上面的 zero_grad() 的实现。了解 pytorch 应该也会注意到 pytorch 训练过程中也存在类似情况。

同样来查看乘法运算,对应的实现如下:


def __mul__(self, other):
    other = other if isinstance(other, Value) else Value(other)
    out = Value(self.data * other.data, (self, other), '*')

    def _backward():
        self.grad += other.data * out.grad
        other.grad += self.data * out.grad
    out._backward = _backward

    return out

主要关注反向传播的实现,可以看到同样是链路法则的推演,举例如下:

c = a * b

那么 ∂l/∂a = ∂l/∂c * ∂c/∂a,而 ∂c/∂a = b, 因此 ∂l/∂a = ∂l/∂c * b, 因此就可以理解上面的实现了。

反向传播

通过上面的运算过程可以看到,通过不断保存其前置元素至 self._prev 中,可以构建出完整的运算链路图。而在运算过程中,元素反向传播计算的梯度的方法 _backward() 也被确定。因此反向传播就是从后往前调用 _backward() 来实现的:


def backward(self):

    topo = []
    visited = set()
    # 根据前置元素的关系构建拓扑排序的元素列表,保证最终调用时是从后往前的

    def build_topo(v):
        if v not in visited:
            visited.add(v)
            for child in v._prev:
                build_topo(child)
            topo.append(v)
    build_topo(self)

    # 最后元素的梯度为 1, 依次计算前置元素的梯度

    self.grad = 1
    for v in reversed(topo):
        v._backward()

最终反向传播就是调用 _backward() 即可确定各个元素的梯度。

总结

通过上面的流程可以很容易理解机器学习模型训练框架的设计方案,这一套流程也完全适用于 pytorch,可以帮助更好地理解 pytorch 的训练流程。整体总结下实现思路:

  1. 前向传播过程中会逐层计算运行结果,并确定结果与运算元素梯度之前的关系,在结果元素梯度确定后就可以确定运算元素的梯度;
  2. 反向传播就是按照从后往前依次确认各个元素的梯度,方便后续根据梯度更新元素对应的值;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/657927.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SAP PP学习笔记 - 错误 CX_SLD_API_EXCEPTION - Job dump is not fully saved (too big)

我这个错误是跑完MRP,然后在MD04查看在库/所有量一览, 点计划手配(Planned order 计划订单)生成 制造指图(Production order 生产订单), 到目前这几步都OK,然后在制造指图界面点保…

Linux之sshpass命令

介绍 sshpass是一个工具,用于通过SSH连接到远程服务器时自动输入密码。它允许您在命令行中指定密码,以便在建立SSH连接时自动进行身份验证。 安装 # 以centos为例 yum install sshpass -y 使用方法 sshpass [-f filename | -d num | -p password | …

C++笔试强训day35

目录 1.奇数位丢弃 2.求和 3.计算字符串的编辑距离 1.奇数位丢弃 链接https://www.nowcoder.com/practice/196141ecd6eb401da3111748d30e9141?tpId128&tqId33775&ru/exam/oj 数据量不大&#xff0c;可以直接进行模拟&#xff1a; #include <iostream> #incl…

MQTT 5.0 报文解析 06:AUTH

欢迎阅读 MQTT 5.0 报文系列 的最后一篇文章。在上一篇中&#xff0c;我们已经介绍了 MQTT 5.0 的 DISCONNECT 报文。现在&#xff0c;我们将介绍 MQTT 中的最后一个控制报文&#xff1a;AUTH。 MQTT 5.0 引入了增强认证特性&#xff0c;它使 MQTT 除了简单密码认证和 Token 认…

没有可用软件包 docker-ce。 错误:无须任何处理

特么的各种百度查看&#xff0c;全是一些废话&#xff01;&#xff01;&#xff01;centos7安装不上docker&#xff0c;都是老的代码了&#xff1a; yum install docker-ce 解决方案&#xff1a; # CentOS yum install docker-io

Jetson Orin Nano v6.0 + tensorflow2.15.0+nv24.05 GPU版本安装

Jetson Orin Nano v6.0 tensorflow2.15.0nv24.05 GPU版本安装 1. 源由2. 步骤2.1 Step1&#xff1a;系统安装2.2 Step2: nvidia-jetpack安装2.3 Step3&#xff1a;jtop安装2.4 Step4&#xff1a;h5py安装2.5 Step5&#xff1a;tensorflow安装2.6 Step6&#xff1a;jupyterlab安…

Windows搭建Nginx代理本地盘的文件(共享路径或本地路径)

文章目录 Windows搭建Nginx代理本地盘的文件 - 前言需求背景挂载网络共享路径检查连接状态下载Nginx编辑 Nginx 配置文件启动 Nginx检测Nginx是否成功启动使用方法远程共享路径示例本地文件示例 测试 Windows搭建Nginx代理本地盘的文件 - 前言 在开发过程中&#xff0c;确保文…

广东省保健食品行业协会批复成为“世界酒中国菜”活动指导单位

广东省保健食品行业协会正式批复成为“世界酒中国菜”系列活动指导单位&#xff0c;共促餐饮文化交流发展 近日&#xff0c;广东省保健食品行业协会正式批复荐酒师国际认证&#xff08;广州&#xff09;有限公司&#xff0c;成为备受瞩目的“世界酒中国菜”系列活动的指导单位…

8.2 数组遍历访问

本节必须掌握的知识点&#xff1a; 示例三十 代码分析 汇编解析 在上一节中介绍了数组相关的概念&#xff0c;而在本节中将介绍数组的使用。 8.2.1 示例三十 ■访问数组 示例代码三十 ●第一步&#xff1a;分析需求&#xff0c;设计程序…

基于C++11实现的手写线程池

在实际的项目中&#xff0c;使用线程池是非常广泛的&#xff0c;所以最近学习了线程池的开发&#xff0c;在此做一个总结。 源码&#xff1a;https://github.com/Cheeron955/Handwriting-threadpool-based-on-C-17 项目介绍 项目分为两个部分&#xff0c;在初版的时候&#x…

STM32——定时器

一、简介 *定时器可以对输入的时钟进行计数&#xff0c;并在计数值达到设定值时触发中断 *16位计数器、预分频器、自动重装寄存器的时基单元&#xff0c;在72MHz计数时钟下可以实现最大59.65s的定时 *不仅具备基本的定时中断功能&#xff0c;而且还包含内外时钟源选择、输入…

ubuntu使用oh my zsh美化终端

ubuntu使用oh my zsh美化终端 文章目录 ubuntu使用oh my zsh美化终端1. 安装zsh和oh my zsh2. 修改zsh主题3. 安装zsh插件4. 将.bashrc移植到.zshrcReference 1. 安装zsh和oh my zsh 首先安装zsh sudo apt install zsh然后查看本地有哪些shell可以使用 cat /etc/shells 将默…

平方回文数-第13届蓝桥杯选拔赛Python真题精选

[导读]&#xff1a;超平老师的Scratch蓝桥杯真题解读系列在推出之后&#xff0c;受到了广大老师和家长的好评&#xff0c;非常感谢各位的认可和厚爱。作为回馈&#xff0c;超平老师计划推出《Python蓝桥杯真题解析100讲》&#xff0c;这是解读系列的第73讲。 平方回文数&#…

监控云安全的9个方法和措施

如今&#xff0c;很多企业致力于提高云计算安全指标的可见性&#xff0c;这是由于云计算的安全性与本地部署的安全性根本不同&#xff0c;并且随着企业将应用程序、服务和数据移动到新环境&#xff0c;需要不同的实践。检测云的云检测就显得极其重要。 如今&#xff0c;很多企业…

windows tomcat服务注册和卸载

首页解压tomcat压缩包&#xff0c;然后进入tomcat bin目录&#xff0c;在此目录通过cmd进入窗口&#xff0c; 1&#xff1a;tomcat服务注册 执行命令&#xff1a;service.bat install tomcat8.5.100 命令执行成功后&#xff0c;会在注册服务列表出现这个服务&#xff0c;如果…

打造爆款活动:确定目标受众与吸引策略的实战指南

身为一名文案策划经理&#xff0c;我深知在活动策划的海洋中&#xff0c;确定目标受众并设计出能触动他们心弦的策略是何等重要。 通过以下步骤&#xff0c;你可以更准确地确定目标受众&#xff0c;并制定出有效的吸引策略&#xff0c;确保活动的成功&#xff1a; 明确活动目…

Unity【入门】环境搭建、界面基础、工作原理

Unity环境搭建、界面基础、工作原理 Unity环境搭建 文章目录 Unity环境搭建1、Unity引擎概念1、什么是游戏引擎2、游戏引擎对于我们的意义3、如何学习游戏引擎 2、软件下载和安装3、新工程和工程文件夹 Unity界面基础1、Scene场景和Hierarchy层级窗口1、窗口布局2、Hierarchy层…

企业如何实现数据采集分析展示一体化

在当今数字化时代&#xff0c;企业越来越依赖于数据的力量来驱动决策和创新。通过全量实时采集各类数据&#xff0c;并利用智能化工具进行信息处理&#xff0c;企业能够借助大数据分析平台深入挖掘数据背后的价值&#xff0c;从而为企业发展注入新动力。 一、企业痛点 随着数字…

基于单片机智能防触电装置的研究与设计

摘 要 &#xff1a; 针对潮湿天气下配电线路附近易发生触电事故等问题 &#xff0c; 对单片机的控制算法进行了研究 &#xff0c; 设 计 了 一 种 基 于 单片机的野外智能防触电装置。 首先建立了该装置的整体结构框架 &#xff0c; 再分别进行硬件设计和软件流程分析 &#xf…

水电表远程抄表:智能化时代的能源管理新方式

1.行业背景与界定 水电表远程抄表&#xff0c;是随着物联网技术发展&#xff0c;完成的一种新型的能源计量管理方式。主要是通过无线传输技术&#xff0c;如GPRS、NB-IoT、LoRa等&#xff0c;将水电表的信息实时传输到云服务器&#xff0c;进而取代了传统人工当场抄水表。这种…