深度学习笔记11-优化器对比实验(Tensorflow)

  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

目录

一、导入数据并检查

二、配置数据集

三、数据可视化

四、构建模型

五、训练模型

六、模型对比评估

七、总结


一、导入数据并检查

import pathlib,PIL
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签

data_dir    = pathlib.Path("./T6")
image_count = len(list(data_dir.glob('*/*')))
batch_size = 16
img_height = 336
img_width  = 336
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

二、配置数据集

AUTOTUNE = tf.data.AUTOTUNE
#归一化处理
def train_preprocessing(image,label):
    return (image/255.0,label)

train_ds = (
    train_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

val_ds = (
    val_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

三、数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")

for images, labels in train_ds.take(1):
    for i in range(15):
        plt.subplot(4, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        # 显示图片
        plt.imshow(images[i])
        # 显示标签
        plt.xlabel(class_names[labels[i]-1])

plt.show()

四、构建模型

from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Model

def create_model(optimizer='adam'):
    # 加载预训练模型
    vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',
                                                                include_top=False,#不包含顶层的全连接层
                                                                input_shape=(img_width, img_height, 3),
                                                                pooling='avg')#平均池化层替代顶层的全连接层
    for layer in vgg16_base_model.layers:
        layer.trainable = False  #将 trainable属性设置为 False 意味着在训练过程中,这些层的权重不会更新

    X = vgg16_base_model.output
    
    X = Dense(170, activation='relu')(X)
    X = BatchNormalization()(X)
    X = Dropout(0.5)(X)

    output = Dense(len(class_names), activation='softmax')(X)#神经元数量等于类别数
    vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)

    vgg16_model.compile(optimizer=optimizer,
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])
    return vgg16_model

model1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())#随机梯度下降(SGD)优化器的
model2.summary()

五、训练模型

NO_EPOCHS = 20

history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)

六、模型对比评估

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率

acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']

loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']

epochs_range = range(len(acc1))

plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
   
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))

plt.show()

可以看出,在这个实例中,Adam优化器的效果优于SGD优化器

七、总结

      通过本次实验,学会了比较不同优化器(Adam和SGD)在训练过程中的性能表现,可视化训练过程的损失曲线和准确率等指标。这是一项非常重要的技能,在研究论文中,可以通过这些优化方法可以提高工作量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/952467.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

JavaEE之定时器及自我实现

在生活当中,有很多事情,我们不是立马就去做,而是在规定了时间之后,在到该时间时,再去执行,比如:闹钟、定时关机等等,在程序的世界中,有些代码也不是立刻执行,…

Qt学习笔记第81到90讲

第81讲 串口调试助手实现自动发送 为这个名叫“定时发送”的QCheckBox编写槽函数。 想要做出定时发送的效果,必须引入QT框架下的毫秒级定时器QTimer,查阅手册了解详情。 在widget.h内添加新的私有成员变量: QTimer *timer; 在widget类的构造…

【LeetCode】力扣刷题热题100道(16-20题)附源码 容器 子数组 数组 连续序列 三数之和(C++)

目录 1.盛最多水的容器 2.和为K的子数组 3.最大子数组和 4.最长连续序列 5.三数之和 1.盛最多水的容器 给定一个长度为 n 的整数数组 height 。有 n 条垂线,第 i 条线的两个端点是 (i, 0) 和 (i, height[i]) 。 找出其中的两条线,使得它们与 x 轴…

AI多模态技术介绍:视觉语言模型(VLMs)指南

本文作者:AIGCmagic社区 刘一手 AI多模态全栈学习路线 在本文中,我们将探讨用于开发视觉语言模型(Vision Language Models,以下简称VLMs)的架构、评估策略和主流数据集,以及该领域的关键挑战和未来趋势。通…

jenkins入门13--pipeline

Jenkins-pipeline(1)-基础 为什么要使用pipeline 代码:pipeline 以代码的形式实现,通过被捡入源代码控制, 使团队能够编译,审查和迭代其cd流程 可连续性:jenkins 重启 或者中断后都不会影响pipeline job 停顿&#x…

【线性代数】通俗理解特征向量与特征值

这一块在线性代数中属于重点且较难理解的内容,下面仅个人学习过程中的体会,错误之处欢迎指出,有更简洁易懂的理解方式也欢迎留言学习。 文章目录 概念计算几何直观理解意义PS.适用 概念 矩阵本身就是一个线性变换,对一个空间中的…

SQL多表联查、自定义函数(字符串分割split)、xml格式输出

记录一个报表的统计,大概内容如下: 多表联查涉及的报表有:房间表、买家表、合同表、交易表、费用表、修改记录表 注意:本项目数据库使用的是sqlserver(mssql),非mysql。 难点1:业主信息&#…

python学opencv|读取图像(三十)使用cv2.getAffineTransform()函数倾斜拉伸图像

【1】引言 前序已经学习了如何平移和旋转缩放图像,相关文章链接为: python学opencv|读取图像(二十七)使用cv2.warpAffine()函数平移图像-CSDN博客 python学opencv|读取图像(二十八&#xff0…

C语言数据结构与算法(排序)详细版

大家好,欢迎来到“干货”小仓库!! 很高兴在CSDN这个大家庭与大家相识,希望能在这里与大家共同进步,共同收获更好的自己!!无人扶我青云志,我自踏雪至山巅!!&am…

【竞技宝】CS2:HLTV2024选手排名TOP4-NiKo

北京时间2025年1月11日,HLTV年度选手排名正在持续公布中,今日凌晨正式公布了今年的TOP4选手为G2(目前已转为至Falcons)战队的NiKo。 选手简介 NiKo是一名来自波黑的CS职业选手,现年26岁。作为DOTA2饱负盛名的职业选手,NiKo在CS1.6时代就已经开始征战职业赛场。2012年,年仅15岁…

IOS界面传值-OC

1、页面跳转 由 ViewController 页面跳转至 NextViewController 页面 &#xff08;1&#xff09;ViewController ViewController.h #import <UIKit/UIKit.h>interface ViewController : UIViewControllerend ViewController.m #import "ViewController.h" …

树的模拟实现

一.链式前向星 所谓链式前向星&#xff0c;就是用链表的方式实现树。其中的链表是用数组模拟实现的链表。 首先我们需要创建一个足够大的数组h&#xff0c;作为所有结点的哨兵位。创建两个足够大的数组e和ne&#xff0c;一个作为数据域&#xff0c;一个作为指针域。创建一个变…

【ArcGIS微课1000例】0138:ArcGIS栅格数据每个像元值转为Excel文本进行统计分析、做图表

本文讲述在ArcGIS中,以globeland30数据为例,将栅格数据每个像元值转为Excel文本,便于在Excel中进行统计分析。 文章目录 一、加载globeland30数据二、栅格转点三、像元值提取至点四、Excel打开一、加载globeland30数据 打开配套实验数据包中的0138.rar中的tif格式栅格土地覆…

Redis集群模式下主从复制和哨兵模式

Redis主从复制是由一个Redis服务器或实例(主节点)来控制一个Redis服务器或实例(从节点),从节点从主节点获取数据更新数据 集群模式下主从数据复制过程 从服务器连接到主服务器,发送SYNC命令。主服务器接收到SYNC命令后,开始执行BGSAVE命令生成RDB文件。主服务器BGSAVE执…

高难度下的一闪---白金ACT游戏设计思想的一点体会

1、以前光环的开发者好像提出过一个理论&#xff0c;大意是游戏要让玩家保持30秒的循环&#xff0c; 持续下去。大意跟后来的心流接近。 2、根据我自身的开发体会&#xff0c;想要保持正回路&#xff0c;并不容易。 一个是要保持适当的挑战性&#xff0c;毫无难度的低幼式玩法…

页面滚动下拉时,元素变为fixed浮动,上拉到顶部时恢复原状,js代码以视频示例

页面滚动下拉时,元素变为fixed浮动js代码 以视频示例 <style>video{width:100%;height:auto}.div2,#float1{position:fixed;_position:absolute;top:45px;right:0; z-index:250;}button{float:right;display:block;margin:5px} </style><section id"abou…

算法题(32):三数之和

审题&#xff1a; 需要我们找到满足以下三个条件的所有三元组&#xff0c;并存在二维数组中返回 1.三个元素相加为0 2.三个元素的下标不可相同 3.三元组的元素不可完全相同 思路&#xff1a; 混乱的数据不利于进行操作&#xff0c;所以我们先进行排序 我们可以采取枚举的方法进…

科研绘图系列:R语言绘制Y轴截断分组柱状图(y-axis break bar plot)

禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍特点意义加载R包数据下载导入数据数据预处理画图输出总结系统信息介绍 Y轴截断分组柱状图是一种特殊的柱状图,其特点是Y轴的刻度被截断,即在某个范围内省略了部分刻度。这种图表…

PHP民宿酒店预订系统小程序源码

&#x1f3e1;民宿酒店预订系统 基于ThinkPHPuniappuView框架精心构建的多门店民宿酒店预订管理系统&#xff0c;能够迅速为您搭建起专属的、功能全面且操作便捷的民宿酒店预订小程序。 该系统不仅涵盖了预订、退房、WIFI连接、用户反馈、周边信息展示等核心功能&#xff0c;更…

前端 图片上鼠标画矩形框,标注文字,任意删除

效果&#xff1a; 页面描述&#xff1a; 对给定的几张图片&#xff0c;每张能用鼠标在图上画框&#xff0c;标注相关文字&#xff0c;框的颜色和文字内容能自定义改变&#xff0c;能删除任意画过的框。 实现思路&#xff1a; 1、对给定的这几张图片&#xff0c;用分页器绑定…