昇思MindSpore学习笔记6-06计算机视觉--Vision Transormer图像分类

摘要:

        记录MindSpore AI框架使用ViT模型在ImageNet图像数据分类上进行训练、验证、推理的过程和方法。包括环境准备、下载数据集、数据集加载、模型解析与构建、模型训练与推理等。

一、

1. ViT模型

Vision Transformer

自注意结构模型

Self-Attention

        Transformer模型

                能够训练具有超过100B规模的参数模型

领域

        自然语言处理

        计算机视觉

不依赖卷积操作

2.模型结构

ViT模型主体结构

从下往上

最下面主输入数据集

        原图像划分为多个patch(图像块)

                二维patch(不考虑channel)转换为一维向量

中间backbone基于Transformer模型Encoder部分

        Multi-head Attention结构

        部分结构顺序有调整

                Normalization位置不同

上面Blocks堆叠后接全连接层Head

附加输入类别向量

输出识别分类结果

二、环境准备

确保安装了Python环境和MindSpore

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore

输出:

Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 

三、数据准备

1.下载、解压数据集

下载源

http://image-net.org

ImageNet数据集

本案例应用数据集是从ImageNet筛选的子集。

from download import download
​
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"
​
path = download(dataset_url, path, kind="zip", replace=True)

输出:

Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip (489.1 MB)

file_sizes: 100%|█████████████████████████████| 513M/513M [00:02<00:00, 228MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./

2.数据集路径结构

.dataset/
    ├── ILSVRC2012_devkit_t12.tar.gz
    ├── train/
    ├── infer/
    └── val/

3.加载数据集

import os
​
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms
​
​
data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
​
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)
​
trans_train = [
    transforms.RandomCropDecodeResize(size=224,
                                      scale=(0.08, 1.0),
                                      ratio=(0.75, 1.333)),
    transforms.RandomHorizontalFlip(prob=0.5),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]
​
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

四、模型解析

1.Transformer基本原理

Transformer模型

基于Attention机制的编码器-解码器型结构

模型结构图:

多个Encoder和Decoder模块所组成

Encoder和Decoder详细结构图:

Encoder与Decoder结构组成

多头注意力Multi-Head Attention层

    基于自注意力Self-Attention机制

    多个Self-Attention并行组成

Feed Forward层

Normaliztion层

残差连接(Residual Connection),图中的“Add”

2.Attention模块

Self-Attention核心内容

为输入向量的每个单词学习一个权重

        给定查询向量Query

        计算Query和各个Key的相似性或者相关性

                得到注意力分布

                得到每个Key对应Value的权重系数

        对Value进行加权求和得到最终的Attention数值。

Self-Attention机制:

(1) 最初的输入向量

经过Embedding层

        映射成dim x 3

        分割成三个向量

                Q(Query)

                K(Key)

                V(Value)

输入向量为一个一维向量序列(x1,x2,x3)

每个一维向量经过Embedding层映射出Q、K、V三个向量

        只是Embedding矩阵不同

        矩阵参数通过学习得到

向量之间关联

通过Q、K、V三个矩阵可计算

其中两个向量点乘获得权重

另一个向量承载权重向加的结果

(2) 自注意力机制的自注意主要体现

Q、K、V来源于其自身

自注意过程

        提取输入的不同顺序的向量的联系与特征

        通过不同顺序向量之间的联系紧密性表现

                Q与K乘积经过Softmax的结果

获取Q,K,V向量间权重

        Q、K点乘

        除以维度的平方根

        Softmax处理所有向量的结果

(3) 全局自注意

向量V与Q、K经过Softmax结果

        weight sum

每一组Q、K、V最后都有一个V输出

当前向量结合其他向量关联权重得到结果

Self-Attention全部过程:

多头注意力机制

分割self-Attention处理的向量为多个Head部分处理

        并行加速

        保持参数总量不变

同样的query, key和value映射为高维空间(Q,K,V)

        不同子空间(Q_0,K_0,V_0)

        分开计算自注意力

        最后再合并不同子空间中的注意力信息。

同一个输入向量

多个注意力机制可以并行加速处理

处理时更充分的分析和利用了向量特征

下图中ai和aj是同一个向量分割而得

以下是Multi-Head Attention代码:

from mindspore import nn, ops
​
class Attention(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_heads: int = 8,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0):
        super(Attention, self).__init__()
​
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = ms.Tensor(head_dim ** -0.5)
​
        self.qkv = nn.Dense(dim, dim * 3)
        self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
        self.out = nn.Dense(dim, dim)
        self.out_drop = nn.Dropout(p=1.0-keep_prob)
        self.attn_matmul_v = ops.BatchMatMul()
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
        self.softmax = nn.Softmax(axis=-1)
​
    def construct(self, x):
        """Attention construct."""
        b, n, c = x.shape
        qkv = self.qkv(x)
        qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
        qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = ops.unstack(qkv, axis=0)
        attn = self.q_matmul_k(q, k)
        attn = ops.mul(attn, self.scale)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        out = self.attn_matmul_v(attn, v)
        out = ops.transpose(out, (0, 2, 1, 3))
        out = ops.reshape(out, (b, n, c))
        out = self.out(out)
        out = self.out_drop(out)
​
        return out

Transformer Encoder

多结构拼接形成Transformer基础结构

Self-Attention

Feed Forward

Residual Connection

Feed Forward,Residual Connection结构代码:

from typing import Optional, Dict
​
class FeedForward(nn.Cell):
    def __init__(self,
                 in_features: int,
                 hidden_features: Optional[int] = None,
                 out_features: Optional[int] = None,
                 activation: nn.Cell = nn.GELU,
                 keep_prob: float = 1.0):
        super(FeedForward, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dense1 = nn.Dense(in_features, hidden_features)
        self.activation = activation()
        self.dense2 = nn.Dense(hidden_features, out_features)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
​
    def construct(self, x):
        """Feed Forward construct."""
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)
​
        return x
​
class ResidualCell(nn.Cell):
    def __init__(self, cell):
        super(ResidualCell, self).__init__()
        self.cell = cell
​
    def construct(self, x):
        """ResidualCell construct."""
        return self.cell(x) + x

Self-Attention构建ViT模型中的TransformerEncoder部分:

ViT模型Transformer不同

Normalization放在Self-Attention和Feed Forward之前

其他结构不变

Transformer结构图

多个子encoder堆叠构建模型编码器

ViT模型配置超参数num_layers

        确定堆叠层数

Residual Connection,Normalization的结构

保证信息经过深层处理不退化

增强模型泛化能力

TransformerEncoder结构和多层感知器(MLP)结合

构成了ViT模型的backbone部分

class TransformerEncoder(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_layers: int,
                 num_heads: int,
                 mlp_dim: int,
                 keep_prob: float = 1.,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: nn.Cell = nn.LayerNorm):
        super(TransformerEncoder, self).__init__()
        layers = []
​
        for _ in range(num_layers):
            normalization1 = norm((dim,))
            normalization2 = norm((dim,))
            attention = Attention(dim=dim,
                                  num_heads=num_heads,
                                  keep_prob=keep_prob,
                                  attention_keep_prob=attention_keep_prob)
​
            feedforward = FeedForward(in_features=dim,
                                      hidden_features=mlp_dim,
                                      activation=activation,
                                      keep_prob=keep_prob)
​
            layers.append(
                nn.SequentialCell([
                    ResidualCell(nn.SequentialCell([normalization1, attention])),
                    ResidualCell(nn.SequentialCell([normalization2, feedforward]))
                ])
            )
        self.layers = nn.SequentialCell(layers)
​
    def construct(self, x):
        """Transformer construct."""
        return self.layers(x)

ViT模型的输入

传统的Transformer结构

处理自然语言领域的词向量

(Word Embedding or Word Vector),

词向量是一维向量堆叠

图片是二维矩阵堆叠,

多头注意力机制处理一维词向量堆叠时会提取词向量之间的联系也就是上下文语义

ViT模型中:

输入图像每个channel卷积操作划分1616个patch

        一幅输入224 x 224的图像卷积处理

                得到16 x 16个patch

                每一个patch的大小就是14 x 14

每个patch矩阵拉伸成为一维向量

获得近似词向量堆叠的效果

        14 x 14patch转换为长度196的向量

图像输入网络经过的第一步处理。

Patch Embedding代码:

class PatchEmbedding(nn.Cell):
    MIN_NUM_PATCHES = 4
​
    def __init__(self,
                 image_size: int = 224,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 input_channels: int = 3):
        super(PatchEmbedding, self).__init__()
​
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)
​
    def construct(self, x):
        """Path Embedding construct."""
        x = self.conv(x)
        b, c, h, w = x.shape
        x = ops.reshape(x, (b, c, h * w))
        x = ops.transpose(x, (0, 2, 1))
​
        return x

输入图像划分patch后

        经过pos_embedding

                class_embedding两个过程。

class_embedding借鉴BERT模型用于文本分类

每一个word vector之前增加一个类别值

196维向量加上class_embedding变为197维

class_embedding是一个可以学习的参数

经过网络的不断训练,输出向量的第一个维度的输出来决定最后的输出类别;

输入16 x 16patch

输出16x16个class_embedding进行分类。

pos_embedding也是一组可以学习的参数

        加入patch矩阵

pos_embedding有4种方案

        采用一维pos_embedding

        由于class_embedding是加在pos_embedding之前

        所以pos_embedding维度会比patch拉伸后的维度加1。

五、整体构建ViT

构建ViT模型代码

from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameter
​
​
def init(init_type, shape, dtype, name, requires_grad):
    """Init."""
    initial = initializer(init_type, shape, dtype).init_data()
    return Parameter(initial, name=name, requires_grad=requires_grad)
​
​
class ViT(nn.Cell):
    def __init__(self,
                 image_size: int = 224,
                 input_channels: int = 3,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 num_layers: int = 12,
                 num_heads: int = 12,
                 mlp_dim: int = 3072,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: Optional[nn.Cell] = nn.LayerNorm,
                 pool: str = 'cls') -> None:
        super(ViT, self).__init__()
​
        self.patch_embedding = PatchEmbedding(image_size=image_size,
                                              patch_size=patch_size,
                                              embed_dim=embed_dim,
                                              input_channels=input_channels)
        num_patches = self.patch_embedding.num_patches
​
        self.cls_token = init(init_type=Normal(sigma=1.0),
                              shape=(1, 1, embed_dim),
                              dtype=ms.float32,
                              name='cls',
                              requires_grad=True)
​
        self.pos_embedding = init(init_type=Normal(sigma=1.0),
                                  shape=(1, num_patches + 1, embed_dim),
                                  dtype=ms.float32,
                                  name='pos_embedding',
                                  requires_grad=True)
​
        self.pool = pool
        self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
        self.norm = norm((embed_dim,))
        self.transformer = TransformerEncoder(dim=embed_dim,
                                              num_layers=num_layers,
                                              num_heads=num_heads,
                                              mlp_dim=mlp_dim,
                                              keep_prob=keep_prob,
                                              attention_keep_prob=attention_keep_prob,
                                              drop_path_keep_prob=drop_path_keep_prob,
                                              activation=activation,
                                              norm=norm)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
        self.dense = nn.Dense(embed_dim, num_classes)
​
    def construct(self, x):
        """ViT construct."""
        x = self.patch_embedding(x)
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
        x = ops.concat((cls_tokens, x), axis=1)
        x += self.pos_embedding
​
        x = self.pos_dropout(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = x[:, 0]
        if self.training:
            x = self.dropout(x)
        x = self.dense(x)
​
        return x

整体流程图如下所示:

六、模型训练与推理

1.模型训练

模型开始训练

设定损失函数

        优化器

        回调函数

调整epoch_size

from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train
​
# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()
​
# construct model
network = ViT()
​
# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
​
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
​
# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),
                        max_lr=0.00005,
                        total_step=epoch_size * step_size,
                        step_per_epoch=step_size,
                        decay_epoch=10)
​
# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)
​
​
# define loss function
class CrossEntropySmooth(LossBase):
    """CrossEntropy."""
​
    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
        super(CrossEntropySmooth, self).__init__()
        self.onehot = ops.OneHot()
        self.sparse = sparse
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
​
    def construct(self, logit, label):
        if self.sparse:
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
        loss = self.ce(logit, label)
        return loss
​
​
network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)
​
# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
​
# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")
​
# train model
model.train(epoch_size,
            dataset_train,
            callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
            dataset_sink_mode=False,)

输出:

Downloading data from https://download-mindspore.osinfra.cn/vision/classification/vit_b_16_224.ckpt (330.2 MB)

file_sizes: 100%|████████████████████████████| 346M/346M [00:26<00:00, 13.2MB/s]
Successfully downloaded file to ./ckpt/vit_b_16_224.ckpt
epoch: 1 step: 125, loss is 1.4842896
Train epoch time: 275011.631 ms, per step time: 2200.093 ms
epoch: 2 step: 125, loss is 1.3481578
Train epoch time: 23961.255 ms, per step time: 191.690 ms
epoch: 3 step: 125, loss is 1.3990085
Train epoch time: 24217.701 ms, per step time: 193.742 ms
epoch: 4 step: 125, loss is 1.1687485
Train epoch time: 23769.989 ms, per step time: 190.160 ms
epoch: 5 step: 125, loss is 1.209775
Train epoch time: 23603.390 ms, per step time: 188.827 ms
epoch: 6 step: 125, loss is 1.3151006
Train epoch time: 23977.132 ms, per step time: 191.817 ms
epoch: 7 step: 125, loss is 1.4682239
Train epoch time: 23898.189 ms, per step time: 191.186 ms
epoch: 8 step: 125, loss is 1.2927357
Train epoch time: 23681.583 ms, per step time: 189.453 ms
epoch: 9 step: 125, loss is 1.5348746
Train epoch time: 23521.045 ms, per step time: 188.168 ms
epoch: 10 step: 125, loss is 1.3726548
Train epoch time: 23719.398 ms, per step time: 189.755 ms

2.模型验证

模型验证

ImageFolderDataset接口用于读取数据集

CrossEntropySmooth接口用于损失函数实例化

Model等接口用于编译模型

步骤:

数据增强

定义ViT网络结构

加载预训练模型参数

设置损失函数

设置评价指标

        Top_1_Accuracy输出最大值为预测结果

        Top_5_Accuracy输出前5的值为预测结果

        两个指标的值越大,代表模型准确率越高

编译模型

验证

dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)
​
trans_val = [
    transforms.Decode(),
    transforms.Resize(224 + 32),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]
​
dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)
​
# construct model
network = ViT()
​
# load ckpt
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
​
network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)
​
# define metric
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),
                'Top_5_Accuracy': train.Top5CategoricalAccuracy()}
​
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")
​
# evaluate model
result = model.eval(dataset_val)
print(result)

输出:

{'Top_1_Accuracy': 0.7495, 'Top_5_Accuracy': 0.928}

3.模型推理

推理图片数据预处理

resize

normalize

匹配训练输入数据

dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)
​
trans_infer = [
    transforms.Decode(),
    transforms.Resize([224, 224]),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]
​
dataset_infer = dataset_infer.map(operations=trans_infer,
                                  input_columns=["image"],
                                  num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)

模型推理

调用模型predict方法

index2label获取对应标签

自定义show_result接口在对应图片上写结果

import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import io
​
​
class Color(Enum):
    """dedine enum color."""
    red = (0, 0, 255)
    green = (0, 255, 0)
    blue = (255, 0, 0)
    cyan = (255, 255, 0)
    yellow = (0, 255, 255)
    magenta = (255, 0, 255)
    white = (255, 255, 255)
    black = (0, 0, 0)
​
​
def check_file_exist(file_name: str):
    """check_file_exist."""
    if not os.path.isfile(file_name):
        raise FileNotFoundError(f"File `{file_name}` does not exist.")
​
​
def color_val(color):
    """color_val."""
    if isinstance(color, str):
        return Color[color].value
    if isinstance(color, Color):
        return color.value
    if isinstance(color, tuple):
        assert len(color) == 3
        for channel in color:
            assert 0 <= channel <= 255
        return color
    if isinstance(color, int):
        assert 0 <= color <= 255
        return color, color, color
    if isinstance(color, np.ndarray):
        assert color.ndim == 1 and color.size == 3
        assert np.all((color >= 0) & (color <= 255))
        color = color.astype(np.uint8)
        return tuple(color)
    raise TypeError(f'Invalid type for color: {type(color)}')
​
​
def imread(image, mode=None):
    """imread."""
    if isinstance(image, pathlib.Path):
        image = str(image)
​
    if isinstance(image, np.ndarray):
        pass
    elif isinstance(image, str):
        check_file_exist(image)
        image = Image.open(image)
        if mode:
            image = np.array(image.convert(mode))
    else:
        raise TypeError("Image must be a `ndarray`, `str` or Path object.")
​
    return image
​
​
def imwrite(image, image_path, auto_mkdir=True):
    """imwrite."""
    if auto_mkdir:
        dir_name = os.path.abspath(os.path.dirname(image_path))
        if dir_name != '':
            dir_name = os.path.expanduser(dir_name)
            os.makedirs(dir_name, mode=777, exist_ok=True)
​
    image = Image.fromarray(image)
    image.save(image_path)
​
​
def imshow(img, win_name='', wait_time=0):
    """imshow"""
    cv2.imshow(win_name, imread(img))
    if wait_time == 0:  # prevent from hanging if windows was closed
        while True:
            ret = cv2.waitKey(1)
​
            closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
            # if user closed window or if some key pressed
            if closed or ret != -1:
                break
    else:
        ret = cv2.waitKey(wait_time)
​
​
def show_result(img: str,
                result: Dict[int, float],
                text_color: str = 'green',
                font_scale: float = 0.5,
                row_width: int = 20,
                show: bool = False,
                win_name: str = '',
                wait_time: int = 0,
                out_file: Optional[str] = None) -> None:
    """Mark the prediction results on the picture."""
    img = imread(img, mode="RGB")
    img = img.copy()
    x, y = 0, row_width
    text_color = color_val(text_color)
    for k, v in result.items():
        if isinstance(v, float):
            v = f'{v:.2f}'
        label_text = f'{k}: {v}'
        cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
                    font_scale, text_color)
        y += row_width
    if out_file:
        show = False
        imwrite(img, out_file)
​
    if show:
        imshow(img, win_name, wait_time)
​
​
def index2label():
    """Dictionary output for image numbers and categories of the ImageNet dataset."""
    metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
    meta = io.loadmat(metafile, squeeze_me=True)['synsets']
​
    nums_children = list(zip(*meta))[4]
    meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
​
    _, wnids, classes = list(zip(*meta))[:3]
    clssname = [tuple(clss.split(', ')) for clss in classes]
    wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
    wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
​
    mapping = {}
    for index, (_, class_name) in enumerate(wind2class_name):
        mapping[index] = class_name[0]
    return mapping
​
​
# Read data for inference
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = image["image"]
    image = ms.Tensor(image)
    prob = model.predict(image)
    label = np.argmax(prob.asnumpy(), axis=1)
    mapping = index2label()
    output = {int(label): mapping[int(label)]}
    print(output)
    show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG",
                result=output,
                out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")

输出:

{236: 'Doberman'}

推理过程完成后

推理文件夹下找图片推理结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/792776.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ubuntu部署minio集群

minio集群介绍 官方文档&#xff1a;https://min.io/docs/minio/linux/operations/install-deploy-manage/deploy-minio-multi-node-multi-drive.html 本方案采用在多节点多驱动器 (MNMD) 或“分布式”配置部署 MinIO。 MNMD 部署提供企业级性能、可用​​性和可扩展性&#…

力扣 203反转链表

思路 用cur->next指向pre,把链表倒转 cur后移&#xff0c;cur指向原链表的下一个 注意用tmp存储原链表中cur的后一个 class Solution { public: ListNode* reverseList(ListNode* head) { ListNode *cur head; ListNode *pre nullptr; ListNode *tmp; while (cur ! nul…

书生·浦语2.5开源,推理能力再创新标杆

2024 年 7 月 3 日&#xff0c;上海人工智能实验室与商汤科技联合香港中文大学和复旦大学正式发布新一代大语言模型书⽣浦语2.5&#xff08;InternLM2.5&#xff09;。相比上一代模型&#xff0c;InternLM2.5 有三项突出亮点&#xff1a; 推理能力大幅提升&#xff0c;领先于国…

【代码随想录】【算法训练营】【第59天】 [卡码110]字符串接龙 [卡码105]有向图的完全可达性 [卡码106]岛屿的周长

前言 思路及算法思维&#xff0c;指路 代码随想录。 题目来自 卡码网。 day 59&#xff0c;周五&#xff0c;继续ding~ 题目详情 [卡码110] 字符串接龙 题目描述 卡码110 字符串接龙 解题思路 前提&#xff1a; 思路&#xff1a; 重点&#xff1a; 代码实现 C语言 […

光伏仿真系统推荐

在全球能源转型和绿色能源发展的背景下&#xff0c;光伏行业作为重要的绿色能源组成部分&#xff0c;其智能化、数字化的发展显得尤为关键。光伏仿真系统作为提升光伏项目设计、运维效率的重要工具&#xff0c;在行业中扮演着不可或缺的角色。在众多光伏仿真系统中&#xff0c;…

自动化(二正)

Java接口自动化用到的技术栈 技术栈汇总&#xff1a; ①Java基础&#xff08;封装、反射、泛型、jdbc&#xff09; ②配置文件解析(properties) ③httpclient&#xff08;发送http请求&#xff09; ④fastjson、jsonpath处理数据的 ⑤testng自动化测试框架重点 ⑥allure测试报…

从0开始的STM32HAL库学习4

对射式红外传感器计数复现 配置工程 我们直接复制oled的工程&#xff0c;但是要重命名。 将PB14设置为中断引脚 自定义命名为sensorcount 设置为上升沿触发 打开中断 配置NVCI 都为默认就可以了 修改代码 修改stm32f1xx_it.c 文件 找到中断函数并修改 void EXTI15_10_I…

element plus 实现跨页面+跨tab栏多选

文章目录 element plus 层面数据层面 菜鸟好久没写博客了&#xff0c;主要是没遇见什么很难的问题&#xff0c;今天碰见了一个没有思路的问题&#xff0c;解决后立马来和大家伙分享了&#xff01; 菜鸟今天要实现一个需求&#xff0c;就是&#xff1a;实现跨页面跨 tab栏 多选…

Linux 程序卡死的特殊处理

一、前言 Linux环境。 我们在日常编写的程序中&#xff0c;可能会出现一些细节问题&#xff0c;导致程序卡死&#xff0c;即程序没法正常运行&#xff0c;界面卡住&#xff0c;也不会闪退... 当这种问题出现在客户现场&#xff0c;那就是大问题了。。。 当我们暂时还无法排…

USB转RS485+RS232+TTL串口电路

USB转RS485RS232TTL电路 USB转RS485RS232TTL电路如下图所示&#xff0c;可实现USB转RS485RS232TTL串口&#xff0c;一个电路模块即可实现电路调试过程中用到常用接口。 电路模块上留有2.54MM单排针接口和接线端子两种接线方式&#xff0c;可接线和跳线。电路模块同时有5V和3.3V…

不仅是输出信息,console.log 也能玩出花

console.log 是 JavaScript 中一个常用的函数&#xff0c;用于向控制台输出信息。 console.log 虽然主要用于调试目的&#xff0c;但也包含了一些有趣的用法&#xff0c; console.log 不仅能输出文本&#xff0c;还能以更丰富的方式展示信息。 比如我们打开 B 站&#xff0c;然…

计算机网络体系结构解析

OSI参考模型 与 TCP/IP模型 如图所示 TCP/IP模型有几层 应用层&#xff1a;只需要专注于为用户提供应用功能 HTTP、SMTP、Telnet等&#xff0c;工作在操作系统中的用户态&#xff0c;传输层及以下工作在内核态传输层&#xff1a;为应用层提供网络支持&#xff08;TCP、UDP传…

c++多态的定义和原理

目录 1、多态的定义和实现 1.多态的构成条件 2.虚函数 3.虚函数的重写(覆盖) 4.虚函数重写的两个例外 5.c11 override和final 6.重载&#xff0c;覆盖(重写)和隐藏(重定义) 2、抽象类 概念 接口继承和实现继承 3、多态的原理 1.虚函数表 2.多态的原理 4、多继承中的虚…

武夷山细节决定成败抓质量求生存

在当今竞争激烈的市场环境中&#xff0c;细节决定成败&#xff0c;质量求生存的理念已成为企业发展的关键。蓝鹏测控科技有限公司&#xff0c;一家专业从事工业测量领域的高新技术企业&#xff0c;正是秉持这一理念&#xff0c;在工业测径仪领域取得了显著成就。 蓝鹏测控科技…

Ozon俄罗斯哪些产品热销中?Ozon7月市场热卖趋势放送

Ozon俄罗斯哪些产品热销工具&#xff1a;D。DDqbt。COm/74rD 据Ozon数据&#xff0c;2023年&#xff0c;在自提服务方面&#xff0c;Ozon投资了100亿扩展自提网络&#xff0c;自提点数量激增至超过5万个&#xff0c;是之前的2.6倍。 物流基础设施方面&#xff0c;Ozon在仓库建…

BGP第二日

上图为今日所用拓扑 &#xff0c;其中R1和R4&#xff0c;R3和R5为EBGP邻居&#xff0c;R1和R3为IBGP邻居&#xff0c;AS200区域做OSPF动态路由 一.BGP建立邻居的六种状态 1.idle 空闲状态&#xff1a;建立邻居最初的状态 2.Connect 连接状态&#xff1a;在…

360安全浏览器就是不行-python秒破解

下面画框都很容易破解&#xff0c;大家试试

ZGC在三色指针中的应用

ZGC基于颜色指针的并发处理算法 ZGC初始化之后&#xff0c;整个内存空间的地址视图被设置为Remapped&#xff0c;当进入标记阶段时的视图转变为Marked0&#xff08;也称为M0&#xff09;或者Marked1&#xff08;也称为M1&#xff09;&#xff0c;从标记阶段结束进入转移阶段时…

怎么样的主食冻干算好冻干?品质卓越、安全可靠的主食冻干分享

当前主食冻干市场产品质量参差不齐。一些品牌过于追求营养数据的堆砌和利润的增长&#xff0c;却忽视了猫咪健康饮食的基本原则&#xff0c;导致市场上出现了以肉粉冒充鲜肉、修改产品日期等不诚信行为。更令人担忧的是&#xff0c;部分产品未经过严格的第三方质量检测便上市销…

MATLAB中的SDPT3、LMILab、SeDuMi工具箱

MATLAB中的SDPT3、LMILab、SeDuMi工具箱都是用于解决特定数学优化问题的工具箱&#xff0c;它们在控制系统设计、机器学习、信号处理等领域有广泛的应用。以下是对这三个工具箱的详细介绍&#xff1a; 1. SDPT3工具箱 简介&#xff1a; SDPT3&#xff08;Semidefinite Progra…