查看TensorFlow已训模型的结构和网络参数

文章目录

    • 概要
    • 流程

概要

通过以下实例,你将学会如何查看神经网络结构并打印出训练参数。

流程

  • 准备一个简易的二分类数据集,并编写一个单层的神经网络
train_data = np.array([[1, 2, 3, 4, 5], 
                       [7, 7, 2, 4, 10], 
                       [1, 9, 3, 6, 5], 
                       [6, 7, 8, 9, 10]])

train_label = np.array([1, 0, 1, 0])  #标签与样本一一对齐


""" 定义一个单层的神经网络 """
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(1, activation=None)
])
  • 编译,训练,并保存模型
model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    optimizer='adam'
)
model.fit(train_data,
          train_label,
          epochs=2750)

tf.saved_model.save(model, "model_dir")  #保存到当前目录中,目录名为model_dir
  • 模型保存形式

模型节点和矩阵参数集中保存在 .data-00000-of-00001和 .index文件中,利用这两个文件中创建CheckpointReader对象。

  • 利用模型的Checkpoint对象查看模型结构和参数

Checkpoint对象存储了模型中所有可tracable追踪的对象,并记录保存着这些对象的参数及名称。可通过 tf.train.load_checkpoint()方法获得一个CheckpointReader对象,该对象可以读取Checkpoint内的所有信息。

"""  最后的variables是.data-00000-of-00001和 .index文件去掉后缀后的表达形式,
     从而统一代表着这两个文件"""
save_path = './model_dir/variables/variables'  # 

reader = tf.train.load_checkpoint(save_path)  # 得到CheckpointReader

"""  打印Checkpoint中存储的所有参数名和参数shape """
for variable_name, variable_shape in reader.get_variable_to_shape_map().items():
    print(f'{variable_name} : {variable_shape}')
 

optimizer/_variables/2/.ATTRIBUTES/VARIABLE_VALUE : [5, 1]
optimizer/_iterations/.ATTRIBUTES/VARIABLE_VALUE : []
_CHECKPOINTABLE_OBJECT_GRAPH : []
keras_api/metrics/0/count/.ATTRIBUTES/VARIABLE_VALUE : []
keras_api/metrics/0/total/.ATTRIBUTES/VARIABLE_VALUE : []
layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE : [1]
layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE : [5, 1]
optimizer/_variables/1/.ATTRIBUTES/VARIABLE_VALUE : [5, 1]
optimizer/_learning_rate/.ATTRIBUTES/VARIABLE_VALUE : []
optimizer/_variables/3/.ATTRIBUTES/VARIABLE_VALUE : [1]
optimizer/_variables/4/.ATTRIBUTES/VARIABLE_VALUE : [1]

其中Dense层的权重参数和偏差bias的显示信息为,

layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE : [1]
layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE : [5, 1]

接着利用刚刚打印出的参数名即可查看其参数值,

print(reader.get_tensor('layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE'))
print(reader.get_tensor("layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE"))


[[-1.7741445 ]
 [-0.07314294]
 [-0.07213379]
 [ 1.1694099 ]
 [-0.36803177]]

[1.7487208]

  • 验证
model = tf.saved_model.load('model_dir')
print(model([[1, 2, 3, 4, 5]]))
output = -1.7741445 - 2*0.07314294 - 3*0.07213379 + 4*1.1694099 - 5*0.36803177+1.7487208
print(output)


tf.Tensor([[2.4493697]], shape=(1, 1), dtype=float32)

2.4493698000000004

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/532569.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MySQL高级(索引结构Hash,为什么InnoDB存储引擎选择使用B+tree索引结构?)

目录 1、Hash索引结构 2、Hash索引特点 3、存储引擎支持 4、为什么InnoDB存储引擎选择使用Btree索引结构? 1、Hash索引结构 哈希索引就是采用一定的hash算法,将键值换算成新的hash值,映射到对应的槽位上,然后存储在hash表中。 如…

吴恩达机器学习-异常检测(Anomaly Detection)

在本练习中,您将实现异常检测算法,并将其应用于检测网络上出现故障的服务器。 文章目录 1-包2-异常检测2.1问题陈述2.2数据集2.3高斯分布2.2.1高斯实现的估计参数:2.2.2选择阈值𝜖 2.4高维数据集 1-包 首先,让我们运…

脑电放大 LM386

LM386介绍 LM386 是一种音频集成功放,具有自身功耗低、电压增益可调整电源电压范围大、外接元件少和总谐波失真小等优点,广泛应用于录音机和收音机之中。 电源电压 4-12V 或 5-18V(LM386N-4);静态消耗电流为 4mA;电压增益为20-200dB;在引脚1和8开路时&a…

Android开发基础:事件传递 基于监听器的事件处理 基于回调的事件处理

目录 一,事件传递机制 1.事件传递机制的三个方法 (1)public boolean dispatchTouchEvent(MotionEvent event) (2)public boolean onInterceptTouchEvent(MotionEvent event&…

【C++题解】1601. 挖胡萝卜

问题:1601. 挖胡萝卜 类型:基本运算、小数运算 题目描述: 小兔朱迪挖了 x 个胡萝卜,狐狸尼克挖到胡萝卜数量是小兔挖到的 3 倍,小羊肖恩挖到胡萝卜的数量比狐狸尼克少 8 个。 请你编程计算一下狐狸尼克和小羊肖恩分别…

时间系列预测总结

转载自:https://mp.weixin.qq.com/s/B1eh4IcHTnEdv2y0l4MCog 拥有一种可靠的方法来预测和预测未来事件一直是人类的愿望。在数字时代,我们拥有丰富的信息,尤其是时间序列数据。 时间序列是指基于时间刻度维度(天、月、年等&…

Mybatis plus 使用通用枚举

说明&#xff1a;mybatis plus 使用枚举可实现数据库存入时指定值保存&#xff0c; 读取时指定值展示&#xff08;返给前端&#xff09; 可通过继承IEnum<T>、 EnumValue实现 1、引包 <dependency><groupId>mysql</groupId><artifactId>mysql-…

java基础语法(16)| 集合

前言 Hello,大家好!很开心与你们在这里相遇,我是一个喜欢文字、喜欢有趣的灵魂、喜欢探索一切有趣事物的女孩,想与你们共同学习、探索关于IT的相关知识,希望我们可以一路陪伴~ 1. 集合概述 什么是集合 集合:集合是java中提供的一种容器,可以用来存储多个数据,并且可以存…

每天五分钟深度学习PyTorch:面对Tensorflow,为何我选择PyTorch

这篇专栏文章不是为了挑起tenserflow和pytorch中哪个更好&#xff0c;众所周知tensorflow诞生以来&#xff0c;已经成为最流行的深度学习框架&#xff0c;可以说github中大多数的深度学习代码实现是以tensorflow实现的&#xff0c;也就是说资源众多&#xff0c;社区强大&#x…

自动化测试十大必备(背)面试题!【含答案精讲】

&#x1f525; 交流讨论&#xff1a;欢迎加入我们一起学习&#xff01; &#x1f525; 资源分享&#xff1a;耗时200小时精选的「软件测试」资料包 &#x1f525; 教程推荐&#xff1a;火遍全网的《软件测试》教程 &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1…

OJ 最大奖励 C Python【贪心算法】【动态规划】

又接触到贪心算法啦&#xff0c;这道题有两种算法思路&#xff0c;我用两个语言来写了一下&#xff0c;这也涉及到了一些动态规划的思路 一.从最后一个时间枚举&#xff0c;找到在这个时间内可以完成的最大分值的题 注意点&#xff1a; 1.数组下标从1开始记录表示第几个时间…

渲染农场实时画面怎么设置?云渲染农场实时预览效果查看

许多用户在使用渲染农场服务时&#xff0c;常常难以找到查看实时渲染画面的功能。由于渲染是一个时间消耗较大的任务&#xff0c;如果最终结果与预期不符&#xff0c;可能会对整个工作流程产生负面影响。因此&#xff0c;渲染平台若能提供实时预览渲染进度和效果的功能&#xf…

冯喜运:4.10晚间黄金原油走势分析

黄金消息技术面分析&#xff1a;美国CPI年率创半年新高&#xff0c;美国3月未季调CPI年率录得3.5%&#xff0c;高于预期的3.4%水平&#xff0c;为2023年9月以来最高水平。美国CPI高于预期&#xff0c;现货黄金短线下挫16美元。日线当前的指标macd依旧属于金叉放量运行&#xff…

Spring与SpringBoot的区别

Spring是一个开源的Java应用程序框架&#xff0c;旨在简化企业级Java应用程序的开发。它提供了一个轻量级的容器&#xff0c;用于管理应用程序中的各个组件&#xff08;如依赖注入、AOP等&#xff09;&#xff0c;并提供了丰富的功能和模块&#xff0c;用于处理数据库访问、事务…

提醒|2024年CSC国家公派访问学者项目开始网申(附常见申报问题解答)

留学基金委&#xff08;CSC&#xff09;2024年国家公派高级研究学者、访问学者项目网上申报时间为4月10日—4月30日。为此&#xff0c;知识人网小编提醒申请者及时申报。本文我们将常见申报问题汇总解答&#xff0c;以帮助申请者顺利完成CSC申报工作&#xff0c;并预祝红榜题名…

python pygame事件与事件处理

本期是接上期python pygame库的略学内容最后一个步骤&#xff0c;游戏与玩家交互的内容。 一、什么是事件 游戏需要与玩家交互&#xff0c;因此它必须能够接收玩家的操作&#xff0c;并根据玩家的不同操作做出有针对性的响应。程序开发中将玩家会对游戏进行的操作称为事件&…

rk3588开发板上安装ssh服务

目的&#xff1a;实现远程访问和控制&#xff0c;其他主机远程控制rk3588 方法及操作步骤&#xff1a; 1&#xff09;安装&#xff1a;sudo apt install openssh-server 2&#xff09; 查看运行状态 sudo systemctl status ssh 其它主机远程连接该开发板的ip和端口22即可

MVP模式

1、创建数据库表单对应的实体类。 package com.mvp.model; //Model(模型)&#xff0c;数据库表单对应的实体类。 public class Word {private int id;private String engName;private String chiVal;private String lastUsedTime;private int usedTimes;private String create…

【华为笔试题汇总】2024-04-10-华为春招笔试题-三语言题解(Python/Java/Cpp)

&#x1f36d; 大家好这里是KK爱Coding &#xff0c;一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为近期的春秋招笔试题汇总&#xff5e; &#x1f4bb; ACM银牌&#x1f948;| 多次AK大厂笔试 &#xff5c; 编程一对一辅导 &#x1f44f; 感谢大家的订阅➕ 和 喜欢&#x1f…

uniapp 轮播列表一排展示3个,左右滑动,滑动到中间放大

一、效果展示 二、代码实现 1.html代码&#xff1a; <!-- 轮播 --><view class"heade"><swiper class"swiper" display-multiple-items3 circulartrue previous-margin1rpx next-margin1rpxcurrent0 change"swiperChange">&l…