前言:偏标记学习,顾名思义,就是在训练数据集中,每个样本的标签不是完全确定的,而是由多个可能的标签组成的集合。这种学习范式更加贴近现实世界的场景,因为在很多情况下,我们无法为图像提供精确无误的单一标签。例如,在一张包含多种花卉的照片中,我们可能只知道照片中包含了哪些花卉,但不确定每朵花卉的具体位置。那么,如何将偏标记学习的思想融入到图像分类的实践中呢?这不仅是一个理论问题,更是一个具有现实意义的工程问题。在这篇文章中,我们将深入探讨偏标记学习与图像分类的结合点,分析其中的关键技术与挑战,并展望未来的发展方向。
本文所涉及所有资源均在传知代码平台可获取
目录
概述
演示效果
核心代码
写在最后
概述
随着深度神经网络技术的不断进步,对于标注数据在机器学习任务中的需求也在持续上升。然而,大规模的标注数据高度依赖于人力资源和标注者的专业技能。弱监督学习方法能够有效地解决这个问题,因为它并不需要完整且精确的数据标注。这篇论文集中探讨了一个关键的弱监督学习难题,即偏标记学习(Partial Label Learning),在这个问题中,每一个训练案例都与一组可能的标签有关,但只有其中一个标签是真实存在的,如下图所示:
本文复现论文 地址 ,提出的偏标记学习方法,该论文提出了一种渐进式真实标签识别方法,旨在训练过程中逐渐确定样本的真实标签。该论文所提出的方法获得了接近监督学习的性能,且与具体的网络结构、损失函数、随机优化算法无关,具体如下所示:
传统的监督学习常用交叉熵损失和随机梯度下降来优化深度神经网络。交叉熵损失定义如下:
其中, xx 表示样本特征;y=[y1,y2,…,yc]y=[y1,y2,…,yc] 表示样本标签,其为独热码,即除了真实标签对应维度值为 1,其余为零;fi(x;θ)fi(x;θ) 表示模型预测样本 xx 标签为 ii 的概率。该论文提出的方法使用一个软标签 y^=[y^1,y^2,…,y^c]y^=[y^1,y^2,…,y^c],其对任意 i∈[0,c]i∈[0,c] 满足 ∑iy^i=1∑iy^i=1 且 0≤y^i≤10≤y^i≤1。为了使用该软标签,论文根据候选标签集 ss 对软标签进行初始化:
为了渐进式地识别真实标签,算法在每次更新参数之前,根据预测结果为下轮训练使用的软标签赋值,其中,I(j∈s)=1I(j∈s)=1 当且仅当 j∈sj∈s 为真,否则 I(j∈s)=0I(j∈s)=0:
演示效果
解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令:
unzip Proden-implemention.zip
cd Proden-implemention
代码的运行环境可通过如下命令进行配置:
pip install -r requirements.txt
运行如下命令以下载并解压数据集:
bash download.sh
如果希望在本地训练模型,请运行如下命令:
python main.py -c [你的配置文件路径] -r [选择下者之一:"train"、"test"、"infer"]
如果希望在线部署,请运行如下命令:
python main-flask.py
在 CIFAR-10[2] 数据集和 12 层的 ConvNet[3] 网络上训练了一份模型参数。为了测试其准确率,需要配置环境并运行main.py脚本,得到结果如下,由图可见,该算法在测试集上获得了 89.8% 的准确率:
进一步地,测试训练出的模型在真实图片上的预测结果。在线部署模型后,将一张轮船的图片输入,可以得到输出的预测类型为 “Ship”:
所使用的数据集(CIFAR-10)共包含十个类,示意图如下:
核心代码
下面这段代码实现了一个 Proden 模型的训练过程,Proden 是一种半监督学习算法,用于解决标注数据不足的问题,它通过生成偏标记来利用未标记数据进行训练,从而提高模型的泛化能力,具体来说,这段代码的主要步骤如下:
1)读取数据集。支持 CIFAR-10 和 CIFAR-100 两种数据集。
2)生成偏标记,通过 datasets.generate_partial_labels 函数生成偏标记,用于利用未标记数据进行训练。
3)计算数据的均值和方差,用于模型输入的标准化。
4)加载模型。支持 ResNet18 和 ConvNet 两种模型。
5)设置学习率等超参数,使用 SGD 优化器和 StepLR 学习率调度器。
6)进行训练。使用 DataLoader 加载训练数据集,对每个 batch 进行训练,计算交叉熵损失并更新模型参数。在每个 epoch 结束后,调整学习率。
7)如果 save 参数为 True,则保存模型。
其中,CE_loss 函数是一个交叉熵损失函数,用于计算模型预测值和真实标签之间的差异。这段代码的主要目的是训练一个 Proden 模型,以解决标注数据不足的问题,它使用半监督学习的思想,通过生成偏标记来利用未标记数据进行训练,从而提高模型的泛化能力,如下:
import models
import datasets
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdm
def CE_loss(probs, targets):
"""交叉熵损失函数"""
loss = -torch.sum(targets * torch.log(probs), dim = -1)
loss_avg = torch.sum(loss)/probs.shape[0]
return loss_avg
class Proden:
def __init__(self, configs):
self.configs = configs
def train(self, save = False):
configs = self.configs
# 读取数据集
dataset_path = configs['dataset path']
if configs['dataset'] == 'CIFAR-10':
train_data, train_labels, test_data, test_labels = datasets.cifar10_read(dataset_path)
train_dataset = datasets.Cifar(train_data, train_labels)
test_dataset = datasets.Cifar(test_data, test_labels)
output_dimension = 10
elif configs['dataset'] == 'CIFAR-100':
train_data, train_labels, test_data, test_labels = datasets.cifar100_read(dataset_path)
train_dataset = datasets.Cifar(train_data, train_labels)
test_dataset = datasets.Cifar(test_data, test_labels)
output_dimension = 100
# 生成偏标记
partial_labels = datasets.generate_partial_labels(train_labels, configs['partial rate'])
train_dataset.load_partial_labels(partial_labels)
# 计算数据的均值和方差,用于模型输入的标准化
mean = [np.mean(train_data[:, i, :, :]) for i in range(3)]
std = [np.std(train_data[:, i, :, :]) for i in range(3)]
normalize = transforms.Normalize(mean, std)
# 设备:GPU或CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载模型
if configs['model'] == 'ResNet18':
model = models.ResNet18(output_dimension = output_dimension).to(device)
elif configs['model'] == 'ConvNet':
model = models.ConvNet(output_dimension = output_dimension).to(device)
# 设置学习率等超参数
lr = configs['learning rate']
weight_decay = configs['weight decay']
momentum = configs['momentum']
optimizer = optim.SGD(model.parameters(), lr = lr, weight_decay = weight_decay, momentum = momentum)
lr_step = configs['learning rate decay step']
lr_decay = configs['learning rate decay rate']
lr_scheduler = StepLR(optimizer, step_size=lr_step, gamma=lr_decay)
for epoch_id in range(configs['epoch count']):
# 训练模型
train_dataloader = DataLoader(train_dataset, batch_size = configs['batch size'], shuffle = True)
model.train()
for batch in tqdm(train_dataloader, desc='Training(Epoch %d)' % epoch_id, ascii=' 123456789#'):
ids = batch['ids']
# 标准化输入
data = normalize(batch['data'].to(device))
partial_labels = batch['partial_labels'].to(device)
targets = batch['targets'].to(device)
optimizer.zero_grad()
# 计算预测概率
logits = model(data)
probs = F.softmax(logits, dim=-1)
# 更新软标签
with torch.no_grad():
new_targets = F.normalize(probs * partial_labels, p=1, dim=-1)
train_dataset.targets[ids] = new_targets.cpu().numpy()
# 计算交叉熵损失
loss = CE_loss(probs, targets)
loss.backward()
# 更新模型参数
optimizer.step()
# 调整学习率
lr_scheduler.step()
写在最后
在深入探讨了偏标记学习与图像分类的交汇点后,我们不禁为这一领域的潜力与前景所震撼。偏标记学习为我们提供了一种处理现实世界中不确定性和模糊性的强大工具,而图像分类作为计算机视觉的核心任务之一,更是将这一工具的应用推向了新的高度。
通过本文的阐述,我们可以看到偏标记学习在图像分类中的广泛应用和显著效果。它不仅能够应对传统图像分类方法在面对复杂场景时的局限性,还能有效地利用带有部分标签的数据进行学习,提高了模型的泛化能力和鲁棒性。无论是在医学图像分析、安防监控还是社交媒体内容识别等领域,偏标记学习都展现出了其独特的优势和价值。
详细复现过程的项目源码、数据和预训练好的模型可从该文章下方附件获取。
【传知科技】关注有礼 公众号、抖音号、视频号