结构图,先将输入的图像进行通道拆分为组GConv1,每个GConv1再拆分Feature,每个GConv1的Feature进行合并GConv2,输出Output
输入图像x,拆分为groups个组,每隔组的通道数为channels_per_group
batch_size, num_channels, height, width = x.size()
channels_per_group = num_channels // groups
进行变换
# reshape
# [batch_size, num_channels, height, width] -> [batch_size, groups, channels_per_group, height, width]
x = x.view(batch_size, groups, channels_per_group, height, width)
再将1和2的维度进行调换 ,就实现了Feature到GConv2
x = torch.transpose(x, 1, 2).contiguous()
全部代码
def channel_shuffle(x: Tensor, groups: int) -> Tensor:
batch_size, num_channels, height, width = x.size()
channels_per_group = num_channels // groups
# reshape
# [batch_size, num_channels, height, width] -> [batch_size, groups, channels_per_group, height, width]
x = x.view(batch_size, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batch_size, -1, height, width)
return x
参考视频:
8.2 使用Pytorch搭建ShuffleNetv2_哔哩哔哩_bilibili