【一起撸个DL框架】5 实现:自适应线性单元

  • CSDN个人主页:清风莫追
  • 欢迎关注本专栏:《一起撸个DL框架》
  • GitHub获取源码:https://github.com/flying-forever/OurDL

文章目录

  • 5 实现:自适应线性单元🍇
    • 1 简介
    • 2 损失函数
      • 2.1 梯度下降法
      • 2.2 补充
    • 3 整理项目结构
    • 4 损失函数的实现
    • 5 修改节点类(Node)
    • 6 自适应线性单元

5 实现:自适应线性单元🍇

1 简介

上一篇:【一起撸个DL框架】4 反向传播求梯度

上一节我们实现了计算图的反向传播,可以求结果节点关于任意节点的梯度。下面我们将使用梯度来更新参数,实现一个简单的自适应线性单元

我们本次拟合的目标函数是一个简单的线性函数: y = 2 x + 1 y=2x+1 y=2x+1,通过随机数生成一些训练数据,将许多组x和对应的结果y值输入模型,但是并不告诉模型具体函数中的系数参数“2”和偏置参数“1”,看看模型能否通过数据“学习”到参数的值。

图1:自适应线性单元的计算图

2 损失函数

2.1 梯度下降法

损失是对模型好坏的评价指标,表示模型输出结果与正确答案(也称为标签)之间的差距。所以损失值越小就说明模型越准确,训练过程的目的便是最小化损失函数的值

自适应线性单元是一个回归任务,我们这里将使用绝对值损失,将模型输出与正确答案之间的差的绝对值作为损失函数的值,即 l o s s = ∣ l − a d d ∣ loss=|l-add| loss=ladd

评价指标有了,可是如何才能达标呢?或者说如何才能降低损失函数的值?计算图中有四个变量: x , w , b , l x,w,b,l x,w,b,l,而我们训练过程的任务是调整参数 w , b w,b w,b的值,以降低损失。因此训练过程中的自变量是w和b,而把x和l看作常量。此时损失函数是关于w和b的二元函数 l o s s = f ( w , b ) loss=f(w,b) loss=f(w,b),我们只需要求函数的梯度 ▽ f ( w , b ) = ( ∂ f ∂ w , ∂ f ∂ b ) \triangledown f(w,b)=(\frac{\partial f}{\partial w},\frac{\partial f}{\partial b}) f(w,b)=(wf,bf),则梯度的反方向就是函数下降最快的方向。沿着梯度的方向更新参数w和b的值,就可以降低损失。这就是经典的优化算法:梯度下降法

2.2 补充

关于损失和优化的概念,大家可能还是有些模糊。上面损失只讲到了一个输入x值对应的模型输出与实际结果之间的差距,但使用整个数据集的平均差距可能更容易理解,就像中学的线性回归

图2所示,改变直线的斜率w,将改变直线与数据点的贴近程度,即改变了损失函数loss的值。

在这里插入图片描述
图2:损失与参数更新示意图

参考: 【深度学习】3-从模型到学习的思路整理_清风莫追的博客-CSDN博客

3 整理项目结构

我们的小项目的代码也渐渐多起来了,好的目录结构将使它更加易于扩展。关于python包结构的知识大家可以自行去了解,大致目录结构如下:

- example
- ourdl
	- core
		- __init__.py
		- node.py
	- ops
		- __init__.py
		- loss.py
		- ops.py
	__init__.py

给这个简单框架的名字叫做OurDL,使用框架搭建的计算图等程序放在example目录下。在ourdl/core/node.py中存放了节点基类和变量类的定义,在ourdl/ops/下存放了运算节点的定义,包括损失函数和加法、乘法节点等。

4 损失函数的实现

/ourdl/ops/loss.py中,

from ..core import Node

class ValueLoss(Node):
    '''损失函数:作差取绝对值'''
    def compute(self):
        self.value = self.parent1.value - self.parent2.value
        self.flag = self.value > 0
        if not self.flag:
            self.value = -self.value
    def get_parent_grad(self, parent):
        a = 1 if self.flag else -1
        b = 1 if parent == self.parent1 else -1
        return a * b

其中compute()方法很显然就是对两个输入作差取绝对值;get_parent_grad()方法求本节点关于父节点的梯度。有绝对值如何求梯度?大家可以画一画绝对值函数的图像。

5 修改节点类(Node)

ourdl/core/node.py

class Node:
    pass  # 省略了一些方法的定义,大家可以查看上一篇文章

    def clear(self):
        '''递归清除父节点的值和梯度信息'''
        self.grad = None
        if self.parent1 is not None:  # 清空非变量节点的值
            self.value = None
        for parent in [self.parent1, self.parent2]:
            if parent is not None:
                parent.clear()
    def update(self, lr=0.001):
        '''根据本节点的梯度,更新本节点的值'''
        self.value -= lr * self.grad  # 减号表示梯度的反方向

我在节点类中新增了两个方法,其中clear()用于清除多余的节点值和梯度信息,因为当节点值或梯度已经存在时会直接返回结果而不会递归去求了(get_grad()forward()的代码)。update()有一个学习率参数lr,更新幅度太大可能导致参数值一直在目标值左右晃悠,无法收敛

6 自适应线性单元

/example/01_esay/自适应线性单元.py

import sys
sys.path.append('../..')
from ourdl.core import Varrible
from ourdl.ops import Mul, Add
from ourdl.ops.loss import ValueLoss

if __name__ == '__main__':
    # 搭建计算图
    x = Varrible()
    w = Varrible()
    mul = Mul(parent1=x, parent2=w)
    b = Varrible()
    add = Add(parent1=mul, parent2=b)
    label = Varrible()
    loss = ValueLoss(parent1=label, parent2=add)
    # 参数初始化
    w.set_value(0)
    b.set_value(0)
    # 生成训练数据
    import random
    data_x = [random.uniform(-10, 10) for i in range(10)]  # 按均匀分布生成[-10, 10]范围内的随机实数
    data_label = [2 * data_x_one + 1 for data_x_one in data_x]
    # 开始训练
    for i in range(len(data_x)):
        x.set_value(data_x[i])
        label.set_value(data_label[i])
        loss.forward()  # 前向传播 --> 求梯度会用到损失函数的值
        w.get_grad()
        b.get_grad()
        w.update(lr=0.05)
        b.update(lr=0.1)
        loss.clear()
        print("w:{:.2f}, b:{:.2f}".format(w.value, b.value))
    print("最终结果:{:.2f}x+{:.2f}".format(w.value, b.value))
    

运行结果:

w:0.13, b:0.10
w:0.36, b:0.20
w:0.58, b:0.10
w:0.74, b:0.00
w:1.13, b:0.10
w:1.43, b:0.20
w:1.62, b:0.30
w:1.94, b:0.20
w:1.50, b:0.30
w:1.87, b:0.40
最终结果:1.87x+0.40

上面自适应线性单元的训练,已经能够大致展现深度学习模型的训练流程:

  • 搭建模型 --> 初始化参数 --> 准备数据 --> 使用数据更新参数的值

我们这里参数只更新了10次,结果就已经大致接近了我们的目标函数 y = 2 x + 1 y=2x+1 y=2x+1。大家可以试试更改学习率lr,训练数据集的大小,观察运行结果会发生怎样的变化。(必备技能:调参)


下节预告:激活函数与计算图的非线性拟合能力

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/19209.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

DAD-DAS模型

DAD-DAS模型 文章目录 DAD-DAS模型[toc]1 产品服务:需求方程2 实际利率:费雪方程3 通货膨胀:菲利普斯方程4 预期通货膨胀:适应性预期5 货币政策规则:泰勒方程6 动态总供给-总需求方程(DAS-DAD)7 总供给冲击模拟 1 产品服务:需求方…

Elasticsearch:NLP 和 Elastic:入门

自然语言处理 (Natural Language Processing - NLP) 是人工智能 (AI) 的一个分支,专注于尽可能接近人类解释的理解人类语言,将计算语言学与统计、机器学习和深度学习模型相结合。 AI - Artificial Inteligence 人工智能ML - Machine Learning 机器学习DL…

永远不该忘记!科技才是硬道理,手中没有剑,跟有剑不用,是两回事

今天是全国防灾减灾日,距离2008年汶川大地震也已经过去15年了。但时至今日,看到那些图像视频资料,那种触及灵魂的疼痛仍是存在的,2008年的大地震在每个中国人身上都留下了无法抚平的伤疤。 2008年是所有中国人都无法忘记的一年&am…

Ims跟2/3G会议电话(Conference call)流程差异介绍

2/3G Conference call 合并(Merged)通话前,两路电话只能一路保持(Hold),一路通话(Active)。 主叫Merged操作,Hold的一路会变成Active,进入会议通话。 例如终端A跟C通话,再跟B通话,此时B就是Active状态,C从Active变成Hold状态。Merged进入会议通话后,C又从Hold变…

docker安装elasticsearch

前言 安装es么,也没什么难的,主要网上搜一搜,看看文档,但是走过的坑还是需要记录一下的 主要参考这三份文档: Running the Elastic Stack on Docker docker简易搭建ElasticSearch集群 Running Kibana on Docker …

Python-exe调用-控制台命令行执行-PyCharm刷新文件夹

文章目录 1.控制台命令行执行1.1.subprocess.Popen1.2.os.system()1.3.subprocess.getstatusoutput()1.4.os.popen() 2.PyCharm刷新文件夹3.作者答疑 1.控制台命令行执行 主要四种方式实现。 1.1.subprocess.Popen import os import subprocess cmd "project1.exe&qu…

只下载rpm包而不安装(用于内网虚拟机使用)

这里写目录标题 问题:解决:1. 安装yum-utils2. 下载rpm包3. 将rpm包拷贝到离线的虚拟机并安装 遇到的问题:1. error while loading shared libraries: libXXX.so.X: cannot open shared object file: No such file2. wrong ELF class: ELFCLA…

C++学习day--10 条件判断、分支

1、if语句 if 语句的三种形态 形态1&#xff1a;如果。。。那么。。。 #include <iostream> using namespace std; int main( void ) { int salary; cout << " 你月薪多少 ?" ; cin >> salary; if (salary < 20000) { cout <&…

【博客系统】页面设计(附完整源码)

&#x1f389;&#x1f389;&#x1f389;点进来你就是我的人了博主主页&#xff1a;&#x1f648;&#x1f648;&#x1f648;戳一戳,欢迎大佬指点! 欢迎志同道合的朋友一起加油喔&#x1f93a;&#x1f93a;&#x1f93a; 目录 一、页面介绍 二、预期效果 1、博客列表页效…

大项目准备(2)

目录 中国十大最具发展潜力城市 docker是什么&#xff1f;能介绍一下吗&#xff1f; 中国十大最具发展潜力城市 按照人随产业走、产业决定城市兴衰、规模经济和交通成本等区位因素决定产业布局的基本逻辑&#xff0c;我们在《中国城市发展潜力排名&#xff1a;2022》研究报告…

websocket

&#x1f449;websocket_菜鸟教程*…*的博客-CSDN博客 目录 1、什么是Socket&#xff1f;什么是WebSocket&#xff1f; 2、WebSocket的通信原理和机制 3、WebSocket技术出现之前&#xff0c;Web端实现即时通讯的方法有哪些&#xff1f; 4、一个简单的WebSocket聊天小例子 …

prometheus监控数据持久化

前置条件 1.规划两台主机安装prometheus # kubectl get nodes --show-labels | grep prometheus nm-foot-gxc-proms01 Ready worker 62d v1.23.6 beta.kubernetes.io/archamd64,beta.kubernetes.io/oslinux,kubernetes.io/archamd64,kubernetes.io…

5款办公必备的好软件,你值得拥有

随着网络信息技术的发展&#xff0c;越来越多的人在办公时需要用到电脑了。如果你想提高办公效率&#xff0c;那么就少不了工具的帮忙&#xff0c;今天给大家分享5款办公必备的好软件。 1.文件管理工具——TagSpaces TagSpaces 是一款开源的文件管理工具,它可以通过标签来组织…

Linux一学就会——系统文件I/O

Linux一学就会——系统文件I/O 有几种输出信息到显示器的方式 #include <stdio.h> #include <string.h> int main() {const char *msg "hello fwrite\n";fwrite(msg, strlen(msg), 1, stdout);printf("hello printf\n");fprintf(stdout, &q…

体验洞察 | 原来它才是最受欢迎的CX指标?

一直以来&#xff0c;企业都在试图追踪他们能否在整个客户旅程中始终如一地提供卓越的客户体验&#xff08;Customer Experience&#xff0c;简称“CX”&#xff09;&#xff0c;并通过多个CX指标&#xff0c;如NPS&#xff08;净推荐值&#xff09;、CSAT&#xff08;客户满意…

openGL 环境搭建

刚入坑&#xff0c;每个包、每个项目都得重新配一遍&#xff0c;实在烦人&#xff0c;由于网上已有很多教程&#xff0c;故在此只简要介绍。 比较通用的安装方法如下&#xff1a; 优先下载&#xff0c;对应vs版本&#xff0c;32位&#xff0c;已经编译好的库。如果下载的是源代…

Java 远程debug,IDEA 远程 Debug 调试

有时候我们需要进行远程的debug&#xff0c;本文研究如何进行远程debug&#xff0c;以及使用 IDEA 远程debug的过程中的细节。看完可以解决你的一些疑惑。 配置 远程debug的服务&#xff0c;以SpringBoot微服务为例。 首先&#xff0c;启动SpringBoot需要加上特定的参数。 …

网页端操作提示「msg.js」库简介

这段时间我正在完成我的第一本个人图书&#xff0c;期间做了很多的案例&#xff0c;最近需要在网页端完成一个关于「恶意文本检测」的案例&#xff0c;为了让该案例表现的更加易用简洁、对用户友好&#xff0c;我需要在页面中添加一些用户操作提示信息&#xff0c;比如「正在加…

最适合家用的洗地机哪个牌子好?2023洗地机推荐

洗地机是目前众多清洁工具中的热门之选&#xff0c;我身边很多朋友都选择了洗地机来处理家居清洁&#xff0c;一说一&#xff0c;洗地机可以处理干湿垃圾&#xff0c;还都有一键自清洁功能&#xff0c;用起来确实方便简单。不过&#xff0c;市面上的洗地机参差不齐&#xff0c;…

QT软件开发: 获取CPU序列号、硬盘序列号、主板序列号 (采用wmic命令)

[TOC](QT软件开发: 获取CPU序列号、硬盘序列号、主板序列号 (采用wmic命令)) [1] QT软件开发: 获取CPU序列号、硬盘序列号、主板序列号 (采用wmic命令) https://blog.51cto.com/xiaohaiwa/5380259 一、环境介绍 QT版本: 5.12.6 环境: win10 64位 编译器: MinGW 32 二、功…