文章目录
- 前言
- 图片效果:
- 独立同分布效果
- 非独立同分布效果
- 一、参数
- 输入
- 输出
- 二、代码
- 可视化:
- 标签划分:
- 代码调用
前言
用于实现并控制联邦学习客户端之间数据集非独立同分布,并将效果可视化
图片效果:
独立同分布效果
- 对不同类别的分配效果可视化:
- 对不同客户端拥有的数据集的可视化:
非独立同分布效果
- 对不同类别的分配效果可视化:
- 对不同客户端拥有的数据集的可视化:
一、参数
输入
- classes:标签名称,列表类型
-示例:[‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’] - train_labels:数据集标签,列表类型
-示例:[6 9 9 … 9 1 1] - alpha:浓度参数,浮点数数值,(0,+∞)
- client_number:客户端数量,整型数值
输出
client_idcs:各客户端拥有的数据图片下标,列表类型
示例:
client_idcs=[array([ 29, 30, 35, ..., 9676, 9683, 9701]),
array([ 9171, 9181, 9193, ..., 20167, 20172, 20176]),
array([18920, 18925, 18935, ..., 29604, 29609, 29628]),
array([28887, 28897, 28912, ..., 38602, 38621, 38644]),
array([39601, 39606, 39619, ..., 49963, 49971, 49997])]
二、代码
可视化:
def draw_dataset(classes,labels,client_idcs,num_users):
#设置图片保存位置
# 构建save文件夹的路径
# 获取当前文件的父目录
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
save_dir = os.path.join(parent_dir, 'save/img')
# 如果save文件夹不存在,则创建它
if not os.path.exists(save_dir):
os.makedirs(save_dir)
file_path1 = os.path.join(save_dir, '1.png')
file_path2 = os.path.join(save_dir, '2.png')
# 展示不同label划分到不同client的情况
n_classes = 10#cifar10有10个类别
plt.figure(figsize=(12, 8))
plt.hist([labels[idc]for idc in client_idcs], stacked=True,
bins=np.arange(min(labels)-0.5, max(labels) + 1.5, 1),
label=["Client {}".format(i) for i in range(num_users)],
rwidth=0.5)
plt.xticks(np.arange(n_classes), classes)
plt.xlabel("Label type")
plt.ylabel("Number of samples")
plt.legend(loc="upper right")
plt.title("Display Label Distribution on Different Clients")
plt.savefig(file_path1)
# 展示不同client上的label分布
plt.figure(figsize=(12, 8))
label_distribution = [[] for _ in range(n_classes)]
for c_id, idc in enumerate(client_idcs):
for idx in idc:
label_distribution[labels[idx]].append(c_id)
plt.hist(label_distribution, stacked=True,
bins=np.arange(-0.5, num_users + 1.5, 1),
label=classes, rwidth=0.5)
plt.xticks(np.arange(num_users), ["Client %d" %
c_id for c_id in range(num_users)])
plt.xlabel("Client ID")
plt.ylabel("Number of samples")
plt.legend()
plt.title("Display Label Distribution on Different Clients")
plt.savefig(file_path2)
标签划分:
def dirichlet_split_noniid(classes,train_labels, alpha=100.0, client_number=5):
'''
参数为 alpha 的 Dirichlet 分布将数据索引划分为 n_clients 个子集
'''
# 总类别数
n_classes = train_labels.max()+1#也可以自己手动设置
label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes)
# 记录每个类别对应的样本下标
# 返回二维数组
class_idcs = [np.argwhere(train_labels==y).flatten()
for y in range(n_classes)]
# 定义一个空列表作最后的返回值
client_idcs = [[] for _ in range(n_clients)]
# 记录N个client分别对应样本集合的索引
for c, fracs in zip(class_idcs, label_distribution):
# np.split按照比例将类别为k的样本划分为了N个子集
# for i, idcs 为遍历第i个client对应样本集合的索引
for i, idcs in enumerate(np.split(c, (np.cumsum(fracs)[:-1]*len(c)).astype(int))):
client_idcs[i] += [idcs]
client_idcs = [np.concatenate(idcs) for idcs in client_idcs]
draw_dataset(classes,train_labels,client_idcs, n_clients)
return client_idcs
代码调用
train_dataset = datasets.CIFAR10(data_dir, train=True, download=True, transform=trans_cifar10_train)
train_client_idcs = dirichlet_split_noniid(train_dataset.classes,np.array(train_dataset.targets),alpha=100.0,n_clients=5)