cuda从零开始手搓PB神经网络

cuda实现PB神经网络


基于上一篇的矩阵点乘,实现了矩阵的加减乘除、函数调用等。并且复用之前元编程里面写的梯度下降、Adam、NAdam优化方法。实现PB神经网络如下:

#ifndef __BP_NETWORK_HPP__
#define __BP_NETWORK_HPP__
#include "matrix.hpp"
#include "mat.hpp"
#include "update_methods.hpp"

template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_, int ... remain_layer>
struct bp_network
{
    constexpr static int input_num = input_num_;
    constexpr static int output_num = output_num_;
    using val_type = val_type_;

    using input_type = mat<input_num, 1, val_type>;
    using input_t_type = mat<1, input_num, val_type>;
    using output_type = mat<output_num, 1, val_type>;
    using weight_type = mat<output_num, input_num, val_type>;

    using forward_func = typename func_pair<activate_type>::forward_func;
    using backward_func = typename func_pair<activate_type>::backward_func;

    using next_node_type = typename bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>;
    using term_output_type = typename next_node_type::term_output_type;

    weight_type weight;
    update_type_tpl<weight_type> weight_update_method;
    output_type bias;
    update_type_tpl<output_type> bias_update_method;

    input_type pre_input;
    output_type pre_func_input;
    next_node_type next_node;

    bp_network():weight_update_method(), bias_update_method()
    {
        weight.template reset<init_type>();
        bias.template reset<init_type>();
        next_node = bp_network<activate_type, val_type, update_type_tpl, init_type, output_num, remain_layer...>();
    }

    auto forward(input_type& input)
    {
        output_type curr_output;
        pre_input = input;
        auto temp = weight.dot(input);
        pre_func_input = temp + bias;
        curr_output = pre_func_input.template activate<forward_func>();
        return next_node.forward(curr_output);
    }

    auto backward(term_output_type& delta, val_type lr)
    {
        output_type curr_delta = next_node.backward(delta, lr);
        curr_delta = pre_func_input.template activate<backward_func>() * curr_delta;
        auto ret = weight.t_dot(curr_delta);
        // 更新参数
        weight_type delta_weight = curr_delta.dot(pre_input.t());
        weight = weight_update_method.update(weight, delta_weight);
        bias = bias_update_method.update(bias, curr_delta);
        return ret;
    }   

    // 更新惯性量
    void update_inert()
    {
        weight_update_method.update_inert();
        bias_update_method.update_inert();
        next_node.update_inert();
    }

    void print()
    {
        weight.print();
        printf("-----------------\n");
        bias.print();
        printf("=================\n");
        next_node.print();
    }
};

template<typename activate_type, typename val_type_, template<typename> class update_type_tpl, typename init_type, int input_num_, int output_num_>
struct bp_network<activate_type, val_type_, update_type_tpl, init_type, input_num_, output_num_>
{
    constexpr static int input_num = input_num_;
    constexpr static int output_num = output_num_;
    using val_type = val_type_;

    using input_type = mat<input_num, 1, val_type>;
    using input_t_type = mat<1, input_num, val_type>;
    using output_type = mat<output_num, 1, val_type>;
    using weight_type = mat<output_num, input_num, val_type>;

    using forward_func = typename func_pair<activate_type>::forward_func;
    using backward_func = typename func_pair<activate_type>::backward_func;
    using term_output_type = typename output_type;
    using weight_update_type = typename update_type_tpl<weight_type>;
    using bias_update_type = typename update_type_tpl<output_type>;

    weight_type weight;
    weight_update_type weight_update;
    output_type bias;
    bias_update_type bias_update;

    output_type pre_func_input;
    input_type pre_input;

    bp_network():weight_update(), bias_update()
    {
        weight.template reset<init_type>();
        bias.template reset<init_type>();
    }

    auto forward(input_type& input)
    {
        pre_input = input;
        auto temp = weight.dot(input);
        pre_func_input = temp + bias;
        return pre_func_input.template activate<forward_func>();
    }

    auto backward(output_type& delta, val_type lr)
    {
        output_type curr_delta = pre_func_input.template activate<backward_func>() * delta;
        auto ret = weight.t_dot(curr_delta);
        // 更新参数
        weight_type delta_weight = curr_delta.dot(pre_input.t());
        weight = weight_update.update(weight, delta_weight);
        bias = bias_update.update(bias, curr_delta);
        return ret;
    }

    void update_inert()
    {
        weight_update.update_inert();
        bias_update.update_inert();
    }

    void print()
    {
        weight.print();
        printf("-----------------\n");
        bias.print();
        printf("*****************\n");
    }
};

#endif

下面实验一下我们的bp神经网络。

#include <chrono>
#include <thread>
#include "matrix.hpp"
#include "bp_network.hpp"
int main()
{
    constexpr int row_num = 32;
    constexpr int adj_num = 32;
    constexpr int col_num = 32;
    /*
    matrix_device_proxy<row_num, adj_num, double> A;
    eyes(A(), 2.0);
    matrix_device_proxy<adj_num, col_num, double> B;
    eyes(B(), 1.0);
    matrix_device_proxy<row_num, col_num, double> C;
    mat_dot<sigmoid>(A(), B(), C());
    print(type_cast(C()));

    auto A = mat<row_num, adj_num, double>::eyes(2.0);
    auto B = mat<adj_num, col_num, double>::eyes(1.0);
    auto C = A.dot(B);
    C = C + 1.0;
    C = sqrtl(C);
    C = C - 2.0;
    C = C * 3.0;
    C = C / 4.0;
    C.print();

    std::cout << "---------- D ----------" << std::endl;
    auto D = mat<row_num, col_num, double>::xavier_gaussian();
    D.print();
    std::cout << "---------- E ----------" << std::endl;
    auto E = mat<row_num, col_num, double>::xavier_mean();
    E.print();
    std::cout << "---------- F ----------" << std::endl;
    auto F = mat<row_num, col_num, double>::he_gaussian();
    F.print();
    std::cout << "---------- G ----------" << std::endl;
    auto G = mat<row_num, col_num, double>::he_mean();
    G.print();
    */
    bp_network<sigmoid, double, nadam, xavier_gaussian_type, row_num, adj_num, col_num> node;
    auto input = mat<row_num, 1, double>::ones(0.2);
    auto expect = mat<col_num, 1, double>::ones(0.4);

    int times = 8000;
    int update_inert_times = 100;
    int step = times / update_inert_times;
    // 计时开始
    auto start = std::chrono::high_resolution_clock::now();
    for (int i = 0; i < times; ++i)
    {
        auto output = node.forward(input);
        auto delta = (output - expect);
        node.backward(delta, 0.001);
        if (i == times - 1)
        {
            output.t().print();
        }

        if (i % step == 0 && i != 0)
        {
            node.update_inert();
        }

    }
    // 计时结束
    // 获取结束时间点
    auto end = std::chrono::high_resolution_clock::now();

    // 计算持续时间
    std::chrono::duration<double> duration = end - start;

    // 输出执行时间
    std::cout << "Execution time: " << duration.count() << " seconds" << std::endl;
    //node.print();
    cudaDeviceReset();
    return 0;
}

以上代码有个学习率lr没有地方设置哈,将来优化,见谅。执行结果如下:
在这里插入图片描述
可以看出,经过8000次的训练,这个使用sigmoid激活函数、NAdam优化、Xavier-Gaussian初始化的323232的PB能够将误差缩减到0.0001这个量级,而训练时间仅为8.54秒。还是相当给力的。
虽然这对于我的工作没有任何关系,但是我还是想搞一下。毕竟“越是没用的知识就越有用,越是有用的东西就越没用”。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/955228.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Django多线程爬虫:突破数据抓取瓶颈

Django框架以其高效、安全、可扩展性强等特点&#xff0c;在Web开发领域得到了广泛应用。同时&#xff0c;Python语言的多线程支持和丰富的库也为开发多线程爬虫提供了便利。将Django与多线程技术相结合&#xff0c;不仅可以利用Django的强大功能进行项目管理和数据存储&#x…

RabbitMQ前置概念

文章目录 1.AMQP协议是什么&#xff1f;2.rabbitmq端口介绍3.消息队列的作用和使用场景4.rabbitmq工作原理5.整体架构核心概念6.使用7.消费者消息推送限制&#xff08;work模型&#xff09;8.fanout交换机9.Direct交换机10.Topic交换机&#xff08;推荐&#xff09;11.声明队列…

Windows环境本地配置pyspark环境详细教程

目录 一、背景简记二、本地单机spark环境配置详细步骤第一步&#xff1a;python环境安装第二步&#xff1a;安装jdk及配置环境变量安装包下载安装环境变量配置 第三步&#xff1a;安装Spark安装包下载安装配置环境变量 第四步&#xff1a;安装hadoop安装包下载安装配置环境变量…

智能家居篇 一、Win10 VM虚拟机安装 Home Assistant 手把手教学

智能家居篇 一、Win10 VM虚拟机安装 Home Assistant 手把手教学 文章目录 [智能家居篇]( )一、Win10 VM虚拟机安装 Home Assistant 手把手教学 前言一.下载Vm版本的HomeAsistant安装包 二.打开Vmware选择新建虚拟机1.选择自定义高级2.选择16.x及以上3.选择稍后安装4.根据官网的…

神经网络基础-正则化方法

文章目录 1. 什么是正则化2. 正则化方法2.1 Dropout正则化2.2 批量归一化(BN层) 学习目标&#xff1a; 知道正则化的作用掌握随机失活 DropOut 策略知道 BN 层的作用 1. 什么是正则化 在设计机器学习算法时希望在新样本上的泛化能力强。许多机器学习算法都采用相关的策略来减小…

vmware虚拟机配置ubuntu 18.04(20.04)静态IP地址

VMware版本 &#xff1a;VMware Workstation 17 Pro ubuntu版本&#xff1a;ubuntu-18.04.4-desktop-amd64 主机环境 win11 1. 修改 VMware虚拟网络编辑器 打开vmware&#xff0c;点击顶部的“编辑"菜单&#xff0c;打开 ”虚拟化网络编辑器“ 。 选择更改设置&#…

前端【2】html添加样式、CSS选择器

一、为html添加样式的三种方法 1、内部样式 2、外部样式 3、行内样式 二、css的使用--css选择器 1、css基本选择器 元素选择器 属性选择器 id选择器 class/类选择器 通配符选择器 2、群组选择器-多方面筛选 3、关系选择器 后代选择器【包含选择器】 子元素选择器…

30分钟内搭建一个全能轻量级springboot 3.4 + 脚手架 <3>5分钟集成好druid并使用druid自带监控工具监控sql请求

快速导航 <1> 5分钟快速创建一个springboot web项目 <2> 5分钟集成好最新版本的开源swagger ui&#xff0c;并使用ui操作调用接口 <3> 5分钟集成好druid并使用druid自带监控工具监控sql请求 <4> 5分钟集成好mybatisplus并使用mybatisplus generator自…

仿射密码实验——Python实现(完整解析版)

文章目录 前言实验内容实验操作步骤1.编写主程序2.编写加密模块3.编写解密模块4.编写文件加解密模块 实验结果实验心得实验源码scirpt.pyusefile.py 前言 实验目的 1&#xff09;初步了解古典密码 2&#xff09;掌握仿射密码的实现 实验方法 根据下图仿射密码&#xff08;变换…

在 QNAP NAS中使用 Container Station 运行 Docker 的完整指南

QNAP 为用户提供了一个名为 Container Station 的应用&#xff0c;它在 QNAP NAS 上将 Docker 和 LXC 结合在一起&#xff0c;通过图形化界面&#xff0c;让用户更轻松地在 NAS 上管理容器。本文将带你一步步了解如何在 QNAP NAS 上安装和使用 Container Station&#xff0c;以…

基于vite+vue3+mapbox-gl从零搭建一个项目

下面是基于 Vite、Vue 3 和 Mapbox GL 从零搭建一个项目的完整步骤&#xff0c;包括环境搭建、依赖安装、配置和代码示例。 文章目录 1. 初始化项目2. 安装 mapbox-gl 依赖3. 配置 Mapbox Access Token4. 实现地图组件5. 在 App.vue 中使用地图组件6. 启动开发服务器7. 添加自定…

运维作业一

1、shell 脚本写出检测 /tmp/size.log 文件如果存在显示它的内容&#xff0c;不存在则创建一个文件将创建时间写入。 2、写一个 shel1 脚本,实现批量添加 20个用户,用户名为user01-20,密码为user 后面跟5个随机字符。 首先&#xff0c;获得随机字符&#xff0c;需下载pwgen&am…

【拒绝算法PUA】3065. 超过阈值的最少操作数 I

系列文章目录 【拒绝算法PUA】0x00-位运算 【拒绝算法PUA】0x01- 区间比较技巧 【拒绝算法PUA】0x02- 区间合并技巧 【拒绝算法PUA】0x03 - LeetCode 排序类型刷题 【拒绝算法PUA】LeetCode每日一题系列刷题汇总-2025年持续刷新中 C刷题技巧总结&#xff1a; [温习C/C]0x04 刷…

SSE部署后无法连接问题解决

1. 问题现象 通过域名访问 https://api-uat.sfxs.com/sse/subscribe?tokenBearer%20eyJUxMiJ9.eyJhY2NvdW50IjoiYWRtaWZ0NvZGUiOiIwMDEiLCJyb2xidXNlcm5hbWUiOiLotoXnuqfnrqHnkIblkZgifQ.tlz9N61Y4 一直无法正常连接 2. 问题解决 nginx.conf进行配置 server {location /ss…

JS宏进阶: 工厂函数与构造函数

一、构造函数 在JavaScript中&#xff0c;构造函数是一种用于创建和初始化对象的特殊函数。构造函数的名字通常以大写字母开头&#xff0c;以区分于普通函数。通过new关键字调用构造函数&#xff0c;可以创建一个新的实例对象&#xff0c;并自动执行构造函数内部的代码来初始化…

编译pytorch——cuda-toolkit-nvcc

链接 https://blog.csdn.net/wjinjie/article/details/108997692https://docs.nvidia.com/cuda/cuda-installation-guide-linux/#switching-between-driver-module-flavorshttps://forums.developer.nvidia.com/t/can-not-load-nvidia-drivers-on-ubuntu-22-10/239750https://…

光谱相机的光谱分辨率可以达到多少?

多光谱相机 多光谱相机的光谱分辨率相对较低&#xff0c;波段数一般在 10 到 20 个左右&#xff0c;光谱分辨率通常在几十纳米到几百纳米之间&#xff0c;如常见的多光谱相机光谱分辨率为 100nm 左右。 高光谱相机 一般的高光谱相机光谱分辨率可达 2.5nm 到 10nm 左右&#x…

案例 —— 怪物出水

目录 一&#xff0c;Ocean Setup 二&#xff0c;Water Setup 解算前设置 解算设置 解算后设置 三&#xff0c;Meshing Setup Meshing 剔除外围边界mesh 应用低频频谱变形并添加变形速度 为whitewater创建自定义的surface、vel 清理属性和组并缓存 四&#xff0c;Wh…

怎么在iPhone手机上使用便签进行记录?

宝子们&#xff0c;在这个快节奏的时代&#xff0c;灵感的火花总是一闪而过&#xff0c;待办事项也常常让人应接不暇。好在咱们的 iPhone手机便签超给力&#xff0c;能满足各种记录需求&#xff01;今天就来给大家分享一下&#xff0c;如何在 iPhone 手机上巧用便签&#xff0c…

2025宝塔API一键建站系统PHP源码

源码介绍 2025宝塔API一键建站系统PHP源码&#xff0c;对接自己的支付&#xff0c;虚拟主机也能搭建&#xff0c;小白式建站系统&#xff0c;基于宝塔面板搭建的建站系统&#xff0c;功能丰富&#xff0c;多款模板&#xff0c;每日更新 上传源码到服务器&#xff0c;浏览器访问…