分布式机器学习(Parameter Server)

分布式机器学习中,参数服务器(Parameter Server)用于管理和共享模型参数,其基本思想是将模型参数存储在一个或多个中央服务器上,并通过网络将这些参数共享给参与训练的各个计算节点。每个计算节点可以从参数服务器中获取当前模型参数,并将计算结果返回给参数服务器进行更新。

为了保持模型一致性,通常采用下列两种方法:

  1. 将模型参数保存在一个集中的节点上,当一个计算节点要进行模型训练时,可从集中节点获取参数,进行模型训练,然后将更新后的模型推送回集中节点。由于所有计算节点都从同一个集中节点获取参数,因此可以保证模型一致性。
  2. 每个计算节点都保存模型参数的副本,因此要定期强制同步模型副本,每个计算节点使用自己的训练数据分区来训练本地模型副本。在每个训练迭代后,由于使用不同的输入数据进行训练,存储在不同计算节点上的模型副本可能会有所不同。因此,每一次训练迭代后插入一个全局同步的步骤,这将对不同计算节点上的参数进行平均,以便以完全分布式的方式保证模型的一致性,即All-Reduce范式

PS架构

在该架构中,包含两个角色:parameter server和worker

parameter server将被视为master节点在Master/Worker架构,而worker将充当计算节点负责模型训练

整个系统的工作流程分为4个阶段:

  1. Pull Weights: 所有worker从参数服务器获取权重参数
  2. Push Gradients: 每一个worker使用本地的训练数据训练本地模型,生成本地梯度,之后将梯度上传参数服务器
  3. Aggregate Gradients:收集到所有计算节点发送的梯度后,对梯度进行求和
  4. Model Update:计算出累加梯度,参数服务器使用这个累加梯度来更新位于集中服务器上的模型参数

可见,上述的Pull Weights和Push Gradients涉及到通信,首先对于Pull Weights来说,参数服务器同时向worker发送权重,这是一对多的通信模式,称为fan-out通信模式。假设每个节点(参数服务器和工作节点)的通信带宽都为1。假设在这个数据并行训练作业中有N个工作节点,由于集中式参数服务器需要同时将模型发送给N个工作节点,因此每个工作节点的发送带宽(BW)仅为1/N。另一方面,每个工作节点的接收带宽为1,远大于参数服务器的发送带宽1/N。因此,在拉取权重阶段,参数服务器端存在通信瓶颈。

对于Push Gradients来说,所有的worker并发地发送梯度给参数服务器,称为fan-in通信模式,参数服务器同样存在通信瓶颈。

基于上述讨论,通信瓶颈总是发生在参数服务器端,将通过负载均衡解决这个问题

将模型划分为N个参数服务器,每个参数服务器负责更新1/N的模型参数。实际上是将模型参数分片(sharded model)并存储在多个参数服务器上,可以缓解参数服务器一侧的网络瓶颈问题,使得参数服务器之间的通信负载减少,提高整体的通信效率。

代码实现

定义网络结构:如上定义了一个简单的CNN

实现参数服务器:

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        if torch.cuda.is_available():
            device = torch.device("cuda:0")
        else:
            device = torch.device("cpu")
 
        self.conv1 = nn.Conv2d(1,32,3,1).to(device)
        self.dropout1 = nn.Dropout2d(0.5).to(device)
        self.conv2 = nn.Conv2d(32,64,3,1).to(device)
        self.dropout2 = nn.Dropout2d(0.75).to(device)
        self.fc1 = nn.Linear(9216,128).to(device)
        self.fc2 = nn.Linear(128,20).to(device)
        self.fc3 = nn.Linear(20,10).to(device)
 
    def forward(self,x):
        x = self.conv1(x)
        x = self.dropout1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.dropout2(x)
        x = F.max_pool2d(x,2)
        x = torch.flatten(x,1)
 
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
 
        output = F.log_softmax(x,dim=1)
 
        return output

如上定义了一个简单的CNN

实现参数服务器:

class ParamServer(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = Net()
 
        if torch.cuda.is_available():
            self.input_device = torch.device("cuda:0")
        else:
            self.input_device = torch.device("cpu")
 
        self.optimizer = optim.SGD(self.model.parameters(),lr=0.5)
 
    def get_weights(self):
        return self.model.state_dict()
 
    def update_model(self,grads):
        for para,grad in zip(self.model.parameters(),grads):
            para.grad = grad
 
        self.optimizer.step()
        self.optimizer.zero_grad()

get_weights获取权重参数,update_model更新模型,采用SGD优化器

实现worker:

class Worker(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = Net()
        if torch.cuda.is_available():
            self.input_device = torch.device("cuda:0")
        else:
            self.input_device = torch.device("cpu")
 
    def pull_weights(self,model_params):
        self.model.load_state_dict(model_params)
 
    def push_gradients(self,batch_idx,data,target):
        data,target = data.to(self.input_device),target.to(self.input_device)
        output = self.model(data)
        data.requires_grad = True
        loss = F.nll_loss(output,target)
        loss.backward()
        grads = []
 
        for layer in self.parameters():
            grad = layer.grad
            grads.append(grad)
 
        print(f"batch {batch_idx} training :: loss {loss.item()}")
 
        return grads

Pull_weights获取模型参数,push_gradients上传梯度

训练

训练数据集为MNIST

import torch
from torchvision import datasets,transforms
 
from network import Net
from worker import *
from server import *
 
train_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=True,
               transform = transforms.Compose([transforms.ToTensor(),
               transforms.Normalize((0.1307,),(0.3081,))])),
               batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=False,
              transform = transforms.Compose([transforms.ToTensor(),
              transforms.Normalize((0.1307,),(0.3081,))])),
              batch_size=128, shuffle=True)
 
def main():
    server = ParamServer()
    worker = Worker()
 
    for batch_idx, (data,target) in enumerate(train_loader):
        params = server.get_weights()
        worker.pull_weights(params)
        grads = worker.push_gradients(batch_idx,data,target)
        server.update_model(grads)
 
    print("Done Training")
 
if __name__ == "__main__":
    main()

来源:分布式机器学习(Parameter Server) - N3ptune - 博客园 (cnblogs.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/33448.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

架构基本概念和架构本质

什么是架构和架构本质 在软件行业,对于什么是架构,都有很多的争论,每个人都有自己的理解。此君说的架构和彼君理解的架构未必是一回事。因此我们在讨论架构之前,我们先讨论架构的概念定义,概念是人认识这个世界的基础…

UWB超宽带定位技术的原理及定位方法

uwb定位技术即超宽带技术,它是一种无载波通信技术,利用纳秒级的非正弦波窄脉冲传输数据,因此其所占的频谱范围很宽。传统的定位技术是根据信号强弱来判别物体位置,信号强弱受外界 影响较大,因此定位出的物体位置与实际…

Redis入门(4)-list

redis中list数据会按照插入顺序进行排序,其底层是一个无头结点的双向链表,因此表头和表尾的操作性能较高,但中间元素操作性能较差。 1.lpush key element [element ] 从表头插入元素 lpush nosql redis hbase lpush nosql mongdb2.lrange…

数据结构--单链表的插入删除

数据结构–单链表的插入&删除 目标 单链表的插入(位插、前插、后插) 单链表的删除 单链表的插入 按为序插入(带头结点) ListInsert(&L,i,e):插入操作。在表L中的第i个位置上插入指定元素e。 思路:找到第i-1个结点,将新结点插入其…

Mysql架构篇--Mysql(M-M) 主从同步

文章目录 前言一、M-M 介绍:二、M-M 搭建:1.Master1:1.1 my.cnf 参数配置:1.2 创建主从同步用户:1.3 开启复制: 2.Master2:2.1 my.cnf 参数配置:2.2 创建主从同步用户:2.…

Android12之ServiceManager::addService注册服务的本质(一百五十八)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生从来没有捷径,只有行动才是治疗恐惧和懒惰的唯一良药. 更多原创,欢迎关注:Android…

iOS多语言解决方案全面指南

本文以及相关工具和代码旨在为已上线的iOS项目提供一种快速支持多语言的解决方案。由于文案显示是通过hook实现的,因此对App的性能有一定影响;除了特殊场景的文案显示需要手动支持外,其他任务均已实现自动化。 本文中的部分脚本代码基于 Chat…

【软件设计师暴击考点】网络安全等杂项高频考点暴击系列

👨‍💻个人主页:元宇宙-秩沅 👨‍💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍💻 本文由 秩沅 原创 👨‍💻 收录于专栏:软件…

SpringBoot 日志文件:日志的作用?为什么要写日志?

文章目录 🎇前言1.日志长什么样子?2.自定义打印日志2.1 在程序中得到日志对象2.2 使用日志对象打印日志 3.日志级别3.1 日志级别的分类与使用3.2 日志级别有什么用呢?3.3 日志级别的设置 4.日志持久化保存5.更方便的日志输出5.1 添加 lombok …

android用java生成crc校验位

在串口通信中,经常会用到后两位生成crc校验位的情况。 下面是校验位生成方法: public static String getCRC(String data) {data data.replace(" ", "");int len data.length();if (!(len % 2 0)) {return "0000";}in…

服务器技术(三)--Nginx

Nginx介绍 Nginx是什么、适用场景 Nginx是一个高性能的HTTP和反向代理服务器,特点是占有内存少,并发能力强,事实上nginx的并发能力确实在同类型的网页服务器中表现较好。 Nginx专为性能优化而开发,性能是其最重要的考量&#xf…

3-css高级特效-1

01-平面转换 简介 作用:为元素添加动态效果,一般与过渡配合使用 概念:改变盒子在平面内的形态(位移、旋转、缩放、倾斜) 平面转换也叫 2D 转换,属性是 transform 平移 transform: translate(X轴移动距…

最新导则下生态环评报告编制要求与规范

根据生态环评内容庞杂、综合性强的特点,依据生态环评最新导则,将内容分为4大篇章(报告篇、制图篇、指数篇、综合篇)、10大专题(生态环评报告编制、土地利用图的制作、植被类型及植被覆盖度图的制作、物种适宜生境分布图的制作、生物多样性测定、生物量及…

基于matlab基于预训练的膨胀双流卷积神经网络的视频分类器执行活动识别(附源码)

一、前言 此示例首先展示了如何使用基于预训练的膨胀 3-D (I3D) 双流卷积神经网络的视频分类器执行活动识别,然后展示了如何使用迁移学习来训练此类视频分类器使用 RGB 和来自视频的光流数据 [1]。 基于视觉的活动识别涉及使用一组视频帧预…

STM32外设系列—DHT11

文章标题 一、DHT11简介二、数据手册分析2.1 接口说明2.2 串行通信说明2.2.1 单总线通信2.2.2 单总线传输数据位定义2.2.3 时序图 三、DHT11程序设计3.1 初始化GPIO3.2 发送起始信号3.3 接收一个字节数据3.4 接收温湿度信息并校准 四、总结 一、DHT11简介 DHT11是一款常用的数…

快速点特征直方图(FPFH)描述子提取

快速点特征直方图(Fast Point Feature Histograms,FPFH)介绍 快速点特征直方图(Fast Point Feature Histograms,FPFH)是一种基于点的描述子,用于描述点云数据中的局部几何信息。FPFH描述子是在…

浅尝kubernetes

浅尝kubernetes 前言:我们早学习一门技术之前并不需要从头到尾的详细的查看一遍,只需要看一看是什么?能干什么?怎么用?即可! 一、了解kubernetes Kubernetes 也称为 K8s,是用于自动部署、扩缩和…

【C/C++实现进程间通信 二】消息队列

文章目录 前情回顾思路源码Publisher.cppSubscriber.cpp 效果 前情回顾 上一期已经讲解过了进程的相关概念以及进程间通信的实现原理,下面仅展示消息传递机制实现1进程间通信的相关代码。 思路 /*本项目主要用于以消息传递机制的方式进行进程间通信的测试。1.主要…

Odoo16 微信公众号模块开发示例

Odoo16 微信公众号模块开发示例 本模块基于 aiohttp asyncio 进行异步微信公众号接口开发, 仅实现了部分 API 仅供学习参考,更完善的同步接口请参考:wechatpy 或 werobot,可用来替代 模块中的 wechat client。 业务需求 小程序中需要用户…

pdf文档多页内插入统一图片

常用来添加公司logo、签名、印章等等 概括来说就是插入同一个图片,然后复制在每一页(自动) 用的是福昕pdf阅读器 首先打开pdf: 点击图像标注功能: 在弹出窗口中选择浏览,点击需要插入的图片&#xff08…