搭建一个简单的深度神经网络

目录

一、引入所需要的库

二、制作数据集

三、搭建神经网络

四、训练网络

五、测试网络

本博客实验环境为jupyter

一、引入所需要的库

torch库是核心,其中torch.nn 提供了搭建网络所需的所有组件,nn即神经网络。matplotlib类似与matlab,其中pyplot用于进行数据可视化,如绘制图表、曲线等。%matplotlib inline: 这是IPython(Jupyter Notebook)的魔法命令,用于在Notebook中直接显示Matplotlib绘制的图表,而不是弹出一个新窗口显示。

import torch 
import torch.nn as nn 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
%matplotlib inline

# 展示高清图 
from matplotlib_inline import backend_inline #导入Matplotlib库中的backend_inline模块,用于控制图表的显示方式。
backend_inline.set_matplotlib_formats('svg') #设置Matplotlib图表的显示格式为SVG格式,SVG格式的图表在显示时具有高清晰度,适合用于展示精细的图形。

二、制作数据集

主要任务是读取数据集,划分为训练集和测试集,一定要随机划分。

读取的数据集中共760组数据,共8个输入特征,1个输出特征。

其中第一列是索引,从0开始,70%为训练集,30%为测试集。

#读取数据
df = pd.read_csv('Data.csv', index_col=0)#之前的pandas库中有介绍到,即df为读取后的对象,以第一列为索引    
arr = df.values #转化为numpy数组              
arr = arr.astype(np.float32)#转化为深度学习常用的单精度浮点类型    
ts = torch.tensor(arr)#转化为张量tensor         
ts = ts.to('cuda')#送到cuda设备上即gpu上计算             

# 划分训练集与测试集 
train_size = int(len(ts) * 0.7) #训练集的大小为百分之七十          
test_size = len(ts) - train_size #测试集的大小为百分之三十         
ts = ts[ torch.randperm( ts.size(0) ) , : ] #随机打乱数据集中样本的顺序    
train_Data = ts[ : train_size , : ] #将前百分之七十行给训练集       
test_Data = ts[ train_size : , : ]  #将百分之七十后的行给测试集        

三、搭建神经网络

主要是构建DNN类,需要对python的类定义有较为深入的理解能力。

class DNN(nn.Module): #定义了一个名为DNN的PyTorch模型类,该类继承自nn.Module类,表示这是一个神经网络模型。
 
    def __init__(self): #定义了模型的初始化方法
        ''' 搭建神经网络各层 ''' 
        super(DNN,self).__init__() #调用父类的初始化方法,确保模型的其他部分也能够被正确初始化。
        self.net = nn.Sequential(            # 按顺序搭建各层 
            nn.Linear(8, 32), nn.Sigmoid(),   # 第1层:全连接层 ,是一个包含32个神经元的全连接层,输入特征数为8(表示输入数据维度为8),并使用Sigmoid激活函数。
            nn.Linear(32, 8), nn.Sigmoid(),   # 第2层:全连接层 ,是一个包含8个神经元的全连接层,输入特征数为32,同样使用Sigmoid激活函数。
            nn.Linear(8, 4), nn.Sigmoid(),    # 第3层:全连接层 ,是一个包含4个神经元的全连接层,输入特征数为8,同样使用Sigmoid激活函数。
            nn.Linear(4, 1), nn.Sigmoid()    # 第4层:全连接层 ,是一个包含1个神经元的全连接层,输入特征数为4,同样使用Sigmoid激活函数。这是模型的输出层。
        ) 
 
    def forward(self, x): 
        ''' 前向传播 ''' 
        y = self.net(x)    # 将输入数据x通过模型定义的神经网络结构self.net进行前向传播计算,得到输出y。
        return y        # y即输出数据
 model = DNN().to('cuda:0')    # 创建子类的实例,并搬到GPU上 

这个代码可以当做模板,其中需要修改的部分为网络层的搭建,输入特征,中间层,输出特征一般都要为2的n次幂。

这就是该实例的各层。

四、训练网络

通过前向传播,反向传播等操作,本质上是不断调整权重和偏置。

loss_fn = nn.BCELoss(reduction='mean')#选择二元交叉熵损失函数作为模型的损失函数,其中reduction='mean'表示采用平均损失值作为最终的损失值。
learning_rate = 0.005    # 设置学习率 
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 
# 训练网络 
epochs = 5000     #设置训练的总轮数为5000。
losses = []        # 记录损失函数变化的列表 
 
# 给训练集划分输入与输出 
X = train_Data[ : , : -1 ]                # 前8列为输入特征 
Y = train_Data[ : , -1 ].reshape((-1,1))    # 后1列为输出特征
for epoch in range(epochs):     #对于每个epoch进行循环。
    Pred = model(X)              #通过模型进行一次前向传播,得到模型的预测结果Pred。
    loss = loss_fn(Pred, Y)        #计算模型预测结果与实际标签之间的损失值。
    losses.append(loss.item())     #将当前轮次的损失值记录到losses列表中。
    optimizer.zero_grad()        #清空上一轮的梯度信息。
    loss.backward()             #进行反向传播,计算梯度。
    optimizer.step()             #根据优化算法更新模型的参数,完成一轮训练。

#绘制损失函数随训练轮次变化的图像,用于可视化训练过程中损失值的变化。
Fig = plt.figure() 
plt.plot(range(epochs), losses) 
plt.ylabel('loss') 
plt.xlabel('epoch') 
plt.show() 

生成结果为

可以发现随着训练的进行,loss开始减少。

五、测试网络

通过用训练好的模型对测试集进行测试,由于只有一个输出特征为0或者1,将大于0.5的置为1,小于0.5的置为0,可以类比成可能性从0到1。

# 测试网络 
# 给测试集划分输入与输出 
X = test_Data[ : , : -1 ]                
Y = test_Data[ : , -1 ].reshape((-1,1))    
with torch.no_grad():    #进入上下文管理器,表示接下来的计算不会被记录在计算图中,因此不会影响梯度的计算。
    Pred = model(X)     
    Pred[Pred>=0.5] = 1 
    Pred[Pred<0.5] = 0 
    correct = torch.sum( (Pred == Y).all(1) )    #统计预测正确的样本数,使用.all(1)表示在第1维度(即行)上进行比较,得到一个布尔张量,再进行求和操作。
    total = Y.size(0)   #获取试集样本总数。
    print(f'测试集精准度: {100*correct/total} %') 

一般精准度得百分之八九十才合格哦,所以精度不高很有可能是训练集或者环境的问题,所以训练前一定要做好准备工作,因为训练一个模型要花费很久时间。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/701736.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

136G全国1m土地覆盖数据

数据是GIS的血液&#xff01; 我们现在为你分享一个136G的全国土地覆盖数据&#xff0c;该数据的分辨率为1米&#xff0c;你可以在文末查看领取方法。 数据概况 为满足对更精确地感知国土面积的需求&#xff0c;来自武汉大学和中国地质⼤学的团队利用低成本的激光扫描仪&…

SAP CS01/CS02/CS03 BOM创建维护删除BAPI使用及增强改造

BOM创建维护删除相关BAPI的使用代码参考示例&#xff0c;客户电脑只能远程桌面&#xff0c;代码没法复制粘贴出来&#xff0c;只能贴图。 创建及修改BAPI: CSAP_MAT_BOM_MAINTAIN。 删除BAPI: CSAP_MAT_BOM_DELETE。 改造BAPI: CSAP_MAT_BOM_MAINTAIN 改造点1&#xff1a;拷…

[LitCTF 2023]Virginia(变异凯撒)

题目&#xff1a; 首先利用网站进行维吉尼亚解密&#xff08;第一段和第二段密钥不同&#xff0c;两段无关&#xff0c;可只解密第二段&#xff09; 第二段解密结果&#xff1a; Please send this message to those people who mean something to you,to those who have touch…

雪球产品可能要远离普通人了

最近有消息说&#xff0c;在年初发生大规模敲入事件后&#xff0c;雪球产品的购买门槛可能从300w提升至1000w。 那么在这个时间&#xff0c;了解一下雪球产品到底是什么&#xff0c;运行原理是什么。 第一种 经典雪球 经典雪球比较容易理解&#xff0c;设定好了固定的敲出条件…

Java从放弃到继续放弃

并发编程 为什么需要多线程&#xff1f; 由于硬件的发展&#xff0c;CPU的核数增多&#xff0c;如果仍然使用单线程对CPU资源会造成浪费。同时&#xff0c;单线程也会出现阻塞的问题。所以&#xff0c;选择向多线程转变。 多线程的使用使得程序能够并行计算&#xff0c;提高计…

HTTPS请求阶段图解分析

HTTPS请求阶段分析 请求阶段分析 请求阶段分析 一个完整、无任何缓存、未复用连接的 HTTPS 请求需要经过以下几个阶段&#xff1a; DNS 域名解析、TCP 握手、SSL 握手、服务器处理、内容传输。 一个 HTTPS 请求共需要 5 个 RTT 1 RTT&#xff08;域名解析&#xff09; 1 RTT…

当你的AirPods没声音或有异常时怎么办?这里提供几个解决办法

你的AirPods听起来比以前安静吗?一个AirPod的声音比另一个大吗?这里有一些技巧可以帮助你提高音量,消除与音量相关的问题。 调整AirPods的音量 在我们深入研究修复AirPods的更技术性方法之前,让我们先检查一下它们是否处于最佳音量水平,并在你尝试更改音量时正确响应。如…

C语言入门系列:变量 —— 存储数据的容器

1. 变量及其作用 想象一下&#xff0c;你正在编写一个程序来记录学生的成绩。你需要一个地方来暂时存放这些成绩&#xff0c;这就是变量的用武之地。 变量不仅可以存储数值&#xff0c;还能保存字符、布尔值等多种数据类型&#xff0c;使你的程序能够处理多样化的信息。 代码…

vue3中使用defineProps、defineEmits和defineExpose

一、defineProps 父组件通过 v-bind 绑定一个数据&#xff0c;然后子组件通过 defineProps 接受传过来的值。 父组件&#xff1a; <template><StudyDefineProps :title"title" /> </template><script setup lang"ts"> import { r…

如何设置EVM/IMA密钥即如何生成签名密钥和证书

为生成和使用签名密钥来配置 IMA&#xff0c;以下是详细步骤&#xff0c;包括如何生成签名密钥和证书。 1. 生成签名密钥和证书 首先&#xff0c;我们需要一个私钥和一个自签名证书。我们将使用 OpenSSL 来生成这些密钥和证书。 生成私钥 openssl genpkey -algorithm RSA -…

k8s之HPA,命名空间资源限制

一、HPA 的相关知识 HPA&#xff08;Horizontal Pod Autoscaling&#xff09;Pod 水平自动伸缩&#xff0c;Kubernetes 有一个 HPA 的资源&#xff0c;HPA 可以根据 CPU 利用率自动伸缩一个 Replication Controller、Deployment 或者Replica Set 中的 Pod 数量。 &#xff08;…

idea ctrl+shift+f 全局搜索失效的解決方法

一定是輸入法的問題&#xff0c;而且是簡繁輸入快捷鍵的問題 你看我按了ctrlshiftf 之后直接打出来繁体字了。微软拼音的解决方法如下&#xff1a; 打开微软拼音输入法设置 关闭它

【稳定检索/投稿优惠】2024年心理健康与社会科学国际会议(MHSS 2024)

2024 International Conference on Mental Health and Social Sciences 2024年心理健康与社会科学国际会议 【会议信息】 会议简称&#xff1a;MHSS 2024截稿时间&#xff1a;点击查看大会地点&#xff1a;中国三亚会议官网&#xff1a;www.icmhss.com会议邮箱&#xff1a;mhs…

刷代码随想录有感(101):动态规划——有障碍的最短路径

题干&#xff1a; 代码&#xff1a; class Solution { public:int uniquePathsWithObstacles(vector<vector<int>>& obstacleGrid) {int m obstacleGrid.size();int n obstacleGrid[0].size();if(obstacleGrid[0][0] 1 || obstacleGrid[m - 1][n - 1] 1)r…

python-docx初探——如何用python新建world—添加段落、标题、表格

python-docx初探&#x1f680; 在项目中需要用到使用代码来编写一些结构化的文档&#xff0c;所以这里就需要涉及到一些需要用代码写world的一些工作&#xff0c;经过简单了解&#xff0c;python操作world最主要使用的就是python-docx文档&#xff0c;所以这次就先学一下这个库…

便民智慧小程序源码系统 同城信息+商家联盟+生活电商 功能强大 带完整的安装代码包以及搭建部署教程

系统概述 便民智慧小程序源码系统是一个高度集成化的本地化服务平台解决方案&#xff0c;它融合了同城信息发布、商家联盟管理和生活电商平台三大核心模块&#xff0c;旨在打造一个全方位、多维度的生活服务生态系统。该系统采用先进的前后端分离架构&#xff0c;支持快速响应…

Cartographer学习笔记

Cartographer 是一个跨多个平台和传感器配置提供 2D 和 3D 实时同步定位和地图绘制 (SLAM) 的系统。 1. 文件关系 2. 代码框架 common: 定义了基本数据结构和一些工具的使用接口。例如&#xff0c;四舍五入取整的函数、时间转化相关的一些函数、数值计算的函数、互斥锁工具等…

变量的基本原理

目录 注意&#xff1a; 程序中 号的使用 数据类型 string类 变量相当于内存中一个数据存储空间的表示&#xff0c;你可以把变量看做是一个房间的门牌号&#xff0c;通过门牌号我们可以找到房 间&#xff0c;而通过变量名可以访问到变量(值)。 int age 30; double score …

iFlyCode:AI智能编程助手引领未来软件开发新趋势

体验地址 在当前软件行业飞速发展的背景下&#xff0c;开发效率和代码质量成为了衡量软件工程师工作效能的两大关键指标。为了应对日益增长的市场需求和紧迫的发布时间&#xff0c;科大讯飞推出了iFlyCode2.0——一款集AI技术于一身的智能编程助手&#xff0c;旨在引领未来软件…

Linux C语言:指针的运算

一、指针的算术运算 1、指针运算 指针运算是以指针所存放的地址作为运算量而进行的指针运算的实质就是地址的计算 2、指针的算数运算 指针加上整数&#xff0c;指针减去整数, 指针递增&#xff0c;指针递减和两个指针相减。 指针加减一个n的运算: px n px - n 移动步长…