深度学习-Tensorflow使用Keras进行模型训练

本文以FasionMNIST/加州房价数据集为例,介绍KerasAPI进行分类问题/回归问题模型训练的方法

Tensorflow版本

Tensorflow和keara都需要2.0及以上版本

import tensorflow as tf
from tensorflow import keras
print(tf.__version__)
print(keras.__version__)

分类MLP构建

数据集

fashion MNIST dataset,数据集有60000张衣服鞋子的图片,大小为28X28。

Keras可以通过keras.datasets方法来下载主流的数据集,fashion MNIST dataset已经区分了训练集(50000张)和测试集(10000张), 但最好从训练集中划出一部分作为验证集(Validation set)

import tensorflow as tf
from tensorflow import keras
fashion_mnist = keras.datasets.fashion_mnist
(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist.load_data()

# 由于数据是最大值为255的灰度图像,除于255进行归一化
X_valid, X_train = X_train_full[:5000] / 255., X_train_full[5000:] / 255.
y_valid, y_train = y_train_full[:5000], y_train_full[5000:]
X_test = X_test / 255.

模型训练

定义网络层

Sequential API方法

Sequential网络层定义有两种方法,一种是add layers方法,有点笨拙

model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=[28, 28]))
model.add(keras.layers.Dense(300, activation="relu"))
model.add(keras.layers.Dense(100, activation="relu"))
model.add(keras.layers.Dense(10, activation="softmax"))

另外一种整体括号内定义网络层,推荐这种方法

model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=[28, 28]),
    keras.layers.Dense(300, activation="relu"),
    keras.layers.Dense(100, activation="relu"),
    keras.layers.Dense(10, activation="softmax")
])

解释:

  1. Sequential是Keras最简单的网络结构定义方式,按照从上到下(代码文本)讲单层的网络连接起来

  1. Flatten的作用是将28X28的数组转化为1D的格式

  1. 第4/5行分别定义了两个隐藏层,大小为300/100个神经元,激活函数未relu

  1. 第6行定义输出层,由于是多分类问题,使用了“softmax”作为激活函数。

  • 使用model_summary()可以查看模型的结构

  • 使用model.get_layer(), layer.get_weight()可以查看模型的参数

定义损失函数、优化器等

model.compile(loss="sparse_categorical_crossentropy",
              optimizer="sgd",
              metrics=["accuracy"])

模型训练

和sk-learn一样调用fit()方法

history = model.fit(X_train, y_train, epochs=30,
                    validation_data=(X_valid, y_valid))

可以方便的运用history数据进行训练过程的可视化

import pandas as pd
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
plt.gca().set_ylim(0, 1)

测试集评估

model.evaluate(X_test, y_test)

模型保存

测试集上能达到要求就可以进行模型保存了,也可以将其部署到生产环境。

# kears模型的格式为h5
model.save("my_keras_FasionMNist_model.h5")

模型预测

# 载入模型
model = keras.models.load_model("my_keras_FasionMNist_model.h5")

X_new = X_test[:1] # 以一个样本为例子,输入是28X28的数组
y_proba = model.predict(X_new) # 由于是10分类问题,softmax输出结果是一个在10分类上概率和为1的10维数组
y_proba.round(2)

y_pred = np.argmax(model.predict(X_new), axis=-1)# 求概率和最大,axis=-1为在每个10维度数组内求
y_pred

回归MLP构建

以加州房价预测数据集为例,模型开发和分类问题类似,就不分模块一一介绍了。

主要区别:

  1. 由于输出房价是1维数据,只需要一个神经元,且训练label也是房价,因此不需要激活函数

  1. loss function为MSE

  1. 由于数据集比较简单,为防止过拟合,仅适用一个hidden layer,且神经元的个数设置为更低的值

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# load数据集
housing = fetch_california_housing()

X_train_full, X_test, y_train_full, y_test = train_test_split(housing.data, housing.target, random_state=42)
X_train, X_valid, y_train, y_valid = train_test_split(X_train_full, y_train_full, random_state=42)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_valid = scaler.transform(X_valid)
X_test = scaler.transform(X_test)

# 模型定义和训练
model = keras.models.Sequential([
    keras.layers.Dense(30, activation="relu", input_shape=X_train.shape[1:]),
    keras.layers.Dense(1)
])
model.compile(loss="mean_squared_error", optimizer=keras.optimizers.SGD(lr=1e-3))
history = model.fit(X_train, y_train, epochs=20, validation_data=(X_valid, y_valid))

#模型训练epoch过程MSE可视化
plt.plot(pd.DataFrame(history.history))
plt.grid(True)
plt.gca().set_ylim(0, 1)
plt.show()

mse_test = model.evaluate(X_test, y_test)

X_new = X_test[:3]
y_pred = model.predict(X_new)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/1441.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AI_Papers周刊:第六期

CV - 计算机视觉 | ML - 机器学习 | RL - 强化学习 | NLP 自然语言处理 2023.03.13—2023.03.19 文摘词云 Top Papers Subjects: cs.CL 1.UPRISE: Universal Prompt Retrieval for Improving Zero-Shot Evaluation 标题:UPRISE:改进零样本评估…

要是早看到这篇文章,你起码少走3年弯路,20年老程序员的忠告

文章目录前言一、程序员的薪资是怎么样的?二、我现在的情况适合做程序员吗?三、大学期间到底应该学些什么?四、工作还是考研?五、总结前言 我是龙叔,一名工作了20多年的退休老程序员。 如果你在工作之前看到这篇文章…

【AI大比拼】文心一言 VS ChatGPT-4

摘要:本文将对比分析两款知名的 AI 对话引擎:文心一言和 OpenAI 的 ChatGPT,通过实际案例让大家对这两款对话引擎有更深入的了解,以便大家选择合适的 AI 对话引擎。 亲爱的 CSDN 朋友们,大家好!近年来&…

libcurl库访问人工智能平台之人脸识别

一、前言上一篇文章我们调用libcurl库去访问了百度,访问的是http协议的百度云主页。那么现在我们要基于翔云人工智能平台来实现人脸识别,具体的操作大概就是我们在linux下调用libcurl库去访问翔云人工智能平台,然后实现我们想要的两张人脸图片…

FPGA纯verilog实现RIFFA的PCIE通信,提供工程源码和软件驱动

目录1、前言2、RIFFA简介RIFFA概述RIFFA架构RIFFA驱动3、vivado工程详解4、上板调试验证并演示5、福利:工程代码的获取1、前言 PCIE是目前速率很高的外部板卡与CPU通信的方案之一,广泛应用于电脑主板与外部板卡的通讯,PCIE协议极其复杂&…

【Linux】基本指令介绍

前言从今天开始,我们一起来学习Linux的相关知识,今天先来介绍怎么登录Linux,并且介绍一些Linux的基本指令。使用 XShell 远程登录 Linux很多同学的 Linux 启动进入图形化的桌面. 这个东西大家以后就可以忘记了. 以后的工作中 没有机会 使用图…

蓝桥杯刷题冲刺 | 倒计时21天

作者:指针不指南吗 专栏:蓝桥杯倒计时冲刺 🐾马上就要蓝桥杯了,最后的这几天尤为重要,不可懈怠哦🐾 文章目录1.迷宫1.迷宫 题目 链接: 迷宫 - 蓝桥云课 (lanqiao.cn) 本题为填空题,只…

Three.js——learn02

Three.js——learn02Three.js——learn02通过轨道控制器查看物体OrbitControls核心代码index2.htmlindex.cssindex2.jsresult添加辅助器1.坐标轴辅助器AxesHelper核心代码完整代码2.箭头辅助器ArrowHelper核心代码完整代码3.相机视锥体辅助器CameraHelper核心代码完整代码Three…

近期投简历、找日常实习的一些碎碎念(大二---测试岗)

嘿嘿嘿,我又回来了,相信不少兄弟已经发现我似乎已经断更了好久,哈哈,我是尝试去找实习,投简历面试去了。 先说一下背景。 目录 背景 求职进行中 简历 投递和沟通 收获和感受 背景 博主,大二软件工程…

Arthas工具的基本使用

介绍 Arthas 是Alibaba开源的Java诊断工具,深受开发者喜爱。在线排查问题,无需重启;动态跟踪Java代码;实时监控JVM状态。Arthas支持JDK 6,支持Linux/Mac/Windows,采用命令行交互模式,同时提供丰…

Python截图自动化工具

1、展示部分源码(写的比较乱,哈哈) 2、功能展示 1)首页 2)按钮截图(用于自动翻页) 3)保存位置按钮(选择图片保存的位置) 4)重复次数,就是要截取多少次 5)定位截屏(截取的内容&#x…

[数据分析与可视化] Python绘制数据地图1-GeoPandas入门指北

本文主要介绍GeoPandas的基本使用方法,以绘制简单的地图。GeoPandas是一个Python开源项目,旨在提供丰富而简单的地理空间数据处理接口。GeoPandas扩展了Pandas的数据类型,并使用matplotlib进行绘图。GeoPandas官方仓库地址为:GeoP…

尚融宝06-ECMAScript基本介绍和使用

目录 一、ECMAScript 1、ECMA 2、ECMAScript 3、什么是 ECMA-262 4、ECMA-262 历史 5、ECMAScript 和 JavaScript 的关系 二、基本语法 1、let声明变量 2、const声明常量 3、解构赋值 4、模板字符串 5、声明对象简写 6、定义方法简写 7、参数的默认值 8、对象拓…

QT常用位置函数区别

目录1、引言2、实验代码3、位置函数3.1 x()3.2 y()3.3 frame()3.4 pos()3.5 geometry()3.6 width()3.7 height()3.8 rect()3.9 size()1、引言 QT有众多图形绘制函数,包括x()、y()、frame()、pos()、geometry()、width()、height()、rect()、size(),它们…

【Java学习笔记】多线程与线程池

多线程与线程池一、多线程安全与应用1、程序、进程与线程的关系2、创建多线程的三种方式(1)继承Thread类创建线程【不推荐】(2)实现Runnable接口创建线程(3)Callable接口创建线程3、线程的生命周期4、初识线…

基础入门 HTTP数据包Postman构造请求方法请求头修改状态码判断

文章目录数据-方法&头部&状态码请求requestResponse状态码案例-文件探针&登录爆破工具-Postman自构造使用数据-方法&头部&状态码 请求request 1、常规请求-Get 2、用户登录-Post •get:向特定资源发出请求(请求指定页面信息&#x…

为什么这么NB?huatuo革命Unity热更新

最近huatuo(华佗)热更新解决方案火爆了unity开发圈,起初我觉得热更新嘛,不就是内置一个脚本解释器脚本语言开发,如xLua, ILRuntime, puerts。Huatuo又能玩出什么花样,凭什么会这么NB,引起了那么多程序员的关注与称赞呢&#xff1f…

单片机——IIC协议与24C02

1、基础知识 1.1、IIC串行总线的组成及工作原理 I2C总线只有两根双向信号线。一根是数据线SDA,另一根是时钟线SCL。 1.2、I2C总线的数据传输 I2C总线进行数据传送时,时钟信号为高电平期间,数据线上的数据必须保持稳定,只有在时钟…

Linux实操之进程管理

文章目录一、基本介绍二、显示系统执行的进程基本介绍三、ps详解四、终止进程kill和killall介绍:●基本语法常用选项五、查看进程树pstree基本语法常用选项一、基本介绍 1.在LINUX中,每个执行的程序都称为一个进程。每一个进程都分配一个ID号(pid,进程号…

【SCL】实现简单算法--冒泡排序

使用SCL语言实现一个冒泡排序的简单算法 文章目录 目录 文章目录 前言 二、实现排序 1.读取存储器地址(PEEK)指令 2.编写程序 总结 前言 本文我们来一起使用SCL来实现一个简单的算法——冒泡排序;它可以对少量数据进行从小到大或从大到小排序…