第P7周:咖啡豆识别(VGG-16复现)

>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/rbOOmire8OocQ90QM78DRA) 中的学习记录博客**
>- **🍖 原作者:[K同学啊 | 接辅导、项目定制](https://mtyjkh.blog.csdn.net/)**

一、前期工作

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")             #忽略警告信息

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device



2. 导入数据

import os,PIL,random,pathlib

data_dir = './7-data/'
data_dir = pathlib.Path(data_dir)

data_paths  = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
classeNames

# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    # transforms.RandomHorizontalFlip(), # 随机水平翻转
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

test_transform = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

total_data = datasets.ImageFolder("./7-data/",transform=train_transforms)
total_data

3. 划分数据集





二、手动搭建VGG-16模型

VGG-16结构说明:

●13个卷积层(Convolutional Layer),分别用blockX_convX表示
●3个全连接层(Fully connected Layer),分别用fcX与predictions表示
●5个池化层(Pool layer),分别用blockX_pool表示

VGG-16包含了16个隐藏层(13个卷积层和3个全连接层),故称为VGG-16

这里,我制作了一个视频来展示VGG-16的传播过程

play

0:00/0:22

倍速

volumeUp

fullscreen

fullscreen

VGG-16网络动画展示

FC7

FE8

FC6

7*7*512

1*1*512

1*1*1000

1*1*512

CONV5

CONV4

14*14*512

CONV3

28*28*512

CONV2

56*56*256

112*112*128

CONVOLUTION+RELU

K同学啊制作

MAX POOLING

CONVI

百度/谷歌/微信搜索:K同学啊

224*224*64

FULLY CONNECTED+RELU

image.png


1. 搭建模型



2. 查看模型详情


三、 训练模型
1. 编写训练函数


2. 编写测试函数

测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器


3. 正式训练

model.train()、model.eval()训练营往期文章中有详细的介绍。

📌如果将优化器换成 SGD 会发生什么呢?请自行探索接下来发生的诡异事件的原因。



四、 结果可视化
1. Loss与Accuracy图

 

TRAINING AND VALIDATION ACCURACY

TRAINING AND

VALIDATION LOSS

1.4

1.0

TRAINING LOSS

1.2

TEST LOSS

1.0

0.8

0.8

0.6

0.6

0.4

0.4

0.2

TRAINING ACCURACY

TEST ACCURACY

0.0

75

35

15

35

25

20

40

30

40

30

output_29_0.png


2. 指定图片进行预测
 

Python复制代码

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

from PIL import Image

classes = list(total_data.class_to_idx)

def predict_one_image(image_path, model, transform, classes):

test_img = Image.open(image_path).convert('RGB')

plt.imshow(test_img) # 展示预测的图片

test_img = transform(test_img)

img = test_img.to(device).unsqueeze(0)

model.eval()

output = model(img)

_,pred = torch.max(output,1)

pred_class = classes[pred]

print(f'预测结果是:{pred_class}')

Python复制代码

1

2

3

4

5

# 预测训练集中的某张照片

predict_one_image(image_path='./7-data/Dark/dark (1).png',

model=model,

transform=train_transforms,

classes=classes)

Plain Text复制代码

1

预测结果是:Dark

25

50

75

100

125

150

175

200

50

100

150

200

output_32_1.png


3. 模型评估
 

Python复制代码

1

2

best_model.eval()

epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)

Python复制代码

1

epoch_test_acc, epoch_test_loss

Plain Text复制代码

1

(0.9916666666666667, 0.035762640996836126)

Python复制代码

1

2

# 查看是否与我们记录的最高准确率一致

epoch_test_acc

Plain Text复制代码

1

0.9916666666666667

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/248112.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【ET8框架入门】1.运行指南

主要学习网址 论坛地址为:https://et-framework.cn Git地址为:GitHub - egametang/ET: Unity3D Client And C# Server Framework 官方QQ群 : 474643097 多线程多进程框架设计-01预告_哔哩哔哩_bilibili 项目检出 检出项目切换到release8.0分支 Git…

网络安全——基于Snort的入侵检测实验

一、实验目的要求: 二、实验设备与环境: 三、实验原理: 四、实验步骤: 五、实验现象、结果记录及整理: 六、分析讨论与思考题解答: 七、实验截图: 一、实验目的要求: 1、掌握…

visdom使用小技巧

visdom常用于深度学习迭代过程的可视化。在代码中,需要在相应位置增加如下代码: import visdomvis visdom.Visdom(envudomain_accuracy) #设置的环境名称vis.line(Xnp.array([epoch]), Ynp.array([domain_accuracy]),windomain_accuracy,updateappend,…

持续集成交付CICD:Jenkins使用基于SaltStack的CD流水线部署前后端应用

目录 一、实验 1.Jenkins使用基于SaltStack的CD流水线部署后端应用 2.Jenkins使用基于SaltStack的CD流水线部署前端应用 一、实验 1.Jenkins使用基于SaltStack的CD流水线部署后端应用 (1)GitLab添加Token (2)Jenkins添加凭据 (3&#xf…

Modbus转Profinet网关配置步骤

Modbus转Profinet网关(XD-MDPN100/200)是用于将Modbus协议和Profinet协议进行转换并进行通迅的设备。Modbus转Profinet网关(XD-MDPN100/200)无论是新项目还是改造项目都可轻松配置完成通迅互联。 正确的安装和配置对于确保设备的正…

重新思考U-Net在医学超声图像分割中的应用(NU-net)

Rethinking the Unpretentious U-net for Medical Ultrasound Image Segmentation 摘要: 乳腺肿瘤分割是帮助我们描述和定位肿瘤区域的关键步骤之一。然而,乳腺肿瘤形态多变、边界模糊、强度分布相似,给乳腺肿瘤的准确分割带来了挑战。近年…

matter模组有无源测试事例

测试一款matter模组的硬件性能 1.1 天线阻抗、电压驻波比测试 主要测试:PCB板载天线设计效率及板材PCB铜面的平整度等 1.2 模组有源数据测试 主要测试:模组的阻抗匹配、频偏等情况 1.3 模组传输能量精度 主要测试:矢量误差等数据 1.4 模…

多家高校设立“鸿蒙班”;ChatGPT 成为 Nature 年度十大人物丨 RTE 开发者日报 Vol.107

开发者朋友们大家好: 这里是 「RTE 开发者日报」 ,每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享 RTE (Real Time Engagement) 领域内「有话题的 新闻 」、「有态度的 观点 」、「有意思的 数据 」、「有思考的 文…

C# WPF上位机开发(利用tcp/ip网络访问plc)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 c# wpf如果是用来开发非标上位机的,那么和plc的通信肯定是少不了的。而且,大部分plc都支持modbus协议,所以这个…

焊盘:十字连接VS全覆盖 铺铜

在铺铜规则中,焊盘连接方式有两种: 十字连接 优点:较好焊接:因铺铜面积减少,温度下降速度降低,较好焊接,不易虚焊。 缺点:载流能力较弱:铺铜面积↓ → 载流能力↓全连接…

IDEA:TomEE 主目录/基目录出错

错误示例: 错误分析:创建Tomcat项目时,错误地创建了TomcatEE项目,重新创建即可。

【Java 集合】ConcurrentLinkedQueue

在日常中, 我们用到的数据结构有很多: 数组, 链表, 树等, 而在这些结构中, 还有一个叫做队列的存在。 和其他的集合相同, Java 原生提供了不同的实现。 而如果我们需要一个线程安全的队列的话, 可以基于实际的场景进行选择, 比如基于数组实现同时操作上会阻塞的 ArrayBlockingQ…

BRAS应用场景及测试方法

什么是BRAS? 宽带接入服务器(Broadband Remote Access Server,简称BRAS)是面向宽带网络应用的新型接入网关,它位于骨干网的边缘层,可以完成用户带宽的IP/ATM网的数据接入。 宽带接入服务器(BRAS&#xff0…

STL中sort的底层实现

文章目录 1、源码分析2、算法优化3、总结 在讲解STL中sort的底层原理之前,先引申出这样几个问题? ①STL中sort的底层是采用哪种或者哪几种排序? ②STL中sort会导致栈溢出吗? ③快速排序的时间复杂度是不稳定的 l o g 2 n log_2n l…

基于Java (spring-boot)的二手物品交易平台

​ 一、项目介绍 1、管理员功能 (1)登录:管理员能够根据账号访问系统。 (2)用户管理:管理员可以添加、删除、修改用户信息,查看用户列表,对用户进行管理和控制。例如&#xff0c…

死锁的预防、避免、检测和消除

一、预防死锁 1. 破坏互斥条件 2. 破坏不剥夺条件 3.破坏请求和保持条件 4.破坏循环等待条件 二、避免死锁 避免死锁的一种方法是使用银行家算法,它涉及到安全序列的概念。银行家算法是一种资源分配和死锁避免的算法,它确保系统能够分配资源而不会导致死…

十五、YARN辅助架构

1、学习内容 (1)了解什么是代理服务器 (2)了解什么是历史服务器 2、辅助架构 (1)辅助架构的由来 对于YARN架构来讲,除了ResourceManager集群资源总管家、NodeManager单机资源管家两个核心角…

『番外篇三』Swift “乱弹”之带索引遍历异步序列(AsyncSequence)

概览 在 Swift 开发中,我们往往在遍历集合元素的同时希望获得元素对应的索引。在本课中,我们将向小伙伴们展示除 enumerated() 方法之外的几种实现思路。在玩转普通集合之后,我们将用“魔法棒”进一步搞定异步序列带索引遍历的实现。 在本篇博主中,您将学到以下内容: 概…

Makefile基础使用与原理

一、基本概念 通常我们编写好代码后,都需要编译,只是这些操作是由IDE来完成,我们只需要点击一个编译按钮。当项目工程越来越庞大,存在几十个甚至更多的文件的时候,你使用的不是IDE工具,而是命令行&#xf…

基于JavaWeb+SSM+Vue微信小程序的移动学习平台系统的设计和实现

基于JavaWebSSMVue微信小程序的移动学习平台系统的设计和实现 源码获取入口Lun文目录前言主要技术系统设计功能截图订阅经典源码专栏Java项目精品实战案例《500套》 源码获取 源码获取入口 Lun文目录 第1章 绪论 1 1.1 课题背景 1 1.2 课题意义 1 1.3 研究内容 2 第2章 开发环…