基于飞桨paddle的极简方案构建手写数字识别模型测试代码

基于飞桨paddle的极简方案构建手写数字识别模型测试代码
在这里插入图片描述
原始测试图片为255X252的图片
因为是极简方案采用的是线性回归模型,所以预测结果数字不一致
本次预测的数字是 [[3]]
测试结果:

PS E:\project\python> & D:/Python39/python.exe e:/project/python/MNIST.py
10.0.0
2.4.2
图像数据形状和对应数据为: (28, 28)
图像标签形状和对应数据为: (1,) [5]

打印第一个batch的第一个图像,对应标签数字为[5]
epoch_id: 0, batch_id: 0, loss is: [34.4626]
epoch_id: 0, batch_id: 1000, loss is: [7.599941]
epoch_id: 0, batch_id: 2000, loss is: [4.583123]
epoch_id: 0, batch_id: 3000, loss is: [2.8974648]
epoch_id: 1, batch_id: 0, loss is: [3.610869]
epoch_id: 1, batch_id: 1000, loss is: [5.6290216]
epoch_id: 1, batch_id: 2000, loss is: [1.9465038]
epoch_id: 1, batch_id: 3000, loss is: [2.1046467]
epoch_id: 7, batch_id: 2000, loss is: [4.63013]
epoch_id: 7, batch_id: 3000, loss is: [4.4638147]
epoch_id: 8, batch_id: 0, loss is: [3.0043283]
epoch_id: 8, batch_id: 1000, loss is: [1.633965]
epoch_id: 8, batch_id: 2000, loss is: [3.1906333]
epoch_id: 8, batch_id: 3000, loss is: [2.4461133]
epoch_id: 9, batch_id: 0, loss is: [3.9595613]
epoch_id: 9, batch_id: 1000, loss is: [1.3417265]
epoch_id: 9, batch_id: 2000, loss is: [2.3505783]
epoch_id: 9, batch_id: 3000, loss is: [2.0194921]
原始图像shape:  (252, 255)
采样后图片shape:  (28, 28)
result Tensor(shape=[1, 1], dtype=float32, place=Place(cpu), stop_gradient=False,
       [[3.94108272]])
本次预测的数字是 [[3]]
PS E:\project\python>

测试代码如下所示:

#加载飞桨和相关类库
import paddle
from paddle.nn import Linear
import paddle.nn.functional as F
import os
import numpy as np
import matplotlib.pyplot as plt
# 导入图像读取第三方库
from PIL import Image,ImageFilter
print(Image.__version__)    #10.0.0
#原来是在pillow的10.0.0版本中,ANTIALIAS方法被删除了,使用新的方法即可Image.LANCZOS
#或降级版本为9.5.0,安装pip install Pillow==9.5.0
print(paddle.__version__)   #2.4.2

#飞桨提供了多个封装好的数据集API,涵盖计算机视觉、自然语言处理、推荐系统等多个领域,
# 帮助读者快速完成深度学习任务。
# 如在手写数字识别任务中,
# 通过paddle.vision.datasets.MNIST可以直接获取处理好的MNIST训练集、测试集,
# 飞桨API支持如下常见的学术数据集:
'''
mnist
cifar
Conll05
imdb
imikolov
movielens
sentiment
uci_housing
wmt14
wmt16
'''

#数据处理
# 设置数据读取器,API自动读取MNIST数据训练集
train_dataset = paddle.vision.datasets.MNIST(mode='train')

train_data0 = np.array(train_dataset[0][0])
train_label_0 = np.array(train_dataset[0][1])

# 显示第一batch的第一个图像
'''
import matplotlib.pyplot as plt
plt.figure("Image") # 图像窗口名称
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
plt.axis('on') # 关掉坐标轴为 off
plt.title('image') # 图像题目
plt.show()
'''

print("图像数据形状和对应数据为:", train_data0.shape)                          #(28, 28)
print("图像标签形状和对应数据为:", train_label_0.shape, train_label_0)         #(1,) [5]
print("\n打印第一个batch的第一个图像,对应标签数字为{}".format(train_label_0))   # [5]

#飞桨将维度是28×28的手写数字图像转成向量形式存储,
# 因此使用飞桨数据加载器读取到的手写数字图像是长度为784(28×28)的向量。

#模型设计
#模型的输入为784维(28×28)数据,输出为1维数据,

# 定义mnist数据识别网络结构,同房价预测网络
#===========================================
class MNIST(paddle.nn.Layer):
    def __init__(self):
        super(MNIST, self).__init__()
        
        # 定义一层全连接层,输出维度是1
        self.fc = paddle.nn.Linear(in_features=784, out_features=1)
        
    # 定义网络结构的前向计算过程
    def forward(self, inputs):
        outputs = self.fc(inputs)
        return outputs
#===========================================

#训练配置
# 声明网络结构
model = MNIST()
def train(model):
    # 启动训练模式
    model.train()
    # 加载训练集 batch_size 设为 16
    train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), 
                                        batch_size=16, 
                                        shuffle=True)
    # 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001
    opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())
#===========================================
# 图像归一化函数,将数据范围为[0, 255]的图像归一化到[0, 1]
def norm_img(img):
    # 验证传入数据格式是否正确,img的shape为[batch_size, 28, 28]
    assert len(img.shape) == 3
    batch_size, img_h, img_w = img.shape[0], img.shape[1], img.shape[2]
    # 归一化图像数据
    img = img / 255
    # 将图像形式reshape为[batch_size, 784]
    img = paddle.reshape(img, [batch_size, img_h*img_w])
    
    return img  
#===========================================   
import paddle
# 确保从paddle.vision.datasets.MNIST中加载的图像数据是np.ndarray类型
paddle.vision.set_image_backend('cv2')

# 声明网络结构
model = MNIST()
#===========================================
def run(model):
    # 启动训练模式
    model.train()
    # 加载训练集 batch_size 设为 16
    train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), 
                                        batch_size=16, 
                                        shuffle=True)
    # 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001
    opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())
    EPOCH_NUM = 10
    for epoch in range(EPOCH_NUM):
        for batch_id, data in enumerate(train_loader()):
            images = norm_img(data[0]).astype('float32')
            labels = data[1].astype('float32')
            
            #前向计算的过程
            predicts = model(images)
            
            # 计算损失
            loss = F.square_error_cost(predicts, labels)
            avg_loss = paddle.mean(loss)
            
            #每训练了1000批次的数据,打印下当前Loss的情况
            if batch_id % 1000 == 0:
                print("epoch_id: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, avg_loss.numpy()))
            
            #后向传播,更新参数的过程
            avg_loss.backward()
            opt.step()
            opt.clear_grad()
#===========================================
#调用训练            
run(model)
paddle.save(model.state_dict(), './mnist.pdparams')  

#模型测试

#===========================================
def showImage(im):
  #img_path = 'example_0.jpg'
  # 读取原始图像并显示
  #im = Image.open('example_0.jpg')
  plt.imshow(im)
  plt.show()
  # 将原始图像转为灰度图
  im = im.convert('L')
  print('原始图像shape: ', np.array(im).shape)
  # 使用Image.ANTIALIAS方式采样原始图片
  im = im.resize((28, 28), Image.LANCZOS)
  plt.imshow(im)
  plt.show()
  print("采样后图片shape: ", np.array(im).shape)
#===========================================
im = Image.open('example_0.jpg')
showImage(im)

# 读取一张本地的样例图片,转变成模型输入的格式
#=========================================== 
def load_image(img_path):
    # 从img_path中读取图像,并转为灰度图
    im = Image.open(img_path).convert('L')
    # print(np.array(im))
    im = im.resize((28, 28), Image.LANCZOS)
    im = np.array(im).reshape(1, -1).astype(np.float32)
    # 图像归一化,保持和数据集的数据范围一致
    im = 1 - im / 255
    return im
#=========================================== 
# 定义预测过程
def test():
  model = MNIST()
  params_file_path = 'mnist.pdparams'
  img_path = 'example_0.jpg'
  # 加载模型参数
  param_dict = paddle.load(params_file_path)
  model.load_dict(param_dict)
  # 灌入数据
  model.eval()
  tensor_img = load_image(img_path)  
  result = model(paddle.to_tensor(tensor_img))
  print('result',result)
  #  预测输出取整,即为预测的数字,打印结果
  print("本次预测的数字是", result.numpy().astype('int32'))
#=========================================== 
test(); 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/48973.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

性能提升,SpringBoot 3.2虚拟线程来了

spring boot 3.2 会提供默认支持,必须Java19。 在以往的项目中,我们面临了这样一种情况:我们收到了数千个认证请求。为了确保安全性,我们依靠第三方系统发送短信 OTP 进行验证。然而,有时候第三方系统花费的时间比预期…

Tesseract开源的OCR工具及python pytesseract安装使用

一 、介绍 Tesseract是一款由Google赞助的开源OCR。 pytesseract是python包装器,它为可执行文件提供了pythonic API。 Tesseract 已经有 30 年历史,开始它是惠普实验室的一款专利软件,在2005年后由Google接手并进一步开发和完善。Tesseract支…

高压放大器模块的作用是什么呢

高压放大器模块是一种集成了高压放大器芯片、控制电路、保护电路等多种元件和功能的模块化设备。它可以将输入信号进行放大处理,并输出到负载上,具有高性能、高可靠性、高稳定性等优点。下面安泰电子将详细介绍高压放大器模块的作用: 信号放大…

如何学好Java并调整学习过程中的心态:学习之路的秘诀

文章目录 第一步:建立坚实的基础实例分析:选择合适的学习路径 第二步:选择合适的学习资源实例分析:参与编程社区 第三步:动手实践实例分析:开发个人项目 调整学习过程中的心态1. 不怕失败2. 持续学习3. 寻求…

外观模式-提供统一入口

在办理流量套餐的时候,我们可以去官网查查询套餐,找出符合我们需求的套餐,然后进行办理。官网是充斥着各种各样的套餐,如果我们一个个去查看这些套餐,耗费的时间很多。而且在办理套餐的时候,步骤也比较多。…

Linux工具【1】(编辑器vim、编译器gcc与g++)

vim详解 引言vimVim的三种模式及模式切换普通模式下操作底行模式下操作 gcc与ggcc的使用(g类似)预编译编译汇编链接静态库与动态库 总结 引言 vim(vi improved)编辑器是从 vi 发展出来的一个文本编辑器。 代码补全、编译及错误跳…

03-树1 树的同构(c++)

03-树1 树的同构 给定两棵树T1和T2。如果T1可以通过若干次左右孩子互换就变成T2,则我们称两棵树是“同构”的。例如图1给出的两棵树就是同构的,因为我们把其中一棵树的结点A、B、G的左右孩子互换后,就得到另外一棵树。而图2就不是同构的。 现…

MySQL优化(面试)

文章目录 通信优化查询缓存语法解析及查询优化器查询优化器的策略 性能优化建议数据类型优化索引优化 优化关联查询优化limit分页对于varchar end mysql查询过程: 客户端向MySQL服务器发送一条查询请求服务器首先检查查询缓存,如果命中缓存,则立刻返回存…

基于Javaweb实现ATM机系统开发实战(十五)退卡和转账跳转实现

首先创建一个servlet接受和处理请求: package com.atm.servlet;import javax.servlet.*; import javax.servlet.http.*; import javax.servlet.annotation.*; import java.io.IOException;//用户退出 WebServlet("/logout") public class ExitServlet ex…

CSDN浏览如何解决

一、对于平时我们苦恼csdn数据不够好看 当面试等各个场合需要我们装*或者秀技术无法拿出亮眼的时候,刚好我闲时间编译的在线模块适合你 二、如何操作(虚拟平台我已给大家放到最后直接使用即可) 重点:pc端必须拥有python环境 win…

JGIT获取远程仓库、本地仓库提交版本号

https://www.freesion.com/article/50181381474/ JGIT获取远程仓库、本地仓库提交版本号 一、环境搭建二、项目结构二、代码部分 GitUtils.javaGitInfoAtom.java三、运行结果&#xff1a;总结 一、环境搭建 Maven依赖导入 <dependency><groupId>org.eclipse.jg…

图像滤波器

图像噪声 • 图像噪声是图像在获取或是传输过程中受到随机信号干扰&#xff0c;妨碍人们对图像理解及分析处理 的信号。 • 图像噪声的产生来自图像获取中的环境条件和传感元器件自身的质量&#xff0c;图像在传输过程中产 生图像噪声的主要因素是所用的传输信道受到了噪声…

【深度学习】基于图形的机器学习:概述

一、说明 图神经网络&#xff08;GNN&#xff09;在数据科学和机器学习中越来越受到关注&#xff0c;但在专家圈之外仍然知之甚少。为了掌握这种令人兴奋的方法&#xff0c;我们必须从更广泛的图形机器学习&#xff08;GML&#xff09;领域开始。许多在线资源谈论GNN和GML&…

DP83867IS SGMII eye diagram问题调试记录

1. 前言 使用的是带CPU的DP83867IS,通过SGMII接口 从PHY到CPU的眼图看起来很好 而从CPU到PHY的眼图很差 2. 问题梳理 (1)能证实SGMII道有100欧姆的阻抗吗? (2)能不能做一个误码率测试来看看眼图是否仍然是可以接受的? (3)因为从PHY到CPU的眼睛是好的,可能有一个…

Oracle 最高安全架构

​在当今世界中&#xff0c;数据库是存储敏感信息的宝贵资料库&#xff0c;攻击者总是在寻找目标。这导致网络安全威胁的增加&#xff0c;因此有必要采取适当的保护措施。Oracle Maximum Security Architecture&#xff08;MSA&#xff09;就是一种提供数据库端到端安全的解决方…

MYSQL 主从复制

在读多写少的网络环境下&#xff0c;MySQL 如何优化数据查询方案 假如说一个电商平台 到双十一了 大量的读写操作 如果不做点什么的话 平台就被冲烂了 那我们要怎么办呢? 你或许会想 林北直接一个redis缓存 帮数据库度过难关 这个操作实际上是不行的 因为应用缓存的原则之一…

【开发环境】Windows下搭建TVM编译器

关于搭建TVM编译器的官方文档&#xff1a;Install from Source — tvm 0.14.dev0 documentation (apache.org) 1. 安装Anaconda 首先我们需要安装Anaconda&#xff0c;因为其中包含着我们所需要的各类依赖&#xff1a; 进入Anaconda官网https://www.anaconda.com/products/d…

【Spring Cloud Alibaba】Sentinel运行原理

文章目录 前言1、基本原理2、SphU.entry()2.1、StringResourceWrapper2.2、Entry 3、entry.exit()4、Context 前言 本文基于sentinel-1.8.0版本 Sentinel 是面向分布式服务架构的流量控制组件&#xff0c;主要以流量为切入点&#xff0c;从限流、流量整形、熔断降级、系统负载保…

DoIP学习笔记系列:导航篇

文章目录 1. 前言2. 导航3. 参考资料 1. 前言 DoIP学习笔记系列是一整套基于网络的诊断协议学习笔记&#xff0c;非常适合对有UDS基础但对DoIP没有实战经验的小伙伴参考&#xff0c;通过源协议讲解&#xff0c;企标讲解&#xff0c;测试需求讲解&#xff0c;测试用例讲解&…

STM32CubeMX配置STM32G031多通道ADC采集(HAL库开发)

时钟配置HSI主频配置64M 勾选打开8个通道的ADC 使能连续转换模式 配置好串口&#xff0c;选择异步模式配置好需要的开发环境并获取代码 修改main.c 串口重定向 #include "stdio.h" int fputc(int ch, FILE *f) {HAL_UART_Transmit(&huart1, (uint8_t *)&ch…