SpikingJelly笔记之梯度替代

文章目录

  • 前言
  • 一、梯度替代
  • 二、网络结构
  • 三、MNIST分类
    • 1、单步模式
    • 2、多步模式
  • 总结


前言

在SpikingJelly使用梯度替代训练SNN,构建单层全连接SNN实现MNIST分类任务。


一、梯度替代

1、梯度替代:

阶跃函数不可微,无法进行反向传播

g ( x ) = { 1 , x ≥ 0 0 , x < 0 g(x) = \left\{\begin{matrix} 1,&\quad x\ge 0\\ 0,&\quad x<0\\ \end{matrix}\right. g(x)={1,0,x0x<0 , ,\quad\quad\quad g ′ ( x ) = { + ∞ , x = 0 0 , x ≠ 0 g^{\prime}(x) = \left\{\begin{matrix} +∞&,\quad x= 0\\ 0&,\quad x\neq0\\ \end{matrix}\right. g(x)={+0,x=0,x=0

前向传播使用阶跃函数,反向传播使用替代函数

2、梯度替代函数:

来源:spikingjelly.activation_based.surrogate package

①Sigmoid:surrogate.Sigmoid(alpha=4.0, spiking=True)

g ( x ) = s i g m o i d ( α x ) = 1 1 + e − α x g(x) = sigmoid(\alpha x)=\frac{1}{1+e^{-\alpha x}} g(x)=sigmoid(αx)=1+eαx1

g ′ ( x ) = α ∗ s i g m o i d ( α x ) ∗ ( 1 − s i g m o i d ( α x ) ) g^{\prime}(x) = \alpha*sigmoid(\alpha x)*(1-sigmoid(\alpha x)) g(x)=αsigmoid(αx)(1sigmoid(αx))

②ATan:surrogate.ATan(alpha=2.0, spiking=True)

g ( x ) = 1 π a r c t a n ( π 2 α x ) + 1 2 g(x) = \frac{1}{\pi}arctan(\frac{\pi}{2}\alpha x)+\frac{1}{2} g(x)=π1arctan(2παx)+21

g ′ ( x ) = α 2 ( 1 + ( π 2 α x ) 2 ) g^{\prime}(x) = \frac{\alpha}{2(1+(\frac{\pi}{2}\alpha x)^2)} g(x)=2(1+(2παx)2)α

③SoftSign:surrogate.SoftSign(alpha=2.0, spiking=True)

g ( x ) = 1 2 ( α x 1 + ∣ α x ∣ + 1 ) g(x) = \frac{1}{2}(\frac{\alpha x}{1+|\alpha x|}+1) g(x)=21(1+αxαx+1)

g ′ ( x ) = α 2 ( 1 + ∣ α x ∣ 2 ) g^{\prime}(x) = \frac{\alpha}{2(1+|\alpha x|^2)} g(x)=2(1+αx2)α

④LeakyKReLU:surrogate.LeakyKReLU(spiking=True, leak: float=0.0, k: float=1.0)

g ( x ) = { k ∗ x , x ≥ 0 l e a k ∗ x , x < 0 g(x) = \left\{\begin{matrix} k*x,&\quad x\ge 0\\ leak*x,&\quad x<0\\ \end{matrix}\right. g(x)={kx,leakx,x0x<0 , ,\quad\quad\quad g ′ ( x ) = { k , x ≥ 0 l e a k , x < 0 g^{\prime}(x) = \left\{\begin{matrix} k&,\quad x\ge 0\\ leak&,\quad x<0\\ \end{matrix}\right. g(x)={kleak,x0,x<0

二、网络结构

使用神经元层替代激活函数

1、ANN

nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 10, bias=False),
    nn.Softmax()
    )

2、SNN

nn.Sequential(
    layer.Flatten(),
    layer.Linear(28 * 28, 10, bias=False),
    neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan())
    )

三、MNIST分类

1、单步模式

(1)导入库

import time
import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader
from spikingjelly.activation_based import neuron, encoding,\
    functional, surrogate, layer, monitor
from spikingjelly import visualizing
from load_mnist import load_mnist

(2)构建数据加载器

将numpy数据封装成DataLoader

使用Pytorch自带的数据集会更方便

def To_loader(x_train, y_train, x_test, y_test, batch_size):
    # 转为张量
    x_train = torch.from_numpy(x_train.astype(np.float32))
    y_train = torch.from_numpy(y_train.astype(np.float32))
    x_test = torch.from_numpy(x_test.astype(np.float32))
    y_test = torch.from_numpy(y_test.astype(np.float32))
    # 数据集封装
    train_dataset = TensorDataset(x_train, y_train)
    test_dataset = TensorDataset(x_test, y_test)
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    test_loader = DataLoader(dataset=test_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    return train_dataset, test_dataset, train_loader, test_loader

(3)构建SNN模型

将LIF神经元层当作激活函数使用

使用ATan作为梯度替代函数进行反向传播

class SNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Sequential(
            layer.Linear(784, 10, bias=False),
            neuron.LIFNode(tau=2.0,
                           decay_input=True,
                           v_threshold=1.0,
                           v_reset=0.0,
                           surrogate_function=surrogate.ATan(),
                           step_mode='s',
                           store_v_seq=False)
            )
    def forward(self, x):
        return self.layer(x)

(4)训练参数

使用泊松编码器对输入进行编码

取10000个样本进行训练

epoch_num = 10
batch_size = 256
T = 50
lr = 0.001
encoder = encoding.PoissonEncoder() # 泊松编码器
model = SNN() # 单层SNN
loss_function = nn.MSELoss() # 均方误差
optimizer = optim.Adam(model.parameters(), lr) # Adam优化器
x_train, y_train, x_test, y_test = \
    load_mnist(normalize=True, flatten=False, one_hot_label=True)
train_dataset, test_dataset, train_loader, test_loader =\
    To_loader(x_train[:10000], y_train[:10000], x_test, y_test, batch_size)

(5)迭代训练

①取一段时间的平均发放率作为输出

②损失函数采用交叉熵或均方差,使对应神经元fout→1,其他神经元fout→0

③每批训练后重置网络状态

④每轮训练后测试准确率

start_time = time.time()
loss_train_list = []
acc_train_list = []
acc_test_list = []
for epoch in range(epoch_num):
    print('Epoch:%s'%(epoch+1))
    # 模型训练
    loss_train = 0
    acc_train = 0
    for x, y in train_loader:
        f_out = torch.zeros((y.shape[0], 10)) # 输出频率
        # 前向计算,逐步传播
        for t in range(T):
            encoded_x = encoder(x.reshape(-1, 784))
            f_out += model(encoded_x)
        f_out /= T
        # 反向传播
        loss = loss_function(f_out, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 计算损失值与准确率
        loss_train += loss.item()
        acc_train += (f_out.argmax(1) == y.argmax(1)).sum().item()
        # 清除状态
        functional.reset_net(model)
    acc_train /= len(train_dataset)
    loss_train_list.append(loss_train)
    acc_train_list.append(acc_train)
    print('loss_train:', loss_train)
    print('acc_train:{:.2%}:'.format(acc_train))
    # 模型测试
    with torch.no_grad():
        acc_test = 0
        for x, y in test_loader:
            f_out = torch.zeros((y.shape[0], 10))
            # 逐步传播
            for t in range(T):
                encoded_x = encoder(x.reshape(-1,784))
                f_out += model(encoded_x)
            f_out /= T
            loss = loss_function(f_out, y)
            acc_test += (f_out.argmax(1) == y.argmax(1)).sum().item()
            functional.reset_net(model)
        acc_test /= len(test_dataset)
        acc_test_list.append(acc_test)
        print('acc_test:{:.2%}'.format(acc_test))
end_time = time.time()
print('Time:{:.1f}s'.format(end_time - start_time))

训练结果:

Epoch:10
loss_train: 0.8223596904426813
acc_train:91.10%
acc_test:90.24%
Time:123.3s

(6)显示损失值与准确率变化

fig1 = plt.figure(1, figsize=(12, 6))
ax1 = fig1.add_subplot(2, 2, 1)
ax1.plot(loss_train_list, 'r-')
ax1.set_title('loss_train')
ax2 = fig1.add_subplot(2, 2, 2)
ax2.plot(acc_train_list, 'b-')
ax2.set_title('acc_train')
ax3 = fig1.add_subplot(2, 1, 2)
ax3.plot(acc_test_list, 'b-')
ax3.set_title('acc_test')
plt.show()

训练结果:

(7)结果预测

选取一个数据,观察各神经元的膜电位变化与输出情况

# 设置监视器
for m in model.modules():
    if isinstance(m, neuron.LIFNode):
        m.store_v_seq = True
monitor_o = monitor.OutputMonitor(model, neuron.LIFNode)
monitor_v = monitor.AttributeMonitor('v',
                                      pre_forward=False,
                                      net=model,
                                      instance=neuron.LIFNode)
print('model:', model)
print('monitor_v:', monitor_v.monitored_layers)
print('monitor_o:', monitor_o.monitored_layers)
# 选择一组输入
x, y = test_dataset[0]
f_out = torch.zeros((y.shape[0], 10))
with torch.no_grad():
    # 逐步传播
    for t in range(T):
        encoded_x = encoder(x.reshape(-1,784))
        f_out += model(encoded_x)
    functional.reset_net(model)
    label = y.argmax().item()
    pred = f_out.argmax().item()
print('label:{},predict:{}'.format(label, pred))
# 膜电位与输出可视化
# 膜电位变化
dpi = 100
figsize = (6, 4)
# 合并列表中的张量,删除多余维度,删除梯度信息
v_list = torch.stack(monitor_v['layer.1']).squeeze().detach()
visualizing.plot_2d_heatmap(array=v_list.numpy(),
                            title='Membrane Potentials',
                            xlabel='Simulating Step',
                            ylabel='Neuron Index',
                            int_x_ticks=True,
                            x_max=T,
                            figsize=figsize,
                            dpi=dpi)
# 神经元输出
s_list = torch.stack(monitor_o['layer.1']).squeeze().detach()
visualizing.plot_1d_spikes(spikes=s_list.numpy(),
                            title='Out Spikes',
                            xlabel='Simulating Step',
                            ylabel='Neuron Index',
                            figsize=figsize,
                            dpi=dpi)

预测结果:

model: SNN(
  (layer): Sequential(
    (0): Linear(in_features=784, out_features=10, bias=False)
    (1): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
  )
)
monitor_v: ['layer.1']
monitor_o: ['layer.1']
label:7,predict:7

膜电位变化:

神经元输出:

2、多步模式

将单步模式改为多步模式,需要修改以下部分:

(1)将神经元层的步进模式由’s’改为’m’

neuron.LIFNode(tau=2.0,
               decay_input=True,
               v_threshold=1.0,
               v_reset=0.0,
               surrogate_function=surrogate.ATan(),
               step_mode='m',
               store_v_seq=False)

(2)一次将所有时间步的数据全部输入

            encoded_x = encoder(x).repeat(T,1,1))
            f_out += model(encoded_x).sum(axis=0)
            f_out /= T

(3)修改监视器监视的变量

monitor_v = monitor.AttributeMonitor('v_seq',
                                      pre_forward=False,
                                      net=model,
                                      instance=neuron.LIFNode)

输出情况:

①训练结果

Epoch:10
loss_train: 0.8167978068813682
acc_train:91.06%:
acc_test:89.78%:
Time:145.1s

②网络结构

model: SNN(
  (layer): Sequential(
    (0): Linear(in_features=784, out_features=10, bias=False)
    (1): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
  )
)
monitor_v: ['layer.1']
monitor_o: ['layer.1']
label:7,predict:7

③膜电位变化

④神经元输出:


总结

使用梯度替代法进行反向传播时,使用可微的激活函数替代,避免脉冲的不可微;

使用编码器将输入编码为1/0脉冲序列;

将神经元层代替激活函数;

“在正确构建网络的情况下,逐层传播的并行度更大,速度更快”。但在此逐步传播比逐层传播略快一些。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/577432.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

miniTry:Python实现web搜索(全自动+程序操控)

声明&#xff1a;本问给出了全部代码--可以复现--亲测有效 :) [ 代码为图片--> 强制自己去敲一次 又不多] 1.打开网站&#xff1a; 2.利用id去定位到我们要进行输入的内容&#xff08;bing可以直接进行搜索&#xff0c;而csdn需要登录&#xff0c;所以我们用csdn做演示&…

HODL、FUD、FOMO 等其他比特币俚语是什么意思?

作者&#xff1a;Paxful Team 1、FOMO&#xff08;惧怕错失机会&#xff09; FOMO 是惧怕错失机会的缩写&#xff0c;可用于日常生活。它指的是当其他人都在谈论比特币时&#xff0c;产生的购买比特币的紧迫感。 2、Shill&#xff08;不断推广吹捧&#xff09; Shilling 是指…

linux支持vGPU方案

1&#xff0c;查询gpu型号&#xff1a;lspci | grep "NVIDIA\|VGA" PCI Devices 2&#xff0c;下载驱动 官方驱动 | NVIDIA 3&#xff0c;安装 sudo sh NVIDIA-Linux-x86_64-440.118.02.run -no-x-check -no-nouveau-check -no-opengl-files参数说明&#xff1a; …

自定义View-旋转变色圆角三角形的绘制

本文字数&#xff1a;3151字 预计阅读时间&#xff1a;20分钟 在现代设计中&#xff0c;动效图在APP的UI界面中所起到的作用无疑是显著的。相比于静态的界面&#xff0c;动效更符合人类的自然认知体系&#xff0c;它有效地降低了用户的认知负载&#xff0c;UI动效俨然已经成为了…

汽车新四化,会发生什么?

北京国际汽车展览会正如火如荼地进行中,作为国内外汽车行业瞩目的盛会&#xff0c;众多车企纷纷亮出了自家的“杀手锏”。 这场汽车的盛宴不仅集中展示了众多汽车品牌的最新技术和产品&#xff0c;更深刻体现了汽车新四化的发展趋势。汽车新四化&#xff0c;即电动化、网联化、…

DS进阶:AVL树和红黑树

一、AVL树 1.1 AVL树的概念 二叉搜索树&#xff08;BST&#xff09;虽可以缩短查找的效率&#xff0c;但如果数据有序或接近有序二叉搜索树将退化为单支树&#xff0c;查找元素相当于在顺序表中搜索元素&#xff0c;效率低下。因此&#xff0c;两位俄罗斯的数学家G.M.Adelson-…

使用Keil移植工程时修改单片机型号参数

系列文章目录 STM32单片机系列专栏 C语言术语和结构总结专栏 当使用Keil对STM32系列单片机开发时&#xff0c;如果使用的是库函数&#xff0c;那么不同型号单片机的工程项目文件是可以直接移植的。只需要按照下面的步骤修改对应的芯片&#xff0c;就可以直接将工程移植过去&a…

JVM垃圾收集器--分区收集器

G1收集器 属性 G1&#xff08;Garbage-First Garbage Collector&#xff09;在 JDK 1.7 时引入&#xff0c;在 JDK 9 时取代 CMS 成为了默认的垃圾收集器。G1 有五个属性&#xff1a;分代、增量、并行、标记整理、STW。 分代 G1收集器 将内部分为多个大小相等的区域&#x…

Java8 Stream常见用法

Stream流的常见用法&#xff1a; 1.利用stream流特性把数组转list集合 //定义一个数组Integer[] array {5,2,1,6,4,3};//通过stream特性把数组转list集合List<Integer> list Arrays.stream(array).collect(Collectors.toList());//打印结果System.out.println(list);…

全球数据爬取的解决方案-国外数据爬取

引言 随着经济的持续低迷和对外贸易的需求扩大&#xff0c;各个公司为了更好的了解海外客户情况&#xff0c;最简单直接的办法就是从全球收集公共的网络数据。 无论是海外电商用户的消费习惯还是训练自己的通用人工智能chatgpt&#xff0c;都是需要海量和多种类型数据的支持。…

【Linux】进程间通信(共享内存、消息队列、信号量)

一、System V —— 共享内存&#xff08;详解&#xff09; 共享内存区是最快的 IPC 形式。一旦这样的内存映射到共享它的进程的地址空间&#xff0c;这些进程间数据传递不再涉及到内核&#xff0c;换句话说&#xff0c;就是进程不再通过执行进入内核的系统调用来传递彼此的数…

[NSSCTF]prize_p5

前言 之前就学过反序列化的字符串逃逸 但是没怎么做题 补一下窟窿 题目 <?phperror_reporting(0);class catalogue{public $class;public $data;public function __construct(){$this->class "error";$this->data "hacker";}public functi…

HTTP网络协议的请求方法,具体详解(2024-04-26)

1、HTTP 即超文本传输协议&#xff0c;是一种实现客户端和服务器之间通信的响应协议&#xff0c;它是用作客户端和服务器之间的请求 根据 HTTP 标准&#xff0c;HTTP 请求可以使用多种请求方法。 2、方法分类 HTTP1.0 定义了三种请求方法&#xff1a; GET, POST 和 HEAD 方…

表情识别 | 卷积神经网络(CNN)人脸表情识别(Matlab)

表情识别 | 卷积神经网络(CNN)人脸表情识别&#xff08;Matlab&#xff09; 目录 表情识别 | 卷积神经网络(CNN)人脸表情识别&#xff08;Matlab&#xff09;预测效果基本介绍程序设计参考资料 预测效果 基本介绍 Matlab使用卷积神经网络(CNN)&#xff0c;进行人脸表情情绪识别…

论文解读:(CAVPT)Dual Modality Prompt Tuning for Vision-Language Pre-Trained Model

v1文章名字&#xff1a;Dual Modality Prompt Tuning for Vision-Language Pre-Trained Model v2文章名字&#xff1a;Class-Aware Visual Prompt Tuning for Vision-Language Pre-Trained Model 文章汇总 对该文的改进&#xff1a;论文解读&#xff1a;(VPT)Visual Prompt …

STM32H7 HSE时钟的使用方法介绍

目录 概述 1 STM32H750 HSE时钟介绍 2 使用STM32Cube创建Project 3 认识HSE时钟 3.1 HSE时钟的特性 3.2 HSE的典型应用电路 4 STM32Cube中配置时钟 4.1 时钟需求 4.2 配置参数 4.2.1 使能外围资源 4.2.2 使用STM32Cube注意项 4.2.3 配置参数 5 总结 概述 本文主要…

ESP-IDF编译系统详解(1)

接前一篇文章&#xff1a;VSCode ESP-IDF安装与配置全过程 本文内容主要参考&#xff1a; 《ESP32-C3物联网工程开发实战》 —— 乐鑫科技 编著 特此致谢&#xff01; 前文已经详述了ESP-IDF开发环境的搭建&#xff0c;包括ESP-IDF的下载与安装完整流程&#xff0c;以及VSCo…

【Docker】Docker 实践(一):在 Docker 中部署第一个应用

Docker 实践&#xff08;一&#xff09;&#xff1a;在 Docker 中部署第一个应用 1.使用 YUM 方式安装 Docker2.验证 Docker 环境3.在 Docker 中部署第一个应用3.1 小插曲&#xff1a;docker pull 报 missing signature key 错误3.2 重新安装 Nginx 1.使用 YUM 方式安装 Docker…

MySQL中脏读与幻读

一般对于我们的业务系统去访问数据库而言&#xff0c;它往往是多个线程并发执行多个事务的&#xff0c;对于数据库而言&#xff0c;它会有多个事务同时执行&#xff0c;可能这多个事务还会同时更新和查询同一条数据&#xff0c;所以这里会有一些问题需要数据库来解决 我们来看…

IoT Scenario: Smart Retail System-Multiple Sources and Multiple Terminals

物联网/大数据可视化领域发文可以联系&#xff1a;nascimsina.com IoT Scenario: Smart Retail System Overview The use of IoT in the retail industry enhances customer experiences, optimizes inventory management, and provides valuable insights into consumer beh…