基于Pytorch框架的深度学习Vision Transformer神经网络蝴蝶分类识别系统源码

 第一步:准备数据

6种蝴蝶数据:self.class_indict = ["曙凤蝶", "麝凤蝶", "多姿麝凤蝶", "旖凤蝶", "红珠凤蝶", "热斑凤蝶"],总共有900张图片,每个文件夹单独放一种数据

第二步:搭建模型

本文选择一个Vision Transformer网络,其原理介绍如下:

Vision Transformer(ViT)是一种基于Transformer架构的深度学习模型,用于图像识别和计算机视觉任务。与传统的卷积神经网络(CNN)不同,ViT直接将图像视为一个序列化的输入,并利用自注意力机制来处理图像中的像素关系。

ViT通过将图像分成一系列的图块(patches),并将每个图块转换为向量表示作为输入序列。然后,这些向量将通过多层的Transformer编码器进行处理,其中包含了自注意力机制和前馈神经网络层。这样可以捕捉到图像中不同位置的上下文依赖关系。最后,通过对Transformer编码器输出进行分类或回归,可以完成特定的视觉任务。

Vit model结构图
Vit的模型结构如下图所示。vit是将图像块应用于transformer。CNN是以滑窗的思想用卷积核在图像上进行卷积得到特征图。为了可以使图像仿照NLP的输入序列,我们可以先将图像分成块(patch),再将这些图像块进行平铺后输入到网络中(这样就变成了图像序列),然后通过transformer进行特征提取,最后再通过MLP对这些特征进行分类【其实就可以理解为在以往的CNN分类任务中,将backbone替换为transformer】。

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import math
import argparse

import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms


from my_dataset import MyDataSet
from vit_model import vit_base_patch16_224_in21k as create_model
from utils import read_split_data, train_one_epoch, evaluate


def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    if os.path.exists("./weights") is False:
        os.makedirs("./weights")

    tb_writer = SummaryWriter()

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}

    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)

    model = create_model(num_classes=args.num_classes, has_logits=False).to(device)

    if args.weights != "":
        assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
        weights_dict = torch.load(args.weights, map_location=device)
        # 删除不需要的权重
        del_keys = ['head.weight', 'head.bias'] if model.has_logits \
            else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias']
        for k in del_keys:
            del weights_dict[k]
        print(model.load_state_dict(weights_dict, strict=False))

    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除head, pre_logits外,其他权重全部冻结
            if "head" not in name and "pre_logits" not in name:
                para.requires_grad_(False)
            else:
                print("training {}".format(name))

    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5)
    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    for epoch in range(args.epochs):
        # train
        train_loss, train_acc = train_one_epoch(model=model,
                                                optimizer=optimizer,
                                                data_loader=train_loader,
                                                device=device,
                                                epoch=epoch)

        scheduler.step()

        # validate
        val_loss, val_acc = evaluate(model=model,
                                     data_loader=val_loader,
                                     device=device,
                                     epoch=epoch)

        tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
        tb_writer.add_scalar(tags[0], train_loss, epoch)
        tb_writer.add_scalar(tags[1], train_acc, epoch)
        tb_writer.add_scalar(tags[2], val_loss, epoch)
        tb_writer.add_scalar(tags[3], val_acc, epoch)
        tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)

        torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=6)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch-size', type=int, default=4)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--lrf', type=float, default=0.01)

    # 数据集所在根目录
    # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    parser.add_argument('--data-path', type=str,
                        default=r"G:\demo\data\Butterfly20")
    parser.add_argument('--model-name', default='', help='create model name')

    # 预训练权重路径,如果不想载入就设置为空字符
    parser.add_argument('--weights', type=str, default='./vit_base_patch16_224_in21k.pth',
                        help='initial weights path')
    # 是否冻结权重
    parser.add_argument('--freeze-layers', type=bool, default=True)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

    opt = parser.parse_args()

    main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习Vision Transformer神经网络蝴蝶分类识别系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/730261.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

使用自签名 TLS 将 Dremio 连接到 MinIO

Dremio 是一个开源的分布式分析引擎,为数据探索、转换和协作提供简单的自助服务界面。Dremio 的架构建立在 Apache Arrow(一种高性能列式内存格式)之上,并利用 Parquet 文件格式实现高效存储。有关 Dremio 的更多信息,…

跑通并使用Yolo v5的源代码并进行训练—目标检测

跑通并使用Yolo v5的源代码并进行训练 摘要:yolo作为目标检测计算机视觉领域的核心网络模型,虽然到24年已经出到了v10的版本,但也很有必要对之前的核心版本v5版本进行进一步的学习。在学习yolo v5的时候因为缺少论文所以要从源代码入手来体验…

Eureka 学习笔记(1)

一 、contextInitialized() eureka-core里面,监听器的执行初始化的方法,是contextInitialized()方法,这个方法就是整个eureka-server启动初始化的一个入口。 Overridepublic void contextInitialized(ServletContextEvent event) {try {init…

生产实习Day9 ---- Scala介绍

文章目录 Scala:融合面向对象与函数式编程的强大语言引言Scala与Java的互操作性Scala在大数据处理中的应用Scala的并发编程Scala的学习资源和社区结论 Scala:融合面向对象与函数式编程的强大语言 引言 Scala,全称Scalable Language&#xff…

教你开发一个适合外贸的消息群发工具!

在全球化日益加速的今天,外贸业务已经成为许多企业不可或缺的一部分,而在外贸业务中,高效的消息群发工具则扮演着至关重要的角色。 它能够帮助企业快速、准确地传达产品信息、促销活动等重要内容,从而提升业务效率和客户满意度&a…

项目经验——交通行业数据可视化大屏、HMI设计

交通行业数据大屏、HMI设计时要的注意点:清晰可读、简洁直观、适配性强。颜色对比度满足WCAG标准,深色背景减少干扰,实时展示交通数据,支持有线网络控制内容更新,保障驾驶安全与决策效率。

Linux企业 集群批量管理-秘钥认证

集群批量管理-秘钥认证 概述 管理更加轻松:两个节点,通过秘钥认证形成进行访问,不需要输入密码,单向服务要求(应用场景): 一些服务在使用前要求我们做秘钥认证 手动写批量管理脚本名字&#x…

A800显卡驱动安装(使用deb安装)

重新安装显卡驱动,查阅了资料将过程记录如下: 1.下载deb安装包 打开nvidia官网查找对应的驱动版本,A800所在的选项卡位置如图: 点击查找后下载得到的是nvidia-driver-local-repo-ubuntu2004-550.90.07_1.0-1_amd64.deb安装包 2.…

猫头虎分享:IPython的使用技巧整理

🐯 猫头虎分享:IPython的使用技巧整理 关于猫头虎 大家好,我是猫头虎,别名猫头虎博主,擅长的技术领域包括云原生、前端、后端、运维和AI。我的博客主要分享技术教程、bug解决思路、开发工具教程、前沿科技资讯、产品…

【学一点儿前端】单页面点击前进或后退按钮导致的内存泄露问题(history.listen监听器清除)

今天测试分配了一个比较奇怪的问题,在单页面应用中,反复点击“上一步”和“下一步”按钮时,界面表现出逐渐变得卡顿。为分析这一问题,我用Chrome的性能监控工具进行了浏览器性能录制。结果显示,每次点击“上一步”按钮…

【2024最新华为OD-C/D卷试题汇总】[支持在线评测] 任务安排问题(200分) - 三语言AC题解(Python/Java/Cpp)

🍭 大家好这里是清隆学长 ,一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-C/D卷的三语言AC题解 💻 ACM银牌🥈| 多次AK大厂笔试 | 编程一对一辅导 👏 感谢大家的订阅➕ 和 喜欢💗 📎在线评测链接 https://app5938.acapp.acwing.com.cn/contest/2/problem/OD…

cron.timezone

系统 date 数据库 show timezone插件 show cron.timezonealter system set cron.timezonePRC;show cron.timezone

前端新手小白的Vue3入坑指南

昨天有同学说想暑假在家学一学Vue3,问我有没有什么好的文档,我给他找了一些,然后顺带着,自己也写一篇吧,希望可以给新手小白们一些指引,Vue3欢迎你。 目录 1 项目安装 1.1 初始化项目 1.2 安装初始化依…

CDGA|数据治理要点是数据稳定、规范、安全,就像盖楼盘一样

在数字化浪潮汹涌的时代,数据已经成为企业运营和社会发展的核心驱动力。如同高楼大厦需要稳固的地基和规范的施工流程,数据治理同样需要确保数据的稳定性、规范性和安全性,以构建坚实可靠的数据大厦。 数据治理的首要任务是确保数据的稳定性 …

一文读懂过零检测电路的作用、电路原理图及应用

过零检测电路是一种常见的应用,其中运算放大器用作比较器。它通常用于跟踪正弦波形的变化,例如过零电压从正到负或从负到正。它还可以用作方波发生器。过零检测电路有许多应用,例如标记信号发生器、相位计和频率计。#过零检测电路#可以采用多…

人工智能与生物信息组学 || 2. 非编码 RNA 与疾病关联分析 || 2.2 非编码 miRNA 与疾病关联关系预测

非编码 miRNA 与疾病关联关系预测 越来越多的研究表明,一个复杂疾病通常经由多个 miRNA 协同调控,一个 miRNA 通常参与多个疾病的发生发展过程。因此,预测 miRNA 与疾病的关联关系成为一个当前的研究热点。下面我们将探讨一种 miRNA 和疾病关…

spring-gateway配置说明

在开发过程中遇到的一些配置问题,记录下来以供参考 spring-gateway版本是2.2.9-release,使用的spring cloud dependence 是 Hoxton.SR12 在依赖eureka 服务发现并自动将发现服务器加入到router中的时候,需要指定对应的服务进行添加,根据文档…

对比 Axios 和 Fetch:选择最适合的 HTTP 请求方法

在前端开发中,处理 HTTP 请求是一个常见且重要的任务。JavaScript 提供了多种方式来发送网络请求,其中最受欢迎的两种方式分别就是 Fetch API 和 Axios。尽管两者都能完成同样的任务,即从客户端向服务器发送请求并接收响应,但它们…

学校教育为什么要选择SOLIDWORKS教育版?

在数字化和智能化时代,学校教育正面临着挑战与机遇。为了培养具备创新能力和实践技能的新时代人才,学校教育需要引入先进的教学工具和资源。SOLIDWORKS教育版作为一款专为教育和培训目的而设计的软件,以其全方面的功能、友好的用户界面、丰富…

[信号与系统]模拟域中的一阶低通滤波器和二阶滤波器

前言 不是学电子出身的,这里很多东西是问了朋友… 模拟域中的一阶低通滤波器传递函数 模拟域中的一阶低通滤波器的传递函数可以表示为: H ( s ) 1 s ω c H(s) \frac{1}{s \omega_c} H(s)sωc​1​ 这是因为一阶低通滤波器的设计目标是允许低频信…