StarNet实战:使用StarNet实现图像分类任务(一)

文章目录

  • 摘要
  • 安装包
    • 安装timm
  • 数据增强Cutout和Mixup
  • EMA
  • 项目结构
  • 计算mean和std
  • 生成数据集

摘要

https://arxiv.org/pdf/2403.19967
论文主要集中在介绍和分析一种新兴的学习范式——星操作(Star Operation),这是一种通过元素级乘法融合不同子空间特征的方法,通过元素级乘法(类似于“星”形符号的乘法操作)将不同子空间的特征进行融合,从而在多个研究领域中展现出出色的性能和效率。

星操作在自然语言处理(NLP)和计算机视觉(CV)等多个领域中都得到了成功应用。例如,在自然语言处理中,Monarch Mixer、Mamba、Hyena Hierarchy和GLU等模型都采用了星操作;在计算机视觉中,FocalNet、HorNet和VAN等模型也利用了星操作进行特征融合。

尽管星操作在多个领域中都取得了显著成果,但其背后的基本原理尚未得到全面分析和验证。StarNet通过深入探究星操作的细节,发现星操作具有将输入映射到极高维、非线性特征空间的能力。这种映射方式与传统增加网络宽度的方法不同,而是通过跨通道特征对乘实现了一种类似于多项式核函数的非线性高维映射。

当将星操作融入神经网络并堆叠多层时,每一层都使隐含的维度复杂度呈指数级增长。这种高效的特征融合方式使得星操作能够在紧凑的特征空间内实现近乎无限的维度,从而极大地提高了模型的表示能力和性能。

在这里插入图片描述

本文使用StarNet模型实现图像分类任务,模型选择starnet_s1,在植物幼苗分类任务ACC达到了95%+。

在这里插入图片描述
在这里插入图片描述

通过这篇文章能让你学到:

  1. 如何使用数据增强,包括transforms的增强、CutOut、MixUp、CutMix等增强手段?
  2. 如何实现StarNet模型实现训练?
  3. 如何使用pytorch自带混合精度?
  4. 如何使用梯度裁剪防止梯度爆炸?
  5. 如何使用DP多显卡训练?
  6. 如何绘制loss和acc曲线?
  7. 如何生成val的测评报告?
  8. 如何编写测试脚本测试测试集?
  9. 如何使用余弦退火策略调整学习率?
  10. 如何使用AverageMeter类统计ACC和loss等自定义变量?
  11. 如何理解和统计ACC1和ACC5?
  12. 如何使用EMA?

如果基础薄弱,对上面的这些功能难以理解可以看我的专栏:经典主干网络精讲与实战
这个专栏,从零开始时,一步一步的讲解这些,让大家更容易接受。

安装包

安装timm

使用pip就行,命令:

pip install timm

mixup增强和EMA用到了timm

数据增强Cutout和Mixup

为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    Cutout(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])

需要导入包:from timm.data.mixup import Mixup,

定义Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(
    mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
    prob=0.1, switch_prob=0.5, mode='batch',
    label_smoothing=0.1, num_classes=12)
 criterion_train = SoftTargetCrossEntropy()

Mixup 是一种在图像分类任务中常用的数据增强技术,它通过将两张图像以及其对应的标签进行线性组合来生成新的数据和标签。
参数详解:

mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。

cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。

如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0

prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。

switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。

mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。

correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正

label_smoothing (float):将标签平滑应用于混合目标张量。

num_classes (int): 目标的类数。

EMA

EMA(Exponential Moving Average)是指数移动平均值。在深度学习中的做法是保存历史的一份参数,在一定训练阶段后,拿历史的参数给目前学习的参数做一次平滑。具体实现如下:


import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn

_logger = logging.getLogger(__name__)

class ModelEma:
    def __init__(self, model, decay=0.9999, device='', resume=''):
        # make a copy of the model for accumulating moving average of weights
        self.ema = deepcopy(model)
        self.ema.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if device:
            self.ema.to(device=device)
        self.ema_has_module = hasattr(self.ema, 'module')
        if resume:
            self._load_checkpoint(resume)
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def _load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        assert isinstance(checkpoint, dict)
        if 'state_dict_ema' in checkpoint:
            new_state_dict = OrderedDict()
            for k, v in checkpoint['state_dict_ema'].items():
                # ema model may have been wrapped by DataParallel, and need module prefix
                if self.ema_has_module:
                    name = 'module.' + k if not k.startswith('module') else k
                else:
                    name = k
                new_state_dict[name] = v
            self.ema.load_state_dict(new_state_dict)
            _logger.info("Loaded state_dict_ema")
        else:
            _logger.warning("Failed to find state_dict_ema, starting from loaded model weights")

    def update(self, model):
        # correct a mismatch in state dict keys
        needs_module = hasattr(model, 'module') and not self.ema_has_module
        with torch.no_grad():
            msd = model.state_dict()
            for k, ema_v in self.ema.state_dict().items():
                if needs_module:
                    k = 'module.' + k
                model_v = msd[k].detach()
                if self.device:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

加入到模型中。

#初始化
if use_ema:
     model_ema = ModelEma(
            model_ft,
            decay=model_ema_decay,
            device='cpu',
            resume=resume)

# 训练过程中,更新完参数后,同步update shadow weights
def train():
    optimizer.step()
    if model_ema is not None:
        model_ema.update(model)


# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)

针对没有预训练的模型,容易出现EMA不上分的情况,这点大家要注意啊!

项目结构

StarNet_Demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─models
│  └─starnet.py
├─mean_std.py
├─makedata.py
├─train.py
└─test.py

mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
train.py:训练StarNet模型
models:来源官方代码。

计算mean和std

为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms

def get_mean_and_std(train_data):
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=1, shuffle=False, num_workers=0,
        pin_memory=True)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    for X, _ in train_loader:
        for d in range(3):
            mean[d] += X[:, d, :, :].mean()
            std[d] += X[:, d, :, :].std()
    mean.div_(len(train_data))
    std.div_(len(train_data))
    return list(mean.numpy()), list(std.numpy())

if __name__ == '__main__':
    train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())
    print(get_mean_and_std(train_dataset))

数据集结构:

image-20220221153058619

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob
import os
import shutil

image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):
    print('true')
    #os.rmdir(file_dir)
    shutil.rmtree(file_dir)#删除再建立
    os.makedirs(file_dir)
else:
    os.makedirs(file_dir)

from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(train_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

for file in val_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(val_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

完成上面的内容就可以开启训练和测试了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/713438.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[大模型]XVERSE-7B-chat langchain 接入

XVERSE-7B-Chat为XVERSE-7B模型对齐后的版本。 XVERSE-7B 是由深圳元象科技自主研发的支持多语言的大语言模型(Large Language Model),参数规模为 70 亿,主要特点如下: 模型结构:XVERSE-7B 使用主流 Deco…

echarts学习:通过图例事件实现选中后控制多条折线的显隐

1.问题描述 我在工作中遇到了这样一个需求:我们都知道点击echarts折线图的图例,是可以控制折线的显隐的。我现在希望点击某一个图例可以改变多条折线的显隐。 例如在下面这张图中,我将“xxx水位”和“yyy水位”分为一组;将“xxx…

521. 最长特殊序列 Ⅰ(Rust单百解法-脑筋急转弯)

题目 给你两个字符串 a 和 b,请返回 这两个字符串中 最长的特殊序列 的长度。如果不存在,则返回 -1 。 「最长特殊序列」 定义如下:该序列为 某字符串独有的最长 子序列 (即不能是其他字符串的子序列) 。 字符串 s …

【云原生】docker swarm 使用详解

目录 一、前言 二、容器集群管理问题 2.1 docker集群管理问题概述 2.1.1 docker为什么需要容器部署 2.2 docker容器集群管理面临的挑战 三、docker集群部署与管理解决方案 四、Docker Swarm概述 4.1 Docker Swarm是什么 4.1.1 Docker Swarm架构图 4.1.2 Docker Swarm几…

【MySQL】在CentOS环境下安装MySQL

目录 一、卸载残留环境 二、获取官方yum源 三、安装yum源 四、安装MySQL 五、启动MySQL 一、卸载残留环境 输入 ps axj | grep mysql 查看是否存在正在运行的MySQL服务 如果有,则先输入 systemctl stop mysqld 来关闭服务 然后输入 rpm -qa | grep mysql 查看…

Docker MySQL Shutting down mysqld

6月初至6月15日发现MySQL无故停机多次,导致系统无法使用。接下来各种日志查看,排查原因。先附上一份Docker种MySQL的日志的截图。 一、根据Docker的日志初步估计是数据库内存飙升,从而被系统杀掉进程 查询Linux系统日志,在宿主机…

Python武器库开发-武器库篇之Mongodb未授权漏洞扫描器(五十六)

Python武器库开发-武器库篇之Mongodb未授权漏洞扫描器(五十六) MongoDB 未授权访问漏洞简介以及危害 MongoDB是一款非常受欢迎的开源NoSQL数据库,广泛应用于各种Web应用和移动应用中。然而,由于默认配置的不当或者管理员的疏忽,导致不少Mon…

常用串口助手推荐

串口助手作为嵌入式软件工程师最常用的工具,相信大部分的同学都不陌生,这里就不介绍它的使用啦,介绍介绍有哪些好用的款。感兴趣的小伙伴也可以自己去写一个串口助手。 一、SSCOM5.13.1 站内下载资源: https://download.csdn.n…

51单片机STC89C52RC——2.2 独立按键控制LED亮灭Plus

目的 当独立K1按键按一下(立即松开),LED D1点亮。再按一下K1(立即松开)LED D1熄灭。 与前一节《51单片机STC89C52RC——2.1 独立按键控制LED亮灭》当独立K1按键按下时LED D1 点亮,松开D1熄灭 效果不一…

PyTorch 张量数据类型

【数据类型】Python 与 PyTorch 常见数据类型对应: 用 a.type() 获取数据类型,用 isinstance(a, 目标类型) 进行类型合法化检测 >>> import torch >>> a torch.randn(2,3) >>> a tensor([[-1.7818, -0.2472, -2.0684],[ 0.…

单片机与DHT11温湿度检测设计

本次设计是采用STC89C54单片机加上低成本的温湿度模块DHT11构成的温湿度检测系统。设计主要由硬件与软件两部分设计构成。硬件方面包括单片机STC89C54、温湿度模块DHT11、显示模块LCD1602、电池电源、I2C存储器以及控制按键等5个部分。此系统完全基于单片机最小系统并进行一定的…

Open vSwitch 中 vswitchd 事件上报

一、数据包转发流程与 vswitchd 事件上报 Open vSwitch 的数据包转发流程如下图所示: 在数据包的转发流程中,提到过慢速路径的概念:即当数据包在内核空间无法完全处理时,会产生 upcall 调用,将数据包从内核空间转发到用…

XGBoost预测及调参过程(+变量重要性)--血友病计数数据

所使用的数据是血友病数据,如有需要,可在主页资源处获取,数据信息如下: 读取数据及数据集区分 数据预处理及区分数据集代码如下(详细预处理说明见上篇文章--随机森林): import pandas as pd im…

RPG游戏完整指南

环境:unity2021urp 本教程教大家如何使用Unity创建一个RPG游戏,玩家可以在城镇场景中进行导航并寻找战斗,并在战斗中遇到不同类型的敌人。玩家可以向敌人施加不同的动作,如:常规攻击和撤离。这会是一个十分有趣的体验。…

AI时代新爬虫:网站自动转LLM数据,firecrawl深度玩法解读

在大模型的时代,爬虫技术也有了很多新的发展,最近出现了专门针对大模型来提取网站信息的爬虫,一键将网页内容转换为LLM-ready的数据。今天我们介绍其中的开源热门代表:firecrawl。 firecrawl 是什么 FireCrawl是一款创新的爬虫工…

数据资产治理与数据质量提升:构建完善的数据治理体系,确保数据资产的高质量与准确性

一、引言 随着信息技术的迅猛发展,数据已经成为企业和社会发展的重要资产。然而,数据资产的有效治理与数据质量的提升,是企业实现数字化转型、提升竞争力的关键。本文旨在探讨数据资产治理与数据质量提升的重要性,并提出构建完善…

开源高效API管理工具:RAP

RAP:简化API开发,提升团队协作效率- 精选真开源,释放新价值。 概览 RAP(RESTful API Project)是一个开源的API管理工具,由阿里巴巴团队开发并维护。它旨在帮助前后端开发人员通过一个统一的平台来设计、开…

Linux 按键输入实验

Linux 按键输入实验 1、添加 pinctrl 节点 首先修改在设备树里面添加关于按键的节点。I.MX6U-ALPHA 开发板上的 KEY 使用了 UART1_CTS_B 这个 PIN,打开 imx6ull-alientekemmc.dts,在 iomuxc 节点的 imx6ul-evk 子节点下创建一个名为“pinctrl_key”的子…

华为机考入门python3--(36)牛客36-字符串加密

分类:字符串 知识点: 判断一个元素是否在集合中 if char not in key_set 计算字母差 index ord(char) - ord(a) 题目来自【牛客】 # 生成加密表 def generate_cipher_table(key):key_set set()cipher_table ""# 去重for char in k…

TVBOX 最新版下载+视频源教程

下载链接 wx 搜索 Geek 前端 发送电视资源进行获取 操作教程