基于Pytorch框架的深度学习EfficientNetV2神经网络中草药识别分类系统源码

 第一步:准备数据

5种中草药数据:self.class_indict = ["百合", "党参", "山魈", "枸杞", "槐花", "金银花"]

,总共有900张图片,每个文件夹单独放一种数据

第二步:搭建模型

本文选择一个EfficientNetV2网络,其原理介绍如下:

        该网络主要使用训练感知神经结构搜索缩放的组合;在EfficientNetV1的基础上,引入了Fused-MBConv到搜索空间中;引入渐进式学习策略自适应正则强度调整机制使得训练更快;进一步关注模型的推理速度训练速度

与EfficientV1相比,主要有以下不同:

  1. V2中除了使用MBConv模块外,还使用了Fused-MBConv模块
  2. V2中会使用较小的expansion ratio,在V1中基本都是6。这样的好处是能够减少内存访问开销
  3. V2中更偏向使用更小的kernel_size(3 x 3),在V1中很多5 x 5。优于3 x 3的感受野是比5 x 5小的,所以需要堆叠更多的层结构以增加感受野
  4. 移除了V1中最优一个步距为1的stage

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import math
import argparse

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler

from model import efficientnetv2_s as create_model
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluate


def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    print(args)
    print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
    tb_writer = SummaryWriter()
    if os.path.exists("./weights") is False:
        os.makedirs("./weights")

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

    img_size = {"s": [300, 384],  # train_size, val_size
                "m": [384, 480],
                "l": [384, 480]}
    num_model = "s"

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model][0]),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
        "val": transforms.Compose([transforms.Resize(img_size[num_model][1]),
                                   transforms.CenterCrop(img_size[num_model][1]),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}

    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)

    # 如果存在预训练权重则载入
    model = create_model(num_classes=args.num_classes).to(device)
    if args.weights != "":
        if os.path.exists(args.weights):
            weights_dict = torch.load(args.weights, map_location=device)
            load_weights_dict = {k: v for k, v in weights_dict.items()
                                 if model.state_dict()[k].numel() == v.numel()}
            print(model.load_state_dict(load_weights_dict, strict=False))
        else:
            raise FileNotFoundError("not found weights file: {}".format(args.weights))

    # 是否冻结权重
    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除head外,其他权重全部冻结
            if "head" not in name:
                para.requires_grad_(False)
            else:
                print("training {}".format(name))

    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4)
    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    for epoch in range(args.epochs):
        # train
        train_loss, train_acc = train_one_epoch(model=model,
                                                optimizer=optimizer,
                                                data_loader=train_loader,
                                                device=device,
                                                epoch=epoch)

        scheduler.step()

        # validate
        val_loss, val_acc = evaluate(model=model,
                                     data_loader=val_loader,
                                     device=device,
                                     epoch=epoch)

        tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
        tb_writer.add_scalar(tags[0], train_loss, epoch)
        tb_writer.add_scalar(tags[1], train_acc, epoch)
        tb_writer.add_scalar(tags[2], val_loss, epoch)
        tb_writer.add_scalar(tags[3], val_acc, epoch)
        tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)

        torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=5)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch-size', type=int, default=4)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--lrf', type=float, default=0.01)

    # 数据集所在根目录
    # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    parser.add_argument('--data-path', type=str,
                        default=r"G:\demo\data\ChineseMedicine")

    # download model weights
    # 链接: https://pan.baidu.com/s/1uZX36rvrfEss-JGj4yfzbQ  密码: 5gu1
    parser.add_argument('--weights', type=str, default='./pre_efficientnetv2-s.pth',
                        help='initial weights path')
    parser.add_argument('--freeze-layers', type=bool, default=True)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

    opt = parser.parse_args()

    main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习EfficientNetV2神经网络中草药识别分类系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/663826.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

初识Spring Boot:构建项目结构与组件解析

目录 前言 第一点:项目的结构 第二点:controller类的创建与使用(构造器) 第二点:service类的创建与使用(逻辑层) 第三点:Mapper类的创建与使用(数据操作) 总结 前言 在进行Sp…

HQChart使用教程100-uniapp如何在vue3运行微信小程序

HQChart使用教程100-uniapp如何在vue3运行微信小程序 症状原因分析解决思路解决步骤1. 修改vender.js2. 修改HQChartControl.js 完整实例HQChart代码地址 症状 HQChart插件在uniappvue3的项目编译成小程序以后, 运行会报错,见下图。 原因分析 查了下…

抖音太可怕了,我卸载了

这两天刷短视频,上瘾了,太可怕了。 自己最近一直在研究短视频制作,所以下载了抖音,说实话,我之前手机上并没有抖音,一直在用B站。 用了两天抖音,我发现,这玩意比刷B站还容易上瘾啊…

【深度学习-第6篇】使用python快速实现CNN多变量回归预测(使用pytorch框架)

上一篇我们讲了使用CNN进行分类的python代码: Mr.看海:【深度学习-第5篇】使用Python快速实现CNN分类(模式识别)任务,含一维、二维、三维数据演示案例(使用pytorch框架) 这一篇我们讲CNN的多变…

对网工的误解,早就不是一点半点了

号主:老杨丨11年资深网络工程师,更多网工提升干货,请关注公众号:网络工程师俱乐部 上午好,我的网工朋友 很多人对网工是有误解的,同为网工的我深有感受。 虽然我的阅历不如老杨总多,但也在这行…

开源与闭源 AI 模型:发展路径的比较与前瞻

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

解决 iOS 端小程序「saveVideoToPhotosAlbum:fail invalid video」问题

场景复现: const url https://mobvoi-digitalhuman-video-public.weta365.com/1788148372310446080.mp4uni.downloadFile({url,success: (res) > {uni.saveVideoToPhotosAlbum({filePath: res.tempFilePath,success: (res) > {console.log("res > &…

chap4 simple neural network

全连接神经网络 问题描述 利用numpy和pytorch搭建全连接神经网络。使用numpy实现此练习需要自己手动求导,而pytorch具有自动求导机制。 我们首先先手动算一下反向传播的过程,使用的模型和初始化权重、偏差和训练用的输入和输出值如下: 我…

R语言绘图 --- 折线图(Biorplot 开发日志 --- 1)

「写在前面」 在科研数据分析中我们会重复地绘制一些图形,如果代码管理不当经常就会忘记之前绘图的代码。于是我计划开发一个 R 包(Biorplot),用来管理自己 R 语言绘图的代码。本系列文章用于记录 Biorplot 包开发日志。 相关链接…

通过强化学习彻底改变大型数据集特征选择

文章目录 一、说明二、强化学习:特征选择的马尔可夫决策问题三、用于使用强化学习进行特征选择的 python 库3.1. 数据预处理3.2. 安装和导入FSRLearning库 四、结论和参考文献 一、说明 了解强化学习如何改变机器学习模型的特征选择。通过实际示例和专用的 Python 库…

Qt6.4.2基于CMake添加Qt3DCore模块报错

在文档中说明是添加 find_package(Qt6 REQUIRED COMPONENTS 3dcore) target_link_libraries(mytarget PRIVATE Qt6::3dcore)find_package是没有问题,但是target_link_libraries会报错,报拼写错误,无法链接上Qt6::3dcore 需要使用“3DCore”…

工厂如何最大化mes系统的价值

mes系统(Manufacturing Execution System)是现代工厂管理中的一个重要系统,它可以实现生产过程中的信息约束与控制,促进生产流程的跟踪和分析,提高生产效率及质量。 一、整合mes系统和erp系统 mes系统和erp系统是两个…

STM32 IIC协议

本文代码使用 HAL 库。 文章目录 前言一、什么是IIC协议二、IIC信号三、IIC协议的通讯时序1. 写操作2. 读操作 四、上拉电阻作用总结 前言 从这篇文章开始为大家介绍一些通信协议,包括 UART,SPI,IIC等。 UART串口通讯协议 SPI通信协议 一、…

【深度学习】YOLOv10实战:20行代码将笔记本摄像头改装成目标检测监控

目录 一、引言 二、YOLOv10视觉目标检测—原理概述 2.1 什么是YOLO 2.2 YOLO的网络结构 三、YOLOv10视觉目标检测—训练推理 3.1 YOLOv10安装 3.1.1 克隆项目 3.1.2 创建conda环境 3.1.3 下载并编译依赖 3.2 YOLOv10模型推理 3.2.1 模型下载 3.2.2 WebUI推理 …

微服务架构-微服务架构的挑战与微服务化的具体时机

目录 一、微服务架构的挑战 1.1 概述 1.2 服务拆分 1.3 开发挑战 1.4 测试挑战 1.4.1 开箱即用、一键部署的集成环境 1.4.2 测试场景和测试确定性 1.4.3 微服务相关的非功能测试 1.4.4 自动化测试 1.5 运维挑战 1.5.1 监控 1.5.2 部署 1.5.3 问题追查 1.5.4 依赖管…

chrome调试手机网页

前期准备 1、 PC端安装好chrmoe浏览器 2、 安卓手机安装好chrmoe浏览器 3、 数据线 原文地址:https://lengmo714.top/343880cb.html 手机打开调试模式 进入手机设置,找到开发者模式,然后启用USB调试 打开PC端chrome调试功能 1、点击chr…

视频汇聚平台EasyCVR对接GA/T 1400视图库:结构化数据(人员/人脸、车辆、物品)对象XMLSchema描述

在信息化浪潮席卷全球的背景下,公安信息化建设日益成为提升社会治理能力和维护社会稳定的关键手段。其中,GA/T 1400标准作为公安视频图像信息应用系统的核心规范,以其结构化数据处理与应用能力,为公安信息化建设注入了强大的动力。…

webpack5零基础入门-19HMR的应用

1.定义 HMR即HotModuleReplacement 开发时,当我们修改了其中一个模块的代码webpack默认会将所有模块重新打包编译,速度很慢所以我们需要做到修改摸个模块代码,只对这个模块的代码重新打包编译,其他模块不变,这样打包…

【excel】设置二级联动菜单

文章目录 【需求】在一级菜单选定后,二级菜单联动显示一级菜单下的可选项【步骤】step1 制作辅助列1.列转行2.在辅助列中匹配班级成员 之前做完了 【excel】设置可变下拉菜单(一级联动下拉菜单),开始做二级联动菜单。 【需求】在…

算法(六)计数排序

文章目录 计数排序技术排序简介算法实现 计数排序 技术排序简介 计数排序是利用数组下标来确定元素的正确位置的。 假定数组有10个整数,取值范围是0~10,可以根据这有限的范围,建立一个长度为11的数组。数组下标从0到10,元素初始…