pytorch车牌识别

目录

  • 使用pytorch库中CNN模型进行图像识别
    • 收集数据集
    • 定义CNN模型
      • 卷积层
      • 池化层
      • 全连接层
    • CNN模型代码
    • 使用模型

在这里插入图片描述

使用pytorch库中CNN模型进行图像识别

收集数据集

可以去找开源的数据集或者自己手做一个
最终整合成 类别分类的图片文件
在这里插入图片描述

定义CNN模型

卷积层

功能:提取特征

概念

  1. 卷积层输入层通道数

如果输入数据是彩色图像,那么通常情况下,输入数据具有三个通道(红、绿、蓝),因此第一个卷积层的输入通道数应该为3。
如果输入数据是灰度图像,那么输入通道数通常为 1。

  1. 卷积层输出层通道数

卷积层的输出通道数控制着该层提取的特征的数量和复杂度。更多的输出通道意味着网络可以学习更多种类的特征,但过多的输出通道数会导致复杂度和过拟合。

池化层

功能:使卷积层的特征更加明显,对图像进行降维压缩(舍弃无关特征,避免过拟合),提高神经网络的泛华能力。
问题:

  1. 最大池化操作

最大池化操作是一种常用的池化操作,用于减少特征图的空间维度并保留最重要的特征信息
在这里插入图片描述

# 定义最大池化层,池化窗口大小为 2x2,步幅为 2
max_pool_layer = nn.MaxPool2d(kernel_size=2, stride=2)

全连接层

将特征进行整合,然后归一化,对各种分类情况都输入一个概率,根据概率进行分类

CNN模型代码

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Dataset
# 进度条工具
from tqdm import tqdm

# 数据集中的类别数
num_classes = len(os.listdir('./数据集'))
# 训练的轮数
num_epochs = 10
# 30次:['陕', '陕', 'U', 'U', '6', '6', '6', '6']
# 10次:['陕', 'A', 'D', '0', '6', '6', '6', '6']

# 一、定义数据预处理和数据加载器
transform = transforms.Compose([
    # 固定图像大小
    transforms.Resize((64, 64)),
    # 将图像转换为灰度图像
    transforms.Grayscale(),
    # 将图像转换为张量
    transforms.ToTensor(),
])
# 使用ImageFolder定义数据集,标签为序号
train_dataset = ImageFolder(root='./数据集', transform=transform)
# 数据加载器,每个批次包含32张图像
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)


# 二、定义 CNN 模型
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        # 卷积层1  1代表单通道,黑白;32代表输出通道;3代表3*3的卷积核, 1代表在最外围补一圈0
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        # 池化层1  最大池化操作,2代表尺寸减半
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        # 卷积层2 ,32对于卷积层1的输出通道数
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        # 全连接层 64输出通道数,16*16代表压缩后的尺寸,生成长度128向量
        self.fc1 = nn.Linear(64 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, num_classes)

    # 前向传播 返回输出结果
    def forward(self, x):
        # 卷积1
        x = self.conv1(x)
        # 激活函数/激化函数 引入非线性变化,增强神经网络复杂性
        x = torch.relu(x)
        # 池化
        x = self.pool(x)
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 16 * 16)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# 三、初始化模型、损失函数和优化器
model = CNNModel()
criterion = nn.CrossEntropyLoss()
# 学习率一般设0.01
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 四、只要当主文件运行时候,才训练模型
if __name__ == "__main__":
    for epoch in range(num_epochs):
        running_loss = 0.0
        print(f'Epoch : {epoch + 1}/{num_epochs}')
        # 显示每轮的进度条
        for images, labels in tqdm(train_loader):
            #  将优化器中存储的之前计算的梯度归零
            optimizer.zero_grad()
            # 将输入图像数据 images 输入到模型中进行前向传播,得到模型的输出
            outputs = model(images)
            # 损失函数 criterion 计算模型 输出 与 真实标签 之间的损失值。
            loss = criterion(outputs, labels)
            # 对损失值进行反向传播,计算模型参数的梯度
            loss.backward()
            # 据优化算法(梯度下降)更新模型参数,最小化损失函数
            optimizer.step()
            running_loss += loss.item()

        # 输出每个 epoch 的平均损失
        epoch_loss = running_loss / len(train_loader)
        print(f'Epoch {epoch + 1} loss: {epoch_loss:.4f}')

        # 保存模型
        torch.save(model.state_dict(), 'cnn_model.pt')

使用模型

import torch
from PIL import Image
from torch.utils.data import dataset
from cnn_model import transform, train_dataset, CNNModel

# 加载整个模型
model = CNNModel()
# 将模型设置为评估模式
model.eval()
checkpoint = torch.load('./cnn_model.pt')
model.load_state_dict(checkpoint)


# 使用模型进行预测,识别单个文字图片
def predict_image(image_path):
    image = Image.open(image_path)
    # 转换图片格式
    image = transform(image)
    # 只进行前向传播
    with torch.no_grad():
        output = model(image)
    # ImageFolder输出的标签是文件序号,argmax找到张量output中的最大值
    predicted_idx = torch.argmax(output).item()
    print(predicted_idx)
    # 将输出转换成对应序号的文件名
    if predicted_idx < len(train_dataset.classes) :
        predicted_label = train_dataset.classes[predicted_idx]
        return predicted_label
    else:
        return "null"

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/542009.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

opencv基础图行展示

"""试用opencv创建画布并显示矩形框&#xff08;适用于目标检测图像可视化&#xff09; """ # 创建一个黑色的画布&#xff0c;图像格式(BGR) img np.zeros((512, 512, 3), np.uint8)# 画一个矩形&#xff1a;给定左上角和右下角坐标&#xff0…

Redis入门到通关之Hash命令

文章目录 ⛄介绍⛄命令⛄RedisTemplate API❄️❄️添加缓存❄️❄️设置过期时间(单独设置)❄️❄️添加一个Map集合❄️❄️提取所有的小key❄️❄️提取所有的value值❄️❄️根据key提取value值❄️❄️获取所有的键值对集合❄️❄️删除❄️❄️判断Hash中是否含有该值 ⛄…

文献阅读:猕猴的单个基底外侧杏仁核神经元表现出与额叶皮层不同的连接模式

文献介绍 「文献题目」 Single basolateral amygdala neurons in macaques exhibit distinct connectional motifs with frontal cortex 「研究团队」 Peter H. Rudebeck&#xff08;美国西奈山伊坎医学院&#xff09; 「发表时间」 2023-10-18 「发表期刊」 Neuron 「影响因…

Springboot+Vue项目-基于Java+MySQL的母婴商城系统(附源码+演示视频+LW)

大家好&#xff01;我是程序猿老A&#xff0c;感谢您阅读本文&#xff0c;欢迎一键三连哦。 &#x1f49e;当前专栏&#xff1a;Java毕业设计 精彩专栏推荐&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; &#x1f380; Python毕业设计 &…

Ubuntu去除烦人的顶部【活动】按钮

文章目录 一、需求说明二、打开 extensions 网站三、安装 GNOME Shell 插件四、安装本地连接器五、安装 Hide Activities Button 插件六、最终效果七、卸载本地连接器命令参考 本文所使用的 Ubuntu 系统版本是 Ubuntu 22.04 ! 一、需求说明 使用 Ubuntu 的过程中&#xff0c;屏…

【大语言模型】应用:10分钟实现搜索引擎

本文利用20Newsgroup这个数据集作为Corpus(语料库)&#xff0c;用户可以通过搜索关键字来进行查询关联度最高的News&#xff0c;实现对文本的搜索引擎&#xff1a; 1. 导入数据集 from sklearn.datasets import fetch_20newsgroupsnewsgroups fetch_20newsgroups()print(fNu…

在Linux驱动中,如何确保中断上下文的正确保存和恢复?

大家好&#xff0c;今天给大家介绍在Linux驱动中&#xff0c;如何确保中断上下文的正确保存和恢复&#xff1f;&#xff0c;文章末尾附有分享大家一个资料包&#xff0c;差不多150多G。里面学习内容、面经、项目都比较新也比较全&#xff01;可进群免费领取。 在Linux驱动中&am…

AI图书推荐:如何在课堂上使用ChatGPT 进行教育

ChatGPT是一款强大的新型人工智能&#xff0c;已向公众免费开放。现在&#xff0c;各级别的教师、教授和指导员都能利用这款革命性新技术的力量来提升教育体验。 本书提供了一个易于理解的ChatGPT解释&#xff0c;并且更重要的是&#xff0c;详述了如何在课堂上以多种不同方式…

STM32利用软件I2C通讯读MPU6050的ID号

今天的读ID号是建立在上篇文章中有了底层的I2C通讯的6个基本时序来编写的。首先需要完成的就是MPU6050的初始化函数 然后就是编写 指定地址写函数&#xff1a; 一&#xff1a;开始 二&#xff1a;发送 从机地址读写位&#xff08;1&#xff1a;读 0&#xff1…

MySQL之索引失效、覆盖、前缀索引及单列、联合索引详细总结

索引失效 最左前缀法则 如果索引了多列(联合索引)&#xff0c;要遵守最左前缀法则&#xff0c;最左前缀法则指的是查询从索引的最左列开始&#xff0c;并且不跳过索引中的列。如果跳跃某一列&#xff0c;索引将部分失效&#xff08;后面的字段索引失效&#xff09;。 联合索…

第1章 计算机网络体系结构

王道学习 【考纲内容】 &#xff08;一&#xff09;计算机网络概述 计算机网络的概念、组成与功能&#xff1b;计算机网络的分类&#xff1b; 计算机网络的性能指标 &#xff08;二&#xff09;计算机网络体系结构与参考模型 计算机网络分层结…

权威Scrum敏捷开发企业级实训/敏捷开发培训课程

课程简介 Scrum是目前运用最为广泛的敏捷开发方法&#xff0c;是一个轻量级的项目管理和产品研发管理框架。 这是一个两天的实训课程&#xff0c;面向研发管理者、项目经理、产品经理、研发团队等&#xff0c;旨在帮助学员全面系统地学习Scrum和敏捷开发, 帮助企业快速启动敏…

Ansys Workbench拓扑优化教程

很基础。 前言 进行拓扑优化的好处在于可以简化结构&#xff0c;满足力学性能的同时简化结构。 如赵州桥的一大一小的拱&#xff0c;就可以用拓扑优化优化出来&#xff0c;可见一千四百多年以前古人的智慧是多么丰富。 步骤 大体的步骤是需要 1.先导入模型&#xff08;需…

练习题(2024/4/13)

1长度最小的子数组 给定一个含有 n 个正整数的数组和一个正整数 target 。 找出该数组中满足其总和大于等于 target 的长度最小的 连续 子数组 [numsl, numsl1, ..., numsr-1, numsr] &#xff0c;并返回其长度。如果不存在符合条件的子数组&#xff0c;返回 0 。 示例 1&am…

ThreadX:怎么确定一个线程应该开多少内存

ThreadX&#xff1a;如何确定线程的大小 在实时操作系统&#xff08;RTOS&#xff09;ThreadX中&#xff0c;线程的大小是一个重要的参数。这个参数决定了线程的堆栈大小&#xff0c;也就是线程可以使用的内存空间。那么&#xff0c;我们应该如何确定一个线程需要多大的字节呢…

Asterisk 21.2.0编译安装经常遇到的问题和解决办法之卸载pjsip

目录 会安装也要会卸载make uninstallldconfig 会安装也要会卸载 有些人就只会装。 最常见的场景就是需要卸载之前版本的pjproject。 一般来说&#xff0c;其他版本的 pjproject 会被作为静态链接库安装。这些库跟 Asterisk可能不兼容。 因此&#xff0c;在安装正确版本的pjpro…

语音识别(录音与语音播报)

语音识别&#xff08;录音与语音播报&#xff09; 简介 语音识别人工智能技术的应用领域非常广泛&#xff0c;常见的应用系统有&#xff1a;语音输入系统&#xff0c;相对于键盘输入方法&#xff0c;它更符合人的日常习惯&#xff0c;也更自然、更高效&#xff1b;语音控制系…

记录day1

1.早上 ①协同过滤算法 基于物品基于用户分别是如何实现的 相似度的计算方式 基于用户和基于物品的区别 实时性和新物品这里&#xff1a; 实时性指的是用户有新行为&#xff0c;这样基于物品就好&#xff0c;因为用户新行为了&#xff0c;用户矩阵不会变化&#xff0c;用户…

cron表达式使用手册

cron表达式 我们在使用定时调度任务的时候&#xff0c;最常用的就是cron表达式。通过cron表达式来指定任务在某个时间点或者周期性执行。 范围&#xff1a; 秒&#xff08;0-59&#xff09;&#xff08;可选&#xff09; 分&#xff08;0-59&#xff09; 时&#xff08;0-…

从零实现诗词GPT大模型:数据集介绍和预处理

本章将介绍该系列文章中使用的数据集&#xff0c;并且编写预处理代码&#xff0c;处理成咱们需要的格式。 一、数据集介绍 咱们使用的数据集名称是chinese-poetry&#xff0c;是一个在github上开源的中文诗词数据集&#xff0c;根据仓库中readme.md中的介绍&#xff0c;该数据…