MLX vs MPS vs CUDA:苹果新机器学习框架的基准测试

如果你是一个Mac用户和一个深度学习爱好者,你可能希望在某些时候Mac可以处理一些重型模型。苹果刚刚发布了MLX,一个在苹果芯片上高效运行机器学习模型的框架。

最近在PyTorch 1.12中引入MPS后端已经是一个大胆的步骤,但随着MLX的宣布,苹果还想在开源深度学习方面有更大的发展。

在本文中,我们将对这些新方法进行测试,在三种不同的Apple Silicon芯片和两个支持cuda的gpu上和传统CPU后端进行基准测试。

这里把基准测试集中在图卷积网络(GCN)模型上。这个模型主要由线性层组成,所以对于其他的模型也应该得到类似的结果。

创造环境

要为MLX构建环境,我们必须指定是使用i386还是arm架构。使用conda,可以使用:

 CONDA_SUBDIR=osx-arm64 conda create -n mlx python=3.10 numpy pytorch scipy requests -c conda-forge
 conda activate mlx

如果检查你的env是否实际使用了arm,下面命令的输出应该是arm,而不是i386(因为我们用的Apple Silicon):

 python -c "import platform; print(platform.processor())"

然后就是使用pip安装MLX:

 pip install mlx

GCN模型

GCN模型是图神经网络(GNN)的一种,它使用邻接矩阵(表示图结构)和节点特征。它通过收集邻近节点的信息来计算节点嵌入。每个节点获得其邻居特征的平均值。这种平均是通过将节点特征与标准化邻接矩阵相乘来完成的,并根据节点度进行调整。为了学习这个过程,特征首先通过线性层投射到嵌入空间中。

我们将使用MLX实现一个GCN层和一个GCN模型:

 import mlx.nn as nn
 
 class GCNLayer(nn.Module):
     def __init__(self, in_features, out_features, bias=True):
         super(GCNLayer, self).__init__()
         self.linear = nn.Linear(in_features, out_features, bias)
 
     def __call__(self, x, adj):
         x = self.linear(x)
         return adj @ x
 
 class GCN(nn.Module):
     def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):
         super(GCN, self).__init__()
 
         layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]
         self.gcn_layers = [
             GCNLayer(in_dim, out_dim, bias)
             for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])
         ]
         self.dropout = nn.Dropout(p=dropout)
 
     def __call__(self, x, adj):
         for layer in self.gcn_layers[:-1]:
             x = nn.relu(layer(x, adj))
             x = self.dropout(x)
 
         x = self.gcn_layers[-1](x, adj)
         return x

可以看到,mlx的模型开发方式与tf2基本一样,都是调用

__call__

进行前向传播,其实torch也一样,只不过它自定义了一个forward函数。

下面就是训练

 gcn = GCN(
     x_dim=x.shape[-1],
     h_dim=args.hidden_dim,
     out_dim=args.nb_classes,
     nb_layers=args.nb_layers,
     dropout=args.dropout,
     bias=args.bias,
 )
 mx.eval(gcn.parameters())
 
 optimizer = optim.Adam(learning_rate=args.lr)
 loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
 
 # Training loop
 for epoch in range(args.epochs):
 
     # Loss
     (loss, y_hat), grads = loss_and_grad_fn(
         gcn, x, adj, y, train_mask, args.weight_decay
     )
     optimizer.update(gcn, grads)
     mx.eval(gcn.parameters(), optimizer.state)
 
     # Validation
     val_loss = loss_fn(y_hat[val_mask], y[val_mask])
     val_acc = eval_fn(y_hat[val_mask], y[val_mask])

在MLX中,计算是惰性的,这意味着eval()通常用于在更新后实际计算新的模型参数。而另一个关键函数是nn.value_and_grad(),它生成一个计算参数损失的函数。第一个参数是保存当前参数的模型,第二个参数是用于前向传递和损失计算的可调用函数。它返回的函数接受与forward函数相同的参数(在本例中为forward_fn)。我们可以这样定义这个函数:

 def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
     y_hat = gcn(x, adj)
     loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters())
     return loss, y_hat

它仅仅包括计算前向传递和计算损失。Loss_fn()和eval_fn()定义如下:

 def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
     l = mx.mean(nn.losses.cross_entropy(y_hat, y))
 
     if weight_decay != 0.0:
         assert parameters != None, "Model parameters missing for L2 reg."
 
         l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()
         return l + weight_decay * l2_reg
 
     return l
 
 def eval_fn(x, y):
     return mx.mean(mx.argmax(x, axis=1) == y)

损失函数是计算预测和标签之间的交叉熵,并包括L2正则化。由于L2正则化还不是内置特性,需要手动实现。

本文的完整代码:https://github.com/TristanBilot/mlx-GCN

可以看到除了一些细节函数调用的差别,基本的训练流程与pytorch和tf都很类似,但是这里的一个很好的事情是消除了显式地将对象分配给特定设备的需要,就像我们在PyTorch中经常使用.cuda()和.to(device)那样。这是因为苹果硅芯片的统一内存架构,所有变量共存于同一空间,也就是说消除了CPU和GPU之间缓慢的数据传输,这样也可以保证不会再出现与设备不匹配相关的烦人的运行时错误。

基准测试

我们将使用MLX与MPS, CPU和GPU设备进行比较。我们的测试平台是一个2层GCN模型,应用于Cora数据集,其中包括2708个节点和5429条边。

对于MLX, MPS和CPU测试,我们对M1 Pro, M2 Ultra和M3 Max进行基准测试。在两款NVIDIA V100 PCIe和V100 NVLINK上进行测试

MPS:比M1 Pro的CPU快2倍以上,在其他两个芯片上,与CPU相比有30-50%的改进。

MLX:比M1 Pro上的MPS快2.34倍。与MPS相比,M2 Ultra的性能提高了24%。在M3 Pro上MPS和MLX之间没有真正的改进。

CUDA V100 PCIe & NVLINK:只有23%和34%的速度比M3 Max与MLX,这里的原因可能是因为我们的模型比较小,所以发挥不出V100和NVLINK的优势(NVLINK主要GPU之间的数据传输大的情况下会有提高)。这也说明了苹果的统一内存架构的确可以消除CPU和GPU之间缓慢的数据传输。

总结

与CPU和MPS相比,MLX可以说是非常大的金币,在小数据量的情况下它甚至接近特斯拉V100的性能。也就是说我们可以使用MLX跑一些不是那么大的模型,比如一些表格数据。

从上面的基准测试也可以看到,现在可以利用苹果芯片的全部力量在本地运行深度学习模型(我一直认为MPS还没发挥苹果的优势,这回MPS已经证明了这一点)。

MLX刚刚发布就已经取得了惊人的影响力,并展示了巨大的潜力。相信未来几年开源社区的进一步增强,可以期待在不久的将来更强大的苹果芯片,将MLX的性能提升到一个全新的水平。

另外也说明了MPS(虽然也发布不久)还是有巨大的发展空间的,毕竟切换框架是一件很麻烦的事情,如果MPS能达到MLX 80%或者90%的速度,我想不会有人去换框架的。

最后说到框架,现在已经有了Pytorch,TF,JAX,现在又多了一个MLX。各种设备、各种后端包括:TPU(pytorch使用的XLA),CUDA,ROCM,现在又多了一个MPS。

https://avoid.overfit.cn/post/eb87d12f29eb4665adb43ad59fd3d64f

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/260360.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

神经科学与计算神经科学的蓬勃发展与未来趋势

导言 神经科学和计算神经科学是当前科学研究领域中备受关注的方向。本文将深入研究这两个领域的发展历程、遇到的问题、解决过程,以及未来的可用范围。我们还将关注在各国的应用现状以及未来的研究趋势,探讨如何在竞争中取胜,以及在哪些方面发…

PostgreSQL表中字段由字符串改为数组

需求:PostgreSQL数据库中的一张表的某些字段,之前存的是字符串,由于业务需求变更,需要存储多条数据,字段类型要改为数组,并保留原来的数据。 具体实现: 修改表中字段类型:把 etl_f…

全功能知识付费小程序系统源码是什么?有什么好处?

全功能知识付费小程序系统源码,是一个集课程管理、用户管理、支付管理、数据分析等于一体的综合性解决方案。它支持多种形式的课程内容,如视频、音频、图文等,满足不同用户的学习需求。同时,系统具备完善的支付功能,保…

Panoply查看nc文件的时间维

打开的是全球灌溉农田灌溉用水量遥感估算数据集(2011-2018),该文件以nc格式储存。nc格式文件就是一个多维的数据库。经纬度占了两维,可能还有时间维度,就是时空谱。 双击打开刚打开时只能看到2018年1月的灌溉数据 打…

快猫视频模板源码定制开发 苹果CMS 可打包成双端APP

苹果CMS快猫视频网站模板源码,可用于开发双端APP,后台支持自定义参数,包括会员升级页面、视频、演员、专题、收藏和会员系统等完整模块。还可以直接指定某个分类下的视频为免费专区,具备完善的卡密支付体系,无需人工管…

数合建模功能清单

在百度搜索 ”数合建模及可视化“,进入官网网站,有更详细的说明

k8s 中部署Jenkins

创建namespace apiVersion: v1 kind: Namespace metadata:name: jenkins创建pv以及pvc kind: PersistentVolume apiVersion: v1 metadata:name: jenkins-pv-volumenamespace: jenkinslabels:type: localapp: jenkins spec:#storageClassName: manualcapacity:storage: 5Giacc…

mysql部署 --(docker)

先查找MySQL 镜像 Docker search mysql ; 拉取mysql镜像,默认拉取最新的; 创建mysql容器,-p 代表端口映射,格式为 宿主机端口:容器运行端口 -e 代表添加环境变量,MYSQL_ROOT_PASSWORD是root用户…

ACE Tools环境配置指导

ACE Tools环境配置指导 简介: ACE Tools是一套为ArkUI-X应用开发者提供的命令行工具,支持在Windows/Ubuntu/macOS平台运行,用于构建OpenHarmony、HarmonyOS、Android和iOS平台的应用程序, 其功能包括开发环境检查,新建…

中缀表达式求值

中缀表达式是一种常见的数学表达式表示方法,它将操作符放在两个操作数的中间。例如,中缀表达式 "2 3" 表示将两个数相加。 中缀表达式求值的一般算法如下: 建立一个栈,用于存储操作数和运算符。从左到右扫描中缀表达…

如何做好买货查窜货这件事

窜货是大多数品牌都会遇到的渠道问题,店铺窜货是为了更多的利润空间,所以窜货还会伴随低价,治理窜货是品牌的义务,更是品牌的责任,品牌在管控渠道时应配合一套完整的控价流程,治理窜货也不例外,…

Vue 官方周报 #124 - 使用JSDoc记录组件属性

Hi &#x1f44b; 当你将鼠标悬停在IDE中的组件上时&#xff0c;显示组件属性所对应的描述&#xff0c;这个功能在开发过程中会很有用。你可以在传递给defineProps函数的TypeScript接口中使用JSDoc来实现这一点&#xff1a; MyComponent.vue <script setup lang"ts&…

Java 并发编程 —— Fork/Join 框架的原理详解

目录 一. 前言 二. 并发和并行 2.1. 并发 2.2. 并行 2.3. 分治法 三. ForkJoin 并行处理框架的理论 3.1. ForkJoin 框架概述 3.2. ForkJoin 框架原理 3.3. 工作窃取算法 四. ForkJoin 并行处理框架的实现 4.1. ForkJoinPool 类 4.2. ForkJoinWorkerThread 类 4.3.…

大华 DSS 城市安防数字监控系统 SQL 注入漏洞

漏洞简介 大华DSS数字监控系统itcBulletin接口对传入的数据没有预编译和充足的校验&#xff0c;导致该接口存在SQL注入漏洞&#xff0c;可通过注入漏洞获取数据库敏感信息。 资产测绘 app“dahua-DSS” 漏洞复现 POC: POST /portal/services/itcBulletin?wsdl HTTP/1.1 H…

图卷积神经网络发展

1. 图神经网络&#xff08;GNN&#xff09; 图神经网络的概念最早在2005年提出。2009年Franco博士在其论文 [2]中定义了图神经网络的理论基础。 本文中所提到的图均指图论中的图(Graph)。它是一种由若干个结点(Node)及连接两个结点的边(Edge)所构成的图形&#xff0c;用于刻画…

从源码到实践:深入了解鸿鹄电子招投标系统与电子招投标

在数字化采购领域&#xff0c;企业需要一个高效、透明和规范的管理系统。通过采用Spring Cloud、Spring Boot2、Mybatis等先进技术&#xff0c;我们打造了全过程数字化采购管理平台。该平台具备内外协同的能力&#xff0c;通过待办消息、招标公告、中标公告和信息发布等功能模块…

eslint严格规则检测问题报错

由于eslint严格规则检测问题报错时&#xff0c;更改两个文件package.json和vue.config.js package.json "no-unused-components": "off" vue.config.js lintOnSave: false 更改完成之后&#xff0c;重新run一下就OK了

DC电源模块在工业自动化中的关键应用案例分析

BOSHIDA DC电源模块在工业自动化中的关键应用案例分析 DC电源模块在工业自动化中有许多关键应用案例&#xff0c;以下是其中的一些&#xff1a; 1. 电机控制系统&#xff1a;在工业自动化中&#xff0c;电机控制是非常常见的应用。DC电源模块用于为电机提供稳定的直流电源&…

25考研指导规划建议(内含宝典级资料)

句句肺腑&#x1f4aa; &#x1f4cc;1.没有导师不喜欢要英语好的学生。四六级报名咨询查询链接 普通大学生至少要拿到四级证书四级真题和资料&#xff0c;有解析包括听力 &#x1f4cc;2.政治开始越早&#xff0c;离上岸越远 &#x1f4cc;3.每年7、8月是弃考高峰、11月再弃…

聪明高效能力广,AGI如何赋能内容管理?

文 | 智能相对论 作者 | 叶远风 毫无疑问&#xff0c;现在的大模型在技术比拼之外&#xff0c;如何通过产品化的方式走入到实际业务&#xff0c;是各厂商的着力点。 而一些一贯与数字化场景紧密融合的服务厂商&#xff0c;在大模型浪潮一开始就已经走在落地一线。 大数据基…