深度学习(三)

5.Functional API 搭建神经网络模型
5.1利用Functional API编写宽深神经网络模型进行手写数字识别
import numpy as np

import pandas as pd

import matplotlib.pyplot as plt

from sklearn.datasets import load_iris

from sklearn.model_selection import train_test_split

from tensorflow.keras.layers import Input, Dense, concatenate

from tensorflow.keras.models import Model



iris = load_iris()



x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=23)

X_train, X_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=0.2, random_state=12)



print(X_valid.shape)

print(X_train.shape)



inputs = Input(shape=X_train.shape[1:])

hidden1 = Dense(300, activation="relu")(inputs)

hidden2 = Dense(100, activation="relu")(hidden1)

concat = concatenate([inputs, hidden2])

output = Dense(10, activation="softmax")(concat)

model_wide_deep = Model(inputs=inputs, outputs=output)

iris = load_iris():加载iris数据集,这是一个常用的多类分类数据集,包含了150个样本,每个样本有4个特征,属于3个不同的类别。

x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=23):将iris数据集分割为训练集和测试集。test_size=0.2表示测试集的大小为原始数据的20%,random_state=23是一个随机种子,确保分割的可重复性。

X_train, X_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=0.2, random_state=12):进一步将训练集分割为训练集和验证集。同样,test_size=0.2表示验证集的大小为分割后训练数据的20%,random_state=12确保分割的可重复性。

print(X_valid.shape):打印验证集的特征数据的形状。

print(X_train.shape):打印新的训练集的特征数据的形状。

inputs = Input(shape=X_train.shape[1:]):定义模型的输入层,shape=X_train.shape[1:]指定输入的形状,由于X_train是一个二维数组,shape[1:]表示除了第一维(样本数量)之外的所有维度。

hidden1 = Dense(300, activation="relu")(inputs):定义第一个隐藏层,它有300个神经元,并使用ReLU激活函数。

hidden2 = Dense(100, activation="relu")(hidden1):定义第二个隐藏层,它有100个神经元,并使用ReLU激活函数。

concat = concatenate([inputs, hidden2]):将输入层和第二个隐藏层的输出拼接起来,形成更宽的网络。

output = Dense(10, activation="softmax")(concat):定义输出层,它有10个神经元(对应于3个类别和一个额外的神经元,这是常见的做法),并使用softmax激活函数输出概率分布。

model_wide_deep = Model(inputs=inputs, outputs=output):创建一个Keras模型,将输入层和输出层连接起来。

使用scikit-learn库中的load_iris函数来加载iris数据集,然后使用train_test_split函数将数据集分割为训练集和测试集,以及进一步的训练集和验证集。接着,它定义了一个宽深网络(wide and deep network)模型,其中包含了输入层、两个隐藏层和一个输出层。

model_wide_deep.summary()

model_wide_deep.compile(loss="sparse_categorical_crossentropy",optimizer="sgd",metrics=["accuracy"])

h = model_wide_deep.fit(X_train, y_train, batch_size=32, epochs=30,validation_data=(X_valid, y_valid))

# 绘图

pd.DataFrame(h.history).plot(figsize=(8,5))

plt.grid(True)

plt.gca().set_ylim(0, 1)

plt.show()

# 使用 model_wide_deep 评估测试集

test_loss, test_accuracy = model_wide_deep.evaluate(x_test, y_test, batch_size=32)



print(f"Test Loss: {test_loss}")

print(f"Test Accuracy: {test_accuracy}")

6.SubClassing API 搭建神经网络模型

以前馈全连接神经网络手写数字识别为例

import numpy as np

import pandas as pd

import matplotlib.pyplot as plt

from sklearn.datasets import load_iris

from sklearn.model_selection import train_test_split

from tensorflow.keras.layers import Input, Dense, concatenate

from tensorflow.keras.models import Model

from tensorflow.keras import backend as K



# 加载数据集

iris = load_iris()

X = iris.data

y = iris.target



# 分割数据集

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=23)

X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.2, random_state=12)



# 打印验证集和训练集的形状

print(X_valid.shape)

print(X_train.shape)



# 定义 Model_sub_fn 类

class Model_sub_fn(Model):

    def __init__(self, units_1, units_2, units_out, activation="relu"):

        super(Model_sub_fn, self).__init__()

        self.hidden1 = Dense(units_1, activation=activation)

        self.hidden2 = Dense(units_2, activation=activation)

        self.main_output = Dense(units_out, activation="softmax")



    def call(self, inputs):

        x = self.hidden1(inputs)

        x = self.hidden2(x)

        return self.main_output(x)

定义了一个名为Model_sub_fn的类,该类继承自tensorflow.keras.Model。这个类用于创建一个简单的神经网络模型,它包含两个隐藏层和一个输出层。

class Model_sub_fn(Model)定义一个名为Model_sub_fn的类,它继承自tensorflow.keras.Model。这意味着Model_sub_fn类可以访问和继承Model类的所有属性和方法。

def __init__(self, units_1, units_2, units_out, activation="relu"):定义类的构造函数__init__,它接受四个参数:units_1(第一个隐藏层的神经元数量)、units_2(第二个隐藏层的神经元数量)、units_out(输出层的神经元数量)和activation(激活函数类型,默认为ReLU)。

super(Model_sub_fn, self).__init__():调用父类的构造函数,这是继承自Model类的标准做法。

self.hidden1 = Dense(units_1, activation=activation):定义第一个隐藏层,它有units_1个神经元,并使用activation作为激活函数。

self.hidden2 = Dense(units_2, activation=activation):定义第二个隐藏层,它有units_2个神经元,并使用activation作为激活函数。

self.main_output = Dense(units_out, activation="softmax"):定义输出层,它有units_out个神经元,并使用softmax作为激活函数。

def call(self, inputs):定义call方法,这是所有Keras模型必须定义的方法,它用于前向传播。在这个方法中,输入数据通过两个隐藏层,最后通过输出层。

x = self.hidden1(inputs):将输入数据通过第一个隐藏层。

x = self.hidden2(x):将第一个隐藏层的输出通过第二个隐藏层。

return self.main_output(x):将第二个隐藏层的输出通过输出层,并返回结果。

model_sub_fn = Model_sub_fn(units_1=64, units_2=32, units_out=3)



# 创建 Model_sub_fn 实例

model_sub_fn = Model_sub_fn(300, 100, 3, activation="relu")  # 假设输出层有3个单元,因为Iris数据集有3个类别



# 编译模型

model_sub_fn.compile(loss="sparse_categorical_crossentropy",optimizer="sgd",metrics=["accuracy"])



# 训练模型

history = model_sub_fn.fit(X_train, y_train, batch_size=32, epochs=30, validation_data=(X_valid, y_valid))

model_sub_fn.summary()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/687170.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SVNCloud 与 Navicat和IDEA的连接

文章目录 SVNCloud 配置Navicat访问云端数据库与IDEA Java jdbc 的连接 SVNCloud 配置 访问网址:SVN注册账号,进入mysql区域: 数据库管理->创建数据库,输入数据库名称和密码,注意,这里的数据库名称实际…

vue 如何制作一个跟随窗口大小变化而变化的组件

vue 如何制作一个跟随窗口大小变化而变化的组件 像下图中展示的那些统计数件就是跟随窗口变化而变化的,而且是几乎等比缩放的。 实现原理 只简略说一下原理。 pinia 中记录一个窗口变化的高度值给要变化的组件添加一个高度值组件内部所有关于长度距离的值都通过这…

笔记 | 软件工程04:软件项目管理

1 软件项目及其特点 1.1 什么是项目 1.2 项目特点 1.3 影响项目成功的因素 1.4 什么是软件项目 针对软件这一特定产品和服务的项目努力开展“软件开发活动",(理解:软件项目是一种活动) 1.5 软件项目的特点 1.6 军用软件项目的特点 2 …

水库安全监测系统:智慧水文动态监测系统

TH-SW2水库安全监测系统,作为一款智慧水文动态监测系统,其在现代水利管理中扮演着至关重要的角色。该系统通过集成先进的数据采集、传输、处理和分析技术,为水库的安全运行提供了强有力的技术支撑。 水库安全监测系统是一种用于实时监测和记…

matplotlib绘制三维曲面图时遇到的问题及解决方法

在使用 Matplotlib 绘制三维曲面图时,可能会遇到一些常见的问题。今天我将全程详细讲解下遇到问题并且找到应对方法的全部过程,希望能帮助大家。 1、问题背景 在使用 matplotlib 绘制三维曲面图时,遇到了一个问题。代码如下: im…

Faiss框架使用与FaissRetriever实现

Faiss是一个由Facebook AI Research开发的库,用于高效相似性搜索和稠密向量聚类。它为机器学习和深度学习中的向量检索问题提供了一种高效的解决方案,特别是在处理大规模数据集时。Faiss支持多种索引类型,包括基于量化的索引、基于聚类的索引…

Ubuntu系统的k8s常见的错误和解决的问题

K8s配置的时候出现的常见问题 Q1: master节点kubectl get nodes 出现的错误 或者 解决方法&#xff1a; cat <<EOF >> /root/.bashrc export KUBECONFIG/etc/kubernetes/admin.conf EOFsource /root/.bashrc重新执行 kubectl get nodes 记得需要查看一下自己的…

倒计时 3 天!立即预约苹果 WWDC24 直播;RLAIF-V 大规模多模态偏好数据集上线,有效减少不同 MLLMs 幻觉现象

6 月 3 日-6 月 7 日&#xff0c;hyper.ai 官网更新速览&#xff1a; 优质公共数据集&#xff1a;10 个 优质教程精选&#xff1a;2 个 社区文章精选&#xff1a;3 篇 热门百科词条&#xff1a;5 条 6-7 月截稿顶会&#xff1a;5 个 访问官网&#xff1a;hyper.ai 公共数…

37. 【Java教程】序列化与反序列化

上一小节我们学习了 Java 的输入输出流&#xff0c;有了这些前置知识点&#xff0c;我们就可以学习 Java 的序列化了。本小节将介绍什么是序列化、什么是反序列化、序列化有什么作用&#xff0c;如何实现序列化与反序列化&#xff0c;Serializable 接口介绍&#xff0c;常用序列…

【JavaEE精炼宝库】多线程(4)深度理解死锁、内存可见性、volatile关键字、wait、notify

目录 一、死锁 1.1 出现死锁的常见场景&#xff1a; 1.2 产生死锁的后果&#xff1a; 1.3 如何避免死锁&#xff1a; 二、内存可见性 2.1 由内存可见性产生的经典案例&#xff1a; 2.2 volatile 关键字&#xff1a; 2.2.1 volatile 用法&#xff1a; 2.2.2 volatile 不…

C++中的stack和queue

C中的stack和queue 一丶stack1. stack的介绍2. stack的使用3. stack的模拟实现 二丶queue1. queue的介绍2. queue的使用3. queue的模拟实现 一丶stack 1. stack的介绍 stack的文档介绍 关于stack&#xff1a; 1. stack是一种容器适配器&#xff0c;专门用在具有后进先出操作的…

ROS socketcan_bridge使用说明

ROS socketcan_bridge使用说明&#xff08;以ubuntu20.04为例&#xff09; socketcan_bridge是什么 ROS针对socketcan提供了三个层次的驱动库&#xff0c;分别是ros_canopen&#xff0c;socketcan_bridge和socketcan_interface。 socketcan_interface&#xff1a; 功能&#x…

20240607在Toybrick的TB-RK3588开发板的Android12下适配IMX415摄像头和ov50c40

20240607在Toybrick的TB-RK3588开发板的Android12下适配IMX415摄像头和ov50c40 2024/6/7 11:42 【4K/8K摄像头发热量巨大&#xff0c;请做好散热措施&#xff0c;最好使用散热片鼓风机模式&#xff01;】 结论&#xff1a;欢迎您入坑。 Toybrick的TB-RK3588开发板的技术支持不…

STM32—按键控制LED(定时器)

目录 1 、 电路构成及原理图 2 、编写实现代码 main.c exit.c 3、代码讲解 4、烧录到开发板调试、验证代码 5、检验效果 此笔记基于朗峰 STM32F103 系列全集成开发板的记录。 1 、 电路构成及原理图 EXTI&#xff08;External interrupt/event controller&#xff…

机器视觉——物块分拣

项目进行到第四天&#xff0c;我们学到了很多&#xff0c;可以进行实操。 首先我们利用相机软件进行采图 然后导入代码里面 完整代码 dev_get_window (WindowHandle) list_image_files (采图, default, [], ImageFiles) for Index : 0 to |ImageFiles| - 1 by 1read_image (Im…

上BFT,是你的首选

上BFT&#xff0c;是你的首选 如果你想要找最智能的机器人&#xff0c;想要找品牌最全或者想要咨询专业的解决方案&#xff0c;一定不要错过BFT机器人采购站。BFT致力于为广大用户提供品质卓越、技术先进的机器人产品。 BFT里面机器人多种多样&#xff0c;不管您是想要工业机器…

取证工作: SysTools MailXaminer, 用强大功能辅助电子邮件调查工作的每一步

天津鸿萌科贸发展有限公司是 SysTools 系列软件的授权代理商。 SysTools MailXaminer 电子邮件取证软件将调查工作分为五个阶段&#xff1a;邮件加载、预览、搜索、分析及导出。软件对调查工作的每一阶段都提供了现代高级功能&#xff0c;以帮助数字取证调查员根据其具体要求对…

知乎知+广告推广开户充值的返点政策是怎样?

如何让您的品牌精准触达目标受众&#xff0c;实现高效传播与转化&#xff0c;成为了每一位市场人面临的挑战。为此&#xff0c;云衔科技作为业界领先的数字营销解决方案提供商&#xff0c;正式宣布全面支持知乎知广告开户及一站式代运营服务&#xff0c;旨在帮助各行业客户在知…

珠海鸿瑞毛利率持续下滑:核心产品销量大降,偿债能力偏弱

《港湾商业观察》黄懿 日前&#xff0c;珠海市鸿瑞信息技术股份有限公司&#xff08;下称“珠海鸿瑞”&#xff09;收到了北京证券交易所发出的第三轮审核问询函。 此前&#xff0c;2020年11月&#xff0c;珠海鸿瑞曾向深交所报送上市申请。IPO申请文件获受理后&#xff0c;珠…

用互斥锁解决缓存击穿

我先说一下正常的业务流程&#xff1a;需要查询店铺数据&#xff0c;我们会先从redis中查询&#xff0c;判断是否能命中&#xff0c;若命中说明redis中有需要的数据就直接返回&#xff1b;没有命中就需要去mysql数据库查询&#xff0c;在数据库中查到了就返回数据并把该数据存入…
最新文章