「ResNet-18」70 个犬种的图片分类

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。


本文目录

    • 数据集与 Notebook
    • 环境准备
    • 数据集
    • 可视化
    • 模型
    • 预测
    • Loss 与评价指标


数据集与 Notebook

数据集:70 Dog Breeds-Image Data Set
Notebook:「ResNet-18」70 Dog Breeds-Image Classification


环境准备

import warnings
warnings.filterwarnings('ignore')

禁用警告,防止干扰。

!pip install lightning --quiet

安装 PyTorch Lightning。

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style="darkgrid", font_scale=1.5, font="SimHei", rc={"axes.unicode_minus":False})

导入常用的库,设置绘图风格。

import torch
import torchmetrics
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models

导入 PyTorch 相关的库。

import lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

导入 PyTorch Lightning 相关的库。

seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
pl.seed_everything(seed, workers=True)

设置随机种子。


数据集

batch_size = 64

设置批次大小。

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

设置数据集的预处理。

train_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/train", transform=train_transform)
val_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/valid", transform=test_transform)
test_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/test", transform=test_transform)

读取数据集。

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

加载数据集。


可视化

class_names = train_dataset.classes
class_count = [train_dataset.targets.count(i) for i in range(len(class_names))]
df = pd.DataFrame({"Class": class_names, "Count": class_count})

plt.figure(figsize=(12, 20), dpi=100)
sns.barplot(x="Count", y="Class", data=df)
plt.tight_layout()
plt.show()

绘制训练集的类别分布。

训练集的类别分布

plt.figure(figsize=(12, 20), dpi=100)
images, labels = next(iter(val_loader))
for i in range(8):
    ax = plt.subplot(8, 4, i + 1)
    plt.imshow(images[i].permute(1, 2, 0).numpy())
    plt.title(class_names[labels[i]])
    plt.axis("off")
plt.tight_layout()
plt.show()

绘制训练集的样本。

训练集的样本


模型

class LitModel(pl.LightningModule):
    def __init__(self, num_classes=1000):
        super().__init__()
        self.model = models.resnet18(weights="IMAGENET1K_V1")
        # for param in self.model.parameters():
        #     param.requires_grad = False
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.precision = torchmetrics.Precision(task="multiclass", average="macro", num_classes=num_classes)
        self.recall = torchmetrics.Recall(task="multiclass", average="macro", num_classes=num_classes)
        self.f1score = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        x = self.model(x)
        return x

    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.parameters(), lr=0.001, betas=(0.9, 0.99), eps=1e-08, weight_decay=1e-5
        )
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True)
        self.log_dict(
            {
                "train_acc": self.accuracy(y_hat, y),
                "train_prec": self.precision(y_hat, y),
                "train_recall": self.recall(y_hat, y),
                "train_f1score": self.f1score(y_hat, y),
            },
            on_step=True,
            on_epoch=False,
            logger=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss, on_step=False, on_epoch=True, logger=True)
        self.log_dict(
            {
                "val_acc": self.accuracy(y_hat, y),
                "val_prec": self.precision(y_hat, y),
                "val_recall": self.recall(y_hat, y),
                "val_f1score": self.f1score(y_hat, y),
            },
            on_step=False,
            on_epoch=True,
            logger=True,
        )

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        self.log_dict(
            {
                "test_acc": self.accuracy(y_hat, y),
                "test_prec": self.precision(y_hat, y),
                "test_recall": self.recall(y_hat, y),
                "test_f1score": self.f1score(y_hat, y),
            }
        )

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        x, y = batch
        y_hat = self(x)
        preds = torch.argmax(y_hat, dim=1)
        return preds

定义模型。

num_classes = len(class_names)
model = LitModel(num_classes=num_classes)
logger = CSVLogger("./")
early_stop_callback = EarlyStopping(
    monitor="val_loss", min_delta=0.00, patience=5, verbose=False, mode="min"
)
trainer = pl.Trainer(
    max_epochs=20,
    enable_progress_bar=True,
    logger=logger,
    callbacks=[early_stop_callback],
    deterministic=True,
)
trainer.fit(model, train_loader, val_loader)

训练模型。

trainer.test(model, val_loader)

测试模型。


预测

pred = trainer.predict(model, test_loader)
pred = torch.cat(pred, dim=0)
pred = pd.DataFrame(pred.numpy(), columns=["Class"])
pred["Class"] = pred["Class"].apply(lambda x: class_names[x])

plt.figure(figsize=(12, 20), dpi=100)
sns.countplot(y="Class", data=pred)
plt.tight_layout()
plt.show()

绘制预测结果的类别分布。

预测结果的类别分布


Loss 与评价指标

log_path = logger.log_dir + "/metrics.csv"
metrics = pd.read_csv(log_path)
x_name = "epoch"

plt.figure(figsize=(8, 6), dpi=100)
sns.lineplot(x=x_name, y="train_loss", data=metrics, label="Train Loss", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_loss", data=metrics, label="Valid Loss", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.tight_layout()
plt.show()


plt.figure(figsize=(14, 12), dpi=100)

plt.subplot(2,2,1)
sns.lineplot(x=x_name, y="train_acc", data=metrics, label="Train Accuracy", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_acc", data=metrics, label="Valid Accuracy", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")

plt.subplot(2,2,2)
sns.lineplot(x=x_name, y="train_prec", data=metrics, label="Train Precision", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_prec", data=metrics, label="Valid Precision", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Precision")

plt.subplot(2,2,3)
sns.lineplot(x=x_name, y="train_recall", data=metrics, label="Train Recall", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_recall", data=metrics, label="Valid Recall", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Recall")

plt.subplot(2,2,4)
sns.lineplot(x=x_name, y="train_f1score", data=metrics, label="Train F1-Score", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_f1score", data=metrics, label="Valid F1-Score", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("F1-Score")

plt.tight_layout()
plt.show()

绘制 Loss 与评价指标的变化。

Loss

评价指标

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/172458.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Oracle实时同步技术

Oracle数据库的价值 Oracle数据库是一种高度可靠、安全和强大的关系型数据库管理系统,它具有以下几个方面的价值: 可靠性和稳定性:Oracle数据库以其高度可靠性、稳定性和数据完整性而闻名于世。 安全性:Oracle数据库提供了一系列…

解决uncompyle6反编译报错KeyError

报错内容:KeyError:3.11.0 (这个是我自己的Python版本号) 主要原因就是你的Python版本没在它库文件支持的版本里,我们需要进行手动添加即可。 首先找到文件:magics.py 我的是在这个路径下:D:\…

P9232 [蓝桥杯 2023 省 A] 更小的数(区间DP)

求大数字某连续部分反转后&#xff0c;比原数字小的个数 思路&#xff1a;自前向后遍历 ai是位于数字第i位的数字 aj是位于数字第j位的数字&#xff08;i<j&#xff09; ai>aj f[ai][aj]1; ai<aj f[ai][aj]0; aiaj f[ai][aj]f…

宏电股份荣膺国家知识产权优势企业称号,科技创新与研发实力获国家级认可

近日&#xff0c;国家知识产权局公布了2023年度国家知识产权优势企业的评审结果&#xff0c;宏电股份凭借强大的技术实力和创新能力&#xff0c;荣获“国家知识产权优势企业”荣誉称号。这一荣誉是对宏电股份在技术创新和知识产权创造、运用、保护及管理方面的高度认可&#xf…

通明智云宣布完成数千万元A+轮融资, 引领云原生与信创两翼齐飞的应用交付解决方案

近日&#xff0c;通明智云&#xff08;北京&#xff09;科技有限公司&#xff08;简称&#xff1a;通明智云&#xff09;宣布完成数千万元A轮融资&#xff0c;由全聚合与信公投资联合投资&#xff0c;明论资本担任本轮融资独家财务顾问。本轮融资资金将主要用于NJet云原生应用引…

华为云IoT与OpenHarmony深度协同,加速设备上鸿即上云【云驻共创】

本次专题论坛探讨了华为云IoT与Open Harmony的深度协同、边缘屏蔽硬件差异、实现智慧隧道全方位智能化管理&#xff0c;以及华为云与Open Harmony生态的合作。同时也介绍了华为云物联网卡平台、HTTP2协议以及华为物联网在交通领域的应用。 一&#xff0e;华为云IoT与Open Harm…

Jmeter 自动化性能测试常见问题汇总

一、request 请求超时设置 timeout 超时时间是可以手动设置的&#xff0c;新建一个 http 请求&#xff0c;在“高级”设置中找到“超时”设置&#xff0c;设置连接、响应时间为2000ms。 1. 请求连接超时&#xff0c;连不上服务器。 现象&#xff1a; Jmeter表现形式为&…

18章总结—Swing程序设计

例题1 package admi; import java.awt.*; import javax.swing.*; public class JFreamTest { public static void main(String[] args) { JFrame jfnew JFrame(); jf.setTitle("创建一个JFrame窗体"); Container containerjf.getC…

ubuntu安装nvm

需求 在 virtualbox 虚拟机上运行的 ubuntu &#xff08;22.04.3&#xff09;里安装 nvm &#xff08;Node Version Manager&#xff09; 简述 官网文档 &#xff08;github地址&#xff09;上有提到两种安装方式&#xff0c;一种是直接 curl | wget 命令安装&#xff0c;一…

Royal TSX v6.0.1(远程管理软件)

MacOS远程管理软件哪款好&#xff1f;Royal TSX mac是一款功能非常强大适用于 Mac 的远程连接管理工具。兼容多种连接类型&#xff0c;比如&#xff1a;RDP、VNC、基于SSH连接的终端&#xff0c;SFTP/FTP/SCP或基于Web的连接管理&#xff0c;Royal TSX 都可以满足您的要求&…

LLM之Prompt(二):清华提出Prompt 对齐优化技术BPO

论文题目&#xff1a;《Black-Box Prompt Optimization: Aligning Large Language Models without Model Training》 论文链接&#xff1a;https://arxiv.org/abs/2311.04155 github地址&#xff1a;https://github.com/thu-coai/BPO BPO背景介绍 最近&#xff0c;大型语言模…

学习笔记—吴恩达《AI for everyone》

【写在前面】 学习视频来源&#xff1a;B站“GPT中英字幕课程资源”&#xff08;见图片水印&#xff09;。 此文是自学笔记&#xff0c;主要是截图视频课件中的一些知识点&#xff0c;只做自学使用。 一. AI 介绍 二. 机器学习 Machine Learning 三. 什么是数据 What is AI 四…

玩转大模型行业应用,且看盘古大模型全栈工程能力展身手【云驻共创】

AI技术在金融和工业领域的应用不断扩展&#xff0c;促进了金融行业的数字化转型和产业升级。AI提供了专属财富管家和工业范式的解决方案&#xff0c;在金融领域的应用包括风险评估和投资建议&#xff0c;而在工业领域的应用则涵盖了数据分析和机器人操作。与此同时&#xff0c;…

神经网络中BN层简介及位置分析

1. 简介 Batch Normalization是深度学习中常用的技巧&#xff0c;Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift (Ioffe and Szegedy, 2015) 第一次介绍了这个方法。 这个方法的命名&#xff0c;明明是Standardization, 非…

springcloud医院挂号预约系统源码

开发技术&#xff1a; jdk1.8&#xff0c;mysql5.7&#xff0c;nodejs&#xff0c;idea&#xff0c;vscode springcloud springboot mybatis vue elementui 功能介绍&#xff1a; 用户端&#xff1a; 登录注册 首页显示医生列表&#xff0c;医院简介&#xff0c;点击医生…

Stable Diffusion专场公开课

从SD原理、本地部署到其二次开发 分享时间&#xff1a;11月25日14&#xff1a;00-17&#xff1a;00 分享大纲 从扩散模型DDPM起步理解SD背后原理 SD的本地部署:在自己电脑上快速搭建、快速出图如何基于SD快速做二次开发(以七月的AIGC模特生成系统为例) 分享人简介 July&#…

手把手设计C语言版循环队列(力扣622:设计循环队列)

文章目录 前言描述分析力扣AC代码 力扣&#xff1a; 622.设计循环队列 前言 队列会出现“假溢出”现象&#xff0c;即队列的空间有限&#xff0c;队列是在头和尾进行操作的&#xff0c;当元素个数已经达到最大个数时&#xff0c;队尾已经在空间的最后面了&#xff0c;但是对头…

北邮22级信通院数电:Verilog-FPGA(0)怎么使用modelsim进行仿真?modelsim仿真教程一份请签收~

北邮22信通一枚~ 跟随课程进度更新北邮信通院数字系统设计的笔记、代码和文章 持续关注作者 迎接数电实验学习~ 获取更多文章&#xff0c;请访问专栏&#xff1a; 北邮22级信通院数电实验_青山如墨雨如画的博客-CSDN博客 最近很多uu问我怎么用quartus连接的modelsim软件进…

C#使用MaxMind.GeoIP2数据库查询当前ip地址

GeoLite2-City.mmdb下载 因为比较简单&#xff0c;直接上代码&#xff0c;代码展示获取ip地址的国家和城市信息 using MaxMind.GeoIP2; using MaxMind.GeoIP2.Model; using System; using System.Collections; using System.Collections.Generic; using System.Linq; using Sy…

事关Django的静态资源目录设置与静态资源文件引用(Django的setting.py中的三句静态资源(static)目录设置语句分别是什么作用?)

在Django的setting.py中常见的三句静态资源(static)目录设置语句如下&#xff1a; STATICFILES_DIRS [os.path.join(BASE_DIR, static_list)] # 注意这是一个列表,即可以有多个目录的路径 STATIC_ROOT os.path.join(BASE_DIR, static_root) STATIC_URL /static-url/本文介…