基于卷积神经网络的交通标志识别(pytorch,opencv,yolov5)

文章目录

  • 数据集介绍:
  • resnet18模型代码
  • 加载数据集(Dataset与Dataloader)
  • 模型训练
  • 训练准确率及损失函数:
  • resnet18交通标志分类源码
  • yolov5检测与识别(交通标志)

本文共包含两部分,
第一部分是用resnet18对交通标志分类,仅仅只是交通标志分类
文末附有yolov5和resnet18结合的源码,yolov5复制检测交通标志位置,然后使用resnet18对交通标志进行分类。

数据集介绍:

本文使用的数据集共有6000多张,共包含58个类别。部分数据集如下:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

resnet18模型代码

使用pytorch自带的resnet18模型,代码如下:

from torchvision import models
import torch.nn as nn

#加载resnet18模型
net=models.resnet18(weights=None)
#因为分类个数为58,所以需要修改模型最后一层全连接层
net.fc=nn.Linear(in_features=512, out_features=58, bias=True)
# print(net)

加载数据集(Dataset与Dataloader)

from torch.utils.data import Dataset,DataLoader
import numpy as np
import cv2
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from PIL import Image
import os
from torchvision import transforms
import torch
import random

a=[]
class Mydata(Dataset):
    def __init__(self,lines,train=True):
        super(Mydata, self).__init__()
        self.lines=lines
        random.shuffle(self.lines)
        self.train=train


    def __len__(self):
        return len(self.lines)
    def __getitem__(self, index):

        txts=self.lines[index].strip().split(';')
        src_path='pic/'+txts[0]
        w=int(txts[1])
        h=int(txts[2])
        x1=int(txts[3])
        y1=int(txts[4])
        x2=int(txts[5])
        y2=int(txts[6])

        new_x1=random.randint(0,x1)
        new_y1=random.randint(0,y1)
        new_x2=random.randint(x2,w-1)
        new_y2=random.randint(y2,h-1)

        lab=int(txts[7])
        # if lab in a:
        #     pass
        # else:a.append(lab)
        #
        # a.sort()
        # print(len(a))
        # print(a)
        img = Image.open(src_path)
        img=np.array(img)[...,:3]
        img=img[new_y1:new_y2,new_x1:new_x2]


        #数据增强
        if self.train:
            img=self.get_random_data(img)
        else:
            img = cv2.resize(img, (128, 128))
        # cv2.imshow('img',img[...,::-1])
        # cv2.waitKey(0)

        #归一化
        img=(img/255.0).astype('float32')
        img=np.transpose(img,(2,0,1))

        img=torch.from_numpy(img)
        return img,lab
    def get_random_data(self,img):
        seq = iaa.Sequential([
            # iaa.Flipud(0.5),  # flip up and down (vertical)
            # iaa.Fliplr(0.5),  # flip left and right (horizontal)
            iaa.Multiply((0.8, 1.2)),  # change brightness, doesn't affect BBs(bounding boxes)
            iaa.GaussianBlur(sigma=(0, 1.0)),  # 标准差为0到3之间的值

            iaa.Crop(percent=(0, 0.2)),
            iaa.Affine(
                translate_px={"x": (0,15), "y": (0,15)},  # 平移
                scale=(0.8, 1.2),  # 尺度变换
                rotate=(-20, 20),
                mode='constant',
                cval=(125)
                ),
            iaa.Resize(128)
        ])

        img= seq(image=img)
        return img
if __name__ == '__main__':
    lines=open('data.txt','r').readlines()
    my=Mydata(lines=lines,train=True)
    myloader=DataLoader(dataset=my,batch_size=3,shuffle=False)

    for i,j in myloader:
        print(i.shape,j.shape)

模型训练

经过60个epoch训练后,模型准确率基本上达到百分百

from mymodel import net
from myDataset import Mydata
import random
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch
from tqdm import tqdm
import matplotlib.pylab as plt

batch_size=32
Epoch=60
lr=0.001

lines=open('data.txt','r').readlines()
random.shuffle(lines)
val_lines=random.sample(lines,int(len(lines)*0.1))
train_lines=list(set(lines)-set(val_lines))


train_data=Mydata(lines=train_lines)
val_data=Mydata(lines=val_lines,train=False)
train_loader=DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
val_loader=DataLoader(dataset=val_data,batch_size=batch_size,shuffle=False)

num_train   = len(train_lines)
epoch_step  = num_train // batch_size
BCE_loss     = nn.CrossEntropyLoss()
optimizer  = optim.Adam(net.parameters(), lr=lr, betas=(0.5, 0.999))
lr_scheduler  = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
#获取学习率函数
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
#计算准确率函数
def metric_func(pred,lab):
    _,index=torch.max(pred,dim=-1)
    acc=torch.where(index==lab,1.,0.).mean()
    return acc
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net=net.to(device)
#设置损失函数
loss_fun     = nn.CrossEntropyLoss()

if __name__ == '__main__':

    T_acc=[]
    V_acc=[]
    T_loss=[]
    V_loss=[]

    # 设置迭代次数200次
    epoch_step = num_train // batch_size
    for epoch in range(1, Epoch + 1):
        net.train()

        total_loss = 0
        loss_sum = 0.0
        train_acc_sum=0.0

        with tqdm(total=epoch_step, desc=f'Epoch {epoch}/{Epoch}', postfix=dict, mininterval=0.3) as pbar:
            for step, (features, labels) in enumerate(train_loader, 1):
                features = features.to(device)
                labels = labels.to(device)
                batch_size = labels.size()[0]

                optimizer.zero_grad()
                predictions = net(features)
                loss = loss_fun(predictions, labels)
                loss.backward()
                optimizer.step()
                total_loss += loss
                train_acc = metric_func(predictions, labels)
                train_acc_sum+=train_acc
                pbar.set_postfix(**{'loss': total_loss.item() / (step),
                                    "acc":train_acc_sum.item()/(step),
                                    'lr': get_lr(optimizer)})
                pbar.update(1)
        T_acc.append(train_acc_sum.item()/(step))
        T_loss.append(total_loss.item() / (step))
        # 验证
        net.eval()
        val_acc_sum = 0
        val_loss_sum=0
        for val_step, (features, labels) in enumerate(val_loader, 1):
            with torch.no_grad():
                features = features.to(device)
                labels = labels.to(device)
                predictions = net(features)

                val_metric = metric_func(predictions, labels)
                loss=loss_fun(predictions,labels)
            val_acc_sum += val_metric.item()
            val_loss_sum+=loss.item()
        print('val_acc=%.4f' % (val_acc_sum / val_step))
        V_acc.append(round(val_acc_sum / val_step,2))
        V_loss.append(val_loss_sum/val_step)

        # 保存模型
        if (epoch) % 2 == 0:
            torch.save(net.state_dict(), 'logs/Epoch%d-Loss%.4f_.pth' % (
                epoch, total_loss / (epoch_step + 1)))

        lr_scheduler.step()

    plt.figure()
    plt.plot(T_acc,'r')
    plt.plot(V_acc,'b')
    plt.title('Training and validation Acc')
    plt.xlabel("Epochs")
    plt.ylabel("Acc")
    plt.legend(["Train_acc", "Val_acc"])
    # plt.show()
    plt.savefig("ACC.png")

    plt.figure()
    plt.plot(T_loss, 'r')
    plt.plot(V_loss, 'b')
    plt.title('Training and validation loss')
    plt.xlabel("Epochs")
    plt.ylabel("loss")
    plt.legend(["Train_loss", "Val_loss"])

    plt.savefig("LOSS.png")
    plt.show()


训练准确率及损失函数:

准确率:

在这里插入图片描述
损失函数:
在这里插入图片描述

resnet18交通标志分类源码

(包含训练,预测代码,准确率,损失函结果图像,数据集等):
下载地址:

yolov5检测与识别(交通标志)

前面是使用resnet18网络对交通标志分类,只是单单的分类,无法从一张完整的全局图像中检测交通标志位置。对此,首先使用yolov5从全局图像中检测交通标志的位置,只是检测没有分类,然后再使用前面训练好的resnet18模型对交通标志分类。其效果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/641092.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

力扣刷题---2206. 将数组划分成相等数对【简单】

题目描述🍗 给你一个整数数组 nums ,它包含 2 * n 个整数。 你需要将 nums 划分成 n 个数对,满足: 每个元素 只属于一个 数对。 同一数对中的元素 相等 。 如果可以将 nums 划分成 n 个数对,请你返回 true &#xf…

高开高走的续作,可不止《庆余年2》

说起最近霸屏的影视剧,莫过于《庆余年2》。火爆全网的讨论度总归是没有辜负观众们五年的等待,在五月的影视市场独占鳌头已成定局。张若昀、陈道明、李沁等一众演员稳定发挥,剧情节奏随着故事发展渐入佳境,评分一路高涨。 对影视作…

【网络安全】社会工程学攻击与防范

一、社会工程学概述 1、社会工程学的定义 通过利用人们的心理弱点、本能反应、好奇心、信任、贪婪等一些心理陷阱进行的诸如欺骗、伤害、信息盗取、利益谋取等对社会及人类带来危害的行为或方法。 当网络恶意攻击者无法通过纯粹的计算机技术达到目的时,高超的情商…

统计信号处理基础 习题解答10-4

题目: 重复习题10.3,但条件PDF变为: 以及均匀先验。如果非常大,这样先验知识很少,则会出现什么情况。 解答: 如果记 那么,根据条件独立性质,得到: 其中,&am…

Web3 游戏平台 Creo Engine 销毁代币总量的20%,以促进长远发展

Creo Engine 5月16日进行了第三次代币销毁,这次的销毁占代币总量的 20%。一共销毁了2亿 $CERO 代币,市场价值接近 2000 万美元。 Creo Engine 致力于连接世界、为玩家提供一站式游戏中心,并提升 Web3 游戏体验。 Creo Engine 发布于2022年&am…

Python学习---基于TCP协议的网络通信程序案例

TCP简介: ●TCP 面向连接、可靠的、基于字节流的传输控制协议 ●TCP的特点 ○面向连接 ○可靠传输 ■应答机制 ■超时重传 ■错误校验 ■流量管控 ●TCP通信模型 TCP严格区分客户…

运维Tips | Linux系统文件命令执行时inode表如何变化?

[ 知识是人生的灯塔,只有不断学习,才能照亮前行的道路。 ] 大家好,我是【WeiyiGeek/唯一极客】一个正在向全栈工程师(SecDevOps)前进的技术爱好者 作者微信:WeiyiGeeker 公众号/知识星球:全栈工程师修炼指南 主页博…

Linux下自旋锁的学习使用

前言 前面我们讲到定时器的使用,本篇讲下自旋锁的使用。想第一时间看我的文章的话可以点击公众号主页右上角有个设为星标,以免错过好文。本文源码采用Linux内核5.10 自旋锁简介 自旋锁是Linux内核里最常用的锁之一,自旋锁的概念很简单,就是…

误差反向传播简介与实现

误差反向传播 导语计算图反向传播链式法则 反向传播结构加法节点乘法节点 实现简单层加法乘法 激活函数层实现ReLUSigmoid Affine/Softmax层实现Affine基础版批版本 Softmax-with-Loss 误差反向传播实现梯度确认总结参考文献 导语 书上在前一章介绍了随机梯度下降法进行参数与…

ubuntu20.04 开机自动挂载外加硬盘

文章目录 一、问题描述二、操作1. 查找新添盘符2. 格式化硬盘文件系统3. 挂载硬盘4. 开机自动挂载5. 取消挂载6. 查看挂载的硬盘信息 一、问题描述 因电脑使用一段时间后自身硬盘不足,需外加硬盘使得电脑自动识别加载。 二、操作 1. 查找新添盘符 sudo blkid自己…

linux中的arch命令使用

arch 显示当前主机的硬件架构类型 概要 arch [OPTION]...主要用途 打印机器架构信息;arch 命令输出结果有:i386、i486、i586、alpha、sparc、arm、m68k、mips、ppc、i686等。 选项 --help 显示帮助信息并退出。 --version 显示版本信息并…

windows7的ie11降级到ie8

重点是要在程序管理窗口中“查看已安装的更新”打开当前系统中已安装更新列表,找到两个IE11的更新(见下图“卸载文件“)并卸载掉,这样windows功能中的ie11才会变成ie8. 打开控制面板 进入面板,点击程序,进…

hot100 -- 回溯(上)

目录 🍞科普 🌼全排列 AC DFS 🚩子集 AC DFS 🎂电话号码的字母组合 AC DFS 🌼组合总和 AC DFS 🍞科普 忘记 dfs 的,先看看这个👇 DFS(深度优先搜索&#xf…

ping 探测网段哪些地址被用

#!/bin/bash# 遍历192.168.3.1到192.168.3.254 for i in {1..254} doip"192.168.3.$i"# 对每个IP地址进行三次ping操作if ping -c 3 -W 1 $ip > /dev/null 2>&1thenecho "$ip: yes"fi done$ sh test.sh 192.168.3.1: yes 192.168.3.95: yes 192.…

MySQL中的sql语句

MySQL中的sql语句 DML、 DDL、 DCL DML(Data Manipulation Language),用于对数据库中的数据进行操作,包括插入、查询、更新和删除数据等操作。常见的 DML 命令包括 SELECT(查询)、INSERT(插入)、UPDATE&a…

Windows系统安装dlib及face_recognition搭建人脸识别环境

关于face_recognition face_recognition被称为世界上最简洁的人脸识别库,借助face_recognition库,我们可以使用Python和命令行提取、识别、操作人脸。 face_recognition的人脸识别是基于业内领先的C开源库 dlib中的深度学习模型,用Labeled …

数据结构——栈(详细分析)

目录 🍉引言 🍉栈的本质和特点 🍈栈的基本操作 🍈栈的特点 🍍后进先出 🍍操作受限 🍍动态调整 🍈栈的优缺点 🍍优点 🍍缺点 🍉栈的应用…

STM32F407VET6 学习笔记4:DAC数模转换功能的配置

今日继续学习使用嘉立创的 立创梁山派天空星,芯片是 STM32F407VET6 使用库函数编程 最近突然发现很久没有接触过单片机的AD转换功能了,之前还是学习51单片机时学习驱动PCF8591芯片实现AD转换功能的,还从未在STM32平台上进行过相关的实验经验…

解决go install 网络问题

rootiZbp1hiqzlhh6w05gloffgZ:~# go install mvdan.cc/garblelatest go: mvdan.cc/garblelatest: module mvdan.cc/garble: Get "https://proxy.golang.org/mvdan.cc/garble/v/list": dial tcp 172.217.160.81:443: i/o timeout解决方法 更换阿里代理 rootiZbp1hiq…

保障餐饮场所安全:定期送检可燃气体报警器

在餐饮行业,火灾隐患一直备受关注。餐厅、茶饮店等场所常常使用燃气设备,而这些设备带来了潜在的安全隐患。 为了及时发现并预防可燃气体泄漏,可燃气体报警器的定期送检显得尤为重要。那么,为什么可燃气体报警器需要定期送检呢&a…