深度学习手写字符识别:训练模型

说明

本篇博客主要是跟着B站中国计量大学杨老师的视频实战深度学习手写字符识别。
第一个深度学习实例手写字符识别

深度学习环境配置

可以参考下篇博客,网上也有很多教程,很容易搭建好深度学习的环境。
Windows11搭建GPU版本PyTorch环境详细过程

数据集

手写字符识别用到的数据集是MNIST数据集(Mixed National Institute of Standards and Technology database);MNIST是一个用来训练各种图像处理系统二进制图像数据集,广泛应用到机器学习中的训练和测试。
作为一个入门级的计算机视觉数据集,发布20多年来,它已经被无数机器学习入门者应用无数遍,是最受欢迎的深度学习数据集之一。

序号说明
发布方National Institute of Standards and Technology(美国国家标准技术研究所,简称NIST)
发布时间1998
背景该数据集的论文想要证明在模式识别问题上,基于CNN的方法可以取代之前的基于手工特征的方法,所以作者创建了一个手写数字的数据集,以手写数字识别作为例子证明CNN在模式识别问题上的优越性。
简介MNIST数据集是从NIST的两个手写数字数据集:Special Database 3 和Special Database 1中分别取出部分图像,并经过一些图像处理后得到的。MNIST数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图像都是28×28的灰度图像,每张图像包含一个手写数字。

跟着视频跑源码

  1. 下载源码:mivlab/AI_course (github.com)
  2. 下载数据集:https://opendatalab.com/MNIST;网上下载的地址比较多,也可以直接下载B站中国计量大学杨老师的百度网盘位置里的MNIST。

运行源码

  1. 在Pycharm中打开AI_course项目,运行classify_pytorch文件目录里train_mnist.py的Python文件。
    在这里插入图片描述
    train_mnist.py具体的源码如下:
import torch
import math
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms, models
import argparse
import os
from torch.utils.data import DataLoader

from dataloader import mnist_loader as ml
from models.cnn import Net
from toonnx import to_onnx


parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--datapath', required=True, help='data path')
parser.add_argument('--batch_size', type=int, default=256, help='training batch size')
parser.add_argument('--epochs', type=int, default=300, help='number of epochs to train')
parser.add_argument('--use_cuda', default=False, help='using CUDA for training')

args = parser.parse_args()
args.cuda = args.use_cuda and torch.cuda.is_available()
if args.cuda:
    torch.backends.cudnn.benchmark = True


def train():
    os.makedirs('./output', exist_ok=True)
    if True: #not os.path.exists('output/total.txt'):
        ml.image_list(args.datapath, 'output/total.txt')
        ml.shuffle_split('output/total.txt', 'output/train.txt', 'output/val.txt')

    train_data = ml.MyDataset(txt='output/train.txt', transform=transforms.ToTensor())
    val_data = ml.MyDataset(txt='output/val.txt', transform=transforms.ToTensor())
    train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(dataset=val_data, batch_size=args.batch_size)

    model = Net(10)
    #model = models.vgg16(num_classes=10)
    #model = models.resnet18(num_classes=10)  # 调用内置模型
    #model.load_state_dict(torch.load('./output/params_10.pth'))
    #from torchsummary import summary
    #summary(model, (3, 28, 28))

    if args.cuda:
        print('training with cuda')
        model.cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [20, 30], 0.1)
    loss_func = nn.CrossEntropyLoss()

    for epoch in range(args.epochs):
        # training-----------------------------------
        model.train()
        train_loss = 0
        train_acc = 0
        for batch, (batch_x, batch_y) in enumerate(train_loader):
            if args.cuda:
                batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda())
            else:
                batch_x, batch_y = Variable(batch_x), Variable(batch_y)
            out = model(batch_x)  # 256x3x28x28  out 256x10
            loss = loss_func(out, batch_y)
            train_loss += loss.item()
            pred = torch.max(out, 1)[1]
            train_correct = (pred == batch_y).sum()
            train_acc += train_correct.item()
            print('epoch: %2d/%d batch %3d/%d  Train Loss: %.3f, Acc: %.3f'
                  % (epoch + 1, args.epochs, batch, math.ceil(len(train_data) / args.batch_size),
                     loss.item(), train_correct.item() / len(batch_x)))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        scheduler.step()  # 更新learning rate
        print('Train Loss: %.6f, Acc: %.3f' % (train_loss / (math.ceil(len(train_data)/args.batch_size)),
                                               train_acc / (len(train_data))))

        # evaluation--------------------------------
        model.eval()
        eval_loss = 0
        eval_acc = 0
        for batch_x, batch_y in val_loader:
            if args.cuda:
                batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda())
            else:
                batch_x, batch_y = Variable(batch_x), Variable(batch_y)

            out = model(batch_x)
            loss = loss_func(out, batch_y)
            eval_loss += loss.item()
            pred = torch.max(out, 1)[1]
            num_correct = (pred == batch_y).sum()
            eval_acc += num_correct.item()
        print('Val Loss: %.6f, Acc: %.3f' % (eval_loss / (math.ceil(len(val_data)/args.batch_size)),
                                             eval_acc / (len(val_data))))
        # 保存模型。每隔多少帧存模型,此处可修改------------
        if (epoch + 1) % 1 == 0:
            # torch.save(model, 'output/model_' + str(epoch+1) + '.pth')
            torch.save(model.state_dict(), 'output/params_' + str(epoch + 1) + '.pth')
            #to_onnx(model, 3, 28, 28, 'params.onnx')

if __name__ == '__main__':
    train()

  1. 报错:没有cv2,即没有安装OpenCV库。
    在这里插入图片描述
  2. 安装OpenCV库,可以命令行安装,也可以Pycharm中安装。
  • 命令行激活虚拟环境:conda activate deeplearning
  • 命令行安装: pip install opencv-python(也可以Pycharm中下载,可能上梯子安装更快)
    在这里插入图片描述
  1. 再次运行,出现如下图提示,表明需要将下载好的数据集配置到configure中。
    在这里插入图片描述
  2. 加载下载好的数据集,即--datapath=数据集的路径
    在这里插入图片描述
  3. 点击“Run”,开始训练,损失和准确率在一直更新,持续训练,直到模型完成,未改动源码的情况下,训练时间可能需要较长。
    在这里插入图片描述
  4. 在小编的拯救者笔记本电脑上持续训练了10小时才完成最终的模型训练,可以看到训练损失已经很低了,准确度很高水平。
    在这里插入图片描述
  5. 在项目中output文件夹中可以看到已经训练好了很多模型;后面可以利用模型进行推理了。
    在这里插入图片描述

参考

https://zhuanlan.zhihu.com/p/681236488

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/368907.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【数据分析】Excel中的常用函数公式总结

目录 0 引用方式0.1 相对引用0.2 绝对引用0.3 混合引用0.4 3D引用0.5 命名引用 1 基础函数1.1 加法、减法、乘法和除法1.2 平均数1.3 求和1.4 最大值和最小值 2 文本函数2.1 合并单元格内容2.2 查找2.3 替换 3 逻辑函数3.1 IF函数3.2 AND和OR函数3.3 IFERROR函数 4 统计函数4.1…

java设计模式:策略模式

在平常的开发工作中,经常会用到不同的设计模式,合理的使用设计模式,可以提高开发效率,提高代码质量,提高代码的可拓展性和维护性。今天来聊聊策略模式。 策略模式是一种行为型设计模式,运行时可以根据需求动…

分布式session 笔记

概念 解决方案‘ 复制 session同步&#xff0c;让集群下的服务器进行session同步&#xff0c;一种传统的服务器集群session管理机制&#xff0c;常用于服务器不多的集群环境。<br /> 集群下&#xff0c;进行session同步的服务器的session数据是相同的&#xff0c;…

vulhub中spring的CVE-2022-22947漏洞复现

Spring Cloud Gateway是Spring中的一个API网关。其3.1.0及3.0.6版本&#xff08;包含&#xff09;以前存在一处SpEL表达式注入漏洞&#xff0c;当攻击者可以访问Actuator API的情况下&#xff0c;将可以利用该漏洞执行任意命令。 参考链接&#xff1a; https://tanzu.vmware.c…

图论练习1

内容&#xff1a;&#xff0c;拆点&#xff0c;分层&#xff0c;传递&#xff0c;带限制的最小生成树 [HNOI2015]菜肴制作 题目链接 题目大意 有个限制&#xff0c;号菜肴在号前完成在满足限制的条件下&#xff0c;按照出菜( 是为了满足的限制 ) 解题思路 由限制&#xf…

寒假 day1

1、请简述栈区和堆区的区别? 2、有一个整形数组:int arr[](数组的值由外部输入决定)&#xff0c;一个整型变量: x(也 由外部输入决定)。要求: 1)删除数组中与x的值相等的元素 2)不得创建新的数组 3)最多只允许使用单层循环 4)无需考虑超出新数组长度后面的元素&#xff0c;所以…

2024美赛数学建模D题思路分析 - 大湖区水资源问题

1 赛题 问题D&#xff1a;大湖区水资源问题 背景 美国和加拿大的五大湖是世界上最大的淡水湖群。这五个湖泊和连接的水道构成了一个巨大的流域&#xff0c;其中包含了这两个国家的许多大城市地区&#xff0c;气候和局部天气条件不同。 这些湖泊的水被用于许多用途&#xff0…

【数据分享】1929-2023年全球站点的逐日降雪深度数据(Shp\Excel\免费获取)

气象数据是在各项研究中都经常使用的数据&#xff0c;气象指标包括气温、风速、降水、能见度等指标&#xff0c;说到气象数据&#xff0c;最详细的气象数据是具体到气象监测站点的数据&#xff01; 之前我们分享过1929-2023年全球气象站点的逐日平均气温数据、逐日最高气温数据…

二维图像生成 3D 场景:nerfstudio 帮你简化流程 | 开源日报 No.164

nerfstudio-project/nerfstudio Stars: 7.7k License: Apache-2.0 nerfstudio 是一个友好的 NeRFs 协作工作室。 该项目旨在简化创建、训练和测试 NeRFs 的端到端流程&#xff0c;支持更模块化的 NeRFs 实现&#xff0c;并提供了简单的 API。 其主要功能和优势包括&#xff1…

DS:经典算法OJ题(2)

创作不易&#xff0c;友友们给个三连吧&#xff01;&#xff01; 一、旋转数组&#xff08;力扣&#xff09; 经典算法OJ题&#xff1a;旋转数组 思路1&#xff1a;每次挪动1位&#xff0c;右旋k次 时间复杂度&#xff1a;o(N^2) 右旋最好情况&#xff1a;k是n的倍数…

AI大模型专题:企业大模型市场厂商评估报告:滴普科技

今天分享的是AI大模型系列深度研究报告&#xff1a;《AI大模型专题&#xff1a;企业大模型市场厂商评估报告&#xff1a;滴普科技》。 &#xff08;报告出品方&#xff1a;滴普科技&#xff09; 报告共计&#xff1a;22页 研究范围定义 大模型是指通过在海量数据上依托强大算…

ctfshow web-77

开启环境: 先直接用伪协议获取 flag 位置。 c?><?php $anew DirectoryIterator("glob:///*"); foreach($a as $f) {echo($f->__toString(). );} exit(0); ?> 发现 flag36x.txt 文件。同时根目录下还有 readflag&#xff0c;估计需要调用 readflag 获…

海外YouTube视频点赞刷单悬赏任务投资理财源码/tiktok国际版刷单理财

测试环境&#xff1a;Linux系统CentOS7.6、宝塔、PHP7.3、MySQL5.7&#xff0c;根目录public&#xff0c;伪静态Laravel5&#xff0c;开启SSL证书 前端&#xff1a;修改网站的默认文档 index.html 为第一个&#xff0c; index.php 改成第二个 &#xff0c;或者前端访问 index.…

【问题解决】VSCode1.86.0版+拓展Remote-SSHv0.108 无法连接到VSCode服务器(VSCode无法远程连接到Linux)

作者在下午五点左右照常通过VSCode远程连接Ubuntu主机编写代码&#xff0c;突然在一次重启VSCode后&#xff0c;不管如何尝试都连接不到Linux主机&#xff0c;首先自己尝试了重启电脑&#xff0c;无效&#xff1b;然后尝试通过其他软件&#xff08;XShell和Xftp&#xff09;远程…

赎金信[简单]

优质博文&#xff1a;IT-BLOG-CN 一、题目 给你两个字符串&#xff1a;ransomNote和magazine&#xff0c;判断ransomNote能不能由magazine里面的字符构成。如果可以&#xff0c;返回true&#xff1b;否则返回false。magazine中的每个字符只能在ransomNote中使用一次。 示例 …

从 20 多套 MySQL 到 1 套 TiDB丨骏伯网络综合运营管理平台应用实践

原文来源&#xff1a; https://tidb.net/blog/a38c72a4 本文作者&#xff1a;骏伯网络 唐帆&#xff0c;PingCAP 贺美存 骏伯网络简介 广州骏伯网络是一家以数据驱动的科技公司&#xff0c;聚焦移动互联网营销服务&#xff0c;坚持以客户为中心&#xff0c;深耕 APP、运营…

第01课:自动驾驶概述

文章目录 1、无人驾驶行业概述什么是无人驾驶智慧出行大趋势无人驾驶能解决什么问题行业趋势无人驾驶的发展历程探索阶段&#xff08;2004年以前&#xff09;发展阶段&#xff08;2004年-2016年&#xff09;成熟阶段&#xff08;2016年以后&#xff09; 2、无人驾驶技术路径无人…

uniapp开发一个交流社区小程序

uniapp开发一个交流社区小程序 假期的时候简单学了一下uniapp&#xff0c;想开发一款类似百度贴吧的交流社区来练练手。本篇文章主要记录开发过程&#xff0c;文末附上项目地址。 主要需要开发以下几个页面。 信息页面热榜页面用户主页用户信息页 信息页面 该页面的功能主要…

Leetcode的AC指南 —— 栈与队列 :150. 逆波兰表达式求值

摘要&#xff1a; **Leetcode的AC指南 —— 栈与队列 &#xff1a;150. 逆波兰表达式求值 **。题目介绍&#xff1a;给你一个字符串数组 tokens &#xff0c;表示一个根据 逆波兰表示法 表示的算术表达式。 请你计算该表达式。返回一个表示表达式值的整数。 文章目录 一、题目…

微信小程序(三十一)本地同步存储API

注释很详细&#xff0c;直接上代码 上一篇 新增内容&#xff1a; 1.存储数据 2.读取数据 3.删除数据 4.清空数据 源码&#xff1a; index.wxml <!-- 列表渲染基础写法&#xff0c;不明白的看上一篇 --> <view class"students"><view class"item…