【从零开始实现联邦学习】

1. 环境配置如下

  • python3.7
  • pip install torch
  • pip install torchvision

2. 代码如下

原书的代码存在一点bug,现已被作者修复
Client端代码如下

import torch.utils.data

class Client(object):
    def __init__(self,conf,model,train_dataset,id=1):
        self.conf = conf                        # 配置文件
        self.local_model = model                # 客户端本地模型
        self.client_id = id                     # 客户端ID
        self.train_dataset = train_dataset      #客户端本地数据集
        all_range = list(range(len(self.train_dataset)))
        data_len = int(len(self.train_dataset)/self.conf['no_models'])
        indices = all_range[id*data_len:(id+1)*data_len]
        self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=conf["batch_size"],
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices)
        )

    def local_train(self, model):
        for name,param in model.state_dict().items():
            # 客户端首先用服务器端下发的全局模型覆盖本地模型
            self.local_model.state_dict()[name].copy_(param.clone())
        # 定义最优化函数器,用于本地模型训练
        optimizer = torch.optim.SGD(
            self.local_model.parameters(),
            lr=self.conf['lr'],
            momentum=self.conf['momentum']
        )
        # 本地模型训练
        self.local_model.train()
        for e in range(self.conf['local_epochs']):
            for batch_id,batch in enumerate(self.train_loader):
                data, target = batch
                if torch.cuda.is_available():
                    self.local_model.cuda()
                    data = data.cuda()
                    target = target.cuda()
                optimizer.zero_grad()
                output = self.local_model(data)
                loss = torch.nn.functional.cross_entropy(output, target)
                loss.backward()
                optimizer.step()
            print("Epoch %d done." % e)
        diff = dict()
        for name,data in  self.local_model.state_dict().items():
            diff[name] = (data - model.state_dict()[name])
        return diff

Server端代码如下

import torch.utils.data
import torchvision.datasets as datasets
from torchvision import models
from torchvision.transforms import transforms

# 服务端
class Server(object):
    def __init__(self, conf, eval_dataset):
        self.conf = conf
        # 服务器端的模型
        self.global_model = models.get_model(self.conf["model_name"])
        self.eval_loader = torch.utils.data.DataLoader(eval_dataset,batch_size=self.conf["batch_size"],shuffle=True)
        self.accuracy_history = []  # 保存accuracy的数组
        self.loss_history = []  # 保存loss的数组

    # 聚合各个服务器上传的信息
    def model_aggregate(self, weight_accumulator):
        # weight_accumulator存储了每一个客户端的上传参数变化值
        for name,data in self.global_model.state_dict().items():
            update_per_layer = weight_accumulator[name] * self.conf['lambda']
            if data.type() != update_per_layer.type():
                data.add_(update_per_layer.to(torch.int64))
            else:
                data.add_(update_per_layer)
    # 定义模型评估函数
    def model_eval(self):
        self.global_model.eval()
        total_loss = 0.0
        correct = 0
        dataset_size = 0
        for batch_id,batch in  enumerate(self.eval_loader):
            data,target = batch
            dataset_size += data.size()[0]
            if torch.cuda.is_available():
                self.global_model.cuda()
                data = data.cuda()
                target = target.cuda()
            output = self.global_model(data)
            # 把损失值聚合起来
            total_loss += torch.nn.functional.cross_entropy(output,target,reduction='sum').item()
            # 获取最大的对数概率的索引值
            pred = output.data.max(1)[1]
            correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
        # 计算准确率
        acc = 100.0 * (float(correct) / float(dataset_size))
        # 计算损失值
        total_l = total_loss / dataset_size
        # 将accuracy和loss保存到数组中
        self.accuracy_history.append(acc)
        self.loss_history.append(total_l)
        return acc,total_l

    def save_results_to_file(self):
        # 将accuracy和loss保存到文件中
        with open("fed_accuracy_history.txt", "w") as f:
            for acc in self.accuracy_history:
                f.write("{:.2f}\n".format(acc))

        with open("fed_loss_history.txt", "w") as f:
            for loss in self.loss_history:
                f.write("{:.4f}\n".format(loss))

聚合代码如下

import json
import random

import torch

from MyDataLoader import get_dataset
from chapter3.Client import Client
from chapter3.Server import Server

with open("conf.json",'r') as f:
    conf = json.load(f)

# 接下来分别定义一个服务端对象和多个客户端对象,用来模拟横向联邦训练场景

train_datasets,eval_datasets = get_dataset("./data/",conf["type"])
server = Server(conf,eval_datasets)
clients = []
# 创建多个客户端
for c in range(conf["no_models"]):
    clients.append(Client(conf,server.global_model,train_datasets,c))
# 每一轮迭代,服务端会从当前的客户端集合中随机挑选一部分参与本轮迭代训练,被选中的客户端调用本地训练接口local_train进行本地训练,
# 最后服务器调用模型聚合函数model——aggregate来更新全局模型,代码如下所示:
for e in range(conf["global_epochs"]):
    # 采样k个客户端参与本轮联邦训练
    candidates = random.sample(clients,conf['k'])
    # 初始化weight_accumulator并在GPU上(如果可用)
    weight_accumulator = {}
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")

    for name,params in server.global_model.state_dict().items():
        # 在指定设备上创建并初始化weight_accumulator中的张量
        weight_accumulator[name] = torch.zeros_like(params).to(device)

    for c in candidates:
        # 确保本地训练后的模型差异在正确设备上
        diff = c.local_train(server.global_model)
        for name,params in server.global_model.state_dict().items():
            weight_accumulator[name].add_(diff[name])

    server.model_aggregate(weight_accumulator)
    acc,loss = server.model_eval()
    print("Epoch %d ,acc:%f,loss: %f\n" % (e,acc,loss))

server.save_results_to_file()

数据集的加载

import torch.utils.data
import torchvision.datasets as datasets
from torchvision import models
from torchvision.transforms import transforms


def get_dataset(dir, name):
    if name == 'mnist':
        train_dataset = datasets.MINST(dir, train=True, download=True,transform=transforms.ToTensor())
        eval_dataset = datasets.MINST(dir, train=False, transform=transforms.ToTensor())
    elif name=='cifar':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010)),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        train_dataset = datasets.CIFAR10(dir, train=True, download=True, transform=transform_train)
        eval_dataset = datasets.CIFAR10(dir, train=False, transform=transform_test)
    return train_dataset, eval_dataset

配置文件如下

{
  "model_name" : "resnet18",

  "// comment1": "客户端的个数",
  "no_models" : 10,
  "type" : "cifar",
  "global_epochs" : 10,
  "local_epochs" : 3,

  "// comment2": "每一轮中挑选的机器数",
  "k" : 6,
  "batch_size" : 32,
  "lr" : 0.001,
  "momentum" : 0.0001,
  "lambda" : 0.1
}

3.结果如下

可以看到联邦学习的效果还是不如集中式学习,也有可能是因为我迭代的轮次不够。
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/752993.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【系统架构设计师】七、信息安全技术基础知识(网络安全技术|网络与信息安全风险|网络安全协议)

目录 一、网络安全技术 1.1 防火墙 1.2 入侵检测系统IDS 1.3 入侵防御系统IPS 1.4 杀毒软件 1.5 蜜罐系统 二、网络与信息安全风险 三、网络安全协议 四、相关推荐 五、历年真题练习 一、网络安全技术 1.1 防火墙 防火墙是在内部网络和外部因特网之间增加的一道安全…

使用自定义的shiro密码匹配器CredentialsMatcher完成密码验证

今天突然想研究一下shiro怎么匹配用户的密码。 我们使用shiro的API登录时,会先创建一个令牌对象,而经常用的令牌对象是UsernamePasswordToken,把用户输入的用户名和密码作为参数构建一个UsernamePasswordToken,然后通过Subject.l…

宏集物联网工控屏通过 S7 ETH 协议采集西门子 1200 PLC 数据

前言 为了实现和西门子PLC的数据交互,宏集物联网HMI集成了S7 PPI、S7 MPI、S7 Optimized、S7 ETH等多个驱动来适配西门子200、300、400、1200、1500、LOGO等系列PLC。 本文主要介绍宏集物联网HMI如何通过S7 ETH协议采集西门子1200 PLC的数据,文中详细介…

办公软件WPS与Office的区别

临近计算机考试很多同学在纠结我是报wps好?还是ms office好?下面就来详细说说。 1、wps属于国内金山公司的办公软件,里面包含word、Excel和PPT。考试是2021年开始的! 2、MS(Microsoft 微软) office属于美…

网易游戏如何基于 Apache Doris 构建全新湖仓一体架构

导读:随着网易游戏品类及产品的快速发展,游戏数据分析场景面临着越来越多的挑战,为了保证系统性能和 SLA,要求引入新的组件来解决特定业务场景问题。为此,网易游戏引入 Apache Doris 构建了全新的湖仓一体架构。经过不…

精益生产转型攻略:如何平稳过渡,避免业务震荡?

在当今快速变化的市场环境中,越来越多的企业开始关注并尝试实施精益生产,以提升生产效率、降低成本并增强竞争力。然而,转型并非一蹴而就,如何在确保精益生产实施效果的同时,又避免对企业的现有业务流程和组织结构产生…

【C++进阶9】异常

一、C语言传统的处理错误的方式 终止程序,如assert 如发生内存错误,除0错误时就会终止程序返回错误码 需要程序员自己去查找对应的错误 z如系统的很多库的接口函数都是通 过把错误码放到errno中,表示错误 二、C异常概念 异常:函…

anaconda卸载过程中出现fail to run pre-unistall报错

问题: 在使用Uninstall-Anaconda3.exe卸载程序时,出现报错: 解决方案: 把文件夹移动到C盘用户文件夹后再运行卸载程序。即可正常运行程序。

ping 出现的结果判断

ICMP协议发送包的时候 常见的ping反馈结果: 连接建立成功,Reply from 目标地址 目标主机不可达,Destination host unreachable 直接不能出交换机,到达不了交换机 请求时间超时,Request timed out 服务器到交换机…

一名HR,在招聘嵌入式开发岗位,为什么感觉一年比一年难?

在开始前刚好我有一些资料,是我根据网友给的问题精心整理了一份「嵌入式的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!! 1.嵌入式学用不一致, 高…

MySQL基础查询与复杂查询

基础查询 1、查询用户信息,仅显示用户的姓名与手机号,用中文显示列名。中文显示姓名列与手机号列。 2、根据商品名称进行模糊查询,模糊查询需要可以走索引,需要给出explain语句。使用explain测试给出的查询语句,需要显…

如何把mkv转成mp4?介绍一下将mkv转成MP4的几种方法

如何把mkv转成mp4?如果你有一个MKV格式的视频文件,但是需要将其转换为MP4格式以便更广泛地在各种设备和平台上播放和共享,你可以通过进行简单的文件格式转换来实现。转换MKV到MP4格式可以提供更好的兼容性,并确保你的视频文件能够…

vue2(vue-cli3x[vue.config.js])使用cesium新版(1.117.0)配置过程

看来很多解决方法都没有办法,最后终于。呜呜呜呜 这里我用的是vue-cli去搭建的项目的vue2 项目,其实不建议用vue2搭配cesium。因为目前cesium停止了对vue2的版本更新,现在默认安装都是vue3版本,因此需要控制版本,否则…

Rockchip RK3588 - Rockchip Linux Recovery rkupdate升级

---------------------------------------------------------------------------------------------------------------------------- 开发板 :ArmSoM-Sige7开发板eMMC :64GBLPDDR4 :8GB 显示屏 :15.6英寸HDMI接口显示屏u-boot &a…

【产品经理】订单处理9-台账库存管理

在订单处理过程中,台账库存的具体设计怎么做? 在订单处理过程中,分配仓库成功后要扣除仓库库存并计算商品缺货情况,仓库库存就是台账库存。 1,台账库存是针对某个仓库的库存,且台账库存只计算此商品SKU的库…

小马搬运物品-第13届蓝桥杯省赛Python真题精选

[导读]:超平老师的Scratch蓝桥杯真题解读系列在推出之后,受到了广大老师和家长的好评,非常感谢各位的认可和厚爱。作为回馈,超平老师计划推出《Python蓝桥杯真题解析100讲》,这是解读系列的第89讲。 小马搬运物品&…

C#/.NET量化开发实现财富自由【4】实现EMA、MACD技术指标的计算

听说大A又回到了2950点以下,对于量化交易来说,可能这些都不是事儿。例如,你可以预判到大A到顶了,你可能早就跑路了。判断逃顶还是抄底,最简单的方式就是判断是否顶背离还是底背离,例如通过MACD,…

华为5288 V5服务器安装BCLinux8U4手记

本文记录了华为5288 V5服务器安装BCLinux8U4操作系统的过程。 一、系统环境 1、服务器 华为FusionServer Pro 5288 V5服务器 2、操作系统 BCLinux-R8-U4-Server-x86_64-220725.iso 官网下载地址 sha256sum:1d31d3b8e02279e89965bd3bea61f14c65b9d32ad2ab6d4eb…

MySQL数据库基础练习系列——教务管理系统

项目名称与项目简介 教务管理系统是一个旨在帮助学校或教育机构管理教务活动的软件系统。它涵盖了学生信息管理、教师信息管理、课程管理、成绩管理以及相关的报表生成等功能。通过该系统,学校可以更加高效地处理教务数据,提升教学质量和管理水平。 1.…

uniapp获取证书秘钥、Android App备案获取公钥、签名MD5值

一、 uniapp获取证书秘钥 打开uniapp开发者中心下载证书打开cmd输入以下这段代码,下载提供查看到的密钥证书密码就可以了!下载证书在 java 环境下运行才可以 // your_alias 换成 证书详情中的别名,your_keystore.keystore 改成自己的证书文件…