【代码】python实现一个BP神经网络-原理讲解与代码展示

本文来自《老饼讲解-BP神经网络》https://www.bbbdata.com/

目录

  • 一、BP神经网络原理回顾
    • 1.1 BP神经网络的结构简单回顾
    • 1.2.BP神经网络的训练算法流程
  • 二、python实现BP神经网络代码
    • 2.1.数据介绍
    • 2.2.pytorch实现BP神经网络代码

在python中要如何使用代码实现一个BP神经网络呢?
在python中可以利用pytorch来实现BP神经网络,这是最简洁也是最常用的方法。
通过本文可以详细掌握怎么使用python的pytorch来实现一个BP神经网络。

一、BP神经网络原理回顾

1.1 BP神经网络的结构简单回顾

BP神经网络的结构如下:
BP神经网络结构图
BP神经网络由输入层、隐层、输出层组成,其中隐层可以是有多层的,整个网络以前馈式进行计算,也就是每层的输出作为下层的输入,不断套娃,直到输出层

每层的计算公式如下:
y = T ( W X + B ) y=T(WX+B) y=T(WX+B)
其中,
X:该层的输入
W:该层的权重
B:该层的阈值
T:该层的激活函数

1.2.BP神经网络的训练算法流程

梯度下降算法求解BP神经网络的流程如下:
梯度下降算法求解BP神经网络

一、先初始化一个解                                                 
二、迭代                                                                  
1. 计算所有w,b在当前处的梯度dw,db           
2. 将w,b往负梯度方向更新:                       
   w = w-lr*dw                       
   b = b-lr*db       
3. 判断是否满足退出条件,如果满足,则退出迭代

二、python实现BP神经网络代码

在python中只需要使用pytorch就可以简单实现BP神经网络,而且提供了丰富的训练算法。

2.1.数据介绍

为方便理解,不妨采用以下的简单数据:
在这里插入图片描述
上述即为sin函数在[-5,5]之间的20个采样数据

2.2.pytorch实现BP神经网络代码

下面展示在pytorch中实现BP神经网络的代码
特别说明:需要先安装pytorch包

import torch
import matplotlib.pyplot as plt 
torch.manual_seed(99)

# -----------计算网络输出:前馈式计算---------------
def forward(w1,b1,w2,b2,x):                                   
    return w2@torch.tanh(w1@x+b1)+b2

# -----------计算损失函数: 使用均方差--------------
def loss(y,py):
    return ((y-py)**2).mean()

# ------训练数据----------------
x = torch.linspace(-5,5,20).reshape(1,20)                      # 在[-5,5]之间生成20个数作为x
y = torch.sin(x)                                               # 模型的输出值y

#-----------训练模型------------------------
in_num  = x.shape[0]                                            # 输入个数
out_num = y.shape[0]                                            # 输出个数
hn  = 4                                                         # 隐节点个数
w1  = torch.randn([hn,in_num],requires_grad=True)               # 初始化输入层到隐层的权重w1
b1  = torch.randn([hn,1],requires_grad=True)                    # 初始化隐层的阈值b1
w2  = torch.randn([out_num,hn],requires_grad=True)              # 初始化隐层到输出层的权重w2
b2  = torch.randn([out_num,1],requires_grad=True)               # 初始化输出层的阈值b2

lr = 0.01                                                       # 学习率
for i in range(5000):                                           # 训练5000步
    py = forward(w1,b1,w2,b2,x)                                 # 计算网络的输出
    L = loss(y,py)                                              # 计算损失函数
    print('第',str(i),'轮:',L)                                 # 打印当前损失函数值
    L.backward()                                                # 用损失函数更新模型参数的梯度
    w1.data=w1.data-w1.grad*lr                                  # 更新模型系数w1
    b1.data=b1.data-b1.grad*lr                                  # 更新模型系数b1
    w2.data=w2.data-w2.grad*lr                                  # 更新模型系数w2
    b2.data=b2.data-b2.grad*lr                                  # 更新模型系数b2
    w1.grad.zero_()                                             # 清空w1梯度,以便下次backward
    b1.grad.zero_()                                             # 清空b1梯度,以便下次backward
    w2.grad.zero_()                                             # 清空w2梯度,以便下次backward
    b2.grad.zero_()                                             # 清空b2梯度,以便下次backward
px = torch.linspace(-5,5,100).reshape(1,100)                    # 测试数据,用于绘制网络的拟合曲线    
py = forward(w1,b1,w2,b2,px).detach().numpy()                   # 网络的预测值
plt.scatter(x, y)                                               # 绘制样本
plt.plot(px[0,:],py[0,:])                                       # 绘制拟合曲线  
print('w1:',w1)
print('b1:',b1)
print('w2:',w2)
print('b2:',b2)

运行结果如下:

.....                                            
第 4996 轮: tensor(0.0083, grad_fn=<MeanBackward0>)
第 4997 轮: tensor(0.0083, grad_fn=<MeanBackward0>)
第 4998 轮: tensor(0.0083, grad_fn=<MeanBackward0>)
第 4999 轮: tensor(0.0083, grad_fn=<MeanBackward0>)
w1: tensor([[ 0.1742],[-0.8133],[-0.6450],[-0.4054]],requires_grad=True)
b1: tensor([[ 0.8125],[0.0593],[-1.8776],[1.1220]],requires_grad=True)
w2: tensor([[-0.7753,-2.0142,1.1161,1.9635]],requires_grad=True)
b2: tensor([[0.1094]], requires_grad=True)   

运行结果
可以看到,模型根据训练数据,已经较好地拟合出sin函数曲线


相关链接:

《老饼讲解-机器学习》:老饼讲解-机器学习教程-通俗易懂
《老饼讲解-神经网络》:老饼讲解-matlab神经网络-通俗易懂
《老饼讲解-神经网络》:老饼讲解-深度学习-通俗易懂

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/744370.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

excel实现下拉筛选(超简单)

excel实现下拉筛选 引言1、需求&#xff1a;预警状态下的列 实现下拉筛选2、实现2.1、数据验证2.2、下拉筛选内容2.3、去掉预警状态单元格的下拉筛选 引言 通常&#xff0c;我们会单独新建一张sheet表 专门存每个列的下拉内容。下面我将专门建立一张名为代码表的sheet表来存放…

Nginx调度器

Nginx反向代理 反向代理架构 部署后端Web1服务器 部署后端Web2服务器 配置Nginx服务器&#xff0c;添加服务器池&#xff0c;实现反向代理功能 proxy主机安装nginx 修改/usr/local/nginx/conf/nginx.conf配置文件 重新加载配置 客户端使用火狐浏览器或curl多次访问p…

阿尔兹海默症-图像分类数据集

阿尔兹海默症-图像分类数据集 数据集&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1gSUT74XrnHmg2Z11oZNd6A?pwdwphh 提取码&#xff1a;wphh 数据集信息介绍&#xff1a; 文件夹 健康 中的图片数量: 8000 文件夹 早期轻度认知障碍 中的图片数量: 10000 文件夹 …

https基础概念

目录 1.什么是https 1.1.https概念 1.2.一些术语 2.https如何加密 2.1.使用对称加密 2.2.使用非对称加密 1.什么是https 首先&#xff0c;https不是http的复数形式 1.1.https概念 &#xff08;1&#xff09;产生背景 为什么会出现https&#xff0c;是因为http协议在传…

Java 面试笔记 | Java 基础:线程池

前言 在日常的工作学习生活中&#xff0c;用一种好的方法去学习&#xff0c;可以更加有效&#xff0c;比如费曼学习法&#xff1a;将学到的知识用自己的组织的语言表达出来&#xff0c;如果能够清晰明白的向别人解释清楚&#xff0c;那么就说明你是真的懂了&#xff0c;学会了…

DataStructure.包装类简单认识泛型

包装类&简单认识泛型 【本节目标】1 包装类1.1 基本数据类型和对应的包装类1.2 装箱和拆箱1.3 自动装箱和自动拆箱 2 什么是泛型3 引出泛型3.1 语法 4 泛型类的使用4.1 语法4.2 示例4.3 类型推导(Type Inference) 5. 裸类型(Raw Type) &#xff08;了解&#xff09;5.1 说明…

C++初学者指南-2.输入和输出---文件输入和输出

C初学者指南-2.输入和输出—文件输入和输出 文章目录 C初学者指南-2.输入和输出---文件输入和输出1.写文本文件2.读文本文件3.打开关闭文件4.文件打开的模式 1.写文本文件 使用&#xff1a; std::ofstream&#xff08;输出文件流&#xff09; #include <fstream> // 文…

昂科烧录器支持KIOXIA铠侠的可编程只读存储器TH58NVG4S0HTAK0

芯片烧录行业领导者-昂科技术近日发布最新的烧录软件更新及新增支持的芯片型号列表&#xff0c;其中KIOXIA铠侠的电可擦除可编程只读存储器TH58NVG4S0HTAK0已经被昂科的通用烧录平台AP8000所支持。 TH58NVG4S0HTAK0是一个单一的3.3V 16Gbit&#xff08;18253611008位&#xff…

如何在Paperspace中使用 H100 单 GPU 和多 GPU 机器

本文将介绍如何在 DigitalOcean 的 Paperspace 平台上运行 NVIDIA H100 单卡 GPU 和 8 卡 GPU 。使用低成本的 GPU 云服务构建和扩展你的人工智能模型&#xff0c;即刻简化并加速你的机器学习工作。 如果需要了解 DigitalOcean 的 H100 GPU 云主机&#xff0c;请查看我们在 Dig…

【PA交易】BackTrader: 讨论下分析器和评测指标

前言 BackTrader的分析器主要使用的是analyzers模块&#xff0c;我们可以从Analyzers - Backtrader找到一个非常简单的示例。这个示例中使用方式很简单&#xff0c;其他分析器也可以通过如此简单封装方式进行装载。如果仅是复制粘贴官方教程&#xff0c;完全是制造互联网垃圾…

【前后端实现】AHP权重计算

AHP权重计算&#xff1a; 需求&#xff1a;前端记录矩阵维度、上三角值&#xff0c;后端构建比较矩阵、计算权重值并将结果返回给前端 比较矩阵构建 如果你想要根据上三角&#xff08;不包括对角线&#xff09;的值来构建对称矩阵&#xff0c;那么你可以稍作修改上述的generate…

Kivy tutorial 004: Making the GUI do stuff, binding to events

Kivy tutorial 004: Making the GUI do stuff, binding to events – Kivy Blog Central themes: Events and Kivy properties 中心主题&#xff1a;事件和kivy属性 We left the last tutorial with a calculator app GUI with some nice automatic behaviour, but which doe…

嵌入式C语言中常见寄存器的控制方法

使用C语言对寄存器赋值时,常常需要用到C语言的位操作方法。 把寄存器某位清零 假设a代表寄存器,且其中本来已有值。如果要把其中某一位清零且其它位不变,代码如下。 //定义一个变量 a = 1001 1111 b (二进制数)unsigned char a = 0x9f;//对 bit2 清零a &= ~(1<<…

实现批量自动化电商数据采集|商品详情页面|店铺商品信息|订单详情数据

电商数据采集是指通过技术手段获取电商平台上的商品信息、店铺信息和订单信息等数据。这些数据可以用于市场分析、竞品分析、用户行为分析等。 商品详情页面是指电商平台上展示商品详细信息的页面&#xff0c;包括商品名称、价格、图片、描述、评价等信息。通过采集商品详情页…

springboot+vue+mybatis门窗管理系统+PPT+论文+讲解+售后

如今社会上各行各业&#xff0c;都在用属于自己专用的软件来进行工作&#xff0c;互联网发展到这个时候&#xff0c;人们已经发现离不开了互联网。互联网的发展&#xff0c;离不开一些新的技术&#xff0c;而新技术的产生往往是为了解决现有问题而产生的。针对于仓库信息管理方…

[学习笔记] 禹神:一小时快速上手Electron笔记,附代码

课程地址 禹神&#xff1a;一小时快速上手Electron&#xff0c;前端Electron开发教程_哔哩哔哩_bilibili 笔记地址 https://github.com/sui5yue6/my-electron-app 进程通信 桌面软件 跨平台的桌面应用程序 chromium nodejs native api 流程模型 main主进程 .js文件 node…

攻克PS之路——Day1(A1-A8)

#暑假到了&#xff0c;作为可能是最后一个快乐的暑假&#xff0c;我打算学点技能来傍身&#xff0c;首先&#xff0c;开始PS之旅 这个帖子作为我跟着B站up主学习PS的记录吧&#xff0c;希望我可以坚持下去&#xff01; 学习的链接在这里&#xff1a;A02-PS软件安装&#xff0…

Qt | 子类化 QStyle(Qt自带图标大全)

01、简介 1、把绘制自定义部件外观的步骤大致分为三大板块,如下: ①、样式元素:即指定需要绘制的图形元素(比如焦点框、按钮,工具栏等)。样式元素使 用 QStyle 类中的一系列枚举(共有 11 个枚举)进行描述。 ②、样式选项:包含了需要绘制的图形元素的所有信息,比如包含…

【QCustomPlot实战系列】QCPGraph区域高亮

使用QCPDataSelection来设置选中的区域&#xff0c;并将QCPGraph的可选择区域设置成QCP::stMultipleDataRanges void AreaPieces::initCustomPlot(QCustomPlot *parentPlot) {QVector<double> x {0, 1, 2, 3, 4, 5, 6, 7, 8};QVector<double> y {200, 560, 750…

asp.net core反向代理

新建项目 新建空白的asp.net core web项目 安装Yarp.ReverseProxy包版本为2.2.0-preview.1.24266.1 编写代码 namespace YarpStu01;public class Program {public static void Main(string[] args){var builder WebApplication.CreateBuilder(args);builder.Services.AddRev…