联邦学习实验复现—MNISIT IID实验 pytorch

联邦学习论文复现🚀

        在精度的联邦学习的论文之后打算进一步开展写一个联邦学习的基础代码,用于开展之后的相关研究,首先就是复现一下论文中最基础也是最经典的MNIST IID(独立同分布划分) 数据集。然后由于这个联邦学习的论文是谷歌发的,所以官方的代码好像是Tensorflow的,然后为了方便后续的研究我就又自己写了一个pytroch版本的。

记得把代码中的路径都换成自己的


前置文章:联邦学习论文逐句精度:https://blog.csdn.net/chrnhao/article/details/1427517006


文章目录

  • 联邦学习论文复现🚀
  • 0.预处理流程&项目文件结构
  • 1.获取MNIST数据集
  • 2.处理测试集 Create_test_datasets.py
  • 3.划分客户端样本(训练集数据) Create_client_datasets.py
  • 4.构建多客户端Dataloader init_clients.py
  • 5.模型代码 CNN.py
  • 6.训练代码 train.py
    • 6.1 导入库部分
    • 6.2 联邦学习超参数初始化
    • 6.3 Dataloader 类
    • 6.4 固定随机种子
    • 6.5 初始化获取所有客户端的dataloader
    • 6.6 加载并处理测试集
    • 6.7 初始化中心服务器模型和损失函数并迁移到GPU上
    • 6.8 client_update 客户端更新函数✨
    • 6.9 train 中心服务器训练函数✨
    • 6.10 测试代码部分
    • 6.11 开始训练&绘图&保存训练结果
  • 7.训练结果
  • 8.结果对比 plot_compare_curve.py
  • 9.结束

0.预处理流程&项目文件结构


在这里插入图片描述


在这里插入图片描述

  • Client_datasets:保存所有客户端的处理好的非图像训练集数据
  • mnisit_test:存储测试集图片数据
  • mnisit_test:存储训练集图片数据
  • Test_dataset:保存测试集的处理好的非图像测试集数据
  • Train_result:保存每种超参数组合训练后得到的准确率曲线的结果
  • CNN.py:CNN模型文件
  • Create_client_datasets.py:将训练集划分为多个客户端的样本
  • Create_test_datasets.py:处理测试集图片,构建测试集数据
  • init_clients:工具代码,将所有客户端的数据都处理成dataloader
  • plot_compare_cureve.py:绘制结果比较曲线
  • train.py:训练代码

1.获取MNIST数据集


https://www.kaggle.com/datasets/hojjatk/mnist-dataset

在这里插入图片描述

  • train-images-idx3-ubyte.gz: training set images (9912422 bytes)
  • train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
  • t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
  • t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)

        MNIST数据集原始的数据是二级制格式的文件,需要通过代码将其转换成图片格式,之后才能进行后一步的处理,然后具体转换的代码,参考了这个Github代码,感谢博主。
https://www.kaggle.com/datasets/hojjatk/mnist-dataset
在这里插入图片描述

训练集和测试集都需要使用代码进行转换,训练集和测试集分别得到10类手写数字的10个文件夹,下面的图是GitHub博主的图。

在这里插入图片描述
我自己将转换出的代码分别存在了mnist_trainmnist_test中。
在这里插入图片描述

2.处理测试集 Create_test_datasets.py


将测试集的中的图片用opencv读取,归一化并打包成npy文件。
在这里插入图片描述

import os
import cv2
import numpy as np


MNIST_dir_test = r'C:\Users\Administrator\Desktop\Federated\mnist_test'
MNIST_test_list = os.listdir(MNIST_dir_test)

data_list = []
label_list = []

for label in MNIST_test_list:
    label_path = os.path.join(MNIST_dir_test, label)
    print(label, len(os.listdir(label_path)))
    for image in os.listdir(label_path):
        image_path = os.path.join(label_path, image)
        image = cv2.imread(image_path, 0)/255
        data_list.append([image])
        label_list.append(int(label))

np.save('Test_dataset\MNIST_test_data.npy', data_list)
np.save('Test_dataset\MNIST_test_label.npy', label_list)

3.划分客户端样本(训练集数据) Create_client_datasets.py


        首先在联邦学习中,原论文是将mnist的训练集一共60000张图片,划分到了100个客户端中,每个600张图片,这里有一个问题,就是0-9,在数据集中虽然总数是6000,但是每个类别的个数不是正好6000,但是划分IID数据集,又需要将每个客户端上的数据的分布是相同的,也就是类别数量是均匀的,为了应对这种情况,我采用了以下方案。

  • 第一步:首先读取每个客户端的数组图片并归一化,保存到一个列表中;
  • 第二步:用每个类别的列表的长度除以100,得到每个类别数量除以100得到的商(这里回去我翻了一下小学知识),(假设一类样本数量是5978,则除以100后商为59,则表明他可以均匀的给100个客户端每个客户端100个样本,会有一点剩余);
  • 第三步:获得每个类别列表长度的数量除以100得到的余数,这个余数就是每一个类别剩余的样本数量;
  • 第四步:首先把每个类别的的剩余样本拿出来留着之后补空;
  • 第五步:先将每个类别能均匀分配的样本分配到各个客户端中,然后将剩余样本再顺序填补到每个客户端的数据集中。

这里略微有一点点复杂,想进一步理解的话需要单步运行一下。

import os
import numpy as np
import cv2

# 创建100个客户端的文件夹
# for i in range(1, 101):
#     os.makedirs(os.path.join('Client_datasets', f'client_{i}'), exist_ok=True)

# 获取每个类别数据的图像数据
number_0 = [[cv2.imread(os.path.join("mnist_train/0", i), 0) / 255] for i in os.listdir("mnist_train/0")]
number_1 = [[cv2.imread(os.path.join("mnist_train/1", i), 0) / 255] for i in os.listdir("mnist_train/1")]
number_2 = [[cv2.imread(os.path.join("mnist_train/2", i), 0) / 255] for i in os.listdir("mnist_train/2")]
number_3 = [[cv2.imread(os.path.join("mnist_train/3", i), 0) / 255] for i in os.listdir("mnist_train/3")]
number_4 = [[cv2.imread(os.path.join("mnist_train/4", i), 0) / 255] for i in os.listdir("mnist_train/4")]
number_5 = [[cv2.imread(os.path.join("mnist_train/5", i), 0) / 255] for i in os.listdir("mnist_train/5")]
number_6 = [[cv2.imread(os.path.join("mnist_train/6", i), 0) / 255] for i in os.listdir("mnist_train/6")]
number_7 = [[cv2.imread(os.path.join("mnist_train/7", i), 0) / 255] for i in os.listdir("mnist_train/7")]
number_8 = [[cv2.imread(os.path.join("mnist_train/8", i), 0) / 255] for i in os.listdir("mnist_train/8")]
number_9 = [[cv2.imread(os.path.join("mnist_train/9", i), 0) / 255] for i in os.listdir("mnist_train/9")]

# 每个类别的样本总数除以100个客户端
first_round_number = [len(number_0) // 100, len(number_1) // 100, len(number_2) // 100, len(number_3) // 100,
                      len(number_4) // 100, len(number_5) // 100, len(number_6) // 100, len(number_7) // 100,
                      len(number_8) // 100, len(number_9) // 100]

# 每个类别剩余的样本数量
remain_number = [len(number_0) % 100, len(number_1) % 100,
                 len(number_2) % 100, len(number_3) % 100,
                 len(number_4) % 100, len(number_5) % 100,
                 len(number_6) % 100, len(number_7) % 100,
                 len(number_8) % 100, len(number_9) % 100]

# 获得所有客户端数据集构成的列表
number_list = [number_0, number_1, number_2, number_3, number_4, number_5, number_6, number_7, number_8, number_9]

# 处理剩余数据
remain_data = (number_0[-remain_number[0]:] + number_1[-remain_number[1]:] + number_2[-remain_number[2]:] +
               number_3[-remain_number[3]:] + number_4[-remain_number[4]:] + number_5[-remain_number[5]:] +
               number_6[-remain_number[6]:] + number_7[-remain_number[7]:] + number_8[-remain_number[8]:] +
               number_9[-remain_number[9]:])

# 剩余数据的标签
remain_label = ([0] * remain_number[0] + [1] * remain_number[1] + [2] * remain_number[2] + [3] * remain_number[3] +
                [4] * remain_number[4] + [5] * remain_number[5] + [6] * remain_number[6] + [7] * remain_number[7] +
                [8] * remain_number[8] + [9] * remain_number[9])

# 开始构建100个客户端的数据
for i in range(100):
    data = []
    label = []
    for index, j in enumerate(number_list):
        data += j[i*first_round_number[index]:(i+1)*first_round_number[index]]
        label += [index] * first_round_number[index]

    # 将剩余的数据再补充分配到每个数据集当中
    data += remain_data[i*len(remain_data)//100:(i+1)*len(remain_data)//100]
    label += remain_label[i*len(remain_label)//100:(i+1)*len(remain_label)//100]

    # 缓存每个数据集
    print(i+1,np.shape(data),np.shape(label))
    
    # 记得换成自己的路径
    np.save(rf"C:\Users\Administrator\Desktop\Federated\Client_datasets\client_{i+1}\data.npy", data, allow_pickle=True)
    np.save(rf"C:\Users\Administrator\Desktop\Federated\Client_datasets\client_{i+1}\label.npy", label, allow_pickle=True)

4.构建多客户端Dataloader init_clients.py

        在完成数据集的划分之后,想要在一个电脑上模拟使用多个客户端的数据进行训练,对于使用pytorch框架而言,我能想到的就生成n个dataloader,n就是客户端样本的数量,然后将生成的n个dataloader都存到一个列表里传给客户端,然后客户端如果选择到了用哪个客户端的数据,就用哪个客户端的数据来训练。所以就有了一下的pytorch的Dataloader的代码版本,基本逻辑呢就是,读取所有的客户端文件夹中的数据,然后都封装成pytorch的dataloder,这里正好就可以把联邦学习中的每个客户端本地训练的参数量B传进去,联邦学习里论文的B就是clients_dataloader中的batch_size参数。这个代码作为一个脚本在项目文件夹的目录中,我命名为了init_clients.py

import os
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.len = len(data)
        self.x_data = torch.from_numpy(np.array(list(map(lambda x: x[0], data)), dtype=np.float32))
        self.y_data = torch.from_numpy(np.array(list(map(lambda x: x[-1], data)))).squeeze().long()

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len


def clients_dataloader(batch_size=60):
    dataloader_list = []

    dir_path = os.listdir("Client_datasets")
    for dir in dir_path:
        data_path = os.path.join("Client_datasets", dir, "data.npy")
        label_path = os.path.join("Client_datasets", dir, "label.npy")

        data = np.load(data_path)
        label = np.load(label_path)

        dataset = [[i, j] for i, j in zip(data, label)]
        dataloader = DataLoader(CustomDataset(dataset), shuffle=True, batch_size=batch_size)
        dataloader_list.append(dataloader)

    return dataloader_list


if __name__ == '__main__':
    dataloaders = clients_dataloader()
    if dataloaders:
        print(f"Loaded {len(dataloaders)} client dataloaders.")
    else:
        print("No dataloaders loaded.")

在这里插入图片描述

5.模型代码 CNN.py


        根据联邦学习论文,我选了MNISIT数据集中的CNN模型进行复现,也是非常简单,两个卷积层,然后跟着两个最大池化层,最后接一个线性层,然后再加个ReLU。

import torch

class CNN(torch.nn.Module):
    def __init__(self, in_channels=1, classes=10):
        super(CNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(in_channels, 32, kernel_size=5)
        self.max_pool1 = torch.nn.MaxPool2d(2, 2)

        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=5)
        self.max_pool2 = torch.nn.MaxPool2d(2, 2)
        self.linear = torch.nn.Linear(64 * 4 * 4, classes)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.max_pool1(x)

        x = self.conv2(x)
        x = self.max_pool2(x)

        x = x.view(x.size(0), -1)
        x = self.linear(x)
        x = self.relu(x)
        return x


if __name__ == '__main__':
    model = CNN(3, 10)
    print(model)
    x = torch.randn((1, 3, 28, 28))
    y = model(x)
    print(y.shape)
    print(y)

6.训练代码 train.py


训练代码的内容确实是先对来说难一些,我这里先给完整代码,然后我再逐个部分进行解释。

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from init_clients import clients_dataloader
import random
from CNN import CNN
# 解决出现libiomp5md.dll缺导致无法绘图的错误
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 

C = 0.1  #参与训练的客户端的比例
E = 5     #每个客户端本地训练的轮数
B = 600    #每个客户的BatchSize大小 B>=600 等效于 论文中B=∞

client_num = 100

# Test_dataset class
class Dataset(Dataset):
    def __init__(self, data):
        self.len = len(data)
        self.x_data = torch.from_numpy(np.array(list(map(lambda x: x[0], data)), dtype=np.float32))
        self.y_data = torch.from_numpy(np.array(list(map(lambda x: x[-1], data)))).squeeze().long()

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

# 设置随机种子
def set_random_seed(seed_value=100):
    random.seed(seed_value)         # Fixed Python built-in random generator
    np.random.seed(seed_value)      # Fixed NumPy random generator
    torch.manual_seed(seed_value)   # Fixed PyTorch random generator

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)  # If multiple GPUs are used

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_random_seed(100)

# 初始化客户端dataloader 用100个客户端的数据生成Dataloader
client_dataloader_list = clients_dataloader(B)

# 加载测试数据集
data_test = np.load('Test_dataset/MNIST_test_data.npy')
label_test = np.load('Test_dataset/MNIST_test_label.npy')
test_dataset = [[i, j] for i, j in zip(data_test, label_test)]

# 将测试集数据处理成Dataloder
Test_dataset = Dataset(test_dataset)
testloader = DataLoader(Test_dataset, shuffle=True, batch_size=256)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 初始化模型和损失函数,并将模型和损失函数移到GPU上
model = CNN(in_channels=1, classes=10)
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
criterion.to(device)

def client_update(client_num, E, model_parameter):
    '''
    客户端更新参数
    :param client_num: 选择到的客户端的编号
    :param E: 本地训练的epoch轮数
    :param model_parameter: 中心服务器发给客户端的模型参数
    :return: 在该客户端上训练好的模型参数
    '''

    dataloader = client_dataloader_list[client_num]     # 获取选择到客户端的dataloder
    client_model = CNN().to(device)                     # 加载一个空模型
    client_model.load_state_dict(model_parameter)       # 加载中心服务器发的模型参数
    client_model.train()                                # 将模型变为训练模式

    # optimizer = torch.optim.SGD(client_model.parameters(), lr=0.01, momentum=0.9)
    optimizer = torch.optim.SGD(client_model.parameters(), lr=0.01)

    for i in range(E):
        correct = 0
        total = 0

        for data, label in dataloader:
            train_data_value, train_data_label = data.to(device), label.to(device)
            train_data_label_pred = client_model(train_data_value)

            loss = criterion(train_data_label_pred, train_data_label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(train_data_label_pred, 1)
            total += train_data_label.size(0)
            correct += (predicted == train_data_label).sum().item()

        accuracy = 100 * correct / total
        print(f'Client:{client_num},Epoch {i+1}/{E}, Accuracy: {accuracy:.2f}%, Loss: {loss.item()}')

    # 返回客户端训练模型的模型参数
    return client_model.state_dict()

def train(client_num, C, E):
    model.train()                               # 将客户端模型变为训练模型
    send_model_parameter = model.state_dict()   # 然后获取即将分发给各个客户端的模型的权重

    random_numbers = random.sample(range(client_num), int(client_num*C))        # 按照比例随机选择出用于本轮训练的模型的索引编号
    client_return_parameter_list = []                                           # 初始化一个列表用与保存每个客户端本地训练好之后返回的模型参数

    for client in random_numbers:                                               # 遍历所有的选择到的客户端的编号(索引号)
        model_parameter = client_update(client, E=E, model_parameter=send_model_parameter)      # 返回每个客户端训练好的模型权重
        client_return_parameter_list.append(model_parameter)                                    # 将每一个客户端的模型权重加载到列表中

    # 先生成一个参数都为0的模型参数的参数字典,用于之后将客户端返回的模型参数都加到改字典上
    aggregated_model_parameter = {key: torch.zeros_like(value, dtype=torch.float32) for key, value in send_model_parameter.items()}

    # 将所有客户端模型的权重都加权求和
    for client_param in client_return_parameter_list:
        for key in aggregated_model_parameter:
            aggregated_model_parameter[key] += client_param[key] * (1 / int(client_num*C))

    # 将求和好的权重加载给中心服务器模型,用于下一轮的发送
    model.load_state_dict(aggregated_model_parameter)
    return model

def test():
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for testdata in testloader:
            test_data_value, test_data_label = testdata
            test_data_value, test_data_label = test_data_value.to(device), test_data_label.to(device)
            test_data_label_pred = model(test_data_value)
            _, test_predicted = torch.max(test_data_label_pred.data, dim=1)
            test_total += test_data_label.size(0)
            test_correct += (test_predicted == test_data_label).sum().item()
    test_acc = round(100 * test_correct / test_total, 3)
    print(f'Test Accuracy: {test_acc:.2f}%')
    return test_acc

if __name__ == '__main__':
    test_accuracies = []
    epochs = 1000
    for i in range(epochs):
        print(f"Epoch:{i}——",end='')
        train(client_num, C, E)
        test_acc = test()
        test_accuracies.append(test_acc)

    # Plotting the test accuracy curve
    plt.plot(range(1, epochs + 1), test_accuracies, marker='o', linestyle='-', color='b')
    plt.xlabel('Epoch')
    plt.ylabel('Test Accuracy (%)')
    plt.title(f'Test Accuracy vs. Epoch (C={C}, E={E}, B={B})')
    plt.grid(True)
    plt.show()

    np.save(rf'Train_result/C{C}B{B}E{E}.npy', test_accuracies)


6.1 导入库部分

        除了其他常规部分,这里值得提的是from init_clients import clients_dataloader这句是需要导入我们自己写的代码,from CNN import CNN这部分是导入自己的模型。

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from init_clients import clients_dataloader
import random
from CNN import CNN
# 解决出现libiomp5md.dll缺导致无法绘图的错误
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 

6.2 联邦学习超参数初始化

        这B,C,E三个参数对应的就是联邦学习里最重要的三个超参数。然后这里需要多说的是,这里为什么要初始化客户端数量,因为联邦学习在整个客户端的模型权重部分每个客户端提供的模型参数需要乘以该客户端的数据数量除以全部客户端所持有的数据数量的比例的一个权重,而在该任务中,由于每个客户端的所持有的数据量都一致,所以乘以的权重数量为(1 / int(client_num*C))也就是选了少个客户端,权重就是多少个客户端分之1。

C = 0.1  #参与训练的客户端的比例
E = 5     #每个客户端本地训练的轮数
B = 600    #每个客户的BatchSize大小 B>=600 等效于 论文中B=∞
client_num = 100

6.3 Dataloader 类

这个是给测试集数据生成dataloader的一个Dataset类,和之前划分客户端的时候用的类一样,这个属于pytorch基础部分。

# Test_dataset class
class Dataset(Dataset):
    def __init__(self, data):
        self.len = len(data)
        self.x_data = torch.from_numpy(np.array(list(map(lambda x: x[0], data)), dtype=np.float32))
        self.y_data = torch.from_numpy(np.array(list(map(lambda x: x[-1], data)))).squeeze().long()

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

6.4 固定随机种子

在联邦学习论文中有专门的论述,使用相同的随机种子初始化的模型使用不同的客户端数据集进行训练,然后加和得到的新模型会比单独两个模型对与测试集的效果更好。

# 设置随机种子
def set_random_seed(seed_value=100):
    random.seed(seed_value)         # Fixed Python built-in random generator
    np.random.seed(seed_value)      # Fixed NumPy random generator
    torch.manual_seed(seed_value)   # Fixed PyTorch random generator

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)  # If multiple GPUs are used

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_random_seed(100)

6.5 初始化获取所有客户端的dataloader

这个就是我们之前写的代码,B就是联邦学习论文中的每个客户端的使用的本地的mini_batch的大小,这个代码返回的client_dataloader_list,是包含所有客户端的数据处理成的dataloader的列表。

# 初始化客户端dataloader 用100个客户端的数据生成Dataloader
client_dataloader_list = clients_dataloader(B)

6.6 加载并处理测试集

这步就是将测试集的数据和标签打包到一起,然后输入Dataset然后实例化出一个Test_dataset,然后再生成一个dataloader,这里dataloader的shuffle,和batch_size,都不是很重要,都不会影响最终的结果,正常填一差不多合适的值就行。

# 加载测试数据集
data_test = np.load('Test_dataset/MNIST_test_data.npy')
label_test = np.load('Test_dataset/MNIST_test_label.npy')
test_dataset = [[i, j] for i, j in zip(data_test, label_test)]

# 将测试集数据处理成Dataloder
Test_dataset = Dataset(test_dataset)
testloader = DataLoader(Test_dataset, shuffle=True, batch_size=256)

6.7 初始化中心服务器模型和损失函数并迁移到GPU上

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CNN(in_channels=1, classes=10)
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
criterion.to(device)

6.8 client_update 客户端更新函数✨

这个逻辑说起来比较复杂,我用GPT帮我解释下,然后我再微调一下,具体的解释在代码的下面。

def client_update(client_num, E, model_parameter):
    '''
    客户端更新参数
    :param client_num: 选择到的客户端的编号
    :param E: 本地训练的epoch轮数
    :param model_parameter: 中心服务器发给客户端的模型参数
    :return: 在该客户端上训练好的模型参数
    '''

    dataloader = client_dataloader_list[client_num]     # 获取选择到客户端的dataloder
    client_model = CNN().to(device)                     # 加载一个空模型
    client_model.load_state_dict(model_parameter)       # 加载中心服务器发的模型参数
    client_model.train()                                # 将模型变为训练模式

    # optimizer = torch.optim.SGD(client_model.parameters(), lr=0.01, momentum=0.9)
    optimizer = torch.optim.SGD(client_model.parameters(), lr=0.01)

    for i in range(E):
        correct = 0
        total = 0

        for data, label in dataloader:
            train_data_value, train_data_label = data.to(device), label.to(device)
            train_data_label_pred = client_model(train_data_value)

            loss = criterion(train_data_label_pred, train_data_label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(train_data_label_pred, 1)
            total += train_data_label.size(0)
            correct += (predicted == train_data_label).sum().item()

        accuracy = 100 * correct / total
        print(f'Client:{client_num},Epoch {i+1}/{E}, Accuracy: {accuracy:.2f}%, Loss: {loss.item()}')

    # 返回客户端训练模型的模型参数
    return client_model.state_dict()

        这段代码是一个客户端在联邦学习过程中更新模型参数的逻辑实现。联邦学习是一种分布式的机器学习方法,其中多个客户端独立地在本地数据上训练模型,然后将更新后的模型参数发送到中心服务器进行聚合,从而保护数据隐私。让我们逐行详细解释这段代码。

# 客户端更新参数函数
def client_update(client_num, E, model_parameter):
    '''
    客户端更新参数
    :param client_num: 选择到的客户端的编号
    :param E: 本地训练的epoch轮数
    :param model_parameter: 中心服务器发给客户端的模型参数
    :return: 在该客户端上训练好的模型参数
    '''

这段代码定义了一个函数 client_update,它用于更新客户端模型的参数。函数的参数说明如下:

  • client_num: 选择到的客户端的编号,即具体是哪一个客户端。
  • E: 本地训练的 epoch 轮数,即在每个客户端上训练多少次。
  • model_parameter: 中心服务器发送给客户端的初始模型参数。

函数的目的是在该客户端上使用本地数据训练模型,然后返回训练后的模型参数。

    dataloader = client_dataloader_list[client_num]     # 获取选择到客户端的dataloder

client_dataloader_list 中获取特定客户端的 dataloaderclient_dataloader_list 是一个列表,其中存储了每个客户端的数据加载器。dataloader 用于提供该客户端的训练数据。

    client_model = CNN().to(device)                     # 加载一个空模型

创建一个新的空模型,使用名为 CNN 的神经网络结构(这里假设 CNN 是一个定义好的卷积神经网络)。然后将模型转移到指定的计算设备(device),可能是 GPU 或 CPU 上。

    client_model.load_state_dict(model_parameter)       # 加载中心服务器发的模型参数

将中心服务器发送过来的 model_parameter 加载到模型中。这样,客户端的模型从中心服务器的初始模型参数开始训练。

    client_model.train()                                # 将模型变为训练模式

将模型设置为训练模式。这在 PyTorch 中是必要的,因为模型在训练和评估模式下的行为有所不同(例如,批归一化和 dropout 层的行为)。

    # optimizer = torch.optim.SGD(client_model.parameters(), lr=0.01, momentum=0.9)
    optimizer = torch.optim.SGD(client_model.parameters(), lr=0.01)

这里定义了一个优化器,用于更新模型参数。使用的是随机梯度下降(SGD)优化器,学习率 lr0.01。代码注释掉了包含 momentum 参数的版本,momentum 可以加速收敛,但在此处没有使用。

    for i in range(E):
        correct = 0
        total = 0

开始一个循环,循环次数为 E,即 epoch 的数量。每次 epoch 代表对整个训练数据集进行一次完整的训练。初始化变量 correcttotal,用于统计模型的预测准确率。

        for data, label in dataloader:
            train_data_value, train_data_label = data.to(device), label.to(device)
            train_data_label_pred = client_model(train_data_value)

遍历该客户端的 dataloader,逐批次地获取数据和标签。将数据和标签移到指定设备(GPU 或 CPU)上,然后用模型对输入数据 train_data_value 进行预测,得到 train_data_label_pred

            loss = criterion(train_data_label_pred, train_data_label)

计算损失,criterion 是损失函数,用于评估模型预测与真实标签之间的差异。

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

这是标准的反向传播与梯度更新步骤:

  • optimizer.zero_grad():在每次反向传播之前,将所有参数的梯度置为零,避免梯度累积。
  • loss.backward():计算损失相对于模型参数的梯度。
  • optimizer.step():根据计算出的梯度更新模型参数。
            _, predicted = torch.max(train_data_label_pred, 1)
            total += train_data_label.size(0)
            correct += (predicted == train_data_label).sum().item()

这里进行预测的评估:

  • torch.max(train_data_label_pred, 1):找出每个样本的最大值的索引,即模型预测的类别。
  • total 用于累计总的样本数量。
  • correct 用于累计正确预测的数量,即预测值与真实标签相同的样本数量。
        accuracy = 100 * correct / total
        print(f'Client:{client_num},Epoch {i+1}/{E}, Accuracy: {accuracy:.2f}%, Loss: {loss.item()}')

计算本地训练的准确率,并输出每个 epoch 的信息,包括客户端编号、当前 epoch、准确率和损失值。

    # 返回客户端训练模型的模型参数
    return client_model.state_dict()

函数返回更新后的模型参数,即客户端在本地数据上训练完成后的模型参数(以 state_dict 的形式返回)。

  • 该函数用于实现联邦学习中客户端的本地模型训练过程。
  • 中心服务器发送模型参数到客户端,客户端用这些参数作为初始模型进行本地数据训练。
  • 客户端训练完成后,将更新后的模型参数返回给服务器,以便进一步聚合

6.9 train 中心服务器训练函数✨

中心服务器的代码是整个联邦学习代码核心的核心,也是我自己调试了最长时间的部分,整体也是让GPT给解释下。

def train(client_num, C, E):
    model.train()                               # 将客户端模型变为训练模型
    send_model_parameter = model.state_dict()   # 然后获取即将分发给各个客户端的模型的权重

    random_numbers = random.sample(range(client_num), int(client_num*C))        # 按照比例随机选择出用于本轮训练的模型的索引编号
    client_return_parameter_list = []                                           # 初始化一个列表用与保存每个客户端本地训练好之后返回的模型参数

    for client in random_numbers:                                               # 遍历所有的选择到的客户端的编号(索引号)
        model_parameter = client_update(client, E=E, model_parameter=send_model_parameter)      # 返回每个客户端训练好的模型权重
        client_return_parameter_list.append(model_parameter)                                    # 将每一个客户端的模型权重加载到列表中

    # 先生成一个参数都为0的模型参数的参数字典,用于之后将客户端返回的模型参数都加到改字典上
    aggregated_model_parameter = {key: torch.zeros_like(value, dtype=torch.float32) for key, value in send_model_parameter.items()}

    # 将所有客户端模型的权重都加权求和
    for client_param in client_return_parameter_list:
        for key in aggregated_model_parameter:
            aggregated_model_parameter[key] += client_param[key] * (1 / int(client_num*C))

    # 将求和好的权重加载给中心服务器模型,用于下一轮的发送
    model.load_state_dict(aggregated_model_parameter)
    return model

这段代码是中心服务器在联邦学习过程中分发模型参数并聚合客户端返回的模型参数的实现。联邦学习的目标是让多个客户端使用本地数据独立训练模型,然后服务器聚合客户端的更新,从而提升整体模型性能并保持数据隐私。下面我们逐行详细解释这段代码。

# 中心服务器训练模型的函数
def train(client_num, C, E):
    model.train()                               # 将客户端模型变为训练模型
    send_model_parameter = model.state_dict()   # 然后获取即将分发给各个客户端的模型的权重

该函数名为 train,用于执行联邦学习中的模型聚合过程。

  • client_num: 表示总客户端数量。
  • C: 表示客户端参与比例,决定每轮训练中选择的客户端数。
  • E: 本地训练的 epoch 轮数。

首先,将中心服务器模型设为训练模式 (model.train()),然后通过 model.state_dict() 获取中心服务器模型的当前参数,这些参数将被分发给客户端。

    random_numbers = random.sample(range(client_num), int(client_num*C))        # 按照比例随机选择出用于本轮训练的模型的索引编号
    client_return_parameter_list = []                                           # 初始化一个列表用于保存每个客户端本地训练好之后返回的模型参数
  • random_numbers 使用 random.sample() 随机选择一定数量的客户端(client_num * C)来参与本轮的训练。client_num * C 是根据参与比例选择的客户端数量。
  • client_return_parameter_list 初始化为空列表,用于保存每个客户端本地训练后返回的模型参数。
    for client in random_numbers:                                               # 遍历所有的选择到的客户端的编号(索引号)
        model_parameter = client_update(client, E=E, model_parameter=send_model_parameter)      # 返回每个客户端训练好的模型权重
        client_return_parameter_list.append(model_parameter)                                    # 将每一个客户端的模型权重加载到列表中
  • 对于每个被选择的客户端,调用 client_update 函数来进行本地训练,传入客户端编号、epoch 数量以及要发送的模型参数。
  • client_update 返回训练后的模型参数,将这些参数添加到 client_return_parameter_list 中。
    # 先生成一个参数都为0的模型参数的参数字典,用于之后将客户端返回的模型参数都加到该字典上
    aggregated_model_parameter = {key: torch.zeros_like(value, dtype=torch.float32) for key, value in send_model_parameter.items()}
  • 初始化一个 aggregated_model_parameter 字典,包含所有模型参数的键,且每个键的值都初始化为与对应原始模型参数相同形状的零张量。这个字典将用于累加各客户端返回的模型参数。
    # 将所有客户端模型的权重都加权求和
    for client_param in client_return_parameter_list:
        for key in aggregated_model_parameter:
            aggregated_model_parameter[key] += client_param[key] * (1 / int(client_num*C))
  • 遍历 client_return_parameter_list,对每个客户端返回的模型参数进行聚合。
  • 对于每个参数键,将所有客户端的相应参数值按比例累加,比例为每个客户端权重的平均值(1 / (client_num * C))。
    # 将求和好的权重加载给中心服务器模型,用于下一轮的发送
    model.load_state_dict(aggregated_model_parameter)
    return model
  • 将聚合后的模型参数加载到中心服务器的模型中,为下一轮联邦学习做准备。

  • 最后返回更新后的中心服务器模型。

  • 该函数用于联邦学习中的中心服务器模型训练,通过分发初始模型参数、收集客户端的更新并聚合这些更新来提升全局模型的性能。

  • 在每一轮中,随机选择一部分客户端对其本地数据进行训练,然后聚合各客户端返回的模型参数。

  • 聚合后,更新中心服务器的模型参数以用于下一轮联邦学习。

6.10 测试代码部分

这个就是基础的计算测试集准确率的代码部分,也是让GPT解释一下。

def test():
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for testdata in testloader:
            test_data_value, test_data_label = testdata
            test_data_value, test_data_label = test_data_value.to(device), test_data_label.to(device)
            test_data_label_pred = model(test_data_value)
            _, test_predicted = torch.max(test_data_label_pred.data, dim=1)
            test_total += test_data_label.size(0)
            test_correct += (test_predicted == test_data_label).sum().item()
    test_acc = round(100 * test_correct / test_total, 3)
    print(f'Test Accuracy: {test_acc:.2f}%')
    return test_acc
  • 模型设为评估模式
    model.eval()                               # 将模型设为评估模式

model.eval() 用于将模型设置为评估模式。这在 PyTorch 中很重要,因为评估模式会影响像 dropout 和 batch normalization 这样的层,使其在推理阶段使用训练时的参数而不是随机性。

    test_correct = 0                           # 初始化正确预测数量
    test_total = 0                             # 初始化总测试样本数量

初始化两个变量 test_correcttest_total,分别用于记录正确预测的样本数量和测试样本的总数量。

    with torch.no_grad():                      # 在测试时不需要计算梯度,提升效率

with torch.no_grad() 用于禁用梯度计算,以减少内存消耗和加快推理速度,因为在测试和推理阶段不需要反向传播。

        for testdata in testloader:            # 遍历所有测试数据
            test_data_value, test_data_label = testdata
            test_data_value, test_data_label = test_data_value.to(device), test_data_label.to(device)

遍历测试数据集 testloader,获取每个批次的测试数据和对应标签。然后将这些数据和标签移到计算设备(如 GPU)上。

            test_data_label_pred = model(test_data_value)   # 使用模型对测试数据进行预测
            _, test_predicted = torch.max(test_data_label_pred.data, dim=1)  # 获取预测值的最大概率索引

使用模型对测试数据进行预测,得到每个样本的预测结果 test_data_label_pred。使用 torch.max() 找到预测结果中每个样本的最大概率的索引,即模型预测的类别。

            test_total += test_data_label.size(0)          # 累加测试样本的数量
            test_correct += (test_predicted == test_data_label).sum().item()  # 统计正确预测的数量
  • test_total 累加当前批次的测试样本数量。
  • test_correct 通过比较预测值和真实标签,统计正确预测的样本数量。
    test_acc = round(100 * test_correct / test_total, 3)  # 计算测试集的准确率
    print(f'Test Accuracy: {test_acc:.2f}%')              # 打印测试集的准确率
    return test_acc                                       # 返回测试集的准确率

计算测试集的准确率,并将其四舍五入到小数点后三位。最后打印并返回测试准确率。

  • 该函数用于联邦学习中的中心服务器对聚合后的模型进行测试,以评估其性能。
  • 测试过程中,模型被设置为评估模式,禁用梯度计算以提高效率。
  • 函数通过遍历整个测试数据集来统计预测的正确性,并计算最终的准确率。

6.11 开始训练&绘图&保存训练结果

这里由于我们的随机种子的固定的所以我们是按照不同的超参数配置来命名保存训练文件的。

if __name__ == '__main__':
    test_accuracies = []
    epochs = 1000
    for i in range(epochs):
        train(client_num, C, E)
        test_acc = test()
        print(f"Epoch:{i}——", end='')
        test_accuracies.append(test_acc)

    # Plotting the test accuracy curve
    plt.plot(range(1, epochs + 1), test_accuracies, marker='o', linestyle='-', color='b')
    plt.xlabel('Epoch')
    plt.ylabel('Test Accuracy (%)')
    plt.title(f'Test Accuracy vs. Epoch (C={C}, E={E}, B={B})')
    plt.grid(True)
    plt.show()

    np.save(rf'Train_result/C{C}B{B}E{E}.npy', test_accuracies)

在这里插入图片描述

7.训练结果


训练之后的结果如下:可以把client_updat 中的打印每个客户端每轮训练准确率这块注释掉,不然一致刷刷打印,看不清还占用不少时间。

在这里插入图片描述

8.结果对比 plot_compare_curve.py


import matplotlib.pyplot as plt
import numpy as np

line1 = np.load("Train_result/C0.1B10E5.npy")
line2 = np.load("Train_result/C0.1B50E5.npy")
line3 = np.load("Train_result/C0.1B100E5.npy")
line4 = np.load("Train_result/C0.1B600E1.npy")
line5 = np.load("Train_result/C0.1B600E5.npy")
line6 = np.load("Train_result/C0.2B600E5.npy")
line7 = np.load("Train_result/C1B600E5.npy")


plt.plot(line1, linewidth=3.0,label='C=0.1,B=10,E=5')
plt.plot(line2, linewidth=3.0,label='C=0.1,B=50,E=5')
plt.plot(line3, linewidth=3.0,label='C=0.1,B100,E=5')
plt.plot(line4, linewidth=3.0,label='C=0.1,B=∞,E=1')
plt.plot(line5, linewidth=3.0,label='C=0.1,B=∞,E=5')
plt.plot(line6, linewidth=3.0,label='C=0.2,B=∞,E=5')
plt.plot(line7, linewidth=3.0,label='C=1,B=∞,E=5')

plt.grid()
plt.legend()

plt.xlabel("Tran epoch")
plt.ylabel("Accuracy")
plt.title("Federated Learning")

plt.show()

在这里插入图片描述
在这里插入图片描述

9.结束


这个代码从精度联邦学习论文到复现也是花了不少时间,然后如果不想复制粘贴了手懒了的话,可以到CSDN推广的公众号浩浩的科研笔记中buy,如果是学弟学妹的话,给我私信我也会发你一份。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/894943.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

html基础小练习

需求&#xff1a;实现如上图网页 <!DOCTYPE html> <html><head><meta charset"utf-8"><title>注册界面</title></head><body><div><form action""><table width"400"><t…

Java21虚拟线程:我的锁去哪儿了?

0 前言 最近的文章中&#xff0c;我们详细介绍了当我们迁移到 Java 21 并将代际 ZGC 作为默认垃圾收集器时&#xff0c;我们的工作负载是如何受益的。虚拟线程是我们在这次迁移中兴奋采用的另一个特性。 对虚拟线程新手&#xff0c;它们被描述为“轻量级线程&#xff0c;大大…

推荐10 个令人惊叹的 Python 自动化脚本!

/01/ 剪贴板管理器 你是否曾发现自己忙于处理多个文本片段&#xff0c;而忘记了自己复制了什么&#xff1f;有没有想过有一个工具可以记录你一天中复制的所有内容&#xff1f; 这个自动化脚本可以监控你复制的所有内容&#xff0c;将复制的每个文本无缝地存储在一个时尚的图形…

C++ 在项目中使用GDB

一&#xff1a;GDB 的 TUI 模式使用 GDB的 TUI &#xff08;Text User Interface&#xff09;模式提供了一种图形化的调试体验&#xff0c;允许在终端中同时显示源代码&#xff0c;寄存器和汇编代码等信息&#xff0c;下面是GDB TUI的基本操作和快捷键 1. 显示源代码窗口&…

【C++】类的默认成员函数:深入剖析与应用(上)

&#x1f600;在上一篇文章中我们初步了解了C的基础概念&#xff0c;现在我们进行对C类的默认成员函数进行更加深入的理解&#xff01; &#x1f449;【C新手入门指南&#xff1a;从基础概念到实践之路】 目录 &#x1f4af;前言 &#x1f4af;构造函数 一、构造函数的定义…

常见TCP/IP协议基础——计算机网络

目录 前言常见协议基础常见协议-基于TCP的应用层协议常见协议-基于UDP的应用层协议常见协议-网络层协议习题自测1.邮件发送协议2.接收邮件协议端口3.建立连接4.层次对应关系5.FTP服务器端口 前言 本笔记为备考软件设计师时的重点知识点笔记&#xff0c;关于常见TCP/IP协议基础…

【飞腾加固服务器】全国产化解决方案:飞腾FT2000+/64核,赋能关键任务保驾护航

在信息安全和自主可控的时代背景下&#xff0c;国产化设备的需求与日俱增&#xff0c;尤其是在国防、航空航天、能源和其他关键行业。高可靠性和极端环境设计的国产加固服务器&#xff0c;搭载强大的飞腾FT2000/64核处理器&#xff0c;全面满足国产自主可控的严苛要求。 性能强…

Python案例小练习——小计算器

文章目录 前言一、代码展示二、运行展示 前言 这是用python实现一个简单的计器。 一、代码展示 def calculate(num1, op, num2):if op "":return float(num1) float(num2)elif op "-":return float(num1) - float(num2)elif op "*":return…

stable diffusion安装ai绘画真人动漫win中文版软件

前言 所有的AI设计工具&#xff0c;安装包、模型和插件&#xff0c;都已经整理好了&#xff0c;&#x1f447;获取~ Stable Diffusion&#xff08;简称SD&#xff09;&#xff0c;是通过数学算法实现文本输入&#xff0c;图像输出的开源软件&#xff01; 引用维基百科&#x…

expect工具

一.expect工具介绍 在写脚本的过程当中不可避免的需要去写交互式命令 那么如何让交互式命令在脚本中自动执行&#xff1f; 使用expect工具 作用&#xff1a;捕获交互式的输出&#xff0c;自动执行交互式命令 如上图所示&#xff0c;可以使用expect工具去捕获交互式命令的提…

什么是大数据分析:定义、优缺点、应用、机遇和风险

大数据分析的概念已经成为我们社会不可或缺的一部分。众多公司和机构已经开发了大数据应用程序&#xff0c;取得了不同程度的成功。社交媒体平台和传感器等技术正在以前所未有的速度生成数据&#xff0c;就像一条装配线。如今&#xff0c;几乎所有东西都是物联网的一部分&#…

[Xshell] Xshell的下载安装使用及连接linux过程 详解(附下载链接)

前言 Xshell.zip 链接&#xff1a;https://pan.quark.cn/s/5d9d1836fafc 提取码&#xff1a;SPn7 安装 下载后解压得到文件 安装路径不要有中文 打开文件 注意&#xff01;360等软件会拦截创建注册表的行为&#xff0c;需要全部允许、同意。或者退出360以后再安装。 在“绿化…

vscode pylance怎么识别通过sys.path.append引入的库

问题 假如我有一个Python项目 - root_path -- moduleA ---- fileA.py -- moduleB ---- fileB.py# fileAimport sys sys.path.append(moduleB)import fileB # vscode pylance找不到&#xff0c;因为sys.path.append(moduleB)是动态添加的print(fileB)结果 代码正常运行但是vs…

【北京迅为】《STM32MP157开发板嵌入式开发指南》- 第五十四章 Pinctrl 子系统和 GPIO 子系统

iTOP-STM32MP157开发板采用ST推出的双核cortex-A7单核cortex-M4异构处理器&#xff0c;既可用Linux、又可以用于STM32单片机开发。开发板采用核心板底板结构&#xff0c;主频650M、1G内存、8G存储&#xff0c;核心板采用工业级板对板连接器&#xff0c;高可靠&#xff0c;牢固耐…

基于百度智能体开发爱情三十六计

基于百度智能体开发爱情三十六计 文章目录 基于百度智能体开发爱情三十六计1. 爱情三十六计智能体2. 三十六计开发创意3. 智能体开发实践3.1 基础配置3.2 进阶配置3.3 调优心得3.4可能会遇到的问题 4. 为什么选择文心智能体平台 1. 爱情三十六计智能体 爱情三十六计 是一款基于…

Kaggle竞赛——森林覆盖类型分类

目录 1. 竞赛简要2. 数据分析2.1 特征类型统计2.2 四个荒野区域数据分析2.3 连续特征分析2.4 离散特征分析2.5 特征相关性热图2.6 特征间的散点关系图 3. 特征工程3.1 特征组合3.2 连续特征标准化 4. 模型搭建4.1 模型定义4.2 绘制混淆矩阵和ROC曲线4.3 模型对比与选择 5. 测试…

从0-1实战演练后台管理系统 (3)还在寻找优秀的后台管理系统?Pure Admin 源码及目录结构带你一探究竟!

一、获取源码: 从-gitee-上拉取从 Gitee 上拉取 1、完整版前端代码 git clone https://gitee.com/yiming_chang/vue-pure-admin.git2、国际化精简版前端代码 git clone -b i18n https://gitee.com/yiming_chang/pure-admin-thin.git3、非国际化精简版前端代码 git clone ht…

【Vue】Vue扫盲(七)如何使用Vue脚手架进行模块化开发及遇到的问题(cmd中无法识别vue命令、vue init webpack 命令执行失败)

上篇文章&#xff1a; Vue】Vue扫盲&#xff08;六&#xff09;关于 Vue 项目运行以及文件关系和关联的详细介绍 文章目录 一、安装 相关工具二、处理相关问题问题一&#xff1a;vue -v 提示 vue不是内部或外部命令&#xff0c;也不是可运行的程序或批处理文件。问题二&#xf…

wifi、热点密码破解 - python

乐子脚本&#xff0c;有点小慢&#xff0c;试过多线程&#xff0c;系统 wifi 连接太慢了&#xff0c;需要时间确认&#xff0c;多线程的话系统根本反应不过来。 也就可以试试破解别人的热点&#xff0c;一般都是 123456 这样的傻鸟口令 # coding:utf-8 import pywifi from pyw…

el-table修改指定列字体颜色 ,覆盖划过行的高亮显示文字颜色

修改指定列字体颜色 ,覆盖划过行的高亮显示文字颜色 代码如下&#xff1a; <div class"c1"><el-table:data"tableData"striperow-class-name"custom-table-row"style"width:100%"cell-mouse-enter"lightFn"cell-…