【机器学习】基于tensorflow实现你的第一个DNN网络

博客导读:

《AI—工程篇》

AI智能体研发之路-工程篇(一):Docker助力AI智能体开发提效

AI智能体研发之路-工程篇(二):Dify智能体开发平台一键部署

AI智能体研发之路-工程篇(三):大模型推理服务框架Ollama一键部署

AI智能体研发之路-工程篇(四):大模型推理服务框架Xinference一键部署

AI智能体研发之路-工程篇(五):大模型推理服务框架LocalAI一键部署

《AI—模型篇》

AI智能体研发之路-模型篇(一):大模型训练框架LLaMA-Factory在国内网络环境下的安装、部署及使用

AI智能体研发之路-模型篇(二):DeepSeek-V2-Chat 训练与推理实战

AI智能体研发之路-模型篇(三):中文大模型开、闭源之争

AI智能体研发之路-模型篇(四):一文入门pytorch开发

AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比

目录

一、引言

二、tensorflow介绍

2.1 tensorflow历史

2.2 tensorflow特点

 2.3 tensorflow安装

三、tensorflow实战

3.1 引入依赖的tensorflow库

3.2 训练数据准备

3.3 创建三层DNN模型

3.4 编译模型、定义损失函数与优化器

3.5 启动训练,迭代收敛

3.6 模型评估

3.7 可以直接跑的代码 

四、总结


一、引言

上一篇AI智能体研发之路-模型篇(四):一文入门pytorch开发介绍如何使用pytorch实现一个简单的DNN网络,今天我们还是用同样的例子,看看使用tensorflow如何实现。

二、tensorflow介绍

2.1 tensorflow历史

TensorFlow由谷歌人工智能团队谷歌大脑(Google Brain)开发和维护,拥有包括TensorFlow Hub、TensorFlow Lite、TensorFlow Research Cloud在内的多个项目以及各类应用程序接口(Application Programming Interface, API)。自2015年11月9日起,TensorFlow依据阿帕奇授权协议(Apache 2.0 open source license)开放源代码。

2.2 tensorflow特点

深度学习时代,tensorflow在工业应用较为广泛,而pytorch更多应用于研究中。大模型时代,pytorch是很多项目的底层库,大有超过tensorflow的趋势。可谓并驾齐驱。

  • 生态系统更成熟:TensorFlow拥有一个庞大的社区和丰富的资源,包括大量的教程、预训练模型和工具,适合从初学者到专家的各个层次用户。
  • 生产部署友好:TensorFlow支持更多的平台和设备,包括移动设备和边缘设备,提供了TensorFlow Lite和TensorFlow.js等,便于模型的部署和优化。
  • 静态图与动态图的结合:虽然早期TensorFlow以静态图为主,但TensorFlow 2.x引入了Eager Execution,结合了动态图的易用性和静态图的高性能,同时保持了模型的可部署性。
  • Keras集成:TensorFlow内建了Keras,这是一个高级神经网络API,使得模型构建、训练和评估更加简洁直观。
  • TensorBoard:TensorFlow自带的可视化工具TensorBoard,便于可视化模型结构、训练过程中的损失和指标,帮助用户更好地理解和调试模型。
  • 广泛的工业应用支持:由于其成熟度和稳定性,TensorFlow在工业界得到了广泛的应用,特别是在大型企业中。

 2.3 tensorflow安装

与pytorch一样,还是采用conda创建环境,采用pip安装tensorflow包

1.建立名为pytrain,python版本为3.11的conda环境(这里与pytorch一样)

conda create -n pytrain python=3.11
conda activate pytrain

​  

 2.采用pip下载tensorflow以及机器学习常用的scikit-learn和numpy包

pip install tensorflow scikit-learn numpy  -i https://mirrors.cloud.tencent.com/pypi/simple

​ 

这里未指定版本,默认下载最新版本tensorflow-2.16.1以及其他tensorboard等生态包。 

三、tensorflow实战

 动手实现一个三层DNN网络:

3.1 引入依赖的tensorflow库

这里主要是tensorflow、keras、sklearn、numpy等

Keras是一个用于构建和训练深度学习模型的高级API,它设计得极其用户友好,支持快速实验。Keras可以运行在TensorFlow之上。

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np

3.2 训练数据准备

这里采用numpy库进行数据随机生成

# 假设你已经有了特征数据 X 和标签数据 y
# X, y = ...  # 实际数据加载和预处理步骤
# 这里我们用随机数据作为示例
np.random.seed(0)
X = np.random.rand(1000, 1000)  # 1000个样本,每个样本1000个特征
y = np.random.randint(0, 2, size=(1000, 1))  # 二分类标签

# 数据预处理,标准化特征
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
  • 首先,采用numpy的random随机生成X矩阵(1000行样本*1000行特征)和y矩阵(1000行0或1的label)
  • 其次,采用sklearn库中的StandardScaler将X矩阵中的每个样本特征数值标准化(将每个特征都转换为正态分布,均值为0,标准差为1),这一步骤对于机器学习算法的性能至关重要,特别是那些对输入数据的尺度敏感的算法。
  • 最后,按照2:8的比例从数据中切分出测试机与训练集

3.3 创建三层DNN模型

采用keras.sequential类,顾名思义“按顺序的”由输入至输出编排神经网络

# 创建模型
model = Sequential([
    Dense(512, input_shape=(X_train.shape[1],)),  # 第一层
    Activation('relu'),
    Dense(512),  # 第二层
    Activation('relu'),
    Dense(1),  # 输出层
    Activation('sigmoid')  # 二分类使用sigmoid
])

 Sequential是Keras中用于构建深度学习模型的一个类,特别适合于构建线性的堆叠层模型。这种模型结构是层与层直接相连,没有复杂的拓扑结构,适合于解决如图像分类、文本分类等任务

特点

  • 线性堆叠:层按照添加的顺序堆叠,每一层只与前一层有连接。
  • 易于使用:适合初学者和快速原型设计,对于复杂的网络结构可能不够灵活。
  • 灵活性限制:对于需要多输入或多输出,或者层间有复杂连接的模型,应使用更高级的模型结构,如Functional API。

3.4 编译模型、定义损失函数与优化器

不同于pytorch的实例化模型对象,这里采用compile对模型进行编译。与pytorch相同点是都要定义损失函数和优化器,方法与技巧完全相同。

# 编译模型
model.compile(optimizer=Adam(learning_rate=0.001),
              loss=BinaryCrossentropy(),
              metrics=['accuracy'])
  • optimizer=Adam(learning_rate=0.001):这里选择了Adam作为优化器。Adam(Adaptive Moment Estimation)是一种常用的优化算法,它结合了RMSprop和Momentum的优点,能够自动调整学习率。通过设置learning_rate=0.001,可以控制模型学习的速度。学习率是训练过程中的一个重要超参数,影响模型收敛的速度和最终的性能。
  • loss=BinaryCrossentropy():损失函数设置为二元交叉熵(Binary Crossentropy)。这个损失函数适用于二分类问题,它衡量了模型预测的概率分布与实际标签之间的差异。在二分类任务中,正确选择损失函数对于模型的性能至关重要。
  • metrics=['accuracy']:指定评估模型性能的指标。这里使用的是准确率(accuracy),即分类正确的比例。在训练和验证过程中,除了损失值外,还会计算并显示这个指标,帮助我们了解模型的性能。

3.5 启动训练,迭代收敛

不同于pytorch需要写两个循环处理每一行样本,tensorflow直接采用fit方法对输入的特征样本矩阵以及label矩阵进行训练

tensorflow版:

# 训练模型
history = model.fit(X_train, y_train, epochs=100, 
                    validation_split=0.1,  # 使用10%的数据作为验证集
                    verbose=1)

pytorch版:

# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
    model.train()  # 设置为训练模式
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(data_loader, 0):
        optimizer.zero_grad()  # 清零梯度
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()  # 反向传播
        optimizer.step()  # 更新权重

        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(data_loader)}')

对比来看,pytorch版的更加透明,有助于理解,tensorflow更加便捷 

运行后可以看到loss逐步收敛:​

3.6 模型评估

通过model.evaluate对模型进行评估,evaluate与fit的区别是只计算指标不进行模型更新

tensorflow版:

# 评估模型
loss, accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f'Test loss: {loss}, Test accuracy: {accuracy}')

 pytorch版:

import torchmetrics # 导入torchmetrics

test_num_samples = 200  # 测试样本数
test_X_train = torch.randn(test_num_samples, input_size) 
test_y_train = torch.randint(0, output_size, (test_num_samples,))

# 数据加载
test_dataset = TensorDataset(test_X_train,test_y_train)
test_data_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

# 在模型训练完成后进行评估
# 首先,我们需要确保模型在评估模式下
model.eval()

# 初始化准确率和召回率的计算器
accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=output_size)
recall = torchmetrics.Recall(task="multiclass", num_classes=output_size)

with torch.no_grad():  # 确保在评估时不进行梯度计算
    for inputs, labels in test_data_loader:
        outputs = model(inputs)
        preds = torch.softmax(outputs, dim=1)
        # 更新指标计算器
        accuracy.update(preds, labels)
        recall.update(preds, labels)

# 打印准确率和召回率
print(f'Accuracy: {accuracy.compute():.4f}')
print(f'Recall: {recall.compute():.4f}')

print('Evaluation finished.')

对比pytorch需要写一个循环,tensorflow.keras的封装更为简洁

运行后,可以输出模型的准确率与召回率,由于采用随机生成的测试数据且迭代轮数较少,具体数值不错参考,可以根据自己需要丰富数据。

3.7 可以直接跑的代码 

与上一篇AI智能体研发之路-模型篇(四):一文入门pytorch开发一样,附可以直接运行的代码,先跑起来,再一行行研究!

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np

# 假设你已经有了特征数据 X 和标签数据 y
# X, y = ...  # 实际数据加载和预处理步骤
# 这里我们用随机数据作为示例
np.random.seed(0)
X = np.random.rand(1000, 1000)  # 1000个样本,每个样本1000个特征
y = np.random.randint(0, 2, size=(1000, 1))  # 二分类标签

# 数据预处理,标准化特征
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)

# 创建模型
model = Sequential([
    Dense(512, input_shape=(X_train.shape[1],)),  # 第一层
    Activation('relu'),
    Dense(512),  # 第二层
    Activation('relu'),
    Dense(1),  # 输出层
    Activation('sigmoid')  # 二分类使用sigmoid
])

# 编译模型
model.compile(optimizer=Adam(learning_rate=0.001),
              loss=BinaryCrossentropy(),
              metrics=['accuracy'])

# 训练模型
history = model.fit(X_train, y_train, epochs=10, 
                    validation_split=0.1,  # 使用10%的数据作为验证集
                    verbose=1)

# 评估模型
loss, accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f'Test loss: {loss}, Test accuracy: {accuracy}')

四、总结

本文先对tensorflow深度学习框架历史、特点及安装方法进行介绍,接下来基于tensorflow带读者一步步开发一个简单的三层神经网络程序,最后附可执行的代码供读者进行测试学习。个人感觉tensorflow封装程度高于pytorch,网络结构也更加清晰,但pytorch更加透明。

喜欢的话期待您的关注、点赞、收藏,您的互动是对我最大的鼓励!

如果还有时间,可以看看我的其他文章:

《AI—工程篇》

AI智能体研发之路-工程篇(一):Docker助力AI智能体开发提效

AI智能体研发之路-工程篇(二):Dify智能体开发平台一键部署

AI智能体研发之路-工程篇(三):大模型推理服务框架Ollama一键部署

AI智能体研发之路-工程篇(四):大模型推理服务框架Xinference一键部署

AI智能体研发之路-工程篇(五):大模型推理服务框架LocalAI一键部署

《AI—模型篇》

AI智能体研发之路-模型篇(一):大模型训练框架LLaMA-Factory在国内网络环境下的安装、部署及使用

AI智能体研发之路-模型篇(二):DeepSeek-V2-Chat 训练与推理实战

AI智能体研发之路-模型篇(三):中文大模型开、闭源之争

AI智能体研发之路-模型篇(四):一文入门pytorch开发

AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/659729.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Facebook:社交世界的接口

在当今数字时代,社交媒体已经成为了人们生活中不可或缺的一部分,而Facebook作为其中的巨头之一,扮演着至关重要的角色。本文将带您深入探索Facebook这张社交世界的画卷,全面了解这个令人着迷的平台。 起源与历程 Facebook的故事始…

揭开神秘的“位移主题”面纱 no.16

Kafka中神秘的内部主题(Internal Topic)__consumer_offsets。 consumer_offsets在Kafka源码中有个更为正式的名字,叫*位移主题*,即Offsets Topic。为了方便今天的讨论,我将统一使用位移主题来指代consumer_offsets。需…

机器视觉分析在加油站安全中的应用:使用手机检测、打电话行为识别

在加油站等高危场所,禁止使用手机是为了防止潜在的火灾和爆炸风险。手机在使用过程中可能产生电火花,而在加油站这种易燃易爆环境中,任何电火花都可能引发严重的安全事故。因此,加油站禁止使用手机是保障安全生产的重要措施。基于…

如何让Google快速收录?

要让Google快速收录你的网站,可以考虑使用GSI服务,这是一种专门设计来加速网站被Google搜索引擎收录的服务,下面详细解释GSI服务的基本原理和具体好处: GSI服务通过一种名为GPC爬虫池的系统实现,这个系统是基于对Goog…

SQL注入攻击是什么?如何预防?

一、SQL注入攻击是什么? SQL注入攻击是一种利用Web应用程序中的安全漏洞,将恶意的SQL代码插入到数据库查询中的攻击方式。攻击者通过在Web应用程序的输入字段中插入恶意的SQL代码,然后在后台的数据库服务器上解析执行这些代码,从而…

对比方案:5款知识中台工具的优缺点详解

知识中台工具为企业和组织高效地组织、存储和分享知识,还能提升团队协作的效率。在选择搭建知识中台的工具时,了解工具的优缺点,有助于企业做出最佳决策。本文LookLook同学将对五款搭建知识中台的工具进行优缺点的简单介绍,帮助企…

单细胞分析(Signac): PBMC scATAC-seq 基因组区域可视化

引言 在本教学指南中,我们将探讨由10x Genomics公司提供的人类外周血单核细胞(PBMCs)的单细胞ATAC-seq数据集。 加载包 首先加载 Signac、Seurat 和我们将用于分析人类数据的其他一些包。 if (!requireNamespace("EnsDb.Hsapiens.v75&qu…

虹科Pico汽车示波器 | 免拆诊断案例 | 2017款吉利帝豪GL车发动机偶尔无法起动

故障现象  一辆2017款吉利帝豪GL车,搭载JLC-4G18发动机和手动变速器,累计行驶里程约为39.3万km。车主反映,该车发动机偶尔无法起动。故障发生频率比较频繁,冷机状态下故障比较容易出现。 故障诊断  接车后试车,故…

1Panel开源面板全平台下载总量突破500,000次!

截至2024年5月22日,FIT2CLOUD飞致云旗下开源项目——1Panel开源Linux服务器运维管理面板全平台下载总量突破500,000次!

Java培训后找不到工作,现在去培训嵌入式可行吗?

最近java 工作还是比较好找,不知道你是对薪资要求太高,还是因为其他原因,如果你真的面试了很多都还找不到工作,那么一定要知道找不到工作的原因是啥,一定不是因为java 太卷,你说那个行业,那个职…

socks5 如何让dns不被污染

问题 发现firefox浏览器代理设置成socks5后,查看ip是成功了,但是谷歌等海外的还是无法正常访问。 原因 主要原因是socks5连接虽然是成功了,但是dns还是走国内的,国内的dns解析都被污染了导致没法正常访问 解决 把设置里的 使…

工程文档CAD转换必备!快速将 DWG 转换到 PNG ~

Aspose.CAD 是一个独立的类库,以加强Java应用程序处理和渲染CAD图纸,而不需要AutoCAD或任何其他渲染工作流程。该CAD类库允许将DWG, DWT, DWF, DWFX, IFC, PLT, DGN, OBJ, STL, IGES, CFF2文件、布局和图层高质量地转换为PDF和光栅图像格式。 Aspose AP…

springboot学生就业信息管理系统-计算机毕业设计源码95340

摘 要 信息化社会内需要与之针对性的信息获取途径,但是途径的扩展基本上为人们所努力的方向,由于站在的角度存在偏差,人们经常能够获得不同类型信息,这也是技术最为难以攻克的课题。针对学生就业信息管理系统等问题,对…

Spire.PDF for .NET【文档操作】演示:将PDF 拆分为多个 PDF

Spire.PDF 完美支持将多页 PDF 拆分为单页。但是,更常见的情况是,您可能希望提取选定的页面范围并保存为新的 PDF 文档。在本文中,您将学习如何通过 Spire.PDF 在 C#、VB.NET 中根据页面范围拆分 PDF 文件。 Spire.PDF for .NET 是一款独立 …

Python-3.12.0文档解读-内置函数pow()详细说明+记忆策略+常用场景+巧妙用法+综合技巧

一个认为一切根源都是“自己不够强”的INTJ 个人主页:用哲学编程-CSDN博客专栏:每日一题——举一反三Python编程学习Python内置函数 Python-3.12.0文档解读 目录 详细说明 功能描述 参数 返回值 使用规则 示例代码 基本使用 模运算 变动记录…

CAD石墨烯生成器 V1.0 渊鱼

插件介绍 CAD石墨烯生成器插件可用于在AutoCAD软件内参数化建立石墨烯几何模型。插件建立石墨烯的球棍模型,可控制模型的尺寸、碳原子环的尺寸、原子直径、化学键直径,并可控制模型的起伏形态。插件生成的实体模型可进行修改或绘图渲染,用于…

Java多线程(02)—— 线程等待,线程安全

一、如何终止线程 终止线程就是要让 run 方法尽快执行结束 1. 手动创建标志位 可以通过在代码中手动创建标志位的方式,来作为 run 方法的执行结束条件; public static void main(String[] args) throws InterruptedException {boolean flag true;Thr…

快速下载极客时间课程

仅供学习,切勿商用 1. 下载 下载geektime-downloader,安装到指定文件夹,注意路径尽量不要出现汉字 不想去github上下载的可以直接下载文章顶部的软件安装包。 2. 执行命令 在安装geektime-downloader目录下,点击鼠标右键&…

C++之运算符重载

1、运算符重载 //Complex.h #ifndef _COMPLEX_H_ #define _COMPLEX_H_class Complex { public:Complex(int real_, int imag_);Complex();~Complex();Complex& Add(const Complex& other); void Display() const;Complex operator(const Complex& other);privat…

Linux中常见的基本指令(上)

目录 一、ls指令 1. ls 2. ls -l 3. ls -a 4.ls -F 二、qwd指令 三、cd指令 1. cd .. 2. cd / / / 3. cd ../ / / 4. cd ~ 5. cd - 五、mkdir指令 六、rmdir指令和rm指令 一、ls指令 语法 : ls [ 选项 ][ 目录或文件 ] 。 功能 :对于目录…