基本信息
原文链接:[2205.09671] A graph-transformer for whole slide image classification (arxiv.org)
源码:https://github.com/vkola-lab/tmi2022
提出了一种融合了基于图的WSI表示和用于处理病理图像的视觉转换器,称为GTP,以预测疾病等级。其中使用对比学习框架来生成特征提取器,计算单个WSI补丁的特征向量,表示图的节点,然后构建GTP框架。
正常与LUAD与LSCC:平均准确率= 91.2±2.5%
外部测试数据(TCGA)的平均准确度= 82.3±1.0%。
基本方法
1. Graph-Transformer
(1) 总体
点集G=(V,E),V为表示图像patch的节点集,E为V中表示两个图像patch是否相邻的节点间的边集。
我们将G的邻接矩阵记为A = [Aij],其中如果存在一条边(vi, v j)∈E,则Aij = 1,否则Aij = 0。一个图像patch必须与其他patch相连,并且相邻的patch最多可以被8个相邻的patch包围,因此A的每一行或每一列的和至少为1,最多为8。
一个图可以关联一个节点特征矩阵F,,其中每一行包含为一个图像patch计算的d维特征向量,即节点。
- 将每张整张幻灯片图像(WSI)划分为小块。
- 去除主要包含背景的patch,通过基于对比学习的patch embedding模块将剩余的patch嵌入到feature vector中。
- 使用特征向量构建图形,将每个选择的patch表示为一个节点,并在整个WSI上使用具有8节点邻接矩阵的节点构建图。使用Transformer将图形作为输入并预测wsi级别的类标签。
(2)patch的特征提取
使用对比学习来训练卷积神经网络(CNN),该网络通过潜在空间中的对比损失最大化同一图像patch的两个不同增强视图之间的一致性来产生嵌入表示。训练开始于将训练集中的wsi平铺成小块,并随机抽取一小批K个小块。对每个patch (p)应用两个不同的数据增强操作,得到两个增强的patch (pi和pj)。
数据增强的方法有:随机颜色失真、随机高斯模糊和随机裁剪
同一patch的两个增广patch对记为正对。对于一个小批量的K个patch,总共有2K个增强patch。给定一个正对,其他2K−1个增广patch被认为是负样本。随后,使用CNN从每个增广patch (pi, pj)中提取有代表性的嵌入向量(fi, fj)。然后通过projector将嵌入向量映射到潜在空间(zi, zj),在潜在空间(zi, zj)中应用对比学习损失。正对增广patch (i, j)的对比学习损失函数定义为:
a.图卷积层
为A的对称归一化邻接矩阵,M为图卷积层数,,是对角矩阵,且。Hm为第m层GC的输入,H1用节点特征矩阵f初始化。为图卷积层的可学习滤波器矩阵。其中Cm为输入维数,Cm+1为输出维数。
b.Transformer层
以邻接矩阵的形式对图卷积层使用位置嵌入。WSI图的邻接矩阵反映了节点之间的空间信息和连通性,在进行图卷积时保留了这些信息。
c.池化层
在图卷积层和变压器层之间添加了最小池化层,将输入数量从数千个节点减少到数百个节点。min-cut池化在保留相邻节点的局部信息的同时,减少了节点数量。
min-cut池化背后的思想是将min-cut问题通过具有自定义损失函数的池化层来实现。
通过最小化自定义损失,池化层学习在任何给定的图上找到min-cut聚类,并聚类以减小图的大小。
(3)构建图
GTP使用特征节点矩阵F和相邻矩阵A构造一个图来表示每个WSI。
特征节点矩阵F:通过对比学习训练的Resnet得到的d维嵌入向量fi,得到节点特征矩阵F = [f1;f2,……;fN], ,N为来自一个WSI的patch数。注意,N是可变的,因为不同的wsi包含不同数量的补丁。
相邻矩阵A:根据WSI上对应patch的空间位置在F中定义一对节点之间的边。如果补丁i是补丁j在WSI上的邻居,则GTP在节点i和节点j之间创建一条边,并设置Aij = 1, Aji = 1,否则Aij = 0, Aji = 0。
实验
1.实验设置
在20倍放大率下,对每个WSI进行裁剪,形成一个由512 × 512个无重叠斑块组成的袋,丢弃非组织面积> 50%的背景斑块。使用Resnet18作为CNN主干进行特征提取。
采用Adam优化器,初始学习率为0.0001,学习率调度采用余弦退火方案,小批量大小为512。我们使用一个图卷积层,设置Transformer层配置为L = 3, MLP大小= 128,D = 64, k = 8 (Eq.4, Eq.3)。GTP模型在150次迭代中以8个示例为批次进行训练。初始学习率设置为10−3,在第30步和第100步分别衰减为10−4和10−5。