基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试

基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试


1. 引言

在计算机视觉领域,图像分类是一个经典的任务。本文将详细介绍如何使用 PyTorch 实现一个树叶分类任务。我们将从数据准备开始,逐步构建模型、训练模型,并在测试集上进行预测,最终生成提交文件。
在这里插入图片描述


2. 环境准备

首先,确保已安装以下 Python 库:

pip install torch torchvision pandas d2l
  • torch:PyTorch 核心库。
  • torchvision:提供计算机视觉相关的工具。
  • pandas:用于处理 CSV 文件。
  • d2l:深度学习工具库,提供辅助函数。

3. 数据准备

竞赛链接:https://www.kaggle.com/competitions/classify-leaves/leaderboard?tab=public

3.1 数据集结构

假设数据集位于 classify-leaves 目录下,包含以下文件:

classify-leaves/
├── train.csv
├── test.csv
├── images/
    ├── image1.jpg
    ├── image2.jpg
    ...
  • train.csv:包含训练图像的路径和标签。
  • test.csv:包含测试图像的路径。

3.2 数据加载与预处理

import os
import pandas as pd
import random

imgpath = "classify-leaves"
trainlist = pd.read_csv(f"{imgpath}/train.csv")
num2name = list(trainlist["label"].value_counts().index)
random.shuffle(num2name)
name2num = {}
for i in range(len(num2name)):
    name2num[num2name[i]] = i
  • num2name:获取所有类别标签,并按类别数量排序。
  • name2num:将类别名称映射到数字编号。

4. 自定义数据集类

为了加载数据,我们需要定义一个自定义数据集类 Leaf_data

from torch.utils.data import Dataset
from d2l import torch as d2l

class Leaf_data(Dataset):
    def __init__(self, path, train, transform=lambda x: x):
        super().__init__()
        self.path = path
        self.transform = transform
        self.train = train
        if train:
            self.datalist = pd.read_csv(f"{path}/train.csv")
        else:
            self.datalist = pd.read_csv(f"{path}/test.csv")

    def __getitem__(self, index):
        res = ()
        tmplist = self.datalist.iloc[index, :]
        for i in tmplist.index:
            if i == "image":
                res += (self.transform(d2l.Image.open(f"{self.path}/{tmplist[i]}")),)
            else:
                res += (name2num[tmplist[i]],)
        if len(res) < 2:
            res += (tmplist[i],)
        return res

    def __len__(self):
        return len(self.datalist)
  • __getitem__:根据索引返回一个样本,包括图像和标签。
  • __len__:返回数据集的长度。

5. 模型定义与初始化

我们使用预训练的 ResNet34 模型,并修改最后一层以适应分类任务:

import torch
import torchvision
from torch import nn

def init_weight(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_normal_(m.weight)

net = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)
net.fc = nn.Linear(in_features=512, out_features=len(name2num), bias=True)
net.fc.apply(init_weight)
net.to(try_gpu())
  • init_weight:使用 Xavier 初始化方法初始化全连接层的权重。
  • net:加载预训练的 ResNet34 模型,并修改最后一层全连接层。

6. 训练过程

6.1 优化器与损失函数

lr = 1e-4
parames = [parame for name, parame in net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.Adam([{"params": parames}, {"params": net.fc.parameters(), "lr": lr * 10}], lr=lr)
LR_con = torch.optim.lr_scheduler.CosineAnnealingLR(trainer, 1, 0)
loss = nn.CrossEntropyLoss(reduction='none')
  • trainer:使用 Adam 优化器,对全连接层使用更高的学习率。
  • LR_con:使用余弦退火学习率调度器。
  • loss:使用交叉熵损失函数。

6.2 训练函数


def train_batch(features, labels, net, loss, trainer, device):
    # 将数据移动到指定设备(如 GPU)
    features, labels = features.to(device), labels.to(device)
    
    # 前向传播
    outputs = net(features)
    l = loss(outputs, labels).mean()  # 计算损失
    
    # 反向传播和优化
    trainer.zero_grad()  # 梯度清零
    l.backward()         # 反向传播
    trainer.step()      # 更新参数
    
    # 计算准确率
    acc = (outputs.argmax(dim=1) == labels).float().mean()
    
    return l.item(), acc.item()

def train(train_data, test_data, net, loss, trainer, num_epochs, device=try_gpu()):
    best_acc = 0
    timer = d2l.Timer()
    plot = d2l.Animator(xlabel="epoch", xlim=[1, num_epochs], legend=['train loss', 'train acc', 'test loss'], ylim=[0, 1])
    for epoch in range(num_epochs):
        metric = d2l.Accumulator(4)
        for i, (features, labels) in enumerate(train_data):
            timer.start()
            l, acc = train_batch(features, labels, net, loss, trainer, device)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
        test_acc = d2l.evaluate_accuracy_gpu(net, test_data, device=device)
        if test_acc > best_acc:
            save_model(net)
            best_acc = test_acc
        plot.add(epoch + 1, (metric[0] / metric[2], metric[1] / metric[3], test_acc))
        print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')
    print(f"best acc {best_acc}")
    return metric[0] / metric[2], metric[1] / metric[3], test_acc
  • train:训练模型,记录损失和准确率,并在验证集上评估模型。

7. 测试与结果保存

在测试集上进行预测,并保存结果到 CSV 文件:

net.load_state_dict(torch.load(model_path))
augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224),
    torchvision.transforms.ToTensor(), norm
])
test_data = Leaf_data(imgpath, False, augs)
test_dataloader = Data.DataLoader(test_data, batch_size=64, shuffle=False)
res = pd.DataFrame(columns=["image", "label"], index=range(len(test_data)))
net = net.cpu()
count = 0
for X, y in test_dataloader:
    preds = net(X).detach().argmax(dim=-1).numpy()
    preds = pd.DataFrame(y, index=map(lambda x: num2name[x], preds))
    preds.loc[:, 1] = preds.index
    preds.index = range(count, count + len(y))
    res.iloc[preds.index] = preds
    count += len(y)
    print(f"loaded {count}/{len(test_data)} datas")
res.to_csv('./submission.csv', index=False)
  • test_dataloader:加载测试数据。
  • res:保存预测结果到 CSV 文件。

8. 总结

本文详细介绍了如何使用 PyTorch 实现一个树叶分类任务,包括数据准备、模型定义、训练、验证和测试。通过本文,您可以掌握以下技能:

  1. 自定义数据集类的实现。
  2. 使用预训练模型进行迁移学习。
  3. 训练模型并保存最佳模型。
  4. 在测试集上进行预测并生成提交文件。

希望本文对您有所帮助!如果有任何问题,欢迎在评论区留言讨论。😊

完整代码

import os
import torch
from torch.utils import data as Data
import torchvision
from torch import nn
from d2l import torch as d2l
import pandas as pd
import random

# 数据准备
imgpath = "classify-leaves"
trainlist = pd.read_csv(f"{imgpath}/train.csv")
num2name = list(trainlist["label"].value_counts().index)
random.shuffle(num2name)
name2num = {}
for i in range(len(num2name)):
    name2num[num2name[i]] = i

# GPU 检查
def try_gpu():
    if torch.cuda.device_count() > 0:
        return torch.device('cuda')
    return torch.device('cpu')

# 模型保存路径
model_dir = './models'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
model_path = os.path.join(model_dir, 'pre_res_model.ckpt')

def save_model(net):
    torch.save(net.state_dict(), model_path)

# 自定义数据集类
class Leaf_data(Data.Dataset):
    def __init__(self, path, train, transform=lambda x: x):
        super().__init__()
        self.path = path
        self.transform = transform
        self.train = train
        if train:
            self.datalist = pd.read_csv(f"{path}/train.csv")
        else:
            self.datalist = pd.read_csv(f"{path}/test.csv")

    def __getitem__(self, index):
        res = ()
        tmplist = self.datalist.iloc[index, :]
        for i in tmplist.index:
            if i == "image":
                res += (self.transform(d2l.Image.open(f"{self.path}/{tmplist[i]}")),)
            else:
                res += (name2num[tmplist[i]],)
        if len(res) < 2:
            res += (tmplist[i],)
        return res

    def __len__(self):
        return len(self.datalist)

def train_batch(features, labels, net, loss, trainer, device):
    # 将数据移动到指定设备(如 GPU)
    features, labels = features.to(device), labels.to(device)
    
    # 前向传播
    outputs = net(features)
    l = loss(outputs, labels).mean()  # 计算损失
    
    # 反向传播和优化
    trainer.zero_grad()  # 梯度清零
    l.backward()         # 反向传播
    trainer.step()      # 更新参数
    
    # 计算准确率
    acc = (outputs.argmax(dim=1) == labels).float().mean()
    
    return l.item(), acc.item()


# 训练函数
def train(train_data, test_data, net, loss, trainer, num_epochs, device=try_gpu()):
    best_acc = 0
    timer = d2l.Timer()
    plot = d2l.Animator(xlabel="epoch", xlim=[1, num_epochs], legend=['train loss', 'train acc', 'test loss'], ylim=[0, 1])
    for epoch in range(num_epochs):
        metric = d2l.Accumulator(4)
        for i, (features, labels) in enumerate(train_data):
            timer.start()
            l, acc = train_batch(features, labels, net, loss, trainer, device)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
        test_acc = d2l.evaluate_accuracy_gpu(net, test_data, device=device)
        if test_acc > best_acc:
            save_model(net)
            best_acc = test_acc
        plot.add(epoch + 1, (metric[0] / metric[2], metric[1] / metric[3], test_acc))
        print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')
    print(f"best acc {best_acc}")
    return metric[0] / metric[2], metric[1] / metric[3], test_acc

# 模型初始化
def init_weight(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_normal_(m.weight)

net = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)
net.fc = nn.Linear(in_features=512, out_features=len(name2num), bias=True)
net.fc.apply(init_weight)
net.to(try_gpu())

# 优化器和损失函数
lr = 1e-4
parames = [parame for name, parame in net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.Adam([{"params": parames}, {"params": net.fc.parameters(), "lr": lr * 10}], lr=lr)
LR_con = torch.optim.lr_scheduler.CosineAnnealingLR(trainer, 1, 0)
loss = nn.CrossEntropyLoss(reduction='none')

# 数据增强和数据加载
batch = 64
num_epochs = 10
norm = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224),
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.ToTensor(), norm
])
train_data, valid_data = Data.random_split(
    dataset=Leaf_data(imgpath, True, augs),
    lengths=[0.8, 0.2]
)
train_dataloder = Data.DataLoader(train_data, batch, True)
valid_dataloder = Data.DataLoader(valid_data, batch, True)

# 训练模型
train(train_dataloder, valid_dataloder, net, loss, trainer, num_epochs)

# 测试模型
net.load_state_dict(torch.load(model_path))
augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224),
    torchvision.transforms.ToTensor(), norm
])
test_data = Leaf_data(imgpath, False, augs)
test_dataloader = Data.DataLoader(test_data, batch_size=64, shuffle=False)
res = pd.DataFrame(columns=["image", "label"], index=range(len(test_data)))
net = net.cpu()
count = 0
for X, y in test_dataloader:
    preds = net(X).detach().argmax(dim=-1).numpy()
    preds = pd.DataFrame(y, index=map(lambda x: num2name[x], preds))
    preds.loc[:, 1] = preds.index
    preds.index = range(count, count + len(y))
    res.iloc[preds.index] = preds
    count += len(y)
    print(f"loaded {count}/{len(test_data)} datas")
res.to_csv('./submission.csv', index=False)

参考链接:

  • PyTorch 官方文档
  • torchvision 官方文档
  • d2l 深度学习工具库

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/968102.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

11vue3实战-----封装缓存工具

11vue3实战-----封装缓存工具 1.背景2.pinia的持久化思路3.以localStorage为例解决问题4.封装缓存工具 1.背景 在上一章节&#xff0c;实现登录功能时候&#xff0c;当账号密码正确&#xff0c;身份验证成功之后&#xff0c;把用户信息保存起来&#xff0c;是用的pinia。然而p…

vue中使用高德地图自定义掩膜背景结合threejs

技术架构 vue3高德地图2.0threejs 代码步骤 这里我们就用合肥市为主要的地区&#xff0c;将其他地区扣除&#xff0c;首先使用高德的webapi的DistrictSearch功能&#xff0c;使用该功能之前记得检查一下初始化的时候是否添加到plugins中&#xff0c;然后搜索合肥市的行政数据…

02、QLExpress从入门到放弃,相关API和文档

QLExpress从入门到放弃,相关API和文档 一、属性开关 public class ExpressRunner {private boolean isTrace;private boolean isShortCircuit;private boolean isPrecise; }/*** 是否需要高精度计算*/ private boolean isPrecise false;高精度计算在会计财务中非常重要&…

二、OSG学习笔记-入门开发

前一章节&#xff1a;一、OSG学习笔记-编译开发环境-CSDN博客https://blog.csdn.net/weixin_36323170/article/details/145513691 一、环境配置 1、VS需要配置头文件路径如下图&#xff1a;&#xff08;$(OSG_INCLUDE)&#xff09; 这里的OSG_INCLUDE,为环境变量名&#xff0…

C++ Primer 语句作用域

欢迎阅读我的 【CPrimer】专栏 专栏简介&#xff1a;本专栏主要面向C初学者&#xff0c;解释C的一些基本概念和基础语言特性&#xff0c;涉及C标准库的用法&#xff0c;面向对象特性&#xff0c;泛型特性高级用法。通过使用标准库中定义的抽象设施&#xff0c;使你更加适应高级…

Windows逆向工程入门之汇编开发框架解析

公开视频 -> 链接点击跳转公开课程博客首页 -> ​​​链接点击跳转博客主页 目录 环境搭建与配置 Visual Studio配置 X86汇编基础框架 基本程序框架 数据定义与内存访问 过程&#xff08;函数&#xff09;定义 汇编框架解析 代码主体解析 完整代码执行 代码逻…

Android ndk兼容 64bit so报错

1、报错logcat如下 2025-01-13 11:34:41.963 4687-4687 DEBUG pid-4687 A #01 pc 00000000000063b8 /system/lib64/liblog.so (__android_log_default_aborter16) (BuildId: 467c2038cdfa767245f9280e657fdb85) 2025…

工业路由器物联网应用,智慧环保环境数据监测

在智慧环保环境数据监测中工业路由器能连接各类分散的传感器&#xff0c;实现多源环境数据集中采集&#xff0c;并通过多种通信网络稳定传输至数据中心或云平台。 工作人员借助工业路由器可远程监控设备状态与环境数据&#xff0c;还能远程配置传感器参数。远程控制设置数据阈…

QT修仙笔记 事件大圆满 闹钟大成

学习笔记 牛客刷题 闹钟 时钟显示 通过 QTimer 每秒更新一次 QLCDNumber 显示的当前时间&#xff0c;格式为 hh:mm:ss&#xff0c;实现实时时钟显示。 闹钟设置 使用 QDateTimeEdit 让用户设置闹钟时间&#xff0c;可通过日历选择日期&#xff0c;设置范围为当前时间到未来 …

MapReduce到底是个啥?

在聊 MapReduce 之前不妨先看个例子&#xff1a;假设某短视频平台日活用户大约在7000万左右&#xff0c;若平均每一个用户产生3条行为日志&#xff1a;点赞、转发、收藏&#xff1b;这样就是两亿条行为日志&#xff0c;再假设每条日志大小为100个字节&#xff0c;那么一天就会产…

拯救者Y9000P双系统ubuntu22.04安装4070显卡驱动

拯救者Y9000P双系统ubuntu22.04安装4070显卡驱动 1. 前情&#xff1a; 1TB的硬盘&#xff0c;分了120G作ubuntu22.04。/boot: 300MB, / : 40GB, /home: 75G, 其余作swap area。 2. 一开始按这个教程&#xff1a;对我无效 https://blog.csdn.net/Eric_xkk/article/details/1…

Redis 数据类型 List 列表

列表类型是⽤来存储多个有序的字符串&#xff0c;如下图所⽰&#xff0c;a、b、c、d、e 五个元素从左到右组成了⼀个有序的列表&#xff0c;列表中的每个字符串称为元素&#xff08;element&#xff09;&#xff0c;⼀个列表最多可以存储 2^32 - 1个元素。在 Redis 中&#xff…

【devops】Macos 轻量化docker解决方案 orbstack | 不用Docker Desktop启动docker服务

一、orbstack OrbStack is the fast, light, and easy way to run Docker containers and Linux machines. It’s a supercharged WSL and Docker Desktop alternative, all in one easy-to-use app. 二、orbstack 的可视化

RabbitMQ消息队列 发送和接受

步骤 1: 安装 RabbitMQ 首先&#xff0c;需要安装 RabbitMQ&#xff0c;并确保它在运行中。 下载erlang语言包OTP。官网地址&#xff1a;Downloads - Erlang/OTP Rabbitmq官网下载地址&#xff1a;Downloading and Installing RabbitMQ — RabbitMQ 安装MQ注意事项&#xf…

2025最新版Node.js下载安装~保姆级教程

1. node中文官网地址&#xff1a;http://nodejs.cn/download/ 2.打开node官网下载压缩包&#xff1a; 根据操作系统不同选择不同版本&#xff08;win7系统建议安装v12.x&#xff09; 我这里选择最新版win 64位 3.安装node ①点击对话框中的“Next”&#xff0c;勾选同意后点…

Spring Boot 3.4 中 MockMvcTester 的新特性解析

引言 在 Spring Boot 3.4 版本中&#xff0c;引入了一个全新的 MockMvcTester 类&#xff0c;使 MockMvc 测试可以直接支持 AssertJ 断言。本文将深入探讨这一新特性&#xff0c;分析它如何优化 MockMvc 测试并提升测试的可读性。 Spring MVC 示例 为了演示 MockMvcTester 的…

WEB攻防-文件下载文件读取文件删除目录遍历目录穿越

目录 一、文件下载漏洞 1.1 文件下载案例&#xff08;黑盒角度&#xff09; 1.2 文件读取案例&#xff08;黑盒角度&#xff09; 二、文件删除 三、目录遍历与目录穿越 四、审计分析-文件下载漏洞-XHCMS 五、审计分析-文件读取漏洞-MetInfo-函数搜索 六、审计分析-…

01.Docker 概述

Docker 概述 1. Docker 的主要目标2. 使用Docker 容器化封装应用程序的意义3. 容器和虚拟机技术比较4. 容器和虚拟机表现比较5. Docker 的组成6. Namespace7. Control groups8. 容器管理工具9. docker 的优缺点10. 容器的相关技术 docker 官网: http://www.docker.com 帮助文档…

IDEA中常见问题汇总

&#x1f353; 简介&#xff1a;java系列技术分享(&#x1f449;持续更新中…&#x1f525;) &#x1f353; 初衷:一起学习、一起进步、坚持不懈 &#x1f353; 如果文章内容有误与您的想法不一致,欢迎大家在评论区指正&#x1f64f; &#x1f353; 希望这篇文章对你有所帮助,欢…

基于蜘蛛蜂优化算法的无人机集群三维路径规划Matlab实现

代码下载&#xff1a;私信博主回复基于蜘蛛蜂优化算法的无人机集群三维路径规划Matlab实现 《基于蜘蛛蜂优化算法的无人机集群三维路径规划》 摘要 本研究针对无人机集群三维路径规划问题&#xff0c;提出了一种基于蜘蛛蜂优化算法的解决方案。以5个无人机构成的集群为研究对…