前言
多实例学习是由监督学习演变而来的,我们都知道,监督学习在训练的时候是一个实例(或者说一个样本、一条训练数据)对应一个确定的标签。而多实例的特点就是,我们在训练的时候的输入是多个实例对应一个确定的标签,我们把这多个实例看做一个包,每个包有一个确定的标签,包是有标签的而包里面的实例是没有标签的,多实例的目的就是希望从这种粗粒度的带有标签的包中学习到细粒度的表示,进而进行一系列的任务,比如说分类等。
多个实例只对应一个标签的这种情况,我们也称为弱标签情况,弱标签的问题在医学成像中尤其明显(例如,计算病理学、乳房x光检查或CT肺筛查),其中图像通常由单一标签(良性/恶性)或感兴趣区域(ROI)大致给出。
举个栗子:
- 把一张X检查图像分成多个部分,从这多个部分中我们判断是否包含病灶(往往只有一小部分包含病灶),进一步判断患者是否患病,这也可以看做一个多实例,每个部分是一个实例,整张图片包含多个实例可以看做一个包
- CT序列很明显是弱标签的,每个患者有一系列CT切片,从这些切片中判断患者是否患病,那这一套CT序列就是一个包,每张CT切片就是一个实例
所以,多实例的一个挑战是发现关键实例,即触发包标签的实例。
Attention-based Deep Multiple Instance Learning
要解决的问题
在这篇文章里作者主要以图片的二分类为例,每个包对应一个标签label y={0,1},包里包含多个多个实例(即有多张图片),并且实例的数量是可变的(每个包里面的实例数是不一样的),实例与实例之间具有排列不变性,即没有顺序也没有依赖性。
eg:下述是论文所有数据集中一个包的里面的图片实例,这个包的标签是9,我们通过注意力进行可视化之后,发现数字为9的图片占比更大
之前的方法:
(1)实例级别的方法:通过实例级别的分类器,返回每个实例的分数,然后通过MIL pooling来聚合单个分数得到最好的那个实例结果
缺点:由于单个标签是未知的,因此实例级分类器可能被训练不足,这可能会给最终的预测带来额外的错误
优点:实例级方法提供了一个分数,可用于查找关键实例,即触发包标签的实例,可解释性较好
(2)嵌入级别的方法:将实例映射到一个低维的嵌入,再通过MIL池获得与包中的实例数无关的包表示。最后把这个包表示放到分类器中得到进一步结果
缺点:可解释性不太好
优点:嵌入级方法确定了袋子的联合表示,因此它不会给袋级分类器引入额外的偏差。
论文方法
①一个由神经网络参数化的MIL模型来获取输入实例的低维嵌入
②通过注意力机制来对实例进行加权平均,代替之前的平均实例池化和最大实例池化。作者认为,基于注意力的MIL pooling是一个灵活和自适应的MIL池,可以通过适应任务和数据来获得更好的结果。理想情况下,这种MIL池也应该是可解释的。
源码解析
class Attention(nn.Module):
def __init__(self):
super(Attention, self).__init__()
self.M = 500
self.L = 128
self.ATTENTION_BRANCHES = 1
# 二维特征提取
self.feature_extractor_part1 = nn.Sequential(
nn.Conv2d(1, 20, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(20, 50, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(2, stride=2)
)
# 对之前提取的特征进行降维和非线性变换
self.feature_extractor_part2 = nn.Sequential(
nn.Linear(50 * 4 * 4, self.M),
nn.ReLU(),
)
# 通过注意力机制来计算每个实例的得分
self.attention = nn.Sequential(
nn.Linear(self.M, self.L), # matrix V
nn.Tanh(),
nn.Linear(self.L, self.ATTENTION_BRANCHES) # matrix w (or vector w if self.ATTENTION_BRANCHES==1)
)
# 分类器进行分类
self.classifier = nn.Sequential(
nn.Linear(self.M*self.ATTENTION_BRANCHES, 1),
nn.Sigmoid()
)
def forward(self, x):
"""
x:输入的包,维度为[batch, MLI_nums, channel, height, width],第一个是batch,第二个是包里的实例数,这里是变化的,第三个是图片通道数,最后两个维度是高和宽
"""
x = x.squeeze(0) # 第一个维度压缩,[MLI_nums, channel, height, width]
H = self.feature_extractor_part1(x) # 提取特征之后的维度为[MLI_nums, 50, 4, 4]
H = H.view(-1, 50 * 4 * 4) # 改变形状[MLI_nums,800],适应下一次输入
H = self.feature_extractor_part2(H) # 这里对特征进行了降维[MLI_nums,500]
# 使用注意力机制自适应的计算每个实例的分数
A = self.attention(H) # [MLI_nums,1]
A = torch.transpose(A, 1, 0) # 转换维度,[1,,MLI_nums]
A = F.softmax(A, dim=1) # 对实例分数进行归一化
Z = torch.mm(A, H) # 将归一化后的实例分数与提取的实例特征进行矩阵乘法运算,最后得到一个最终的包表示[1,500]
Y_prob = self.classifier(Z) # 分类器进行分类
Y_hat = torch.ge(Y_prob, 0.5).float()
return Y_prob, Y_hat, A
论文源码:https://github.com/AMLab-Amsterdam/AttentionDeepMIL