发表时间:2023年3月7日
论文地址:https://arxiv.org/abs/2303.03667
项目地址:https://github.com/JierunChen/FasterNet
FasterNet-t0在GPU、CPU和ARM处理器上分别比MobileViT-XXS快2.8×、3.3×和2.4×,而准确率要高2.9%。我们的大型FasterNet-L实现了令人印象深刻的83.5%的前1精度,与新兴的Swin-B相当,同时在GPU上有36%的推理吞吐量,并在CPU上节省了37%的计算时间。FasterNet作者提到的其核心在于PConv模块,其不仅减少了FLOPs(降低了冗余计算,其与ghostnet一样,认为conv中存在冗余),同时降低了mac(大部分输入直达输入),故而在取得了高性能的延时能力,如在gpu上fps高,在cpu与arm设备上延时最低。为此对PConv的设计与实现进行深入分析。
1、论文信息
1.1 模块设计
Pconv与常规卷积、分组卷积相比,只对输入通道的少部分做密集卷积(常规卷积),剩余部分直通到输出。该操作大幅度降低了卷积的运算量(如将输入通道分成4份,只对其中一份进行卷积,剩余的3份直通到下一层),也降低了内存访问成本(如C_in为400,只对其四分之一进行卷积,内存访问则为100wh+100wh,内存访问成本为200wh,为原来的1/4)
Pconv对应实现代码如下所示,可以看到就是split=》conv=》cat操作
class Partial_conv3(nn.Module):
def __init__(self, dim, n_div, forward):
super().__init__()
self.dim_conv3 = dim // n_div
self.dim_untouched = dim - self.dim_conv3
self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)
if forward == 'slicing':
self.forward = self.forward_slicing
elif forward == 'split_cat':
self.forward = self.forward_split_cat
else:
raise NotImplementedError
def forward_slicing(self, x: Tensor) -> Tensor:
# only for inference
x = x.clone() # !!! Keep the original input intact for the residual connection later
x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])
return x
def forward_split_cat(self, x: Tensor) -> Tensor:
# for training/inference
x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
x1 = self.partial_conv3(x1)
x = torch.cat((x1, x2), 1)
return x
在论文中提到了与PWcov结合、或是T-shaped Conv,然而在代码层面实际上跟PConv没有任何关系。只是在FasterNet Block中与Conv1x1进行结合conv1x1实现通道间信息交互
1.2 模型结构
Faster的模型结构如下所示,可以看到Pconv只是其中的一小部分。作者将Pconv与conv1x1+BN+Relu+残差联合在一起形成FasterNet Block,FasterNet Block才是模型的主要成分。然后模型中参考了VIT模型设计中的很多设计(如PatchEmbed、mlp),只是没有Transformer模块。
PatchEmbed在模型输入层中可以看到,而mlp操作其实就是Pconv后面的Conv1x1+bn+relu+Conv1x1
具体模型结构如下所示(一共有t0、t1、t2、s、m、l等版本,可以看到数据在经过Embedding层后即完成了1/4下采样;后续的每一个Stage(即FasterNet Block)仅是实现特征提取;最后经过Merging层(即conv2+bn层)实现对数据的下采样
1.3 结构对比
模块性能对比 这里对比了conv、分组卷积、深度分离卷积、PConv。对应的feature map在像素点量上是逐步减半的(如:96x56x56的像素量是192x28x28的一半),可以发现只有DWConv的FLOPs是减半,其他方法是没有减少的。 这里可以发现,DWConv是性价比最高的结构,PConv是第二的(观察fps与latency)。唯独在ARM (Cortex-A72,using a single thread)架构下,PConv比DWConv要强
注:1、PConv在r为1/4时,FLOPs与group为1/16的分组卷积是一样的,但内存访问量是不同的。
注:2、DWConv是全分组卷积(ksize为3,分组数为通道数,仅实现空间信息交互)+点卷积组成(ksize为1,实现通道信息交互)
作者通过对Conv进行拟合,发现PConv是loss最低的。这里是因为GConv与PConv都无法实现全局的通道信息交互,所以需要PWConv。然后为了同等对比,所以DWConv也被迫加上了一个PWConv,这些loss在值差异上只有0.001~0.002,实际上是没有区别的,具体参考ddb_conv、RepConv进行融合输出值差异
内存访问成本对比: 公式2是Pconv的,公式3是conv的,但c’是c的1/4,故而说Pconv的内存访问成本是conv的1/4 这里是假定了模型输入输出的通道数都为c,所以是2c,否则是(c_in+c_out)
1.3 模型效果
宏观对比如下,可以发现FasterNet在GPU上达到了最高的fps,在cpu与arm上达到了最低的延时。
以下图表表示了FasterNet在轻量级与重量级模型中都取得了最近性能。
2、代码实现与分析
2.1 Pconv代码
Pconv的实现代码经过简化后如下所示,可以发现就是简单的split+cat操作。23年博主也做过类似尝试(用pconv全量替换掉conv),并没有训练出好效果
class Partial_conv3(nn.Module):
def __init__(self, dim, n_div, forward):
super().__init__()
self.dim_conv3 = dim // n_div
self.dim_untouched = dim - self.dim_conv3
self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)
def forward(self, x: Tensor) -> Tensor:
# only for inference
x = x.clone() # !!! Keep the original input intact for the residual connection later
x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])
return x
2.2 Faster Block代码
spatial_mixing对象为pconv层
mlp对象为Faster Block模块中的非pconv层
forword代码如下:
def forward(self, x: Tensor) -> Tensor:
shortcut = x
x = self.spatial_mixing(x)
x = shortcut + self.drop_path(self.mlp(x))
return x
完整实现代码如下
class MLPBlock(nn.Module):
def __init__(self,
dim,
n_div,
mlp_ratio,
drop_path,
layer_scale_init_value,
act_layer,
norm_layer,
pconv_fw_type
):
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.n_div = n_div
mlp_hidden_dim = int(dim * mlp_ratio)
mlp_layer: List[nn.Module] = [
nn.Conv2d(dim, mlp_hidden_dim, 1, bias=False),
norm_layer(mlp_hidden_dim),
act_layer(),
nn.Conv2d(mlp_hidden_dim, dim, 1, bias=False)
]
self.mlp = nn.Sequential(*mlp_layer)
self.spatial_mixing = Partial_conv3(
dim,
n_div,
pconv_fw_type
)
if layer_scale_init_value > 0:
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.forward = self.forward_layer_scale
else:
self.forward = self.forward
def forward(self, x: Tensor) -> Tensor:
shortcut = x
x = self.spatial_mixing(x)
x = shortcut + self.drop_path(self.mlp(x))
return x
def forward_layer_scale(self, x: Tensor) -> Tensor:
shortcut = x
x = self.spatial_mixing(x)
x = shortcut + self.drop_path(
self.layer_scale.unsqueeze(-1).unsqueeze(-1) * self.mlp(x))
return x
此外还有一个BasicStage类,其主要就是实现多层MLPBlock(即Faster Block)的堆叠
2.3 PatchEmbed与PatchMerging
PatchEmbed是类似于vit模型中的图像切patch,将空间信息转移到通道上。
PatchMerging是基于conv的stride实现特征图的分辨率降低,同时实现通道的增加。
class PatchEmbed(nn.Module):
def __init__(self, patch_size, patch_stride, in_chans, embed_dim, norm_layer):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, bias=False)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = nn.Identity()
def forward(self, x: Tensor) -> Tensor:
x = self.norm(self.proj(x))
return x
class PatchMerging(nn.Module):
def __init__(self, patch_size2, patch_stride2, dim, norm_layer):
super().__init__()
self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=patch_size2, stride=patch_stride2, bias=False)
if norm_layer is not None:
self.norm = norm_layer(2 * dim)
else:
self.norm = nn.Identity()
def forward(self, x: Tensor) -> Tensor:
x = self.norm(self.reduction(x))
return x
2.4 模型代码
class FasterNet(nn.Module):
def __init__(self,
in_chans=3,
num_classes=1000,
embed_dim=96,
depths=(1, 2, 8, 2),
mlp_ratio=2.,
n_div=4,
patch_size=4,
patch_stride=4,
patch_size2=2, # for subsequent layers
patch_stride2=2,
patch_norm=True,
feature_dim=1280,
drop_path_rate=0.1,
layer_scale_init_value=0,
norm_layer='BN',
act_layer='RELU',
fork_feat=False,
init_cfg=None,
pretrained=None,
pconv_fw_type='split_cat',
**kwargs):
super().__init__()
if norm_layer == 'BN':
norm_layer = nn.BatchNorm2d
else:
raise NotImplementedError
if act_layer == 'GELU':
act_layer = nn.GELU
elif act_layer == 'RELU':
act_layer = partial(nn.ReLU, inplace=True)
else:
raise NotImplementedError
if not fork_feat:
self.num_classes = num_classes
self.num_stages = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_stages - 1))
self.mlp_ratio = mlp_ratio
self.depths = depths
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
patch_size=patch_size,
patch_stride=patch_stride,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None
)
# stochastic depth decay rule
dpr = [x.item()
for x in torch.linspace(0, drop_path_rate, sum(depths))]
# build layers
stages_list = []
for i_stage in range(self.num_stages):
stage = BasicStage(dim=int(embed_dim * 2 ** i_stage),
n_div=n_div,
depth=depths[i_stage],
mlp_ratio=self.mlp_ratio,
drop_path=dpr[sum(depths[:i_stage]):sum(depths[:i_stage + 1])],
layer_scale_init_value=layer_scale_init_value,
norm_layer=norm_layer,
act_layer=act_layer,
pconv_fw_type=pconv_fw_type
)
stages_list.append(stage)
# patch merging layer
if i_stage < self.num_stages - 1:
stages_list.append(
PatchMerging(patch_size2=patch_size2,
patch_stride2=patch_stride2,
dim=int(embed_dim * 2 ** i_stage),
norm_layer=norm_layer)
)
self.stages = nn.Sequential(*stages_list)
self.fork_feat = fork_feat
if self.fork_feat:
self.forward = self.forward_det
# add a norm layer for each output
self.out_indices = [0, 2, 4, 6]
for i_emb, i_layer in enumerate(self.out_indices):
if i_emb == 0 and os.environ.get('FORK_LAST3', None):
raise NotImplementedError
else:
layer = norm_layer(int(embed_dim * 2 ** i_emb))
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
else:
self.forward = self.forward_cls
# Classifier head
self.avgpool_pre_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(self.num_features, feature_dim, 1, bias=False),
act_layer()
)
self.head = nn.Linear(feature_dim, num_classes) \
if num_classes > 0 else nn.Identity()
self.apply(self.cls_init_weights)
self.init_cfg = copy.deepcopy(init_cfg)
if self.fork_feat and (self.init_cfg is not None or pretrained is not None):
self.init_weights()
def cls_init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.Conv1d, nn.Conv2d)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
# init for mmdetection by loading imagenet pre-trained weights
def init_weights(self, pretrained=None):
logger = get_root_logger()
if self.init_cfg is None and pretrained is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
pass
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
if self.init_cfg is not None:
ckpt_path = self.init_cfg['checkpoint']
elif pretrained is not None:
ckpt_path = pretrained
ckpt = _load_checkpoint(
ckpt_path, logger=logger, map_location='cpu')
if 'state_dict' in ckpt:
_state_dict = ckpt['state_dict']
elif 'model' in ckpt:
_state_dict = ckpt['model']
else:
_state_dict = ckpt
state_dict = _state_dict
missing_keys, unexpected_keys = \
self.load_state_dict(state_dict, False)
# show for debug
print('missing_keys: ', missing_keys)
print('unexpected_keys: ', unexpected_keys)
def forward_cls(self, x):
# output only the features of last layer for image classification
x = self.patch_embed(x)
x = self.stages(x)
x = self.avgpool_pre_head(x) # B C 1 1
x = torch.flatten(x, 1)
x = self.head(x)
return x
def forward_det(self, x: Tensor) -> Tensor:
# output the features of four stages for dense prediction
x = self.patch_embed(x)
outs = []
for idx, stage in enumerate(self.stages):
x = stage(x)
if self.fork_feat and idx in self.out_indices:
norm_layer = getattr(self, f'norm{idx}')
x_out = norm_layer(x)
outs.append(x_out)
return outs
2.5 完整模型代码
完整模型代码只是用于3.2中的FLOPs分析
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from functools import partial
from typing import List
from torch import Tensor
import copy
import os
try:
from mmdet.models.builder import BACKBONES as det_BACKBONES
from mmdet.utils import get_root_logger
from mmcv.runner import _load_checkpoint
has_mmdet = True
except ImportError:
print("If for detection, please install mmdetection first")
has_mmdet = False
class Partial_conv3(nn.Module):
def __init__(self, dim, n_div, forward):
super().__init__()
self.dim_conv3 = dim // n_div
self.dim_untouched = dim - self.dim_conv3
self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)
if forward == 'slicing':
self.forward = self.forward_slicing
elif forward == 'split_cat':
self.forward = self.forward_split_cat
else:
raise NotImplementedError
def forward_slicing(self, x: Tensor) -> Tensor:
# only for inference
x = x.clone() # !!! Keep the original input intact for the residual connection later
x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])
return x
def forward_split_cat(self, x: Tensor) -> Tensor:
# for training/inference
x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
x1 = self.partial_conv3(x1)
x = torch.cat((x1, x2), 1)
return x
class MLPBlock(nn.Module):
def __init__(self,
dim,
n_div,
mlp_ratio,
drop_path,
layer_scale_init_value,
act_layer,
norm_layer,
pconv_fw_type
):
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.n_div = n_div
mlp_hidden_dim = int(dim * mlp_ratio)
mlp_layer: List[nn.Module] = [
nn.Conv2d(dim, mlp_hidden_dim, 1, bias=False),
norm_layer(mlp_hidden_dim),
act_layer(),
nn.Conv2d(mlp_hidden_dim, dim, 1, bias=False)
]
self.mlp = nn.Sequential(*mlp_layer)
self.spatial_mixing = Partial_conv3(
dim,
n_div,
pconv_fw_type
)
if layer_scale_init_value > 0:
self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.forward = self.forward_layer_scale
else:
self.forward = self.forward
def forward(self, x: Tensor) -> Tensor:
shortcut = x
x = self.spatial_mixing(x)
x = shortcut + self.drop_path(self.mlp(x))
return x
def forward_layer_scale(self, x: Tensor) -> Tensor:
shortcut = x
x = self.spatial_mixing(x)
x = shortcut + self.drop_path(
self.layer_scale.unsqueeze(-1).unsqueeze(-1) * self.mlp(x))
return x
class BasicStage(nn.Module):
def __init__(self,
dim,
depth,
n_div,
mlp_ratio,
drop_path,
layer_scale_init_value,
norm_layer,
act_layer,
pconv_fw_type
):
super().__init__()
blocks_list = [
MLPBlock(
dim=dim,
n_div=n_div,
mlp_ratio=mlp_ratio,
drop_path=drop_path[i],
layer_scale_init_value=layer_scale_init_value,
norm_layer=norm_layer,
act_layer=act_layer,
pconv_fw_type=pconv_fw_type
)
for i in range(depth)
]
self.blocks = nn.Sequential(*blocks_list)
def forward(self, x: Tensor) -> Tensor:
x = self.blocks(x)
return x
class PatchEmbed(nn.Module):
def __init__(self, patch_size, patch_stride, in_chans, embed_dim, norm_layer):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, bias=False)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = nn.Identity()
def forward(self, x: Tensor) -> Tensor:
x = self.norm(self.proj(x))
return x
class PatchMerging(nn.Module):
def __init__(self, patch_size2, patch_stride2, dim, norm_layer):
super().__init__()
self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=patch_size2, stride=patch_stride2, bias=False)
if norm_layer is not None:
self.norm = norm_layer(2 * dim)
else:
self.norm = nn.Identity()
def forward(self, x: Tensor) -> Tensor:
x = self.norm(self.reduction(x))
return x
class FasterNet(nn.Module):
def __init__(self,
in_chans=3,
num_classes=1000,
embed_dim=96,
depths=(1, 2, 8, 2),
mlp_ratio=2.,
n_div=4,
patch_size=4,
patch_stride=4,
patch_size2=2, # for subsequent layers
patch_stride2=2,
patch_norm=True,
feature_dim=1280,
drop_path_rate=0.1,
layer_scale_init_value=0,
norm_layer='BN',
act_layer='RELU',
fork_feat=False,
init_cfg=None,
pretrained=None,
pconv_fw_type='split_cat',
**kwargs):
super().__init__()
if norm_layer == 'BN':
norm_layer = nn.BatchNorm2d
else:
raise NotImplementedError
if act_layer == 'GELU':
act_layer = nn.GELU
elif act_layer == 'RELU':
act_layer = partial(nn.ReLU, inplace=True)
else:
raise NotImplementedError
if not fork_feat:
self.num_classes = num_classes
self.num_stages = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_stages - 1))
self.mlp_ratio = mlp_ratio
self.depths = depths
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
patch_size=patch_size,
patch_stride=patch_stride,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None
)
# stochastic depth decay rule
dpr = [x.item()
for x in torch.linspace(0, drop_path_rate, sum(depths))]
# build layers
stages_list = []
for i_stage in range(self.num_stages):
stage = BasicStage(dim=int(embed_dim * 2 ** i_stage),
n_div=n_div,
depth=depths[i_stage],
mlp_ratio=self.mlp_ratio,
drop_path=dpr[sum(depths[:i_stage]):sum(depths[:i_stage + 1])],
layer_scale_init_value=layer_scale_init_value,
norm_layer=norm_layer,
act_layer=act_layer,
pconv_fw_type=pconv_fw_type
)
stages_list.append(stage)
# patch merging layer
if i_stage < self.num_stages - 1:
stages_list.append(
PatchMerging(patch_size2=patch_size2,
patch_stride2=patch_stride2,
dim=int(embed_dim * 2 ** i_stage),
norm_layer=norm_layer)
)
self.stages = nn.Sequential(*stages_list)
self.fork_feat = fork_feat
if self.fork_feat:
self.forward = self.forward_det
# add a norm layer for each output
self.out_indices = [0, 2, 4, 6]
for i_emb, i_layer in enumerate(self.out_indices):
if i_emb == 0 and os.environ.get('FORK_LAST3', None):
raise NotImplementedError
else:
layer = norm_layer(int(embed_dim * 2 ** i_emb))
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
else:
self.forward = self.forward_cls
# Classifier head
self.avgpool_pre_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(self.num_features, feature_dim, 1, bias=False),
act_layer()
)
self.head = nn.Linear(feature_dim, num_classes) \
if num_classes > 0 else nn.Identity()
self.apply(self.cls_init_weights)
self.init_cfg = copy.deepcopy(init_cfg)
if self.fork_feat and (self.init_cfg is not None or pretrained is not None):
self.init_weights()
def cls_init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.Conv1d, nn.Conv2d)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
# init for mmdetection by loading imagenet pre-trained weights
def init_weights(self, pretrained=None):
logger = get_root_logger()
if self.init_cfg is None and pretrained is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
pass
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
if self.init_cfg is not None:
ckpt_path = self.init_cfg['checkpoint']
elif pretrained is not None:
ckpt_path = pretrained
ckpt = _load_checkpoint(
ckpt_path, logger=logger, map_location='cpu')
if 'state_dict' in ckpt:
_state_dict = ckpt['state_dict']
elif 'model' in ckpt:
_state_dict = ckpt['model']
else:
_state_dict = ckpt
state_dict = _state_dict
missing_keys, unexpected_keys = \
self.load_state_dict(state_dict, False)
# show for debug
print('missing_keys: ', missing_keys)
print('unexpected_keys: ', unexpected_keys)
def forward_cls(self, x):
# output only the features of last layer for image classification
x = self.patch_embed(x)
x = self.stages(x)
x = self.avgpool_pre_head(x) # B C 1 1
x = torch.flatten(x, 1)
x = self.head(x)
return x
def forward_det(self, x: Tensor) -> Tensor:
# output the features of four stages for dense prediction
x = self.patch_embed(x)
outs = []
for idx, stage in enumerate(self.stages):
x = stage(x)
if self.fork_feat and idx in self.out_indices:
norm_layer = getattr(self, f'norm{idx}')
x_out = norm_layer(x)
outs.append(x_out)
return outs
3、相关分析
3.1 PConv可以取代Conv么?
不可以,其仅是实现了对于C_in与C_out相等时,conv的平替;同时,其只有局部空间信息的交互,大部分通道数据是直连输出,因此会是输入数据直传到网络深层。故而需要密集全连接的卷积层进行通道间信息交互。
在整个论文实验中,也没有将FasterNet中pconv替换为Conv的对比,pconv。或许FasterNet的优势仅是因为其结构设计(尤其是对输入进行PatchEmbed,将空间大小降低为原来的1/16),也就是是使用Conv替代pconv,在acc与延时上或许依旧占据优势。
同样,对于PWConv也没有等效对比,将FasterNet中pconv替换为PWConv或许还能再度迎来性能提升。毕竟在作者实验中,PWConv在gpu上推理速度比pconv更具优势,拟合能力与pconv不相上下。
3.2 FasterNet中的FLOPs分布
基于以下代码构建了一个简易的FasterNet模型,并输出了每一层的flops
if __name__=="__main__":
model=FasterNet( depths=(1, 1, 1, 1),)
from fvcore.nn import flop_count_table, FlopCountAnalysis, ActivationCountAnalysis
x = torch.randn(1, 3, 256, 256)
# model = SAFMN(dim=36, n_blocks=12, ffn_scale=2.0, upscaling_factor=2)
print(f'params: {sum(map(lambda x: x.numel(), model.parameters()))}')
print(flop_count_table(FlopCountAnalysis(model, x), activations=ActivationCountAnalysis(model, x)))
output = model(x)
print(output.shape)
代码运行输出效果如下,可以发现模型关键模块FasterBlock中flops的大头在blocks.0.mlp上,spatial_mixing.partial_conv3(即pconv)只占据了模块10%的计算量为0.21m。
| module | #parameters or shape | #flops | #activations |
|:--------------------------------------------------|:-----------------------|:-----------|:---------------|
| model | 7.4M | 0.948G | 3.136M |
| patch_embed | 4.8K | 20.84M | 0.393M |
| patch_embed.proj | 4.608K | 18.874M | 0.393M |
| patch_embed.proj.weight | (96, 3, 4, 4) | | |
| patch_embed.norm | 0.192K | 1.966M | 0 |
| patch_embed.norm.weight | (96,) | | |
| patch_embed.norm.bias | (96,) | | |
| stages | 5.131M | 0.924G | 2.74M |
| stages.0.blocks.0 | 42.432K | 0.176G | 1.278M |
| stages.0.blocks.0.mlp | 37.248K | 0.155G | 1.18M |
| stages.0.blocks.0.spatial_mixing.partial_conv3 | 5.184K | 21.234M | 98.304K |
| stages.1 | 74.112K | 76.481M | 0.197M |
| stages.1.reduction | 73.728K | 75.497M | 0.197M |
| stages.1.norm | 0.384K | 0.983M | 0 |
| stages.2.blocks.0 | 0.169M | 0.174G | 0.639M |
| stages.2.blocks.0.mlp | 0.148M | 0.153G | 0.59M |
| stages.2.blocks.0.spatial_mixing.partial_conv3 | 20.736K | 21.234M | 49.152K |
| stages.3 | 0.296M | 75.989M | 98.304K |
| stages.3.reduction | 0.295M | 75.497M | 98.304K |
| stages.3.norm | 0.768K | 0.492M | 0 |
| stages.4.blocks.0 | 0.674M | 0.173G | 0.319M |
| stages.4.blocks.0.mlp | 0.591M | 0.152G | 0.295M |
| stages.4.blocks.0.spatial_mixing.partial_conv3 | 82.944K | 21.234M | 24.576K |
| stages.5 | 1.181M | 75.743M | 49.152K |
| stages.5.reduction | 1.18M | 75.497M | 49.152K |
| stages.5.norm | 1.536K | 0.246M | 0 |
| stages.6.blocks.0 | 2.694M | 0.173G | 0.16M |
| stages.6.blocks.0.mlp | 2.362M | 0.151G | 0.147M |
| stages.6.blocks.0.spatial_mixing.partial_conv3 | 0.332M | 21.234M | 12.288K |
| avgpool_pre_head | 0.983M | 1.032M | 1.28K |
| avgpool_pre_head.1 | 0.983M | 0.983M | 1.28K |
| avgpool_pre_head.1.weight | (1280, 768, 1, 1) | | |
| avgpool_pre_head.0 | | 49.152K | 0 |
| head | 1.281M | 1.28M | 1K |
| head.weight | (1000, 1280) | | |
| head.bias | (1000,) | | |
3.3 将PConv替换为Conv的FLops变化
将原来的Partial_conv3类代码替换为以下代码
class Partial_conv3(nn.Module):
def __init__(self, dim, n_div, forward):
super().__init__()
self.conv = nn.Conv2d(dim, dim, 3, 1, 1, bias=False)
def forward(self, x: Tensor) -> Tensor:
# only for inference
x = x.clone() # !!! Keep the original input intact for the residual connection later
x = self.conv(x)
return x
再次运行以下代码后
if __name__=="__main__":
model=FasterNet( depths=(1, 1, 1, 1),)
from fvcore.nn import flop_count_table, FlopCountAnalysis, ActivationCountAnalysis
x = torch.randn(1, 3, 256, 256)
# model = SAFMN(dim=36, n_blocks=12, ffn_scale=2.0, upscaling_factor=2)
print(f'params: {sum(map(lambda x: x.numel(), model.parameters()))}')
print(flop_count_table(FlopCountAnalysis(model, x), activations=ActivationCountAnalysis(model, x)))
output = model(x)
print(output.shape)
这里可以发现flops为2.22g,相比与原来的0.98g翻了一倍。在新的FasterBlock中,spatial_mixing.conv中flops的占比达到了70%,为0.34g,相比于原来的21m为16倍。
| module | #parameters or shape | #flops | #activations |
|:-----------------------------------------|:-----------------------|:-----------|:---------------|
| model | 14.009M | 2.222G | 3.689M |
| patch_embed | 4.8K | 20.84M | 0.393M |
| patch_embed.proj | 4.608K | 18.874M | 0.393M |
| patch_embed.proj.weight | (96, 3, 4, 4) | | |
| patch_embed.norm | 0.192K | 1.966M | 0 |
| patch_embed.norm.weight | (96,) | | |
| patch_embed.norm.bias | (96,) | | |
| stages | 11.74M | 2.199G | 3.293M |
| stages.0.blocks.0 | 0.12M | 0.495G | 1.573M |
| stages.0.blocks.0.mlp | 37.248K | 0.155G | 1.18M |
| stages.0.blocks.0.spatial_mixing.conv | 82.944K | 0.34G | 0.393M |
| stages.1 | 74.112K | 76.481M | 0.197M |
| stages.1.reduction | 73.728K | 75.497M | 0.197M |
| stages.1.norm | 0.384K | 0.983M | 0 |
| stages.2.blocks.0 | 0.48M | 0.493G | 0.786M |
| stages.2.blocks.0.mlp | 0.148M | 0.153G | 0.59M |
| stages.2.blocks.0.spatial_mixing.conv | 0.332M | 0.34G | 0.197M |
| stages.3 | 0.296M | 75.989M | 98.304K |
| stages.3.reduction | 0.295M | 75.497M | 98.304K |
| stages.3.norm | 0.768K | 0.492M | 0 |
| stages.4.blocks.0 | 1.918M | 0.492G | 0.393M |
| stages.4.blocks.0.mlp | 0.591M | 0.152G | 0.295M |
| stages.4.blocks.0.spatial_mixing.conv | 1.327M | 0.34G | 98.304K |
| stages.5 | 1.181M | 75.743M | 49.152K |
| stages.5.reduction | 1.18M | 75.497M | 49.152K |
| stages.5.norm | 1.536K | 0.246M | 0 |
| stages.6.blocks.0 | 7.671M | 0.491G | 0.197M |
| stages.6.blocks.0.mlp | 2.362M | 0.151G | 0.147M |
| stages.6.blocks.0.spatial_mixing.conv | 5.308M | 0.34G | 49.152K |
| avgpool_pre_head | 0.983M | 1.032M | 1.28K |
| avgpool_pre_head.1 | 0.983M | 0.983M | 1.28K |
| avgpool_pre_head.1.weight | (1280, 768, 1, 1) | | |
| avgpool_pre_head.0 | | 49.152K | 0 |
| head | 1.281M | 1.28M | 1K |
| head.weight | (1000, 1280) | | |
| head.bias | (1000,) | | |
torch.Size([1, 1000])
3.3 整体结论
基于3.1-3.3的分析,可以发现我们不能直接用pconv取代模型中所有的conv层,但可以在部分层中取代个别flops较大的conv中。pconv只是近似conv的一个选择,其仅是在FasterNet的架构设计下发挥作用,直接平替到其他模型中必然存在水土不服(需要额外的PWConv层实现信息交互)。
但是,FasterNet却为我们提供了一个强大的backbone,其在轻量级与重量级模型中均达到了最佳精度下的最快速度,可以用于图像分类、目标检测中。然后在我们的实验中,或许可以将FasterNet中的Pconv替换为DWConv,这样也许能再次提升backbone能力的提升。毕竟作者没有做这个对比,也说不定是发现Pconv不如DWConv后隐匿了这一部分实验数据