动手学深度学习(Pytorch版)代码实践 -计算机视觉-38实战Kaggle比赛:图像分类 (CIFAR-10)

38实战Kaggle比赛:图像分类 (CIFAR-10)

比赛链接:CIFAR-10 - Object Recognition in Images | Kaggle

导入包
import os
import glob
import pandas as pd
import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torch import nn
from d2l import torch as d2l
import liliPytorch as lp
import csv
预处理:数据集分析
# 获取精简数据集
#@save
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',
                                '2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')
# 如果使用完整的Kaggle竞赛的数据集,设置demo为False
demo = True
if demo:
    data_dir = d2l.download_extract('cifar10_tiny')
else:
    data_dir = '../data/cifar-10/'


train_path = '../data/kaggle_cifar10_tiny/train.csv'
file_path = '../data/kaggle_cifar10_tiny/'

# 读取数据
train_data = pd.read_csv(train_path)
# 查看数据
print(train_data['label'].value_counts())
# """
# label
# automobile    112
# frog          107
# truck         103
# horse         102
# airplane      102
# deer           99
# bird           99
# ship           99
# cat            92
# dog            85
# """
1.数据处理与加载
train_path = '../data/kaggle_cifar10_tiny/train.csv'
test_path = '../data/kaggle_cifar10_tiny/test.csv'
file_path = '../data/kaggle_cifar10_tiny/'

# 统计label种类,并排序
cifar_labels = sorted(list(set(train_data['label'])))
# 将label对应编号
labels_to_num = dict(zip(cifar_labels, range(len(cifar_labels))))
# print(labels_to_num)
"""
{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 
'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
"""
# 将编号对应label,用于后续预测
num_to_labels = {value : key for key, value in labels_to_num.items()}
# print(num_to_labels)
"""
{0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer', 
5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
"""


def get_image_filenames(folder_path, extensions=['.png', '.jpg', '.jpeg']):
    # 获取指定文件夹中的所有图片文件
    image_files = []
    for ext in extensions:
        image_files.extend(glob.glob(os.path.join(folder_path, f'*{ext}')))
    # 返回图片文件名列表
    return [os.path.basename(image) for image in image_files]

def save_filenames_to_csv(filenames, csv_path):
    with open(csv_path, mode='w', newline='', encoding='utf-8') as file:
        writer = csv.writer(file)
        # 写入CSV的第一行
        writer.writerow(['id'])
        # 写入每个文件名
        for filename in filenames:
            writer.writerow([filename])

# 获取测试图片名
test_images_path = '../data/kaggle_cifar10_tiny/test'
image_filenames = get_image_filenames(test_images_path)
# 保存到CSV文件
save_filenames_to_csv(image_filenames, file_path + 'test.csv') 


class CifarDataset(Dataset):
    def __init__(self, csv_path, file_path, mode='train', valid_ratio=0.2, resize_height=224, resize_width=224):
        """
        初始化 LeavesDataset 对象。
        参数:
            csv_path (str): 包含图像路径和标签的 CSV 文件路径。
            file_path (str): 图像文件所在目录的路径。
            mode (str, optional): 数据集的模式。可以是 'train', 'valid' 或 'test'。默认值为 'train'。
            valid_ratio (float, optional): 用于验证的数据比例。默认值为 0.2。
            resize_height (int, optional): 调整图像高度的大小。默认值为 224。
            resize_width (int, optional): 调整图像宽度的大小。默认值为 224。
        """
        # 存储图像调整大小的高度和宽度
        self.resize_height = resize_height
        self.resize_width = resize_width
        
        # 存储图像文件路径和模式(train/valid/test)
        if mode == 'train' or mode == 'valid':
            self.file_path = file_path + 'train/'
        else:
             self.file_path = file_path + 'test/'
        self.mode = mode
        
        # 读取包含图像路径和标签的 CSV 文件
        self.data_info = pd.read_csv(csv_path, header=0)
        
        # 获取样本总数
        self.data_len = len(self.data_info.index)
        
        # 计算训练集样本数
        self.train_len = int(self.data_len * (1 - valid_ratio))

        # 根据模式处理数据
        if self.mode == 'train':
            # 训练模式下的图像和标签
            self.train_img = np.asarray(self.data_info.iloc[0:self.train_len, 0])
            self.train_label = np.asarray(self.data_info.iloc[0:self.train_len, 1])
            self.image_arr = self.train_img
            self.label_arr = self.train_label
        elif self.mode == 'valid':
            # 验证模式下的图像和标签
            self.valid_img = np.asarray(self.data_info.iloc[self.train_len:, 0])
            self.valid_label = np.asarray(self.data_info.iloc[self.train_len:, 1])
            self.image_arr = self.valid_img
            self.label_arr = self.valid_label
        elif self.mode == 'test':
            # 测试模式下的图像
            self.test_img = np.asarray(self.data_info.iloc[:, 0])
            self.image_arr = self.test_img

        # 获取图像数组的长度
        self.len_image = len(self.image_arr)
        print(f'扫描所有 {mode} 数据,共 {self.len_image} 张图像')

    def __getitem__(self, idx):
        """
        获取指定索引的图像和标签。

        参数:
            idx (int): 标签文本对应编号的索引

        返回:
            如果是测试模式,返回图像张量;
            否则返回图像张量和标签。
        """
        # 打开图像文件
        if self.mode == 'test':
             self.img = Image.open(self.file_path + str(self.image_arr[idx]))
        else :
            self.img = Image.open(self.file_path + str(self.image_arr[idx]) + '.png')

        if self.mode == 'train':
            # 训练模式下的数据增强
            trans =torchvision.transforms.Compose([
                torchvision.transforms.Resize((self.resize_height, self.resize_width)),
                torchvision.transforms.RandomHorizontalFlip(p=0.5),
                torchvision.transforms.RandomVerticalFlip(p=0.5),
                torchvision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),ratio=(1.0, 1.0)),
                torchvision.transforms.RandomRotation(degrees=30),
                # torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                # torchvision.transforms.RandomResizedCrop(size=self.resize_height, scale=(0.8, 1.0)),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
            self.img = trans(self.img)
        else:
            # 验证和测试模式下的简单处理
            trans = torchvision.transforms.Compose([
                torchvision.transforms.Resize((self.resize_height, self.resize_width)),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
            self.img = trans(self.img)
        
        if self.mode == 'test':
            return self.img
        else:
            # 获取标签文本对应的编号
            self.label = labels_to_num[self.label_arr[idx]]
            return self.img, self.label

    def __call__(self, idx):
        """
        使对象可以像函数一样被调用。
        
        参数:
            idx (int):标签文本对应编号的索引
            
        返回:
            调用 __getitem__ 方法并返回结果。
        """
        return self.__getitem__(idx)

    def __len__(self):
        """
        获取数据集的长度。
        
        返回:
            数据集中图像的数量。
        """
        return self.len_image
    

train_dataset = CifarDataset(train_path,file_path, mode='train', valid_ratio=0.1, resize_height=40, resize_width=40)
valid_dataset = CifarDataset(train_path, file_path, mode='valid',valid_ratio=0.1, resize_height=40, resize_width=40)
test_dataset = CifarDataset(test_path, file_path, mode='test',valid_ratio=0.1, resize_height=40, resize_width=40)


batch_size = 32 
train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)
valid_iter = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=0)
test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=0)
2.模型训练
def train_batch(net, X, y, loss, trainer, devices):
    """使用多GPU训练一个小批量数据。
    参数:
    net: 神经网络模型。
    X: 输入数据,张量或张量列表。
    y: 标签数据。
    loss: 损失函数。
    trainer: 优化器。
    devices: GPU设备列表。
    返回:
    train_loss_sum: 当前批次的训练损失和。
    train_acc_sum: 当前批次的训练准确度和。
    """
    # 如果输入数据X是列表类型
    if isinstance(X, list):
        # 将列表中的每个张量移动到第一个GPU设备
        X = [x.to(devices[0]) for x in X]
    else:
        X = X.to(devices[0])# 如果X不是列表,直接将X移动到第一个GPU设备
    y = y.to(devices[0])# 将标签数据y移动到第一个GPU设备
    net.train() # 设置网络为训练模式
    trainer.zero_grad()# 梯度清零
    pred = net(X) # 前向传播,计算预测值
    l = loss(pred, y) # 计算损失
    l.sum().backward()# 反向传播,计算梯度
    trainer.step() # 更新模型参数
    train_loss_sum = l.sum()# 计算当前批次的总损失
    train_acc_sum = d2l.accuracy(pred, y)# 计算当前批次的总准确度
    return train_loss_sum, train_acc_sum# 返回训练损失和与准确度和

def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay,param_group=True):

    # trainer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9,weight_decay=wd)
    trainer = torch.optim.Adam(net.parameters(), lr=lr,weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)
    loss = nn.CrossEntropyLoss(reduction="none")
    num_batches, timer = len(train_iter), d2l.Timer()
    legend = ['train loss', 'train acc']
    if valid_iter is not None:
        legend.append('valid acc')
    animator = lp.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=legend)
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        net.train()
        metric = lp.Accumulator(3)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = train_batch(net, features, labels,loss, trainer, devices)
            metric.add(l, acc, labels.shape[0])
            timer.stop()
            train_l = metric[0] / metric[2] # 计算训练损失
            train_acc = metric[1] / metric[2] # 计算训练准确率
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,(train_l , train_acc,None))
        if valid_iter is not None:
            valid_acc = d2l.evaluate_accuracy_gpu(net, valid_iter)
            animator.add(epoch + 1, (None, None, valid_acc))
        scheduler.step()
        print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
              f'valid_acc {valid_acc:.3f}')
        
    measures = (f'train loss {metric[0] / metric[2]:.3f}, '
                f'train acc {metric[1] / metric[2]:.3f}')
    if valid_iter is not None:
        measures += f', valid acc {valid_acc:.3f}'
    print(measures + f'\n{metric[2] * num_epochs / timer.sum():.1f}'
          f' examples/sec on {str(devices)}')
3.定义超参数
# 定义模型
net = d2l.resnet18(len(cifar_labels),3)
devices, num_epochs, lr, wd = d2l.try_all_gpus(), 100, 3e-4, 5e-4
lr_period, lr_decay = 4, 0.9
train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)
plt.show()
# train loss 0.153, train acc 0.955, valid acc 0.469
# 873.5 examples/sec on [device(type='cuda', index=0)]

在这里插入图片描述

4.模型预测
# 针对测试集进行分类预测
def predict(net, data_loader, devices):
    """
    使用模型进行预测

    参数:
        net (torch.nn.Module): 要进行预测的模型
        data_loader (torch.utils.data.DataLoader): 数据加载器,用于提供待预测的数据
        devices (list): 计算设备列表(CPU或GPU)

    返回:
        all_preds (list): 包含所有预测结果的列表
    """
    all_preds = []  # 存储所有预测结果
    net.to(devices[0])  # 将模型移动到指定设备
    net.eval()  # 设置模型为评估模式
    with torch.no_grad():  # 在不需要计算梯度的上下文中进行
        for X in data_loader:  # 遍历数据加载器
            X = X.to(devices[0])  # 将数据移动到指定设备
            outputs = net(X)  # 前向传播,计算模型输出
            _, preds = torch.max(outputs, 1)  # 获取预测结果
            all_preds.extend(preds.cpu().numpy())  # 将预测结果添加到列表中
    return all_preds  # 返回所有预测结果

# 调用预测函数
predictions = predict(net, test_iter, devices)
# 映射预测结果到标签
mapped_predictions = [num_to_labels[int(i)] for i in predictions]
# 读取测试数据
test_data = pd.read_csv(test_path)
# 将预测结果添加到测试数据中
test_data['label'] = pd.Series(mapped_predictions)
# 创建提交文件
submission = pd.concat([test_data['id'], test_data['label']], axis=1)
# 保存提交文件
submission.to_csv(file_path + 'submission.csv', index=False)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/761571.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Mac多线程下载管理器:Neat Download Manage 最新版

Neat Download Manager(NDM)是一款功能强大的下载管理软件,它可以帮助用户更有效地管理和下载网络资源。这款软件支持多种浏览器和协议,可以提升下载速度,恢复中断的下载任务,以及自动化下载过程。在使用任…

如何学好AI绘画?点这里有答案!

前言 地狱难度的求职模式下,“掌握一门技术”的那部分求职者,远比其他人更有竞争力;而拥有出色技术和技能的设计师、以及未来想做设计师的小伙伴们,怎么才能更好实现工作自由? 只有两个字:学习。 学习新…

【Go】excelize库实现excel导入导出封装(四),导出时自定义某一列或多列的单元格样式

大家好,这里是符华~ 查看前三篇: 【Go】excelize库实现excel导入导出封装(一),自定义导出样式、隔行背景色、自适应行高、动态导出指定列、动态更改表头 【Go】excelize库实现excel导入导出封装(二&…

uniapp运行到小程序Vue.use注册全局组件不起作用

真想吐槽一下小程序,uniapp运行到小程序使用Vue.use注册全局组件根本不起作用,也不报错,这只是其中一个问题,其他还有很多问题,比如vue中正常使用的没问题的语法,运行到小程序就不行,又是包太大…

第一后裔延迟高怎么办?快速降低第一后裔延迟

第一后裔/The First Descendant一款射击游戏,融合了刷宝、角色扮演、团队合作、剧情等元素,让每个玩家都能在自己的角度上,找到切入点,并不断地成长,一步步解开后裔身上隐藏的秘密。近期该作正式上线,很多玩…

如何选择适合您业务需求的多语言跨境电商系统源码

随着互联网技术的飞速发展和全球市场的日益融合,多语言跨境电商已经成为许多企业进军国际市场的重要战略。在这个竞争激烈的时代,拥有一个适合自己业务需求的多语言跨境电商系统源码至关重要。本篇文章将为您揭秘如何选择适合您业务需求的多语言跨境电商…

接口自动化测试-项目实战

什么是接口自动化测试:使用工具或代码代替人对接口进行测试 测试项目结构(python包) 1、接口api包 2、script:业务脚本 3、data:数据 4、config.py :配置文件 5、reporter:报告 错误问题: 1、未打印任何东西。添加pip ins…

浅谈定时器之JSR223 定时器

浅谈定时器之JSR223 定时器 JSR223 定时器作为JMeter提供的众多定时器之一,以其高度的灵活性和可编程性脱颖而出,允许用户通过脚本自定义延时逻辑。本文将详细介绍JSR223定时器的特性和使用方法。 JSR223 定时器简介 JSR223 定时器利用了Java平台的JS…

家政小程序的开发,带动市场快速发展,提高家政服务质量

当下生活水平逐渐提高,也增加了年轻人的工作压力,同时老龄化也在日益增加,使得大众对家政的需求日益提高,能力、服务质量高的家政人员能够有效提高大众的生活幸福指数。 但是,传统的家政服务模式存在着效率低、用户与…

Unity 解包工具(AssetStudio/UtinyRipper)

文章目录 1.UtinyRipper2.AssetStudio 1.UtinyRipper 官方地址: https://github.com/mafaca/UtinyRipper/ 下载步骤: 2.AssetStudio 官方地址: https://github.com/Perfare/AssetStudio 下载步骤:

2024百元蓝牙耳机哪个好?2024性价比最高的蓝牙耳机推荐

2024想要在百元左右找到一款好用的性价比高的蓝牙耳机,确实是个不小的挑战。市场上各种耳机品牌和型号琳琅满目,各有各的特点。你可能会疑惑,如何才能在预算内挑选到一款性价比高、音质好的耳机呢?这篇文章将为你提供一些选购百元…

【SpringBoot Web框架实战教程】06 SpringBoot 整合 Druid

不积跬步,无以至千里;不积小流,无以成江海。大家好,我是闲鹤,微信:xxh_1459,十多年开发、架构经验,先后在华为、迅雷服役过,也在高校从事教学3年;目前已创业了…

【Mac】王国保卫战:起源 for mac(塔防策略游戏)游戏介绍和安装教程

游戏介绍 《王国保卫战:起源》(Kingdom: Origins)是一款策略塔防游戏,其核心玩法融合了塔防、策略管理和资源管理元素。游戏的主要目标是在一个开放的像素化世界中建立和管理自己的王国,并抵御夜晚来袭的怪物入侵。 …

华为仓颉语言体验:一个简单的socket服务端实现

前言 由于仓颉目前是内测状态, 不能展示仓颉的详细信息,但是华为仓颉官网的公共文档的内容是可以公开的。 我相信有不少喜欢编程的朋友都申请了内测,但是一些编程初学者应该和我一样,处于摸索阶段。所以,我这里把我测…

如果对方没做幂等!记一次生产订单重复的反思

最近公司公司的旧系统中发现了一个bug。业务部门反馈,尽管用户只支付了一年的服务费用,系统却将有效期增加了两年。 原因分析: 到底是什么原因呢? 经过日志分析,发现消息队列(MQ)向第三方服务发…

想用AI高端算力训练模型?试试英智BayStone平台

随着生成式人工智能的迅猛增长,各大公司纷纷推出强大的 AI产品以提升自身核心竞争力,对于依赖基础模型进行推理训练,同时需要高级基础设施的人工智能初创企业,急需使用高端智算算力来加速模型训练与产品研发创新。 算力是否充足&…

HiBit Uninstaller:软件批量卸载,一触即得

名人说:莫道谗言如浪深,莫言迁客似沙沉。 ——刘禹锡《浪淘沙》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 目录 一、软件介绍1、HiBit Uninstaller2、核心功能 二、下载安装1、下载2、安装 …

【Sklearn-驯化】一文从基础帮你搞懂svm算法做分类和回归的原理以及实践

【Sklearn-驯化】一文从基础帮你搞懂svm算法做分类和回归的原理以及实践 本次修炼方法请往下查看 🌈 欢迎莅临我的个人主页 👈这里是我工作、学习、实践 IT领域、真诚分享 踩坑集合,智慧小天地! 🎇 免费获取相关内容文…

【运维】如何在Ubuntu中设置一个内存守护进程来确保内存不会溢出

文章目录 前言增加守护进程1. 编写监控脚本2. 创建 systemd 服务文件3. 启动并启用服务4. 验证服务是否运行注意事项 如何修改守护进程1. 修改监控脚本2. 重新加载并重启服务3. 验证服务是否运行总结 如何设置一个日志文件来查看信息1. 修改监控脚本以记录日志方法一&#xff1…

antd DatePicker日期选择框限制最多选择一年

实现效果 实现逻辑 import React, { useState } from react;const ParentComponent () > {const [dates, setDates] useState(null);const disabledDate (current) > {if (!dates) {return false;}const tooLate dates[0] && current.diff(dates[0], days) &…