基于tensorflow的咖啡豆识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、前期工作

1. 设置GPU

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")
    print("GPU is available")

2. 导入数据

from tensorflow       import keras
from tensorflow.keras import layers,models
import numpy             as np
import matplotlib.pyplot as plt
import os,PIL,pathlib

data_dir = "F:/host/Data/咖啡豆识别数据/"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*.png')))

print("图片总数为:",image_count)

在这里插入图片描述

二、数据预处理

1. 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

batch_size = 8
img_height = 224
img_width = 224
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.1,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)

在这里插入图片描述

2. 可视化数据

plt.figure(figsize=(10, 4))  # 图形的宽为10高为5

for images, labels in train_ds.take(1):
    for i in range(8):
        
        ax = plt.subplot(2, 4, i + 1)  

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

在这里插入图片描述

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

在这里插入图片描述

3. 配置数据集

  • shuffle() :打乱数据,关于此函数的详细介绍可以参考:https://zhuanlan.zhihu.com/p/42417456
  • prefetch() :预取数据,加速运行,其详细介绍可以参考我前两篇文章,里面都有讲解。
  • cache() :将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)

train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds   = val_ds.map(lambda x, y: (normalization_layer(x), y)) 
image_batch, labels_batch = next(iter(val_ds))
first_image = image_batch[0]

# 查看归一化后的数据
print(np.min(first_image), np.max(first_image))

三、构建VGG-16网络

from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout

def VGG16(nb_classes, input_shape):
    # 输入层
    input_tensor = Input(shape=input_shape)
    # 卷积层1
    x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv1')(input_tensor)
    x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv2')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block1_pool')(x)
    # 卷积层2
    x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv1')(x)
    x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv2')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block2_pool')(x)
    # 卷积层3
    x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv1')(x)
    x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv2')(x)
    x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv3')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block3_pool')(x)
    # 卷积层4
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv1')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv2')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv3')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block4_pool')(x)
    # 卷积层5
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv1')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv2')(x)
    x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv3')(x)
    x = MaxPooling2D((2,2), strides=(2,2), name = 'block5_pool')(x)
    # 展平层
    x = Flatten()(x)
    # 全连接层1
    x = Dense(4096, activation='relu',name='fc1')(x)
    # 全连接层2
    x = Dense(4096, activation='relu',name='fc2')(x)
    # 输出层
    output_tensor = Dense(nb_classes, activation='softmax',name='predictions')(x)
    # 创建模型
    model = Model(input_tensor, output_tensor)
    return model

# 创建模型
model=VGG16(len(class_names), (img_width, img_height, 3))

# 打印模型结构
model.summary()

在这里插入图片描述

3. 网络结构图

关于卷积的相关知识可以参考文章:https://mtyjkh.blog.csdn.net/article/details/114278995

结构说明:

  • 13个卷积层(Convolutional Layer),分别用blockX_convX表示
  • 3个全连接层(Fully connected Layer),分别用fcX与predictions表示
  • 5个池化层(Pool layer),分别用blockX_pool表示

VGG-16包含了16个隐藏层(13个卷积层和3个全连接层),故称为VGG-16

在这里插入图片描述

四、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
# 设置初始学习率
initial_learning_rate = 1e-4

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, 
        decay_steps=30,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.92,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)

model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

五、训练模型

epochs = 20

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)

在这里插入图片描述

六、可视化结果

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

七、个人小结

在本次咖啡豆识别项目中,我们通过设置GPU、导入并预处理数据、构建深度学习模型,以及对模型进行训练和评估,实现了对咖啡豆图像的自动识别。整个过程涵盖了数据加载与可视化、数据集配置、模型构建与优化等关键步骤,最终显著提升了图像分类的准确性,同时也加深了我们对深度学习技术的实践理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/657603.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Redhat9 LAMP安全配置方案及测试

目录 数据库主机 安装Mariadb数据库服务 设置mariadb开机自动启动 Php主机 部署Apache服务器 设置apache服务开机自启 安装php 安装 phpMyAdmin 打开测试机 更新软件包列表: 首先,确保你的软件包列表是最新的。打开终端并输入以下命令&#xf…

电脑如何远程访问?

【天联】的使用场景 电脑远程访问在现代科技的发展中扮演了重要的角色。对于企业和个人用户来说,远程访问的便利性提供了许多机会和可能性。作为一种高效的工具,【天联】具有广泛的应用场景,可以实现异地统一管理、协同办公以及远程数据采集…

咖啡看书休闲时光404错误页面源码

源码介绍 咖啡看书休闲时光404错误页面源码,源码由HTMLCSSJS组成,记事本打开源码文件可以进行内容文字之类的修改,双击html文件可以本地运行效果,也可以上传到服务器里面,重定向这个界面 源码效果 源码下载 咖啡看书…

WEB安全:Content Security Policy (CSP) 详解

Content Security Policy (CSP) 是一种强大的网页安全机制,用于防止跨站脚本 (XSS) 和其他注入攻击。通过设置一系列的内容安全策略,CSP 可以限制网页可以加载的资源,从而保护用户数据和网站的安全性。 什么是 XSS 攻击? 跨站脚本攻击 (XSS) 是一种常见的安全漏洞,攻击者…

2024年JAVA、C++、Pyhton学哪种语言更容易进国央企?

对于不同编程语言在进入国有企业的观点大体是正确的,不过在实际选择时还需考虑一些因素。我这里有一套编程入门教程,不仅包含了详细的视频讲解,项目实战。如果你渴望学习编程,不妨点个关注,给个评论222,私信…

“2024南京智博会”共同探索智能科技产业创新发展新路径

随着全球数字化浪潮的深入推进,智慧城市、物联网与大数据等领域的发展成为推动经济社会发展的重要力量。在这样的背景下,2024南京国际智慧城市、物联网、大数据博览会(南京智博会)的举办,无疑为国内外企业提供了一个绝…

如何成为一名合格的JAVA程序员?

如何成为一名称职的Java编程人员?你一定不能错过的两本书。 第一本《Java核心技术速学版(第3版)》! 1.经典Java作品《Java核心技术》的速学版本,降低学习门槛,帮助读者更容易学习Java,更快地把…

基于ssm的微信小程序的居民健康监测系统

采用技术 基于ssm的微信小程序的居民健康监测系统的设计与实现~ 开发语言:Java 数据库:MySQL 技术:SpringMVCMyBatis 工具:IDEA/Ecilpse、Navicat、Maven 页面展示效果 后端页面 用户信息管理 健康科普管理 公告管理 论坛…

C++线程任务队列模型

功能描述 实现一个任务队列,用于任务的执行 任务队列 任务队列可以添加、删除任务,实现对任务的管理添加任务后,任务队列可以开始执行任务队列执行任务方式为串行执行 任务 任务执行需要持续一段10s内随机的时间,执行过程通过…

每天五分钟深度学习:如何使用计算图来反向计算参数的导数?

本文重点 在上一个课程中,我们使用一个例子来计算函数J,也就相当于前向传播的过程,本节课程我们将学习如何使用计算图计算函数J的导数。相当于反向传播的过程。 计算J对v的导数,dJ/dv3 计算J对a的导数,dJ/da&#xf…

JVM学习-字节码指令集(一)

概述 Java字节码对于虚拟机,好像汇编语言对于计算机,属于基本执行指令Java虚拟机的指令由一个字节长度的,代表某种特定操作含义 的数字(称为操作码Opcode)以及跟随其后的零至多个代表此操作所需参数(操作数,Operands)而构成&…

如何实时掌握手机号状态的API利器分析

在移动互联网的时代,手机号码不仅是通信的连接点,也是用户身份的关键识别。手机状态查询API 通过提供实时的手机号码状态查询服务,协助企业和组织更有效地管理用户信息,提升服务流程。 手机状态查询API 通过与电信运营商的数据库进…

UE5 使用外置摄像头进行拍照并保存到本地

连接外置摄像头功能:https://docs.unrealengine.com/4.27/zh-CN/WorkingWithMedia/IntegratingMedia/MediaFramework/HowTo/UsingWebCams/ 核心功能:UE4 相机拍照功能(图片保存)_ue 移动端保存图片-CSDN博客 思路是: …

《python编程从入门到实践》day41

# 昨日知识点回顾 用户注销、注册,限制访问,新主题关联到当前用户 # 今日知识点学习 第20章 设置应用程序的样式并部署 20.1 设置项目“学习笔记”的样式 20.1.1 应用程序django-bootstrap4 # settings.py ---snip--- INSTALLED_APPS [# 我的应用程序…

免费,Python蓝桥杯等级考试真题--第14级(含答案解析和代码)

Python蓝桥杯等级考试真题–第14级 一、 选择题 答案:B 解析:键为‘B’对应的值为602,故答案为B。 答案:A 解析:字典的符合为花括号,先键后值,故答案为A。 答案:C 解析&#xff1a…

磁盘管理后续——盘符漂移问题解决

之前格式化磁盘安装了文件系统,且对磁盘做了相应的挂载,但是服务器重启后挂载信息可能有问题,或者出现盘符漂移、盘符变化、盘符错乱等故障,具体是dev/sda, sdb, sdc 等等在某些情况下会混乱掉 比如sda变成了sdb或者sdc变成了sdb等…

100个 Unity小游戏系列七 -Unity 抽奖游戏专题五 刮刮乐游戏

一、演示效果 二、知识点讲解 2.1 布局 void CreateItems(){var rewardLists LuckyManager.Instance.CalculateRewardId(rewardDatas, Random.Range(4, 5));reward_data_list reward_data_list ?? new List<RewardData>();reward_data_list.Clear();for (int i 0; …

ADS基础教程17 - 创建含参子图

设计加密保护IP 一、引言二、参数设计 一、引言 将一个子图内部元器件的参数设置成可以在外部进行修改的参数&#xff0c;能够使得封装的子图更加灵活和通用。 二、参数设计 (1)打开一个子图&#xff0c;在菜单栏中选择File–>Design Parameters… (2)弹出的对话框中&am…

国产PS插件新选择;StartAI平替中的佼佼者!

前言 在设计的世界里&#xff0c;每一个细节都至关重要。设计师们常常面临时间紧迫、创意受限、工具复杂等挑战。Photoshop虽强大&#xff0c;但繁琐的操作和高昂的成本往往令人望而却步。今天我就为大家介绍一款PSAI插件——StartAI&#xff0c;一款专为Photoshop设计的国产A…

Django配置

后端开发&#xff1a; python 解释器、 pycharm 社区版、 navicate 、 mysql(phpstudy) 前段开发&#xff1a; vs code 、 google 浏览器 django 项目配置 配置项目启动方式 创建模型 创建一个应用 在应用中创建模型类 根据模型类生成数据表 创建应用 创建模型类 …