TensorFlow 简单的二分类神经网络的训练和应用流程

展示了一个简单的二分类神经网络的训练和应用流程。主要步骤包括:

1. 数据准备与预处理

2. 构建模型

3. 编译模型

4. 训练模型

5. 评估模型

6. 模型应用与部署

加载和应用已训练的模型


1. 数据准备与预处理

在本例中,数据准备是通过两个 Numpy 数组来完成的:

  • x:输入特征,形状为 (8, 2),包含 8 个数据点,每个数据点有 2 个特征。
  • y:标签,形状为 (8,),包含对应的 0 或 1 标签,表示每个输入点的类别。
x = np.array([[1, 1], [1, -1], [-1, 1], [-1, -1], [0.7, 0.7], [0.7, -0.7], [-0.7, -0.7], [-0.7, 0.7]])
y = np.array([1, 1, 1, 1, 0, 0, 0, 0])

2. 构建模型

使用 Keras 的 Sequential 模型来构建神经网络。此模型包含两个全连接层(Dense 层):

  • 第一个 Dense 层有 3 个单位,激活函数是 Sigmoid。
  • 第二个 Dense 层有 1 个单位,激活函数是 Sigmoid,输出层的激活函数将模型输出的值映射到 0 到 1 之间,适合二分类任务。
l1 = tf.keras.layers.Dense(units=3, activation='sigmoid')
l2 = tf.keras.layers.Dense(units=1, activation='sigmoid')
model = tf.keras.Sequential([l1, l2])

3. 编译模型

在编译阶段,我们选择了优化器、损失函数和评估指标:

  • 优化器:SGD(随机梯度下降),学习率设置为 0.9。
  • 损失函数:binary_crossentropy,适用于二分类任务。
  • 评估指标:accuracy,表示训练过程中对分类准确率的衡量。
sgd = tf.keras.optimizers.SGD(learning_rate=0.9)
model.compile(optimizer=sgd, loss='binary_crossentropy', metrics=['accuracy'])

4. 训练模型

通过 model.fit() 函数来训练模型。我们传入训练数据 x 和标签 y,并设置训练的 epoch 数量为 2000。

model.fit(x, y, epochs=2000)

5. 评估模型

在此示例中,评估部分通过训练后的 model 来进行,并没有显式写出 evaluate() 函数。评估通常是在训练之后,通过测试集或验证集对模型性能进行评估,具体可以使用 model.evaluate() 来查看损失和准确度。

6. 模型应用与部署

训练完成后,我们保存了训练好的模型。保存后的模型可以被加载和应用于新的数据集。

model.save('my_model.keras')  # 保存模型

7.加载和应用已训练的模型

加载保存的模型,并用其对新数据进行预测。model.predict() 方法返回的是预测的概率值,我们通过设置阈值(如 0.9)将其转换为类别(0 或 1)。

model = tf.keras.models.load_model('my_model.keras')  # 加载模型
nx = np.array([[2, 2], [0.1, 0.1], [1.1, 1.2], [0.3, 0.3]])  # 新的输入数据
predictions = model.predict(nx)  # 获取预测结果
print(predictions)  # 输出概率

# 将概率转化为类别
predicted_classes = (predictions > 0.9).astype(int)
print(predicted_classes)  # 输出最终的类别预测

8.完整代码
test.py 训练模型

import tensorflow as tf
import numpy as np
# 创建int32类型的0维张量,即标量
l1=tf.keras.layers.Dense(units=3,activation='sigmoid')
l2=tf.keras.layers.Dense(units=1,activation='sigmoid')
model=tf.keras.Sequential([l1,l2])
sgd = tf.keras.optimizers.SGD(learning_rate=0.9)
model.compile(optimizer=sgd, loss='binary_crossentropy', metrics=['accuracy'])
x=np.array([[1,1],[1,-1],[-1,1],[-1,-1],[0.7,0.7],[0.7,-0.7],[-0.7,-0.7],[-0.7,0.7]])
y=np.array([1,1,1,1,0,0,0,0])
model.fit(x,y,epochs=2000)
# 保存训练好的模型(Keras 格式)
model.save('my_model.keras')

 test2.py加载模型并进行预测:

import tensorflow as tf
import numpy as np

# 加载训练好的模型
model = tf.keras.models.load_model('my_model.keras')

# 预测数据
nx = np.array([[2, 2], [0.1, 0.1], [1.1, 1.2], [0.3, 0.3]])

# 获取预测结果
predictions = model.predict(nx)

# 输出预测结果
print(predictions)

# 如果需要将概率转化为类别(0或1)
predicted_classes = (predictions > 0.9).astype(int)

# 输出最终的类别预测
print(predicted_classes)

9.视频分享


初识TensorFlow 
https://v.douyin.com/ifG2mmLH/
复制此链接,打开Dou音搜索,直接观看视频!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/963090.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【B站保姆级视频教程:Jetson配置YOLOv11环境(四)cuda cudnn tensorrt配置】

Jetson配置YOLOv11环境(4)cuda cudnn tensorrt配置 文章目录 0. 简介1. cuda配置:添加cuda环境变量2. cudnn配置3. TensorRT Python环境配置3.1 系统自带Python环境中的TensorRT配置3.2 Conda 虚拟Python环境中的TensorRT配置 0. 简介 官方镜…

Python安居客二手小区数据爬取(2025年)

目录 2025年安居客二手小区数据爬取观察目标网页观察详情页数据准备工作:安装装备就像打游戏代码详解:每行代码都是你的小兵完整代码大放送爬取结果 2025年安居客二手小区数据爬取 这段时间需要爬取安居客二手小区数据,看了一下相关教程基本…

Electron使用WebAassembly实现CRC-8 MAXIM校验

Electron使用WebAssembly实现CRC-8 MAXIM校验 将C/C语言代码,经由WebAssembly编译为库函数,可以在JS语言环境进行调用。这里介绍在Electron工具环境使用WebAssembly调用CRC-8 MAXIM格式校验的方式。 CRC-8 MAXIM校验函数WebAssebly源文件 C语言实现CR…

DeepSeek-R1:通过强化学习激励大型语言模型(LLMs)的推理能力

摘要 我们推出了第一代推理模型:DeepSeek-R1-Zero和DeepSeek-R1。DeepSeek-R1-Zero是一个未经监督微调(SFT)作为初步步骤,而是通过大规模强化学习(RL)训练的模型,展现出卓越的推理能力。通过强…

pytorch基于FastText实现词嵌入

FastText 是 Facebook AI Research 提出的 改进版 Word2Vec,可以: ✅ 利用 n-grams 处理未登录词 比 Word2Vec 更快、更准确 适用于中文等形态丰富的语言 完整的 PyTorch FastText 代码(基于中文语料),包含&#xff1…

【hot100】刷题记录(8)-矩阵置零

题目描述: 给定一个 m x n 的矩阵,如果一个元素为 0 ,则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 示例 1: 输入:matrix [[1,1,1],[1,0,1],[1,1,1]] 输出:[[1,0,1],[0,0,0],[1,0,1]]示例 2…

PyTorch框架——基于深度学习YOLOv8神经网络学生课堂行为检测识别系统

基于YOLOv8深度学习的学生课堂行为检测识别系统,其能识别三种学生课堂行为:names: [举手, 读书, 写字] 具体图片见如下: 第一步:YOLOv8介绍 YOLOv8 是 ultralytics 公司在 2023 年 1月 10 号开源的 YOLOv5 的下一个重大更新版本…

Doki Doki Mods Maker小指南

-*- 做都做了,那就做到底吧。 -*- 前言: 项目的话,在莫盘里,在贴吧原帖下我有发具体地址。 这里是Doki Doki Mods Maker,是用来做DDLC Mods的小工具。 说是“Mods”,实则不然,这个是我从零仿…

nodejs:express + js-mdict 网页查询英汉词典

向 DeepSeek R1 提问: 我想写一个Web 前端网页,后台用 nodejs js-mdict, 实现在线查询英语单词 1. 项目结构 首先,创建一个项目目录,结构如下: mydict-app/ ├── public/ │ ├── index.html │ ├── st…

LabVIEW纤维集合体微电流测试仪

LabVIEW开发纤维集合体微电流测试仪。该设备精确测量纤维材料在特定电压下的电流变化,以分析纤维的结构、老化及回潮率等属性,对于纤维材料的科学研究及质量控制具有重要意义。 ​ 项目背景 在纤维材料的研究与应用中,电学性能是评估其性能…

dfs枚举问题

碎碎念:要开始刷算法题备战蓝桥杯了,一切的开头一定是dfs 定义 枚举问题就是咱数学上学到的,从n个数里面选m个数,有三种题型(来自Acwing) 从 1∼n 这 n个整数中随机选取任意多个,输出所有可能的选择方案。 把 1∼n这…

SOME/IP--协议英文原文讲解3

前言 SOME/IP协议越来越多的用于汽车电子行业中,关于协议详细完全的中文资料却没有,所以我将结合工作经验并对照英文原版协议做一系列的文章。基本分三大块: 1. SOME/IP协议讲解 2. SOME/IP-SD协议讲解 3. python/C举例调试讲解 Note: Thi…

leetcode——二叉树的中序遍历(java)

给定一个二叉树的根节点 root ,返回 它的 中序 遍历 。 示例 1: 输入:root [1,null,2,3] 输出:[1,3,2] 示例 2: 输入:root [] 输出:[] 示例 3: 输入:root [1] 输出…

91,【7】 攻防世界 web fileclude

进入靶场 <?php // 包含 flag.php 文件 include("flag.php");// 以高亮语法显示当前文件&#xff08;即包含这段代码的 PHP 文件&#xff09;的内容 // 方便查看当前代码结构和逻辑&#xff0c;常用于调试或给解题者提示代码信息 highlight_file(__FILE__);// 检…

Microsoft Power BI:融合 AI 的文本分析

Microsoft Power BI 是微软推出的一款功能强大的商业智能工具&#xff0c;旨在帮助用户从各种数据源中提取、分析和可视化数据&#xff0c;以支持业务决策和洞察。以下是关于 Power BI 的深度介绍&#xff1a; 1. 核心功能与特点 Power BI 提供了全面的数据分析和可视化功能&…

海外问卷调查,最常用到的渠道查有什么特殊之处

市场调研&#xff0c;包含市场调查和市场研究两个步骤&#xff0c;是企业和机构根据经营方向而做出的决策问题&#xff0c;最终通过海外问卷调查中的渠道查&#xff0c;来系统地设计、收集、记录、整理、分析、研究市场反馈的工作流程。 市场调研的工作流程包括&#xff1a;确…

《苍穹外卖》项目学习记录-Day10来单提醒

type&#xff1a;用来标识消息的类型&#xff0c;比如说type1表示来单提醒&#xff0c;type2表示客户催单。 orderId&#xff1a;表示订单id&#xff0c;因为不管是来单提醒还是客户催单&#xff0c;这一次提醒都对应一个订单。是用户下了某个单或者催促某个订单&#xff0c;这…

【全栈】SprintBoot+vue3迷你商城(10)

【全栈】SprintBootvue3迷你商城&#xff08;10&#xff09; 往期的文章都在这里啦&#xff0c;大家有兴趣可以看一下 后端部分&#xff1a; 【全栈】SprintBootvue3迷你商城&#xff08;1&#xff09; 【全栈】SprintBootvue3迷你商城&#xff08;2&#xff09; 【全栈】Sp…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.27 线性代数王国:矩阵分解实战指南

1.27 线性代数王国&#xff1a;矩阵分解实战指南 #mermaid-svg-JWrp2JAP9qkdS2A7 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-JWrp2JAP9qkdS2A7 .error-icon{fill:#552222;}#mermaid-svg-JWrp2JAP9qkdS2A7 .erro…

【愚公系列】《循序渐进Vue.js 3.x前端开发实践》030-自定义组件的插槽Mixin

标题详情作者简介愚公搬代码头衔华为云特约编辑&#xff0c;华为云云享专家&#xff0c;华为开发者专家&#xff0c;华为产品云测专家&#xff0c;CSDN博客专家&#xff0c;CSDN商业化专家&#xff0c;阿里云专家博主&#xff0c;阿里云签约作者&#xff0c;腾讯云优秀博主&…