深度学习之加宽全连接

1.Functional API 搭建神经网络模型

1.1.利用Functional API编写宽深神经网络模型进行手写数字识别

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

iris = load_iris()

x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,test_size=0.2,random_state=23)

X_train,X_valid,y_train,y_valid = train_test_split(x_train,y_train,test_size=0.2,random_state=12)

import tensorflow as tf
from tensorflow import keras

#创建了一个输入层inputs,其形状与X_train的形状(除了第一个维度,即样本数量)相同。
inputs = keras.layers.Input(shape=X_train.shape[1:])
#添加了一个隐藏层hidden1,它包含300个神经元,并使用ReLU激活函数。
hidden1 = keras.layers.Dense(300,activation='relu')(inputs)
hidden2 = keras.layers.Dense(100,activation='relu')(hidden1)
concat = keras.layers.concatenate([inputs,hidden2])
output = keras.layers.Dense(10,activation='softmax')(concat)

model_fun_WideDeep = keras.models.Model(inputs=[inputs],outputs=[output])

model_fun_WideDeep.summary()  #观察神经网络的整体情况

运行结果:
在这里插入图片描述
:在这个模型中,inputs层代表了宽度的特征,而hidden1和hidden2层代表了深度的特征。通过将输入层和第二个隐藏层的输出拼接在一起,模型结合了宽度和深度的特征,以提高模型的预测能力。

#利用.compile().fit()函数对数据进行编译与训练
model_fun_WideDeep.compile(loss='sparse_categorical_crossentropy',optimizer='sgd',metrics=["accuracy"])
h=model_fun_WideDeep.fit(X_train,y_train,batch_size=32,epochs=30,validation_data=(X_valid,y_valid))

运行结果:
在这里插入图片描述
:optimizer=‘sgd’:这是模型训练时使用的优化器。sgd代表随机梯度下降(Stochastic Gradient Descent),它是一种简单的优化算法,用于在训练过程中更新模型的权重。

1.2.利用Functional API编写多输入神经网络模型进行手写数字识别

X_train_A, X_train_B = X_train[:, :200], X_train[:, 100:]
X_valid_A, X_valid_B = X_valid[:, :200], X_valid[:, 100:]

# 定义输入层
input_A = keras.layers.Input(shape=X_train_A.shape[1])
input_B = keras.layers.Input(shape=X_train_B.shape[1])

# 定义隐藏层和输出层
hidden1 = keras.layers.Dense(300, activation='relu')(input_B)
hidden2 = keras.layers.Dense(100, activation='relu')(hidden1)
concat = keras.layers.concatenate([input_A, hidden2])
output = keras.layers.Dense(10, activation='softmax')(concat)

# 创建模型
model_fun_MulIn = keras.models.Model(inputs=[input_A, input_B], outputs=[output])

# 编译模型
model_fun_MulIn.compile(loss='sparse_categorical_crossentropy',optimizer='sgd',metrics=["accuracy"])

# 训练模型
h=model_fun_MulIn.fit((X_train_A,X_train_B),y_train,batch_size=32,epochs=10, validation_data=((X_valid_A, X_valid_B), y_valid))

运行结果:
在这里插入图片描述
:这段代码将训练数据集X_train和验证数据集X_valid分割成两个子集,分别是X_train_A和X_train_B以及X_valid_A和X_valid_B。每个子集都包含原始数据集的一部分特征。

pd.DataFrame(h.history).plot(figsize=(8, 5))
plt.grid(True)
plt.gca().set_ylim(0, 1)
plt.show()

运行结果:
在这里插入图片描述

2.SubClassing API 搭建神经网络模型

2.1.构建继承keras.model.Model的子类来搭建神经网络模型

class Model_sub_fnn(keras.models.Model):
    def __init__(self, units_1=300, units_2=100, units_out=10, activation='relu'):
        super().__init__()
        self.hidden1 = keras.layers.Dense(units_1, activation=activation)
        self.hidden2 = keras.layers.Dense(units_2, activation=activation)
        self.main_output = keras.layers.Dense(units_out, activation='softmax')
    def call(self, data):
        hidden1 = self.hidden1(data)
        hidden2 = self.hidden2(hidden1)
        main_output = self.main_output(hidden2)
        return main_output

#初始化模型
model_sub_fnn = Model_sub_fnn()

#通过在初始化中传递参数改变模型元素默认值
model_sub_fnn2 = Model_sub_fnn(300, 100, 10, activation='relu')

#编译模型与训练
model_sub_fnn.compile(loss='sparse_categorical_crossentropy',optimizer='sgd',metrics=["accuracy"])
h= model_sub_fnn.fit(X_train,y_train,batch_size=32,epochs=30,validation_data = (X_valid,y_valid))

model_sub_fnn.summary()

运行结果:
在这里插入图片描述
在这里插入图片描述

pd.DataFrame(h.history).plot(figsize=(8, 5))
plt.grid(True)
plt.gca().set_ylim(0, 1)
plt.show()

运行结果:
在这里插入图片描述
:在这段代码中,我们定义了一个名为Model_sub_fnn的Keras模型类,并在之后实例化了两个模型实例,分别命名为model_sub_fnn和model_sub_fnn2。这两个模型实例的参数值有所不同,这可以通过在实例化时传递不同的参数值来实现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/672408.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【免费Web系列】JavaWeb实战项目案例三

这是Web第一天的课程大家可以传送过去学习 http://t.csdnimg.cn/K547r 部门管理开发 1. 删除部门 1.1 需求分析 删除部门数据。在点击 "删除" 按钮,会根据ID删除部门数据。 了解了需求之后,我们再看看接口文档中,关于删除部门…

还没搞懂作用域、执行上下文、变量提升?看这篇就够啦

前言 📫 大家好,我是南木元元,热爱技术和分享,欢迎大家交流,一起学习进步! 🍅 个人主页:南木元元 目录 作用域(Scope) 全局作用域 函数作用域 块级作用域…

编译选项导致的结构体字节参数异常

文章目录 前言问题描述原因分析问题解决总结 前言 在构建编译工程时,会有一些对应的编译配置选项,不同的编译器,会有对应的配置项。本文介绍GHS工程中编译选项配置不对应导致的异常。 问题描述 在S32K3集成工程中,核1的INP_SWC…

【并发程序设计】15.信号灯(信号量)

15.信号灯(信号量) Linux中的信号灯即信号量是一种用于进程间同步或互斥的机制,它主要用于控制对共享资源的访问。 在Linux系统中,信号灯作为一种进程间通信(IPC)的方式,与其他如管道、FIFO或共享内存等IPC方式不同&…

c++ 哈希 unordered_map unordered_set 的学习

1. unordered 系列 在 c98 中, STL 提供了底层是红黑树结构的一系列关联式容器,set 和 map 的查询效率可以达到 log2N,红黑树最差的情况也只是需要比较红黑树的高度次,当节点数量非常多时,查找一个节点还需要比较几十…

护肤品美妆商城小程序的作用是什么

经营美妆的方式多种多样,商场街边、电商平台、微商等,无论厂商品牌还是经销商批发零售都有大量目标群体,客户在哪里商家就应该在哪里,私域生意模式,商家需要线上多渠道获客转化和提高营收。 运用【雨科】平台搭建护肤…

PatchEmbed

PatchEmbed 是用于计算机视觉任务的神经网络层,特别是在Vision Transformer (ViT) 模型中使用。它负责将输入的图像分割成固定大小的图像块(patches),并将这些图像块线性嵌入到高维空间中。这是Vision Transformer处理图像的方式&…

JVM虚拟机性能监控工具

命令行工具 jps 虚拟机进程状况查询工具 jps(JVM Process Status Tool),可以列出正在运行的虚拟机进程,并显示虚拟机执行主类名称或者jar文件名,还有这些进程的本地虚拟机唯一ID(LVMID,Local Virtual Machine Identifier)。 # …

Vue.js 与 TypeScript(1) :项目配置、props标注类型、emits标注类型

像 TypeScript 这样的类型系统可以在编译时通过静态分析检测出很多常见错误。这减少了生产环境中的运行时错误,也让我们在重构大型项目的时候更有信心。通过 IDE 中基于类型的自动补全,TypeScript 还改善了开发体验和效率。 一、项目配置 在使用 npm cr…

AI 网页解锁器,用于网页抓取一切 | 最快的验证码解决服务

想象一下,解锁互联网的全部潜力,数据自由流动,没有任何障碍阻挡你获取所需信息。在网络爬虫的世界里,这个梦想常常会遇到障碍:CAPTCHA和反机器人措施,这些措施旨在保护网站免受自动化访问的侵害。但如果有一…

蓝桥杯软件测试-十五届模拟赛2期题目解析

十五届蓝桥杯《软件测试》模拟赛2期题目解析 PS 需要第十五界蓝桥杯模拟赛2期功能测试模板、单元测试被测代码、自动化测试被测代码请加🐧:1940787338 备注:15界蓝桥杯省赛软件测试模拟赛2期 题目1:功能测试题目 1(测试用例&…

60 关于 SegmentFault 的一些场景 (1)

前言 呵呵 此问题主要是来自于 帖子 月经结贴 -- 《Segmentation Fault in Linux》 这里主要也是 结合了作者的相关 case, 来做的一些 调试分享 当然 很多的情况还是 蛮有意思 本文主要问题如下 1. 访问可执行文件中的 只读数据 2. 访问不存在的虚拟地址 3. 访问内核地址…

【机器学习】基于OpenCV和TensorFlow的MobileNetV2模型的物种识别与个体相似度分析

在计算机视觉领域,物种识别和图像相似度比较是两个重要的研究方向。本文通过结合深度学习和图像处理技术,基于OpenCV和TensorFlow的MobileNetV2的预训练模型模,实现物种识别和个体相似度分析。本文详细介绍该实验过程并提供相关代码。 一、名…

Python代码:二十六、反转列表

1、题目 描述 小明有一个列表记录了各个朋友的喜欢的数字,num [3, 5, 9, 0, 1, 9, 0, 3],请你帮他创建列表,然后使用reverse函数将列表反转输出。 输入描述: 无 输出描述: 第一行输出创建好的原始的列表&#x…

typescript --object对象类型

ts中的object const obj new Object()Object 这里的Object是Object类型,而不是JavaScript内置的Object构造函数。 这里的Object是一种类型,而Object()构造函数表示一个值。 Object()构造函数的ts代码 interface ObjectConstructor{readonly prototyp…

【JavaEE】JVM中垃圾回收机制详解

一.垃圾回收的基本概念 1.什么是垃圾回收机制. JVM(Java虚拟机)垃圾回收机制是Java内存管理的重要组成部分,它负责自动回收程序中不再使用的对象所占用的内存空间。这样可以有效地防止内存泄漏和内存溢出问题,提高程序的稳定性和…

电脑死机问题排查

情况描述:2024年6月2日下午16:04分电脑突然花屏死机,此情况之前遇到过三次,认为是腾讯会议录屏和系统自带录屏软件冲突导致。 报错信息:应用程序-特定 权限设置并未向在应用程序容器 不可用 SID (不可用)中运行的地址…

GPT-4o有点坑

GPT-4o有点坑 0. 前言1. GPT-4o简介2. GPT-4o带来的好处2.1 可以上传图片和文件2.2 更丰富的功能以及插件 3. "坑"的地方3.1 使用时间短3.2 GPT-4o变懒了 4. 总结 0. 前言 原本不想对GPT-4o的内容来进行评论的,但是看了相关的评论一直在说:技…

全国水系数据(更新到2024年5月)

上海市水系数据地图可视化 水系数据线图层(小河/溪流、江/河、运河、下水道/排水管) 水系数据面数据(水域、水库、河岸、湿地) 水系数据字段说明 可视化预览 北京市水系可视化 上海市水系可视化 广州市水系可视化 深圳市水系可视化…

Gin的快速入门和搭建

文章目录 Go的工程工程架构技术选型 Gin入门 Go的工程 基于Go生态,构建一个支持内容管理,内容加工、内容分发的内容库系统。 内容管理:增删改查内容加工:例如内容审核、推荐等内容分发:将内容可以推到不同的业务线 …