利用Pytorch预训练模型进行图像分类

Use Pre-trained models for Image Classification.

# This post is rectified on the base of https://learnopencv.com/pytorch-for-beginners-image-classification-using-pre-trained-models/

# And we have re-orginaized the code script.

预训练模型(Pre-trained models)是在ImageNet等大型基准数据集上训练的神经网络模型。深度学习社区从这些开源模型中受益匪浅。此外,预训练模型也是计算机视觉研究取得快速进展的一个重要因素。其他研究人员和从业人员可以使用这些最先进的模型,而不是从头开始重新训练。

# Here are some examples of classic pre-trained models.

在这里插入图片描述

在详细介绍如何使用预训练模型进行图像分类之前,我们先来看看有哪些预训练模型。我们将以 AlexNet 和 ResNet101 为例进行讨论。这两个网络都在 ImageNet 数据集上训练过。

ImageNet 数据集拥有超过 1400 万张由斯坦福大学维护的图像。它被广泛用于各种与图像相关的深度学习项目。这些图像属于不同的类别或标签。预训练模型(如 AlexNet 和 ResNet101)的目的是将图像作为输入并预测其类别。

这里的 "预训练 "是指,深度学习架构 AlexNet 和 ResNet101 已经在某个(庞大的)数据集上进行过训练,因此带有由此产生的权重和偏差。架构与权重和偏置之间的区别应该非常明显,因为我们将在下一节看到,TorchVision 同时拥有架构和预训练模型。

1.1 Model Inference Process

由于我们将重点讨论如何使用预先训练好的模型来预测输入的类别(标签),因此我们也来讨论一下其中涉及的过程。这个过程被称为模型推理。整个过程包括以下主要步骤:

(1) 读取输入图像;
(2) 对图像进行转换;例如resize、center crop、normalization等;
(3) 前向传递:使用预训练的模型权重来获得输出向量,而输出向量中的每个元素都描述了模型对于输入图像属于特定类别的置信度预测结果;
(4) 预测结果:基于获得的置信度分数,显示预测结果。

1.2 Loading Pre-Trained Network using TorchVision

# [Optinal Step]
# %pip install torchvision
# Load necessary packages.
from PIL import Image
import torch
import torchvision
from torchvision import models
from torchvision import transforms

print(torch.__version__)
print(torchvision.__version__)
2.0.0
0.15.0
# Check the different models and architectures available to us.
dir(models)
['AlexNet',
 'AlexNet_Weights',
 'ConvNeXt',
 'ConvNeXt_Base_Weights',
 'ConvNeXt_Large_Weights',
 'ConvNeXt_Small_Weights',
 'ConvNeXt_Tiny_Weights',
 'DenseNet',
 'DenseNet121_Weights',
 'DenseNet161_Weights',
 'DenseNet169_Weights',
 'DenseNet201_Weights',
 'EfficientNet',
 'EfficientNet_B0_Weights',
 'EfficientNet_B1_Weights',
 'EfficientNet_B2_Weights',
 'EfficientNet_B3_Weights',
 'EfficientNet_B4_Weights',
 'EfficientNet_B5_Weights',
 'EfficientNet_B6_Weights',
 'EfficientNet_B7_Weights',
 'EfficientNet_V2_L_Weights',
 'EfficientNet_V2_M_Weights',
 'EfficientNet_V2_S_Weights',
 'GoogLeNet',
 'GoogLeNetOutputs',
 'GoogLeNet_Weights',
 'Inception3',
 'InceptionOutputs',
 'Inception_V3_Weights',
 'MNASNet',
 'MNASNet0_5_Weights',
 'MNASNet0_75_Weights',
 'MNASNet1_0_Weights',
 'MNASNet1_3_Weights',
 'MaxVit',
 'MaxVit_T_Weights',
 'MobileNetV2',
 'MobileNetV3',
 'MobileNet_V2_Weights',
 'MobileNet_V3_Large_Weights',
 'MobileNet_V3_Small_Weights',
 'RegNet',
 'RegNet_X_16GF_Weights',
 'RegNet_X_1_6GF_Weights',
 'RegNet_X_32GF_Weights',
 'RegNet_X_3_2GF_Weights',
 'RegNet_X_400MF_Weights',
 'RegNet_X_800MF_Weights',
 'RegNet_X_8GF_Weights',
 'RegNet_Y_128GF_Weights',
 'RegNet_Y_16GF_Weights',
 'RegNet_Y_1_6GF_Weights',
 'RegNet_Y_32GF_Weights',
 'RegNet_Y_3_2GF_Weights',
 'RegNet_Y_400MF_Weights',
 'RegNet_Y_800MF_Weights',
 'RegNet_Y_8GF_Weights',
 'ResNeXt101_32X8D_Weights',
 'ResNeXt101_64X4D_Weights',
 'ResNeXt50_32X4D_Weights',
 'ResNet',
 'ResNet101_Weights',
 'ResNet152_Weights',
 'ResNet18_Weights',
 'ResNet34_Weights',
 'ResNet50_Weights',
 'ShuffleNetV2',
 'ShuffleNet_V2_X0_5_Weights',
 'ShuffleNet_V2_X1_0_Weights',
 'ShuffleNet_V2_X1_5_Weights',
 'ShuffleNet_V2_X2_0_Weights',
 'SqueezeNet',
 'SqueezeNet1_0_Weights',
 'SqueezeNet1_1_Weights',
 'SwinTransformer',
 'Swin_B_Weights',
 'Swin_S_Weights',
 'Swin_T_Weights',
 'Swin_V2_B_Weights',
 'Swin_V2_S_Weights',
 'Swin_V2_T_Weights',
 'VGG',
 'VGG11_BN_Weights',
 'VGG11_Weights',
 'VGG13_BN_Weights',
 'VGG13_Weights',
 'VGG16_BN_Weights',
 'VGG16_Weights',
 'VGG19_BN_Weights',
 'VGG19_Weights',
 'ViT_B_16_Weights',
 'ViT_B_32_Weights',
 'ViT_H_14_Weights',
 'ViT_L_16_Weights',
 'ViT_L_32_Weights',
 'VisionTransformer',
 'Weights',
 'WeightsEnum',
 'Wide_ResNet101_2_Weights',
 'Wide_ResNet50_2_Weights',
 '_GoogLeNetOutputs',
 '_InceptionOutputs',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 '_api',
 '_meta',
 '_utils',
 'alexnet',
 'convnext',
 'convnext_base',
 'convnext_large',
 'convnext_small',
 'convnext_tiny',
 'densenet',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'detection',
 'efficientnet',
 'efficientnet_b0',
 'efficientnet_b1',
 'efficientnet_b2',
 'efficientnet_b3',
 'efficientnet_b4',
 'efficientnet_b5',
 'efficientnet_b6',
 'efficientnet_b7',
 'efficientnet_v2_l',
 'efficientnet_v2_m',
 'efficientnet_v2_s',
 'get_model',
 'get_model_builder',
 'get_model_weights',
 'get_weight',
 'googlenet',
 'inception',
 'inception_v3',
 'list_models',
 'maxvit',
 'maxvit_t',
 'mnasnet',
 'mnasnet0_5',
 'mnasnet0_75',
 'mnasnet1_0',
 'mnasnet1_3',
 'mobilenet',
 'mobilenet_v2',
 'mobilenet_v3_large',
 'mobilenet_v3_small',
 'mobilenetv2',
 'mobilenetv3',
 'optical_flow',
 'quantization',
 'regnet',
 'regnet_x_16gf',
 'regnet_x_1_6gf',
 'regnet_x_32gf',
 'regnet_x_3_2gf',
 'regnet_x_400mf',
 'regnet_x_800mf',
 'regnet_x_8gf',
 'regnet_y_128gf',
 'regnet_y_16gf',
 'regnet_y_1_6gf',
 'regnet_y_32gf',
 'regnet_y_3_2gf',
 'regnet_y_400mf',
 'regnet_y_800mf',
 'regnet_y_8gf',
 'resnet',
 'resnet101',
 'resnet152',
 'resnet18',
 'resnet34',
 'resnet50',
 'resnext101_32x8d',
 'resnext101_64x4d',
 'resnext50_32x4d',
 'segmentation',
 'shufflenet_v2_x0_5',
 'shufflenet_v2_x1_0',
 'shufflenet_v2_x1_5',
 'shufflenet_v2_x2_0',
 'shufflenetv2',
 'squeezenet',
 'squeezenet1_0',
 'squeezenet1_1',
 'swin_b',
 'swin_s',
 'swin_t',
 'swin_transformer',
 'swin_v2_b',
 'swin_v2_s',
 'swin_v2_t',
 'vgg',
 'vgg11',
 'vgg11_bn',
 'vgg13',
 'vgg13_bn',
 'vgg16',
 'vgg16_bn',
 'vgg19',
 'vgg19_bn',
 'video',
 'vision_transformer',
 'vit_b_16',
 'vit_b_32',
 'vit_h_14',
 'vit_l_16',
 'vit_l_32',
 'wide_resnet101_2',
 'wide_resnet50_2']

以AlexNet为例,我们可以看到还有一个名称为alexnet的条目。其中,大写的名称是Python类(AlexNet),而alexnet是一个便于操作的函数(convenience function),用于从AlexNet类返回实例化的模型。

这些方便函数也可以有不同的参数集,例如:densenet121、densenet161、densenet169以及densenet201,都是DenseNet的实例,但层数分别为121,161,169和201.

1.3. Using AlexNet for Image Classification

AlexnetNet是图像识别领域早期的一个突破性网络结构,相关文章可以参考Understanding Alexnet。该网络架构如下:

在这里插入图片描述

Step 1: Load the pre-trained model

# Create an instance of the network.
alexnet = models.alexnet(pretrained=True)
/home/wsl_ubuntu/anaconda3/envs/xy_trans/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/wsl_ubuntu/anaconda3/envs/xy_trans/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
# Note: Pytorch模型的扩展名通常为.pt或.pth
# Check the model details.
print(alexnet)
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

Step 2: Specify image transformations

# Use transforms to compose all the data transformations.
transform = transforms.Compose([
    transforms.Resize(256), 
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])])     # Three numbers for RGB Channels.
# transforms.Resize: Resize the input images to 256x256 pixels.
# transforms.CenterCrop: Crop the image to 224×224 pixels about the center.
# transforms.Normalize: Normalize the image by setting its mean and standard deviation to the specified values.
# transforms.ToTensor: Convert the image to Pytorch tensor datatype.

Step 3: Load the input image and pre-process it.

# Download image
# !wget https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg -O dog.jpg
img = Image.open("./dog.jpg")
img

在这里插入图片描述

# Pre-process the image.
trans_img = transform(img)

img_batch = torch.unsqueeze(trans_img, 0)

Step 4: Model Inference

# Set the model to eval model.
alexnet.eval()

out = alexnet(img_batch)
print(out.shape)
torch.Size([1, 1000])
# Download classes text file
!wget https://raw.githubusercontent.com/Lasagne/Recipes/master/examples/resnet50/imagenet_classes.txt
--2023-12-14 21:30:09--  https://raw.githubusercontent.com/Lasagne/Recipes/master/examples/resnet50/imagenet_classes.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 0.0.0.0, ::
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|0.0.0.0|:443... failed: Connection refused.
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|::|:443... failed: Connection refused.
# Load labels.
with open('imagenet_classes.txt') as f:
    classes = [line.strip() for line in f.readlines()]
# Find out the maximum score.
_, index = torch.max(out, 1)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
print(classes[index[0]], percentage[index[0]].item())
Labrador retriever 41.58513259887695
# The model predicts the image to be of a Labrador Retriever with a 41.58% confidence.
_, indices = torch.sort(out, descending=True)
[(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]
[('Labrador retriever', 41.58513259887695),
 ('golden retriever', 16.59164810180664),
 ('Saluki, gazelle hound', 16.286897659301758),
 ('whippet', 2.8539111614227295),
 ('Ibizan hound, Ibizan Podenco', 2.39247727394104)]

1.4. Using ResNet for Image Classification

# Load the resnet101 model.
resnet = models.resnet101(pretrained=True)

# Set the model to eval mode.
resnet.eval()

# carry out model inference.
out = resnet(img_batch)

# Print the top 5 classes predicted by the model.
_, indices = torch.sort(out, descending=True)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
[(classes[idx], percentage[idx].item()) for idx in indices[0][:5]]
Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /home/wsl_ubuntu/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:51<00:00, 3.47MB/s] 





[('Labrador retriever', 48.255577087402344),
 ('dingo, warrigal, warragal, Canis dingo', 7.900773048400879),
 ('golden retriever', 6.91691780090332),
 ('Eskimo dog, husky', 3.6434390544891357),
 ('bull mastiff', 3.046128273010254)]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/245048.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

回溯热门问题

关卡名 回溯热门问题 我会了✔️ 内容 1.组合总和问题 ✔️ 2.分割回文串问题 ✔️ 3.子集问题 ✔️ 4.排列问题 ✔️ 5.字母全排列问题 ✔️ 6.单词搜索 ✔️ 1. 组合总和问题 LeetCode39题目要求&#xff1a;给你一个无重复元素的整数数组candidates和一个目标整数 ta…

数据结构学习 12字母迷宫

dfs 回溯 剪枝 这个题和dfs有关&#xff0c;但是我之前没有接触过&#xff0c;我看了这一篇很好的文章&#xff0c;看完之后写的答案。 我觉得很好的总结&#xff1a; dfs模板 int check(参数) {if(满足条件)return 1;return 0; }void dfs(int step) {判断边界{相应操作}尝试…

自考 00023高等数学考点整理

空间直角坐标系 右手法则 向量 点到点的距离 点到直线的距离点到平面的距离向量平行向量垂直向量投影向量数乘 a*b axb(行列式计算)直线夹角、直线与平面夹角平面点法式方程空间直角坐标系 右手法则向量数量积、向量积 平行四边形法则、三角形法则 第二章 多元函数 微分学…

VS2022 将项目打包,导出为exe运行

我有一个在 VS2022 上开发的程序&#xff0c;基于.net 6框架, 想打包成 .exe程序&#xff0c;以在另一个没有安装VS的机器上运行&#xff0c;另一个机器是Win7系统&#xff0c;上面安装了.net 6框架。 虽然网上很多教程&#xff0c;需要安装Project Installer&#xff0c;配置A…

从零开始创建一个项目,springBoot+mybatisPlus+mysql+swagger+maven

一&#xff0c;前提 从零开始创建一个项目&#xff0c;绑定了数据库 用到的技术栈&#xff1a;springBootmybatisPlusmysqlswaggermaven 二&#xff0c;创建项目步骤 1&#xff0c;创建项目 创建出来的项目结构如图所示 2&#xff0c;修改配置文件 因为我比较习惯yml语言&…

算法:最小生成树

文章目录 生成树Kruskal算法Prim算法 本篇总结的是最小生成树算法 生成树 连通图中的每一棵生成树&#xff0c;都是原图的一个极大无环子图&#xff0c;即&#xff1a;从其中删去任何一条边&#xff0c;生成树就不在连通&#xff1b;反之&#xff0c;在其中引入任何一条新边&…

路由器交换机配置备份工具

本文主要介绍fast-backup 2.0软件的使用&#xff0c;fast-backup 2.0是可以在任何Windows系统上运行的网络运维软件&#xff0c;帮助运维人员减少大量重复的交换机等设备的配置下载工作&#xff0c;支持的厂商有华为和华三的网络设备和安全设备。 功能特性&#xff1a; 支持S…

提升数据分析效率:Amazon S3 Express One Zone数据湖实战教程

前言 什么是 Amazon S3&#xff1f;什么是 S3 Express One Zone&#xff1f;实现概述 技术架构组件实现步骤概览 第一步&#xff1a;构建数据湖的基础第二步&#xff1a;选择并查看数据集第三步&#xff1a;在 Athena 中搭建架构第四步&#xff1a;数据转换与优化第五步&#x…

数组笔试题解析(下)

数组面试题解析 字符数组 &#xff08;一&#xff09; 我们上一篇文章学习了一维数组的面试题解析内容和字符数组的部分内容&#xff0c;我们这篇文章讲解一下字符数组和指针剩余面试题的解析内容&#xff0c;那现在&#xff0c;我们开始吧。 我们继续看一组字符数组的面试…

binkw32.dll丢失怎么办?这5个方法都可以解决binkw32.dll丢失问题

binkw32.dll文件是什么&#xff1f; binkw32.dll是一个动态链接库文件&#xff0c;它是Windows操作系统中的一个重要组件。它包含了许多用于处理多媒体文件的函数和资源&#xff0c;如视频、音频等。当我们在电脑上打开或播放某些多媒体文件时&#xff0c;系统会调用binkw32.d…

【算法】滑动窗口

目录 基本思想 应用场景 应用实例 总结 基本思想 滑动窗口&#xff0c;也叫尺取法&#xff0c;就是不断的调节子序列的起始位置和终止位置&#xff0c;从而得出我们要想的结果&#xff0c;可以用来解决一些查找满足一定条件的连续区间的性质&#xff08;长度等&#xff09;…

【活动回顾】Databend 云数仓与 Databend Playground 扩展组件介绍

2023 年 12 月 7 日&#xff0c;作为 KubeSphere 的合作伙伴&#xff0c;Databend 荣幸地受邀参与了 KubeSphere 社区主办的云原生技术直播活动。本次活动的核心议题为「Databend 云数仓与 Databend Playground 扩展组件介绍」&#xff0c;此次分享由 Databend Labs 的研发工程…

Vue3-08-条件渲染-v-if 的基本使用

v-if 是什么 v-if 一个指令&#xff0c; 它是用来根据条件表达式&#xff0c;进行选择性地【展示】/【不展示】html元素的。比如 &#xff1a; 有一个按钮A&#xff0c;当条件为真时&#xff0c;展示该按钮&#xff1b;条件为假时&#xff0c;不展示该按钮。与 js 中的 条件判…

如何部署Portainer容器管理工具+cpolar内网穿透实现公网访问管理界面

文章目录 前言1. 部署Portainer2. 本地访问Portainer3. Linux 安装cpolar4. 配置Portainer 公网访问地址5. 公网远程访问Portainer6. 固定Portainer公网地址 前言 本文主要介绍如何本地安装Portainer并结合内网穿透工具实现任意浏览器远程访问管理界面。Portainer 是一个轻量级…

一文5000字从0到1构建高效的接口自动化测试框架思路

在选择接口测试自动化框架时&#xff0c;需要根据团队的技术栈和项目需求来综合考虑。对于测试团队来说&#xff0c;使用Python相关的测试框架更为便捷。无论选择哪种框架&#xff0c;重要的是确保 框架功能完备&#xff0c;易于维护和扩展&#xff0c;提高测试效率和准确性。…

挺进云存储,天翼云全新一代XSSD勇立潮头

引言&#xff1a;自研高性能分布式存储引擎LAVA&#xff0c;实现云硬盘持续创新获得新突。 【全球云观察 &#xff5c; 科技热点关注】 作为算力基础设施的基石&#xff0c;云存储的发展一直备受公有云厂商所重视&#xff0c;对拉动云厂商营收规模带来重要价值&#xff0c;就…

山海鲸开发者:展现数据可视化在各领域的无限可能

作为一名山海鲸可视化软件的内部开发者&#xff0c;我对这款软件投入了大量的经历以及含有深深的情感。下面&#xff0c;我从这款软件应用场景下手&#xff0c;带大家探秘这款软件的多种可能性以及我们的用心。 首先&#xff0c;从行业角度来看&#xff0c;山海鲸可视化软件可以…

06.迪米特法则(Demeter Principle)

明 嘉靖四十年 江南织造总局 小黄门唯唯诺诺的听完了镇守太监杨金水的训斥&#xff0c;赶忙回答&#xff1a;“知道了&#xff0c;干爹&#xff01;” “知道什么&#xff1f;&#xff01;&#xff01;” 杨金水打断了他的话&#xff0c;眼神突然变得凌厉起来&#xff1a; “有…

椋鸟C语言笔记#26:数据在内存中的存储(大小端字节序)、浮点数的存储(IEEE754)

萌新的学习笔记&#xff0c;写错了恳请斧正。 目录 大小端字节序 什么是大小端 写一个判断大小端的程序 浮点数在内存中的存储&#xff08;IEEE 754规则&#xff09; 引入 存储规则解释 读取规则解释 1.阶码不全为0或全为1&#xff08;规格化数&#xff09; 2.阶码全为…

鸿蒙系统走向独立,高校设立“鸿蒙班”,鸿蒙人才紧缺!

近日&#xff0c;华为以及鸿蒙系软件厂商都在积极培养鸿蒙开发人才&#xff0c;产学联动、产教融合是重要的一条路径。目前已有23家985高校、46家211高校已开设或即将开设HarmonyOS相关课程。 一位鸿蒙生态内部人士表示&#xff0c;目前鸿蒙开发人才比较紧缺&#xff0c;而安卓…