57、通过EEG数据的SHAPE变化,揭开EEG-TCNet的黑匣子[看好了小子,我只教这一次]

之前在第18篇博客中对于EEG-TCNet这个处理EEG信号的sota模型进行了介绍,也给出了模型,目前也是全网对于EEG-TCNet浏览度最高的文章了,我觉得讲的已经很细致了,没想到还是有不少同学疑问,这也是全网缺少该模型pytorch代码的原因,因为pytorch中没有封装TCN模块,无法直接调用,而在Tensorflow中可直接调用,废话不多少,上菜:

EEG-TCNet模型图:

原论文EEG-TCNet结构图

模型结构分析:

1、BCI IV2a数据以4维数据输入,shape=(288,1,22,1000)

2、数据先经过一个完整的EEGNet结构(时间卷积+深度卷积+深度可分离),来处理这个4维数组

3、数据从EEGNet出来,进入到TCN块之前进行降维处理(TCN只能处理1维数组)

下面我们来看2a数据以(batch_size,1,22,1000)输入到EEG-TCNet中是如何改变shape的

我自己写的EEG-TCNet代码模型-结构图:(我自己画的,别盗图)

TCN块(膨胀因果卷积)分析:

代码编写:

Chomp1d(nn.Module):裁剪类

TemporalBlock(nn.Module):TCN主体类,调用Chomp1d(),在这个类使用的卷积是Conv1d

TemporalConvNet(nn.Module):调用TemporalBlock()TCN完全体类

TCNNet(nn.Module):调用TemporalConvNet(),降维,使得TCN完全体跑的通


讲上面这4个类,我要倒着讲,费点劲要:(为啥倒着讲?同学想想 0。0 )

input_data = batch_size,1,22,1000经过一个前3个block后,此时控制台输出shape = 32,8,1,31断点如下:

数据此时还是4维的,所以我们在这使用if来判断维度,给他降维度

1、Data = torch.rand(x.shape):生成一个空的和x的维度一致的张量数据,用来存储for循环TCN块裁剪的数据

2、空的张量数据也要送到GPU中,否则报错,因为此时X的数据都在GPU上

3、在x的第二维度channel = 1,进行for循环,通过self来调用类内的tcn_block对应的TCN方法,对x数据进行裁剪并提取数据,把这些数据(此时还是4维)送给张量data

4、x = data(乾坤大魔移!


tcn_block对应着咱们定义的TemporalConvNet() 完全体这个类,如下:

类里面调用了上面定义好的Chomp1d()这个裁剪的类

此时代码跑到了Chomp1d()里面,如下所示:

TCN之前的数据= 32,8,1,31

此时数据维度 = 32,8,40,这里代码自动的去掉了通道=1的维度,并+res这个对x下采样的数据

因为这里是for i in range(x.shape[2])的循环,此时i=0,x.shape[3] = 40,我们再进入下一个循环i=1看看

此时x.shape[3] = 49,所以就这样,在送到Chomp1d()进行裁剪时,x加上了res这个下采样特征数据,导致了x的数据量增加,我们规定了Chomp1d()中的chomp_size这个数值,只保留与原始数据总量相同的前chomp_size的这个数目,来最后送给Fc层做最后结果的输出

此时我们送给Fc的shape :

又变回原来的31个数据了,这事裁剪类的功劳!但此时前后的这个31数据是不同的,多了下采样的特征,所以TCNNet这个类实现了先降维再生维的神奇操作,使得代码流通,完事。

全部代码如下:

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()

class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp1 = Chomp1d(padding)
       
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation))
        self.chomp2 = Chomp1d(padding)
       
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1,self.relu1, self.dropout1,
                                 self.conv2, self.chomp2,self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
    #   self.init_weights()

    # def init_weights(self):
    #     self.conv1.weight.data.normal_(0, 0.01)
    #     self.conv2.weight.data.normal_(0, 0.01)
    #     if self.downsample is not None:
    #         self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)

class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size, padding=(kernel_size-1) * dilation_size, 
                                     dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

import numpy as np

class TCNNet(nn.Module):
     def __init__ (self,*args) -> None:
        super(TCNNet,self).__init__()
        if len(args) < 2:
            print('error')
            exit()
        else:
            num_inputs = args[0]
            num_channels = args[1]
            kernel_size = int(args[2][0])
            
        self.tcn_block =  TemporalConvNet(num_inputs,num_channels,kernel_size) 
        #self.tcn_block =  TemporalConvNet(num_inputs=self.F2,num_channels=[tcn_filters,tcn_filters],kernel_size=tcn_kernelSize) 
     def forward(self,x) :
        if len(x.shape) == 4:
            data = torch.rand(x.shape)
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
            data = data.to(device)
 
            for i in range(x.shape[2]):
              
                data[:,:,i,:] = self.tcn_block(x[:,:,i,:])
            x = data
        else:
            x = torch.squeeze(x,dim=2) 
            x = self.tcn_block(x)
        
        return x 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/554850.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

13、ESP32 深度睡眠

1、深度睡眠 ESP32 可以在不同的电源模式之间切换&#xff1a; 活动模式调制解调器睡眠模式浅睡眠模式深度睡眠模式休眠模式 在深度睡眠模式下&#xff0c;CPU 或 Wi-Fi 活动都不会发生&#xff0c;但超低功耗 &#xff08;ULP&#xff09; 协处理器仍可打开电源&#xff0c;R…

常见的数据结构

链表 链表&#xff1a;适用于插入删除多、读少的场景。 链表在新增、删除数据都比较容易&#xff0c;可以在 O(1) 的时间复杂度内完成。 但对于查找&#xff0c;不管是按照位置的查找还是按照数值条件的查找&#xff0c;都需要对全部数据进行遍历。这显然就是 O(n) 的时间复杂…

解锁流量密码:如何利用HubSpot打造高效的获客策略?(下)

在当今数字化时代&#xff0c;流量是企业成功的关键。HubSpot作为一款全面的营销自动化工具&#xff0c;为我们提供了强大的支持&#xff0c;帮助企业打造高效的流量获取策略。接下来&#xff0c;我们将从社交媒体与SEO优化、自动化营销流程、数据分析与效果评估以及流量获取策…

畅通工程(并查集)

//新生训练 #include <iostream> #include <algorithm> #include <bits/stdc.h> using namespace std;#define N 1010int fa[N];int Find(int x) {if (x ! fa[x]){fa[x] Find(fa[x]);}return fa[x]; }void Union(int x, int y) {int a Find(x);int b Find…

Linux读写文件

前言 学习了文件系统&#xff0c;就能理解为什么说Linux下一切皆文件。 语言层面的操作 在c语言的学习中我们可以使用fopen()函数对文件进行操作。 int main() {//FILE * fp fopen("./log.txt", "w");//FILE * fp fopen("./log.txt", "…

Compose 布局

文章目录 Compose 布局ColumnColumn属性使用 RowRow属性使用 BoxBox属性使用 ConstraintLayoutLazyColumnLazyColumn属性使用使用多类型使用粘性标题回到顶部 LazyRowLazyRow属性使用 LazyVerticalGridLazyVerticalGrid属性使用 Compose 布局 Column Compose中的”垂直线性布…

BlueNRG-X 原理图参数说明

1. 前言 为了让客户在原理图设计阶段少走弯路&#xff0c;我这里结合客户评估和设计阶段常遇到的问题&#xff0c;整理了一下 BlueNRG-1/-2 相关设计及注意事项以备客户解惑用。 2. BlueNRG-1/-2 的原理图参数说明及设计注意事项 2.1. BlueNRG-1/-2 原理图及参数如下&#x…

NFTScan | 04.08~04.14 NFT 市场热点汇总

欢迎来到由 NFT 基础设施 NFTScan 出品的 NFT 生态热点事件每周汇总。 周期&#xff1a;2024.04.08~ 2024.04.14 NFT Hot News 01/ 数据&#xff1a;Runestone 地板价突破 0.07 BTC&#xff0c;创历史新高 4 月 8 日&#xff0c;据数据显示&#xff0c;Runestone 地板价突破 …

【第七届openGauss技术文章征集】 openGauss新版本征文活动来啦!

活动背景 2024年3月30日&#xff0c;openGauss 6.0.0版本正式上线&#xff0c;该版本与之前版本特性功能保持兼容&#xff0c;在内核能力、DataPod三层资源池化架构、DataKit数据全生命周期管理平台、生态兼容性等方面全面增强。&#xff08;下方【点击原文】即可查看更多【新…

Android--ConnectivityManager使用

一、前言 Android10之后官方废弃了通过WifiManager连接WIFI的方式&#xff0c;现在要使用ConnectivityManager连接WIFI 二、连接WIFI public class MainActivity extends AppCompatActivity {private static final String TAG"lkx";Overrideprotected void onCrea…

【黑马头条】-day10热点文章定时计算-xxl-job

文章目录 1 今日内容1.1 需求分析1.2 解决方案1.3 定时计算1.4 定时任务方案对比 2 分布式任务调度3 xxl-job3.1 简介3.2 环境搭建3.2.1 配置maven3.2.2 源码说明 3.3 配置部署调度中心3.3.1 运行sql脚本3.3.2 修改配置application.properties3.3.3 启动引导类 3.4 docker配置x…

✌粤嵌—2024/4/18—旋转链表✌

代码实现&#xff1a; 方法一&#xff1a;在原链表中找到旋转之后的首尾&#xff0c;改变指向 /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/ struct ListNode* rotateRight(struct ListNode *head, int k) {i…

电脑怎么创建加密磁盘?方法很简单!

在电脑中创建加密磁盘&#xff0c;可以有效地保护电脑数据&#xff0c;避免数据泄露。那么&#xff0c;电脑怎么创建加密磁盘呢&#xff1f;下面我们就一起来了解一下吧。 密盘创建方法 创建密盘建议使用超级秘密磁盘3000&#xff0c;软件界面简约&#xff0c;操作简单&#x…

秋招复习笔记——八股文部分:网络HTTP

常见面试题 基本概念 HTTP 是超文本传输协议&#xff0c;也就是HyperText Transfer Protocol。HTTP 协议是一个双向协议,是一个在计算机世界里专门用来在两点之间传输数据的约定和规范。是超越了普通文本的文本&#xff0c;它是文字、图片、视频等的混合体&#xff0c;最关键…

资源管理规范

踩坑经验&#xff1a; 对于IO及池化资源(文件、线程池、网络IO(HttpClient)、磁盘IO)&#xff0c;使用之后一定要及时回收&#xff0c;好借好还&#xff0c;再借不难。 稳定关闭方式 完成I/O操作后&#xff0c;应该关闭流以释放系统资源。可以使用finally块确保流被关闭&#…

学生故事|勇于创新,拒绝“一成不变”的设计

对于JIANG MANQI而言&#xff0c;室内设计一直是她钟爱的行业选择。 在没有进入莱佛士学习之前&#xff0c;MANQI受家人的影响&#xff0c;从小就对设计行业比较感兴趣。而她选择室内设计&#xff0c;是觉得室内设计是比较有前途的一个专业&#xff0c;随着人们生活品质的提高…

数字化转型-工具变量数据集

01、数据介绍 数字化转型是指企业或个人利用数字技术&#xff0c;如大数据、云计算、人工智能等&#xff0c;对其业务流程、运营模式、决策方式等进行全面、深入的变革&#xff0c;以提高效率、降低成本、提升质量、增强竞争力。在这个过程中&#xff0c;工具变量扮演着至关重…

SpringBoot 集成Nacos注册中心和配置中心-支持自动刷新配置

SpringBoot 集成Nacos注册中心和配置中心-支持自动刷新配置 本文介绍SpringBoot项目集成Nacos注册中心和配置中心的步骤&#xff0c;供各位参考使用 1、配置pom.xml 文件 在pom.xml文件中定义如下配置和引用依赖&#xff0c;如下所示&#xff1a; <properties><pr…

MATLAB求和函数

语法 S sum(A) S sum(A,“all”) S sum(A,dim) S sum(A,vecdim) S sum(,outtype) S sum(,nanflag) 说明 示例 S sum(A) 返回沿大小大于 1 的第一个数组维度计算的元素之和。 如果 A 是向量&#xff0c;则 sum(A) 返回元素之和。 如果 A 是矩阵&#xff0c;则 sum(A) 将…

如何在Linux CentOS部署宝塔面板并实现固定公网地址访问内网宝塔

文章目录 一、使用官网一键安装命令安装宝塔二、简单配置宝塔&#xff0c;内网穿透三、使用固定公网地址访问宝塔 宝塔面板作为建站运维工具&#xff0c;适合新手&#xff0c;简单好用。当我们在家里/公司搭建了宝塔&#xff0c;没有公网IP&#xff0c;但是想要在外也可以访问内…