神经网络模型预训练

根据神经网络各个层的计算逻辑用程序实现相关的计算,主要是:前向传播计算、反向传播计算、损失计算、精确度计算等,并提供保存超参数到文件中。

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
from DeepLearn_Base.common.functions import *
from DeepLearn_Base.common.gradient import numerical_gradient
import pickle

# 三层神经网络处理类(两层隐藏层+1层输出层)
class ThreeLayerNet:

    # input_size:输入层神经元数量,灰度图像的三维表示: 1 * 28 * 28 = 784
    # output_size: 输出层神经元数量,10,表示10个数字
    # hidden_size:第一层隐藏层神经元数量,50
    # second_hidden_size:第二层隐藏层神经元数量,100
    # weight_init_std:权重初始化
    def __init__(self, input_size, hidden_size, output_size, second_hidden_size, weight_init_std=0.01):
        # 初始化权重
        self.params = {}
        self.params['W1'] = weight_init_std * np.random.randn(input_size, hidden_size)
        self.params['b1'] = np.zeros(hidden_size)
        self.params['W2'] = weight_init_std * np.random.randn(hidden_size, second_hidden_size)
        self.params['b2'] = np.zeros(second_hidden_size)
        self.params['W3'] = weight_init_std * np.random.randn(second_hidden_size, output_size)
        self.params['b3'] = np.zeros(output_size)

    # 执行预测
    def predict(self, x):
        W1, W2, W3 = self.params['W1'], self.params['W2'], self.params['W3']
        b1, b2, b3 = self.params['b1'], self.params['b2'], self.params['b3']
    
        # 隐藏层第一层
        a1 = np.dot(x, W1) + b1
        z1 = sigmoid(a1)

        # 隐藏层第二层
        a2 = np.dot(z1, W2) + b2
        z2 = sigmoid(a2)
        
        # 输出层
        a3 = np.dot(z2, W3) + b3
        y = softmax(a3)
        
        return y
        
    # x:输入数据, t:监督数据
    def loss(self, x, t):
        y = self.predict(x)
        
        return cross_entropy_error(y, t)
    
    # 精确度计算
    def accuracy(self, x, t):
        y = self.predict(x)
        y = np.argmax(y, axis=1)
        t = np.argmax(t, axis=1)
        
        accuracy = np.sum(y == t) / float(x.shape[0])
        return accuracy
        
    # 梯度计算
    def gradient(self, x, t):
        W1, W2, W3 = self.params['W1'], self.params['W2'], self.params['W3']
        b1, b2, b3 = self.params['b1'], self.params['b2'], self.params['b3']
        grads = {}
        
        batch_num = x.shape[0]
        
        # forward
        # 隐藏层第一层
        a1 = np.dot(x, W1) + b1
        z1 = sigmoid(a1)

        # 隐藏层第二层
        a2 = np.dot(z1, W2) + b2
        z2 = sigmoid(a2)
        
        # 输出层
        a3 = np.dot(z2, W3) + b3
        y = softmax(a3)
        
        # backward
        # 两层隐藏层计算梯度
        # 输出层梯度: Loss与输出的导数,分类场景下,等于预测值-真实值
        # 权重梯度: 隐藏层输出的转置 * 损失函数梯度
        dy = (y - t) / batch_num
        grads['W3'] = np.dot(z2.T, dy)
        grads['b3'] = np.sum(dy, axis=0)
        
        # 反向传播到隐藏层
        # 隐藏层梯度:Loss与输出的导数 * 输出层权重的转置
        da2 = np.dot(dy, W3.T)
        dz2 = sigmoid_grad(a2) * da2
        grads['W2'] = np.dot(z1.T, dz2)
        grads['b2'] = np.sum(dz2, axis=0)

        da1 = np.dot(da2, W2.T)
        dz1 = sigmoid_grad(a1) * da1
        grads['W1'] = np.dot(x.T, dz1)
        grads['b1'] = np.sum(dz1, axis=0)

        return grads
    
    # 保存参数到文件
    def save_params(self, file_name="params.pkl"):
        params = {}
        for key, val in self.params.items():
            params[key] = val
        with open(file_name, 'wb') as f:
            pickle.dump(params, f)

预训练实现

读取MNIST训练数据集,总共有60000个。每次从60000个训练数据中随机取出100个数据 (图像数据和正确解标签数据)。然后,对这个包含100笔数据的批数据求梯度,使用随机梯度下降法(SGD)更新参数。这里,梯度法的更新次数(循环的次数)为10000。每更新一次,都对训练数据计算损失函数的值,并把该值添加到数组中。

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import matplotlib.pyplot as plt
from DeepLearn_Base.dataset.mnist import load_mnist
from three_layer_net import ThreeLayerNet

# 读入数据
# x_train.sharp 60000 * 784
# t_train.sharp 60000 * 10
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)

network = ThreeLayerNet(input_size=784, hidden_size=50, second_hidden_size=100, output_size=10)

iters_num = 10000  # 适当设定循环的次数
# 训练集大小 60000
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.1

train_loss_list = []
train_acc_list = []
test_acc_list = []

# 每批次迭代数量:600
iter_per_epoch = max(train_size / batch_size, 1)

for i in range(iters_num):
    # 从训练集中选取100个为一批次进行训练
    batch_mask = np.random.choice(train_size, batch_size)
    x_batch = x_train[batch_mask]
    t_batch = t_train[batch_mask]
    
    # 更新超参数梯度
    grad = network.gradient(x_batch, t_batch)
    
    # 更新超参数W,b
	# 基于SGD算法更新梯度,上面是随机选择的批数据处理,因此更新时,也是随即更新梯度
    for key in ('W1', 'b1', 'W2', 'b2', 'W3', 'b3'):
        network.params[key] -= learning_rate * grad[key]
    
    loss = network.loss(x_batch, t_batch)
    train_loss_list.append(loss)
    
    if i % iter_per_epoch == 0:
        train_acc = network.accuracy(x_train, t_train)
        test_acc = network.accuracy(x_test, t_test)
        train_acc_list.append(train_acc)
        test_acc_list.append(test_acc)
        print("train acc, test acc | " + str(train_acc) + ", " + str(test_acc))

# 绘制图形
markers = {'train': 'o', 'test': 's'}
x = np.arange(len(train_acc_list))
plt.plot(x, train_acc_list, label='train acc')
plt.plot(x, test_acc_list, label='test acc', linestyle='--')
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()

# 输出到文件保存参宿后
network.save_params("E:\\workcode\\code\\DeepLearn_Base\\ch04\\myparams.pkl")

用图像来表示这个损失函数的值的推移,如图所示;并保存最终的超参数到pkl文件

应用自训练超参数

将之前用于预测图像文字中使用的超参数文件替换为自己预训练生成的pkl参数文件,并执行代码,打印出精确度。
这是基于默认的超参数进行推理后的精确度:

替换超参数文件,进行图像识别推理

 

精确度上涨了0.01,因此选择合适的梯度更新超参数,是保证推理精确度好坏的关键。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/212426.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Redis对象

Redis根据基本数据结构构建了自己的一套对象系统。主要包括字符串对象、列表对象、哈希对象、集合对象和有序集合对象 同时不同的对象都有属于自己的一些特定的redis指令集,而且每种对象也包括多种编码类型,和实现方式。 Redis对象结构 struct redisOb…

SSM SpringBoot vue社团事务管理系统

SSM SpringBoot vue社团事务管理系统 系统功能 登录 个人中心 人员信息管理 考勤信息管理 空闲时间管理 现金日记账管理 经费预算管理 物品租借管理 会议信息管理 活动信息管理 项目任务管理 公告通知管理 物资信息管理 开发环境和技术 开发语言:Java 使用框架:…

华为1+x网络系统建设与运维(中级)-练习题2

一.设备命令 LSW1 [Huawei]sys LSW1 同理可得,给所有设备改名 二.VLAN LSW1 [LSW1]vlan ba 10 20 [LSW1]int g0/0/1 [LSW1-GigabitEthernet0/0/1]port link-type trunk [LSW1-GigabitEthernet0/0/1]port trunk allow-pass vlan 10 20 [LSW1-GigabitEthernet0/0/1]in…

【Casbin】一篇文章入门Casbin

Casbin Casbin模型基础(PERM)Policy定义Request定义MatchersEffect ACL模型RBAC模型Go语言实战使用前先下载casbin包新建一个Casbin enforcer判断是否能通过增加Policy删除Policy更新Policy获取Policy Casbin 权限管理在几乎每个系统中都是必备的模块。…

Spring Cloud Gateway与spring-cloud-circuitbreaker集成与理解

官方文档地址 本文以 spring-cloud2021版本为例子 spring-cloud-gateway文档地址: https://spring.io/projects/spring-cloud-gateway#overview spring-cloud-circuitbreaker文档地址: https://spring.io/projects/spring-cloud-circuitbreaker 两者…

【LeetCode:1423. 可获得的最大点数 | 滑动窗口】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

matlab 计算两点云之间的放缩倍数

目录 一、算法原理1、原理概述2、参考文献二、代码实现三、结果展示本文由CSDN点云侠原创,原文链接。如果你不是在点云侠的博客中看到该文章,那么此处便是不要脸的爬虫与GPT。 一、算法原理 1、原理概述 放缩倍数即尺度参数,尺度参数可由2个公共点在不同坐标系下的距离之…

leetcode:225. 用队列实现栈

一、题目 链接:225. 用队列实现栈 - 力扣(LeetCode) 函数原型: typedef struct { } MyStack; MyStack* myStackCreate() void myStackPush(MyStack* obj, int x) int myStackPop(MyStack* obj) int myStackTop(MyStack* obj) …

【ArcGIS Pro微课1000例】0041:Pro强大的定位搜索功能、定位窗格、地图上查找地点

一谈到搜索,你是不是还停留在矢量数据的属性表中呢?今天给大家介绍ArcGIS Pro中定位搜索强大功能的使用,可以基于在线地图、矢量数据等多种数据源,进行地址、地名、道路、坐标等的查找。 文章目录 一、定位工具介绍二、在线地图搜索三、本地矢量数据搜索四、无地图搜索五、…

如何使用vue组件

目录 1:组件之间的父子关系 2:使用组件的三个步骤 3:components组件的是私有子组件 4:在main.js文件中使用Vue.component全局注册组件 1:组件之间的父子关系 一:首先封装好的组件是不存在任何的关系的…

scrapy-redis

一、什么是scrapy-redis Scrapy-Redis 是 Scrapy 框架的一个扩展,它提供了对 Redis 数据库的支持,用于实现分布式爬取。通过使用 Scrapy-Redis,你可以将多个 Scrapy 进程连接到同一个 Redis 服务器,共享任务队列和去重集&#xf…

人工智能和网络安全:坏与好

人工智能似乎可以并且已经被用来帮助网络犯罪和网络攻击的各个方面。 人工智能可以用来令人信服地模仿真人的声音。人工智能工具可以帮助诈骗者制作更好、语法正确的网络钓鱼消息(而糟糕的语法往往会暴露出漏洞),并将其翻译成多种语言&…

SSM整合 spring-mybaits配置文件——设置数据库字段名驼峰命名规则

一、简介:mybatis是支持属性使用驼峰的命名 如下java代码 public class Role {private Integer id;private String roleName;private String roleKey;private Integer orderNum;private Integer roleType;private String remark;...省略set,get方法 } 列名是有下…

数据结构和算法-线索二叉树中的线索化和在线索二叉树中找前驱后继

线索二叉树的概念 找到某个节点得按照遍历得到的序列开始遍历才能遍历全部节点,非常繁琐 中序线索二叉树 线索二叉树的存储结构 先序线索二叉树 后序线索二叉树 三种线索二叉树的对比 即对应前驱后后继判断标准不同 小结 二叉树的线索化 用土办法找中序前驱 当…

「Verilog学习笔记」自动贩售机1

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点,刷题网站用的是牛客网 自动贩售机中可能存在的几种金额:0,0.5,1,1.5,2,2.5,3。然后直接将其作为状态机的几种状…

13.字符串处理函数——输入输出

文章目录 前言一、题目描述 二、解题 程序运行代码 三、总结 前言 本系列为字符串处理函数编程题&#xff0c;点滴成长&#xff0c;一起逆袭。 一、题目描述 二、解题 程序运行代码 #include<stdio.h> #include<string.h> int main() {char str[10];printf(&q…

Ubuntu22.04无需命令行安装中文输入法

概要&#xff1a;Ubuntu22.04安装完成后&#xff0c;只需在设置中点点点即可完成中文输入法的安装&#xff0c;无需命令行。 一、安装中文语言包 1、点击屏幕右上角&#xff0c;如下图所示。 2、点击设置 3、选择地区与语言&#xff0c;点击管理已安装的语言 4、点击安装 5、输…

nodejs微信小程序+python+PHP药品招标采购系统的设计与实现-计算机毕业设计推荐MySQL

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.2 express框架介绍 6 2.4 MySQL数据库 4 第3章 系统分析 5 3.1 需求分析 5 3.2 系统可行性分析 5 3.2.1技术可行性&#xff1a;…

二维码设备安全巡检手指口述应用

二维码设备安全巡检手指口述应用 在安全管理中&#xff0c;”手指口述”工作法是通过心想、眼看、手指、口述等一系列行为&#xff0c;对工作过程中的每一道工序进行确认&#xff0c;使“人”的注意力和“物”的可靠性达到高度统一&#xff0c;从而达到避免违章、消除隐患、杜绝…

JAVA代码优化:记录日志

登录中的一条日志记录代码&#xff1a; //异步任务管理器&#xff08;详见文章异步任务管理器&#xff09; //me() 初始化线程池 AsyncManager.me() .execute( //异步工厂记录登录信息 AsyncFactory.recordLogininfor ( //使用者姓名 username, //常量登录失败public static …