深度学习练手小例子——cifar10数据集分类问题

CIFAR-10 是一个经典的计算机视觉数据集,广泛用于图像分类任务。它包含 10 个类别的 60,000 张彩色图像,每张图像的大小是 32x32 像素。数据集被分为 50,000 张训练图像和 10,000 张测试图像。每个类别包含 6,000 张图像,具体类别包括:

  • 飞机 (airplane)
  • 汽车 (automobile)
  • 鸟 (bird)
  • 猫 (cat)
  • 鹿 (deer)
  • 狗 (dog)
  • 青蛙 (frog)
  • 马 (horse)
  • 船 (ship)
  • 卡车 (truck)

CIFAR-10 是一个多类分类问题,目标是根据图像内容(例如,物体的形状、颜色等特征)预测图像所属的类别。图像分类模型(如卷积神经网络 CNN)常用于这个任务,通过学习图像的空间特征来做出预测。

来看看实现过程:

import torch
import torchvision.datasets
from torch.utils.data import DataLoader
from torch import nn

train_data = torchvision.datasets.CIFAR10(root="../input/cifar10-python",train=True,transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="../input/cifar10-python",train=False,transform=torchvision.transforms.ToTensor(),
                                          download=True)
print(f"train length: {len(train_data)}")
print(f"test length: {len(test_data)}")
Files already downloaded and verified
Files already downloaded and verified
train length: 50000
test length: 10000

找到了CIFAR10数据集并且导入进来,用了三个卷积层的网络模型来训练,进行了10轮训练。

train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10)
        )
    def forward(self,x):
        x = self.model(x)
        return x
mynet = CNN()
mynet = mynet.cuda()

loss_func = nn.CrossEntropyLoss().cuda()
learning_rate = 0.0001
optimizer = torch.optim.Adam(mynet.parameters(),lr=learning_rate)
total_train = 0
total_test = 0
epoch = 10

for i in range(epoch):
    print(f"----No.{i+1} training...-----")
    mynet.train()
    for data in train_dataloader:
        imgs, targets = data
        imgs = imgs.cuda()
        targets = targets.cuda()
        outputs = mynet(imgs)
        loss = loss_func(outputs,targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train = total_train + 1
        if total_train % 100 == 0:
            print(f"训练次数:{total_train},loss:{loss.item()}")
    #测试
    mynet.eval()
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            imgs = imgs.cuda()
            targets = targets.cuda()
            outputs = mynet(imgs)
            loss = loss_func(outputs, targets)
            total_test_loss = total_test_loss + loss.item()
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy
    print(f"测试集的loss:{total_test_loss},准确率:{total_accuracy/len(test_data)}")
    torch.save(mynet, f'myCNN_{i+1}p.pth')
    print("模型保存成功")
----No.1 training...-----
训练次数:100,loss:2.0156445503234863
训练次数:200,loss:1.999146580696106
训练次数:300,loss:1.860052466392517
训练次数:400,loss:1.7510318756103516
训练次数:500,loss:1.7712416648864746
训练次数:600,loss:1.6994789838790894
训练次数:700,loss:1.7278780937194824
测试集的loss:257.74497163295746,准确率:0.41990000009536743
模型保存成功
----No.2 training...-----
训练次数:800,loss:1.515326976776123
训练次数:900,loss:1.485555648803711
训练次数:1000,loss:1.6138449907302856
训练次数:1100,loss:1.7650551795959473
训练次数:1200,loss:1.4380264282226562
训练次数:1300,loss:1.3843588829040527
训练次数:1400,loss:1.5849156379699707
训练次数:1500,loss:1.5038520097732544
测试集的loss:236.6359145641327,准确率:0.47110000252723694
模型保存成功
----No.3 training...-----
训练次数:1600,loss:1.4474828243255615
训练次数:1700,loss:1.4474865198135376
训练次数:1800,loss:1.7310973405838013
训练次数:1900,loss:1.5719612836837769
训练次数:2000,loss:1.6212022304534912
训练次数:2100,loss:1.2924069166183472
训练次数:2200,loss:1.256321907043457
训练次数:2300,loss:1.560215711593628
测试集的loss:221.27214550971985,准确率:0.5011000037193298
模型保存成功
----No.4 training...-----
训练次数:2400,loss:1.4557472467422485
训练次数:2500,loss:1.2620049715042114
训练次数:2600,loss:1.4703019857406616
训练次数:2700,loss:1.4131494760513306
训练次数:2800,loss:1.303225040435791
训练次数:2900,loss:1.4961038827896118
训练次数:3000,loss:1.2810102701187134
训练次数:3100,loss:1.337519645690918
测试集的loss:210.63251876831055,准确率:0.5252999663352966
模型保存成功
----No.5 training...-----
训练次数:3200,loss:1.1311390399932861
训练次数:3300,loss:1.2354803085327148
训练次数:3400,loss:1.2415772676467896
训练次数:3500,loss:1.4213279485702515
训练次数:3600,loss:1.4151396751403809
训练次数:3700,loss:1.2579320669174194
训练次数:3800,loss:1.201486349105835
训练次数:3900,loss:1.287066102027893
测试集的loss:202.65885722637177,准确率:0.5475999712944031
模型保存成功
----No.6 training...-----
训练次数:4000,loss:1.2759090662002563
训练次数:4100,loss:1.3534283638000488
训练次数:4200,loss:1.4388338327407837
训练次数:4300,loss:1.1126259565353394
训练次数:4400,loss:1.072700023651123
训练次数:4500,loss:1.2942607402801514
训练次数:4600,loss:1.3078550100326538
测试集的loss:195.93554836511612,准确率:0.5615000128746033
模型保存成功
----No.7 training...-----
训练次数:4700,loss:1.3510404825210571
训练次数:4800,loss:1.3887534141540527
训练次数:4900,loss:1.2628172636032104
训练次数:5000,loss:1.3063734769821167
训练次数:5100,loss:0.9366315007209778
训练次数:5200,loss:1.208983063697815
训练次数:5300,loss:1.0933520793914795
训练次数:5400,loss:1.2654058933258057
测试集的loss:190.015959918499,准确率:0.5735999941825867
模型保存成功
----No.8 training...-----
训练次数:5500,loss:1.1543941497802734
训练次数:5600,loss:1.0732381343841553
训练次数:5700,loss:1.179479718208313
训练次数:5800,loss:1.0669857263565063
训练次数:5900,loss:1.3145105838775635
训练次数:6000,loss:1.4563915729522705
训练次数:6100,loss:1.0026252269744873
训练次数:6200,loss:0.9769096374511719
测试集的loss:184.76930475234985,准确率:0.5831999778747559
模型保存成功
----No.9 training...-----
训练次数:6300,loss:1.2531676292419434
训练次数:6400,loss:1.0582406520843506
训练次数:6500,loss:1.467718482017517
训练次数:6600,loss:0.9885475635528564
训练次数:6700,loss:0.9887412190437317
训练次数:6800,loss:1.1251451969146729
训练次数:6900,loss:1.0831143856048584
训练次数:7000,loss:0.8735517263412476
测试集的loss:180.18007707595825,准确率:0.5949000120162964
模型保存成功
----No.10 training...-----
训练次数:7100,loss:1.1680148839950562
训练次数:7200,loss:0.9758849740028381
训练次数:7300,loss:1.1076891422271729
训练次数:7400,loss:0.8192071914672852
训练次数:7500,loss:1.2766807079315186
训练次数:7600,loss:1.2046217918395996
训练次数:7700,loss:0.8206453323364258
训练次数:7800,loss:1.1484739780426025
测试集的loss:176.2480058670044,准确率:0.6036999821662903
模型保存成功

拿网上下载的几张图片测试一下,注意路径

import torch
import torchvision
from PIL import Image
from torch import nn

# 10分类,分别为airplane'= 0 'automobile'= 1 'bird'= 2'cat'= 3 'deer'=  4 'dog'=  5 'frog'= 6 'horse'= 7 'ship'= 8 'truck'= 9
image_path = "/kaggle/input/testdata/bird.jpg"
image = Image.open(image_path)
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),
                                            torchvision.transforms.ToTensor()])
image = transform(image)
image = torch.reshape(image,(1,3,32,32))

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10)
        )
    def forward(self,x):
        x = self.model(x)
        return x

model = torch.load("/kaggle/working/myCNN_10p.pth",map_location=torch.device('cpu'))
model.eval()
with torch.no_grad():
    output = model(image)
print(output.argmax(1))
tensor([2])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/962963.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

力扣257. 二叉树的所有路径(遍历思想解决)

Problem: 257. 二叉树的所有路径 文章目录 题目描述思路复杂度Code 题目描述 思路 遍历思想(利用二叉树的先序遍历) 利用先序遍历的思想,我门用一个List变量path记录当前先序遍历的节点,当遍历到根节点时,将其添加到另一个List变量res中&…

力扣第149场双周赛

文章目录 题目总览题目详解找到字符串中合法的相邻数字重新安排会议得到最多空余时间I 第149场双周赛 题目总览 找到字符串中合法的相邻数字 重新安排会议得到最多空余时间I 重新安排会议得到最多空余时间II 变成好标题的最少代价 题目详解 找到字符串中合法的相邻数字 思…

算法题(54):插入区间

审题: 需要我们把newinterval的区间与interval的区间合并起来,并返回合并后的二维数组地址 思路: 方法一:排序合并区间 我们可以先把newinterval插入到interval中,进行排序然后复用合并区间的代码 方法二:模…

网工_HDLC协议

2025.01.25:网工老姜学习笔记 第9节 HDLC协议 9.1 HDLC高级数据链路控制9.2 HDLC帧格式(*控制字段)9.2.1 信息帧(承载用户数据,0开头)9.2.2 监督帧(帮助信息可靠传输,10开头&#xf…

[免费]微信小程序智能商城系统(uniapp+Springboot后端+vue管理端)【论文+源码+SQL脚本】

大家好,我是java1234_小锋老师,看到一个不错的微信小程序智能商城系统(uniappSpringboot后端vue管理端),分享下哈。 项目视频演示 【免费】微信小程序智能商城系统(uniappSpringboot后端vue管理端) Java毕业设计_哔哩哔哩_bilibili 项目介绍…

nth_element函数——C++快速选择函数

目录 1. 函数原型 2. 功能描述 3. 算法原理 4. 时间复杂度 5. 空间复杂度 6. 使用示例 8. 注意事项 9. 自定义比较函数 11. 总结 nth_element 是 C 标准库中提供的一个算法&#xff0c;位于 <algorithm> 头文件中&#xff0c;用于部分排序序列。它的主要功能是将…

CF 581A.Vasya the Hipster(Java实现)

题目分析 红色袜子数量a&#xff0c;蓝色袜子数量b&#xff0c;题目是个潮哥儿&#xff0c;首先选择两种袜子混搭&#xff0c;搭不出来就纯色 思路分析 混搭数量取决于最小数量&#xff0c;剩余的纯色数量取决于哪个还有剩余且数量要/2 代码 import java.util.*;public class…

C基础寒假练习(6)

一、终端输入行数&#xff0c;打印倒金字塔 #include <stdio.h> int main() {int rows;printf("请输入倒金字塔的行数: ");scanf("%d", &rows);for (int i rows; i > 0; i--) {// 打印空格for (int j 0; j < rows - i; j) {printf(&qu…

Python在线编辑器

from flask import Flask, render_template, request, jsonify import sys from io import StringIO import contextlib import subprocess import importlib import threading import time import ast import reapp Flask(__name__)RESTRICTED_PACKAGES {tkinter: 抱歉&…

ASP.NET Core 中间件

目录 一、常见的内置中间件 二、自定义中间件 三、中间件的执行顺序 四、其他自动逸中间件案例 1. 身份验证中间件 2、跨域中间件&#xff08;CORS&#xff09; ASP.NET Core 中&#xff0c;中间件&#xff08;Middleware&#xff09;是处理 HTTP 请求和响应的组件链。你…

LevelDB 源码阅读:写入键值的工程实现和优化细节

读、写键值是 KV 数据库中最重要的两个操作&#xff0c;LevelDB 中提供了一个 Put 接口&#xff0c;用于写入键值对。使用方法很简单&#xff1a; leveldb::Status status leveldb::DB::Open(options, "./db", &db); status db->Put(leveldb::WriteOptions…

2007-2019年各省科学技术支出数据

2007-2019年各省科学技术支出数据 1、时间&#xff1a;2007-2019年 2、来源&#xff1a;国家统计局、统计年鉴 3、指标&#xff1a;行政区划代码、地区名称、年份、科学技术支出 4、范围&#xff1a;31省 5、指标解释&#xff1a;科学技术支出是指为促进科学研究、技术开发…

2025年1月22日(网络编程 udp)

系统信息&#xff1a; ubuntu 16.04LTS Raspberry Pi Zero 2W 系统版本&#xff1a; 2024-10-22-raspios-bullseye-armhf Python 版本&#xff1a;Python 3.9.2 已安装 pip3 支持拍摄 1080p 30 (1092*1080), 720p 60 (1280*720), 60/90 (640*480) 已安装 vim 已安装 git 学习…

如何对系统调用进行扩展?

扩展系统调用是操作系统开发中的一个重要任务。系统调用是用户程序与操作系统内核之间的接口,允许用户程序执行内核级操作(如文件操作、进程管理、内存管理等)。扩展系统调用通常包括以下几个步骤: 一、定义新系统调用 扩展系统调用首先需要定义新的系统调用的功能。系统…

当卷积神经网络遇上AI编译器:TVM自动调优深度解析

从铜线到指令&#xff1a;硬件如何"消化"卷积 在深度学习的世界里&#xff0c;卷积层就像人体中的毛细血管——数量庞大且至关重要。但鲜有人知&#xff0c;一个简单的3x3卷积在CPU上的执行路径&#xff0c;堪比北京地铁线路图般复杂。 卷积的数学本质 对于输入张…

深度学习的应用

目录 一、机器视觉 1.1 应用场景 1.2 常见的计算机视觉任务 1.2.1 图像分类 1.2.2 目标检测 1.2.3 图像分割 二、自然语言处理 三、推荐系统 3.1 常用的推荐系统算法实现方案 四、图像分类实验补充 4.1 CIFAR-100 数据集实验 实验代码 4.2 CIFAR-10 实验代码 深…

Flutter常用Widget小部件

小部件Widget是一个类&#xff0c;按照继承方式&#xff0c;分为无状态的StatelessWidget和有状态的StatefulWidget。 这里先创建一个简单的无状态的Text小部件。 Text文本Widget 文件&#xff1a;lib/app/app.dart。 import package:flutter/material.dart;class App exte…

mysqldump+-binlog增量备份

注意&#xff1a;二进制文件删除必须使用help purge 不可用rm -f 会崩 一、概念 增量备份&#xff1a;仅备份上次备份以后变化的数据 差异备份&#xff1a;仅备份上次完全备份以后变化的数据 完全备份&#xff1a;顾名思义&#xff0c;将数据完全备份 其中&#xff0c;…

智能园区管理系统助力企业安全与效率双提升的成功案例分析

内容概要 在当今迅速发展的商业环境中&#xff0c;企业面临着资产管理、风险控制和运营效率提高等多重挑战。为了应对这些挑战&#xff0c;智能园区管理系统应运而生&#xff0c;为企业提供了全新的解决方案。例如&#xff0c;快鲸智慧园区&#xff08;楼宇&#xff09;管理系…

洛谷 P10289 [GESP样题 八级] 小杨的旅游 C++ 完整题解

一、题目链接 P10289 [GESP样题 八级] 小杨的旅游 - 洛谷 二、题目大意 n个节点之间有n - 1条边&#xff0c;其中k个节点是传送门&#xff0c;任意两个传送门之间可以 以0单位地时间相互到达。问从u到v至少需要多少时间&#xff1f; 三、解题思路 输入不必多讲。 cin >> …