0 简单的图像分类

本文主要针对交通标识图片进行分类,包含62类,这个就是当前科大讯飞比赛,目前准确率在0.94左右,难点如下:

1 类别不均衡,有得种类图片2百多,有个只有10个不到;

2 像素大小不同,导致有的图片很清晰,有的很模糊;

直接上代码:

import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split

from torchvision import models, datasets, transforms
import torch.utils.data as tud
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from PIL import Image
import matplotlib.pyplot as plt
import warnings
import pandas as pd
from torch.utils.data import random_split

warnings.filterwarnings("ignore")

# 检测能否使用GPU
print(#labels
torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
)

device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
n_classes = 62  # 几种分类的
preteain = False  # 是否下载使用训练参数 有网true 没网false
epoches = 10  # 训练的轮次
traindataset = datasets.ImageFolder(root='../all/data/train_set/', transform=transforms.Compose([
    transforms.Resize((224,224)),
    #transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
]))


# 分割比例:比如80%的数据用于训练,20%用于验证
train_val_ratio = 0.8
train_size = int(len(traindataset) * train_val_ratio)
val_size = len(traindataset) - train_size
train_dataset, val_dataset = random_split(traindataset, [train_size, val_size])


classes = traindataset.classes
print(classes)
 
model = models.resnext50_32x4d(pretrained=preteain)
#model = models.resnet34(pretrained=preteain)

if preteain == True:
    for param in model.parameters():
        param.requires_grad = False
        
model.fc = nn.Linear(in_features=2048, out_features=n_classes, bias=True)
model = model.to(device)
 
 
def train_model(model, train_loader, loss_fn, optimizer, epoch):
    model.train()
    total_loss = 0.
    total_corrects = 0.
    total = 0.
    for idx, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        preds = outputs.argmax(dim=1)
        total_corrects += torch.sum(preds.eq(labels))
        total_loss += loss.item() * inputs.size(0)
        total += labels.size(0)
    total_loss = total_loss / total
    acc = 100 * total_corrects / total
    print("轮次:%4d|训练集损失:%.5f|训练集准确率:%6.2f%%" % (epoch + 1, total_loss, acc))
    return total_loss, acc
 
 
def test_model(model, test_loader, loss_fn, optimizer, epoch):
    model.train()
    total_loss = 0.
    total_corrects = 0.
    total = 0.
    with torch.no_grad():
        for idx, (inputs, labels) in enumerate(test_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            preds = outputs.argmax(dim=1)
            total += labels.size(0)
            total_loss += loss.item() * inputs.size(0)
            total_corrects += torch.sum(preds.eq(labels))
 
        loss = total_loss / total
        accuracy = 100 * total_corrects / total
        print("轮次:%4d|测试集损失:%.5f|测试集准确率:%6.2f%%" % (epoch + 1, loss, accuracy))
        return loss, accuracy
 
 
loss_fn = nn.CrossEntropyLoss().to(device)

optimizer = optim.Adam(model.parameters(), lr=0.0001)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(val_dataset, batch_size=32, shuffle=True)
for epoch in range(0, epoches):
    loss1, acc1 = train_model(model, train_loader, loss_fn, optimizer, epoch)
    loss2, acc2 = test_model(model, test_loader, loss_fn, optimizer, epoch)

模型预测:

sub = pd.read_csv("../all/data/example.csv")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

model.eval()
for path in os.listdir("../all/data/test_set/"):
    try:
        img = Image.open("../all/data/test_set/"+path)
        img_p = transform(img).unsqueeze(0).to(device)
        output = model(img_p)
        pred = output.argmax(dim=1).item()
        if img.size[0] * img.size[1]<2000:
            plt.imshow(img)
            plt.show()
        p = 100 * nn.Softmax(dim=1)(output).detach().cpu().numpy()[0]
        sub.loc[sub['ImageID'] == path,'label'] = classes[pred]
        print(f'{path} size = {img.size}, 该图像预测类别为:', classes[pred])
    except:
        print(f'error {path}')
sub.loc[sub['ImageID']=='e57471de-6527-4b9b-90a8-4f1d93909216.png','label'] = 'Under Construction'
sub.loc[sub['ImageID']=='ff38d59e-9a11-41e4-901b-67097bb0e960.png','label'] = 'Keep Left'
sub.columns = ['ImageID','Sign Name']
label_map = pd.read_excel("../all/data/label_map.xlsx")
sub_all = pd.merge(left=sub,right=label_map,on='Sign Name',how='left')
#sub_all[['ImageID','label']].to_csv('./sub_resnet34_add_img_ratio_drop_dire.csv',index=False)

个人的心得:

1 如何进行图片增强,图片增强应该注意什么(方向问题);总结一些transforms,数据增强的方式 - 代码天地

2 模型大小如何进行选择;

更新:

6.17: 昨天我尝试使用更多的增强技术,直接上到0.98;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/727062.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

超越Llama3,多模态比肩GPT4V:GLM-4智能体,新一代语言处理利器

在人工智能领域&#xff0c;自然语言处理技术一直备受关注。就在昨日,今年备受关注的国内AI公司北京智谱AI发布了第四代 GLM 系列开源模型&#xff1a;GLM-4-9B。这是一个集成了先进自然语言处理技术的创新平台&#xff0c;它凭借清华大学KEG实验室提出的GLM模型结构&#xff0…

18V-180V降12V500mA恒压模块WT5118

18V-180V降12V500mA恒压模块WT5118 WT5118是一款能够将输入电压范围从18V至180V降低至12V并保持恒定输出电流500mA的恒压模块。 WT5118 是一款专为开关电源设计的集成了 180V 高电压 MOSFET 的 DC-DC 控制器。这个设备具备内置高压启动和自供电功能&#xff0c;能够满足快速启…

OpenAI 推出“模型规范”:塑造责任制的人工智能的框架

为了提升人工智能开发的责任性和透明度&#xff0c;OpenAI 最近发布了一份名为“模型规范”的初步草案。这份文件首次明确了其 API 和 ChatGPT 模型行为的指导原则&#xff0c;并通过博客形式对外公布。 OpenAI 在博客中解释说&#xff1a;“我们之所以发布此文档&#xff0c;…

「Python-docx 专栏」docx 设置页面边距、页眉页脚高度

本文目录 前言一、docx 页面边距在哪里二、对 <w:pgMar> 的详细说明1、上边距的说明2、右边距的说明3、下边距的说明4、左边距的说明5、页眉高度的说明6、页脚高度的说明三、设置 docx 页边距、页眉页脚高度1、完整代码2、代码执行效果图四、补充一些内容1、页面边距的两…

图解Sieve of Eratosthenes(埃拉托斯特尼筛法)算法求解素数个数

1.素数的定义 素数又称质数。质数是指在大于1的自然数中&#xff0c;除了1和它本身以外不再有其他因数的自然数。一个大于1的自然数&#xff0c;除了1和它自身外&#xff0c;不能被其他自然数整除的数叫做质数&#xff1b;否则称为合数&#xff08;规定1既不是质数也不是合数&…

【绝对有用】刚刚开通的GPT-4o计算这种数学题目出现问题了

欢迎关注如何解决以上问题的方法&#xff1a;查看个人简介中的链接的具体解决方案

Nvidia Isaac Sim搭建仿真环境 入门教程 2024(4)

Nvidia Isaac Sim 入门教程 2024 版权信息 Copyright 2023-2024 Herman YeAuromix. All rights reserved.This course and all of its associated content, including but not limited to text, images, videos, and any other materials, are protected by copyright law. …

【iOS】#include、#import、@class、@import

文章目录 #include#importclassimport总结 #include #include是c\c中的预处理器指令&#xff0c;用于包含头文件的内容 但是使用#include可能会出现重复包含文件的问题&#xff0c;因此需要使用&#xff08;#ifndef/#define/#endif&#xff09;。 #import //导入系统头文件…

候选键的确定方法-如何判断属性集U的子集K是否为候选键、如何找到关系模式的候选键

一、候选键的定义 在关系模式R(U,F)中&#xff0c;若&#xff0c;且K满足&#xff0c;则K为关系模式R的候选键 关系模式R的候选键必须满足以下两个条件&#xff1a; &#xff08;1&#xff09;必须是属性集U的子集 &#xff08;2&#xff09;完全函数决定属性集U 二、如何…

【网络安全的神秘世界】已解决Failed to start proxy service on 127.0.0.1:8080

&#x1f31d;博客主页&#xff1a;泥菩萨 &#x1f496;专栏&#xff1a;Linux探索之旅 | 网络安全的神秘世界 | 专接本 | 每天学会一个渗透测试工具 解决burpsuite无法在 127.0.0.1&#xff1a;8080 上启动代理服务端口被占用以及抓不到本地包的问题 Burpsuite无法启动proxy…

令人震撼的人类智慧的科学领域-AI技术

AI&#xff0c;全称为人工智能&#xff08;Artificial Intelligence&#xff09;&#xff0c;是一门致力于让机器模仿人类智慧的科学领域。其核心技术涵盖了机器学习、自然语言处理、计算机视觉及专家系统等多个方面。AI旨在开发能够感知环境、进行逻辑推理、自主学习并做出决策…

Python数据可视化:直方图、核密度估计图、箱线图、累积分布函数图

本文使用数据来源自2023年数学建模国赛C题&#xff0c;以附件1、附件2数据为基础&#xff0c;通过excel的数据透视表等功能重新汇总了一份新的数据表&#xff0c;从中截取了一部分数据为例用于绘制图表。绘制的图表包括一维直方图、一维核密度估计图、二维直方图、二维核密度估…

godot所有2D节点介绍

五十个2D节点介绍 2D节点介绍 前言一、Node2D二、sprite2D三、AnimatedSprite2D四、Camera2D五、PhysicsBody2D六、 RigidBody2D七、CharacterBody2D八、StaticBody2D九、joint2D十、DampedSpringJoint2D十一、GrooveJoint2D十二、PinJoint2D十三、Area2D十四、AnimatableBody2…

cloud_enum:一款针对不同平台云环境安全的OSINT工具

关于cloud_enum cloud_enum是一款功能强大的云环境安全OSINT工具&#xff0c;该工具支持AWS、Azure和Google Cloud三种不同的云环境&#xff0c;旨在帮助广大研究人员枚举目标云环境中的公共资源&#xff0c;并尝试寻找其中潜在的安全威胁。 功能介绍 当前版本的cloud_enum支…

kettle实时增量同步mysql数据

** 本文主要介绍运用kettle实时增量同步mysql数据 ** Debezium介绍 官网地址&#xff1a;https://debezium.io/documentation/ Debezium是一个开源项目&#xff0c;为捕获数据更改(Capture Data Change,CDC)提供了一个低延迟的流式处理平台&#xff0c;通过安装配置Debeziu…

[面试题]RabbitMQ

[面试题]Java【基础】[面试题]Java【虚拟机】[面试题]Java【并发】[面试题]Java【集合】[面试题]MySQL[面试题]Maven[面试题]Spring Boot[面试题]Spring Cloud[面试题]Spring MVC[面试题]Spring[面试题]MyBatis[面试题]Nginx[面试题]缓存[面试题]Redis[面试题]消息队列[面试题]…

科普文章:怎么远程监控电脑屏幕?三种监控电脑屏幕的方法

远程监控公司电脑屏幕是一项重要的管理手段&#xff0c;它不仅有助于提升工作效率&#xff0c;还能确保公司信息安全和合规性。随着远程办公的普及&#xff0c;这一需求变得日益重要。下面我将详细介绍几种实现远程监控公司电脑屏幕的方法&#xff0c;以及实施过程中需要注意的…

网络安全 DVWA通关指南 SQL Injection(SQL注入)

DVWA SQL Injection 文章目录 DVWA SQL InjectionLowMediumHighImpossible SQL注入漏洞基本原理 Web应用程序对用户输入的数据校验处理不严或者根本没有校验&#xff0c;致使用户可以拼接执行SQL命令。 可能导致数据泄露或数据破坏&#xff0c;缺乏可审计性&#xff0c;甚至导致…

机器学习案例|使用机器学习轻松预测信用卡坏账风险,极大程度降低损失

01、案例说明 对于模型的参数&#xff0c;除了使用系统的设定值之外&#xff0c;可以进行再进一步的优化而得到更好的结果。RM提供了几种参数优化的方法&#xff0c;能够让整体模型的效率提高。而其使用的概念&#xff0c;仍然是使用计算机强大的计算能力&#xff0c;对于不同…

01 Shell 编程规范与变量

目录 1.1 Shell脚本概述 1.1.1 Shell的作用 1.1.2 编写第一个Shell脚本 1.1.3 重定向与管道操作 1. 重定向操作 1. 重定向输出 2. 重定向输入 3. 错误重定向 2. 管代操作 1.2 Shell变量的作用、类型 1.2.1 自定义变量 1. 定义新的变量 2. 查看和引用变量的值 3. 变量赋值的特…