详解FedAvg:联邦学习的开山之作

FedAvg:2017年 开山之作

论文地址:https://proceedings.mlr.press/v54/mcmahan17a/mcmahan17a.pdf
源码地址:https://github.com/shaoxiongji/federated-learning
针对的问题:移动设备中有大量的数据,但显然我们不能收集这些数据到云端以进行集中训练,所以引入了一种分布式的机器学习方法,即联邦学习Federal Learning。在FL中,server将全局模型下放给各client,client利用本地的数据去训练模型,并将训练后的权重上传到server,从而实现全局模型的更新。
论文贡献

  1. 提出了联邦学习这个研究方向,简单来说就是从分散的存储于各设备的数据中训练模型;
  2. 提出了FedAvg算法;
  3. 通过实验验证了FedAvg的可靠性;

总结一下就是,本文提出了FedAvg算法,这种算法融合了client上的局部随机梯度下降和server上的模型平均。作者用该算法做了不少实验,结果表明FedAvg对于unbalanced且non-iid的数据有很好的鲁棒性,并且使得在非数据中心存储的数据上进行深度网络训练所需的通信轮次减少了好几个数量级。
算法介绍

  1. 联邦随机梯度下降算法FedSGD

设定固定的学习率η,对K个client的数据计算损失梯度:
g k = ▽ F k ( w t ) g_k=\bigtriangledown F_k(w_t) gk=Fk(wt)
server将聚合每个服务器计算的梯度,以此来更新模型参数:
w t + 1 ← w t − η ∑ k = 1 K n k n g k = w t − η ▽ f ( w t ) w_{t+1}\leftarrow w_t-\eta\sum\limits_{k=1}^K\frac{n_k}{n}g_k=w_t-\eta\bigtriangledown f(w_t) wt+1wtηk=1Knnkgk=wtηf(wt)

  1. 联邦平均算法FedAvg:

在client进行局部模型的更新:
w t + 1 k ← w t − η g k w_{t+1}^k\leftarrow w_t-\eta g_k wt+1kwtηgk
server对每个client更新后的权重进行加权平均:
w t + 1 ← ∑ k = 1 K n k n w t + 1 k w_{t+1}\leftarrow \sum_{k=1}^K \frac{n_k}{n}w_{t+1}^k wt+1k=1Knnkwt+1k
注意,在这里每个client可以在本地独立地多次更新本地权重,然后将更好的权重参数发给server进行加权平均。这样做的好处是不用每更新一次就去聚合,这大大减少了通信量。
FedAvg的计算量与3个参数有关:

  • C:每轮训练选择client的比例,每一轮通信时只选择C*K个client;(K为client总数)
  • E:每个client更新本地权重时,在本地数据集上训练E轮;
  • B:client更新权重时,每次梯度下降所使用的数据量,即本地数据集的batch size;

对于一个拥有 n k n_k nk个数据样本的client,每轮通信本地参数的更新次数为:
u k = E × n k B u_k=E\times\frac{n_k}{B} uk=E×Bnk
所以我们可知,FedSGD只是FedAvg的一个特例,即当参数 E = 1 , B = ∞ E=1,B=\infty E=1B=时,FedAvg等价于FedSGD。注: B = ∞ B=\infty B=意味着batch size大小就是本地数据集大小。
下面为FedAvg的算法流程图:
FedAvg算法流程图
实验设计与实现
Q1:在训练伊始,需不需要对模型进行统一初始化?
image.png
可见,采用不同的初始化参数进行模型平均,模型性能比两个父模型都差(左图);而统一初始化后,对模型的平均可以显著减少整个训练集的loss,模型性能优于两个父模型(右图)。
该结论是实现FL的重要支持,在每一轮通信时,server有必要发布全局模型,使各client采用相同的参数在本地数据集上进行训练,可以有效减少loss。
Q2:数据集怎么设置?
原文中主要研究了MNIST数据集和一个莎士比亚作品集构建的数据集,但我们在这里主要关注MNIST数据集和Cifar-10数据集,这两个数据集也是以后FL领域工作最常用的。
在模型选择方面,作者选择了多层感知机MLP和卷积神经网络CNN。
在数据集划分方面,作者假设有100个client,对于MNIST数据集,进行了iid和non-iid两种划分:

  • MNIST-iid:数据随机打乱分给100个client,每个client得到600个样例;
  • MNIST-non-iid:按数字label将数据集划分为200个大小为300的碎片,每个client两个碎片,意味着每个client至多只能获得两种label的样例;

对于Cifar-10数据集,做了iid划分。
Q3:实验咋做的?
作者指出,相比于传统模式下训练模型时计算开销为主通信开销较小的情况,在FL中,通信开销才是大头,因此减少通信开销才是我们需要关注的,作者提出可以通过加大计算以减少训练模型所需的通信轮数。作者提出主要有两种方法:提高并行度、增加每个client的计算量
而FedAvg的计算量在前面我们也给出过,再来看一下:
u k = E × n k B u_k=E\times\frac{n_k}{B} uk=E×Bnk
提高并行度:固定参数E,对C和B进行讨论。注:此处C=0时,算法也会选择一个client参与,详见上面的算法流程图。
2NN测试集acc 97%,CNN测试集acc 99%所需的通信轮数

  • B = ∞ B=\infty B=时,增加client的比例C,效果提升的优势较小;
  • B = 10 B=10 B=10时,效果显著改善了,特别是在non-iid情况下;
  • B = 10 , C ≥ 10 B=10,C\geq10 B=10,C10时,收敛速度明显改进,当client到一定数量后,收敛速度增加也不明显了;

增加每个client的计算量:根据公式,可以通过增加E或者减小B实现。
对测试集到达期望acc所需的通信轮数

  • 每个通信轮次内增加更多的本地SGD可以显著降低通信成本;
  • 对于unbalanced-non-iid的莎士比亚数据集减少的通信轮数更多,推测可能某些client有相对较大的本地数据集,这种情况下增加了本地训练的价值;

Q4:FedAvg VS FedSGD?
image.png
蓝色实现即为FedSGD。由图可知,FedAvg相比FedSGD不仅降低通信轮数,还具有更高的测试精度。推测是平均模型产生了类似Dropout的正则化效益。
Q5:加大每个client的计算量会不会导致过拟合?
image.png
加大每个client的计算量(主要体现在加大E),确实可能导致训练损失停滞或发散。所以在实际应用时,在训练后期减少各client的E,或者在loss有震荡的苗头时即刻停止,这样做有助于收敛。
Q6:在Cifar-10数据集上的表现如何?
如下图所示:
image.png
image.png
针对第一张图的一点吐槽,你去拿分布式深度学习去pk单机上的深度学习,去比通信轮数,这不是太不公平了。。。
总结展望
作者证明了FL在实践中是可行的,能够用相对较少的通信轮数训练出高质量的模型。并且提出未来的一个方向就是通过差分隐私、安全多方技术等隐私保护技术去组合FL以提供隐私保护。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/691881.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

GPT-4o仅排第二!北大港大等6所高校联手,发布权威多模态大模型榜单!

多模态大模型视频分析能力榜单出炉: Gemini 1.5 Pro最强,GPT-4o仅排第二? 曾经红极一时的GPT-4V屈居第三。 3.5研究测试:hujiaoai.cn 4研究测试:askmanyai.cn Claude-3研究测试:hiclaude3.com 最近&#…

python代码中参数的默认值

python中的函数,可以给形参指定默认值。 带有默认值的参数,可以在调用的时候不传参。 如上图所示,在给函数设定形参的时候可以给函数形参设定默认值,当然默认参数的形参应该在非默认形参的后面。 如果在调用函数的时候&#xff…

【机器学习】因TensorFlow所适配的numpy版本不适配,用anaconda降低numpy的版本

目录 0 TensorFlow最高支持的numpy版本 1 激活你的环境(如果你正在使用特定的环境) 2 查找可用的NumPy版本 3 安装特定版本的NumPy 4. 验证安装 5.(可选)如果你更改了base环境 0 TensorFlow最高支持的numpy版本 要使用 …

测试基础11:测试用例设计方法-等价类划分

课程大纲 1、概述 1.1测试用例设计方法意义 穷举测试:每种输入都测一次。最完备,但不现实。 使用设计方法,用最少的数据(成本),实现最大的测试覆盖。 1.2常用设计方法 ①等价类划分 ②边界值分析 ③错误推…

SpringBoot+Vue网上购物商城系统(前后端分离)

技术栈 JavaSpringBootMavenMySQLMyBatisVueShiroElement-UI 系统角色对应功能 用户商家管理员 系统功能截图

【安装笔记-20240608-Linux-免费空间之三维主机免费空间】

安装笔记-系列文章目录 安装笔记-20240608-Linux-免费空间之三维主机免费空间 文章目录 安装笔记-系列文章目录安装笔记-20240608-Linux-免费空间之三维主机免费空间 前言一、软件介绍名称:三维主机免费空间主页官方介绍 二、安装步骤测试版本:openwrt-…

ROS学习记录:栅格地图格式

一、机器人导航所使用的地图数据,就是ROS导航软件包里的map_server节点在话题 /map 中发布的消息数据,消息类型是nav_msgs消息包中的OccupancyGrid,它的中文意思的占据栅格,是一种正方形小格子组成的地图。 二、对障碍物进行俯视&…

基于STM32智能小车

一、前置准备 前置知识:需要学习stm32,建议去b站看江科大的视频,讲的很详细,学完串口那一块就可以制作了,软件用的是Keil5,开发语言C语言,手机连接蓝牙模块软件是蓝牙调试器。 需要准备的器件…

const详解

关键字const用来定义常量,如果一个变量被const修饰,那么它的值就不能再被改变。 但是,可以通过取地址进行修改。 将const 在指针前进行修饰,那么就修饰指针所指向的变量。 但是指针变量可以被修改。 将const 在指针后进行修饰&am…

轻松连接远程服务器SecureCRT for Mac/Windows

SecureCRT是一款功能强大的终端仿真器和文件传输工具,专为网络管理员、开发人员和系统工程师设计。它支持SSH、Telnet、RDP和串口等多种协议,提供安全、高效的远程访问和管理体验。SecureCRT具有多窗口/多标签管理、自定义终端仿真、颜色方案优化等高级功…

30-unittest生成测试报告(HTMLTestRunner插件)

批量执行完测试用例后,为了更好的展示测试报告,最好是生成HTML格式的。本文使用第三方HTMLTestRunner插件生成测试报告。 一、导入HTMLTestRunner模块 这个模块下载不能通过pip安装,只能下载后手动导入,下载地址是:ht…

DOM型xss靶场实验

DOM型xss可以使用js去控制标签中的内容。 我使用的是一个在线的dom型xss平台&#xff0c;靶场链接&#xff1a;Challenges 第一关Ma Spaghet!&#xff1a; Ma Spaghet! 关卡 <h2 id"spaghet"></h2> <script>spaghet.innerHTML (new URL(locatio…

数字校园的优势有哪些

数字化时代下&#xff0c;数字校园已成为教育领域一股显著趋势。数字校园旨在借助信息技术工具对传统校园进行改造&#xff0c;提供全新的教学、管理和服务方式。那么&#xff0c;数字校园究竟具备何种优势&#xff1f;现从三个方面为您详细介绍。 首先&#xff0c;数字校园为教…

【NI国产替代】产线测试:数字万用表(DMM),功率分析仪,支持定制

数字万用表&#xff08;DMM&#xff09; • 6 位数字表显示 • 24 位分辨率 • 5S/s-250KS/s 采样率 • 电源和数字 I/O 均采用隔离抗噪技术 • 电压、电流、电阻、电感、电容的高精度测量 • 二极管/三极管测试 功率分析仪 0.8V-14V 的可调输出电压&#xff0c;最大连…

鸿业的【管立得】设计的地下管线BIM模型如何导入到图新地球

0序&#xff1a; 在城乡建设行业&#xff0c;不论是园区的建设还是整个区划的智慧城市应用&#xff0c;地下管线都是很重要的组成元素。地下管线的直接测绘成果是管点表、管线表&#xff0c;存档及交付的成果多数是CAD文件&#xff0c;在智慧城市、市政工程、三维GIS信息化平台…

linux系统——ping命令

ping命令可以用来判断对远端ip的连通性&#xff0c;可以加域名也可以加公共ip地址 这里发送出56字节&#xff0c;返回64字节

How to: Add and Customize the Ribbon Skin List and Skin Gallery

皮肤列表和皮肤库允许用户选择皮肤。本文介绍如何在功能区中显示“皮肤列表”或“皮肤库”并对其进行自定义。 DevExpress演示中心中的大多数应用程序都允许您选择皮肤。例如&#xff0c;运行XtraGrid演示并导航到皮肤功能区页面以更改当前皮肤。 在功能区UI中显示皮肤列表或…

多模态模型是什么意思(国内外的AI多模态有哪些)

在人工智能和机器学习的领域&#xff0c;我们经常会遇到一些专业术语&#xff0c;这些术语可能会让初学者感到困惑。其中&#xff0c;"多模态模型"就是这样一个概念。 什么是AI多模态。它是什么意思呢&#xff1f; 那么&#xff0c;多模态模型是什么意思呢&#xff1…

前端工程化:基于Vue.js 3.0的设计与实践

这里写目录标题 《前端工程化&#xff1a;基于Vue.js 3.0的设计与实践》书籍引言本书概述主要内容作者简介为什么选择这本书&#xff1f;结语 《前端工程化&#xff1a;基于Vue.js 3.0的设计与实践》书籍 够买连接—>https://item.jd.com/13952512.html 引言 在前端技术日…

MySQL基础_10.约束

文章目录 第一章、约束1.1 约束的定义1.2 非空约束1.3 唯一性约束1.4 主键约束1.5 自增列1.6 外键约束1.7 CHECK约束1.8 DEFAULT约束 第一章、约束 1.1 约束的定义 约束是对表中字段的限制。 约束按照作用范围可以分为&#xff1a;列级约束和表级约束 列级约束&#xff1a;声…