【神经网络】python实现神经网络(一)——数据集获取

一.概述

        在文章【机器学习】一个例子带你了解神经网络是什么中,我们大致了解神经网络的正向信息传导、反向传导以及学习过程的大致流程,现在我们正式开始进行代码的实现,首先我们来实现第一步的运算过程模拟讲解:正向传导。本次代码实现将以“手写数字识别”为例子。

二.测试训练数据集的获取

        首先我们需要通过官网获取到手写数字识别数据集,数据集一共分为四个部分,分别是训练集的图片(六万张)、训练集的标签、测试集的图片(一万张)以及测试集的标签。所以我们在代码中可以使用键值表示对应的key-value:

url_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {
    'train_img':'train-images-idx3-ubyte.gz',
    'train_label':'train-labels-idx1-ubyte.gz',
    'test_img':'t10k-images-idx3-ubyte.gz',
    'test_label':'t10k-labels-idx1-ubyte.gz'
}

        同时,我们需要将下载的文件保存到与代码同一级目录下:

dataset_dir = os.path.dirname(os.path.abspath(__file__))

        下载部分十分简单么,就不在此赘述,需要注意的是代码使用了python的urlretrieve函数,该函数需要使用头文件urllib.request,需要自行下载:

def download_mnist():
    for filename in key_file.values():
        file_path = dataset_dir + "/" + filename

        if os.path.exists(file_path):
            return

        print("Downloading " + filename + " ... ")
        urllib.request.urlretrieve(url_base + filename, file_path)
        print("Done")

三.测试训练数据集的加载

        下载完数据集后,我们需要将其加载到我们的程序中以供后续的使用,首先是判断一下我们是否已经下载过数据集,如果没有下载,则先进行下载操作,再执行其他步骤:

    if not os.path.exists(save_file) :
        download_mnist()
        dataset = _convert_numpy()
        print("Creating pickle file ...")
        with open(save_file, 'wb') as f:
            pickle.dump(dataset, f, -1)
        print("Done!")

        以上代码有个需要注意的地方,因为下载完数据集之后无法直接给到python使用,所以还需要对数据进行格式处理,处理成python可以识别的格式,这一步交由函数_convert_numpy实现:

def _convert_numpy():    
    dataset = {}
    dataset['train_img'] = _load_img(key_file['train_img'])
    dataset['train_label'] = _load_label(key_file['train_label'])
    dataset['test_img'] = _load_img(key_file['test_img'])
    dataset['test_label'] = _load_label(key_file['test_label'])

    return dataset

       其中,_load_img函数负责处理图片数据:

def _load_img(file_name):
    file_path = dataset_dir + "\\MNIST\\" + file_name

    print("Converting " + file_name + " to NumPy Array ...")
    with gzip.open(file_path, 'rb') as f:
        data = np.frombuffer(f.read(), np.uint8, offset=16)
    data = data.reshape(-1, img_size)
    print("Done")

    return data

        其中,_load_label函数负责处理标签数据:

def _load_label(file_name):
    file_path = dataset_dir + "\\MNIST\\" + file_name

    print("Converting " + file_name + " to NumPy Array ...")
    with gzip.open(file_path, 'rb') as f:
        labels = np.frombuffer(f.read(), np.uint8, offset=8)
    print("Done")

    return labels

        函数中使用到的都是一些python常用的函数,所以具体作用不在赘述,可自行查询。介绍完_convert_numpy函数,我们继续回到数据集加载函数本身,为了方便后续数据集的批量调用等操作,我们需要在加载数据后对其进行进一步的数据清洗整理等预处理,分别为数据归一化(normalize)、图像展开(flatten)以及图像标签对应(one_hot_label),先将三个功能代码贴上,然后我们再详细讲解各个功能的具体作用:

    with open(save_file,'rb') as f:
        dataset = pickle.load(f)

        if normalize:
            for key in ['train_img','test_img']:
                dataset[key] = dataset[key].astype(np.float32)

        if not flatten:
            for key in ('train_img', 'test_img'):
                dataset[key] = dataset[key].reshape(-1, 1, 28, 28)

        if one_hot_label:
            dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
            dataset['test_label'] = _change_one_hot_label(dataset['test_label'])

3.1.数据归一化(normalize)

        数据归一化normalize如果设置为True,可以将输入图像归一化为0.0~1.0 的值。如果将该参数设置为False,则输入图像的像素会保持原来的0~255。函数实现是使用了python函数中的astype功能将数据,用于将数据集指定字段的数据转换为 float32 类型,常见于深度学习模型输入前的数据预处理。

dataset[key] = dataset[key].astype(np.float32)

3.2.图像展开(flatten)

        图像展开flatten用于设置是否展开输入图像使其变成一维数组。如果将该参数设置为False,则输入图像为1 × 28 × 28 的三维数组;若设置为True,则输入图像会保存为由784 个元素构成的一维数组。函数实现也只是使用到深度学习中常用的reshape函数:

 dataset[key] = dataset[key].reshape(-1, 1, 28, 28)

3.3.图像标签对应(one_hot_label)

        图像标签对应one_hot_label用于设置是否将标签保存为onehot表示(one-hot representation)。one-hot 表示是仅正确解标签为1,其余皆为0 的数组,就像[0,0,1,0,0,0,0,0,0,0]这样。当one_hot_label为False时,就是像7、2这样简单保存正确解标签,函数_change_one_hot_label的实现如下:

def _change_one_hot_label(X):
    T = np.zeros((X.size, 10))
    for idx, row in enumerate(T):
        row[X[idx]] = 1

    return T

        以上即为测试训练数据集加载函数的全部内容,我们将在下面正式调用一下看看是否能够正常工作,在此贴上函数全文:

ef load_mnist(normalize=True, flatten=True, one_hot_label=False):
    if not os.path.exists(save_file) :
        download_mnist()
        dataset = _convert_numpy()
        print("Creating pickle file ...")
        with open(save_file, 'wb') as f:
            pickle.dump(dataset, f, -1)
        print("Done!")

    with open(save_file,'rb') as f:
        dataset = pickle.load(f)

        if normalize:
            for key in ['train_img','test_img']:
                dataset[key] = dataset[key].astype(np.float32)

        if not flatten:
            for key in ('train_img', 'test_img'):
                dataset[key] = dataset[key].reshape(-1, 1, 28, 28)

        if one_hot_label:
            dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
            dataset['test_label'] = _change_one_hot_label(dataset['test_label'])

    return (dataset['train_img'],dataset['train_label']),(dataset['test_img'],dataset['test_label'])

四.测试训练数据集的使用测试

        我们可以加载数据集并且查看到各个数据集的形状:

(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True,normalize=False)
# 输出各个数据的形状
print(x_train.shape) # (60000, 784)
print(t_train.shape) # (60000,)
print(x_test.shape) # (10000, 784)
print(t_test.shape) # (10000,)

        根据输出我们可以看到,训练集图片有六万张,每张图片有784各像素(28*28),训练集标签和照片数量一样(那是肯定的),测试集图片和标签数量比训练集的少,主要用来验证模型学习后的正确性。

        我们甚至还能随机从数据集中抽取一张照片查看一下实际样子,具体实现如下:

def img_show(img):
pil_img = Image.fromarray(np.uint8(img))
pil_img.show()
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True,normalize=False)
img = x_train[0]
label = t_train[0]
print(label) # 5
print(img.shape) # (784,)
img = img.reshape(28, 28) # 把图像的形状变成原来的尺寸
print(img.shape) # (28, 28)
img_show(img)

        输出的图片如图下所示:

        在后面的文章中,我们将开始正式步入主题,讲解神经网络如何学习,各层次之间如何传递数值,如何反向传导,计算损失,又在重新学习,最终实现传入一张手写数字就能自动识别出具体的数字的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/983527.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】冯诺依曼体系与操作系统理解

🌟🌟作者主页:ephemerals__ 🌟🌟所属专栏:Linux 目录 前言 一、冯诺依曼体系结构 二、操作系统 1. 操作系统的概念 2. 操作系统存在的意义 3. 操作系统的管理方式 4. 补充:理解系统调用…

HTML-网页介绍

一、网页 1.什么是网页: 网站是指在因特网上根据一定的规则,使用 HTML 等制作的用于展示特定内容相关的网页集合。 网页是网站中的一“页”,通常是 HTML 格式的文件,它要通过浏览器来阅读。 网页是构成网站的基本元素&#xf…

STM32——GPIO介绍

GPIO(General-Purpose IO ports,通用输入/输出接口)模块是STM32的外设接口的核心部分,用于感知外界信号(输入模式)和控制外部设备(输出模式),支持多种工作模式和配置选项。 1、GPIO 基本结构 STM32F407 的每个 GPIO 引脚均可独立配置,主要特性包括: 9 组 GPIO 端口…

字节码是由什么组成的?

Java字节码是Java程序编译后的中间产物,它是一种二进制格式的代码,可以在Java虚拟机(JVM)上运行。理解字节码的组成有助于我们更好地理解Java程序的运行机制。 1. Java字节码是什么? 定义 Java字节码是Java源代码经过…

链表算法题目

1.两数相加 两个非空链表,分别表示两个整数,只不过是反着存储的,即先存储低位在存储高位。要求计算这两个链表所表示数的和,然后再以相同的表示方式将结果表示出来。如示例一:两个数分别是342和465,和为807…

blender学习25.3.8

【04-进阶篇】Blender材质及灯光Cycle渲染&后期_哔哩哔哩_bilibili 注意的问题 这一节有一个大重点就是你得打开显卡的渲染,否则cpu直接跑满然后渲染的还十分慢 在这里你要打开GPU计算,但是这还不够 左上角编辑,偏好设置,系…

【godot4.4】布局函数库Layouts

概述 为了方便编写一些自定义容器和控件、节点时方便元素布局,所以编写了一套布局的求取函数,统一放置在一个名为Layouts的静态函数库中。 本文介绍我自定义的一些布局计算和实现以及函数编写的思路,并提供完整的函数库代码(持续…

Windows下配置Conda环境路径

问题描述: 安装好Conda之后,创建好自己的虚拟环境,同时下载并安装了Pycharm,但在Pycharm中找不到自己使用Conda创建好的虚拟环境。显示“Conda executable is not found” 解决办法(依次尝试以下) 起初怀…

OpenHarmony子系统开发编译构建指导

OpenHarmony子系统开发编译构建指导 概述 OpenHarmony编译子系统是以GN和Ninja构建为基座,对构建和配置粒度进行部件化抽象、对内建模块进行功能增强、对业务模块进行功能扩展的系统,该系统提供以下基本功能: 以部件为最小粒度拼装产品和独…

leetcode日记(80)复原IP地址

只能说之前动态规划做多了,看到就想到动态规划,然后想想其实完全不需要,回溯法就行了。 一开始用了很多莫名其妙的代码,写的很复杂……(主要因为最后不能加‘.’)其实想想只要最后加入vector时去掉最后一个…

LINUX网络基础 [五] - HTTP协议

目录 HTTP协议 预备知识 认识 URL 认识 urlencode 和 urldecode HTTP协议格式 HTTP请求协议格式 HTTP响应协议格式 HTTP的方法 HTTP的状态码 ​编辑HTTP常见Header HTTP实现代码 HttpServer.hpp HttpServer.cpp Socket.hpp log.hpp Makefile Web根目录 H…

【A2DP】SBC 编解码器互操作性要求详解

目录 一、SBC编解码器互操作性概述 二、编解码器特定信息元素(Codec Specific Information Elements) 2.1 采样频率(Sampling Frequency) 2.2 声道模式(Channel Mode) 2.3 块长度(Block Length) 2.4 子带数量(Subbands) 2.5 分配方法(Allocation Method) 2…

电脑内存智能监控清理,优化性能的实用软件

软件介绍 Memory cleaner是一款内存清理软件。功能很强,效果很不错。 Memory cleaner会在内存用量超出80%时,自动执行“裁剪进程工作集”“清理系统缓存”以及“用全部可能的方法清理内存”等操作,以此来优化电脑性能。 同时,我…

基于multisim的花样彩灯循环控制电路设计与仿真

1 课程设计的任务与要求 (一)、设计内容: 设计一个8路移存型彩灯控制器,基本要求: 1. 8路彩灯能演示至少三种花型(花型自拟); 2. 彩灯用发光二极管LED模拟; 3. 选做…

Axure常用变量及使用方法详解

点击下载《Axure常用变量及使用方法详解.pdf》 摘要 Axure RP 作为一款领先的前端原型设计工具,提供了全面的 变量 和 函数 系统,以支持复杂的交互设计和动态内容展示。本文将从专业角度详细解析 Axure 中的 全局变量、中继器数据集变量/函数、元件变量…

MySql的安装及数据库的基本操作命令

1.MySQL的安装 1.1进入MySQL官方网站 1.2点击下载 1.3下拉选择MySQL社区版 1.4选择你需要下载的版本及其安装的系统和下载方式 直接安装以及压缩包 建议选择8.4.4LST LST表明此版本为长期支持版 新手建议选择红框勾选的安装方式 1.5 安装包下载完毕之后点击安装 2.数据库…

树莓派5首次开机保姆级教程(无显示器通过VNC连接树莓派桌面)

第一次开机详细步骤 步骤一:树莓派系统烧录1 搜索打开烧录软件“Raspberry Pi Imager”2 选择合适的设备、系统、SD卡3 烧录配置选项 步骤二:SSH远程树莓派1 树莓派插电2 网络连接(有线或无线)3 确定树莓派IP地址 步骤三&#xff…

分布式锁—7.Curator的分布式锁

大纲 1.Curator的可重入锁的源码 2.Curator的非可重入锁的源码 3.Curator的可重入读写锁的源码 4.Curator的MultiLock源码 5.Curator的Semaphore源码 1.Curator的可重入锁的源码 (1)InterProcessMutex获取分布式锁 (2)InterProcessMutex的初始化 (3)InterProcessMutex.…

电脑睡眠智能管控:定时、依状态灵活调整,多模式随心选

软件介绍 今天要给大家介绍一款十分实用的软件——DontSleep,从名字就能看出,它的核心功能之一是阻止电脑进入睡眠状态,不过它的能耐可远不止于此,还具备强大的定时以及根据电脑实时状况灵活调整的功能。 这款软件采用绿色单文件…

装饰器模式--RequestWrapper、请求流request无法被重复读取

目录 前言一、场景二、原因分析三、解决四、更多 前言 曾经遇见这么一段代码,能看出来是把request又重新包装了一下,核心信息都不会改变 后面了解到这叫 装饰器模式(Decorator Pattern) :也称为包装模式(Wrapper Pat…