CVPR2022人脸识别Partial FC论文及代码学习笔记

论文链接:https://openaccess.thecvf.com/content/CVPR2022/papers/An_Killing_Two_Birds_With_One_Stone_Efficient_and_Robust_Training_CVPR_2022_paper.pdf

代码链接:insightface/recognition/arcface_torch at master · deepinsight/insightface · GitHub

背景

使用基于百万规模的数据集和基于margin的softmax损失函数来学习区分性的embeddings是当前人脸识别的SOTA方法。然而,全连接层的内存和计算成本随着训练集中ID数量的增加而线性增加。此外,大规模训练数据存在类间冲突(同一个人被分成不同ID)和长尾分布的问题。

传统FC

将传统的FC层应用在大规模的数据集上时,存在以下缺陷:

1、gradient confusion under interclass conflict

WebFace42M里有很多不同类别对之间的余弦相似度大于0.4,这表明类间冲突仍然存在于这些清洗过的数据集中。直接优化的话会导致gradient confusion(同一个人的特征非常相似却要掰成两个ID)

2、centers of tail classes undergo too many passive updates

每个iteration都优化图片数量很少的id,可能会导致负优化

3、the storage and calculation of the FC layer can easily exceed current GPU capabilities

PartialFC

在训练期间仍然维护所有类别中心,但只随机采样一小部分负类别中心来计算基于margin的softmax损失,而不是在每次迭代中使用所有负类别中心。更具体地说,首先从每个GPU收集embeddings和标签,然后将组合的特征和标签分布到所有GPU。为了平衡每个GPU的内存使用和计算成本,为每个GPU设置了一个内存缓冲区(下面代码中的perm)。内存缓冲区的大小由类别总数和负类别中心的采样率决定。在每个GPU上,首先通过标签选择正类中心并放入缓冲区,然后随机选择一小部分负类中心(负类中心的数量为self.sample_rate * self.num_local)填充缓冲区的其余部分,

def sample(self, labels, index_positive):
    """
    This functions will change the value of labels
    Parameters:
    -----------
    labels: torch.Tensor
        pass
    index_positive: torch.Tensor
        pass
    optimizer: torch.optim.Optimizer
        pass
    """
    with torch.no_grad():
        positive = torch.unique(labels[index_positive], sorted=True).cuda()
        if self.num_sample - positive.size(0) >= 0:
            perm = torch.rand(size=[self.num_local]).cuda()
            perm[positive] = 2.0
            index = torch.topk(perm, k=self.num_sample)[1].cuda()
            index = index.sort()[0].cuda()
        else:
            index = positive
        self.weight_index = index

        labels[index_positive] = torch.searchsorted(index, labels[index_positive])

    return self.weight[self.weight_index]

随后,使用选出的样本中心去与特征相乘并计算基于margin的softmax损失。

PFC在DDP框架下的流程图如下图所示,

整体代码如下,

class PartialFC_V2(torch.nn.Module):
    """
    https://arxiv.org/abs/2203.15565
    A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).
    When sample rate less than 1, in each iteration, positive class centers and a random subset of
    negative class centers are selected to compute the margin-based softmax loss, all class
    centers are still maintained throughout the whole training process, but only a subset is
    selected and updated in each iteration.
    .. note::
        When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).
    Example:
    --------
    >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)
    >>> for img, labels in data_loader:
    >>>     embeddings = net(img)
    >>>     loss = module_pfc(embeddings, labels)
    >>>     loss.backward()
    >>>     optimizer.step()
    """
    _version = 2

    def __init__(
        self,
        margin_loss: Callable,
        embedding_size: int,
        num_classes: int,
        sample_rate: float = 1.0,
        fp16: bool = False,
    ):
        """
        Paramenters:
        -----------
        embedding_size: int
            The dimension of embedding, required
        num_classes: int
            Total number of classes, required
        sample_rate: float
            The rate of negative centers participating in the calculation, default is 1.0.
        """
        super(PartialFC_V2, self).__init__()
        assert (
            distributed.is_initialized()
        ), "must initialize distributed before create this"
        self.rank = distributed.get_rank()
        self.world_size = distributed.get_world_size()

        self.dist_cross_entropy = DistCrossEntropy()
        self.embedding_size = embedding_size
        self.sample_rate: float = sample_rate
        self.fp16 = fp16
        self.num_local: int = num_classes // self.world_size + int(
            self.rank < num_classes % self.world_size
        )
        self.class_start: int = num_classes // self.world_size * self.rank + min(
            self.rank, num_classes % self.world_size
        )
        self.num_sample: int = int(self.sample_rate * self.num_local)
        self.last_batch_size: int = 0

        self.is_updated: bool = True
        self.init_weight_update: bool = True
        self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))

        # margin_loss
        if isinstance(margin_loss, Callable):
            self.margin_softmax = margin_loss
        else:
            raise

    def sample(self, labels, index_positive):
        """
            This functions will change the value of labels
            Parameters:
            -----------
            labels: torch.Tensor
                pass
            index_positive: torch.Tensor
                pass
            optimizer: torch.optim.Optimizer
                pass
        """
        with torch.no_grad():
            positive = torch.unique(labels[index_positive], sorted=True).cuda()
            if self.num_sample - positive.size(0) >= 0:
                perm = torch.rand(size=[self.num_local]).cuda()
                perm[positive] = 2.0
                index = torch.topk(perm, k=self.num_sample)[1].cuda()
                index = index.sort()[0].cuda()
            else:
                index = positive
            self.weight_index = index

            labels[index_positive] = torch.searchsorted(index, labels[index_positive])

        return self.weight[self.weight_index]

    def forward(
        self,
        local_embeddings: torch.Tensor,
        local_labels: torch.Tensor,
    ):
        """
        Parameters:
        ----------
        local_embeddings: torch.Tensor
            feature embeddings on each GPU(Rank).
        local_labels: torch.Tensor
            labels on each GPU(Rank).
        Returns:
        -------
        loss: torch.Tensor
            pass
        """
        local_labels.squeeze_()
        local_labels = local_labels.long()

        batch_size = local_embeddings.size(0)
        if self.last_batch_size == 0:
            self.last_batch_size = batch_size
        assert self.last_batch_size == batch_size, (
            f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}")

        _gather_embeddings = [
            torch.zeros((batch_size, self.embedding_size)).cuda()
            for _ in range(self.world_size)
        ]
        _gather_labels = [
            torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
        ]
        _list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
        distributed.all_gather(_gather_labels, local_labels)

        embeddings = torch.cat(_list_embeddings)
        labels = torch.cat(_gather_labels)
        
        ## 选出落在本进程对应的类别范围内的数据
        labels = labels.view(-1, 1)
        index_positive = (self.class_start <= labels) & (
            labels < self.class_start + self.num_local
        )
        ## 标签不在本类别段的, 将其类别标签设为-1
        labels[~index_positive] = -1
        ## 将类别ID平移到原点(因为不同进程都会初始化对应的self.weight, 若不平移回去, 则label与self.weight中的index会对应不上)
        labels[index_positive] -= self.class_start

        if self.sample_rate < 1:
            weight = self.sample(labels, index_positive)
        else:
            weight = self.weight

        with torch.cuda.amp.autocast(self.fp16):
            norm_embeddings = normalize(embeddings)
            norm_weight_activated = normalize(weight)
            logits = linear(norm_embeddings, norm_weight_activated)
        if self.fp16:
            logits = logits.float()
        logits = logits.clamp(-1, 1)

        logits = self.margin_softmax(logits, labels)
        loss = self.dist_cross_entropy(logits, labels)
        return loss

实验结果

将PFC替换掉传统FC后,模型在WebFace(包括4m、12m、42m)上的性能会有所提升,

 消融实验的结果如下,

与SOTA方法的性能对比如下, 

结论与讨论

结论

作者提出了一种用于在大规模数据集上训练人脸识别模型的方法——Partial FC (PFC)。在PFC的每次迭代中,仅选择一小部分类别中心来计算基于边际的softmax损失,这样可以显著减少类间冲突的概率、尾类中心的被动更新频率以及计算需求。通过广泛的实验,作者验证了所提出的PFC的有效性、鲁棒性和高效性。

局限性

尽管在WebFace上训练的PFC模型在高质量测试集上取得了不错的结果,但在人脸分辨率较低或低光照条件下拍摄的人脸上,PFC模型的表现可能较差。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/626140.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

leetcode——链表的中间节点

876. 链表的中间结点 - 力扣&#xff08;LeetCode&#xff09; 链表的中间节点是一个简单的链表OJ。我们要返回中间节点有两种情况&#xff1a;节点数为奇数和节点数是偶数。如果是奇数则直接返回中间节点&#xff0c;如果是偶数则返回第二个中间节点。 这道题的解题思路是&a…

【JS面试题】this

this取什么值&#xff0c;是在函数执行的时候确定的&#xff0c;不是在函数定义的时候确定的&#xff01; this的6种使用场景&#xff1a; ① 在普通函数中使用&#xff1a;返回window对象 ② 使用call apply bind 调用&#xff1a;绑定的是哪个对象就返回哪个对象 ③ 在对象…

LeetCode2390从字符串中移除星号

题目描述 给你一个包含若干星号 * 的字符串 s 。在一步操作中&#xff0c;你可以&#xff1a;选中 s 中的一个星号。移除星号 左侧 最近的那个 非星号 字符&#xff0c;并移除该星号自身。返回移除 所有 星号之后的字符串。注意&#xff1a;生成的输入保证总是可以执行题面中描…

电子邮箱是什么?怎么申请一个电子邮箱?

电子邮箱是我们沟通的工具&#xff0c;细分为免费版电子邮箱和付费版电子邮箱。怎么申请一个属于自己的电子邮箱&#xff1f;今天小编就分享一下电子邮箱注册教程&#xff0c;手把手教您注册一个电子邮箱。 一、电子邮箱的定义 电子邮箱&#xff0c;简称邮箱&#xff0c;是一…

【Java基础】权限修饰符

一个java文件中只能有一个被public修饰的类&#xff0c;且该类名与java文件的名字一样 同一个类同一个包不同包有继承不同包无继承private✔❌❌❌默认✔✔❌❌protected✔✔✔❌public✔✔✔✔

景源畅信数字:抖音热门赛道有哪些?

抖音&#xff0c;作为当下流行的短视频平台&#xff0c;吸引了无数用户和创作者。热门赛道&#xff0c;即平台上受关注度高、活跃用户多的内容领域&#xff0c;是许多内容创作者关注的焦点。这些赛道不仅反映了用户的兴趣偏好&#xff0c;也指引着创作的方向。 一、美食制作与分…

产品新说:应急定界 | 如何在运维/技术支持领域中应对突发故障?

一、简介 应急定界的方案旨在帮助运维人员以业务故障驱动为起点&#xff0c;第一时间的快速恢复业务。该场景的条件基础是通过构建一体化监控告警平台&#xff0c;纳管应用与基础组件&#xff0c;提供业务系统监测、及时告警、排查分析能。通过告警、指标、日志、链路等重要运…

C语言中数组与指针的区别

一. 简介 本文学习了 C语言中数组与指针的区别。这样的话&#xff0c;可以在编写C代码时规避掉出错的问题。 二. C语言中数组与指针的区别 1. 数组 定义字符串数组时&#xff0c;必须让编译器知道需要多少空间。 一种方法是用足够空间的数组存储字符串。例如如下&#xf…

多表查询练习题

1、创建好数据库 create database text use text --学生表 (students) CREATE TABLE students ( student_id INT PRIMARY KEY, name VARCHAR(50), age INT, major VARCHAR(50) );--课程表 (courses) CREATE TABLE courses ( course_id INT PRIMARY KEY, course_name V…

Linux基础之进程-进程状态

目录 一、进程状态 1.1 什么是进程状态 1.2 运行状态 1.2 阻塞状态 1.3 挂起状态 二、Linux操作系统上具体的进程状态 2.1 状态 2.2 R 和 S 状态的查看 2.3 后台进程和前台进程 2.4 休眠状态和深度休眠状态 一、进程状态 1.1 什么是进程状态 首先我们知道我们的操作系…

Java学习47-Java 流(Stream)、文件(File)和IO - 其他流的使用

1.标准输入流System.in/标准输出流System.out System.in : 标准的输入流&#xff0c;默认从键盘输入 System.out: 标准的输出流&#xff0c;默认从显示器输出(理解为控制台输出) System.setOut()方法和 System.setIn()方法&#xff08;结合下面介绍的打印流举例&#xff09; …

灵活的静态存储控制器 (FSMC)的介绍(STM32F4)

目录 概述 1 认识FSMC 1.1 应用介绍 1.2 FSMC的主要功能 1.2.1 FSMC用途 1.2.2 FSMC的功能 2 FSMC的框架结构 2.1 AHB 接口 2.1.1 AHB 接口的Fault 2.1.2 支持的存储器和事务 2.2 外部器件地址映射 3 地址映射 3.1 NOR/PSRAM地址映射 3.2 NAND/PC卡地址映射 概述…

ctfshow web入门 php反序列化 web267--web270

web267 查看源代码发现这三个页面 然后发现登录页面直接admin/admin登录成功 然后看到了 ///backdoor/shell unserialize(base64_decode($_GET[code]))EXP <?php namespace yii\rest{class IndexAction{public $checkAccess;public $id;public function __construct(){…

定时器的理论和使用

文章目录 一、定时器理论1.1定时器创建和使用 二、定时器实践2.1周期触发定时器2.2按键消抖 一、定时器理论 定时器是一种允许在特定时间间隔后或在将来的某个时间点调用回调函数的机制。对于需要周期性任务或延迟执行任务的嵌入式应用程序特别有用。 软件定时器&#xff1a; …

MySQL表的基本操作

表 创建表 comment是添加一个注释 语法&#xff1a; 说明&#xff1a; field 表示列名 datatype 表示列的类型 character set 字符集&#xff0c;如果没有指定字符集&#xff0c;则以所在数据库的字符集为准 collate 校验规则&#xff0c;如果没有指定校验规则&#xff0c;则…

知识图谱 | 语义网络写入图形数据库(含jdk和neo4j的安装过程)

Hi&#xff0c;大家好&#xff0c;我是半亩花海。本文主要介绍如何使用 Neo4j 图数据库呈现语义网络&#xff0c;并通过 Python 将语义网络的数据写入数据库。具体步骤包括识别知识中的节点和关系&#xff0c;将其转化为图数据库的节点和边&#xff0c;最后通过代码实现数据的写…

两数相加 - (LeetCode)

前言 今天无意间看到LeetCode的一道“两数相加”的算法题&#xff0c;第一次接触链表ListNode&#xff0c;ListNode结构如下&#xff1a; public class ListNode {int val;ListNode next;ListNode() {}ListNode(int val) {this.val val;}ListNode(int val, ListNode next) {…

使用TimeSum教你打造一套最牛的知识笔记管理系统!

从用户使用场景进行介绍软件的使用&#xff1a; 一、用户需求&#xff1a; 我需要一款软件记录我每天&#xff1a; 干了啥事有啥输出&#xff08;文档&#xff09;需要时间统计&#xff0c;后续会复盘记录的内容有好的逻辑关系需要有日历进行展示。 二、软件使用介绍&#xf…

【UE5.1 角色练习】01-使用小白人蓝图控制商城角色移动

目录 效果 步骤 一、导入资源 二、控制角色移动 三、更换角色移动动作 效果 步骤 一、导入资源 新建一个工程&#xff0c;然后在虚幻商城中将角色动画的相关资源加入工程&#xff0c;这里使用的是“动画初学者内容包”和“MCO Mocap Basics” 将我们要控制的角色添加进…

在idea中使用vue

一、安装node.js 1、在node.js官网&#xff08;下载 | Node.js 中文网&#xff09;上下载适合自己电脑版本的node.js压缩包 2、下载完成后进行解压并安装&#xff0c;一定要记住自己的安装路径 一直点击next即可&#xff0c;这部选第一个 3、安装成功后&#xff0c;按住winR输入…