Week-T11-优化器对比试验

文章目录

  • 一、准备环境
  • 二、准备数据
  • 三、搭建训练网络
  • 三、训练模型
    • (1)VSCode训练情况:
    • (2)`jupyter notebook`训练情况:
  • 四、模型评估 & 模型预测
    • 1、绘制Accuracy-Loss图
    • 2、显示model2的预测效果
  • 五、总结
    • 1、`plt.savefig("./数据展示.jpg")`保存的图片在文件夹内打开是空白的,如下图所示:
    • 2. 优化器是什么?包括哪些?

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

本文主要探究不同优化器、以及不同参数配置对模型的影响,最终对Adam、SGD优化器进行比较,并绘制比较结果。

使用的数据集为咖啡豆数据集,共有四类。

优化器常用的有Adam、SGD。优化器的归纳将放在文末的总结部分。

本文将使用Adam优化器的模型命名为"model1",使用SGD优化器的模型命名为"model2",然后根据模型训练结果绘制各自的Accuracy-Loss图。比较得出,在运行环境、epoch次数相同、模型结构相同等条件下,Adam优化器的整体情况要优于SGD优化器。

一、准备环境

# 1. 设置环境
import sys
import tensorflow as tf
from datetime import datetime

from tensorflow          import keras
import matplotlib.pyplot as plt
import pandas            as pd
import numpy             as np
import warnings,os,PIL,pathlib

print("---------------------1.配置环境------------------")
print("Start time: ", datetime.today())
print("tensorflow version: " + tf.__version__)
print("Python version: " + sys.version)

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
        # 打印显卡信息,确认GPU可用
    print("GPU: " + gpus)
else:
    print("Using CPU")

warnings.filterwarnings("ignore")             #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号

在这里插入图片描述

Q1: VSCode虚拟环境安装pandas
在这里插入图片描述

二、准备数据

# 2.导入数据
# 本次使用咖啡豆数据集(共4类)
print("---------------------2.1 从本地读取数据------------------")
data_dir    = "D:/jupyter notebook/DL-100-days/datasets/coffebeans-data"
data_dir    = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)

batch_size = 16
img_height = 336
img_width  = 336

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
print("---------------------2.2 划分训练数据------------------")
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
print("---------------------2.3 划分验证数据------------------")
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

print("---------------------2.4 打印数据类别 && 数据的shape------------------")
class_names = train_ds.class_names
print(class_names)

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

print("---------------------2.5 配置数据集------------------")
AUTOTUNE = tf.data.AUTOTUNE

def train_preprocessing(image,label):
    return (image/255.0,label)

train_ds = (
    train_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

val_ds = (
    val_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

print("---------------------2.6 数据可视化,显示部分样本图片------------------")
plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")

for images, labels in train_ds.take(1):
    for i in range(15):
        plt.subplot(4, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)

        # 显示图片
        plt.imshow(images[i])
        # 显示标签
        plt.xlabel(class_names[labels[i]-1])

plt.show()
plt.savefig("./数据展示.jpg")

在这里插入图片描述
在这里插入图片描述

Q2:plt.savefig("./数据展示.jpg")保存的图片在文件夹内打开是空白的

三、搭建训练网络

print("---------------------3. 搭建训练网络,此处预训练模型调用VGG-16官方模型------------------")
# 自定义一个创建模型的函数,形参是优化器类型,预训练模型是VGG-16,但屏蔽了自带的训练部分以及顶层,然后对输出进行处理
# 在此处创建了两个网络,拥有不同的优化器类型
from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Model

def create_model(optimizer='adam'):
    # 加载预训练模型
    vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',
                                                                include_top=False,
                                                                input_shape=(img_width, img_height, 3),
                                                                pooling='avg')
    for layer in vgg16_base_model.layers:
        layer.trainable = False

    X = vgg16_base_model.output
    
    X = Dense(170, activation='relu')(X)
    X = BatchNormalization()(X)
    X = Dropout(0.5)(X)

    output = Dense(len(class_names), activation='softmax')(X)
    vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)

    vgg16_model.compile(optimizer=optimizer,
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])
    return vgg16_model

model1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())
model2.summary()

在这里插入图片描述

三、训练模型

print("---------------------4.启动训练,epoch==50------------------")
# try:加入早停试一下,一个epoch跑完要220s,时间还是有点久
NO_EPOCHS = 50

history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)

(1)VSCode训练情况:

model1.fit():Adam优化器
在这里插入图片描述
model2.fit():SGD优化器
在这里插入图片描述

(2)jupyter notebook训练情况:

model1.fit():即Adam优化器
在这里插入图片描述
model2.fit():即SGD优化器
在这里插入图片描述

四、模型评估 & 模型预测

1、绘制Accuracy-Loss图

print("---------------------5.1 模型评估,绘制Accuracy-Loss图------------------")
from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率

acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']

loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']

epochs_range = range(len(acc1))

plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
   
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))
plt.savefig("./Accuracy-Loss图.jpg")
plt.show()

plt.show()显示的图片:
请添加图片描述
比较Accuracy图表,可以看出训练时Adam优化器的表现要稍优于SGD优化器,而验证时则相反。

Q: VSCode绘制出来的图咋这么奇怪?
改变plt.savefig("./Accuracy-Loss图.jpg")的位置后所保存的图片,比直接plt.show()的图片比例要好些。
在这里插入图片描述

2、显示model2的预测效果

print("---------------------5.2 模型预测------------------")
def test_accuracy_report(model):
    score = model.evaluate(val_ds, verbose=0)
    print('Loss function: %s, accuracy:' % score[0], score[1])
    
test_accuracy_report(model2)

VSCode环境下的预测结果:
在这里插入图片描述
jupyter notebook环境下的预测结果:
在这里插入图片描述

五、总结

1、plt.savefig("./数据展示.jpg")保存的图片在文件夹内打开是空白的,如下图所示:

在这里插入图片描述
将保存的语句放在plt.show()之前,因为plt.show()之后会默认打开一个空白画板。

2. 优化器是什么?包括哪些?

(参考文章也是来自训练营文章)

优化器是什么?

  • 优化器是一种算法,它在模型优化过程中,动态地调整梯度的大小和方向,使模型能够收敛到更好的位置,或者用更快的速度进行收敛。
  • 各类优化器方法总结如下:
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/197395.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C++ 泛型编程,函数模版和类模版

1.泛型编程 泛型编程:编写与类型无关的通用代码,是代码复用的一种手段。模板是泛型编程的基础 就比如说活字印刷术,就是提供一个模具,然后根据模具来印刷出不同的字。 泛型编程跟着类似,提供一个模版,根据这…

五、Microsoft群集服务(MSCS)环境的搭建

一、【目的】 学会利用Windows Server布置群集环境。 二、【设备】 FreeNAS11.2,Windows Server 2019 三、【要求】 学会利用Windows Server布置群集环境,掌握处理问题的能力。 配置表: 节点公网IP(public)内网IP(private)群集IP(clust…

欧拉LINUX 23.09版本上安装ORACLE 19c

前面解决了在RHEL9上安装ORACLE 19C的问题后,发现龙蜥 LINUX23 上可以安装ORACLE19C,网上搜了一下,欧拉 linux 22.03 上,没有成功安装ORACLE 19c 的先例,23.09就更不用说了,但看到的错误,不外服都是缺 libp…

玻色量子真机测试完整报告

​ 真机测试 2023年 2023.8 量子计算突破云渲染资源调度!真机测试完整报告公开! 2023.8 量子计算突破金融信用评分!真机测试完整报告公开! 组合优化问题专题 2023年 2023.7 玻色量子“揭秘”之旅行商问题与Ising建模 2023.…

OPENWRT解决配置pppoe后无法光猫路由管理界面

一、新建一个wan口 二、设置流量转发 设置完成后保存应用即可

pdf加密文件解密(pdf文件解密小工具)

工具放在文章末尾! 1.pdf文件加密后会有很多使用权限的限制很不方便,只要是为了pdf的数据不被二次利用,未加密的pdf功能都是可以正常使用的 2.加密后的pdf使用权限会被限制部分 3.工具只能解决pdf编辑等加密情况,不能解决文件打…

耗时一个星期整理的APP自动化测试工具大全

在本篇文章中,将给大家推荐14款日常工作中经常用到的测试开发工具神器,涵盖了自动化测试、APP性能测试、稳定性测试、抓包工具等。 一、UI自动化测试工具 1. uiautomator2 openatx开源的ui自动化工具,支持Android和iOS。主要面向的编程语言…

神器!使用 patchworklib 库进行多图排版真棒啊

如果想把多个图合并放在一个图里,如图,该如何实现 好在R语言 和 Python 都有对应的解决方案, 分别是patchwork包和patchworklib库。 推介1 我们打造了《100个超强算法模型》,特点:从0到1轻松学习,原理、…

与 PCIe 相比,CXL为何低延迟高带宽?

文章目录 前言1. LatencyPCIE 生产者消费则模型结论Flit 包PCIE/CXL.ioCXL.cace & .mem总结 2. BandWidth常见开销CXL.IO Link efficiencyPCIe Link efficiencyCXL.IO bandwidthCXL.mem/.cache bandwidth 参考 前言 CXL 规范里没有具体描述与PCIe 相比低延时高带宽的原因&…

只会在终端使用Python运行代码?这些高级用法了解了解

大部分同学在终端使用Python可能只是简单的执行代码,但其实结合一些Python内置模块或第三方库可以实现更高级且便捷的用法,一起看看吧 插播,更多文字总结指南实用工具科技前沿动态第一时间更新在公粽号【啥都会一点的研究生】 代码Benchmar…

Linux下基于MPI的hello程序设计

Linux下基于MPI的hello程序设计 一、MPICH并行计算库安装实验环境部署创建SSH信任连接,实现免密钥互相连接node1安装MPICH 3.4配置NFS注意(一定要先看)环境测试 二、HELLO WORLD并行程序设计 一、MPICH并行计算库安装 在Linux环境下安装MPICH执行环境,配…

详细介绍如何使用深度学习自动车牌(ALPR)识别-含(数据集+源码下载)

深度学习一直是现代世界发展最快的技术之一。深度学习已经成为我们日常生活的一部分,从语音助手到自动驾驶汽车,它无处不在。其中一种应用程序是自动车牌识别 (ALPR)。顾名思义,ALPR是一项利用人工智能和深度学习的力量自动检测和识别车辆车牌字符的技术。这篇博文将重点讨论…

11.28 知识回顾(Web框架、路由控制、视图层)

一、 web 框架 1.1 web框架是什么? 别人帮咱们写了一些基础代码------》我们只需要在固定的位置写固定的代码--》就能实现一个web应用 Web框架(Web framework)是一种开发框架,用来支持动态网站、网络应用和网络服务的开发。这大多…

Windows环境下的JDK安装与环境配置

一、JDK下载 1、打开Oracle官方网站下载页 Java Downloads | Oracle 中国 2、选择Java archive页,在版本列表中选择需要下载的版本 3、选择系统环境对应的版本,点击对应的下载按钮,弹出技术许可勾选框 4、勾选Oracle技术许可协议 5、输入Or…

python服装电商系统vue购物商城django-pycharm毕业设计项目推荐

系统面向的使用群体为商家和消费者,商家和消费者所承担的功能各不相同,所对象的权限也各不相同。对于消费者和商家设计的功能如下: 对于消费者设计了五大功能模块: (1) 商品信息:用户可在商品…

【shell】文本三剑客之sed详解

目录 一、sed简介(行编辑器) 二、基本用法 三、sed脚本格式(匹配地址 脚本命令) 1、不给地址,那么就是针对全文处理 2、单地址,表示#,指定的行,$表示最后一行,/pattt…

bat脚本执行py文件

天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…

IntelliJ IDEA 中有什么让你相见恨晚的技巧

一、条件断点 循环中经常用到这个技巧,比如:遍历1个大List的过程中,想让断点停在某个特定值。 参考上图,在断点的位置,右击断点旁边的小红点,会出来一个界面,在Condition这里填入断点条件即可&…

当TinyMCE富文本编辑器遇到Vue3+nuxt+ts项目,分享引入成功案例及过程中踩的那些坑

文章目录 前言遇到的坑插入上传图片插件上传图片请求与返回值处理本地文件引入报错解决源码 前言 如果你的前端项目技术栈使用的是Vue3nuxtts,并且老大让你集成一下那个传说中非常丝滑的TinyMCE富文本编辑器,那么恭喜你和我一样中大奖了。 网上找了好久…

数据结构 / day01 作业

1.定义结构体数组存储5个学生的信息:姓名,年龄,性别 定义函数实现输入,要求形参使用结构体指针接收 函数实现5个学生年龄排序(注意对年龄排序时,交换的是所有信息) 定义函数实现输出,要求形参使用结构体…