原论文:https://arxiv.org/abs/2110.09103
源码:https://github.com/unilight/LDNet
直接步入正题~~~
一、ESA_blcok模块
1、PPM模块
class PPM(nn.Module):
def __init__(self, pooling_sizes=(1, 3, 5)):
super().__init__()
self.layer = nn.ModuleList([nn.AdaptiveAvgPool2d(output_size=(size,size)) for size in pooling_sizes])
def forward(self, feat):
b, c, h, w = feat.shape # 4, 512, 320, 320
output = [layer(feat).view(b, c, -1) for layer in self.layer]
output = torch.cat(output, dim=-1) # 4 3 35
return output
2、ESA_layer模块
class ESA_layer(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads # 512
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5 # 1/8
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, kernel_size=1, stride=1, padding=0, bias=False)
self.ppm = PPM(pooling_sizes=(1, 3, 5))
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
# input x (b, c, h, w)
b, c, h, w = x.shape #假设输入4, 3, 320, 320
# .chunk沿dim=1维度,对张量进行均匀切分
q, k, v = self.to_qkv(x).chunk(3, dim=1) # q/k/v shape: 4, 512, 320, 320
q = rearrange(q, 'b (head d) h w -> b head (h w) d', head=self.heads) # q shape: 4, 8, 320*320, 64
k, v = self.ppm(k), self.ppm(v) # k/v shape: 4, 512, 35
k = rearrange(k, 'b (head d) n -> b head n d', head=self.heads) # k shape: 4, 8, 35, 64
v = rearrange(v, 'b (head d) n -> b head n d', head=self.heads) # v shape: 4, 8, 35, 64
a = k.transpose(-1, -2) # 4, 8, 64, 35 将k的最后两个维度进行转置
b = torch.matmul(q, a) # 4, 8, 320*320, 35
dots = b * self.scale # 4, 8, 320*320, 35
attn = self.attend(dots) # 4, 8, 320*320, 35
out = torch.matmul(attn, v) # 4, 8, 320*320, 64
out = rearrange(out, 'b head n d -> b n (head d)') # 4, 320*320, 512
out = self.to_out(out) # 4, 320*320, 3
return out
3、ESA_blcok模块
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim) # 对每个batch进行的归一化
self.fn = fn # FeedForward
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs) # 4, 320*320, 3
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x): # 4, 320*320, 3 -- 4, 320*320, 512 -- 4, 320*320, 3
return self.net(x)
class ESA_blcok(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, mlp_dim=512, dropout = 0.):
super().__init__()
self.ESAlayer = ESA_layer(dim, heads=heads, dim_head=dim_head, dropout=dropout)
self.ff = PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
def forward(self, x):
b, c, h, w = x.shape #假设输入4, 3, 320, 320
out = rearrange(x, 'b c h w -> b (h w) c') # 4, 320*320, 3
out = self.ESAlayer(x) + out # 4, 320*320, 3
out = self.ff(out) + out # 4, 320*320, 3
out = rearrange(out, 'b (h w) c -> b c h w', h=h) # 4, 3, 320, 320
return out
二、LCA_blcok模块
1、MaskAveragePooling模块
def MaskAveragePooling(x, mask):
mask = torch.sigmoid(mask) # mask shape:4, 1, 320, 320
b, c, h, w = x.shape # 4, 512, 320, 320
eps = 0.0005
x_mask = x * mask # 4, 512, 320, 320
h, w = x.shape[2], x.shape[3]
area = F.avg_pool2d(mask, (h, w)) * h * w + eps # 4, 1, 1, 1
m = F.avg_pool2d(x_mask, (h, w)) # 4, 512, 1, 1
x_feat = m * h * w / area # 4, 512, 1, 1
x_feat = x_feat.view(b, c, -1) # 4, 512, 1
return x_feat # 4, 512, 1
2、LCA_layer模块
class LCA_layer(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads # 512
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim=-1)
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, kernel_size=1, stride=1, padding=0, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x, mask):
# input x (b, c, h, w)
b, c, h, w = x.shape #假设输入4, 3, 320, 320
q, k, v = self.to_qkv(x).chunk(3, dim=1) # q/k/v shape: 4, 512, 320, 320
q = rearrange(q, 'b (head d) h w -> b head (h w) d', head=self.heads) # q shape: 4, 8, 320*320, 64
k, v = MaskAveragePooling(k, mask), MaskAveragePooling(v, mask) # k/v shape: # 4, 512, 1
k = rearrange(k, 'b (head d) n -> b head n d', head=self.heads) # k shape: (b, head, 1, d) 4, 8, 1, 64
v = rearrange(v, 'b (head d) n -> b head n d', head=self.heads) # v shape: (b, head, 1, d) 4, 8, 1, 64
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale # shape: (b, head, n_q, n_kv) 4, 8, 320*320, 1
attn = self.attend(dots) # 4, 8, 320*320, 1
out = torch.matmul(attn, v) # shape: (b, head, n_q, d) 4, 8, 320*320, 64
out = rearrange(out, 'b head n d -> b n (head d)') # 4, 320*320, 512
return self.to_out(out) # 4, 320*320, 3
3、LCA_blcok模块
class LCA_blcok(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, mlp_dim=512, dropout = 0.):
super().__init__()
self.LCAlayer = LCA_layer(dim, heads=heads, dim_head=dim_head, dropout=dropout)
self.ff = PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
def forward(self, x, mask):
b, c, h, w = x.shape #假设输入4, 3, 320, 320
out = rearrange(x, 'b c h w -> b (h w) c') # 4, 320*320, 3
out = self.LCAlayer(x, mask) + out # 4, 320*320, 3
out = self.ff(out) + out # 4, 320*320, 3
out = rearrange(out, 'b (h w) c -> b c h w', h=h) # 4, 3, 320, 320
return out
三、HeadUpdator模块
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(DecoderBlock, self).__init__()
self.conv1 = ConvBlock(in_channels, in_channels // 4, kernel_size=kernel_size, stride=stride, padding=padding)
self.conv2 = ConvBlock(in_channels // 4, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.upsample(x)
return x
class HeadUpdator(nn.Module):
def __init__(self, in_channels=64, feat_channels=64, out_channels=None, conv_kernel_size=1):
super(HeadUpdator, self).__init__()
self.conv_kernel_size = conv_kernel_size
# C == feat
self.in_channels = in_channels
self.feat_channels = feat_channels
self.out_channels = out_channels if out_channels else in_channels
# feat == in == out
self.num_in = self.feat_channels
self.num_out = self.feat_channels
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.pred_transform_layer = nn.Linear(self.in_channels, self.num_in + self.num_out)
self.head_transform_layer = nn.Linear(self.in_channels, self.num_in + self.num_out, 1)
self.pred_gate = nn.Linear(self.num_in, self.feat_channels, 1)
self.head_gate = nn.Linear(self.num_in, self.feat_channels, 1)
self.pred_norm_in = nn.LayerNorm(self.feat_channels)
self.head_norm_in = nn.LayerNorm(self.feat_channels)
self.pred_norm_out = nn.LayerNorm(self.feat_channels)
self.head_norm_out = nn.LayerNorm(self.feat_channels)
self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1)
self.fc_norm = nn.LayerNorm(self.feat_channels)
self.activation = nn.ReLU(inplace=True)
def forward(self, feat, head, pred): #feat:B 64 28 28 head:B num_classes 64 1 1 pred:B num_classes 14 14
bs, num_classes = head.shape[:2]
# C, H, W = feat.shape[-3:]
pred = self.upsample(pred)# B num_classes 28 28
pred = torch.sigmoid(pred)
"""
Head feature assemble
- use prediction to assemble head-aware feature
"""
# [B, N, C]
assemble_feat = torch.einsum('bnhw,bchw->bnc', pred, feat)# B num_classes 64
# [B, N, C, K, K] -> [B, N, C, K*K] -> [B, N, K*K, C]
head = head.reshape(bs, num_classes, self.in_channels, -1).permute(0, 1, 3, 2)#B num_classes 64 1 -- B num_classes 1 64
"""
Update head
- assemble_feat, head -> linear transform -> pred_feat, head_feat
- both split into two parts: xxx_in & xxx_out
- gate_feat = head_feat_in * pred_feat_in
- gate_feat -> linear transform -> pred_gate, head_gate
- update_head = pred_gate * pred_feat_out + head_gate * head_feat_out
"""
# [B, N, C] -> [B*N, C]
assemble_feat = assemble_feat.reshape(-1, self.in_channels)#B*num_classes 64
bs_num = assemble_feat.size(0)#bs_num=B*num_classes
# [B*N, C] -> [B*N, in+out]
pred_feat = self.pred_transform_layer(assemble_feat)#B*num_classes 128
# [B*N, in] 取所有行的前64列
pred_feat_in = pred_feat[:, :self.num_in].view(-1, self.feat_channels)#B*num_classes 64
# [B*N, out] 取所有行的后64列
pred_feat_out = pred_feat[:, -self.num_out:].view(-1, self.feat_channels)#B*num_classes 64
# [B, N, K*K, C] -> [B*N, K*K, C] -> [B*N, K*K, in+out]
head_feat = self.head_transform_layer(
head.reshape(bs_num, -1, self.in_channels))#B num_classes 1 64 -- B*num_classes 1 64 -- B*num_classes 1 128
# [B*N, K*K, in]
head_feat_in = head_feat[..., :self.num_in]#B*num_classes 1 64
# [B*N, K*K, out]
head_feat_out = head_feat[..., -self.num_out:]#B*num_classes 1 64
# [B*N, K*K, in] * [B*N, 1, in] -> [B*N, K*K, in]
gate_feat = head_feat_in * pred_feat_in.unsqueeze(-2)#B*num_classes 1 64
# [B*N, K*K, feat]
head_gate = self.head_norm_in(self.head_gate(gate_feat))#B*num_classes 1 64
pred_gate = self.pred_norm_in(self.pred_gate(gate_feat))#B*num_classes 1 64
head_gate = torch.sigmoid(head_gate)
pred_gate = torch.sigmoid(pred_gate)
# [B*N, K*K, out]
head_feat_out = self.head_norm_out(head_feat_out)#B*num_classes 1 64
# [B*N, out]
pred_feat_out = self.pred_norm_out(pred_feat_out)#B*num_classes 64
# [B*N, K*K, feat] or [B*N, K*K, C]
update_head = pred_gate * pred_feat_out.unsqueeze(-2) + head_gate * head_feat_out#B*num_classes 1 64
update_head = self.fc_layer(update_head)#B*num_classes 1 64
update_head = self.fc_norm(update_head)#B*num_classes 1 64
update_head = self.activation(update_head)#B*num_classes 1 64
# [B*N, K*K, C] -> [B, N, K*K, C]
update_head = update_head.reshape(bs, num_classes, -1, self.feat_channels)#B num_classes 1 64
# [B, N, K*K, C] -> [B, N, C, K*K] -> [B, N, C, K, K]
update_head = update_head.permute(0, 1, 3, 2).reshape(bs, num_classes, self.feat_channels, self.conv_kernel_size, self.conv_kernel_size)#B num_classes 64 1 1
return update_head
四、整体网络结构
class LDNet(nn.Module):
def __init__(self, num_classes=3, unified_channels=64, conv_kernel_size=1):
super(LDNet, self).__init__()
self.num_classes = num_classes
self.conv_kernel_size = conv_kernel_size
self.unified_channels = unified_channels
res2net = res2net50_v1b_26w_4s(pretrained=True)
# Encoder
self.encoder1_conv = res2net.conv1
self.encoder1_bn = res2net.bn1
self.encoder1_relu = res2net.relu
self.maxpool = res2net.maxpool
self.encoder2 = res2net.layer1
self.encoder3 = res2net.layer2
self.encoder4 = res2net.layer3
self.encoder5 = res2net.layer4
self.reduce2 = nn.Conv2d(256, 64, 1)
self.reduce3 = nn.Conv2d(512, 128, 1)
self.reduce4 = nn.Conv2d(1024, 256, 1)
self.reduce5 = nn.Conv2d(2048, 512, 1)
# Decoder
self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
self.decoder4 = DecoderBlock(in_channels=512+256, out_channels=256)
self.decoder3 = DecoderBlock(in_channels=256+128, out_channels=128)
self.decoder2 = DecoderBlock(in_channels=128+64, out_channels=64)
self.decoder1 = DecoderBlock(in_channels=64+64, out_channels=64)
self.gobal_average_pool = nn.Sequential(
# GroupNorm不会改变输入张量的shape,它只是按照group做normalization
nn.GroupNorm(16, 512), # 即将512个channel分为16组
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1), # 自适应平均池化,输出尺寸为1*1
)
#self.gobal_average_pool = nn.AdaptiveAvgPool2d(1)
self.generate_head = nn.Linear(512, self.num_classes*self.unified_channels*self.conv_kernel_size*self.conv_kernel_size)
# self.pred_head = nn.Conv2d(64, self.num_classes, self.conv_kernel_size)
# self.headUpdators = nn.ModuleList()
# for i in range(4):
# self.headUpdators.append(HeadUpdator())
self.headUpdators = nn.ModuleList([HeadUpdator(), HeadUpdator(), HeadUpdator(), HeadUpdator()])
# Unified channel
self.unify1 = nn.Conv2d(64, 64, 1)
self.unify2 = nn.Conv2d(64, 64, 1)
self.unify3 = nn.Conv2d(128, 64, 1)
self.unify4 = nn.Conv2d(256, 64, 1)
self.unify5 = nn.Conv2d(512, 64, 1)
# Efficient self-attention block
self.esa1 = ESA_blcok(dim=64)
self.esa2 = ESA_blcok(dim=64)
self.esa3 = ESA_blcok(dim=128)
self.esa4 = ESA_blcok(dim=256)
#self.esa5 = ESA_blcok(dim=512)
# Lesion-aware cross-attention block
self.lca1 = LCA_blcok(dim=64)
self.lca2 = LCA_blcok(dim=128)
self.lca3 = LCA_blcok(dim=256)
self.lca4 = LCA_blcok(dim=512)
self.decoderList = nn.ModuleList([self.decoder4, self.decoder3, self.decoder2, self.decoder1])
self.unifyList = nn.ModuleList([self.unify4, self.unify3, self.unify2, self.unify1])
self.esaList = nn.ModuleList([self.esa4, self.esa3, self.esa2, self.esa1])
self.lcaList = nn.ModuleList([self.lca4, self.lca3, self.lca2, self.lca1])
def forward(self, x):
# x = 224*224*3
bs = x.shape[0]
e1_ = self.encoder1_conv(x) # 112*112*64
e1_ = self.encoder1_bn(e1_)
e1_ = self.encoder1_relu(e1_)
e1_pool_ = self.maxpool(e1_) # 56*56*64
e2_ = self.encoder2(e1_pool_) # 56*56*256
e3_ = self.encoder3(e2_) # 28*28*512
e4_ = self.encoder4(e3_) # 14*14*1024
e5_ = self.encoder5(e4_) # 7*7*2048
e1 = e1_
e2 = self.reduce2(e2_) # 56*56*64
e3 = self.reduce3(e3_) # 28*28*128
e4 = self.reduce4(e4_) # 14*14*256
e5 = self.reduce5(e5_) # 7*7*512
#e5 = self.esa5(e5)
d5 = self.decoder5(e5) # 7*7*512 -- 14*14*512
feat5 = self.unify5(d5) # 14*14*64
decoder_out = [d5]
encoder_out = [e4, e3, e2, e1]
"""
B = batch size (bs)
N = number of classes (num_classes)
C = feature channels
K = conv kernel size
"""
# [B, 512, 1, 1] -> [B, 512]
gobal_context = self.gobal_average_pool(e5) # B, 512, 1, 1
gobal_context = gobal_context.reshape(bs, -1) # B, 512
# [B, N*C*K*K] -> [B, N, C, K, K]
head = self.generate_head(gobal_context) # B, 512 -- B, 64*num_classes
head = head.reshape(bs, self.num_classes, self.unified_channels, self.conv_kernel_size, self.conv_kernel_size) # B, num_classes, 64, 1, 1
pred = []
for t in range(bs):
pred.append(F.conv2d(
feat5[t:t+1],
head[t],
padding=int(self.conv_kernel_size // 2)))
pred = torch.cat(pred, dim=0) # B, 1, 14, 14
H, W = feat5.shape[-2:] # H=14, W=14
# [B, N, H, W]
pred = pred.reshape(bs, self.num_classes, H, W) # B, num_classes, 14, 14
stage_out = [pred]
# feat size: [B, C, H, W]
# feats = [feat4, feat3, feat2, feat1]
feats = []
# self.decoderList = nn.ModuleList([self.decoder4, self.decoder3, self.decoder2, self.decoder1])
# self.unifyList = nn.ModuleList([self.unify4, self.unify3, self.unify2, self.unify1])
# self.esaList = nn.ModuleList([self.esa4, self.esa3, self.esa2, self.esa1])
# self.lcaList = nn.ModuleList([self.lca4, self.lca3, self.lca2, self.lca1])
# encoder_out = [e4, e3, e2, e1]
for i in range(4):
esa_out = self.esaList[i](encoder_out[i])#输入:B 256 14 14 输出:B 256 14 14
lca_out = self.lcaList[i](decoder_out[-1], stage_out[-1])#输入{d5:B 512 14 14 pred:B num_classes 14 14} 输出:B 512 14 14
comb = torch.cat([lca_out, esa_out], dim=1)#B 512+256 14 14
d = self.decoderList[i](comb)#B 256 28 28
decoder_out.append(d)#decoder_out = [d5 d]
feat = self.unifyList[i](d)#B 64 28 28
feats.append(feat)#feats = [feat]
head = self.headUpdators[i](feats[i], head, pred)#输入{feat:B 64 28 28 head:B num_classes 64 1 1 pred:B num_classes 14 14} 输出:B num_classes 64 1 1
pred = []
for j in range(bs):
pred.append(F.conv2d(
feats[i][j:j+1],
head[j],
padding=int(self.conv_kernel_size // 2)))
pred = torch.cat(pred, dim=0)#B 1 28 28
H, W = feats[i].shape[-2:] # H=28, W=28
pred = pred.reshape(bs, self.num_classes, H, W)#B num_classes 28 28
stage_out.append(pred)
stage_out.reverse() #对列表的元素进行反向排序
#return stage_out[0], stage_out[1], stage_out[2], stage_out[3], stage_out[4]
return torch.sigmoid(stage_out[0]), torch.sigmoid(stage_out[1]), torch.sigmoid(stage_out[2]), \
torch.sigmoid(stage_out[3]), torch.sigmoid(stage_out[4])
tips:虽然有涉及到num_classes参数,但num_classes只能为1,为其他数时会报错!