CV05_深度学习模块之间的缝合教学(1)

1.1 在哪里缝

测试文件?(×)

训练文件?(×)

模型文件?(√)

1.2 骨干网络与模块缝合

以Vision Transformer为例,模型文件里有很多类,我们只在最后集大成的那个类里添加模块。

之后后,我们准备好我们要缝合的模块,比如SE Net模块,我们先建立一个测试文件测试能否跑通

import numpy as np
import torch
from torch import nn
from torch.nn import init

class SEAttention(nn.Module):
    # 初始化SE模块,channel为通道数,reduction为降维比率
    def __init__(self, channel=512, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 自适应平均池化层,将特征图的空间维度压缩为1x1
        self.fc = nn.Sequential(  # 定义两个全连接层作为激励操作,通过降维和升维调整通道重要性
            nn.Linear(channel, channel // reduction, bias=False),  # 降维,减少参数数量和计算量
            nn.ReLU(inplace=True),  # ReLU激活函数,引入非线性
            nn.Linear(channel // reduction, channel, bias=False),  # 升维,恢复到原始通道数
            nn.Sigmoid()  # Sigmoid激活函数,输出每个通道的重要性系数
        )

    # 权重初始化方法
    def init_weights(self):
        for m in self.modules():  # 遍历模块中的所有子模块
            if isinstance(m, nn.Conv2d):  # 对于卷积层
                init.kaiming_normal_(m.weight, mode='fan_out')  # 使用Kaiming初始化方法初始化权重
                if m.bias is not None:
                    init.constant_(m.bias, 0)  # 如果有偏置项,则初始化为0
            elif isinstance(m, nn.BatchNorm2d):  # 对于批归一化层
                init.constant_(m.weight, 1)  # 权重初始化为1
                init.constant_(m.bias, 0)  # 偏置初始化为0
            elif isinstance(m, nn.Linear):  # 对于全连接层
                init.normal_(m.weight, std=0.001)  # 权重使用正态分布初始化
                if m.bias is not None:
                    init.constant_(m.bias, 0)  # 偏置初始化为0

    # 前向传播方法
    def forward(self, x):
        b, c, _, _ = x.size()  # 获取输入x的批量大小b和通道数c
        y = self.avg_pool(x).view(b, c)  # 通过自适应平均池化层后,调整形状以匹配全连接层的输入
        y = self.fc(y).view(b, c, 1, 1)  # 通过全连接层计算通道重要性,调整形状以匹配原始特征图的形状
        return x * y.expand_as(x)  # 将通道重要性系数应用到原始特征图上,进行特征重新校准

# 示例使用
if __name__ == '__main__':
    input = torch.randn(50, 512, 7, 7)  # 随机生成一个输入特征图
    se = SEAttention(channel=512, reduction=8)  # 实例化SE模块,设置降维比率为8
    output = se(input)  # 将输入特征图通过SE模块进行处理
    print(output.shape)  # 打印处理后的特征图形状,验证SE模块的作用

打印处理后的形状,我们这里要注意,缝合模块时只需要注意第一维,也就是这个channel,要和骨干网络保持一致,只要你把输入输出的通道数对齐,那么这个通道数就可以缝合成功。

把模块复制进骨干网络中:

然后进行缝合,在缝合之前要先测试通道是否匹配,不然肯定报错。

如何验证通道数

我们找到骨干网络前向传播的部分,在你想加入这个模块地方print(x.shape)即可。运行训练文件:

放在最前面:

通道数为3(8为batch size)。

将模块添加进骨干网络

在骨干网络的init函数下添加:(ctrl+p可查看参数)通道数与之前查的对齐。

在前向传播中添加:

看看是否正常运行:

正常运行,说明模块缝合成功!

打印缝合后的模型结构

该操作在模型文件中进行。

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (2): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (3): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (4): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (5): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (6): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (7): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (8): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (9): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (10): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (11): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (pre_logits): Sequential(
    (fc): Linear(in_features=768, out_features=768, bias=True)
    (act): Tanh()
  )
  (head): Linear(in_features=768, out_features=21843, bias=True)
  (se): SEAttention(
    (avg_pool): AdaptiveAvgPool2d(output_size=1)
    (fc): Sequential(
      (0): Linear(in_features=3, out_features=0, bias=False)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=0, out_features=3, bias=False)
      (3): Sigmoid()
    )
  )
)

我们可以看到多了一个SEAttention,说明模块缝合进去了!

1.3 模块之间缝合

以SENet和ECA模块为例。

串联模块

方式1

同1.2。照猫画虎。(注意通道数保持一致)

打印模型结构:

ECAAttention(
  (gap): AdaptiveAvgPool2d(output_size=1)
  (conv): Conv1d(1, 1, kernel_size=(3,), stride=(1,), padding=(1,))
  (sigmoid): Sigmoid()
  (se): SEAttention(
    (avg_pool): AdaptiveAvgPool2d(output_size=1)
    (fc): Sequential(
      (0): Linear(in_features=64, out_features=4, bias=False)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=4, out_features=64, bias=False)
      (3): Sigmoid()
   )))

 方式2

我们定义一个串联函数,将模块之间串联起来:

实例化查看一下模型结构

输出结果:

torch.Size([1, 63, 64, 64]) torch.Size([1, 63, 64, 64])
Cascade(
  (se): SEAttention(
    (avg_pool): AdaptiveAvgPool2d(output_size=1)
    (fc): Sequential(
      (0): Linear(in_features=63, out_features=3, bias=False)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=3, out_features=63, bias=False)
      (3): Sigmoid()
    )
  )
  (eca): ECAAttention(
    (gap): AdaptiveAvgPool2d(output_size=1)
    (conv): Conv1d(1, 1, kernel_size=(63,), stride=(1,), padding=(31,))
    (sigmoid): Sigmoid()
  )
)

并联模块

对于并联模块,方法有很多种,两个两个模块输出的张量可以:

(1)逐元素相加(2)逐元素相乘(3)concat拼接(4)等等

输出结果:

torch.Size([1, 63, 64, 64]) torch.Size([1, 126, 64, 64])
Cascade(
  (se): SEAttention(
    (avg_pool): AdaptiveAvgPool2d(output_size=1)
    (fc): Sequential(
      (0): Linear(in_features=63, out_features=3, bias=False)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=3, out_features=63, bias=False)
      (3): Sigmoid()
    )
  )
  (eca): ECAAttention(
    (gap): AdaptiveAvgPool2d(output_size=1)
    (conv): Conv1d(1, 1, kernel_size=(63,), stride=(1,), padding=(31,))
    (sigmoid): Sigmoid()
  )
)

1.4 思考 

我们不要拘泥于只串联获并联,可以将二者结合,多个模块中,部分模块并联后又与其他模块串联,等等。。这种排列组合之后,总会有一个你想要的模型!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/790445.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Flutter——最详细(Table)网格、表格组件使用教程

背景 用于展示表格组件,可指定线宽、列宽、文字方向等属性 属性作用columnWidths列的宽度defaultVerticalAlignment网格内部组件摆放方向border网格样式修改children表格里面的组件textDirection文本排序方向 import package:flutter/material.dart;class CustomTa…

Mac 上安转文字转 SQL 利器 WrenAI

WrenAI 是一个开源的 Text-SQL 的工具,通过导入数据库结构,通过提问的方式生成 SQL。本文将讲述如何在 MacOS 上安装 WrenAI。要运行WrenAI,首先需要安装 Docker 桌面版。 下载 WrenAI https://github.com/Canner/WrenAI/releases/tag/0.7.…

开源流程表单设计器都有哪些值得一提的优势?

如果需要提质、增效、降本,不妨来了解下低代码技术平台、开源流程表单设计器的功能和优势特点。想要实现流程化办公,低代码技术平台是助力增效的理想工具。功能灵活、操作方便、好维修、可视化操作等优势都是其深受行业喜爱的优势特点。通过本文&#xf…

DDL也会有undo吗?模拟Oracle中DML、DDL与undo的关系,10046跟踪DDL语句

已经有两个月没有更新博客了,主要实在忙毕设和毕业的一些事情!这两个月也是非常的精彩呀,充分体会到了职场的和校园的不同,作为一名刚毕业就满 1 年工作经验的牛马人,在两个月期间经历了两次调岗、两次降薪&#xff0c…

一句歌词描述夏天

夏天总是带着一种奇特的魔力,既能让人沉醉在阳光和海浪的浪漫中,也能在炎热与燥热中让人心生烦闷。特别是在夏日里情绪低落时,那些可以抚平心情的歌曲显得尤为珍贵。音乐,这个神奇的存在,总能在最需要的时候带来心灵的…

使用AutoGPT构建智能体:从LSTM到Prompt编写实战教程001

如果报错,这里会有一个环境变量的设置需要设置上. 然后这一节我们来自己制作一个智能体,来感受一下,实际上现在,大模型还是可以做很多功能的. 可以看到上面是智能体的架构,之前也说过了, 上面这几个功能,如果用我们人类去操作,还是需要花些时间的,如果用大模型就快很多了. 以…

利用Python的sympy包求解一元多次方程

一元1次方程 import sympy as sp # 导入sympy包 x sp.Symbol(x) # 定义符号变量 f 2*x -8 # 定义要求解的一元1次方程 x sp.solve(f) # 调用solve函数求解方程 x[4]一元2次方程 import sympy as sp # 导入sympy包 x sp.Symbol(x) # 定义符号变量 f …

Nature Protocols:整合多组学并进行因果推理的系统框架

转载自:MetaAI 在生物学研究中,随着实验和计算技术的进步,生物系统研究产生了大量高通量数据。技术努力主要集中在提高吞吐量、降低成本和提升实验与计算效率。因此,整合不同类型组学数据,并通过关联分析识别关键因素…

[机器学习]-人工智能对程序员的深远影响——案例分析

机器学习和人工智能对未来程序员的深远影响 目录 机器学习和人工智能对未来程序员的深远影响1. **自动化编码任务**1.1 代码生成1.2 自动调试1.3 测试自动化 2. **提升开发效率**2.1 智能建议2.2 项目管理 3. **改变编程范式**3.1 数据驱动开发 4. **职业发展的新机遇**4.1 AI工…

大数据开发者:如何快速熟悉新公司的技术环境

目录 1. 了解系统架构实践建议:示例对话: 2. 了解领域模型实践建议:示例: 3. 了解代码结构实践建议:示例: 结语 作为一名大数据开发者,加入新公司后快速熟悉技术环境是一项重要而又具有挑战性的…

bev 之 fastBEV

前面我们提到bev 之 LSS, 知道视觉的BEV方案的主要痛点在于: 1、depth 的预测 2、图像特征到BEV特征之间的视图变换消耗大量计算 LSS 为什么需要D维深度 占据大量消耗的原因是LSS 对每个图像特征点引入深度D,即假设每个像素上存在可能的D维深度。也就是假设不同像…

C++ 栈-队列-优先级队列

目录 1 栈 2 队列 3 deque 介绍 4 优先级队列 5 反向迭代器 栈也是我们在C语言就模拟实现过的一种数据结构,在C中,栈其实和我们前面模拟实现过的string、vector等容器有一点区别,站起是不是容器,而是一种容器适配器,我…

Floyd判圈算法——寻找重复数(C++)

287. 寻找重复数 - 力扣(LeetCode) 题目描述 给定一个包含 n 1 个整数的数组 nums ,其数字都在 [1, n] 范围内(包括 1 和 n),可知至少存在一个重复的整数。假设 nums 只有 一个重复的整数 ,返…

python基础语法笔记(有C语言基础之后)

input()用于输入,其有返回值(即用户输入的值),默认返回字符串。括号里可放提示语句 一行代码若想分为多行来写,需要在每一行的末尾加上“\” 单个“/”表示数学中的除法,不会取整。“//”才会向下取整。 …

无人机之飞行规划与管理篇

无人机飞行规划与管理是确保无人机安全、高效且符合法规的运行的关键步骤。这一过程包括了对飞行任务的详细安排、航线的设定以及风险的评估和管理。下面简述这一过程的主要环节: 一、飞行目的和任务确定 在规划之初,必须明确无人机的飞行目的&#xf…

HTTPS理解

一个完整的HTTP连接 TCP三次握手接受窗口发送数据关闭连接 接受窗口是用来做什么呢? 它根据自身网络情况设置不同大小的值用来控制对方发送速度,避免对方发送太快,导致网络拥塞。 为什么TCP握手要三次? 1)确认双方的…

单片机中有FLASH为啥还需要EEROM?

在开始前刚好我有一些资料,是我根据网友给的问题精心整理了一份「单片机的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!! 一是EEPROM操作简单&…

JDK11中zgc垃圾回收器的探索

背景 垃圾回收器主要做的事情 自动跟踪和管理程序中创建的对象,确定哪些对象仍在使用,哪些对象已经不再使用。回收那些不再使用的对象所占用的内存空间,使得这部分内存可以被重新使用。 1.1 传统垃圾回收器 垃圾回收器简述优缺点应用场景…

typora 两边太宽,设置宽度

步骤: 查看目前使用主题类型 文件 —> 偏好设置 —> 外观 —> 打开主题文件夹 修改对应的主题:max-width

在Linux下使用Docker部署chirpstack

目录 一、前言 二、chirpstack 1、chirpstack是什么 2、chirpstack组件 3、为什么选择Docker部署 三、Linux下部署过程 四、web界面部署过程 一、前言 本篇文章我是在Linux下使用 Docker 进行部署chirpstack,chirpstack采用的是v4 版本,v4 版本 与…