LeMeViT:具有可学习元令牌的高效ViT

本文提出使用可学习的元令牌来制定稀疏令牌,这有效地学习了关键信息,同时提高了推理速度。从技术上讲,主题标记首先通过交叉关注从图像标记中初始化。提出了双交叉注意(DCA)来促进图像令牌和元令牌之间的信息交换,其中它们在双分支结构中交替充当查询和密钥(值)令牌,与自注意相比,显著降低了计算复杂度。通过在具有密集视觉标记的早期阶段使用DCA,获得了不同大小的分层结构LeMeViT。在分类和密集预测任务中的实验结果表明,与baseline相比,LeMeViT具有1.7倍的显著加速、更少的参数和有竞争力的性能,并在效率和性能之间实现了更好的权衡。

        现有方法通常使用下采样或 clus tering 来减少当前块内的图像标记数量,这依赖于强先验或对并行计算不友好。而通过学习元标记稀疏地表示密集的图像标记。元代币通过计算高效的双交叉注意力块以端到端的方式与图像代币交换信息,促进信息分阶段流动。


 LeMeViT总结构:LeMeViT由三个不同的注意力块组成,从左到右排列为交叉注意力块、双交叉注意力块和标准注意力块。

通过代码来实现:

def scaled_dot_product_attention(q, k, v, scale=None):
    """Custom Scaled-Dot Product Attention
        dim (B h N d)
    """
    _,_,_,dim = q.shape
    scale = scale or dim**(-0.5)
    attn = q @ k.transpose(-1,-2) * scale
    attn = attn.softmax(dim=-1)
    x = attn @ v
    return x


class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x):
        """
        x: NHWC tensor
        """
        x = x.permute(0, 3, 1, 2) #NCHW
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1) #NHWC

        return x

class Attention(nn.Module):
    """Patch-to-Cluster Attention Layer"""
    
    def __init__(
        self,
        dim,
        num_heads,
        attn_drop=0.0,
        proj_drop=0.0,
        **kwargs,
    ):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
        self.num_heads = num_heads

        self.use_xformers = has_xformers and (dim // num_heads) % 32 == 0

        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)
        self.attn_drop = attn_drop

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_viz = nn.Identity() 

    def forward(self, x):
        if self.use_xformers:
            q = self.q(x)  # B N C
            k = self.k(x)  # B N C
            v = self.v(x)
            q = rearrange(q, "B N (h d) -> B N h d", h=self.num_heads)
            k = rearrange(k, "B N (h d) -> B N h d", h=self.num_heads)
            v = rearrange(v, "B N (h d) -> B N h d", h=self.num_heads)

            x = xops.memory_efficient_attention(q, k, v)  # B N h d
            x = rearrange(x, "B N h d -> B N (h d)")

            x = self.proj(x)
        else:
            x = rearrange(x, "B N C -> N B C")

            x, attn = F.multi_head_attention_forward(
                query=x,
                key=x,
                value=x,
                embed_dim_to_check=x.shape[-1],
                num_heads=self.num_heads,
                q_proj_weight=self.q.weight,
                k_proj_weight=self.k.weight,
                v_proj_weight=self.v.weight,
                in_proj_weight=None,
                in_proj_bias=torch.cat([self.q.bias, self.k.bias, self.v.bias]),
                bias_k=None,
                bias_v=None,
                add_zero_attn=False,
                dropout_p=self.attn_drop,
                out_proj_weight=self.proj.weight,
                out_proj_bias=self.proj.bias,
                use_separate_proj_weight=True,
                training=self.training,
                need_weights=not self.training,  # for visualization
                average_attn_weights=False,
            )

            x = rearrange(x, "N B C -> B N C")

            if not self.training:
                attn = self.attn_viz(attn)

        x = self.proj_drop(x)

        return x

class StandardAttention(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        scale = None,
        bias = False,
        attn_drop=0.0,
        proj_drop=0.0,
        **kwargs,
    ):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
        self.num_heads = num_heads
        
        self.use_flash_attn = has_flash_attn
        self.use_xformers = has_xformers and (dim // num_heads) % 32 == 0
        self.use_torchfunc = has_torchfunc

        self.qkv = nn.Linear(dim, 3 * dim)
        self.attn_drop = attn_drop

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_viz = nn.Identity() 
        self.scale = dim**0.5
        
    # @get_local('attn_map')
    def forward(self, x):
        if self.use_flash_attn:
            qkv = self.qkv(x)
            qkv = rearrange(qkv, "B N (x h d) -> B N x h d", x=3, h=self.num_heads).contiguous()
            x = flash_attn_qkvpacked_func(qkv)
            x = rearrange(x, "B N h d -> B N (h d)").contiguous()
            x = self.proj(x)
        elif self.use_xformers:
            qkv = self.qkv(x)
            qkv = rearrange(qkv, "B N (x h d) -> x B N h d", x=3, h=self.num_heads).contiguous()
            q, k, v = qkv[0], qkv[1], qkv[2]
            x = xops.memory_efficient_attention(q, k, v)  # B N h d
            x = rearrange(x, "B N h d -> B N (h d)").contiguous()
            x = self.proj(x)
        elif self.use_torchfunc:
            qkv = self.qkv(x)
            qkv = rearrange(qkv, "B N (x h d) -> x B h N d", x=3, h=self.num_heads).contiguous()
            q, k, v = qkv[0], qkv[1], qkv[2]
            x = F.scaled_dot_product_attention(q, k, v)  # B N h d
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj(x)
        else:
            qkv = self.qkv(x)
            qkv = rearrange(qkv, "B N (x h d) -> x B h N d", x=3, h=self.num_heads).contiguous()
            q, k, v = qkv[0], qkv[1], qkv[2]
            x = scaled_dot_product_attention(q, k, v)  # B N h d
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj(x)
        # with torch.no_grad():
        #     attn = (q @ k.transpose(-2, -1)) * self.scale
        #     attn_map = attn.softmax(dim=-1)
        #     # print("Standard:", attn_map)
        return x


class DualCrossAttention(nn.Module):
    
    def __init__(
        self,
        dim,
        num_heads,
        scale = None,
        bias = False,
        attn_drop=0.0,
        proj_drop=0.0,
        **kwargs,
    ):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
        self.num_heads = num_heads
        self.scale = scale or dim**(-0.5)

        self.use_flash_attn = has_flash_attn
        self.use_xformers = has_xformers and (dim // num_heads) % 32 == 0
        self.use_torchfunc = has_torchfunc

        self.qkv1 = nn.Linear(dim, 3 * dim)
        self.qkv2 = nn.Linear(dim, 3 * dim)
        self.attn_drop = nn.Dropout(attn_drop)

        self.proj_x = nn.Linear(dim, dim)
        self.proj_c = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_viz = nn.Identity() 
        
    # @get_local('attn_map')
    def forward(self, x, c):
        B, N, C = x.shape        
        B, M, _ = c.shape 
        scale_x = math.log(M, N) * self.scale
        scale_c = math.log(N, N) * self.scale
        
        if self.use_flash_attn:
            qkv1 = self.qkv1(x)
            qkv1 = rearrange(qkv1, "B N (x h d) -> B N x h d", x=3, h=self.num_heads).contiguous()
            qkv2 = self.qkv2(c)
            qkv2 = rearrange(qkv2, "B M (x h d) -> B M x h d", x=3, h=self.num_heads).contiguous()
            
            q1, kv1 = qkv1[:,:,0], qkv1[:,:,1:]
            q2, kv2 = qkv2[:,:,0], qkv2[:,:,1:]
            
            x = flash_attn_kvpacked_func(q1, kv2, softmax_scale=scale_x)
            x = rearrange(x, "B N h d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = flash_attn_kvpacked_func(q2, kv1, softmax_scale=scale_c)
            c = rearrange(c, "B M h d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        elif self.use_xformers:
            qkv1 = self.qkv1(x)
            qkv1 = rearrange(qkv1, "B N (x h d) -> x B N h d", x=3, h=self.num_heads).contiguous()
            qkv2 = self.qkv2(c)
            qkv2 = rearrange(qkv2, "B M (x h d) -> x B M h d", x=3, h=self.num_heads).contiguous()
            
            q1, k1, v1 = qkv1[0], qkv1[1], qkv1[2]
            q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]
            
            x = xops.memory_efficient_attention(q1, k2, v2, scale=scale_x)  # B N h d
            x = rearrange(x, "B N h d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = xops.memory_efficient_attention(q2, k1, v1, scale=scale_c)  # B N h d
            c = rearrange(c, "B M h d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        elif self.use_torchfunc:
            qkv1 = self.qkv1(x)
            qkv1 = rearrange(qkv1, "B N (x h d) -> x B h N d", x=3, h=self.num_heads).contiguous()
            qkv2 = self.qkv2(c)
            qkv2 = rearrange(qkv2, "B M (x h d) -> x B h M d", x=3, h=self.num_heads).contiguous()
            
            q1, k1, v1 = qkv1[0], qkv1[1], qkv1[2]
            q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]
            
            x = F.scaled_dot_product_attention(q1, k2, v2)  # B N h d
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = F.scaled_dot_product_attention(q2, k1, v1)  # B N h d
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        else:
            qkv1 = self.qkv1(x)
            qkv1 = rearrange(qkv1, "B N (x h d) -> x B h N d", x=3, h=self.num_heads).contiguous()
            qkv2 = self.qkv2(c)
            qkv2 = rearrange(qkv2, "B M (x h d) -> x B h M d", x=3, h=self.num_heads).contiguous()
            
            q1, k1, v1 = qkv1[0], qkv1[1], qkv1[2]
            q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2]
            
            x = scaled_dot_product_attention(q1, k2, v2, scale=scale_x)  # B N h d
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = scaled_dot_product_attention(q2, k1, v1, scale=scale_c)  # B N h d
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        # with torch.no_grad():   
        #     # q1 = rearrange(q1, "B h M d -> B M (h d)").contiguous()
        #     # k2 = rearrange(k2, "B h M d -> B M (h d)").contiguous()
        #     attn = (q1 @ k2.transpose(-2, -1)) * scale_x
        #     attn_map = attn.softmax(dim=-1)
        #     # print("Mix:", attn_map)
        return x, c
    
class DualCrossAttention_v2(nn.Module):
    
    def __init__(
        self,
        dim,
        num_heads,
        scale = None,
        bias = False,
        attn_drop=0.0,
        proj_drop=0.0,
        **kwargs,
    ):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
        self.num_heads = num_heads
        self.scale = scale or dim**(-0.5)

        self.use_flash_attn = has_flash_attn
        self.use_xformers = has_xformers and (dim // num_heads) % 32 == 0
        self.use_torchfunc = has_torchfunc
        
        self.qv1 = nn.Linear(dim, 2 * dim)
        self.kv2 = nn.Linear(dim, 2 * dim)
        self.attn_drop = nn.Dropout(attn_drop)

        self.proj_x = nn.Linear(dim, dim)
        self.proj_c = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_viz = nn.Identity() 

    def forward(self, x, c):
        B, N, C = x.shape        
        B, M, _ = c.shape 
        scale_x = math.log(M, N) * self.scale
        scale_c = math.log(N, N) * self.scale
        
        if self.use_flash_attn:
            qv1 = self.qv1(x)
            qv1 = rearrange(qv1, "B N (x h d) -> B N x h d", x=2, h=self.num_heads).contiguous()
            kv2 = self.kv2(c)
            kv2 = rearrange(kv2, "B M (x h d) -> B M x h d", x=2, h=self.num_heads).contiguous()
            
            q, v1 = qv1[:,:,0], qv1[:,:,1]
            k, v2 = kv2[:,:,0], kv2[:,:,1]
            
            x = flash_attn_func(q, k, v2, softmax_scale=scale_x)
            x = rearrange(x, "B N h d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = flash_attn_func(k, q, v1, softmax_scale=scale_c)
            c = rearrange(c, "B M h d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        elif self.use_xformers:
            qv1 = self.qv1(x)
            qv1 = rearrange(qv1, "B N (x h d) -> x B h N d", x=2, h=self.num_heads).contiguous()
            kv2 = self.kv2(c)
            kv2 = rearrange(kv2, "B M (x h d) -> x B h M d", x=2, h=self.num_heads).contiguous()
            
            q, v1 = qv1[0], qv1[1]
            k, v2 = kv2[0], kv2[1]
            
            x = xops.memory_efficient_attention(q, k, v2, scale=scale_x)
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = xops.memory_efficient_attention(k, q, v1, scale=scale_c)
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        elif self.use_torchfunc:
            qv1 = self.qv1(x)
            qv1 = rearrange(qv1, "B N (x h d) -> x B h N d", x=2, h=self.num_heads).contiguous()
            kv2 = self.kv2(c)
            kv2 = rearrange(kv2, "B M (x h d) -> x B h M d", x=2, h=self.num_heads).contiguous()
            
            q, v1 = qv1[0], qv1[1]
            k, v2 = kv2[0], kv2[1]

            x = F.scaled_dot_product_attention(q, k, v2)
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = F.scaled_dot_product_attention(k, q, v1)
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        else:
            qv1 = self.qv1(x)
            qv1 = rearrange(qv1, "B N (x h d) -> x B h N d", x=2, h=self.num_heads).contiguous()
            kv2 = self.kv2(c)
            kv2 = rearrange(kv2, "B M (x h d) -> x B h M d", x=2, h=self.num_heads).contiguous()
            
            q, v1 = qv1[0], qv1[1]
            k, v2 = kv2[0], kv2[1]

            x = scaled_dot_product_attention(q, k, v2)
            x = rearrange(x, "B h N d -> B N (h d)").contiguous()
            x = self.proj_x(x)
            c = scaled_dot_product_attention(k, q, v1)
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj_c(c)
        return x, c

class CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        scale = None,
        bias = False,
        attn_drop=0.0,
        proj_drop=0.0,
        **kwargs,
    ):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} not divisible by num_heads {num_heads}"
        self.num_heads = num_heads

        self.use_flash_attn = has_flash_attn
        self.use_xformers = has_xformers and (dim // num_heads) % 32 == 0
        self.use_torchfunc = has_torchfunc

        self.q = nn.Linear(dim, dim)
        self.kv = nn.Linear(dim, 2 * dim)
        self.attn_drop = attn_drop

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.attn_viz = nn.Identity() 


    def forward(self, x, c):
        B, N, C = x.shape        
        B, M, _ = c.shape 
        
        if self.use_flash_attn:
            q = self.q(c)
            kv = self.kv(x)
            q = rearrange(q, "B M (h d) -> B M h d", h=self.num_heads).contiguous()
            kv = rearrange(kv, "B N (x h d) -> B N x h d", x=2, h=self.num_heads).contiguous()
            
            c = flash_attn_kvpacked_func(q, kv)
            c = rearrange(c, "B M h d -> B M (h d)").contiguous()
            c = self.proj(c)
        elif self.use_xformers:
            q = self.q(c)
            kv = self.kv(x)
            q = rearrange(q, "B M (h d) -> B M h d", h=self.num_heads).contiguous()
            kv = rearrange(kv, "B N (x h d) -> x B N h d", x=2, h=self.num_heads).contiguous()
            k, v = kv[0], kv[1]
            
            c = xops.memory_efficient_attention(q, k, v)
            c = rearrange(c, "B M h d -> B M (h d)").contiguous()
            c = self.proj(c)
        elif self.use_torchfunc:
            q = self.q(c)
            kv = self.kv(x)
            q = rearrange(q, "B M (h d) -> B h M d", h=self.num_heads).contiguous()
            kv = rearrange(kv, "B N (x h d) -> x B h N d", x=2, h=self.num_heads).contiguous()
            k, v = kv[0], kv[1]
            
            c = F.scaled_dot_product_attention(q, k, v)
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj(c)
        else:
            q = self.q(c)
            kv = self.kv(x)
            q = rearrange(q, "B M (h d) -> B h M d", h=self.num_heads).contiguous()
            kv = rearrange(kv, "B N (x h d) -> x B h N d", x=2, h=self.num_heads).contiguous()
            k, v = kv[0], kv[1]
            
            c = scaled_dot_product_attention(q, k, v)
            c = rearrange(c, "B h M d -> B M (h d)").contiguous()
            c = self.proj(c)
        return c


class LeMeBlock(nn.Module):
    def __init__(self, dim, 
                 attn_drop, proj_drop, drop_path=0., attn_type=None,
                 layer_scale_init_value=-1, num_heads=8, qk_dim=None, mlp_ratio=4, mlp_dwconv=False,
                 cpe_ks=3, pre_norm=True):
        super().__init__()
        qk_dim = qk_dim or dim

        # modules
        if cpe_ks > 0:
            self.pos_embed = nn.Conv2d(dim, dim,  kernel_size=cpe_ks, padding=1, groups=dim)
        else:
            self.pos_embed = lambda x: 0
        self.norm1 = nn.LayerNorm(dim, eps=1e-6) # important to avoid attention collapsing
        
        self.attn_type = attn_type
        if attn_type == "D":
            self.attn = DualCrossAttention(dim=dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
        elif attn_type == "D2":
            self.attn = DualCrossAttention_v2(dim=dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
        elif attn_type == "S" or attn_type == None:
            self.attn = StandardAttention(dim=dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
        elif attn_type == "C":
            self.attn = CrossAttention(dim=dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop)
            
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.mlp = nn.Sequential(nn.Linear(dim, int(mlp_ratio*dim)),
                                 DWConv(int(mlp_ratio*dim)) if mlp_dwconv else nn.Identity(),
                                 nn.GELU(),
                                 nn.Linear(int(mlp_ratio*dim), dim)
                                )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        # tricks: layer scale & pre_norm/post_norm
        if layer_scale_init_value > 0:
            self.use_layer_scale = True
            self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones((1,1,dim)), requires_grad=True)
            self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((1,1,dim)), requires_grad=True)
        else:
            self.use_layer_scale = False
        self.pre_norm = pre_norm
            
    def forward_with_xc(self, x, c):

        _, C, H, W = x.shape
        # conv pos embedding
        x = x + self.pos_embed(x)
        # permute to NHWC tensor for attention & mlp
        x = rearrange(x, "N C H W -> N (H W) C")
        # x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)

        # attention & mlp
        if self.pre_norm:
            if self.use_layer_scale:
                _x, _c = self.attn(self.norm1(x), self.norm1(c))
                x = x + self.drop_path(self.gamma1 * _x) # (N, H, W, C)
                x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) # (N, H, W, C)
                c = c + self.drop_path(self.gamma1 * _c) # (N, H, W, C)
                c = c + self.drop_path(self.gamma2 * self.mlp(self.norm2(c))) # (N, H, W, C)
            else:
                _x, _c = self.attn(self.norm1(x), self.norm1(c))
                x = x + self.drop_path(_x) # (N, H, W, C)
                x = x + self.drop_path(self.mlp(self.norm2(x))) # (N, H, W, C)
                c = c + self.drop_path(_c) # (N, H, W, C)
                c = c + self.drop_path(self.mlp(self.norm2(c))) # (N, H, W, C)
        else: # https://kexue.fm/archives/9009
            if self.use_layer_scale:
                _x, _c = self.attn(x,c)
                x = self.norm1(x + self.drop_path(self.gamma1 * _x)) # (N, H, W, C)
                x = self.norm2(x + self.drop_path(self.gamma2 * self.mlp(x))) # (N, H, W, C)
                c = self.norm1(c + self.drop_path(self.gamma1 * _c)) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.gamma2 * self.mlp(c))) # (N, H, W, C)
            else:
                _x, _c = self.attn(x,c)
                x = self.norm1(x + self.drop_path(_x)) # (N, H, W, C)
                x = self.norm2(x + self.drop_path(self.mlp(x))) # (N, H, W, C)
                c = self.norm1(c + self.drop_path(_c)) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.mlp(c))) # (N, H, W, C)
                
        x = rearrange(x, "N (H W) C -> N C H W",H=H,W=W)
        # permute back
        # x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
        return x, c

    def forward_with_c(self, x, c):
        
        _, C, H, W = x.shape
        _x = x
        # conv pos embedding
        x = x + self.pos_embed(x)
        # permute to NHWC tensor for attention & mlp
        x = rearrange(x, "N C H W -> N (H W) C")
        # x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)

        # attention & mlp
        if self.pre_norm:
            if self.use_layer_scale:
                c = c + self.drop_path(self.gamma1 * self.attn(self.norm1(x), self.norm1(c))) # (N, H, W, C)
                c = c + self.drop_path(self.gamma2 * self.mlp(self.norm2(c))) # (N, H, W, C)
            else:
                c = c + self.drop_path(self.attn(self.norm1(x),self.norm1(c))) # (N, H, W, C)
                c = c + self.drop_path(self.mlp(self.norm2(c))) # (N, H, W, C)
        else: # https://kexue.fm/archives/9009
            if self.use_layer_scale:
                c = self.norm1(c + self.drop_path(self.gamma1 * self.attn(x,c))) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.gamma2 * self.mlp(c))) # (N, H, W, C)
            else:
                c = self.norm1(c + self.drop_path(self.attn(x,c))) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.mlp(c))) # (N, H, W, C)

        x = _x
        # permute back
        # x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
        return x, c

    def forward_with_x(self, x, c):
        
        _, C, H, W = x.shape
        # conv pos embedding
        x = x + self.pos_embed(x)
        # permute to NHWC tensor for attention & mlp
        x = rearrange(x, "N C H W -> N (H W) C")
        # x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)

        # attention & mlp
        if self.pre_norm:
            if self.use_layer_scale:
                x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) # (N, H, W, C)
                x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) # (N, H, W, C)
                c = c + self.drop_path(self.gamma1 * self.attn(self.norm1(c))) # (N, H, W, C)
                c = c + self.drop_path(self.gamma2 * self.mlp(self.norm2(c))) # (N, H, W, C)
            else:
                x = x + self.drop_path(self.attn(self.norm1(x))) # (N, H, W, C)
                x = x + self.drop_path(self.mlp(self.norm2(x))) # (N, H, W, C)
                c = c + self.drop_path(self.attn(self.norm1(c))) # (N, H, W, C)
                c = c + self.drop_path(self.mlp(self.norm2(c))) # (N, H, W, C)
        else: # https://kexue.fm/archives/9009
            if self.use_layer_scale:
                x = self.norm1(x + self.drop_path(self.gamma1 * self.attn(x))) # (N, H, W, C)
                x = self.norm2(x + self.drop_path(self.gamma2 * self.mlp(x))) # (N, H, W, C)
                c = self.norm1(c + self.drop_path(self.gamma1 * self.attn(c))) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.gamma2 * self.mlp(c))) # (N, H, W, C)
            else:
                x = self.norm1(x + self.drop_path(self.attn(x))) # (N, H, W, C)
                x = self.norm2(x + self.drop_path(self.mlp(x))) # (N, H, W, C)
                c = self.norm1(c + self.drop_path(self.attn(c))) # (N, H, W, C)
                c = self.norm2(c + self.drop_path(self.mlp(c))) # (N, H, W, C)
        x = rearrange(x, "N (H W) C -> N C H W",H=H,W=W)
        # permute back
        # x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
        return x, c
    
    def forward(self, x, c):
        if self.attn_type == "D" or self.attn_type == "D2":
            return self.forward_with_xc(x,c)
        elif self.attn_type == "S":
            return self.forward_with_x(x,c)
        elif self.attn_type == "C":
            return self.forward_with_c(x,c)
        else:
            raise NotImplementedError("Attention type does not exit")


class LeMeViT(nn.Module):
    def __init__(self, 
                 depth=[2, 3, 4, 8, 3], 
                 in_chans=3, 
                 num_classes=1000, 
                 embed_dim=[64, 64, 128, 320, 512], 
                 head_dim=64, 
                 mlp_ratios=[4, 4, 4, 4, 4], 
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop=0., 
                 drop_path_rate=0.,
                 # <<<------
                 attn_type=["C","D","D","S","S"],
                 queries_len=128,
                 qk_dims=None,
                 cpe_ks=3,
                 pre_norm=True,
                 mlp_dwconv=False,
                 representation_size=None,
                 layer_scale_init_value=-1,
                 use_checkpoint_stages=[],
                 # ------>>>
                 ):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        qk_dims = qk_dims or embed_dim
        
        self.num_stages = len(attn_type)
        
        ############ downsample layers (patch embeddings) ######################
        self.downsample_layers = nn.ModuleList()
        # NOTE: uniformer uses two 3*3 conv, while in many other transformers this is one 7*7 conv 
        stem = nn.Sequential(
            nn.Conv2d(in_chans, embed_dim[0] // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(embed_dim[0] // 2),
            nn.GELU(),
            nn.Conv2d(embed_dim[0] // 2, embed_dim[0], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(embed_dim[0]),
        )

        if use_checkpoint_stages:
            stem = checkpoint_wrapper(stem)
        self.downsample_layers.append(stem)

        for i in range(self.num_stages-1):
            if attn_type[i] == "C":
                downsample_layer = nn.Identity()
            else:
                downsample_layer = nn.Sequential(
                    nn.Conv2d(embed_dim[i], embed_dim[i+1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
                    nn.BatchNorm2d(embed_dim[i+1])
                )
            if use_checkpoint_stages:
                downsample_layer = checkpoint_wrapper(downsample_layer)
            self.downsample_layers.append(downsample_layer)
        ##########################################################################


        #TODO: maybe remove last LN
        self.queries_len = queries_len
        self.meta_tokens = nn.Parameter(torch.randn(self.queries_len ,embed_dim[0]), requires_grad=True) 
        
        self.meta_token_downsample = nn.ModuleList()
        meta_token_downsample = nn.Sequential(
            nn.Linear(embed_dim[0], embed_dim[0] * 4),
            nn.LayerNorm(embed_dim[0] * 4),
            nn.GELU(),
            nn.Linear(embed_dim[0] * 4, embed_dim[0]),
            nn.LayerNorm(embed_dim[0])
        )
        self.meta_token_downsample.append(meta_token_downsample)
        for i in range(self.num_stages-1):
            meta_token_downsample = nn.Sequential(
                nn.Linear(embed_dim[i], embed_dim[i] * 4),
                nn.LayerNorm(embed_dim[i] * 4),
                nn.GELU(),
                nn.Linear(embed_dim[i] * 4, embed_dim[i+1]),
                nn.LayerNorm(embed_dim[i+1])
            )
            self.meta_token_downsample.append(meta_token_downsample)

        
        self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
        nheads= [dim // head_dim for dim in qk_dims]
        dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] 
        cur = 0
        for i in range(self.num_stages):
            stage = nn.ModuleList(
                [LeMeBlock(dim=embed_dim[i], 
                           attn_drop=attn_drop, proj_drop=drop_rate,
                           drop_path=dp_rates[cur + j],
                           attn_type=attn_type[i],
                           layer_scale_init_value=layer_scale_init_value,
                           num_heads=nheads[i],
                           qk_dim=qk_dims[i],
                           mlp_ratio=mlp_ratios[i],
                           mlp_dwconv=mlp_dwconv,
                           cpe_ks=cpe_ks,
                           pre_norm=pre_norm
                    ) for j in range(depth[i])],
            )
            if i in use_checkpoint_stages:
                stage = checkpoint_wrapper(stage)
            self.stages.append(stage)
            cur += depth[i]

        ##########################################################################
        self.norm = nn.BatchNorm2d(embed_dim[-1])
        self.norm_c = nn.LayerNorm(embed_dim[-1])
        # Representation layer
        if representation_size:
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ('fc', nn.Linear(embed_dim, representation_size)),
                ('act', nn.Tanh())
            ]))
        else:
            self.pre_logits = nn.Identity()

        # Classifier head
        self.head = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x, c):
        for i in range(self.num_stages): 
            x = self.downsample_layers[i](x)
            c = self.meta_token_downsample[i](c)
            for j, block in enumerate(self.stages[i]):
                x, c = block(x, c)
        x = self.norm(x)
        x = self.pre_logits(x)
        
        c = self.norm_c(c)
        c = self.pre_logits(c)

        # x = x.flatten(2).mean(-1,keepdim=True)
        # c = c.transpose(-2,-1).contiguous().mean(-1,keepdim=True)
        # x = torch.concat([x,c],dim=-1).mean(-1)

        x = x.flatten(2).mean(-1)
        c = c.transpose(-2,-1).contiguous().mean(-1)
        x = x + c

        return x

    def forward(self, x):
        B, _, H, W = x.shape 
        c = self.meta_tokens.repeat(B,1,1)
        x = self.forward_features(x, c)
        x = self.head(x)
        return x

同时,在整体架构的基础上,作者设计了三种不同大小的模型,即 Tiny、Small 和 Base。通过调整每个阶段的块数和特征的尺寸来定制这些尺寸,其他配置在所有变体之间共享。我们将每个关注的头尺寸设置为 32,MLP 扩展率为 4,条件位置编码核大小为 3。元标记的长度设置为 16。

@register_model
def lemevit_tiny(pretrained=False, pretrained_cfg=None,
                  pretrained_cfg_overlay=None, **kwargs):
    model = LeMeViT(
        depth=[1, 2, 2, 8, 2],
        embed_dim=[64, 64, 128, 192, 320], 
        head_dim=32,
        mlp_ratios=[4, 4, 4, 4, 4],
        attn_type=["C","D","D","S","S"],
        queries_len=16,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.,
        qk_dims=None,
        cpe_ks=3,
        pre_norm=True,
        mlp_dwconv=False,
        representation_size=None,
        layer_scale_init_value=-1,
        use_checkpoint_stages=[],
        **kwargs)
    model.default_cfg = _cfg()

    if pretrained:
        checkpoint = torch.load(pretrained, map_location="cpu", check_hash=True)
        model.load_state_dict(checkpoint["model"])

    return model


@register_model
def lemevit_small(pretrained=False, pretrained_cfg=None,
                  pretrained_cfg_overlay=None, **kwargs):
    model = LeMeViT(
        depth=[1, 2, 2, 6, 2],
        embed_dim=[96, 96, 192, 320, 384], 
        head_dim=32,
        mlp_ratios=[4, 4, 4, 4, 4],
        attn_type=["C","D","D","S","S"],
        queries_len=16,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.,
        qk_dims=None,
        cpe_ks=3,
        pre_norm=True,
        mlp_dwconv=False,
        representation_size=None,
        layer_scale_init_value=-1,
        use_checkpoint_stages=[],
        **kwargs)
    model.default_cfg = _cfg()

    if pretrained:
        checkpoint = torch.load(pretrained, map_location="cpu", check_hash=True)
        model.load_state_dict(checkpoint["model"])

    return model


@register_model
def lemevit_base(pretrained=False, pretrained_cfg=None,
                  pretrained_cfg_overlay=None, **kwargs):
    model = LeMeViT(
        depth=[2, 4, 4, 18, 4],
        embed_dim=[96, 96, 192, 384, 512], 
        head_dim=32,
        mlp_ratios=[4, 4, 4, 4, 4],
        attn_type=["C","D","D","S","S"],
        queries_len=16,
        qkv_bias=True,
        qk_scale=None,
        attn_drop=0.,
        qk_dims=None,
        cpe_ks=3,
        pre_norm=True,
        mlp_dwconv=False,
        representation_size=None,
        layer_scale_init_value=-1,
        use_checkpoint_stages=[],
        **kwargs)
    model.default_cfg = _cfg()

    if pretrained:
        checkpoint = torch.load(pretrained, map_location="cpu", check_hash=True)
        model.load_state_dict(checkpoint["model"])

    return model

性能测试

现在浅浅试一下模型在图像分类上的表现,我选择其中的tiny和small两个版本。使用自制的葡萄多光条件数据集:包括4个类别和3种光照条件共12种数据集,分成8:1=训练集:测试集。

我的训练代码:

import json
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from timm.utils import accuracy, AverageMeter, ModelEma
from sklearn.metrics import classification_report
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy
from models.lemevit import lemevit_small_v2
from torch.autograd import Variable
from torchvision import datasets
torch.backends.cudnn.benchmark = False
import warnings
warnings.filterwarnings("ignore")
os.environ['CUDA_VISIBLE_DEVICES']="0,1"
import pandas as pd
from torchvision.transforms import RandAugment
# 定义训练过程
def train(model, device, train_loader, optimizer, epoch,model_ema):
    model.train()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    total_num = len(train_loader.dataset)
    print(total_num, len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device, non_blocking=True), Variable(target).to(device,non_blocking=True)
        samples, targets = mixup_fn(data, target)
        output = model(data)
        optimizer.zero_grad()
        if use_amp:
            with torch.cuda.amp.autocast():
                loss = torch.nan_to_num(criterion_train(output, targets))
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss = criterion_train(output, targets)
            loss.backward()
            optimizer.step()

        if model_ema is not None:
            model_ema.update(model)
        torch.cuda.synchronize()
        lr = optimizer.state_dict()['param_groups'][0]['lr']
        loss_meter.update(loss.item(), target.size(0))
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        if (batch_idx + 1) % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR:{:.9f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                       100. * (batch_idx + 1) / len(train_loader), loss.item(), lr))
    ave_loss =loss_meter.avg
    acc = acc1_meter.avg
    print('epoch:{}\tloss:{:.2f}\tacc:{:.2f}'.format(epoch, ave_loss, acc))
    return ave_loss, acc


# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
    global Best_ACC
    model.eval()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()
    total_num = len(test_loader.dataset)
    print(total_num, len(test_loader))
    val_list = []
    pred_list = []

    for data, target in test_loader:
        for t in target:
            val_list.append(t.data.item())
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        output = model(data)
        loss = criterion_val(output, target)
        _, pred = torch.max(output.data, 1)
        for p in pred:
            pred_list.append(p.data.item())
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        acc5_meter.update(acc5.item(), target.size(0))
    acc = acc1_meter.avg
    print('\nVal set: Average loss: {:.4f}\tAcc1:{:.3f}%\tAcc5:{:.3f}%\n'.format(
        loss_meter.avg, acc, acc5_meter.avg))
    if acc > Best_ACC:
        if isinstance(model, torch.nn.DataParallel):
            torch.save(model.module, file_dir + '/' + 'best.pth')
        else:
            torch.save(model, file_dir + '/' + 'best.pth')
        Best_ACC = acc
    if isinstance(model, torch.nn.DataParallel):
        state = {

            'epoch': epoch,
            'state_dict': model.module.state_dict(),
            'Best_ACC': Best_ACC
        }
        if use_ema:
            state['state_dict_ema'] = model.module.state_dict()
        torch.save(state, file_dir + "/" + 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')
    else:
        state = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'Best_ACC': Best_ACC
        }
        if use_ema:
            state['state_dict_ema'] = model.state_dict()
        torch.save(state, file_dir + "/" + 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')
    return val_list, pred_list, loss_meter.avg, acc


def seed_everything(seed=0):
    os.environ['PYHTONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


if __name__ == '__main__':
    file_dir = 'checkpoints/LEMEVIT-small/'
    if os.path.exists(file_dir):
        print('true')
        os.makedirs(file_dir,exist_ok=True)
    else:
        os.makedirs(file_dir)

    model_lr = 1e-3
    BATCH_SIZE = 16
    EPOCHS = 50
    DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    use_amp =True
    use_dp = True
    classes = 4
    resume =None
    CLIP_GRAD = 5.0
    Best_ACC = 0
    use_ema=False
    model_ema_decay=0.9995
    start_epoch=1
    seed=1
    seed_everything(seed)
    transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std= [0.5, 0.5, 0.5])

    ])
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std= [0.5, 0.5, 0.5])
    ])
    mixup_fn = Mixup(
        mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
        prob=0.1, switch_prob=0.5, mode='batch',
        label_smoothing=0.1, num_classes=classes)

    dataset_train = datasets.ImageFolder('dataset/train', transform=transform)
    dataset_test = datasets.ImageFolder("dataset/val", transform=transform_test)
    with open('class.txt', 'w') as file:
        file.write(str(dataset_train.class_to_idx))
    with open('class.json', 'w', encoding='utf-8') as file:
        file.write(json.dumps(dataset_train.class_to_idx))
    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True,drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

    criterion_train = SoftTargetCrossEntropy()
    criterion_val = torch.nn.CrossEntropyLoss()
    model_ft = lemevit_small_v2(pretrained=False)
    num_fr=model_ft.head.in_features
    model_ft.head =nn.Linear(num_fr,classes)
    print(model_ft)
    if resume:
        model=torch.load(resume)
        print(model['state_dict'].keys())
        model_ft.load_state_dict(model['state_dict'],strict = False)
        Best_ACC=model['Best_ACC']
        start_epoch=model['epoch']+1
    model_ft.to(DEVICE)
    optimizer = optim.AdamW(model_ft.parameters(),lr=model_lr)
    cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=40, eta_min=5e-8)
    if use_amp:
        scaler = torch.cuda.amp.GradScaler()
    if torch.cuda.device_count() > 1 and use_dp:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model_ft = torch.nn.DataParallel(model_ft)
    if use_ema:
        model_ema = ModelEma(
            model_ft,
            decay=model_ema_decay,
            device=DEVICE,
            resume=resume)
    else:
        model_ema=None

    # 训练与验证
    is_set_lr = False
    log_dir = {}
    train_loss_list, val_loss_list, train_acc_list, val_acc_list, epoch_list = [], [], [], [], []
    epoch_info = []
    if resume and os.path.isfile(file_dir+"result.json"):
        with open(file_dir+'result.json', 'r', encoding='utf-8') as file:
            logs = json.load(file)
            train_acc_list = logs['train_acc']
            train_loss_list = logs['train_loss']
            val_acc_list = logs['val_acc']
            val_loss_list = logs['val_loss']
            epoch_list = logs['epoch_list']
    for epoch in range(start_epoch, EPOCHS + 1):
        epoch_list.append(epoch)
        log_dir['epoch_list'] = epoch_list
        train_loss, train_acc = train(model_ft, DEVICE, train_loader, optimizer, epoch,model_ema)
        train_loss_list.append(train_loss)
        train_acc_list.append(train_acc)
        log_dir['train_acc'] = train_acc_list
        log_dir['train_loss'] = train_loss_list
        if use_ema:
            val_list, pred_list, val_loss, val_acc = val(model_ema.ema, DEVICE, test_loader)
        else:
            val_list, pred_list, val_loss, val_acc = val(model_ft, DEVICE, test_loader)
        val_loss_list.append(val_loss)
        val_acc_list.append(val_acc)
        log_dir['val_acc'] = val_acc_list
        log_dir['val_loss'] = val_loss_list
        log_dir['best_acc'] = Best_ACC
        with open(file_dir + '/result.json', 'w', encoding='utf-8') as file:
            file.write(json.dumps(log_dir))
        print(classification_report(val_list, pred_list, target_names=dataset_train.class_to_idx))
        epoch_info.append({
            'epoch': epoch,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc
        })
        df = pd.DataFrame(epoch_info)
        df.to_excel(file_dir + "/epoch_info.xlsx",index=False)
        with open('epoch_info.txt', 'w') as f:
            for epoch_data in epoch_info:
                f.write(f"Epoch: {epoch_data['epoch']}\n")
                f.write(f"Train Loss: {epoch_data['train_loss']}\n")
                f.write(f"Train Acc: {epoch_data['train_acc']}\n")
                f.write(f"Val Loss: {epoch_data['val_loss']}\n")
                f.write(f"Val Acc: {epoch_data['val_acc']}\n")
                f.write("\n")
        if epoch < 600:
            cosine_schedule.step()
        else:
            if not is_set_lr:
                for param_group in optimizer.param_groups:
                    param_group["lr"] = 1e-6
                    is_set_lr = True
        fig = plt.figure(1)
        plt.plot(epoch_list, train_loss_list, 'r-', label=u'Train Loss')
        # 显示图例
        plt.plot(epoch_list, val_loss_list, 'b-', label=u'Val Loss')
        plt.legend(["Train Loss", "Val Loss"], loc="upper right")
        plt.xlabel(u'epoch')
        plt.ylabel(u'loss')
        plt.title('Model Loss ')
        plt.savefig(file_dir + "/loss.png")
        plt.close(1)
        fig2 = plt.figure(2)
        plt.plot(epoch_list, train_acc_list, 'g-', label=u'Train Acc')
        plt.plot(epoch_list, val_acc_list, 'y-', label=u'Val Acc')
        plt.legend(["Train Acc", "Val Acc"], loc="lower right")
        plt.title("Model Acc")
        plt.ylabel("acc")
        plt.xlabel("epoch")
        plt.savefig(file_dir + "/acc.png")
        plt.close(2)

测试代码:

import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
from sklearn.metrics import recall_score, precision_score, f1_score, accuracy_score

# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义自定义数据集
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.image_paths = []
        self.labels = []

        for idx, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_dir):
                img_path = os.path.join(class_dir, img_name)
                self.image_paths.append(img_path)
                self.labels.append(idx)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# 图像转换
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 加载数据集
dataset_root = 'dataset/sunlight'
dataset = CustomDataset(root_dir=dataset_root, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

# 加载模型
model_path = 'checkpoints/LEMEVIT-small/best.pth'
model = torch.load(model_path)
model.to(device)
model.eval()

# 预测函数
def predict(model, dataloader, device):
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return np.array(all_preds), np.array(all_labels)

# 获取预测结果
predictions, true_labels = predict(model, dataloader, device)

# 计算各个指标
recall = recall_score(true_labels, predictions, average='macro')
precision = precision_score(true_labels, predictions, average='macro')
f1 = f1_score(true_labels, predictions, average='macro')
accuracy = accuracy_score(true_labels, predictions)

# 获取类别名称
class_names = sorted(os.listdir(dataset_root))

print(f'Class names: {class_names}')
print(f'Recall: {recall:.4f}')
print(f'Precision: {precision:.4f}')
print(f'F1 Score: {f1:.4f}')
print(f'Accuracy: {accuracy:.4f}')

 结果如下:

可以看到,根据实验结果,LEMEVIT-tiny和LEMEVIT-small在不同光照条件下的表现差异明显。LEMEVIT-tiny在所有条件下的指标均略高于LEMEVIT-small,尤其在阴影和正常光照下,显示出更稳定的性能。而在阳光下,两个模型的性能均有所下降,但LEMEVIT-tiny依然保持较高的精度和召回率。总体而言,LEMEVIT-tiny在光照变化下具有更优越的适应性和稳定性,表现出更高的鲁棒性。当然本结果只针对我的数据集起作用,各位可以自己去实验。

列出模型的地址:https://github.com/ViTAE-Transformer/LeMeViT

以上为全部内容!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/693775.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MATLAB入门知识

目录 原教程链接&#xff1a;数学建模清风老师《MATLAB教程新手入门篇》https://www.bilibili.com/video/BV1dN4y1Q7Kt/ 前言 历史记录 脚本文件&#xff08;.m&#xff09; Matlab帮助系统 注释 ans pi inf无穷大 -inf负无穷大 i j虚数单位 eps浮点相对精度 0/&a…

设计软件有哪些?效果工具篇(3),渲染100邀请码1a12

这次我们再介绍一批渲染效果和后期处理的工具。 1、ColorCorrect ColorCorrect是一种图像处理技术&#xff0c;用于调整图像的色彩和对比度&#xff0c;使其更加自然和平衡。通过ColorCorrect&#xff0c;用户可以调整图像的色调、亮度、饱和度等参数&#xff0c;以达到理想的效…

关于CodeCombat(沙漠)布朗噪声的攻略

关于CodeCombat(沙漠)//布朗噪声的攻略 总的来说怎么猥琐怎么来 1.走到墙角骷髅看不到的位置&#xff0c;让宠物制造噪音&#xff0c;然后英雄走过去&#xff0c;就是这样没错&#xff08;坐标之类能明白) 最后看看运行结果吧 Rec 0002 希望天天开心

FreeRTOS基础(十三):队列集

队列集&#xff08;Queue Set&#xff09;通常指的是一组队列&#xff0c;它们可以用于处理不同的任务或数据流。每个队列可以独立地处理自己的元素&#xff0c;但作为一个集群&#xff0c;它们可以协同工作来完成更复杂的任务。下面进行介绍。 目录 一、队列集简介 二、队列…

【JS】理解闭包及其应用

历史小剧场 明朝灭亡&#xff0c;并非是简单的政治问题&#xff0c;事实上&#xff0c;这是世界经济史上的一个重要案例。 所谓没钱&#xff0c;就是没有白银。----《明朝那些事儿》 什么是闭包&#xff1f; 闭包就是指有权访问另一个函数作用域中变量的函数 闭包变量存储位置&…

六、【源码】SQL执行器的定义和实现

源码地址&#xff1a;https://github.com/mybatis/mybatis-3/ 仓库地址&#xff1a;https://gitcode.net/qq_42665745/mybatis/-/tree/06-sql-executor SQL执行器的定义和实现 之前的Sql执行都是耦合在SqlSession里的&#xff0c;现在要对这部分进行解耦和重构&#xff0c;引…

1奇函数偶函数

文章目录 自变量有理化奇偶性周期性初等函数 自变量 自变量是x&#xff0c;这个还挺奇怪&#xff0c;记住就好 y f ( e x 1 ) yf(e^x1) yf(ex1) 里面 e x e^x ex 只算中间变量&#xff0c;自变量是x 做这些题&#xff0c;想到了以前高中的时候做数学题&#xff0c;不够扎实…

Java 18 的主要新特性和代码演示

默认 UTF-8 从 JDK8 开始&#xff0c;UTF-8 就是 Java SE API 的默认字符集。java.nio.charset.Charset#defaultCharset() 现在默认返回 UTF-8 。现在 Java 的标准 API 都默认使用 UTF-8 编码&#xff0c;目的是让 Java 程序可预测和可移植。在之前&#xff0c; 读取同一个文件…

从《千脑智能》看大模型

千脑智能与大模型 千脑智能介绍 世界模型千脑智能理论——对大脑的全新理解旧大脑&#xff1a;演化的历史烙印新大脑&#xff1a;智慧的创新引擎新旧大脑的互动与争斗启示与借鉴 大脑对信息的处理和建模六根六尘六识 新脑&#xff1a;智能的创新中枢旧脑&#xff1a;生存的本能…

数据结构笔记 4 树和二叉树

二叉树和完全二叉树的区别&#xff1f; 二叉树和完全二叉树的主要区别在于它们的结构特性和节点排列方式&#xff1a; 1. **二叉树**&#xff1a; - 是一种数据结构&#xff0c;其中每个节点最多有两个子节点&#xff0c;通常称为左子节点和右子节点。 - 节点的子节点数量…

用爬虫实现---模拟填志愿

先来说实现逻辑&#xff0c;首先我要获取到这个网站上所有的信息&#xff0c;那么我们就可以开始对元素进行检查 我们发现他的每一个学校信息都有一个对应的属性&#xff0c;并且是相同的&#xff0c;那么我们就可以遍历这个网页中的所有属性一样的开始爬取 在来分析&#xff0…

探索智慧商场的功能架构与应用

在数字化和智能化的浪潮下&#xff0c;智慧商场已经成为零售业的重要发展方向之一。智慧商场系统的功能架构设计与应用&#xff0c;结合了现代信息技术和零售业的实际需求&#xff0c;为商场的管理和运营提供了全新的解决方案。本文将深入探讨智慧商场的功能架构与应用&#xf…

如何将 Windows图片查看器的背景颜色改成浅色(灰白色)?

现在大家基本都在使用Win10系统&#xff0c;我们在双击查看图片时&#xff0c;系统默认使用系统自带的图片&#xff08;照片&#xff09;查看器去打开图片。图片查看器的背景色默认是黑色的&#xff0c;如下所示&#xff1a;&#xff08;因为大家可能会遇到同样的问题&#xff…

ctfshow-web入门-命令执行(web29)五种解法绕过文件名检测

命令执行&#xff0c;需要严格的过滤 进入 php 代码审计了&#xff1a; 第一题代码很简单&#xff0c;就是对 preg_match 绕过&#xff0c;只要提交的参数值不出现 flag 就行 先看一下当前目录下的文件&#xff0c;构造 payload&#xff1a; ?csystem(ls); 可以看到 flag 就…

MSP432E401Y Launchpad硬件电路

MSP432E401Y是一款32位Arm Cortex-M4F内核的MCU&#xff0c;主频120MHz、256KB SRAM、1MB Flash、6KB EEPROM&#xff0c;具有丰富的通信外设&#xff0c;例如支持以太网、2个CAN、8个UART、4个QSSI(SPI)、10 个I2C; 同时还有2个12 位SAR的ADC模块&#xff0c;每个模块支持高…

【Java】解决Java报错:NullPointerException

文章目录 引言1. 错误详解2. 常见的出错场景2.1 调用 null 对象的实例方法2.2 访问 null 对象的属性2.3 自动拆箱引起的 NullPointerException 3. 解决方案3.1 使用条件判断防止 NullPointerException3.2 优先使用Optional类 4. 预防措施4.1 在方法入口进行校验4.2 使用注解提高…

Elastic 8.14:用于简化分析的 Elasticsearch 查询语言 (ES|QL) 正式发布

作者&#xff1a;来自 Elastic Brian Bergholm 今天&#xff0c;我们很高兴地宣布 Elastic 8.14 正式发布。 什么是新的&#xff1f; 8.14 版本最重要的标题是 ES|QL 的正式发布(GA)&#xff0c;它是从头开始设计和专门构建的&#xff0c;可大大简化数据调查。在新的查询引擎的…

【CTF MISC】XCTF GFSJ0290 reverseMe Writeup(图像处理)

reverseMe 暂无 解法 导入 Photoshop。 水平翻转&#xff0c;得到 flag。 Flag flag{4f7548f93c7bef1dc6a0542cf04e796e}声明 本博客上发布的所有关于网络攻防技术的文章&#xff0c;仅用于教育和研究目的。所有涉及到的实验操作都在虚拟机或者专门设计的靶机上进行&#xf…

实战项目《负载均衡在线OJ系统》

一、项目灵感来源 在日常做题的过程中&#xff0c;我们总会去力扣和牛客网上去做题&#xff0c;但是从来没有想过网站是如何加载给用户的&#xff0c;以及在提交代码时&#xff0c;是如何得知我们的代码是否正确。基于这样的原因&#xff0c;也是学习到一定程度的知识后&#x…

2024真机项目

项目需求&#xff1a; 1. 172.25.250.101 主机上的 Web 服务要求提供 www.exam.com 加密站点&#xff0c;该站点在任何路由可达 的主机上被访问&#xff0c;页面内容显示为 "Hello&#xff0c;Welcome to www.exam.com !"&#xff0c;并提供 content.exam.com/yum/A…