基于pytorch_lightning测试resnet18不同激活方式在CIFAR10数据集上的精度

基于pytorch_lightning测试resnet18不同激活方式在CIFAR10数据集上的精度

  • 一.曲线
    • 1.train_acc
    • 2.val_acc
    • 3.train_loss
    • 4.lr
  • 二.代码

本文介绍了如何基于pytorch_lightning测试resnet18不同激活方式在CIFAR10数据集上的精度
特别说明:
1.NoActive:没有任何激活函数
2.SparseActivation:只保留topk的激活,其余清零,topk通过训练得到[初衷是想让激活变得稀疏]
3.SelectiveActive:通过训练得到使用的激活函数
可参考的代码片段
1.pytorch_lightning 如何使用
2.pytorch如何替换激活函数
3.如何对自定义权值做衰减

一.曲线

1.train_acc

在这里插入图片描述

2.val_acc

在这里插入图片描述

3.train_loss

在这里插入图片描述

4.lr

在这里插入图片描述

二.代码

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import os
import numpy as np
from pytorch_lightning.loggers import TensorBoardLogger

#torch.set_float32_matmul_precision('medium')

class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(outchannel)
            )
        self.act=nn.ReLU()

    def forward(self, x):
        out = self.left(x)
        out += self.shortcut(x)
        out = self.act(out)
        return out

class ResNet(nn.Module):
    def __init__(self, ResidualBlock, num_classes=10):
        super(ResNet, self).__init__()
        self.inchannel = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.layer1 = self.make_layer(ResidualBlock, 64,  2, stride=1)
        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
        self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
        self.fc = nn.Linear(512, num_classes)
        self.dropout=nn.Dropout(0.5)

    def make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, stride))
            self.inchannel = channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.dropout(out)
        out = self.fc(out)
        return out
  
class SparseActivation(nn.Module):
    act_array=[x.cuda() for x in [nn.ReLU(),
                nn.ReLU6(),
                nn.Sigmoid(),
                nn.Hardsigmoid(),
                nn.GELU(),
                nn.SiLU(),
                nn.Mish(),
                nn.LeakyReLU(),
                nn.Hardswish(),
                nn.PReLU(),
                nn.SELU(),
                nn.Softplus(),
                nn.Softsign()]]
                    
    def __init__(self,args):
        super(SparseActivation, self).__init__()
        self.input_weights = nn.Parameter(torch.randn(1)).cuda()
        self.act=SparseActivation.act_array
        self.act_weights = nn.Parameter(torch.randn(len(self.act))).cuda()
        self.args=args
        
    def forward(self, x):        
        
        index=self.args.act
        if index>=0:
            index=index-1
            if index==-1:
                prob=F.softmax(self.act_weights,dim=0)
                _, index = torch.topk(prob, 1, dim=0)
            x=self.act[index](x)
        
        if self.args.sparse==0:
            return x
            
        input=x.flatten(1)
        input_weights = torch.sigmoid(self.input_weights)        
        topk = input.size(1)*input_weights
        topk=topk.int()
        topk_vals, topk_indices = torch.topk(input, topk, dim=1)
        mask = torch.zeros_like(input).scatter(1, topk_indices, topk_vals)
        return mask.reshape(x.shape)
            
class LitNet(pl.LightningModule):
    def __init__(self, args):
        super(LitNet, self).__init__()
        self.save_hyperparameters()
        self.args = args
        self.resnet18 = ResNet(ResidualBlock)
        self.criterion = nn.CrossEntropyLoss()
        self.ws=[]
        self.replace_activation(self.resnet18,nn.ReLU, SparseActivation,self.ws)    
        
    def replace_activation(self,module, old_activation, new_activation,ws):
        for name, child in module.named_children():
            if isinstance(child, old_activation):
                op=new_activation(self.args)
                ws.append(op.input_weights)
                setattr(module, name,op)
            else:
                self.replace_activation(child, old_activation, new_activation,ws)        
        
    def forward(self, x):
        return self.resnet18(x)

    def on_train_epoch_start(self):
        self.train_total_loss=[]
        self.train_total_acc=[]

    def on_train_epoch_end(self):
        self.log('epoch_train_loss', np.mean(self.train_total_loss))
        self.log('epoch_train_acc', np.mean(self.train_total_acc)) 
        self.log("lr",self.optimizer.state_dict()['param_groups'][0]['lr'])
        
    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = self.criterion(output, target)
        
        l2_reg = torch.tensor(0.).cuda()
        l2_lambda=0.001
        for param in self.ws:
            l2_reg += torch.norm(param+4)                    
        loss += l2_lambda * l2_reg        
        self.log('iter_train_loss', loss)

        _, predicted = torch.max(output.data, 1)
        correct = (predicted == target).sum()
        acc = 100. * correct / target.size(0)      
        self.train_total_loss.append(loss.item())
        self.train_total_acc.append(acc.item())
        
        return loss       

    def on_validation_epoch_start(self):
        self.val_total_loss=[]
        self.val_total_acc=[]

    def on_validation_epoch_end(self):
        self.log('epoch_val_loss', np.mean(self.val_total_loss))
        self.log('epoch_val_acc', np.mean(self.val_total_acc))

    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        _, predicted = torch.max(output.data, 1)
        correct = (predicted == target).sum()
        acc = 100. * correct / target.size(0)
        loss = self.criterion(output, target)        
        self.val_total_loss.append(loss.item())
        self.val_total_acc.append(acc.item())

    def test_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = self.criterion(output, target)
        self.log('test_loss', loss)
        return loss
        
    def configure_optimizers(self):
        self.optimizer = optim.SGD(self.parameters(), lr=self.args.lr, momentum=0.9,weight_decay=5e-4)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,step_size=10,gamma = 0.8)            
        return [self.optimizer],[self.scheduler]

class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage=None):
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4), 
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        self.train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
        self.test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size,shuffle=True,num_workers=2,persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size,shuffle=False,num_workers=2,persistent_workers=True)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)


def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',help='input batch size for training (default: 64)')
    parser.add_argument('--epochs', type=int, default=100, metavar='N',help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',help='learning rate (default: 1.0)')
    parser.add_argument('--act', type=int, default=-1,help='learning rate (default: 1.0)')
    parser.add_argument('--sparse', type=int, default=0,help='learning rate (default: 1.0)')
    args = parser.parse_args()

    cifar10_data = CIFAR10DataModule(batch_size=args.batch_size)
    log_dir = "lightning_logs"
    
    
    args.sparse=0   #不开启稀疏
    args.act=0      #自适应激活
    model = LitNet(args)
    
    logger = TensorBoardLogger(save_dir=log_dir, name="SelectiveActive")    
    trainer = pl.Trainer(logger=logger,devices=1,max_epochs=args.epochs,val_check_interval=1.0,gradient_clip_val=0.9, gradient_clip_algorithm="value")
    trainer.fit(model, cifar10_data)    
    
    args.sparse=0     #不开启稀疏
    args.act=-1       #不用激活
    model = LitNet(args)    
    cifar10_data = CIFAR10DataModule(batch_size=args.batch_size)
    
    logger = TensorBoardLogger(save_dir=log_dir, name="NoActive")    
    trainer = pl.Trainer(logger=logger,devices=1,max_epochs=args.epochs,val_check_interval=1.0,gradient_clip_val=0.9, gradient_clip_algorithm="value")
    trainer.fit(model, cifar10_data)  
   
    args.sparse=1
    args.act=-1       #不用激活,开启稀疏
    model = LitNet(args)       
    
    logger = TensorBoardLogger(save_dir=log_dir, name="SparseActivation")    
    trainer = pl.Trainer(logger=logger,devices=1,max_epochs=args.epochs,val_check_interval=1.0,gradient_clip_val=0.9, gradient_clip_algorithm="value")
    trainer.fit(model, cifar10_data)  

    for idx,act_name in enumerate(SparseActivation.act_array):
        name=act_name.__class__.__name__
        print(name)
        
        args.act=idx+1
        args.sparse=0
        model = LitNet(args)     
        
        logger = TensorBoardLogger(save_dir=log_dir, name=name)    
        trainer = pl.Trainer(logger=logger,devices=1,max_epochs=args.epochs,val_check_interval=1.0,gradient_clip_val=0.9, gradient_clip_algorithm="value")
        trainer.fit(model, cifar10_data)

if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/695718.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

机器学习--线性模型和非线性模型的区别?哪些模型是线性模型,哪些模型是非线性模型?

文章目录 引言线性模型和非线性模型的区别线性模型非线性模型 总结线性模型非线性模型 引言 在机器学习和统计学领域,模型的选择直接影响到预测的准确性和计算的效率。根据输入特征与输出变量之间关系的复杂程度,模型可以分为线性模型和非线性模型。线性…

C语言 | Leetcode C语言题解之第142题环形链表II

题目: 题解: struct ListNode* detectCycle(struct ListNode* head) {struct ListNode *slow head, *fast head;while (fast ! NULL) {slow slow->next;if (fast->next NULL) {return NULL;}fast fast->next->next;if (fast slow) {s…

Linux网络命令——tcpdump

tcpdump是Linux下的一个网络数据采集分析工具,也就是常说的抓包工具 tcpdump 核心参数 tcpdump [option] [proto] [dir] [type] 例如:$ tcpdump -i eth0 -nn -s0 -v port 80 option 可选参数: -i : 选择要捕获的接口,通常是以太…

插卡式仪器模块:音频分析模块(插卡式)

• 24 位分辨率 • 192 KHz 采样率 • 支持多种模拟音频信号的输入/输出 应用场景 • 音频信号分析:幅值、频率、信噪比、THD、THDN 等指标 • 模拟音频测试:耳机、麦克风、扬声器测试,串扰测 音频分析仪 输入阻抗10 TΩ10 TΩ输入范围3…

【C语言】宏详解(下卷)

前言 紧接上卷,我们继续来了解宏。 宏替换的规则 1.在调用宏时,首先对参数进行检查,看看是否包含任何由#define定义的符号。如果是,它们首先被替换。 2.替换文本随后被插入到程序中原来文本的位置。对于宏,参数名被他…

C++类与对象(拷贝与类的内存管理)

感谢大佬的光临各位,希望和大家一起进步,望得到你的三连,互三支持,一起进步 个人主页:LaNzikinh-CSDN博客 文章目录 前言一.对象的动态建立和释放二.多个对象的构造和析构三.深拷贝与浅拷贝四.C类的内存管理总结 前言 …

⌈ 传知代码 ⌋ 以思维链为线索推理隐含情感

💛前情提要💛 本文是传知代码平台中的相关前沿知识与技术的分享~ 接下来我们即将进入一个全新的空间,对技术有一个全新的视角~ 本文所涉及所有资源均在传知代码平台可获取 以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦&#x…

基于改进YOLOv5的小目标检测 | 添加CBAM注意机制 + 更换Neck网络之BiFPN + 增加高分辨率检测头

前言:Hello大家好,我是小哥谈。本文针对图像中小目标难以检测的问题,提出了一种基于YOLOv5的改进模型。在主干网络中,加入CBAM注意力模块增强网络特征提取能力;在颈部网络部分,使用BiFPN结构替换PANet结构&…

Linux驱动应用编程(三)UART串口

本文目录 前述一、手册查看二、命令行调试串口1. 查看设备节点2. 使用stty命令设置串口3. 查看串口配置信息4. 调试串口 三、代码编写1. 常用API2. 例程●线程优化●poll优化●select优化(功能和poll一样) 前述 在开始实验前,请一定要检查测试…

【RabbitMQ】RabbitMQ配置与交换机学习

【RabbitMQ】RabbitMQ配置与交换机学习 文章目录 【RabbitMQ】RabbitMQ配置与交换机学习简介安装和部署1. 安装RabbitMQ2.创建virtual-host3. 添加依赖4.修改配置文件 WorkQueues模型1.编写消息发送测试类2.编写消息接收(监听)类3. 实现能者多劳 交换机F…

【深度学习】—— 神经网络介绍

神经网络介绍 本系列主要是吴恩达深度学习系列视频的笔记,传送门:https://www.coursera.org/deeplearning-ai 目录 神经网络介绍神经网络的应用深度学习兴起的原因 神经网络,全称人工神经网络(Artificial Neural Network&#xf…

25.逢七必过

上海市计算机学会竞赛平台 | YACSYACS 是由上海市计算机学会于2019年发起的活动,旨在激发青少年对学习人工智能与算法设计的热情与兴趣,提升青少年科学素养,引导青少年投身创新发现和科研实践活动。https://www.iai.sh.cn/problem/363 题目描述 逢七必过的游戏规则如下:对一…

Linux安装Docker | 使用国内镜像

环境 CentOS7 先确认能够上网 curl www.baidu.com返回该输出说明网络OK 步骤一:安装gcc 和 gcc-c yum -y install gccyum -y install gcc-c步骤二:安装Docker仓库 yum install -y yum-utils接下来配置yum的国内镜像 yum-config-manager --add-re…

激活乡村振兴新动能:推动农村产业融合发展,打造具有地方特色的美丽乡村,实现乡村全面振兴

目录 一、推动农村产业融合发展 1、农业产业链条的延伸 2、农业与旅游业的结合 二、挖掘地方特色,打造美丽乡村 1、保护和传承乡村文化 2、发展特色农业 三、加强基础设施建设,提升乡村品质 1、改善农村交通条件 2、提升农村水利设施 四、促进…

大数据湖一体化运营管理建设方案(49页PPT)

方案介绍: 本大数据湖一体化运营管理建设方案通过构建统一存储、高效处理、智能分析和安全管控的大数据湖平台,实现了企业数据的集中管理、快速处理和智能分析。该方案具有可扩展性、高性能、智能化、安全性和易用性等特点,能够为企业数字化…

水滴型锤片粉碎机:多功能粉碎利器

在现代工业生产中,粉碎机作为一种重要的机械设备,广泛应用于饲料、化工、木材等多个领域。其中,水滴型锤片粉碎机凭借其设计和粉碎能力,成为市场上的热门产品。 水滴型锤片粉碎机其设计灵感来源于水滴的形态。这种设计使得机器在…

vmware-17虚拟机安装教程,安装linux centos系统

下载VMware 1.进入VMware官网:https://www.vmware.com/sg/products/workstation-pro.html 2.向下翻找到,如下界面并点击“现在安装” 因官网更新页面出现误差,现提供vmware17安装包网盘链接如下: 链接:https://pan.b…

【SpringBoot + Vue 尚庭公寓实战】基本属性接口实现(七)

【SpringBoot Vue 尚庭公寓实战】基本属性接口实现(七) 文章目录 【SpringBoot Vue 尚庭公寓实战】基本属性接口实现(七)1、保存或更新属性名称2、保存或更新属性值3、查询全部属性名称和属性值列表4、根据ID删除属性名称5、根据…

freertos内核拓展DAY2(消息队列)

这节内容是信号量的基础,因为创建以及发送/等待信号量所调用的底层函数,就是创建/发送/接受消息队列时所用到的通用创建函数,这里先补充一下数据结构中关于队列的知识。 目录 1. 队列原理 1.1 顺序队列操作 1.2 循环队列操作 2.消息队列原…

N32G45XVL-STB之移植LVGL(lvgl-8.2.0)

目录 概述 1 软硬件介绍 1.1 软件版本信息 1.2 ST7796-LCD 1.3 MCU IO与LCD PIN对应关系 2 认识LVGL 2.1 LVGL官网 2.2 LVGL库文件下载 3 移植LVGL 3.1 准备移植文件 3.2 添加lvgl库文件到项目 3.2.1 src下的文件 3.2.2 examples下的文件 3.2.3 配置文件路径 3.2…