超大规模分类(三):KNN softmax

传统的分类损失计算输入数据和每个类别中心的距离,来优化模型的训练。KNN softmax通过选择和输入数据最相关的top-K个类别,仅计算输入数据和top-K个类别中心的距离,以减小计算量。

![[Pasted image 20250109103750.png]]

KNN softmax首次诞生于达摩院机器智能技术实验室发表的SIGKDD 2020年《Large-Scale Training System for 100-Million Classification at Alibaba》

简单说下论文作者:

  • Pan Pan:潘攀,拍立淘创始人,著有《深度学习图像搜索与识别》
  • Liuyihan Song、Kang Zhao、Yiming Chen、Yingya Zhang均来自拍立淘团队
  • Yinghui Xu:徐盈辉,徐盈辉-复旦大学人工智能创新与产业(AI³)研究院 (fudan.edu.cn)
  • Rong Jin:金榕(阿里巴巴原副总裁、达摩副院长)_百度百科 (baidu.com)

问题建模

一个常见的图像分类任务整体流程如下:
![[Pasted image 20250109105251.png]]

输入图像 x i x_i xi送入Convolutional Feature Learning模块,提取图像表征 f x i ∈ R 1 × D f_{x_i}\in \mathbb{R}^{1\times D} fxiR1×D(其中 D D D表示维度),再通过Fully Connected Layer,将图像表征维度 f x i f_{x_i} fxi映射到类别数 C C C上,紧接着通过Softmax Function获取 [ 0 , 1 ] [0,1] [0,1]的概率值,计算分类损失。

我们来进行公式化定义,

(1)图像表征 f x i f_{x_i} fxi通过Fully Connected Layer将维度映射到类别数 C C C,可以建模成: f x i W ∈ R C f_{x_i}W \in \mathbb{R}^C fxiWRC,其中 W ∈ R D × C W \in \mathbb{R}^{D\times C} WRD×C。一般情况下,Fully Connected Layer会有偏置 b b b,将偏置 b b b设置为0。

(2)通过Softmax Function获取 [ 0 , 1 ] [0,1] [0,1]的概率值,得到 f x i W ∑ j e x p ( f x i W j ) \frac{f_{x_i}W}{\sum_j{exp(f_{x_i}W_j)}} jexp(fxiWj)fxiW,其中 W j ∈ R D × 1 W_j \in \mathbb{R}^{D\times 1} WjRD×1,表示第 j j j列数据,也指类别表征

(3)分类损失的定义为: L = − log ⁡ ( e x p ( f x i W y i ) ∑ j e x p ( f x i W j ) ) = − log ⁡ ( e x p ( ∥ f x i ∥ ⋅ ∥ W y i ∥ ⋅ c o s ( θ y j ) ) ∑ j e x p ( ∥ f x i ∥ ⋅ ∥ W j ∥ ⋅ c o s ( θ j ) ) ) (1) \begin{equation}\begin{aligned} L&=-\log\left(\frac{exp(f_{x_i}W_{y_i})}{\sum_j{exp(f_{x_i}W_j)}}\right)\\ &=-\log\left(\frac{exp(\|f_{x_i}\|\cdot \|W_{y_i}\|\cdot cos(\theta_{y_j}))}{\sum_j{exp(\|f_{x_i}\|\cdot\|W_j\|\cdot cos(\theta_{j}))}}\right)\\ \end{aligned} \end{equation}\tag{1} L=log(jexp(fxiWj)exp(fxiWyi))=log(jexp(fxiWjcos(θj))exp(fxiWyicos(θyj)))(1),其中 y i y_i yi指的是输入图像 x i x_i xi对应的类别下标,等式上下成立的原因是向量的内积公式 a ⋅ b = ∥ a ∥ ⋅ ∥ b ∥ ⋅ cos ⁡ θ \mathbf{a} \cdot \mathbf{b} = \|\mathbf{a}\| \cdot \|\mathbf{b}\| \cdot \cos\theta ab=abcosθ

(4)在常规实践中,图像表征 f x i f_{x_i} fxi和类别表征 W j W_j Wj一般都事先归一化好,仅需要考虑两个表征间的余弦距离。同时,需要乘上一个缩放因子,用于控制训练的激进程度,例如
L = − log ⁡ ( e x p ( α ⋅ c o s ( θ y j ) ) ∑ j e x p ( α ⋅ c o s ( θ j ) ) ) = − log ⁡ ( e x p ( α ⋅ f x i ∥ f x i ∥ ⋅ W y i ∥ W y i ∥ ) ∑ j e x p ( α ⋅ f x i ∥ f x i ∥ ⋅ W j ∥ W j ∥ ) = − log ⁡ ( e x p ( α ⋅ f x i n o r m ⋅ W i n o r m ) ∑ j e x p ( α ⋅ f x i n o r m ⋅ W j n o r m ) (2) \begin{equation}\begin{aligned} L&=-\log\left(\frac{exp(\alpha \cdot cos(\theta_{y_j}))}{\sum_j{exp(\alpha \cdot cos(\theta_{j}))}}\right)\\ &=-\log\left(\frac{exp(\alpha \cdot \frac{f_{x_i}}{\|f_{x_i}\|}\cdot \frac{W_{y_i}}{\|W_{y_i}\|})}{\sum_j{exp(\alpha \cdot \frac{f_{x_i}}{\|f_{x_i}\|}\cdot \frac{W_{j}}{\|W_{j}\|}}}\right)\\ &=-\log\left(\frac{exp(\alpha \cdot f_{x_i}^{norm} \cdot W_{_i}^{norm})}{\sum_j{exp(\alpha \cdot f_{x_i}^{norm}\cdot W_{j}^{norm}}}\right)\\ \end{aligned} \end{equation}\tag{2} L=log(jexp(αcos(θj))exp(αcos(θyj)))=log jexp(αfxifxiWjWjexp(αfxifxiWyiWyi) =log(jexp(αfxinormWjnormexp(αfxinormWinorm))(2)
,这个就是CLIP用的损失函数的形式了。

KNN softmax

全连接层的模型并行

如果特征维度是512维,分类1个亿的全连接层参数有 512 × 100000000 = 5.12 ∗ 1 0 10 512\times 100000000=5.12*10^{10} 512×100000000=5.121010。若参数存储形式为fp32,即1个参数需要4个字节,那么占用的显存为 5.12 × 1 0 10 ∗ 4 1024 × 1024 × 1024 = 191.1 G B \frac{5.12\times 10^{10}*4}{1024\times 1024\times 1024}=191.1GB 1024×1024×10245.12×10104=191.1GB

很显然,单块显卡装不下。于是,本文将全连接层参数均分到每一块显卡上。假设我们有256块V100显卡,每块显卡只需要装 191.1 G B 256 = 0.74 G B \frac{191.1 GB}{256}=0.74GB 256191.1GB=0.74GB,很显然,每块显卡的负担小得多了。

![[Pasted image 20250109210809.png]]
做法如上图所示,包括数据并行和模型并行。

  • 数据并行指的是Convolutional Feature Learning模块参数复制到每块GPU上,只有数据均分成 N N N份,送入不同GPU中。
  • 模型并行特指全连接层参数均分成 N N N份,存储到不同GPU中。
    具体流程如下:
    (1)数据均分成 N N N份,送到不同GPU中。
    (2)每块GPU上,通过Convolutional Feature Learning模块提取图像表征,再执行all-gather操作,将不同GPU的表征汇聚到每一块GPU上。(假设有3块GPU,每块GPU提取了 R 2 × 512 \mathbb{R}^{2\times 512} R2×512表征,执行all-gather操作后,将3块GPU的表征汇聚起来,分发到所有GPU上,每块GPU提取的表征变为 R 6 × 512 \mathbb{R}^{6\times 512} R6×512
    (3)第 i i i块GPU将图像表征送到第 i i i份全连接层参数上
    (4)执行分布式softmax计算,以及损失的计算
    (5)每块GPU参数反向传播,在反向传播至Convolutional Feature Learning模块前,汇聚梯度,再进一步向前传播。
    (6)参数更新时,第 i i i份全连接层参数仅通过第 i i i块GPU的梯度进行更新;Convolutional Feature Learning模块则通过全GPU的梯度进行更新。

尽管做了全连接层的模型并行,但是全连接层的计算量级实在太大,越80%的训练时间消耗在全连接层的操作上(全连接层前向传播,softmax前向传播,softmax反向传播,全连接层反向传播)

top-K类别选择

在公式(2)中,有 L = − log ⁡ ( e x p ( α ⋅ f x i n o r m ⋅ W y i n o r m ) ∑ j e x p ( α ⋅ f x i n o r m ⋅ W j n o r m ) L=-\log\left(\frac{exp(\alpha \cdot f_{x_i}^{norm} \cdot W_{y_i}^{norm})}{\sum_j{exp(\alpha \cdot f_{x_i}^{norm}\cdot W_{j}^{norm}}}\right) L=log(jexp(αfxinormWjnormexp(αfxinormWyinorm)),分类损失需要计算输入表征 f x i n o r m f_{x_i}^{norm} fxinorm和所有类别表征的余弦距离。由于类别数特别大,计算难度特别高,所以选择从中挑选 K K K个类别,进行分母的计算。

这是一个典型的检索场景,文中利用输入数据类别 y i y_i yi的类别表征 W y i W_{y_i} Wyi去检索所有类别中心表征,得到top-K个相似度最高的类别,用于分类损失的分母计算。

分布式KNN图构建

KNN图的建立可以理解为:给定query集合,以及doc集合,建立每个query到doc内最相近top-k个样本的关系。

在1亿类别分类场景,query和doc集合都等于1亿类别,建KNN图流程就特指:将每1个类别中心作为query,检索1亿个类别中心内,最相似的top-k个类别中心,构成 1 亿 × k 1亿\times k 1亿×k的相似度矩阵。

大规模检索场景常用的策略为ANN检索(Approximate Nearest Neighbor,近似最近邻检索)。但作者发现ANN对召回影响较大,导致损失偏差较大,效果不好,推荐采用暴力检索(brute-force)。

暴力检索不影响召回率,但很耗时,所以无法每个iteration更新一次,本文是每隔一个epoch更新一次KNN图。

因为模型并行,已经将全连接层均分到每块GPU上,建立KNN图是需要考虑该因素。传统的建图策略是:将所有GPU上的类别表征聚合到每块GPU上,得到完整的doc集合。计算每块GPU上的类别表征与完整doc集合的相似度矩阵,很显然,对显存消耗很高。

采用分布式建图,策略为:假设将GPU(id=0)作为query,计算KNN图,流程有:

  • 在GPU(id=0)上,计算query到GPU(id=0)上类别表征的top-k,结果传播到GPU(id=1)上
  • 在GPU(id=1)上,计算query到GPU(id=1)上类别表征的top-k,结果传播到GPU(id=2)上
  • 最后,将最终结果返回到GPU(id=0)上
    这样的处理方式对显存消耗非常小,并且GPU间的通信量也少。在这里插入图片描述

具体实现时,类别中心的存储由fp32改为fp16,并且采用TensorCore进行相似度计算加速(较原方法能加速3倍)。fp16的精度低于fp32,为平衡速度和效果,首先用fp16精度从全类别中心里搜top- k ′ k^{'} k,再利用fp32精度从top- k ′ k^{'} k中搜出top-k。

经过上述一通操作,1亿类别中心的KNN建图时间仅需0.75h。

采用和全连接层模型并行类似的策略,将KNN图按照query维度均分到每块GPU上,平均每块GPU仅需承担 372 G B / 256 = 1.45 G B 372GB/256=1.45GB 372GB/256=1.45GB,在可承受范围内。

效果比较

分别用了1百万类、1千万类、1亿类的数据进行训练,统计分类准确率和吞吐量,结果如下:
分类准确率:![[Pasted image 20250110095143.png]]

  • selective softmax:分母中通过Hashing Forest来选择k个类别,未采用KNN方式选择
  • MACH:一种加速策略,速度快,但效果不好
  • Full Softmax指的是分类损失中,分布用全类别表征计算得到
    吞吐量:![[Pasted image 20250110095201.png]]

,表明KNN Softmax能够有效提升吞吐量,类别越多,提升幅度越大。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/953631.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MySQL素材怎么导入Navicat???

不管用什么方法都要先关掉MySQL服务,并且提前备份数据! 1.有sql文件时候。 打开navicat,运行sql文件 然后点击后面三个点,选中要运行的sql文件,开始。 鼠标右键刷新一下,就能看到sql文件中的表了 2.没有s…

程序员独立开发竞品分析:确定网站使用什么建站系统

要确定一个网站使用的建站系统,可以通过以下几种方法尝试分析: 查看页面源代码: 打开网站,右键点击页面并选择“查看页面源代码”。在代码中查找一些常见的建站系统标志,例如: WordPress 的迹象&#xff1a…

Linux(Centos7)安装Mysql/Redis/MinIO

安装Mysql 安装Redis 搜索Redis最先版本所在的在线安装yum库 查看以上两个组件是否是开机自启 安装MinIO 开源的对象存储服务,存储非结构化数据,兼容亚马逊S3协议。 minio --help #查询命令帮助minio --server --help #查询--server帮助minio serve…

【DB-GPT】开启数据库交互新篇章的技术探索与实践

一、引言:AI原生数据应用开发的挑战与机遇 在数字化转型的浪潮中,企业对于智能化应用的需求日益增长。然而,传统的数据应用开发方式面临着诸多挑战,如技术栈复杂、开发周期长、成本高昂、难以维护等。这些问题限制了智能化应用的…

解决aerich init -t xx 报错ModuleNotFoundError: No module named ‘tomli_w‘

今天在学习fastapi的时候,发现一款数据库迁移工具,通过这个工具可以根据模型类来对数据库做出改变。 随跟着学: 在执行 aerich init -t settings.TORTOISE_ORM的时候, 彼其娘之。。 报了一些错误: Traceback (most recent ca…

.NET Core NPOI 导出图片到Excel指定单元格并自适应宽度

NPOI:支持xlsx,.xls,版本>2.5.3 XLS:HSSFWorkbook,主要前缀HSS, XLSX:XSSFWorkbook,主要前缀XSS,using NPOI.XSSF.UserModel; 1、导出Excel添加图片效果&#xff0…

浅谈云计算07 | 云安全机制

浅谈云计算安全机制:全方位守护云端世界 一、引言二、加密技术:数据的隐形护盾三、散列机制:数据完整性的忠诚卫士四、数字签名:数据来源与真伪的鉴定专家五、公钥基础设施(PKI):信任的基石六、…

Unity 2d描边基于SpriteRender,高性能的描边解决方案

目标 以Unity默认渲染管线为例,打造不需要图片内边距,描边平滑,高性能的描边解决方案 前言 在2d游戏中经常需要给2d对象添加描边,来突出强调2d对象 当你去网上查找2d描边shader,移植到项目里面,大概率会…

Uniapp仿ChatGPT Stream流式输出(非Websocket)

Uniapp仿ChatGPT Stream流式输出(非Websocket) 前言:流式输出可以使用websocket也可以使用stream来实现EventSource是 HTML5 中的一个接口,用于接收服务器发送的事件流(Server - Sent Events,SSE&#xff…

黑马linux入门笔记(01)初始Linux Linux基础命令 用户和权限 实用操作

B站 黑马程序员 的视频 BV1n84y1i7td 黑马程序员新版Linux零基础快速入门到精通,全涵盖linux系统知识、常用软件环境部署、Shell脚本、云平台实践、大数据集群项目实战等 增强自控力 冥想慢呼吸绿色锻炼充分休息减少决策次数优先做重要的事情(早晨)融入强自控群控…

当你不小心使用了MySQL的保留字作为字段名而导致你的SQL语法解析错误该怎么办!

问题举例: 你在尝试更新一个名为 desc 的字段时遇到了 SQL 语法错误。原因是 desc 是 MySQL 的保留字,通常用于表示 ORDER BY 子句中的降序(DESC),因此直接使用 desc 作为字段名会导致 SQL 解析错误。如下图&#xff…

excel设置好的可选择列数据后,如何快速输入到单元格中?

当设置好列的【数据】-【数据有效性】-【序列】后,在单元格中输入可选择数据的开头,就会提示出对应的可选择数据,然后,按一下键盘上的【↓】键,再按回车,即可快速输入到单元格中。

2025封禁指定国家ip-安装xtables-addons记录

如何安装和使用 安装lux仓库(该仓库包含xtables-addons所需的依赖环境) # wget http://repo.iotti.biz/CentOS/7/noarch/lux-release-7-1.noarch.rpm # rpm -ivh lux-release-7-1.noarch.rpm 安装xtables-addons。注意:必须先安装kmod-xtables-addons,再…

力扣 岛屿数量

从某个点找,不断找相邻位置。 题目 岛屿中被“0”隔开后 ,是每一小块状的“1”,本题在问有多少块。可以用dfs进行搜索,遍历每一个点,把每一个点的上下左右做搜索检测,当检测到就标记为“0”表示已访问过&a…

ssl与ipsec的区别

ssl与ipsec的区别: 1.概念不同: ssl是一种安全协议,可通过Internet安全地发送信息,而ipsec即Internet协议安全性,是为Internet协议提供安全性的一组协议。 2.所在层不同: ssl在传输和应用层中工作,而ipsec在Interne…

基于大语言模型的组合优化

摘要:组合优化(Combinatorial Optimization, CO)对于提高工程应用的效率和性能至关重要。随着问题规模的增大和依赖关系的复杂化,找到最优解变得极具挑战性。在处理现实世界的工程问题时,基于纯数学推理的算法存在局限…

批量识别图片型PDF指定区域内容识别保存表格+PDF批量改名:技术难题与项目实战总结

相关项目实战: 一、引言 在当今数字化办公环境中,批量处理PDF文件中的表格数据并进行改名是一项常见但具有挑战性的任务。无论是从大量的财务报销凭证、学术研究报告还是项目文档中提取表格信息,都可能遇到各种各样的技术难题。 二、批量提…

open3d+opencv实现矩形框裁剪点云操作(C++)

👑主页:吾名招财 👓简介:工科学硕,研究方向机器视觉,爱好较广泛… ​💫签名:面朝大海,春暖花开! open3dopencv实现矩形框裁剪点云操作(C&#xff…

LabVIEW光流跟踪算法

1. 光流跟踪算法的概述 光流(Optical Flow)是一种图像处理技术,用于估算图像中像素点的运动。通过比较连续帧图像,光流算法可以分析图像中的运动信息,广泛用于目标跟踪、运动检测和视频处理等场景。该示例使用了NI Vi…

django在线考试系统

Django在线考试系统是一种基于Django框架开发的在线考试平台,它提供了完整的在线考试解决方案。 一、系统概述 Django在线考试系统旨在为用户提供便捷、高效的在线考试环境,满足教育机构、企业、个人等不同场景下的考试需求。通过该系统,用…