「MobileNet V3」70 个犬种的图片分类

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。


本文目录

    • 数据集与 Notebook
    • 环境准备
    • 数据集
    • 可视化
    • 模型
    • 预测
    • Loss 与评价指标


数据集与 Notebook

数据集:70 Dog Breeds-Image Data Set
Notebook:「MobileNet V3」70 Dog Breeds-Image Classification


环境准备

import warnings
warnings.filterwarnings('ignore')

禁用警告,防止干扰。

!pip install lightning --quiet

安装 PyTorch Lightning。

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style="darkgrid", font_scale=1.5, font="SimHei", rc={"axes.unicode_minus":False})

导入常用的库,设置绘图风格。

import torch
import torchmetrics
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models

导入 PyTorch 相关的库。

import lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

导入 PyTorch Lightning 相关的库。

seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
pl.seed_everything(seed, workers=True)

设置随机种子。


数据集

batch_size = 64

设置批次大小。

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

设置数据集的预处理。

train_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/train", transform=train_transform)
val_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/valid", transform=test_transform)
test_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/test", transform=test_transform)

读取数据集。

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

加载数据集。


可视化

class_names = train_dataset.classes
class_count = [train_dataset.targets.count(i) for i in range(len(class_names))]
df = pd.DataFrame({"Class": class_names, "Count": class_count})

plt.figure(figsize=(12, 20), dpi=100)
sns.barplot(x="Count", y="Class", data=df)
plt.tight_layout()
plt.show()

绘制训练集的类别分布。

训练集的类别分布

plt.figure(figsize=(12, 20), dpi=100)
images, labels = next(iter(val_loader))
for i in range(8):
    ax = plt.subplot(8, 4, i + 1)
    plt.imshow(images[i].permute(1, 2, 0).numpy())
    plt.title(class_names[labels[i]])
    plt.axis("off")
plt.tight_layout()
plt.show()

绘制训练集的样本。

训练集的样本


模型

class LitModel(pl.LightningModule):
    def __init__(self, num_classes=1000):
        super().__init__()
        self.model = models.mobilenet_v3_large(weights="IMAGENET1K_V2")
        # for param in self.model.parameters():
        #     param.requires_grad = False
        self.model.classifier[3] = nn.Linear(self.model.classifier[3].in_features, num_classes)
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.precision = torchmetrics.Precision(task="multiclass", average="macro", num_classes=num_classes)
        self.recall = torchmetrics.Recall(task="multiclass", average="macro", num_classes=num_classes)
        self.f1score = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        x = self.model(x)
        return x

    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.parameters(), lr=0.001, betas=(0.9, 0.99), eps=1e-08, weight_decay=1e-5
        )
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True)
        self.log_dict(
            {
                "train_acc": self.accuracy(y_hat, y),
                "train_prec": self.precision(y_hat, y),
                "train_recall": self.recall(y_hat, y),
                "train_f1score": self.f1score(y_hat, y),
            },
            on_step=True,
            on_epoch=False,
            logger=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss, on_step=False, on_epoch=True, logger=True)
        self.log_dict(
            {
                "val_acc": self.accuracy(y_hat, y),
                "val_prec": self.precision(y_hat, y),
                "val_recall": self.recall(y_hat, y),
                "val_f1score": self.f1score(y_hat, y),
            },
            on_step=False,
            on_epoch=True,
            logger=True,
        )

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        self.log_dict(
            {
                "test_acc": self.accuracy(y_hat, y),
                "test_prec": self.precision(y_hat, y),
                "test_recall": self.recall(y_hat, y),
                "test_f1score": self.f1score(y_hat, y),
            }
        )

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        x, y = batch
        y_hat = self(x)
        preds = torch.argmax(y_hat, dim=1)
        return preds

定义模型。

num_classes = len(class_names)
model = LitModel(num_classes=num_classes)
logger = CSVLogger("./")
early_stop_callback = EarlyStopping(
    monitor="val_loss", min_delta=0.00, patience=5, verbose=False, mode="min"
)
trainer = pl.Trainer(
    max_epochs=20,
    enable_progress_bar=True,
    logger=logger,
    callbacks=[early_stop_callback],
    deterministic=True,
)
trainer.fit(model, train_loader, val_loader)

训练模型。

trainer.test(model, val_loader)

测试模型。


预测

pred = trainer.predict(model, test_loader)
pred = torch.cat(pred, dim=0)
pred = pd.DataFrame(pred.numpy(), columns=["Class"])
pred["Class"] = pred["Class"].apply(lambda x: class_names[x])

plt.figure(figsize=(12, 20), dpi=100)
sns.countplot(y="Class", data=pred)
plt.tight_layout()
plt.show()

绘制预测结果的类别分布。

预测结果的类别分布


Loss 与评价指标

log_path = logger.log_dir + "/metrics.csv"
metrics = pd.read_csv(log_path)
x_name = "epoch"

plt.figure(figsize=(8, 6), dpi=100)
sns.lineplot(x=x_name, y="train_loss", data=metrics, label="Train Loss", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_loss", data=metrics, label="Valid Loss", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.tight_layout()
plt.show()


plt.figure(figsize=(14, 12), dpi=100)

plt.subplot(2,2,1)
sns.lineplot(x=x_name, y="train_acc", data=metrics, label="Train Accuracy", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_acc", data=metrics, label="Valid Accuracy", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")

plt.subplot(2,2,2)
sns.lineplot(x=x_name, y="train_prec", data=metrics, label="Train Precision", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_prec", data=metrics, label="Valid Precision", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Precision")

plt.subplot(2,2,3)
sns.lineplot(x=x_name, y="train_recall", data=metrics, label="Train Recall", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_recall", data=metrics, label="Valid Recall", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("Recall")

plt.subplot(2,2,4)
sns.lineplot(x=x_name, y="train_f1score", data=metrics, label="Train F1-Score", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="val_f1score", data=metrics, label="Valid F1-Score", linewidth=2, marker="X", markersize=12)
plt.xlabel("Epoch")
plt.ylabel("F1-Score")

plt.tight_layout()
plt.show()

绘制 Loss 与评价指标的变化。

Loss

评价指标

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/173775.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C 语言获取文件绝对路径

示例代码 1&#xff0c;不包含根目录绝对路径&#xff1a; #include <stdlib.h> #include <stdio.h>int main(void) {char *fileName "/Dev/test.txt";char *abs_path _fullpath(NULL, fileName, 0);printf("The absolute path is: %s\n", a…

【JavaEE初阶】 JavaScript基础语法——贰

文章目录 &#x1f332;条件语句&#x1f6a9;if 语句&#x1f6a9;三元表达式&#x1f6a9;switch&#x1f6a9;循环语句&#x1f388;while 循环&#x1f388;continue&#x1f388;break&#x1f388;for 循环 &#x1f340;数组&#x1f6a9;创建数组&#x1f6a9;获取数组…

2014年2月24日 Go生态洞察:FOSDEM 2014上的Go演讲精选

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

【教程】cpp转python Nanobind 实践 加速轻量版 pythonbind11

主要是尝试一下把c这边的函数封装打包给python用&#xff0c;选择nanobind的原因是&#xff1a;1. 优化速度快&#xff0c;2. 生成二进制包小&#xff0c;不过pythonbind11是更为广泛知道的&#xff0c;nanobind也是pythonbind11作者后续做的&#xff0c;可以查看作者写的 why …

多项式求和

题目描述 给定程序中 fun 函数的功能是&#xff1a;求出以下分数序列的前 n 项之和&#xff0c;并通过函数值返回 main 函数。 输入格式 输入参数。 输出格式 计算公式返回的结果。 输入输出样例 输入1 5 输出1 8.391667 python解&#xff1a; def fun(n):a1b2s0for…

【数据结构】HashMap 和 HashSet

目录 1.哈希表概念 2冲突 2.1概念 2.2 冲突-避免 2.3冲突-避免-哈希函数设计 2.4 冲突-避免-负载因子调节 ​编辑 2.5 冲突-解决-开散列/哈希桶 2.5冲突严重时的解决办法 3.实现 4.性能分析 5.与Java集合类的关系 1.哈希表概念 在顺序结构中&#xff0c;元素关键码和存…

风电场数字孪生-升压站BIM三维模型-obj格式

简介&#xff1a; 风电场中的升压站三维模型&#xff0c;obj格式&#xff0c;采用BIM技术建模&#xff0c;可应用于风电场三维数字孪生领域&#xff0c;用于对升压站进行漫游浏览&#xff1b;三维可视化场景应用&#xff1b;风电场三维设计模型。 下载地址 风电场数字孪生-升…

2019ICPC南京站

A A Hard Problem 题意&#xff1a;给定一个正整数 n &#xff0c;你需要找出最小整数 k&#xff0c;满足&#xff1a;从{1,2,⋯,n}中任意选择长度为k的子集&#xff0c;存在两个不同的整数 u,v∈T, 且 u 是 v 的因数。 思路&#xff1a;打表找规律 #include <bits/std…

微服务实战系列之加密RSA

前言 在这个时代&#xff0c;我们选择的人生目标已丰富多彩&#xff0c;秉持的人生态度也千差万别&#xff1a; 除了吃喝玩乐&#xff0c;还有科技探索&#xff1b; 除了CityWalk&#xff0c;还有“BookWalk”&#xff1b; 除了走遍中国&#xff0c;还有走遍世界&#xff1b; …

源码安装Apache

一、下载Apache,源码安装Apache #下载 [rootlocalhost opt]# wget -c https://mirrors.aliyun.com/apache/httpd/httpd-2.4.58.tar.gz [rootlocalhost opt]# ls httpd-2.4.58.tar.gz [rootlocalhost opt]# tar -xf httpd-2.4.58.tar.gz [rootlocalhost opt]# ls httpd-2.4.58…

怎么让NetCore接口支持Json参数

项目&#xff1a;NetCore Web API 接口支持Json参数需要安装Newtonsoft.Json.Linq和Microsoft.AspNetCore.Mvc.NewtonsoftJson Program代码 //支持json需要安装Microsoft.AspNetCore.Mvc.NewtonsoftJson using Newtonsoft.Json.Serialization;var builder WebApplication.Cr…

机器学习8:在病马数据集上进行算法比较(ROC曲线与AUC)

ROC曲线与AUC。使用不同的迭代次数&#xff08;基模型数量&#xff09;进行 Adaboost 模型训练&#xff0c;并记录每个模型的真阳性率和假阳性率&#xff0c;并绘制每个模型对应的 ROC 曲线&#xff0c;比较模型性能&#xff0c;输出 AUC 值最高的模型的迭代次数和 ROC 曲线。 …

力扣1038. 从二叉搜索树到更大和树(java,树的中序遍历解法)

Problem: 1038. 从二叉搜索树到更大和树 文章目录 题目描述思路解题方法复杂度Code 题目描述 给定一个二叉搜索树 root (BST)&#xff0c;请将它的每个节点的值替换成树中大于或者等于该节点值的所有节点值之和。 提醒一下&#xff0c; 二叉搜索树 满足下列约束条件&#xff…

ssh远程连接不了虚拟机ubuntu

直奔主题 1. 确保linux安装了ssh2.查看网络适配器是否启用3.连接成功 1. 确保linux安装了ssh sudo apt-get install openssh-server2.查看网络适配器是否启用 3.连接成功

高德地图点击搜索触发输入提示

减少调用次数&#xff0c;不用每输入一次调用一次&#xff0c;输入完后再触发搜索 效果图&#xff1a; ![Alt](https://img-home.csdnimg.cn/images/20220524100510.png dom结构 <div class"seach"><van-searchshow-actionv-model"addressVal"…

CentOS用nginx搭建文件下载服务器

Nginx 是开源、高性能、高可靠的 Web 和反向代理服务器&#xff0c;而且支持热部署&#xff0c;几乎可以做到 7 * 24 小时不间断运行&#xff0c;即使运行几个月也不需要重新启动。在工作中&#xff0c;我们经常会用到需要搭建文件服务器的情况&#xff0c;这里就以在linux下搭…

debian 12 配置

1. 修改apt源 修改apt源为http版本 # 默认注释了源码镜像以提高 apt update 速度&#xff0c;如有需要可自行取消注释 deb http://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm main contrib non-free non-free-firmware # deb-src http://mirrors.tuna.tsinghua.edu.cn/d…

【SA8295P 源码分析】132 - GMSL2 协议分析 之 GPIO/SPI/I2C/UART 等通迅控制协议带宽消耗计算

【SA8295P 源码分析】132 - GMSL2 协议分析 之 GPIO/SPI/I2C/UART 等通迅控制协议带宽消耗计算 一、GPIO 透传带宽消耗计算二、SPI 通迅带宽消耗计算三、I2C 通迅带宽消耗计算四、UART 通迅带宽消耗计算系列文章汇总见:《【SA8295P 源码分析】00 - 系列文章链接汇总》 本文链接…

【经验之谈·高频PCB电路设计常见的66个问题】

文章目录 1、如何选择PCB 板材&#xff1f;2、如何避免高频干扰&#xff1f;3、在高速设计中&#xff0c;如何解决信号的完整性问题&#xff1f;4、差分布线方式是如何实现的&#xff1f;5、对于只有一个输出端的时钟信号线&#xff0c;如何实现差分布线&#xff1f;6、接收端差…

安装向量数据库milvus及其Attu

前置条件安装docker compose 在宿主机上创建文件目录 mkdir -p /home/sunyuhua/milvus/db mkdir -p /home/sunyuhua/milvus/conf mkdir -p /home/sunyuhua/milvus/etcd下载docker-compose.yml wget https://github.com/milvus-io/milvus/releases/download/v2.2.11/milvus-s…