TenorFlow多层感知机识别手写体

文章目录

  • 数据准备
  • 建立模型
          • 建立输入层 x
          • 建立隐藏层h1
          • 建立隐藏层h2
          • 建立输出层
  • 定义训练方式
          • 建立训练数据label真实值 placeholder
          • 定义loss function
          • 选择optimizer
  • 定义评估模型的准确率
          • 计算每一项数据是否正确预测
          • 将计算预测正确结果,加总平均
  • 开始训练
          • 画出误差执行结果
          • 画出准确率执行结果
  • 评估模型的准确率
  • 进行预测
  • 找出预测错误

GITHUB地址https://github.com/fz861062923/TensorFlow
注意下载数据连接的是外网,有一股神秘力量让你403

数据准备

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\h5py\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters


WARNING:tensorflow:From <ipython-input-1-2ee827ab903d>:4: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
print('train images     :', mnist.train.images.shape,
      'labels:'           , mnist.train.labels.shape)
print('validation images:', mnist.validation.images.shape,
      ' labels:'          , mnist.validation.labels.shape)
print('test images      :', mnist.test.images.shape,
      'labels:'           , mnist.test.labels.shape)
train images     : (55000, 784) labels: (55000, 10)
validation images: (5000, 784)  labels: (5000, 10)
test images      : (10000, 784) labels: (10000, 10)

建立模型

def layer(output_dim,input_dim,inputs, activation=None):#激活函数默认为None
    W = tf.Variable(tf.random_normal([input_dim, output_dim]))#以正态分布的随机数建立并且初始化权重W
    b = tf.Variable(tf.random_normal([1, output_dim]))
    XWb = tf.matmul(inputs, W) + b
    if activation is None:
        outputs = XWb
    else:
        outputs = activation(XWb)
    return outputs
建立输入层 x
x = tf.placeholder("float", [None, 784])
建立隐藏层h1
h1=layer(output_dim=1000,input_dim=784,
         inputs=x ,activation=tf.nn.relu)  

建立隐藏层h2
h2=layer(output_dim=1000,input_dim=1000,
         inputs=h1 ,activation=tf.nn.relu)  
建立输出层
y_predict=layer(output_dim=10,input_dim=1000,
                inputs=h2,activation=None)

定义训练方式

建立训练数据label真实值 placeholder
y_label = tf.placeholder("float", [None, 10])#训练数据的个数很多所以设置为None
定义loss function
# 深度学习模型的训练中使用交叉熵训练的效果比较好
loss_function = tf.reduce_mean(
                   tf.nn.softmax_cross_entropy_with_logits_v2
                       (logits=y_predict , 
                        labels=y_label))
选择optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=0.001) \
                    .minimize(loss_function)
#使用Loss_function来计算误差,并且按照误差更新模型权重与偏差,使误差最小化

定义评估模型的准确率

计算每一项数据是否正确预测
correct_prediction = tf.equal(tf.argmax(y_label  , 1),
                              tf.argmax(y_predict, 1))#将one-hot encoding转化为1所在的位数,方便比较
将计算预测正确结果,加总平均
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

开始训练

trainEpochs = 15#执行15个训练周期
batchSize = 100#每一批的数量为100
totalBatchs = int(mnist.train.num_examples/batchSize)#计算每一个训练周期应该执行的次数
epoch_list=[];accuracy_list=[];loss_list=[];
from time import time
startTime=time()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(trainEpochs):
    #执行15个训练周期
    #每个训练周期执行550批次训练
    for i in range(totalBatchs):
        batch_x, batch_y = mnist.train.next_batch(batchSize)#用该函数批次读取数据
        sess.run(optimizer,feed_dict={x: batch_x,
                                      y_label: batch_y})
        
    #使用验证数据计算准确率
    loss,acc = sess.run([loss_function,accuracy],
                        feed_dict={x: mnist.validation.images, #验证数据的features
                                   y_label: mnist.validation.labels})#验证数据的label

    epoch_list.append(epoch)
    loss_list.append(loss);accuracy_list.append(acc)    
    
    print("Train Epoch:", '%02d' % (epoch+1), \
          "Loss=","{:.9f}".format(loss)," Accuracy=",acc)
    
duration =time()-startTime
print("Train Finished takes:",duration)        
Train Epoch: 01 Loss= 133.117172241  Accuracy= 0.9194
Train Epoch: 02 Loss= 88.949943542  Accuracy= 0.9392
Train Epoch: 03 Loss= 80.701606750  Accuracy= 0.9446
Train Epoch: 04 Loss= 72.045913696  Accuracy= 0.9506
Train Epoch: 05 Loss= 71.911483765  Accuracy= 0.9502
Train Epoch: 06 Loss= 63.642936707  Accuracy= 0.9558
Train Epoch: 07 Loss= 67.192626953  Accuracy= 0.9494
Train Epoch: 08 Loss= 55.959281921  Accuracy= 0.9618
Train Epoch: 09 Loss= 58.867351532  Accuracy= 0.9592
Train Epoch: 10 Loss= 61.904548645  Accuracy= 0.9612
Train Epoch: 11 Loss= 58.283069611  Accuracy= 0.9608
Train Epoch: 12 Loss= 54.332244873  Accuracy= 0.9646
Train Epoch: 13 Loss= 58.152175903  Accuracy= 0.9624
Train Epoch: 14 Loss= 51.552104950  Accuracy= 0.9688
Train Epoch: 15 Loss= 52.803482056  Accuracy= 0.9678
Train Finished takes: 545.0556836128235
画出误差执行结果
%matplotlib inline
import matplotlib.pyplot as plt
fig = plt.gcf()#获取当前的figure图
fig.set_size_inches(4,2)#设置图的大小
plt.plot(epoch_list, loss_list, label = 'loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['loss'], loc='upper left')
<matplotlib.legend.Legend at 0x1edb8d4c240>

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

画出准确率执行结果
plt.plot(epoch_list, accuracy_list,label="accuracy" )
fig = plt.gcf()
fig.set_size_inches(4,2)
plt.ylim(0.8,1)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend()
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

评估模型的准确率

print("Accuracy:", sess.run(accuracy,
                           feed_dict={x: mnist.test.images, 
                                      y_label: mnist.test.labels}))
Accuracy: 0.9643

进行预测

prediction_result=sess.run(tf.argmax(y_predict,1),
                           feed_dict={x: mnist.test.images })
prediction_result[:10]
array([7, 2, 1, 0, 4, 1, 4, 9, 6, 9], dtype=int64)
import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,labels,
                                  prediction,idx,num=10):
    fig = plt.gcf()
    fig.set_size_inches(12, 14)
    if num>25: num=25 
    for i in range(0, num):
        ax=plt.subplot(5,5, 1+i)
        
        ax.imshow(np.reshape(images[idx],(28, 28)), 
                  cmap='binary')
            
        title= "label=" +str(np.argmax(labels[idx]))
        if len(prediction)>0:
            title+=",predict="+str(prediction[idx]) 
            
        ax.set_title(title,fontsize=10) 
        ax.set_xticks([]);ax.set_yticks([])        
        idx+=1 
    plt.show()
plot_images_labels_prediction(mnist.test.images,
                              mnist.test.labels,
                              prediction_result,0)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

y_predict_Onehot=sess.run(y_predict,
                          feed_dict={x: mnist.test.images })
y_predict_Onehot[8]
array([-6185.544  , -5329.589  ,  1897.1707 , -3942.7764 ,   347.9809 ,
        5513.258  ,  6735.7153 , -5088.5273 ,   649.2062 ,    69.50408],
      dtype=float32)

找出预测错误

for i in range(400):
    if prediction_result[i]!=np.argmax(mnist.test.labels[i]):
        print("i="+str(i)+"   label=",np.argmax(mnist.test.labels[i]),
              "predict=",prediction_result[i])
i=8   label= 5 predict= 6
i=18   label= 3 predict= 8
i=149   label= 2 predict= 4
i=151   label= 9 predict= 8
i=233   label= 8 predict= 7
i=241   label= 9 predict= 8
i=245   label= 3 predict= 5
i=247   label= 4 predict= 2
i=259   label= 6 predict= 0
i=320   label= 9 predict= 1
i=340   label= 5 predict= 3
i=381   label= 3 predict= 7
i=386   label= 6 predict= 5
sess.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/390705.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

微信网页版能够使用(会顶掉微信app的登陆)

一、文件结构 新建目录chrome新建icons&#xff0c;其中图片你自己找吧新建文件manifest.json新建文件wx-rules.json 二、文件内容 对应的png你们自己改下 1、manifest.json {"manifest_version": 3,"name": "wechat-need-web","author…

揭开Markdown的秘籍:引用|代码块|超链接

&#x1f308;个人主页&#xff1a;聆风吟 &#x1f525;系列专栏&#xff1a;Markdown指南、网络奇遇记 &#x1f516;少年有梦不应止于心动&#xff0c;更要付诸行动。 文章目录 &#x1f4cb;前言一. ⛳️Markdown 引用1.1 &#x1f514;引用1.2 &#x1f514;嵌套引用1.3 &…

网络原理-TCP_IP(6)

网络层 在复杂的网络环境中确定一个合适的路径. IP协议 与TCP协议并列,都是网络体系中最核心的协议. 基本概念 主机:配有IP地址,但是不进行路由控制的设备; 路由器:即配有IP地址,又能进行路由控制; 节点:主机和路由器的统称; 协议头格式 4位版本号(version):指定IP协议的版…

Vue的一些基础设置

1.浏览器控制台显示Vue 设置找到扩展&#xff0c;搜索Vue 下载这个 然后 点击扩展按钮 点击详细信息 选择这个&#xff0c;然后重启一下就好了 ——————————————————————————————————————————— 2.优化工程结构 src的components里要…

wayland(xdg_wm_base) + egl + opengles 纹理贴图进阶实例(四)

文章目录 前言一、使用gstreamer 获取 pattern 图片二、代码实例1. pattern 图片作为纹理数据源的代码实例1.1 基于opengles2.0 接口的 egl_wayland_texture2_1.c1.2 基于opengles3.0 接口的 egl_wayland_texture3_1.c2. xdg-shell-client-protocol.h 和 xdg-shell-protocol.c3…

IDEA中的神仙插件——Smart Input (自动切换输入法)

IDEA中的神仙插件——Smart Input &#xff08;自动切换输入法&#xff09; 设置 更多功能详见官方文档&#xff1a;Windows版SmartInput使用入门

面试经典150题——串联所有单词的子串(困难)

"Opportunities dont happen, you create them." ​ - Chris Grosser 1. 题目描述 2. 题目分析与解析 2.1 思路一——暴力求解 遇见这种可能刚开始没什么思路的问题&#xff0c;先试着按照人的思维来求解该题目。对于一个人来讲&#xff0c;我想要找到 s 字符串中…

Hive拉链表设计、实现、总结

水善利万物而不争&#xff0c;处众人之所恶&#xff0c;故几于道&#x1f4a6; 文章目录 环境介绍实现1. 初始化拉链表2. 后续拉链表数据的更新 总结彩蛋 - 想清空表的数据&#xff1a;转成内部表&#xff0c;清空数据后&#xff0c;再转成外部表&#xff0c;将分区目录删掉&am…

无心剑英译仓央嘉措《永在我心》

永在我心 Forever in My Heart 仓央嘉措 By Tsangyang Gyatso 这么多年 你一直在我心口幽居 我放下过天地 放下过万物 却从未放下过你 so many years slipped away you’ve been living in my heart I’ve dropped heaven and earth even dropped everything but never dr…

OWASP TOP10

OWASP TOP10 OWASP网址&#xff1a;http://ww.owasp.org.cn A01&#xff1a;失效的访问控制 例如&#xff1a;越权漏洞 案例1&#xff1a; 正常&#xff1a;每个人登录教务系统&#xff0c;只能查询自己的成绩信息 漏洞&#xff1a;张三登录后可以查看自己的成绩 例如&…

人工智能学习与实训笔记(五):神经网络之推荐系统处理

目录 ​​​​​​​七、智能推荐系统处理 7.1 常用的推荐系统算法 7.2 如何实现推荐​​​​​​​ 7.3 基于飞桨实现的电影推荐模型 7.3.1 电影数据类型 7.3.2 数据处理 7.3.4 数据读取器 7.3.4 网络构建 7.3.4.1用户特征提取 7.3.4.2 电影特征提取 7.3.4.3 相似度…

智能网卡(SmartNIC):增强网络性能

在当今的数字时代&#xff0c;网络性能和数据安全是各行各业面临的关键挑战。智能网卡是一项颠覆性的技术创新&#xff0c;对增强网络性能和加强数据安全性具有关键推动作用。本文旨在探讨智能网卡的工作原理及其在不同应用场景中的重要作用。 什么是智能网卡&#xff1f; 智…

Rust 基本环境安装

rust 基本介绍请看上一篇文章&#xff1a;rust 介绍 rustup 介绍 rustup 是 Rust 语言的安装器和版本管理工具。通过 rustup&#xff0c;可以轻松地安装 Rust 编译器&#xff08;rustc&#xff09;、标准库和文档。它也允许你切换不同的 Rust 版本或目标平台&#xff0c;以及…

太以假乱真了,大家小心

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

.NET Core MongoDB数据仓储和工作单元模式封装

前言 上一章我们把系统所需要的MongoDB集合设计好了&#xff0c;这一章我们的主要任务是使用.NET Core应用程序连接MongoDB并且封装MongoDB数据仓储和工作单元模式&#xff0c;因为本章内容涵盖的有点多关于仓储和工作单元的使用就放到下一章节中讲解了。仓储模式&#xff08;R…

SpringMVC速成(二)

文章目录 SpringMVC速成&#xff08;二&#xff09;1.SSM整合1.1 流程分析1.2 整合配置步骤1&#xff1a;创建Maven的web项目步骤2:添加依赖步骤3:创建项目包结构步骤4:创建SpringConfig配置类步骤5:创建JdbcConfig配置类步骤6:创建MybatisConfig配置类步骤7:创建jdbc.properti…

云计算基础-虚拟化概述

虚拟化概述 虚拟化是一种资源管理技术&#xff0c;能够将计算机的各种实体资源&#xff08;如CPU、内存、磁盘空间、网络适配器等&#xff09;予以抽象、转换后呈现出来并可供分割、组合为一个或多个逻辑上的资源。这种技术通过在计算机硬件上创建一个抽象层&#xff0c;将单台…

《UE5_C++多人TPS完整教程》学习笔记15 ——《P16 会话接口委托(Session Interface Delegates)》

本文为B站系列教学视频 《UE5_C多人TPS完整教程》 —— 《P16 会话接口委托&#xff08;Session Interface Delegates&#xff09;》 的学习笔记&#xff0c;该系列教学视频为 Udemy 课程 《Unreal Engine 5 C Multiplayer Shooter》 的中文字幕翻译版&#xff0c;UP主&#xf…

Matplotlib plt.scatter:从入门到精通,只需一篇文章!

Matplotlib plt.scatter&#xff1a;从入门到精通&#xff0c;只需一篇文章&#xff01;&#x1f680; 利用Matplotlib进行数据可视化示例 &#x1f335;文章目录&#x1f335; 一、plt.scatter入门&#xff1a;轻松迈出第一步 &#x1f463;二、进阶探索&#xff1a;plt.scatt…

大文件上传如何做断点续传?

文章目录 一、是什么分片上传断点续传 二、实现思路三、使用场景小结 参考文献 一、是什么 不管怎样简单的需求&#xff0c;在量级达到一定层次时&#xff0c;都会变得异常复杂 文件上传简单&#xff0c;文件变大就复杂 上传大文件时&#xff0c;以下几个变量会影响我们的用…