1. 官方给出的代码
旷视科技在自己的开源GitHub上给出的channel shuffle相关代码如下图所示:
分析上图中的代码,旷视科技将channel shuffle这个操作视为一个函数,函数传入的参数是输入张量x,x的shape为(batchsize, num_channels, height, width)。
首先对输入张量x使用data.size() 方法进行解包,从输入张量 x
中提取批量大小、通道数、高度和宽度。
使用assert函数检查 group【分组个数】是否能够整除num_channels,若不能够整除,则函数运行到此处抛出AssertionError
异常;若能够整除,则正常运行。
每个小组的通道数group_channels 为总体通道数num_channels除分组个数group。
接下来的三行代码均对x操作,我们一步一步来剖析:
首先经过:
x = x.reshape(batchsize, group_channels, self.group, height, width)
这表示x需要经过reshape操作,将num_channels分为group个组,每个组中的通道数为group_channels。
在经过如下操作:
x = x.permute(0, 2, 1, 3, 4)
这表示要将x的第1个维度与第2个维度进行互换,也就是说,可以理解为在这里对x经历了转置操作。
- 重新排列维度,使得维度的顺序变为:
- 维度 0:批量大小保持不变。
- 维度 2:将组的维度移到第二位。
- 维度 1:将每组的通道维度移到后面。
- 维度 3 和 4:高度和宽度保持不变。
- 这一步的作用是将不同组的通道位置互换,从而实现通道间的信息交互。
然后再经过如下操作:
x = x.reshape(batchsize, num_channels, height, width)
将重排后的张量重塑回 (batchsize, num_channels, height, width) 原始形状。
最后借助 return x 返回channel shuffle后的张量。
总结
该方法实现了channel shuffle的过程,通过将通道分组、重排和恢复形状来增强通道间的信息交互,通常用于提升轻量级网络的性能。channel shuffle有助于使模型更好地利用特征共享,提高整体表现。
2. 喂入测试张量进行测试【图例分析】
假设输入张量的shape为:(1, 12, 1, 1) group=3
首先通过以下代码构建输入张量,使用unsqueeze函数是为了给一维张量进行扩维,使之符合输入张量的shape。
对官方代码小修小改,得到独立可运行的channel_shuffle函数,如下图所示:
以图说明上述代码: