JAX深度学习库入门

JAX简介

https://www.bilibili.com/video/BV1Sb4y1b7rK/?spm_id_from=333.999.0.0&vd_source=b2549fdee562c700f2b1f3f49065201b

JAX is NumPy wiht Autograd , XLA and Composable (function) transformations, brought together for high-performance machine learning research on the CPU, GPU, and TPU.

  • JAX 可以进行异步调度。所以需要调用 .block_until_ready() 以确保计算已经实际发生。

  • XLA:XLA (Accelerated Linear Algebra)是Google为TensorFlow设计的一款编译器,主打JIT (Just-in-Time)编译和跨设备(CPU/GPU/TPU)执行,所以JAX介绍中凡是涉及到JIT、high-performance、CPU/GPU/TPU,都指的是XLA。使用XLA(一种加速线性代数计算的编译器)将Python和JAX代码JIT编译成优化的内核,可以在不同设备(例如gpu和tpu)上运行。而优化的内核是为高吞吐量设备(例如gpu和tpu)进行编译,它与主程序分离但可以被主程序调用。JIT编译可以用jax.jit()触发。XLA相对于JAX是一个更加底层的概念,JAX中的算子jax.lax看做是对XLA算子的Python封装。

from jax import lax
  • NumPy:NumPy就不用提了,Python生态下只要涉及到数据分析/机器学习/数值计算中对数组/tensor进行操作,都离不开它,不夸张的说,NumPy API已经成为了数组/tensor操作的半个工业标准,包括各家深度学习框架中对tensor操作的函数接口也都是尽量靠近NumPy,JAX则更夸张,jax.numpy重新实现一套了NumPy API ,让用户从NumPy无缝切入JAX,jax.numpy中的操作/算子是对更底层的jax.lax的封装,与 NumPy 数组不同,JAX 数组始终是不可变的,JAX 速度是 NumPy 的 150 倍以上:
from jax import numpy as jnp
  • Autograd:这里的Autograd是哈佛大学HIPS实验室在14年开始开发的一款自动微分框架,特点是可以对Python/NumPy函数进行高阶求导。自动微分框架除了可以应用于数值计算,它还是深度学习框架的核心,可惜的是,由于性能(纯Python,只有CPU版本)以及其他原因,autograd库并没有推广起来,但是它却实实在在启发到了后续的torch-autograd、Chainer以及PyTorch中的autograd模块。直接看个例子,一个简单的函数 f ( x ) f(x) f(x),顺便求一下一阶、二阶、三阶导函数:
    在这里插入图片描述
from autograd import grad

def f(x):
    return x**3 + 2*x

grad_f = grad(f)  # 一阶导函数
grad_grad_f = grad(grad_f)  # 两次grad组合,就是二阶导函数
grad_grad_grad_f = grad(grad_grad_f)  # 三次grad组合,就是三阶导函数
print(grad_f(2.), grad_grad_f(2.), grad_grad_grad_f(2.))
# 14.0 12.0 6.0
  • Composable (function) transformations (可组合的函数转换):composable transformations是JAX的核心,其实就是高阶函数 (Higher-order function),transformation的输入是Python函数,输出也是函数。JAX中经常用到的transformation主要有四个:

    • grad: reverse mode自动微分,用在深度学习中足够了
    • jit : JIT实时编译,调用XLA进行JIT编译,用于优化代码
    • vmap: vectorization/batching自动向量化/批处理,将函数扩展为支持批处理
    • pmap : parallelization并行化计算,轻松实现数据并行 (data parallelism),类似PyTorch的DistributedDataParallel
from jax import grad, jit, vmap, pmap

grad

grad只是JAX自动微分机制中最基本的一个transform,实际上JAX支持前向(forward-mode)自动微分后向(reverse-mode)自动微分以及二者的任意组合, 感兴趣的同学可以去查看jvp和vjp 的文档。考虑到常见的深度学习任务,grad绰绰有余, 其他transform这里就不介绍了,实际上是我没用过,压根没那个能力介绍。

from jax import numpy as jnp
from jax import grad

def f(x):
    return jnp.sum(x * x)  # 函数输出只能是标量

grad_f = grad(f)
grad_f(jnp.array([1, 2, 3.]))
# DeviceArray([2., 4., 6.], dtype=float32)

grad不但好用,而且数学上更直观,如果我们不局限在深度学习领域,从优化 (optimization)的角度看,大多数机器学习模型的学习都可以表示为: y ~ = f ( x ) ,   max ⁡ y p ( y ∣ x ) ,   max ⁡ y p ( x , y ) p ( x ) \tilde{y}=f(x), \ \max _{y} p(y \mid x), \ \max _{y} \frac{p(x, y)}{p(x)} y~=f(x), maxyp(yx), maxyp(x)p(x,y) 的一种

LR可以表示为 f ( x ) f(x) f(x) ,神经网络也可以表示为 f ( x ) f(x) f(x) ,损失函数是 l o s s = g ( f ( x ) , y ) loss=g(f(x),y) loss=g(f(x),y) ,如果用SGD算法来解决,需要计算参数的梯度,想一下高数课上我们是怎么做的,直接对损失函数求导函数 g r a d ( g ) grad(g) grad(g) ,然后代入 x x x ,现在 grad 用的就是这种方式。并且这种方式在数学上可以自然的泛化到高阶导数优化求解问题上。

jit

jit 是用户显式的调用XLA对代码进行优化(包括算子融合、内存优化等),执行时间可能缩短很多:

import numpy as np
from jax import numpy as jnp
from jax import jit

def norm(X):
    X = X - X.mean(0)
    return X / X.std(0)

norm_compiled = jit(norm)
X = jnp.array(np.random.rand(10000, 100))

%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
# 585 µs ± 85.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 216 µs ± 12.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

# 好像提升不是很显著,再来看一个例子
from jax import random

key = random.PRNGKey(0)

def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jit(selu)
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
%timeit selu_jit(x).block_until_ready()
# 1.06 ms ± 26.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# 187 µs ± 19.6 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

# 哦豁,效果还不错

vmap

vmap可以自动让函数支持batching,看个例子,原始函数表示的是向量-向量乘法,使用vmap可以得到矩阵-向量乘法的函数:

from jax import numpy as jnp
from jax import vmap


def vec_vec_dot(x, y):
    """vector-vector dot, ([a], [a]) -> []
    """
    return jnp.dot(x, y)

x = jnp.array([1,1,2])
y = jnp.array([2,1,1,])
vec_vec_dot(x, y)
# DeviceArray(5, dtype=int32)

mat_vec = vmap(vec_vec_dot, in_axes=(0, None), out_axes=0)  # ([b,a], [a]) -> [b]      (b is the mapped axis)
xx = jnp.array([[1,1,2], [1,1,2]])
mat_vec(xx, y)
# DeviceArray([5, 5], dtype=int32)

解释下vmap中的in_axesout_axees两个参数,前者表示对输入参数中哪一个的哪一维度进行batch扩充,这里(0, None)表示对x的第0维扩充,由原来的[a] -> [b,a]。后者表示对返回结果的哪一维度进行扩充,这里表示由原来的[] - > [b]

pmap

pmap让并行编程变的非常丝滑,可以用于数据并行训练,注意pmap包含了jit操作,下面我就在TPU v3-8 VM演示下:

import jax
from jax import numpy as jnp
from jax import pmap

jax.device_count()  # 8个core
# 8

jax.devices()
"""
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
"""

x = jnp.arange(8)
y = jnp.arange(8)

vmap(jnp.add)(x, y)
# DeviceArray([ 0,  2,  4,  6,  8, 10, 12, 14], dtype=int32)

pmap(jnp.add)(x, y)
# ShardedDeviceArray([ 0,  2,  4,  6,  8, 10, 12, 14], dtype=int32)

看到上面vmap和pmap执行后的区别没,一个返回数据类型是DeviceArray,一个则是SharedDeviceArray,后者表示数据分散在多个device中。

组合

上面介绍的transformation不仅仅可以单兵作战,最重要的是可以任意组合,比如

pmap(vamp(some_func))
jit(grad(grad(vmap(some_func))))

jaxpr

稍微聊一下transformation背后的故事,JAX中定义了一种中间表示语言(jaxpr),每个transformation的执行都分两步:

  • 先将原Python函数翻译为jaxpr,这个过程被称为"tracing"
  • 再对jaxpr进行transform (转换),可以将每个transformation看作一个独立的jaxpr interpreter,对于JAX中每个原子操作 (primitive)都有相应的转换规则

jaxpr的优势是语法简单,相比于直接对Python函数transform,对jaxpr进行transform容易得多。
在这里插入图片描述

Flax+JAX 实现 NN model

有了jax.numpyjax.gradjax.pmapjax.jit,现在就可以编写网络,实现训练过程了,但是想象下用NumPy实现一个ResNet,实现一个Transformer,能做,但是也太复杂了,但是我们可以使用Flax,一个基于JAX的NN library,来轻松实现网络训练流程。

https://zhuanlan.zhihu.com/p/544216783

Flax/JAX开发者大会

pmap 实现 数据并行

https://zhuanlan.zhihu.com/p/544349452

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/513908.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

hive词频统计---文件始终上传不来

目录 准备工作: 文件内容: 创建数据库及表 将文件上传到:上传到/user/hive/warehouse/db1.db/t_word目录下 hive里面查询,始终报错:(直接查询也是不行) 解决方案: 准备工作&am…

【Django学习笔记(三)】BootStrap介绍

BootStrap介绍 前言正文1、BootStrap 快速了解2、初识BootStrap2.1 下载地址2.2 创建目录2.3 引入BootStrap2.4 使用BootStrap 3、BootStrap 组件&样式3.1 导航条3.2 栅格系统3.3 container3.3.1 container3.3.2 container-fluid 3.4 面板3.5 媒体对象3.6 分页3.7 图标3.7.…

【协议篇:Http与Https】

1. Http 1.1 Http的定义 超文本传输协议(Hypertext Transfer Protocol,HTTP)是用于分布式、协作式和超媒体信息系统的应用层协议。它是互联网上最广泛应用的数据通信协议之一,尤其对于万维网(WWW)服务而言…

SQLite下一代查询规划器(十)

返回:SQLite—系列文章目录 上一篇:SQLite 查询优化器概述(九) 下一篇:SQLite—系列文章目录 1. 引言 “查询规划器”的任务是弄清楚 找出完成 SQL 语句的最佳算法或“查询计划”。 从 SQLite 版本 3.8.0 &am…

Redis的值有5种数据结构,不同数据结构的使用场景是什么?

文章目录 字符串缓存计数共享Session限速 哈希缓存 列表消息队列文章列表栈队列有限集合 集合标签抽奖社交需求 有序集合排行榜系统 字符串 缓存 (1)使用原生字符类型缓存 优点:简单直观,每个属性都支持更新操作 缺点&#xff1…

vue2源码解析——vue中如何进行依赖收集、响应式原理

vue每个组件实例vm都有一个渲染watcher。每个响应式对象的属性key都有一个dep对象。所谓的依赖收集,就是让每个属性记住它依赖的watcher。但是属性可能用在多个模板里,所以,一个属性可能对应多个watcher。因此,在vue2中&#xff0…

基于单片机的超声波测距仪设计_kaic

摘 要 如今社会持续深化转型,在人工智能领域,传感器采集外部数据,经过处理器对数 据运算和处理,从而实现相应的功能。比如自动驾驶技术中,超声波传感器应用广泛, 超声波是一种频率在 20khz 以上的声波&…

如何保护IP地址?安全匿名上网的方法

当互联网成为每个家庭的重要组成部分后,IP地址就成了你的虚拟地址。您的请求从该地址开始,然后 Internet 将消息发送回该地址。那么,您担心您的地址被泄露吗? 对于安全意识高或者某些业务需求的用户,如果您正在寻找保护…

Zabbix6 - Web管理网络拓扑/端口流量监控配置手册

Zabbix6 - Web管理网络拓扑/端口流量监控配置手册 概述: 1)Zabbix能监视各种网络参数,保证服务器系统的安全运营;并提供灵活的通知机制以让系统管理员快速定位/解决存在的各种问题。 Zabbix由两部分构成,Zabbix Server与可选组件Zabbix Agent。通过C/S模式采集数据,通过B…

WEB 工程路径

WEB 工程路径 相对路径 使用相对路径来解决, 一个非常重要的规则:页面所有的相对路径,在默认情况下,都会参考当前浏览器地址栏的路径 http://ip:port/工程名/ 资源来进行跳转。 相对路径带来的问题 如上图,若在a.h…

MySQL进阶-----前缀索引、单例与联合索引

目录 前言 一、前缀索引 1. 语法 2. 如何选择前缀长度 3. 前缀索引的查询流程 二、单列索引与联合索引 三、索引设计原则 前言 本期是MySQL进阶篇当中索引的最后一期内容,这里我们主要接着上一期继续讲解前缀索引、单例与联合索引。(上一期链接&…

02 Python进阶:CGI编程

什么是CGI CGI是通用网关接口(Common Gateway Interface)的缩写,它是一种标准协议,用于Web服务器执行外部程序或脚本与Web浏览器进行交互。通过CGI,Web服务器能够动态生成网页内容,处理用户提交的表单数据…

从零开始 使用OMNET++结合VEINS,INET和SUMO的联合仿真

背景知识 当我们探索未来的交通系统和智能交通解决方案时,车辆到一切(Vehicle-to-Everything, V2X)通信技术显得尤为重要。V2X是指在车辆与车辆(V2V)、车辆与基础设施(V2I)、车辆与行人&#x…

年薪50w的网络安全工程师是如何炼成的?

前言 一、什么是网络安全 网络安全可以基于攻击和防御视角来分类,我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术,而“蓝队”、“安全运营”、“安全运维”则研究防御技术。 无论网络、Web、移动、桌面、云等哪个领域,都有攻与防…

“三电”打不过极氪007、比亚迪海豹?我还是选小米SU7 Pro

文 | AUTO芯球 作者 | 雷歌 我真是越来越烦小米汽车的这帮脑残粉了。 小米SU7热闹的发布会过后, 友商们开始在价格上狙击小米SU7。 这两天, 问界M7降价2万,22万多就能提, 极氪001现在送1.6万的选装配置, 银河E…

ChatGPT 上线新功能:DALL·E 可以编辑图片了

ChatGPT 上线新功能:DALLE 可以编辑图片了。可以对生成的图片内容进行修改、添加和删除。 前几天看到消息说还在内测中,今天就体验上了。 这是官方文档:https://help.openai.com/en/articles/9055440-editing-your-images-with-dall-e 界面…

【fastadmin】脚本模式下,日志钩子函数执行出现死循环,导致内存溢出奔溃

问题出现原因是想对项目中error级别的日志,接入钉钉告警,方便查看 于是使用钩子方法,日志写入完成后,自动调用自定义的告警方法中 1、在application/tags.php 中添加log_write_done > [app\\common\\behavior\\Common, ],2、在…

GPTs构建广告文案Agent(只需要一个网址链接即可生成文案及配图)

在大家已经有账号的前提下,我们来看看怎么做!!! 进入GPTs的编辑界面 如下图: 如何配置呢? Name:给我们的GPTs起个名字。Description:简单介绍一下,我们创建的GPTs是…

家用洗地机选购指南,哪款洗地机性价比高且用户评价好?

在当今快节奏的生活中,无论是商业场所还是家庭,清洁环境都是我们追求的目标之一。而一台高性价比的洗地机在频繁清洁地面时尤为重要。市面上的洗地机种类繁多,如何选择适合自己的洗地机成为了一个挑战。那么,到底哪款牌子的洗地机…

隐语SecretFlow实训营-第9讲:隐语多方安全计算在安全核对的行业实践

业务背景:安全核对产生的土壤 行业背景: 隐私计算技术,实现数据可用不可见、可用不可得。 产品方案:从试点到规模化的路 基于隐语SCQL的数据比对应用:风洞隐私安全核对。 支持 1 to N的规模化核对。 特色功能&…