Python深度学习实战-基于tensorflow原生代码搭建BP神经网络实现分类任务(附源码和实现效果)

实现功能

前面两篇文章分别介绍了两种搭建神经网络模型的方法,一种是基于tensorflow的keras框架,另一种是继承父类自定义class类,本篇文章将编写原生代码搭建BP神经网络。

实现代码

import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target

# 数据预处理
scaler = StandardScaler()
X = scaler.fit_transform(X)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 设置超参数
learning_rate = 0.001
num_epochs = 100
batch_size = 32

# 定义输入和输出的维度
input_dim = X.shape[1]
output_dim = len(set(y))

# 定义权重和偏置项
W1 = tf.Variable(tf.random.normal(shape=(input_dim, 64), dtype=tf.float64))
b1 = tf.Variable(tf.zeros(shape=(64,), dtype=tf.float64))
W2 = tf.Variable(tf.random.normal(shape=(64, 64), dtype=tf.float64))
b2 = tf.Variable(tf.zeros(shape=(64,), dtype=tf.float64))
W3 = tf.Variable(tf.random.normal(shape=(64, output_dim), dtype=tf.float64))
b3 = tf.Variable(tf.zeros(shape=(output_dim,), dtype=tf.float64))


# 定义前向传播函数
def forward_pass(X):
    X = tf.cast(X, tf.float64)
    h1 = tf.nn.relu(tf.matmul(X, W1) + b1)
    h2 = tf.nn.relu(tf.matmul(h1, W2) + b2)
    logits = tf.matmul(h2, W3) + b3
    return logits

# 定义损失函数
def loss_fn(logits, labels):
    return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits))

# 定义优化器
optimizer = tf.optimizers.Adam(learning_rate)

# 定义准确率指标
accuracy_metric = tf.metrics.SparseCategoricalAccuracy()

# 定义训练步骤
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits = forward_pass(inputs)
        loss_value = loss_fn(logits, labels)
    gradients = tape.gradient(loss_value, [W1, b1, W2, b2, W3, b3])
    optimizer.apply_gradients(zip(gradients, [W1, b1, W2, b2, W3, b3]))
    accuracy_metric(labels, logits)
    return loss_value

# 进行训练
for epoch in range(num_epochs):
    epoch_loss = 0.0
    accuracy_metric.reset_states()

    for batch_start in range(0, len(X_train), batch_size):
        batch_end = batch_start + batch_size
        batch_X = X_train[batch_start:batch_end]
        batch_y = y_train[batch_start:batch_end]

        loss = train_step(batch_X, batch_y)
        epoch_loss += loss

    train_loss = epoch_loss / (len(X_train) // batch_size)
    train_accuracy = accuracy_metric.result()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.4f}")

# 进行评估
logits = forward_pass(X_test)
test_loss = loss_fn(logits, y_test)
test_accuracy = accuracy_metric(y_test, logits)

print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

实现效果

本人读研期间发表5篇SCI数据挖掘相关论文,现在某研究院从事数据挖掘相关科研工作,对数据挖掘有一定认知和理解,会结合自身科研实践经历不定期分享关于python、机器学习、深度学习基础知识与案例。

致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。

邀请三个朋友关注V订阅号:数据杂坛,即可在后台联系我获取相关数据集和源码,送有关数据分析、数据挖掘、机器学习、深度学习相关的电子书籍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/107226.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

在CentOS 7中手工打造和运行xml文件配置的Servlet,然后使用curl、浏览器、telnet等三种工具各自测试

下载Openjdk并配置环境变量 https://jdk.java.net/java-se-ri/11-MR2是官网下载Openjdk 11的地方。 sudo wget https://download.java.net/openjdk/jdk11.0.0.1/ri/openjdk-11.0.0.1_linux-x64_bin.tar.gz下载openjdk 11。 sudo mkdir -p /usr/openjdk11创建目录&#xff…

一张图系列 - “kv cache“

我觉得回答这个问题需要知道3个知识点: 1、multi-head-attention是如何计算的?attention的数学公式? kv cache是如何存储和传递的? 2、kv cache 的原理步骤是什么?为什么降低了消耗? 3、kv cache 代码模…

C++:stl中set(multiset)和map(multimap)的介绍和使用

本文主要从概念、常用接口和使用方法方面介绍set(multiset)和map(multimap)。 目录 一、概念介绍 1.关联式容器 2.键值对 3. 树形结构的关联式容器 二、set和multiset 1.set的介绍 2.set使用 1. set模板参数列表 2. set构造 3. set迭代器 4. set容量 5. set修改操…

正则表达式包含数字和字符匹配

至少6位。 pattern : (?.[0-9])(?.[A-Za-z])[0-9A-Za-z]{6,} 正则表达式中的“?”是一个正向预查字符,它的意思是匹配前一个字符出现的最少一次。具体来说,当一个匹配出现时,它会检查前一个字符是否符合要求,如果符合&#xf…

使用一个Series序列减去另一个Series序列Series.subtract()

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 求两个序列中对应位置 的各元素的差 a.subtract(b) [太阳]选择题 关于以下代码的说法中正确的是? import pandas as pd a pd.Series([1,2,3]) print("【显示】a:\n",a) b pd.Seri…

Windows下安装Anaconda、Pycharm以及iflycode插件图解

目录 一、下载Anaconda、Pycharm以及iflycode插件 二、创建相关文件夹 三、Pycharm社区版安装详细步骤 四、Anaconda安装详细步骤 五、配置Pycharm 六、安装iflycode插件 Anaconda是一款集成的Python环境,anaconda可以看做Python的一个集成安装,安…

人工智能基础_机器学习007_高斯分布_概率计算_最小二乘法推导_得出损失函数---人工智能工作笔记0047

这个不分也是挺难的,但是之前有详细的,解释了,之前的文章中有, 那么这里会简单提一下,然后,继续向下学习 首先我们要知道高斯分布,也就是,正太分布, 这个可以预测x在多少的时候,概率最大 要知道在概率分布这个,高斯分布公式中,u代表平均值,然后西格玛代表标准差,知道了 这两个…

由于找不到emp.dll无法继续执行此代码问题的五个解决方法

在玩游戏的过程中,我们常常会遇到一些错误提示,其中最常见的就是“找不到emp.dll”,这个问题我们的游戏无法启动运行。本文将分享我在解决这一问题过程中的方法,希望能对遇到类似问题的玩家有所帮助。 emp.dll是一个动态链接库文件…

Typecho 添加 Emoji 表情报错「解决方案」

Typecho 添加 Emoji 表情报错 文章目录 Typecho 添加 Emoji 表情报错前言Emoji 表情utf8mb4 与 UTF8 解决方案[1] 数据库编码更改[2] 数据库配置文件更改 前言 Typecho 添加 Emoji 表情不支持,报错 Database Query Error Emoji 表情 Emoji 就是表情符号&#xff0c…

SiC器件概念

来源:A SiC Trench MOSFET concept offering improved channel mobility and high reliability SiC MOSFET设计挑战 虽然碳化硅的使用由于是一种宽带隙材料而具有许多优点,但与硅也存在一些值得注意的差异,这导致在制造基于4H-SiC多晶型的Si…

vscode markdown 使用技巧 -- 如何快速打出一个Tab 或多个空格

背景描述: 我在使用VSCode,这玩意很好用,但是,有一个缺点是,我想使用Tab来做一些对齐,但是我发现在VSCode中,无论是Tab还是多个空格,最终显示出来的都是一个空格 使用代码可以实现打…

虹科 | 解决方案 | 汽车示波器 学校教学方案

虹科Pico汽车示波器是基于PC的设备,特别适用于大课堂的教学、备课以及与师生的互动交流。老师展现讲解波形数据,让学生直观形象地理解汽车的工作原理 高效备课 课前实测,采集波形数据,轻松截图与标注,制作优美的课件&…

【psychopy】【脑与认知科学】认知过程中的面孔识别加工

目录 实验描述 实验思路 python实现 实验描述 现有的文献认为,人们对倒置的面孔、模糊的面孔等可能会出现加工时长增加、准确率下降的问题,现请你设计一个相关实验,判断不同的面孔是否会出现上述现象。请按照认知科学要求,画…

1819_ChibiOS的互斥信号与条件变量

全部学习汇总: GreyZhang/g_ChibiOS: I found a new RTOS called ChibiOS and it seems interesting! (github.com) 1. 关于会吃信号与条件变量的全局配置提供了4个配置信息,分别是互斥信号的使能、互斥信号的递归支持、条件变量的使能、条件变量的超时使…

10.28总结

目录 一.发布作业 二.写作业 三.批改作业 一.发布作业 点击简答题时———listvie<String>题目列表会新增一个题目 保存该题时———— 获取TextArea的文本,为list当前选中的对象赋值 发布日期不能为过往日期&#xff0c;截止日期不能晚于发布日期。——为发布日期设置…

Linux系统编程_网络编程:字节序、socket、serverclient、ftp 云盘

1. 网络编程概述&#xff08;444.1&#xff09; TCP/UDP对比 TCP 面向连接&#xff08;如打电话要先拨号建立连接&#xff09;&#xff1b;UDP 是无连接的&#xff0c;即发送数据之前不需要建立连接TCP 提供可靠的服务。也就是说&#xff0c;通过 TCP 连接传送的数据&#xf…

C#使用mysql-connector-net驱动连接mariadb报错

给树莓派用最新的官方OS重刷了一下&#xff0c;并且用apt install mariadb-server装上“mysql”作为我的测试服务器。然后神奇的事情发生了&#xff0c;之前用得好好的程序突然就报错了&#xff0c;经过排查&#xff0c;发现在连接数据库的Open阶段就报错了。写了个最单纯的Con…

Wpf 使用 Prism 实战开发Day01

一.开发环境准备 1. VisualStudio 2022 2. .NET SDK 7.0 3. Prism 版本 8.1.97 以上环境&#xff0c;如有新的版本&#xff0c;可自行选择安装新的版本为主 二.创建Wpf项目 1.项目的名称:MyToDo 项目名称:这里只是记录学习&#xff0c;所以随便命名都无所谓,只要觉得合理就…

水声功率放大器的应用场景是什么

水声功率放大器是一种专门用于水声信号处理和传输的设备&#xff0c;通过放大水声信号的功率&#xff0c;以实现远距离传播和提高信号的清晰度和可辨识度。下面是关于水声功率放大器应用场景的详细解释&#xff1a; 水声通信&#xff1a;水声通信是一种在水下进行声波传输的技术…

sharepoint2016-2019升级到sharepoint订阅版

一、升级前准备&#xff1a; 要建立新的sharepoint订阅版环境&#xff0c;需求如下&#xff1a; 1.单服务器硬件需求CPU 4核&#xff0c;内存24G以上&#xff0c;硬盘300G&#xff08;根据要迁移的数量来扩容大小等&#xff09;&#xff1b; 2.操作系统需要windows server 20…