【传知代码】偏标记学习+图像分类(论文复现)

前言:偏标记学习,顾名思义,就是在训练数据集中,每个样本的标签不是完全确定的,而是由多个可能的标签组成的集合。这种学习范式更加贴近现实世界的场景,因为在很多情况下,我们无法为图像提供精确无误的单一标签。例如,在一张包含多种花卉的照片中,我们可能只知道照片中包含了哪些花卉,但不确定每朵花卉的具体位置。那么,如何将偏标记学习的思想融入到图像分类的实践中呢?这不仅是一个理论问题,更是一个具有现实意义的工程问题。在这篇文章中,我们将深入探讨偏标记学习与图像分类的结合点,分析其中的关键技术与挑战,并展望未来的发展方向。

本文所涉及所有资源均在传知代码平台可获取

目录

概述

演示效果

核心代码

写在最后


概述

        随着深度神经网络技术的不断进步,对于标注数据在机器学习任务中的需求也在持续上升。然而,大规模的标注数据高度依赖于人力资源和标注者的专业技能。弱监督学习方法能够有效地解决这个问题,因为它并不需要完整且精确的数据标注。这篇论文集中探讨了一个关键的弱监督学习难题,即偏标记学习(Partial Label Learning),在这个问题中,每一个训练案例都与一组可能的标签有关,但只有其中一个标签是真实存在的,如下图所示:

本文复现论文 地址 ,提出的偏标记学习方法,该论文提出了一种渐进式真实标签识别方法,旨在训练过程中逐渐确定样本的真实标签。该论文所提出的方法获得了接近监督学习的性能,且与具体的网络结构、损失函数、随机优化算法无关,具体如下所示:

传统的监督学习常用交叉熵损失和随机梯度下降来优化深度神经网络。交叉熵损失定义如下:

其中, xx 表示样本特征;y=[y1,y2,…,yc]y=[y1​,y2​,…,yc​] 表示样本标签,其为独热码,即除了真实标签对应维度值为 1,其余为零;fi(x;θ)fi​(x;θ) 表示模型预测样本 xx 标签为 ii 的概率。该论文提出的方法使用一个软标签 y^=[y^1,y^2,…,y^c]y^​=[y^​1​,y^​2​,…,y^​c​],其对任意 i∈[0,c]i∈[0,c] 满足 ∑iy^i=1∑i​y^​i​=1 且 0≤y^i≤10≤y^​i​≤1。为了使用该软标签,论文根据候选标签集 ss 对软标签进行初始化:

为了渐进式地识别真实标签,算法在每次更新参数之前,根据预测结果为下轮训练使用的软标签赋值,其中,I(j∈s)=1I(j∈s)=1 当且仅当 j∈sj∈s 为真,否则 I(j∈s)=0I(j∈s)=0:

演示效果

解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令:

unzip Proden-implemention.zip
cd Proden-implemention

代码的运行环境可通过如下命令进行配置:

pip install -r requirements.txt

运行如下命令以下载并解压数据集:

bash download.sh

如果希望在本地训练模型,请运行如下命令:

python main.py -c [你的配置文件路径] -r [选择下者之一:"train"、"test"、"infer"]

如果希望在线部署,请运行如下命令:

python main-flask.py

在 CIFAR-10[2] 数据集和 12 层的 ConvNet[3] 网络上训练了一份模型参数。为了测试其准确率,需要配置环境并运行main.py脚本,得到结果如下,由图可见,该算法在测试集上获得了 89.8% 的准确率:

进一步地,测试训练出的模型在真实图片上的预测结果。在线部署模型后,将一张轮船的图片输入,可以得到输出的预测类型为 “Ship”:

所使用的数据集(CIFAR-10)共包含十个类,示意图如下:

核心代码

下面这段代码实现了一个 Proden 模型的训练过程,Proden 是一种半监督学习算法,用于解决标注数据不足的问题,它通过生成偏标记来利用未标记数据进行训练,从而提高模型的泛化能力,具体来说,这段代码的主要步骤如下:

1)读取数据集。支持 CIFAR-10 和 CIFAR-100 两种数据集。

2)生成偏标记,通过 datasets.generate_partial_labels 函数生成偏标记,用于利用未标记数据进行训练。

3)计算数据的均值和方差,用于模型输入的标准化。

4)加载模型。支持 ResNet18 和 ConvNet 两种模型。

5)设置学习率等超参数,使用 SGD 优化器和 StepLR 学习率调度器。

6)进行训练。使用 DataLoader 加载训练数据集,对每个 batch 进行训练,计算交叉熵损失并更新模型参数。在每个 epoch 结束后,调整学习率。

7)如果 save 参数为 True,则保存模型。

其中,CE_loss 函数是一个交叉熵损失函数,用于计算模型预测值和真实标签之间的差异。这段代码的主要目的是训练一个 Proden 模型,以解决标注数据不足的问题,它使用半监督学习的思想,通过生成偏标记来利用未标记数据进行训练,从而提高模型的泛化能力,如下:
 

import models
import datasets
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdm

def CE_loss(probs, targets):
    """交叉熵损失函数"""
    loss = -torch.sum(targets * torch.log(probs), dim = -1)
    loss_avg = torch.sum(loss)/probs.shape[0]
    return loss_avg

class Proden:
    def __init__(self, configs):
        self.configs = configs
    
    def train(self, save = False):
        configs = self.configs
        # 读取数据集
        dataset_path = configs['dataset path']
        if configs['dataset'] == 'CIFAR-10':
            train_data, train_labels, test_data, test_labels = datasets.cifar10_read(dataset_path)
            train_dataset = datasets.Cifar(train_data, train_labels)
            test_dataset = datasets.Cifar(test_data, test_labels)
            output_dimension = 10
        elif configs['dataset'] == 'CIFAR-100':
            train_data, train_labels, test_data, test_labels = datasets.cifar100_read(dataset_path)
            train_dataset = datasets.Cifar(train_data, train_labels)
            test_dataset = datasets.Cifar(test_data, test_labels)
            output_dimension = 100
        # 生成偏标记
        partial_labels = datasets.generate_partial_labels(train_labels, configs['partial rate'])
        train_dataset.load_partial_labels(partial_labels)
        # 计算数据的均值和方差,用于模型输入的标准化
        mean = [np.mean(train_data[:, i, :, :]) for i in range(3)]
        std = [np.std(train_data[:, i, :, :]) for i in range(3)]
        normalize = transforms.Normalize(mean, std)
        # 设备:GPU或CPU
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # 加载模型
        if configs['model'] == 'ResNet18':
            model = models.ResNet18(output_dimension = output_dimension).to(device)
        elif configs['model'] == 'ConvNet':
            model = models.ConvNet(output_dimension = output_dimension).to(device)
        # 设置学习率等超参数
        lr = configs['learning rate']
        weight_decay = configs['weight decay']
        momentum = configs['momentum']
        optimizer = optim.SGD(model.parameters(), lr = lr, weight_decay = weight_decay, momentum = momentum)
        lr_step = configs['learning rate decay step']
        lr_decay = configs['learning rate decay rate']
        lr_scheduler = StepLR(optimizer, step_size=lr_step, gamma=lr_decay)
        for epoch_id in range(configs['epoch count']):
            # 训练模型
            train_dataloader = DataLoader(train_dataset, batch_size = configs['batch size'], shuffle = True)
            model.train()
            for batch in tqdm(train_dataloader, desc='Training(Epoch %d)' % epoch_id, ascii=' 123456789#'):
                ids = batch['ids']
                # 标准化输入
                data = normalize(batch['data'].to(device))
                partial_labels = batch['partial_labels'].to(device)
                targets = batch['targets'].to(device)
                optimizer.zero_grad()
                # 计算预测概率
                logits = model(data)
                probs = F.softmax(logits, dim=-1)
                # 更新软标签
                with torch.no_grad():
                    new_targets = F.normalize(probs * partial_labels, p=1, dim=-1)
                    train_dataset.targets[ids] = new_targets.cpu().numpy()
                # 计算交叉熵损失
                loss = CE_loss(probs, targets)
                loss.backward()
                # 更新模型参数
                optimizer.step()
            # 调整学习率
            lr_scheduler.step()

写在最后

        在深入探讨了偏标记学习与图像分类的交汇点后,我们不禁为这一领域的潜力与前景所震撼。偏标记学习为我们提供了一种处理现实世界中不确定性和模糊性的强大工具,而图像分类作为计算机视觉的核心任务之一,更是将这一工具的应用推向了新的高度。

        通过本文的阐述,我们可以看到偏标记学习在图像分类中的广泛应用和显著效果。它不仅能够应对传统图像分类方法在面对复杂场景时的局限性,还能有效地利用带有部分标签的数据进行学习,提高了模型的泛化能力和鲁棒性。无论是在医学图像分析、安防监控还是社交媒体内容识别等领域,偏标记学习都展现出了其独特的优势和价值。

详细复现过程的项目源码、数据和预训练好的模型可从该文章下方附件获取。

【传知科技】关注有礼     公众号、抖音号、视频号

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/680720.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

可的哥(Codigger)推出Monaco编辑器插件,提升编程体验

Monaco编辑器,作为业界领先的代码编辑器,在编程体验中发挥着不可或缺的重要作用,能够在多种编程语言和开发环境中表现出色,为开发者提供高效、便捷的编程环境。可的哥(Codigger)在应用商店上线Monaco编辑器…

在618集中上新,蕉下、VVC们为何押注拼多多?

编辑|Ray 自前两年崛起的防晒产品,今年依旧热度不减。 头部品牌蕉下,2020年入驻拼多多,如今年销售额已过亿元。而自去年起重点押注拼多多的时尚防晒品牌VVC,很快销量翻番。这两家公司,不约而同在618之前上…

设备巡检系统是如何实现一次操作闭环管理的

设备巡检系统通过一系列功能设计,实现了从任务分配到问题处理的一次操作闭环管理。以下是具体的实现方式: 一、多类型任务无感操作 任务识别与整合:系统能够自动识别各种巡检、重大危险源排查及现场检修等任务的类型和优先级,并…

企业全面管理解决方案:基于Java技术的ERP管理系统源码

功能模块与描述: ERP首页: 销售统计与采购统计:实时展示销售和采购金额的统计数据。折线图统计:通过图表直观展示销售和采购趋势。 采购管理: 采购订单管理:处理采购订单的搜索、新增、导出等。采购入库与退…

进程同步的基本元素

目录 临界资源 临界区 信号量机制 整形信号量 记录型信号量 AND信号量 信号量集 信号量的应用 实现进程互斥 实现前驱关系 管程机制 总结 临界资源 I/O设备属于临界资源。著名的生产者-消费者问题就是关于临界资源的争夺产生的进程同步的问题。 生产者-消费者 描…

产品经理:做好有效的客户需求分析

需求分析是产品开发过程中的重要环节,它直接决定了产品是否能够满足市场需求和用户期望。通过深入了解客户需求,产品经理可以确保产品功能的设计符合用户的实际需求,从而提高产品的用户满意度和市场竞争力。 一、识别用户需求 识别用户需求…

mysql用户管理知识点

1、权限表 1.1、user表 1.1.1、用户列 Host、User、Password分别表示主机名、用户名、密码 1.1.2、权限列 决定了用户的权限,描述了在全局范围内允许对数据和数据库进行操作。 1.1.3、安全列 安全列有6个字段,其中两个是ssl相关的,2个是x509相…

虚拟仿真实训平台如何与不同专业进行融合?

虚拟仿真实训平台根据跨专业实训教学和职业培训的不同特点,兼顾实训课程设计的专业性和兼容性,根据不同专业特性确定虚拟仿真实训教学内容,研发虚拟仿真实训教学资源,优化人才培养方案和职业培训方案,改革实训教学体系…

游戏陪玩系统源码线上陪玩软件开发电竞陪练小程序陪玩APP

思维导图 规则说明 支持陪玩官,一级和二级,系统配置设置的是初始化佣金,可以单个设置某个人的佣金比例,分销模块只涉及到下单交易模块,其他的不参与,邀请的用户被下单,即可获得收益。 理解规则…

【面试笔记】C++ 软件开发工程师,智驾研发方向(非算法)

文章目录 1. 前言2. 基础问题2.1 什么是C++中的类?如何定义和实例化一个类?2.2 请解释C++中的继承和多态性。2.3 什么是虚函数?为什么在基类中使用虚函数?2.4 解释封装、继承和多态的概念,并提供相应的代码示例。2.5 如何处理内存泄漏问题?提供一些常见的内存管理技术。2…

LabVIEW冲击响应谱分析系统

LabVIEW冲击响应谱分析系统 开发了一种基于LabVIEW开发的冲击响应谱分析系统,该系统主要用于分析在短时间内高量级输入力作用下装备的响应。通过改进的递归数字滤波法和样条函数法进行冲击响应谱的计算,实现了冲击有效持续时间的自动提取和响应谱的精准…

13.56MHz电动车NFC刷卡解锁

随着电动车市场的快速发展,车主对车辆的智能化和便捷性的要求也在不断提升。仪表盘作为电动车的重要组成部分,不仅需要提供基本的行驶信息,还需要具备智能交互功能。 基于13.56MHz频率的NFC(近场通信)技术为电动车仪表…

李国武:六西格玛绿带项目的实施过程中可能遇到哪些问题?

作为六西格玛管理体系中的中坚力量,绿带项目在企业的转型升级中扮演着举足轻重的角色。然而,在实施六西格玛绿带项目的过程中,企业往往会遭遇一系列挑战。具体如深圳天行健企业管理咨询公司下文所述: 首先,人才与知识的…

雷士大路灯有必要买吗?雷士、书客、孩视宝护眼落地灯实测PK!

面对市面上众多的护眼大路灯品牌,其中雷士、书客和孩视宝这几款大路灯受到了广泛的青睐,也是热度比较高的几款产品,正是因为这么多款大路灯,很多伙伴在看到文章推荐后很纠结,不知道如何选择,也有一部分伙伴…

微信小游戏性能优化解决方案全新发布

小游戏凭借其简单易上手、玩法多样、互动性强的特点,迅速在市场中崭露头角。MMO、ARPG、卡牌等游戏类型也纷纷入局。玩家对启动时间长、发热、加载缓慢、闪退等问题也越来越敏感。 为了突破这些性能瓶颈,UWA全新发布了针对微信小游戏的性能优化解决方案…

水库大坝安全监测系统打通监控数据“最后一公里”

一、概述 我国有水库8万座左右,其中土石坝多数,病险水库占水库也很多。众所周知,水库在防洪、兴利上具有重要的调节作用,如何保证水库安全,及合理有效的利用水资源,是水利建设者需要探讨的主要内容。科学技…

OpenCV学习(4.2) 图像的几何变换

1.目标 学习将不同的几何变换应用到图像上,如平移、旋转、仿射变换等。你会看到这些函数: cv.getPerspectiveTransform 2.缩放 缩放是调整图片的大小。 OpenCV 使用 cv.resize() 函数进行调整。可以手动指定图像的大小,也可以指定比例因子。可以使用不…

【python】成功解决“ModuleNotFoundError: No module named ‘gensim’”错误的全面指南

成功解决“ModuleNotFoundError: No module named ‘gensim’”错误的全面指南 在Python编程中,尤其是进行文本挖掘和自然语言处理(NLP)时,gensim库是一个常用的工具,用于主题建模、文档相似度计算、词向量表示&#x…

【教程】使用 Tailchat 搭建团队内部聊天平台,Slack 的下一个替代品!

前言 多人协作,私有聊天一直是团队协作的关键点,现在有很多专注于团队协作的应用和平台,比如飞书、企业微信和Slack等。这期教程将带你手把手的搭建一个在线的团队协作向聊天室,希望对你有所帮助! 本期聊天室使用TailChat作为服务…

服务器数据恢复—raid5阵列上分配的卷被删除后重建如何恢复被删除卷的数据?

服务器存储数据恢复环境: 某品牌FlexStorage P5730服务器存储,存储中有一组由24块硬盘组建的RAID5阵列,包括1块热备硬盘。 服务器存储故障: 存储中的2个卷被删除,删除之后重建了一个新卷。需要恢复之前删除的一个卷的数…