Center Loss 和 ArcFace Loss 笔记

一、Center Loss

1. 定义

Center Loss 旨在最小化类内特征的离散程度,通过约束样本特征与其类别中心之间的距离,提高类内特征的聚合性。

2. 公式

对于样本 xi​ 和其类别yi​,Center Loss 的公式为:

  • xi​: 当前样本的特征向量(通常来自网络的最后一层)。
  • Cyi: 类别 yi​ 的特征中心。
  • m: 样本数量。

3. 作用

  • 减小类内样本的特征分布范围。
  • 提高分类模型对相似类别样本的区分能力。

4. 实现

import torch
import torch.nn as nn

class CenterLoss(nn.Module):
    def __init__(self, num_classes, feat_dim, weight=1.0):
        """
        :param num_classes: 类别数量
        :param feat_dim: 特征向量维度
        :param weight: 损失的权重
        """
        super(CenterLoss, self).__init__()
        self.weight = weight
        self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))  # 初始化类别中心

    def forward(self, features, labels):
        """
        :param features: 网络输出的特征向量 (batch_size, feat_dim)
        :param labels: 样本对应的类别标签 (batch_size,)
        """
        centers = self.centers[labels]  # 获取对应标签的中心
        loss = torch.sum((features - centers) ** 2, dim=1).mean()  # 欧几里得距离平方和
        return self.weight * loss

5. 结合 Cross-Entropy Loss

Center Loss 与交叉熵损失结合,联合优化网络:

center_loss = CenterLoss(num_classes=10, feat_dim=512)
cross_entropy_loss = nn.CrossEntropyLoss()

# 训练时
features, logits = model(input_data)
loss_ce = cross_entropy_loss(logits, labels)
loss_center = center_loss(features, labels)

total_loss = loss_ce + 0.1 * loss_center  # 合并损失

二、ArcFace Loss

1. 定义

ArcFace Loss 是基于角度的损失函数,用于增强特征的判别性。通过在角度空间引入额外的边际约束,强迫同类样本之间更加接近,而不同类样本之间更加远离。

2. 公式

ArcFace Loss 的公式为:

  • θ: 特征和分类权重之间的角度。
  • m: 边际(margin)。

最终损失使用交叉熵计算:

  • s: 缩放因子,用于平衡模型的学习难度。

3. 作用

  • 强化特征的角度判别能力,使得分类更加鲁棒。
  • 在人脸识别任务中,显著提高模型的性能。

4. 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ArcFaceLoss(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.50):
        """
        :param in_features: 特征向量维度
        :param out_features: 类别数量
        :param s: 缩放因子
        :param m: 边际约束
        """
        super(ArcFaceLoss, self).__init__()
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.randn(out_features, in_features))  # 分类权重

    def forward(self, embeddings, labels):
        # Normalize embeddings and weight
        embeddings = F.normalize(embeddings, p=2, dim=1)
        weight = F.normalize(self.weight, p=2, dim=1)

        # Cosine similarity
        cosine = F.linear(embeddings, weight)

        # Add margin
        phi = cosine - self.m
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, labels.view(-1, 1), 1)
        cosine_with_margin = one_hot * phi + (1 - one_hot) * cosine

        # Scale
        logits = self.s * cosine_with_margin
        loss = F.cross_entropy(logits, labels)
        return loss

解释:

        ArcFaceLoss在最后一层网络,输入是上一层的输出特征值x,初始化当前层的w权重。

cos(角度)=w×x/|w|×|x|,由于ArcLoss会对w和x进行归一化到和为1的概率值。所以|w|×|x|=1。则推导出cos(角度)=w×x,那么真实标签位置给角度+m则让角度变大了,cos值变小。w×x变小,输出的预测为真实标签的概率变低。让模型更难训练,那么在一遍又一遍的模型读取图片提取特征的过程中,会让模型逐渐的将真实标签位置的w×x值变大==cos(角度+m)变大,那么角度就会变的更小。只有角度更小的时候,cos余弦相似度才会大,从而让模型认为这个类别是真实的类别。

所以arcloss主要加入了一个m,增大角度,让模型更难训练,让模型把角度变的更小,从而让w的值调整的更加让类间距增大。

简而言之:加入m的值,让真实类和其他类相似度更高,让模型更难训练。迫使模型为了让真实和其他类相似度更低,而让w权重的值更合理。

三、对比分析

四、如何选择

  • 如果任务需要提升类内特征的聚合性(如样本分布紧密性),优先考虑 Center Loss
  • 如果任务需要增强类间特征的判别能力(如人脸识别),优先选择 ArcFace Loss
  • 可以同时使用两者,将特征聚合和判别性结合,提高模型的鲁棒性。

五、推荐学习资源

  1. ArcFace: Additive Angular Margin Loss for Deep Face Recognition (论文)
  2. Center Loss: A Discriminative Feature Learning Approach for Deep Face Recognition (论文)
  3. PyTorch 官方文档

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/951832.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Maui】动态菜单实现(绑定数据视图)

前言 .NET 多平台应用 UI (.NET MAUI) 是一个跨平台框架,用于使用 C# 和 XAML 创建本机移动和桌面应用。 使用 .NET MAUI,可从单个共享代码库开发可在 Android、iOS、macOS 和 Windows 上运行的应用。 .NET MAUI 是一款开放源代码应用,是 X…

【json】

JSON JSON是一种轻量级的,按照指定的格式去组织和封装数据的数据交互格式。 本质上是一个带有特定格式的字符串(py打印json时认定为str类型) 在各个编程语言中流通的数据格式,负责不同编程语言中的数据传递和交互,类似于计算机普通话 python与json关系及相互转换…

51单片机——中断(重点)

学习51单片机的重点及难点主要有中断、定时器、串口等内容,这部分内容一定要认真掌握,这部分没有学好就不能说学会了51单片机 1、中断系统 1.1 概念 中断是为使单片机具有对外部或内部随机发生的事件实时处理而设置的,中断功能的存在&#…

易支付二次元网站源码及部署教程

易支付二次元网站源码及部署教程 引言 在当今数字化时代,二次元文化逐渐成为年轻人生活中不可或缺的一部分。为了满足这一庞大用户群体的需求,搭建一个二次元主题网站显得尤为重要。本文将为您详细介绍易支付二次元网站源码的特点及其部署教程&#xf…

开源生成式物理引擎Genesis,可模拟世界万物

这是生成大模型时代 —— 它们能生成文本、图像、音频、视频、3D 对象…… 而如果将所有这些组合到一起,我们可能会得到一个世界! 现在,不管是 LeCun 正在探索的世界模型,还是李飞飞想要攻克的空间智能,又或是其他研究…

【fly-iot飞凡物联】(19):开源飞凡物联项目重启,使用go重写后端代码,感兴趣的小伙伴可以一起参加,使用apache协议开源,招募感兴趣的小伙伴!!

目录 前言fly-iot飞凡物联,感兴趣的小伙伴可以一起参加,使用apache协议开源使用go重写后端代码 前言 fly-iot飞凡物联专栏: https://blog.csdn.net/freewebsys/category_12219758.html fly-iot飞凡物联,感兴趣的小伙伴可以一起参…

用于与多个数据库聊天的智能 SQL 代理问答和 RAG 系统(3) —— 基于 LangChain 框架的文档检索与问答功能以及RAG Tool的使用

介绍基于 LangChain 框架的文档检索与问答功能,目标是通过查询存储的向量数据库(VectorDB),为用户的问题检索相关内容,并生成自然语言的答案。以下是代码逻辑的详细解析: 代码结构与功能 初始化环境与加载…

消息中间件类型介绍

消息中间件是一种在分布式系统中用于实现消息传递的软件架构模式。它能够在不同的系统或应用之间异步地传输数据,实现系统的解耦、提高系统的可扩展性和可靠性。以下是几种常见的消息中间件类型及其介绍: 1.RabbitMQ 特点: • 基于AMQP&#…

uniapp使用scss mixin抽离css常用的公共样式

1、编写通用scss样式文件 // 通用 Flex Mixin mixin flex($direction: row, $justify: flex-start, $align: stretch, $wrap: nowrap) {display: flex;flex-direction: $direction;justify-content: $justify;align-items: $align;flex-wrap: $wrap; }// 水平居中 mixin flex-…

Matlab Steger算法提取条纹中心线(亚像素位置)

文章目录 一、简介二、实现代码三、实现效果参考文献一、简介 Steger 算法是一种常用的图像边缘检测算法,可以用于提取图像中的中心线或边缘信息。它的理论假设是:条纹的亮度是按照高斯分布呈现的,即中心亮两侧渐暗。 其计算过程如下所述: 1、首先,我们需要计算每个点Hess…

PySide6 Qt for Python Qt Quick参考网址

Qt QML BOOK: 《Qt for Python》 -Building an Application https://www.qt.io/product/qt6/qml-book/ch19-python-build-app#signals-and-slots Qt for Python:与C版本的差异即BUG处理(常见的DLL文件确实的问题等) Qt for Pyt…

【大数据】Apache Superset:可视化开源架构

Apache Superset是什么 Apache Superset 是一个开源的现代化数据可视化和数据探索平台,主要用于帮助用户以交互式的方式分析和展示数据。有不少丰富的可视化组件,可以将数据从多种数据源(如 SQL 数据库、数据仓库、NoSQL 数据库等&#xff0…

ELK实战(最详细)

一、什么是ELK ELK是三个产品的简称:ElasticSearch(简称ES) 、Logstash 、Kibana 。其中: ElasticSearch:是一个开源分布式搜索引擎Logstash :是一个数据收集引擎,支持日志搜集、分析、过滤,支持大量数据…

汽车物资拍卖系统架构与功能分析

2015工作至今,10年资深全栈工程师,CTO,擅长带团队、攻克各种技术难题、研发各类软件产品,我的代码态度:代码虐我千百遍,我待代码如初恋,我的工作态度:极致,责任&#xff…

利用 Python 爬虫从义乌购根据关键词获取商品列表

在当今数字化商业时代,数据是企业获取竞争优势的关键。对于从事国际贸易的商家而言,能够及时、准确地获取商品信息至关重要。义乌购作为知名的国际贸易批发平台,汇集了海量的商品资源。通过 Python 爬虫技术,我们可以高效地从义乌…

HDFS编程 - 使用HDFS Java API进行文件操作

文章目录 前言一、创建hdfs-demo项目1. 在idea上创建maven项目2. 导入hadoop相关依赖 二、常用 HDFS Java API1. 简介2. 获取文件系统实例3. 创建目录4. 创建文件4.1 创建文件并写入数据4.2 创建新空白文件 5. 查看文件内容6. 查看目录下的文件或目录信息6.1 查看指定目录下的文…

直流无刷电机控制(FOC):电流模式

目录 概述 1 系统框架结构 1.1 硬件模块介绍 1.2 硬件实物图 1.3 引脚接口定义 2 代码实现 2.1 软件架构 2.2 电流检测函数 3 电流环功能实现 3.1 代码实现 3.2 测试代码实现 4 测试 概述 本文主要介绍基于DengFOC的库函数,实现直流无刷电机控制&#x…

51单片机——串口通信(重点)

1、通信 通信的方式可以分为多种,按照数据传送方式可分为串行通信和并行通信; 按照通信的数据同步方式,可分为异步通信和同步通信; 按照数据的传输方向又可分为单工、半双工和全双工通信 1.1 通信速率 衡量通信性能的一个非常…

如何在 Linux、MacOS 以及 Windows 中打开控制面板

控制面板不仅仅是一系列图标和菜单的集合;它是通往优化个人计算体验的大门。通过它,用户可以轻松调整从外观到性能的各种参数,确保他们的电脑能够完美地适应自己的需求。无论是想要提升系统安全性、管理硬件设备,还是简单地改变桌…

浅谈弱电系统RVVP和RVSP电缆的区别(

1、RVVP 1.1RVVP电缆定义? RVVP电缆抗干扰软电缆、屏蔽电缆、信号电缆、控制电缆(名字很多),学名:铜芯-聚氯乙烯绝缘-屏蔽聚氯乙烯护套-软电缆。 1.2RVVP执行标准 主要执行标准为JB/T8734.5-2016,部…