TensorFlow2实战-系列教程2:神经网络分类任务

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、Mnist数据集

下载mnist数据集:

%matplotlib inline
from pathlib import Path
import requests

DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)

制作数据:

import pickle
import gzip

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

简单展示数据:

from matplotlib import pyplot
import numpy as np

pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)
print(y_train[0])

打印结果:

(50000, 784)
5

在这里插入图片描述

2、模型构建

在这里插入图片描述
在这里插入图片描述
输入为784神经元,经过隐层提取特征后为10个神经元,10个神经元的输出值经过softmax得到10个概率值,取出10个概率值中最高的一个就是神经网络的最后预测值

构建模型代码:

import tensorflow as tf
from tensorflow.keras import layers
model = tf.keras.Sequential()
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

选择损失函数,损失函数是机器学习一个非常重要的部分,基本直接决定了这个算法的效果,这里是多分类任务,一般我们就直接选用多元交叉熵函数就好了:
TensorFlow损失函数API

编译模型:

model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
             loss=tf.keras.losses.SparseCategoricalCrossentropy(),
             metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
  1. adam优化器,学习率为0.001
  2. 多元交叉熵损失函数
  3. 评价指标

模型训练:

model.fit(x_train, y_train, epochs=5, batch_size=64, validation_data=(x_valid, y_valid))

训练数据,训练标签,训练轮次,batch_size,验证集

打印结果:

Train on 50000 samples, validate on 10000 samples
Epoch 1/5 50000/50000  1s 29us
sample-loss: 115566 - sparse_categorical_accuracy: 0.1122 - val_loss: 364928.5786 - val_sparse_categorical_accuracy: 0.1064
Epoch 2/5 50000/50000 1s 21us
sample - loss: 837104 - sparse_categorical_accuracy: 0.1136 - val_loss: 1323287.7028 - val_sparse_categorical_accuracy: 0.1064
Epoch 3/5 50000/50000 1s 20us
sample - loss: 1892431 - sparse_categorical_accuracy: 0.1136 - val_loss: 2448062.2680 - val_sparse_categorical_accuracy: 0.1064
Epoch 4/5 50000/50000 1s 20us
sample - loss: 3131130 - sparse_categorical_accuracy: 0.1136 - val_loss: 3773744.5348 - val_sparse_categorical_accuracy: 0.1064
Epoch 5/5 50000/50000 1s 20us
sample - loss: 4527781 - sparse_categorical_accuracy: 0.1136 - val_loss: 5207194.3728 - val_sparse_categorical_accuracy: 0.1064
<tensorflow.python.keras.callbacks.History at 0x1d3eb9015f8>

模型保存:

model.save('Mnist_model.h5')

3、TensorFlow常用模块

3.1 Tensor格式转换

创建一组数据

import numpy as np
input_data = np.arange(16)
input_data

打印结果:
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

转换成TensorFlow格式的数据:

dataset = tf.data.Dataset.from_tensor_slices(input_data)
for data in dataset:
    print (data)

将一个ndarray转换成
打印结果:
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)

tf.Tensor(14, shape=(), dtype=int32)
tf.Tensor(15, shape=(), dtype=int32)

3.2repeat操作

dataset = tf.data.Dataset.from_tensor_slices(input_data)
dataset = dataset.repeat(2)
for data in dataset:
    print (data)

tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)

tf.Tensor(14, shape=(), dtype=int32)
tf.Tensor(15, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)

tf.Tensor(14, shape=(), dtype=int32)
tf.Tensor(15, shape=(), dtype=int32)

会将当前的输出重复一遍

3.3 batch操作

dataset = tf.data.Dataset.from_tensor_slices(input_data)
dataset = dataset.repeat(2).batch(4)
for data in dataset:
    print (data)

tf.Tensor([0 1 2 3], shape=(4,), dtype=int32)
tf.Tensor([4 5 6 7], shape=(4,), dtype=int32)
tf.Tensor([ 8 9 10 11], shape=(4,), dtype=int32)
tf.Tensor([12 13 14 15], shape=(4,), dtype=int32)
tf.Tensor([0 1 2 3], shape=(4,), dtype=int32)
tf.Tensor([4 5 6 7], shape=(4,), dtype=int32)
tf.Tensor([ 8 9 10 11], shape=(4,), dtype=int32)
tf.Tensor([12 13 14 15], shape=(4,), dtype=int32)

将原来的数据按照4个为一个批次

3.4 shuffle操作

dataset = tf.data.Dataset.from_tensor_slices(input_data).shuffle(buffer_size=10).batch(4)
for data in dataset:
    print (data)

tf.Tensor([ 9 8 11 3], shape=(4,), dtype=int32)
tf.Tensor([ 5 6 1 13], shape=(4,), dtype=int32)
tf.Tensor([14 15 4 2], shape=(4,), dtype=int32)
tf.Tensor([12 7 0 10], shape=(4,), dtype=int32)

shuffle操作,直接翻译过来就是洗牌,把当前的数据进行打乱操作
buffer_size=10,就是缓存10来进行打乱取数据

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/351610.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

shell常用命令,参数传递,函数,挂载磁盘

1、ls 功能&#xff1a;显示文件和目录的信息ls 以默认方式显示当前目录文件列表 ls -a 显示所有文件包括隐藏文件 ls -l 显示文件属性&#xff0c;包括大小&#xff0c;日期&#xff0c;符号连接&#xff0c;是否可读写及是否可执行 ls -lh 显示文件的大小&#xff0c;以…

BOSS 直聘:日增10亿数据的历史库,如何通过OceanBase节省70%存储成本?

BOSS 直聘是在全球范围内首创互联网“直聘”模式的在线招聘产品&#xff0c;目前已经成为了中国最大的招聘平台。本文谈到的 BOSS 直聘的业务场景主要是通过数据库对招聘过程中的聊天记录信息进行存储&#xff0c;数据量极大&#xff0c;且每天都有 5 亿到 10 亿的增量数据。和…

Linux入门攻坚——14、实战软件安装-搭建Python3.8环境-2

上一篇解决了openssl和pip问题&#xff0c;这一篇来解决sqlite问题 创建app时出现错误&#xff0c;模块_sqlite3找不到&#xff0c;查询sqlite相关的包&#xff1a; 在python2.6的lib-dynload路径下&#xff0c;有_sqlite3.so&#xff0c;这个应该就是Python需要的sqlite模块&a…

Redis 实际项目中的整合,记录各种用法

Redis缓存餐厅数据 我们来看主要的流程 很简单,就是在数据库和接口之间加了一层缓冲,在redis之前其实还可以加其他的缓存 例如 nginx的缓存 接下来,就是结合我的业务,来做缓存 我这里的业务逻辑是,按了分类的按钮,分别以不同的 分类为一组缓存数据 所以,这里的缓存粒度是分类…

一个人永远无法赚到认知以外的钱

做交付、做产品&#xff0c;从来都不是一件容易的事情。在营销和服务的过程中&#xff0c;我充分见证了人生百态。 前一段时间&#xff0c;小灰创建了一个自媒体陪伴群&#xff0c;群里邀请了各个领域的自媒体大佬&#xff0c;每周都会在群里进行分享&#xff0c;大家的学习气氛…

推荐家庭关系三姑六婆计算器微信小程序源码

亲戚关系计算器微信小程序源码是一款为避免遇到亲戚却不知道该怎么称呼时遇到的尴尬情况而开发的&#xff0c;由于社会节奏的快速发展&#xff0c;现在的关系不像以前一样经常联系和维护&#xff0c;导致了有些自己家的一些亲戚也疏远了很多。 演示地 址 &#xff1a; runrunco…

httprunnerV4.X的基本使用详解

目录 1、httprunner概述 1.1、httprunner的优点 2、httprunner的安装 3、基本命令的使用 3.1、生成脚手架 3.2、将har文件转换为测试用例文件 3.3、执行测试用例 3.4、为项目创建虚拟环境&#xff0c;然后安装httprunner库 3.4、执行测试用例生成测试报告 4、httprun…

北斗卫星为野外科考人员提供安全保障

北斗卫星为野外科考人员提供安全保障 自第二次青藏高原综合科学考察研究启动以来&#xff0c;青海不断提升科考服务保障能力&#xff0c;推动科考全程信息化&#xff0c;有效促进科考成果转化。 为保障科考人员的人身安全&#xff0c;青海省青藏科学考察服务中心开发了基于北…

【Docker】附录一:常见问题总结

作者主页&#xff1a; 正函数的个人主页 文章收录专栏&#xff1a; Docker 欢迎大家点赞 &#x1f44d; 收藏 ⭐ 加关注哦&#xff01; 常见问题总结 一、镜像相关 如何批量清理临时镜像文件&#xff1f; 答&#xff1a;可以使用 docker image prune 命令。 如何查看镜像支持…

9.异步爬虫

异步爬虫可以理解为非只单线程爬虫 我们下面做个例子&#xff0c;之前我们通过单线程爬取过梨视频 https://blog.csdn.net/potato123232/article/details/135672504 在保存视频的时候会慢一些&#xff0c;为了提升效率&#xff0c;我们使用异步爬虫爬取 目录 1 线程池 2 …

使用vs2022将.net8的应用程序发布为一个单独文件

在使用.NetCore3.1时&#xff0c;可以通过设置以下工程配置文本来将项目发布为一个单独的应用程序文件&#xff1a; <Project Sdk"Microsoft.NET.Sdk.WindowsDesktop"><PropertyGroup><TargetFramework>netcoreapp3.1</TargetFramework><…

Jmeter性能测试: 基于JDK 21 安装 Jmeter 5.6.3

目录 一、实验 1.环境 2.JDK下载 3.Jmeter下载 4.Windows安装JDK 21 5.Windows安装Jmeter 5.6.3 6.Linux安装JDK 21 7.Linux安装Jmeter 5.6.3 二、问题 1. Linux 的profile、bashrc、bash_profile文件有哪些区别 一、实验 1.环境 &#xff08;1&#xff09;主机 表…

C语言之指针的地址和指向的内容总结(八十四)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…

计算机网络:体系结构知识点汇总

文章目录 一、计算机网络概述1.1概念及功能1.2组成和分类1.3性能指标 二、体系结构与参考模型2.1分层结构、协议、接口、服务2.2OSI参考模型2.3TCP/IP参考模型 一、计算机网络概述 1.1概念及功能 计算机网络就是通过各个节点&#xff0c;这个节点包括终端的电脑&#xff0c;手…

opencv#32 可分离滤波

滤波的可分离性 就是将一个线性滤波变成多个线性滤波&#xff0c;这里面具体所指的是变成x方向的线性滤波和y方向的线性滤波。无论先做x方向的滤波还是y方向滤波&#xff0c;两者的叠加结果是一致的&#xff0c;这个性质取决于滤波操作是并行的&#xff0c;也就是每一个图像在滤…

研发日记,Matlab/Simulink避坑指南(六)——字节分割Bug

文章目录 前言 背景介绍 问题描述 分析排查 解决方案 总结归纳 前言 见《研发日记&#xff0c;Matlab/Simulink避坑指南&#xff08;一&#xff09;——Data Store Memory模块执行时序Bug》 见《研发日记&#xff0c;Matlab/Simulink避坑指南(二)——非对称数据溢出Bug》…

文旅项目包括什么?

文旅项目是指与文化和旅游相结合的项目&#xff0c;旨在通过提供丰富的文化体验和旅游服务来吸引游客&#xff0c;促进地方经济发展。 文旅项目通常包括多个方面&#xff0c;以下是对每块内容的详细介绍&#xff1a; 文化旅游景区&#xff1a;这类项目以展示人类文化和历史遗产…

单片机学习笔记---独立按键控制LED显示二进制

这节我们来实现独立按键的第三个功能&#xff0c;独立按键控制LED显示二进制 新创建一个工程文件&#xff0c;然后上来我们就要把基本框架写好&#xff0c;这是基本的习惯 老规矩&#xff0c;然后把Delay 1ms的代码复制过来 复制过来后改造一下&#xff1a; 把1ms删掉&#x…

近20k stars,GSYVideoPlayer一款优秀的视频播放器

近20k stars&#xff0c;GSYVideoPlayer一款优秀的视频播放器 引言 在现代社会中&#xff0c;视频已经成为人们获取信息和娱乐的重要形式。为了提供更好的观看体验&#xff0c;开发一款优秀的视频播放器变得至关重要。而GSYVideoPlayer作为一款功能强大、稳定可靠的视频播放器…

03 SB实战 -微头条之首页门户模块(跳转某页面自动展示所有信息+根据hid查询文章全文并用乐观锁修改阅读量)

1.1 自动展示所有信息 需求描述: 进入新闻首页portal/findAllType, 自动返回所有栏目名称和id 接口描述 url地址&#xff1a;portal/findAllTypes 请求方式&#xff1a;get 请求参数&#xff1a;无 响应数据&#xff1a; 成功 {"code":"200","mes…