Tensorflow2.0笔记 - 不使用layer方式,简单的MNIST训练

        本笔记不使用layer相关API,搭建一个三层的神经网络来训练MNIST数据集。

        前向传播和梯度更新都使用最基础的tensorflow API来做。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
import numpy as np

def load_mnist():
    path = r'./mnist.npz' #放置mnist.py的目录。注意斜杠
    f = np.load(path)
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']
    f.close()
    return (x_train, y_train), (x_test, y_test)

#加载mnist数据集
#X_train: [60000, 28, 28] 图片
#Y_train: [60000] 标签
#mnist数据集下载:https://blog.csdn.net/charles_neil/article/details/107851880
#                https://www.zhihu.com/question/56773355
(X_train,Y_train),(X_test,Y_test) = load_mnist()

#转换为tensor
#图片数据值转换到0-1
x = tf.convert_to_tensor(X_train, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(Y_train, dtype=tf.int32)
print(x.shape,y.shape)
print(tf.reduce_min(x), tf.reduce_max(x))
print(tf.reduce_min(y), tf.reduce_max(y))

#数据集切分为多个batch
train_db = tf.data.Dataset.from_tensor_slices((x,y)).batch(128)
train_iter = iter(train_db)

sample = next(train_iter)
print(sample[0].shape, sample[1].shape)


#学习率
lr = 0.1
#用三个神经元,[b:784] => [b,256] => [b,128] => [b,10]
w1 = tf.Variable(tf.random.truncated_normal([784,256], stddev=0.1))
b1 = tf.Variable(tf.zeros([256]))
w2 = tf.Variable(tf.random.truncated_normal([256,128], stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
w3 = tf.Variable(tf.random.truncated_normal([128,10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10]))

for epoch in range(10):
    print("[==================Epoch ", epoch, "========================]")
    for step, (x,y) in enumerate(train_db):
        x = tf.reshape(x, [-1, 28*28])
        #对标签进行onehot编码
        y_onehot = tf.one_hot(y, depth=10)
    
        with tf.GradientTape() as tape:
            #第一层,输入x [128,784]
            #x@w + b: [batch, 784] [784,256] + [256] => [batch,256]
            h1 = x@w1 + b1
            h1 = tf.nn.relu(h1)
            #第二层:[batch, 256] => [batch, 128]
            h2 = h1@w2 + b2
            h2 = tf.nn.relu(h2)
            #输出层:[batch,128] => [batch,10]
            out = h2@w3 + b3
        
            #计算损失
            #使用MSE: mean(sum(y - out)^2)
            loss = tf.reduce_mean(tf.square(y_onehot - out))
        #计算梯度
        grads = tape.gradient(loss, [w1,b1,w2,b2,w3,b3])
        #更新w和b: w = w - lr * w_grad
        w1.assign_sub(lr * grads[0])
        b1.assign_sub(lr * grads[1])
        w2.assign_sub(lr * grads[2])
        b2.assign_sub(lr * grads[3])
        w3.assign_sub(lr * grads[4])
        b3.assign_sub(lr * grads[5])
        
        if (step % 100 == 0):
            print("Batch:", step, "loss:", float(loss))

        运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/334880.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Modern C++ 条件变量

今天无意中看到一篇帖子&#xff0c;关于条件变量的&#xff0c;不过仔细看看发现它并达不到原本的目的。 程序如下&#xff0c;读者可以先想想他的本意&#xff0c;以及有没有问题&#xff1a; #include <iostream> #include <thread> #include <condition_v…

无刷电机学习-原理篇

一、无刷电机的优点 使用一项东西首先就要明白为什么要使用它&#xff0c;使用它有什么优点。与有刷电机相比无刷电机除了控制繁琐几乎全是优点。 1、应用范围广&#xff1a;家用电器&#xff08;冰箱空调压缩机、洗衣机、水泵等&#xff09;、汽车、航空航天、消费品工业自动…

STM32之002--软件安装 Keil

文章目录&#xff1a; 一、安装 Keil 二、注册 三、安装芯片支持包 一、安装 Keil 重点 1&#xff1a; 安装时&#xff0c;不能使用中文路径&#xff0c;否则无法正常使用!! 重点 2&#xff1a; 不要安装 V5.36 及以上的版本&#xff0c;其默认AC6编译器&#xff0c…

(二)基于wpr_simulation 的Ros机器人运动控制,gazebo仿真

一、创建工作空间 mkdir catkin_ws cd catkin_ws mkdir src cd src 二、下载wpr_simulation源码 git clone https://github.com/6-robot/wpr_simulation.git 三、编译 ~/catkin_make 目录下catkin_makesource devel/setup.bash 四、运行 roslaunch wpr_simulation wpb_s…

priority_queue的使用与模拟实现(容器适配器+stack与queue的模拟实现源码)

priority_queue的使用与模拟实现 引言&#xff08;容器适配器&#xff09;priority_queue的介绍与使用priority_queue介绍接口使用默认成员函数 size与emptytoppush与pop priority_queue的模拟实现构造函数size与emptytoppush与pop向上调整建堆与向下调整建堆向上调整建堆向下调…

ssh: connect to host github.com port 22: Connection refused

ssh: connect to host github.com port 22: Connection refused 问题现象 本文以Windows系统为例进行说明&#xff0c;在个人电脑上使用Git命令来操作GitHub上的项目&#xff0c;本来都很正常&#xff0c;突然某一天开始&#xff0c;会提示如下错误ssh: connect to host gith…

通讯录项目的实现以及动态顺序表(基于顺序表)

首先我们要知道什么是顺序表: 顺序表的底层结构是数组,对数组的封装,实现了常⽤的增删改查等接⼝,顺序表分为静态顺序表(使⽤定⻓数组存储元素)和动态顺序表(按需申请) 静态顺序表缺点: 空间给少了不够⽤,给多了造成空间浪费 拿出来我之前以及写好了的顺序表的代码:…

LeetCode、162. 寻找峰值【中等,最大值、二分】

文章目录 前言LeetCode、162. 寻找峰值【中等&#xff0c;最大值、二分】题目及类型思路及代码思路1&#xff1a;二分思路2&#xff1a;寻找最大值 资料获取 前言 博主介绍&#xff1a;✌目前全网粉丝2W&#xff0c;csdn博客专家、Java领域优质创作者&#xff0c;博客之星、阿…

微信小程序(七)navigator点击效果

注释很详细&#xff0c;直接上代码 上一篇 新增内容&#xff1a; 1.默认效果 2.无效果 3.激活效果 源码&#xff1a; index.wxml //如果 <navigator url"/pages/logs/logs">跳转到log页面&#xff08;默认&#xff09; </navigator><navigator url&q…

ctfshow命令执行(web29-web52)

目录 web29 web30 web31 web32 web33 web34 web35 web36 web37 web38 web39 web40 web41 web42 web43 web44 web45 web46 web47 web48 web49 web50 web51 web52 web29 <?php error_reporting(0); if(isset($_GET[c])){$c $_GET[c];if(!preg_match…

Rust基础语法1

所有权转移&#xff0c;Rust中没有垃圾收集器&#xff0c;使用所有权规则确保内存安全&#xff0c;所有权规则如下&#xff1a; 1、每个值在Rust中都有一个被称为其所有者&#xff08;owner&#xff09;的变量&#xff0c;值在任何时候只能有一个所有者。 2、当所有者离开作用域…

MacBookPro怎么数据恢复? mac电脑数据恢复?

使用电脑的用户都知道&#xff0c;被删除的文件一般都会经过回收站&#xff0c;想要恢复它直接点击“还原”就可以恢复到原始位置。mac电脑同理也是这样&#xff0c;但是“回收站”在mac电脑显示为“废纸篓”。 如果电脑回收站&#xff0c;或者是废纸篓里面的数据被清空了&…

【GitHub项目推荐--微软开源的可视化工具】【转载】

说到数据可视化&#xff0c;大家都很熟悉了&#xff0c;设计师、数据分析师、数据科学家等&#xff0c;都需要用各种方式各种途径做着数据可视化的工作.....当然许多程序员在工作中有时也需要用到一些数据可视化工具&#xff0c;如果工具用得好&#xff0c;就可以把原本枯燥凌乱…

springcloud Client端cloud-consumer-order80

文章目录 简介建立module修改pom修改yml主启动类把公共代码写在一个mudule 里面测试 简介 这个是和之前的8001相互配合端口测试 这里的80的用户测试端口。 代码在&#xff1a;GitHub 上&#xff1a;https://github.com/13thm/study_springcloud/tree/main/days2 建立module …

C++无锁队列的原理与实现

目录 1.无锁队列原理 1.1.队列操作模型 1.2.无锁队列简介 1.3.CAS操作 2.无锁队列方案 2.1.boost方案 2.2.ConcurrentQueue 2.3.Disruptor 3.无锁队列实现 3.1.环形缓冲区 3.2.单生产者单消费者 3.3.多生产者单消费者 3.4.RingBuffer实现 3.5.LockFreeQueue实现 …

django admin后台中进行多个手机号解密消耗时间对比

需求&#xff1a; 1 手机号在数据库中是使用rsa方式加密存储&#xff0c;后台查看中需要转换为明文&#xff0c;因为需要解密多个手机号&#xff0c;所以在后台查看中消耗时间3秒&#xff0c;希望通过多线程&#xff0c;多进程&#xff0c;异步方式来缩短时间 相关注意点&…

如何实现查找附近的人-GEO

背景 打开美团&#xff0c;可以通过自身定位查看附近的商品。打开社交软件&#xff0c;可以查看附近的人交友。打开滴滴&#xff0c;可以查看的附近的共享单车&#xff0c;那这些是如何实现&#xff1f; Redis GEO Redis GEO 主要用于存储地理位置信息&#xff0c;并对存储的…

Ubuntu系统Git的安装配置及使用笔记(更新中)

Ubuntu下Git的下载及配置 (1)、下载git 打开终端命令窗口,输入&#xff1a;sudo apt-get install git 提示&#xff1a;sudo命令是用来以其他身份来执行命令&#xff0c;预设的身份为root,使用sudo时必须先输入密码 (2)、可以使用命令git --version查看git的版本号 (3)、设置…

排序链表(LeetCode 148)

文章目录 1.问题描述2.难度等级3.热门指数4.解题思路参考文献 1.问题描述 给你链表的头结点 head &#xff0c;请将其按 升序 排列并返回 排序后的链表 。 示例 1&#xff1a; 输入&#xff1a;head [4,2,1,3] 输出&#xff1a;[1,2,3,4]示例 2&#xff1a; 输入&#xff…

git中合并分支时出现了代码冲突怎么办

目录 第一章、Git代码冲突介绍1.1&#xff09;什么是Git代码冲突①git merge命令介绍②代码冲突原因 1.2&#xff09;提示代码冲突的两种情况①本地不同分支的文件有差异时&#xff1a;②本地仓库和git远程仓库的文件有差异时&#xff1a; 1.3&#xff09;解决合并时的代码冲突…