常见的多标签分类方法是同时生成多个标签的logits,然后接一个sigmoid激活函数做二分类。该方法简单直接,但忽略了标签之间的相关性。虽然业界针对该问题提出了很多解决思路,但大多是任务特定,通用性不强,也不够优雅。
Transformer decoder倒是可以序列输出多个标签,但却加入了位置偏差。而标签之间是没有位置关系的,谁先谁后无所谓,只要输出全就行。这样也导致数据集不好构造。
C-Tran
General Multi-label Image Classification with Transformers 这篇论文提供了新思路,类似BERT的MLM预训练任务:通过在输入端对多个标签做随机mask,然后预测被mask的标签,从而强制模型去学习标签之间的依赖关系:
模型细节:
- Label Embeddings: 可学习的参数矩阵,由模型隐式学习到标签的语义信息和标签间依赖。有点像DETR的query。
- State Embeddings: 控制标签的mask比例,这样就跟标签学习实现了解耦,也方便在推理阶段注入全比例mask
实验结果
不说了,全是sota:
- 旷视用gcn来建模多标签方法(被C-Tran超越了,建模思路可以学习):Multi-Label Image Recognition with Graph Convolutional Networks