使用TensorFlow训练深度学习模型实战(下)

大家好,本文接TensorFlow训练深度学习模型的上半部分继续进行讲述,下面将介绍有关定义深度学习模型、训练模型和评估模型的内容。

定义深度学习模型

数据准备完成后,下一步是使用TensorFlow搭建神经网络模型,搭建模型有两个选项:

可以使用各种层,包括Dense、Conv2D和LSTM,从头开始搭建模型。这些层定义了模型的架构及数据流经过它的方式,可基于TensorFlow Hub提供的预训练模型搭建模型。这些模型已经在大型数据集上进行了训练,并可以在特定数据集上进行微调,以达到在较短的训练时间内达到较高的准确度。

可以根据TensorFlow Hub中的预训练模型来建立模型。这些模型已经在大型数据集上进行了训练,并且可以在你的特定数据集上进行微调,以达到较少的训练时间,达到较高的准确性。

  • 从头开始定义深度学习模型

TensorFlow中的tf.keras.Sequential函数允许我们逐层定义神经网络模型,我们可以选择各种层,如Dense、Conv2D和LSTM,来搭建定制的模型架构。以下是示例: 

# 定义模型架构
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10)
])

在这个示例中,我们定义了一个模型,包含以下六个层(4个隐藏层):

  1. Conv2D层,具有32个过滤器,3x3的内核大小和ReLU激活。此层以形状为(28,28,1)的输入图像作为输入。

  2. MaxPooling2D层,具有默认的2x2池大小。此层对从上一层获得的特征映射进行下采样。

  3. Flatten层,将2D特征映射展平为1D向量。

  4. Dense层,具有128个神经元和ReLU激活。此层对展平的特征映射执行完全连接操作。

  5. Dropout层,在训练期间随机丢弃50%的连接以防止过拟合。

  6. Dense层,具有十个神经元,无激活函数。此层表示模型的输出层,神经元的数量对应于分类任务中的类别数目。

这个模型遵循典型的卷积神经网络架构,包括多个卷积层和池化层,以及一个或多个全连接层。

  • 从预训练模型定义深度学习模型 

利用TensorFlow Hub提供的预训练模型可能是一个不错的选择,因为它们已经在大量的数据集上进行了训练,可以帮助在减少训练时间的同时实现高准确度。在实现任何这些模型之前,让我们先了解一些TensorFlow Hub提供的常见预训练模型。

  1. VGG:The Visual Geometry Group(VGG)模型是由牛津大学开发的。这些模型广泛用于图像分类任务,并在各种基准数据集上取得了最先进的结果。

  2. ResNet:The Residual Network(ResNet)模型是由微软研究院开发的。这些模型具有独特的架构,可以训练非常深的神经网络(高达1000层)。

  3. Inception:Inception模型是由Google开发的。这些模型具有独特的架构,使用不同尺度的多个并行卷积,Inception模型广泛用于目标检测和图像分类任务。

  4. MobileNet:MobileNet模型是由Google开发的。这些模型具有针对移动设备和嵌入式设备进行优化的独特架构,MobileNet模型广泛用于移动设备上的图像分类和目标检测任务。

可以通过向预训练模型添加额外层并在特定数据集上训练模型来应用迁移学习。与从头开始训练模型相比,这种技术可以节省大量时间和计算资源。但是,在选择预训练模型并将数据集转换为该格式以确保兼容之前,了解预训练模型所需的输入格式非常重要。

在这个示例中,MobileNet模型被作为基本模型使用。在使用基本模型之前,检查模型所需的格式非常重要, 在本示例中,格式为(224,224,3)。然而,MNIST数据集是一个灰度图像,大小为(28,28,1),其中单个值表示像素的亮度。图像大小也比所需的格式要小得多。因此,需要重新调整数据集。以下是调整大小的主要思路:

使用image.resize函数将图像调整为所需的大小。该函数使用双线性插值来保留原始图像中的信息,同时将其调整为新大小。因此,此步骤可以将原始形状(28,28,1)调整为(224,224,1)的形状。

使用image.grayscale_to_rgb函数将图像转换为新的RGB图像,通过将单个灰度通道复制到新的RGB图像的所有三个通道中,从而将原始形状(224,224,1)调整为(224,224,3)的形状。

# 调整输入图像的大小为224x224,并将其转换为三通道的RGB图像
X_train = tf.image.grayscale_to_rgb(tf.image.resize(X_train, [224, 224]))
X_test = tf.image.grayscale_to_rgb(tf.image.resize(X_test, [224, 224]))

 现在让我们基于MobileNet模型定义我们的模型:

# 加载MobileNet模型,不包括顶层
base_model = MobileNet(include_top=False, input_shape=(224, 224, 3))

# 添加一个全局平均池化层和一个全连接输出层
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.5)(x)
x = Dense(10, activation='softmax')(x)

# 将基础模型和新层结合起来,创建完整的模型
model = tf.keras.models.Model(inputs=base_model.input, outputs=x)

# 冻结基础模型中的各层
for layer in base_model.layers:
    layer.trainable = False

在上面的示例中,我们定义了一个模型,如下所示:

  1. 使用MobileNet()定义基本模型

  2. GlobalAveragePooling2D层,使用基本模型的最后一个卷积层的输出,计算每个特征映射的平均值,从而得到一个固定长度的向量,总结了特征映射中的空间信息。

  3. Dropout层,在训练期间随机丢弃50%的连接以防止过拟合。

  4. Dense层,使用十个单元的完全连接层和softmax激活。它接收来自上一层的输出并生成覆盖十个可能类别的概率分布。

编译和训练模型 

在创建模型之后,必须通过指定在训练期间使用的损失函数、优化器和指标来编译它。以下是一个编译模型的示例代码:

# 编译该模型
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

由于这是一个多分类问题,因此此示例代码使用了稀疏交叉熵损失函数,我们使用的是Adam优化器和准确率指标。

在训练模型之后可以在测试集上评估它,以查看它在未见过的数据上的表现如何,以下是一个评估模型的示例代码:

# 在测试数据上评估该模型
test_loss, test_acc = model.evaluate(X_test, y_test)
print('Test loss: ', test_loss)
print('Test Acc: ', test_acc)

 在此示例代码中,我们在测试集上评估模型,并输出测试损失和准确率。

进行预测

一旦训练和评估了模型,就可以使用它来预测新数据。以下是一个进行预测的示例代码:

# 对新数据进行预测
y_pred = model.predict(X_test)
y_pred_labels = np.argmax(y_pred, axis=1)
print(y_pred_labels)

在此示例代码中,我们在模型上使用predict()方法对整个测试集进行预测。

如果我们想要预测单个图像并返回预测标签与真实标签,那么就需要对Keras模型的predict()方法进行更改。因为Keras模型的predict()方法期望输入数据形式为一批图像,而我们想要传递单个图像给predict()方法,所以需要将其重新调整为批次大小为1。

def predict_and_compare(model, X_test, y_test, index):
    # 从X_test中获取给定索引的例子
    example = X_test[index]

    # 将例子重塑为预期的输入形状
    example = np.reshape(example, (1, 28, 28, 1))

    # 预测这个例子的标签
    y_pred = model.predict(example)

    # 将预测的概率转换为类别标签
    y_pred_label = np.argmax(y_pred, axis=1)[0]

    # 使用索引从y_test获取真实标签
    y_test_array = y_test.values
    # Get the label for the first example in the test set 
    y_true = y_test_array[index]

    # 输出预测的和真实的标签
    print("Predicted label:", y_pred_label)
    print("True label:", y_true)
    
    # 返回预测的和真实的标签
    return y_pred_label, y_true
  
# 预测并比较测试集中第一个例子的标签
y_pred_label, y_true = predict_and_compare(model, X_test, y_test, 0)

在上面的示例中,我们通过添加一个额外的维度来代表批次大小,从而将输入图像从(28,28,1)调整为(1,28,28,1)。这样,我们就可以传递单个图像给predict()方法,并获得该图像的预测结果。当我们调用上面的函数时,可以自定义要预测的图像:

 这就是在TensorFlow中实现深度学习的步骤。当然,这只是一个基本示例。你可以搭建具有更多层、不同类型的层和不同超参数的更复杂的模型,以便在数据集上获得更好的性能。

综上,本文我们演示了如何对数据进行预处理、搭建和训练模型、在单独的测试集上评估其性能以及使用简单的卷积神经网络(CNN)进行图像分类的预测,通过学习可以获得如何在TensorFlow中构建深度学习模型以及如何将这些概念应用于真实世界数据集的理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/46955.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java-运算符

目录 一、什么是运算符 二、算术运算符 1.基本四则运算符:加减乘除(、-、*、/、%) 2.增量运算符(、-、*、%) 3.自增、自减运算符(、--) 三、关系运算符 四、逻辑运算符 1.逻辑与 && …

Vue前端渲染blob二进制对象图片的方法

近期做开发,联调接口。接口返回的是一张图片,是对二进制图片处理并渲染,特此记录一下。 本文章是转载文章,原文章:Vue前端处理blob二进制对象图片的方法 接口response是下图 显然,获取到的是一堆乱码&…

【Docker】安全及日志管理

目录 一、Docker 安全及日志管理1.1 Docker 容器与虚拟机的区别1. 隔离与共享2. 性能与损耗 1.2Docker 存在的安全问题1.Docker 自身漏洞2.Docker 源码问题 1.3 Docker 架构缺陷与安全机制1. 容器之间的局域网攻击2. DDoS 攻击耗尽资源3. 有漏洞的系统调用4. 共享root用户权限 …

网安第二天笔记

ssh 22端口 账号密码登陆、证书登录 smtp 25端口 邮件协议 DNS 53 DHCP 67 68端口 四个包 1.DHCP服务器:服务器管理IP地址池和配置参数 2.客户端请求:发送DHCP广播请求,discover消息 3.DHCP服务器回应:收到discover会回复offer…

PostMan+Jmeter+QTP工具介绍及安装

目录 一、PostMan介绍​编辑 二、下载安装 三、Postman与Jmeter的区别 一、开发语言区别: 二、使用范围区别: 三、使用区别: 四、Jmeter安装 附一个详细的Jmeter按照新手使用教程,感谢作者,亲测有效。 五、Jme…

Linux_CentOS_7.9部署Docker以及镜像加速配置等实操验证全过程手册

前言:实操之前大家应该熟悉一个新的名词DevOps 俗称开发即运维、新一代开发工程师(Development和Operations的组合词)是一组过程、方法与系统的统称,用于促进开发(应用程序/软件工程)、技术运营和质量保障&…

Angular:动态依赖注入和静态依赖注入

问题描述: 自己写的服务依赖注入到组件时候是直接在构造器内初始化的。 直到看见代码中某大哥写的 private injector: Injector 动态依赖注入和静态依赖注入 在 Angular 中,使用构造函数注入的方式将服务注入到组件中是一种静态依赖注入的方式。这种方…

Python显示循环代码的进度条

目录 1. tqdm库 2. alive_progress库 3. progressbar库 1. tqdm库 tqdm是一个快速,可扩展的Python进度条,可以在Python长循环中添加一个进度提示信息 import time from tqdm import trangefor i in trange(100):# do somethingtime.sleep(0.5) 2. a…

微服务——Nacos配置管理

目录 Nacos配置管理——实现配置管理 配置管理实践 Nacos配置管理——微服务配置拉取 Nacos配置管理——配置热更新 方式一: ​编辑 方式二(推荐方式): Nacos配置管理——多环境配置共享 优先级问题 Nacos配置管理——nacos集群搭建 总结​编辑 Nacos配置管理——实现配置管…

【ADS】导入CMOS衬底文件+使用coilsys生成电感

新建工程经常忘记怎么操作,简记防遗忘。 操作步骤 1.unzip file2.原理图仿真3.Layout加载衬底文件4.使用coilsys生成电感 1.unzip file designKits-》unzip,选择对应库的压缩包,我这里是(TSMC_CRN65GP_v2.zip)。 为了…

BHQ 1Mal,BHQ-1 Maleimide,BHQ1马来酰亚胺,黑洞猝灭剂

资料编辑|陕西新研博美生物科技有限公司小编MISSwu​ PART1----产品描述: BHQ-1 Maleimide黑洞猝灭剂-1(BHQ-1)被归类为暗猝灭剂,该淬灭剂能够将一定距离内荧光基团发出的光全部吸收,实现对荧光信号的淬灭,所以可得到更强的特异性…

【App管理04-Bug修正 Objective-C语言】

一、咱们刚才已经把这个给大家做完了吧 1.这个Label怎么显示到上面去了, 我们现在是把它加到我们的控制器的View里面吧 我们看一下这个坐标是怎么算的,来,我们找一个坐标, 咱们的坐标,是不是用这个View的frame,减的吧 来,咱们在这里,输出一下这个Frame,看一下吧 在…

2023年第三届能源、电力与电气工程国际会议 (CoEEPE 2023)

会议简介 Brief Introduction 2023年第三届能源、电力与电气工程国际会议(CoEEPE 2023) 会议时间:2023年11月22日-24日 召开地点:澳大利亚墨尔本 大会官网:www.coeepe.org 2023年第三届能源、电力与电气工程国际会议(CoEEPE 2023)由安徽大学、…

Linux:Linux的发展史和作用有哪些?

文章目录 Linux是什么?Linux的开源特征为什么要学习Linux?Linux的应用场景有哪些? Linux是什么? 简单来说,Linux就是操作系统,它和Windows等软件一样,都只是操作系统,并无区别 Linu…

如何将ubuntu LTS升级为Pro

LTS支持周期是5年; Pro支持周期是10年。 Ubuntu Pro专业版笔记 步骤: 打开“软件和更新” 可以看到最右侧的标签是Ubuntu Pro。 在没有升级之前,如果使用下面两步: sudo apt updatesudo apt upgrade 出现如下提示&#xff…

vue 实现课程表甘特图

1 封装 components <template><view style"padding-bottom: 100rpx;"><view class"header flex_sb"><div class"header_list flex" v-for"(item,index) in timeList" :key"item.value"><div&g…

微信小程序——页面跳转方法和场景用法总结

✅作者简介&#xff1a;2022年博客新星 第八。热爱国学的Java后端开发者&#xff0c;修心和技术同步精进。 &#x1f34e;个人主页&#xff1a;Java Fans的博客 &#x1f34a;个人信条&#xff1a;不迁怒&#xff0c;不贰过。小知识&#xff0c;大智慧。 &#x1f49e;当前专栏…

如何在APP开发中实现多平台兼容?

随着移动互联网的发展&#xff0c;各大 APP平台也如雨后春笋般崛起&#xff0c;现在市面上主流的 APP平台基本都已经支持 IOS、 Android、 WP等多个平台的开发&#xff0c;但对于一些只想专注于做一款 APP的企业来说&#xff0c;往往只能选择在一个平台开发一个 APP&#xff0c…

opencv hand openpose

使用opencv c 来调用caffemodel 使用opencv 得dnn 模块调用 caffemodel得程序&#xff0c;图片自己输入就行&#xff0c;不做过多得解释&#xff0c;看代码清单。 定义手指关节点 const int POSE_PAIRS[20][2] { {0,1}, {1,2}, {2,3}, {3,4}, // thumb {0,5}, {5,6}, {6,7}…

索引能优化查询,那么谈谈索引的优点和缺点?索引原理

面试必备&#xff1a;索引能优化查询&#xff0c;那么谈谈索引的优点和缺点?索引原理 简述&#xff1a; 优点: 第一&#xff0c;通过创建唯一性索引&#xff0c;可以保证数据库表中每一行数据的唯一性。 第二&#xff0c;可以大大加快数据的检索速度&#xff0c;这也是创建…