MLP实现fashion_mnist数据集分类(1)-模型构建、训练、保存与加载(tensorflow)

1、查看tensorflow版本

import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

2、fashion_mnist数据集下载与展示

(train_image,train_label),(test_image,test_label) = tf.keras.datasets.fashion_mnist.load_data()
print(train_image.shape)
print(train_label.shape)
print(test_image.shape)
print(test_label.shape)

在这里插入图片描述

import matplotlib.pyplot as plt
# plt.imshow(train_image[0])  # 此处为啥是彩色的?

def plot_images_lables(images,labels,start_idx,num=5):
    fig = plt.gcf()
    fig.set_size_inches(12,14)
    for i in range(num):
        ax = plt.subplot(1,num,1+i)
        ax.imshow(images[start_idx+i],cmap='binary')
        title = 'label=' + str(labels[start_idx+i])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()
plot_images_lables(train_image,train_label,0,5)
# plot_images_lables(test_image,test_label,0,5)

在这里插入图片描述

3、数据预处理

X_train,X_test = tf.cast(train_image/255.0,tf.float32),tf.cast(test_image/255.0,tf.float32) # 归一化
y_train,y_test = train_label,test_label # 此处对y没有做onehot处理,需要使用稀疏交叉损失函数

4、模型构建

from keras import Sequential
from keras.layers import Flatten,Dense,Dropout
from keras import Input

model = Sequential()
model.add(Input(shape=(28,28)))
model.add(Flatten())
model.add(Dense(units=256,kernel_initializer='normal',activation='relu'))
model.add(Dropout(rate=0.1))
model.add(Dense(units=64,kernel_initializer='normal',activation='relu'))
model.add(Dropout(rate=0.1))
model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))
model.summary()

在这里插入图片描述

5、模型配置

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['acc'])

6、模型训练

H = model.fit(x=X_train,
              y=y_train,
              validation_split=0.2,
              # validation_data=(X_test,y_test),
              epochs=10,
              batch_size=128,
              verbose=1)

在这里插入图片描述

plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()

在这里插入图片描述

plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()

在这里插入图片描述

7、模型评估

model.evaluate(X_test,y_test)

在这里插入图片描述

8、模型预测

import numpy as np
import matplotlib.pyplot as plt

def pred_plot_images_lables(images,labels,start_idx,num=5):
    # 预测
    res = model.predict(images[start_idx:start_idx+num])
    res = np.argmax(res,axis=1)

    # 画图
    fig = plt.gcf()
    fig.set_size_inches(12,14)
    for i in range(num):
        ax = plt.subplot(1,num,1+i)
        ax.imshow(images[start_idx+i],cmap='binary')
        title = 'label=' + str(labels[start_idx+i]) + ', pred=' + str(res[i])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()
pred_plot_images_lables(X_test,y_test,0,5)

在这里插入图片描述

9、模型保存与加载

import numpy as np

tf.keras.models.save_model(model,"model.keras")
loaded_model = tf.keras.models.load_model("model.keras")
# assert np.allclose(model.predict(X_test[:5]), loaded_model.predict(X_test[:5]))
print(np.argmax(model.predict(X_test[:5]),axis=1))
print(np.argmax(loaded_model.predict(X_test[:5]),axis=1))

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/595894.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

如何使git提交的时候忽略一些特殊文件?

认识.gitignore文件 在生成远程仓库的时候我们会看到这样一个选项: 这个.gitignore文件有啥用呢? .gotignore文件是Git版本控制系统中的一个特殊文件。用来指定哪些文件或者目录不被Git追踪或者提交到版本库中。也就意味着,如果我们有一些文…

怎么通过Java语言实现远程控制无人售货柜

怎么通过Java语言实现远程控制无人售货柜呢? 本文描述了使用Java语言调用HTTP接口,实现控制无人售货柜,独立控制售货柜、格子柜的柜门。 可选用产品:可根据实际场景需求,选择对应的规格 序号设备名称厂商1智能WiFi控…

使用 Postman 实现 API 自动化测试

背景介绍 相信大部分开发人员和测试人员对 postman 都十分熟悉,对于开发人员和测试人员而言,使用 postman 来编写和保存测试用例会是一种比较方便和熟悉的方式。但 postman 本身是一个图形化软件,相对较难或较麻烦(如使用 RPA&am…

低功耗UPF设计的经典案列分享

案例1 分享个例子,景芯A72低功耗设计,DBG domain的isolation为何用VDDS_maia_noncpu供电而不是TOP的VDD? 答:因为dbg的上一级是noncpu,noncpu下面分成dbg和两个tbnk。 案例2 景芯A72的低功耗,请问&#…

精品干货 | 数据中台与数据仓库建设(免费下载)

【1】关注本公众号,转发当前文章到微信朋友圈 【2】私信发送 数据中台与数据仓库建设 【3】获取本方案PDF下载链接,直接下载即可。 如需下载本方案PPT/WORD原格式,请加入微信扫描以下方案驿站知识星球,获取上万份PPT/WORD解决方…

零基础入门学习Python第二阶01生成式(推导式),数据结构

Python语言进阶 重要知识点 生成式(推导式)的用法 prices {AAPL: 191.88,GOOG: 1186.96,IBM: 149.24,ORCL: 48.44,ACN: 166.89,FB: 208.09,SYMC: 21.29}# 用股票价格大于100元的股票构造一个新的字典prices2 {key: value for key, value in prices.i…

小微公司可用的开源ERP系统

项目介绍 华夏ERP是基于SpringBoot框架和SaaS模式的企业资源规划(ERP)软件,旨在为中小企业提供开源且易用的ERP解决方案。它专注于提供进销存、财务和生产功能,涵盖了零售管理、采购管理、销售管理、仓库管理、财务管理、报表查询…

Unreal Engine插件打包技巧

打开UE工程,点击编辑,选择插件,点击"打包"按钮,选择输出目录UE4.26版本打包提示需要VS2017问题解决 1)用记事本打开文件【UE4对应版本安装目录\Epic Games\UE_4.26\Engine\Build\BatchFiles\RunUAT.bat】 2&…

单元测试配置

检查 vendor 目录下 是否有bin目录, bin目录下是否有 phpunit 文件 没有安装 composer require —dev phpunit/phpunit 确认版本是 PHPUnit 9.6.7配置IDE配置php解释器点击绿色箭头,运行测试查看效果备注: 单步调试需要安装 xdebug

5月6号作业

申请该结构体数组,容量为5,初始化5个学生的信息 使用fprintf将数组中的5个学生信息,保存到文件中去 下一次程序运行的时候,使用fscanf,将文件中的5个学生信息,写入(加载)到数组中去,并直接输出学…

Mysql索引失效情况

索引失效的情况 这是正常查询情况,满足最左前缀,先查有先度高的索引。 1. 注意这里最后一种情况,这里和上面只查询 name 小米科技 的命中情况一样。说明索引部分丢失! 2. 这里第二条sql中的,status > 1 就是范围查…

性能超越!新模型Dragoman打造高质量英译乌翻译系统,打败现有SOTA模型

DeepVisionary 每日深度学习前沿科技推送&顶会论文分享,与你一起了解前沿深度学习信息! 引言:探索乌克兰语的机器翻译挑战 在当今全球化迅速发展的背景下,机器翻译技术已成为沟通世界各地文化和语言的重要桥梁。尽管如此&…

【Axure高保真原型】动态伸缩信息架构图

今天和大家分享动态伸缩信息架构图的原型模板,我们可以通过点击加减按钮来展开或收起子内容,具体效果可以点击下方视频观看或者打开预览地址来体验 【原型效果】 【Axure高保真原型】动态伸缩信息架构图 【原型预览含下载地址】 https://axhub.im/ax9/…

Python从0到100(二十):文件读写和文件操作

一、文件的打开和关闭 有了文件系统可以非常方便的通过文件来读写数据;在Python中要实现文件操作是非常简单的。我们可以使用Python内置的open函数来打开文件,在使用open函数时,我们可以通过函数的参数指定文件名、操作模式和字符编码等信息…

关于蓝队应急响应工具箱意见征集

前言 征集一下各位师傅的意见,没用过的师傅可以去以往的文章下载使用: 下载地址(有个小小改动,去除了必要的python环境,使其占用空间更小): [护网必备]知攻善防实验室蓝队应急响应工具箱v202…

自动化运维工具---Ansible

一 Puppet Puppet是历史悠久的运维工具之一。它是一种基础架构即代码(laC)工具,使用户可以定义其基础 架构所需的状态,并使系统自动化以实现相同状态。 Puppet可监视用户的所有系统,并防止任何偏离已定义状态的情况。从简单的工作流程自动…

pytest教程-36-钩子函数-pytest_collection_start

领取资料,咨询答疑,请➕wei: June__Go 上一小节我们学习了pytest_unconfigure钩子函数的使用方法,本小节我们讲解一下pytest_collection_start钩子函数的使用方法。 pytest_collection_start(session) 是一个 pytest 钩子函数,…

YOLO-World环境搭建推理测试

一、引子 CV做了这么多年,大多是在固定的数据集上训练,微调,测试。突然想起来一句话,I have a dream!就是能不能不用再固定训练集上捣腾,也就是所谓的开放词汇目标检测(OVD)。偶尔翻…

new mars3d.control.MapSplit({实现点击卷帘两侧添加不同图层弹出不同的popup

new mars3d.control.MapSplit({实现点击卷帘两侧添加不同图层弹出不同的popup效果: 左侧: 右侧: 说明:mars3d的3.7.12以上版本才支持该效果。 示例链接: 功能示例(Vue版) | Mars3D三维可视化平台 | 火星科技 相关代…

C++进阶:AVL树

AVL树的概念 二叉搜索树虽可以缩短查找的效率,但 如果数据有序或接近有序二叉搜索树将退化为单支树,查 找元素相当于在顺序表中搜索元素,效率低下 。因此,两位俄罗斯的数学家 G.M. A delson- V elskii 和 E.M. L andis 在 1962 …