看图识药,python开发实现基于VisionTransformer的119种中草药图像识别系统

中药药材图像识别相关的实践在前面的系列博文中已经有了相应的实践了,感兴趣的话可以自行移步阅读即可,每篇文章的侧重点不同:

《python基于轻量级GhostNet模型开发构建23种常见中草药图像识别系统》

《基于轻量级MnasNet模型开发构建40种常见中草药图像识别系统》

《基于ResNet模型的908种超大规模中草药图像识别系统》

本文的核心思想是想要应用实践VIT(Vision Transformer)来开发构建图像识别系统,首先看下实例效果:

Vision Transformer(ViT)是一种基于自注意力机制的视觉模型,用于图像分类和其他计算机视觉任务。它是由Dosovitskiy等人在2020年提出的,将Transformer模型成功应用于图像领域。

ViT的构建原理如下:

  1. 输入图像划分为固定大小的图像块(或称为“补丁”),并通过一个线性变换将每个图像块映射为一个向量。这些向量组成了输入序列。

  2. 使用位置编码将位置信息引入输入序列。位置编码是一个学习的过程,用于为每个输入位置提供相对和绝对位置信息。

  3. 输入序列首先通过多头注意力(Multi-Head Attention)模块进行处理。多头注意力允许模型在不同的表示子空间中学习关注不同的图像特征。

  4. 在多头注意力模块中,每个补丁向量都与其他补丁向量进行交互,并计算其自注意力得分。这些得分表示了补丁之间的相关性,模型可以根据这些得分对不同补丁的重要性进行加权。

  5. 通过加权和补丁向量的线性组合,得到了每个补丁向量的新表示。这个表示包含了该补丁与其他补丁的相关性信息。

  6. 接下来,通过一个前馈神经网络(Feed-Forward Network)对每个补丁向量的新表示进行非线性变换,以更好地捕捉图像特征。

  7. 经过多个注意力和前馈神经网络堆叠的层,最终得到了一个编码了整个图像信息的向量序列。

  8. 为了进行图像分类,可以使用一个全局平均池化层(Global Average Pooling)将向量序列转换为一个固定长度的向量表示。然后,可以通过一个全连接层将这个向量映射到不同类别的概率分布。

总体来说,Vision Transformer通过将图像划分为补丁并利用自注意力机制对这些补丁进行交互,实现了对图像特征的学习和编码。相较于传统的卷积神经网络,ViT不需要使用卷积操作,而是完全基于自注意力机制进行图像特征的建模。

本文使用到的数据集来源于网络数据采集与人工处理,主要是收集了常见的100多种中药药材,数据集加载解析处理实现如下:

# 加载解析创建数据集
if not os.path.exists("dataset.json"):
    train_dataset = []
    test_dataset = []
    all_dataset = []
    classes_list = os.listdir(datasetDir)
    classes_list.sort()
    num_classes = len(classes_list)
    if not os.path.exists("labels.json"):
        with open("labels.json","w") as f:
            f.write(json.dumps(classes_list))
    print("classes_list: ", classes_list)
    for one_label in os.listdir(datasetDir):
        oneDir = datasetDir + one_label + "/"
        for one_pic in os.listdir(oneDir):
            one_path = oneDir + one_pic
            one_ind = classes_list.index(one_label)
            all_dataset.append([one_ind, one_path])
    train_ratio = 0.90
    train_num = int(train_ratio * len(all_dataset))
    all_inds = list(range(len(all_dataset)))
    train_inds = random.sample(all_inds, train_num)
    test_inds = [one for one in all_inds if one not in train_inds]
    for one_ind in train_inds:
        train_dataset.append(all_dataset[one_ind])
    for one_ind in test_inds:
        test_dataset.append(all_dataset[one_ind])

简单看下实例数据:

【艾叶】

【陈皮】

【党参】

【何首乌】

基础的vit实现如下:

import tensorflow as tf
from tensorflow.keras import layers

def create_vision_transformer(input_shape, num_classes, num_layers, d_model, num_heads, mlp_dim, dropout_rate):
    inputs = tf.keras.Input(shape=input_shape)
    x = layers.Conv2D(filters=d_model, kernel_size=1)(inputs)
    x = layers.Reshape((-1, d_model))(x)
    x = layers.LayerNormalization(epsilon=1e-6)(x)

    # Patch embeddings
    num_patches = x.shape[1]
    patch_size = x.shape[2]
    x = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)(x, x)
    x = layers.Add()([x, inputs])
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.Conv1D(filters=d_model, kernel_size=1)(x)
    x = layers.LayerNormalization(epsilon=1e-6)(x)

    # Transformer Encoder layers
    for _ in range(num_layers):
        # Attention and MLP block
        y = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)(x, x)
        y = layers.Add()([y, x])
        y = layers.LayerNormalization(epsilon=1e-6)(y)
        y = layers.Conv1D(filters=mlp_dim, kernel_size=1, activation="relu")(y)
        y = layers.Conv1D(filters=d_model, kernel_size=1)(y)
        y = layers.Add()([y, x])
        x = layers.LayerNormalization(epsilon=1e-6)(y)

    # Classification head
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dropout(rate=dropout_rate)(x)
    x = layers.Dense(units=num_classes, activation="softmax")(x)

    model = tf.keras.Model(inputs=inputs, outputs=x)
    return model

# Example usage
input_shape = (224, 224, 3)
num_classes = 1000
num_layers = 12
d_model = 512
num_heads = 8
mlp_dim = 2048
dropout_rate = 0.1

model = create_vision_transformer(input_shape, num_classes, num_layers, d_model, num_heads, mlp_dim, dropout_rate)
model.summary()

默认100次epoch的迭代计算,等待训练完成后对整体训练过程进行可视化,如下所示:
【准确率曲线】

【loss曲线】

可视化推理实例如下所示:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/240241.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Day58力扣打卡

打卡记录 下一个更大元素 IV&#xff08;单调栈 x2&#xff09; 链接 class Solution:def secondGreaterElement(self, nums: List[int]) -> List[int]:ans [-1] * len(nums)s []t []for i, x in enumerate(nums):while t and nums[t[-1]] < x:ans[t.pop()] x # t…

CSS的三大特性(层叠性、继承性、优先级---------很重要)

CSS 有三个非常重要的三个特性&#xff1a;层叠性、继承性、优先级。 层叠性 场景&#xff1a;相同选择器给设置相同的样式&#xff0c;此时一个样式就会覆盖&#xff08;层叠&#xff09;另一个冲突的样式。层叠性主要解决样式冲突 的问题 原则&#xff1a;  样式冲突&am…

nodejs微信小程序+python+PHP的外卖数据分析-计算机毕业设计推荐django

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.2 express框架介绍 6 2.4 MySQL数据库 4 第3章 系统分析 5 3.1 需求分析 5 3.2 系统可行性分析 5 3.2.1技术可行性&#xff1a;…

Ubuntu编译文件安装SNMP服务

net-snmp源码下载 http://www.net-snmp.org/download.html 编译步骤 指定参数编译 ./configure --prefix/root/snmpd --with-default-snmp-version"2" --with-logfile"/var/log/snmpd.log" --with-persistent-directory"/var/net-snmp" --wi…

应用程序映射的 5 个安全优势

现代企业依靠无数的软件应用程序来执行日常运营。这些应用程序相互连接并协同工作以提供所需的服务。了解这些应用程序如何相互交互以及底层基础设施对于任何组织都至关重要。这就是应用程序映射概念的用武之地。 顾名思义&#xff0c;应用程序映射是创建应用程序架构&#xf…

[NAND Flash 3.2] 3D NAND 工艺与发展前沿

依公知及经验整理&#xff0c;原创保护&#xff0c;禁止转载。 专栏 《深入理解NAND Flash》 全文 6200 字&#xff0c;​2023.12.12 更新 1. 导论 1.1 何为 3D NAND? 3D NAND, 也叫做 Sumsung V-NAND, 是一种高密度闪存。 以前&#xff0c;把NAND闪存颗粒&#xff0c;直接…

互斥锁的原理

互斥锁&#xff08;Mutex&#xff0c;全称Mutual Exclusion&#xff09;是一种同步机制&#xff0c;用于确保在任意时刻&#xff0c;只有一个线程可以访问共享资源&#xff0c;从而防止数据竞争和不一致性。互斥锁的基本思想是在进入临界区之前&#xff0c;先获取锁&#xff1b…

Python和Beautiful Soup爬虫助力提取文本内容

大家好&#xff0c;网络爬虫是一项非常抢手的技能&#xff0c;收集、分析和清洗数据是数据科学项目中最重要的部分。今天介绍如何从链接中爬取高质量文本内容&#xff0c;我们使用迭代&#xff0c;从大约700个链接中进行网络爬取。如果想直接跳转到代码部分&#xff0c;可以在下…

【JVM从入门到实战】(四)类的生命周期

什么是类的生命周期 类的生命周期描述了一个类加载、连接、初始化、使用、卸载的整个过程 一个类完整的生命周期如下&#xff1a; 加载阶段 加载阶段第一步是类加载器根据类的全限定名通过不同的渠道以二进制流的方式获取字节码信息。 程序员可以使用Java代码拓展的不同的渠道…

计算机视觉 基于视频识别场景了解GluonCV深度学习工具包

一、简述 GluonCV 提供计算机视觉领域最先进 (SOTA) 深度学习算法的实现。它旨在帮助工程师、研究人员和学生快速制作产品原型、验证新想法并学习计算机视觉。 GluonCV: a Deep Learning Toolkit for Computer Vision — gluoncv 0.11.0 documentationhttps://cv.gluon.ai/con…

FFmpeg的AVIOPROBE

文章目录 定义 可能你一直有疑问&#xff0c;ffmpeg的avformat是怎么提前知道码流是编码格式或者容器&#xff1f;恭喜你&#xff0c;看到这里&#xff0c;你找到答案了&#xff0c;在这里&#xff0c;ffmpeg通过这些probe函数来提前获取码流的编码格式。 看到下面的avs2_prob…

单变量线性回归的机器学习代码

本文为学习吴恩达版本机器学习教程的代码整理&#xff0c;使用的数据集为https://github.com/fengdu78/Coursera-ML-AndrewNg-Notes/blob/f2757f85b99a2b800f4c2e3e9ea967d9e17dfbd8/code/ex1-linear%20regression/ex1data1.txt 将数据集和py代码放到同一目录中&#xff0c;使…

多阶段构建:精妙优化Docker镜像大小和性能

在容器化应用的世界中&#xff0c;Docker镜像大小和性能优化是至关重要的。多阶段构建是一项强大的技术&#xff0c;通过精心设计Dockerfile&#xff0c;可以在构建镜像时去除不必要的组件&#xff0c;从而显著减小镜像大小&#xff0c;提高性能。本文章将深入讨论多阶段构建的…

Knowledge Graph知识图谱—9. Data Quality and Linking

9. Data Quality and Linking 9.1 How well are the linked open data in practice? Linked Open Vocabularies(LOV) project – analyze usage of vocabularies 9.2 Quality Linked Data Conformance vs. Quality Conformance: – i.e., following standards and best prac…

语音验证码可以用在哪些方面?

电商行业 由于行业竞争日渐激烈&#xff0c;大多数电商采用补贴用户的营销方式抢占市场&#xff0c;其中&#xff0c;新用户补贴较为常见&#xff0c;随之也衍生出一部分恶意刷单的人群。 而语音验证码在一定程度上保证了一个号码对应一个账号&#xff0c;大大增强了刷单难度…

分页存储管理

页框和页面 将内存空间分为一个个大小相等的分区 (比如:每个分区4KB)&#xff0c;每个分区就是一个“页框”(页框页内存块物理块物理页面)。每个页框有一个编号&#xff0c;即“页框号”(页框号页帧号内存块号物理块号物理页号)&#xff0c;页框号从0开始。 为了将各个进程的数…

JS基础源码之手写模拟new

JS基础源码之手写模拟new 手写模拟new初步实现最终实现 手写模拟new new 运算符创建一个用户定义的对象类型的实例或具有构造函数的内置对象类型之一。 我们先看看new实现了哪些功能&#xff1a; function Person (name,age){this.name name;this.age age;this.habit Games;…

2023全国职业院校技能大赛信息安全管理与评估正式赛(模块三CTF)

全国职业院校技能大赛高等职业教育组信息安全管理与评估 \任务书\ 模块三 网络安全渗透、理论技能与职业素养 极安云科专注技能竞赛&#xff0c;包含网络建设与运维和信息安全管理与评估两大赛项&#xff0c;及各大CTF&#xff0c;基于两大赛项提供全面的系统性培训&#xf…

目标检测锚框

目标检测锚框 最开始呢&#xff0c;我们需要先介绍一下框&#xff0c;先学会一下怎么画框 导入所需要的包 from PIL import Image import d2lzh_pytorch as d2l import numpy as np import math import torch展示一下本次实验我们用到的图像&#xff0c;猫狗 d2l.set_figsiz…

智能优化算法应用:基于帝国主义竞争算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于帝国主义竞争算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于帝国主义竞争算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.帝国主义竞争算法4.实验参数设定…