深度学习:基于MindSpore实现ResNet50中药分拣

ResNet基本介绍

ResNet(Residual Network)是一种深度神经网络架构,由微软研究院的Kaiming He等人在2015年提出,并且在ILSVRC 2015竞赛中取得了很好的成绩。ResNet主要解决了随着网络深度增加而出现的退化问题,即当网络变得非常深时,训练误差和验证误差可能会开始上升,这并不是因为过拟合,而是由于深层网络难以优化。

ResNet的核心思想是引入了残差学习框架来简化许多层网络的训练。通过构建“跳跃连接”或称“捷径连接”,允许一层直接与更深层相连接。这种设计让网络可以学习到一个残差函数,相对于传统的学习未参考输入x的目标映射H(x),残差块学习的是F(x) = H(x) - x。这样的结构使得信息能够跨过多层流动,从而缓解了梯度消失/爆炸的问题,也使得训练更深的网络成为可能。

残差块

残差块(Residual Block)是构成ResNet(Residual Network)的核心组件,它通过引入所谓的“跳跃连接”或“捷径连接”来解决深层网络训练中的梯度消失/爆炸问题。这种设计允许信息在多层之间直接传递,从而帮助优化非常深的神经网络。

残差块的核心思想是让网络去学习输入和输出之间的差异,也就是所谓的“残差”,而不是直接学习原始的输入到输出的映射。用数学语言来说,就是:

  • 传统方法:网络尝试学习H(x),即从输入x直接得到输出H(x)。
  • 残差方法:网络学习的是F(x) = H(x) - x,即输入x与期望输出H(x)之间的差值。然后,实际的输出是由原输入加上这个差值得到,即H(x) = F(x) + x。

为什么这样做有效?

  1. 简化学习任务:相比于学习复杂的非线性映射H(x),学习残差F(x)可能要简单得多。特别是当H(x)接近于x时(例如,对于一些层来说不需要对输入进行太多改变),此时F(x)可以趋近于0,表示这些层只需学会恒等映射即可,这比学习复杂的变换要容易很多。

  2. 缓解梯度问题:由于信息可以通过跳跃连接直接传递给后面的层,因此即使在深层网络中,早期层的信息也不会因为经过多层而完全丢失,从而减轻了梯度消失/爆炸的影响。

  3. 提高模型性能:通过允许网络更深,同时保持良好的可训练性,ResNet能够实现更高的准确率,并且在多个视觉任务上取得了当时最佳的结果。

一个基本的残差块结构如图所示:

(图源:7.6. 残差网络(ResNet) — 动手学深度学习 2.0.0 documentation)

一个典型的残差块包含以下几个部分:

  1. 卷积层:通常是一个或多个3x3的卷积层。这些层执行常规的特征提取任务。
  2. 批量归一化层(Batch Normalization, BN):每个卷积层后紧跟着一个批量归一化层,以加速收敛过程并提高模型的泛化能力。
  3. ReLU激活函数:除了最后一个卷积层之后,其余卷积层之后都会应用ReLU激活函数来引入非线性。
  4. 跳跃连接:这是残差块最独特的部分。它将输入直接加到经过上述处理后的输出上。如果输入和输出具有相同的尺寸(即相同数量的通道数和空间维度),则可以直接相加;如果它们不同,则需要使用额外的1x1卷积层对输入进行调整,以便能够正确地与输出相加。

当残差块的输入和输出特征图的通道数不同时,直接相加是不可行的。为了在这种情况下实现跳跃连接,需要对输入进行适当的变换,使得它与输出具有相同的维度。通常的做法是使用1x1卷积层来调整输入的通道数,将channels_in转换成channels_out。这个1x1卷积层不会改变输入的高度和宽度,只改变通道数。如右图所示。 

(图源:7.6. 残差网络(ResNet) — 动手学深度学习 2.0.0 documentation)

BuildingBlock和Bottleneck

Building Block和Bottleneck是两种不同类型的残差块结构,它们的主要区别在于内部的卷积层配置以及计算效率。

BuildingBlock

Building Block是最简单的残差块形式,通常用于较浅的ResNet模型中,比如ResNet-18和ResNet-34。它的基本结构包括:

  • 两个3x3卷积层。
  • 每个卷积层之后接一个批量归一化(Batch Normalization, BN)层。
  • 第一个卷积层之后有一个ReLU激活函数。
  • 跳跃连接(skip connection),将输入直接加到第二个卷积层的输出上。
  • 最后一个ReLU激活函数位于跳跃连接之后。

(图源:基于MindSpore实现ResNet50中药分拣_哔哩哔哩_bilibili ) 

这样的结构简单且有效,但随着网络深度的增加,计算成本也会显著上升,因为每个3x3卷积层都会对大量的特征图进行操作。 

Bottleneck

Bottleneck结构被设计用于更深的ResNet模型,例如ResNet-50、ResNet-101和ResNet-152。它通过引入瓶颈(bottleneck)机制来减少计算量,同时保持或增强模型性能。Bottleneck的基本结构如下:

  • 一个1x1卷积层,用于减少通道数(降维),这被称为瓶颈。
  • 一个3x3卷积层,这个卷积层在较少的通道上执行卷积运算。
  • 另一个1x1卷积层,用于恢复原来的通道数(升维)。
  • 每个卷积层后面都跟着BN层。
  • 在第一个1x1卷积层和3x3卷积层之间有一个ReLU激活函数。
  • 跳跃连接,将原始输入直接加到最后一个1x1卷积层的输出上。
  • 最后一个ReLU激活函数位于跳跃连接之后。

(图源:基于MindSpore实现ResNet50中药分拣_哔哩哔哩_bilibili ) 

Bottleneck使得可以在不大幅增加计算负担的情况下堆叠更多的层,从而构建更深的网络。 

各类ResNet模型架构如下所示:

(图源:基于MindSpore实现ResNet50中药分拣_哔哩哔哩_bilibili )

数据集

本案例使用“中药炮制饮片”数据集,该数据集由程度中医药大学提供,共包含中药炮制饮片的3个品种,分为:蒲黄、山楂、王不留行,每个品种有4个炮制状态:生品、不及、适中、太过。其中每类包含500张图片共12类,图片尺寸为4K,图片格式为jpg。

数据集可视化如下:

(图源:基于MindSpore实现ResNet50中药分拣_哔哩哔哩_bilibili )

基于MindSpore实现ResNet50中药分拣

数据下载及预处理

from download import download
import os

url = "https://mindspore-courses.obs.cn-north-4.myhuaweicloud.com/deep%20learning/AI%2BX/data/zhongyiyao.zip"
# 创建的是调试任务,url修改为数据集上传生成的url链接
if not os.path.exists("dataset"):
    download(url, "dataset", kind="zip")

因图片原尺寸为4K过大,因此需要将其Resize至指定尺寸

from PIL import Image
import numpy as np
data_dir = "dataset/zhongyiyao/zhongyiyao"
new_data_path = "dataset1/zhongyiyao"
if not os.path.exists(new_data_path):
    for path in ['train','test']:
        data_path = data_dir + "/" + path
        classes = os.listdir(data_path)
        for (i,class_name) in enumerate(classes):
            floder_path =  data_path+"/"+class_name
            print(f"正在处理{floder_path}...")
            for image_name in os.listdir(floder_path):
                try:
                    image = Image.open(floder_path + "/" + image_name)
                    image = image.resize((1000,1000))
                    target_dir = new_data_path+"/"+path+"/"+class_name
                    if not os.path.exists(target_dir):
                        os.makedirs(target_dir)
                    if not os.path.exists(target_dir+"/"+image_name):
                        image.save(target_dir+"/"+image_name)
                except:
                    pass     

将数据集划分为训练集、测试集、验证集

from sklearn.model_selection import train_test_split
import shutil

def split_data(X, y, test_size=0.2, val_size=0.2, random_state=42):
    """
    This function splits the data into training, validation, and test sets.
    """
    # Split the data into training and test sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)

    # Split the training data into training and validation sets
    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=val_size/(1-test_size), random_state=random_state)

    return X_train, X_val, X_test, y_train, y_val, y_test


data_dir = "dataset1/zhongyiyao"
floders = os.listdir(data_dir)
target = ['train','test','valid']
if set(floders) == set(target):
    # 如果已经划分则跳过
    pass
elif 'train' in floders:
    # 如果已经划分了train,test,那么只需要从train里边划分出valid
    floders = os.listdir(data_dir)
    new_data_dir = os.path.join(data_dir,'train')
    classes = os.listdir(new_data_dir)
    if '.ipynb_checkpoints' in classes:
        classes.remove('.ipynb_checkpoints')
    imgs = []
    labels = []
    for (i,class_name) in enumerate(classes):
        new_path =  new_data_dir+"/"+class_name
        for image_name in os.listdir(new_path):
            imgs.append(image_name)
            labels.append(class_name)
    imgs_train,imgs_val,labels_train,labels_val = X_train, X_test, y_train, y_test = train_test_split(imgs, labels, test_size=0.2, random_state=42)
    print("划分训练集图片数:",len(imgs_train))
    print("划分验证集图片数:",len(imgs_val))
    target_data_dir = os.path.join(data_dir,'valid')
    if not os.path.exists(target_data_dir):
        os.mkdir(target_data_dir)
    for (img,label) in zip(imgs_val,labels_val):
        source_path = os.path.join(data_dir,'train',label)
        target_path = os.path.join(data_dir,'valid',label)
        if not os.path.exists(target_path):
            os.mkdir(target_path)
        source_img = os.path.join(source_path,img)
        target_img = os.path.join(target_path,img)
        shutil.move(source_img,target_img)
else:
    phones = os.listdir(data_dir)
    imgs = []
    labels = []
    for phone in phones:
        phone_data_dir = os.path.join(data_dir,phone)
        yaowu_list = os.listdir(phone_data_dir)
        for yaowu in yaowu_list:
            yaowu_data_dir = os.path.join(phone_data_dir,yaowu)
            chengdu_list = os.listdir(yaowu_data_dir)
            for chengdu in chengdu_list:
                chengdu_data_dir = os.path.join(yaowu_data_dir,chengdu)
                for img in os.listdir(chengdu_data_dir):
                    imgs.append(img)
                    label = ' '.join([phone,yaowu,chengdu])
                    labels.append(label)
    imgs_train, imgs_val, imgs_test, labels_train, labels_val, labels_test = split_data(imgs, labels, test_size=0.2, val_size=0.2, random_state=42)
    img_label_tuple_list = [(imgs_train,labels_train),(imgs_val,labels_val),(imgs_test,labels_test)]
    for (i,split) in enumerate(spilits):
        target_data_dir = os.path.join(data_dir,split)
        if not os.path.exists(target_data_dir):
            os.mkdir(target_data_dir)
        imgs_list,labels_list = img_label_tuple_list[i]
        for (img,label) in zip(imgs_list,labels_list):
            label_split = label.split(' ')
            source_img = os.path.join(data_dir,label_split[0],label_split[1],label_split[2],img)
            target_img_dir = os.path.join(target_data_dir,label_split[1]+"_"+label_split[2])
            if not os.path.exists(target_img_dir):
                os.mkdir(target_img_dir)
            target_img = os.path.join(target_img_dir,img)
            shutil.move(source_img,target_img)
    

定义数据加载方式

from mindspore.dataset import GeneratorDataset
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms
from mindspore import dtype as mstype
# 注意没有使用Mindspore提供的ImageFloder进行加载,原因是调试任务中'.ipynb_checkpoints'缓存文件夹会被当作类文件夹进行识别,导致数据集加载错误
class Iterable:
    def __init__(self,data_path):
        self._data = []
        self._label = []
        self._error_list = []
        if data_path.endswith(('JPG','jpg','png','PNG')):
            # 用作推理,所以没有label
            image = Image.open(data_path)
            self._data.append(image)
            self._label.append(0)
        else:
            classes = os.listdir(data_path)
            if '.ipynb_checkpoints' in classes:
                classes.remove('.ipynb_checkpoints')
            for (i,class_name) in enumerate(classes):
                new_path =  data_path+"/"+class_name
                for image_name in os.listdir(new_path):
                    try:
                        image = Image.open(new_path + "/" + image_name)
                        self._data.append(image)
                        self._label.append(i)
                    except:
                        pass
                

    def __getitem__(self, index):
        return self._data[index], self._label[index]

    def __len__(self):
        return len(self._data)
    
    def get_error_list(self,):
        return self._error_list
    
def create_dataset_zhongyao(dataset_dir,usage,resize,batch_size,workers):
    data = Iterable(dataset_dir)
    data_set = GeneratorDataset(data,column_names=['image','label'])
    trans = []
    if usage == "train":
        trans += [
            vision.RandomCrop(700, (4, 4, 4, 4)),
            # 这里随机裁剪尺度可以设置
            vision.RandomHorizontalFlip(prob=0.5)
        ]

    trans += [
        vision.Resize((resize,resize)),
        vision.Rescale(1.0 / 255.0, 0.0),
        vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        vision.HWC2CHW()
    ]

    target_trans = transforms.TypeCast(mstype.int32)
    # 数据映射操作
    data_set = data_set.map(
        operations=trans,
        input_columns='image',
        num_parallel_workers=workers)

    data_set = data_set.map(
        operations=target_trans,
        input_columns='label',
        num_parallel_workers=workers)

    # 批量操作
    data_set = data_set.batch(batch_size,drop_remainder=True)

    return data_set

加载数据

import mindspore as ms
import random
data_dir = "dataset1/zhongyiyao"
train_dir = data_dir+"/"+"train"
valid_dir = data_dir+"/"+"valid"
test_dir = data_dir+"/"+"test"
batch_size = 32 # 批量大小
image_size = 224 # 训练图像空间大小
workers = 4 # 并行线程个数
num_classes = 12 # 分类数量


# 设置随机种子,使得模型结果复现
seed = 42
ms.set_seed(seed)
np.random.seed(seed)
random.seed(seed)

dataset_train = create_dataset_zhongyao(dataset_dir=train_dir,
                                       usage="train",
                                       resize=image_size,
                                       batch_size=batch_size,
                                       workers=workers)
step_size_train = dataset_train.get_dataset_size()

dataset_val = create_dataset_zhongyao(dataset_dir=valid_dir,
                                     usage="valid",
                                     resize=image_size,
                                     batch_size=batch_size,
                                     workers=workers)
dataset_test = create_dataset_zhongyao(dataset_dir=test_dir,
                                     usage="test",
                                     resize=image_size,
                                     batch_size=batch_size,
                                     workers=workers)
step_size_val = dataset_val.get_dataset_size()

print(f'训练集数据:{dataset_train.get_dataset_size()*batch_size}\n')
print(f'验证集数据:{dataset_val.get_dataset_size()*batch_size}\n')
print(f'测试集数据:{dataset_test.get_dataset_size()*batch_size}\n')

实现标签映射

#index_label的映射
index_label_dict = {}
classes = os.listdir(train_dir)
if '.ipynb_checkpoints' in classes:
    classes.remove('.ipynb_checkpoints')
for i,label in enumerate(classes):
    index_label_dict[i] = label
label2chin = {'ph_sp':'蒲黄-生品',  'ph_bj':'蒲黄-不及', 'ph_sz':'蒲黄-适中', 'ph_tg':'蒲黄-太过', 'sz_sp':'山楂-生品',
              'sz_bj':'山楂-不及', 'sz_sz':'山楂-适中', 'sz_tg':'山楂-太过', 'wblx_sp':'王不留行-生品', 'wblx_bj':'王不留行-不及',
              'wblx_sz':'王不留行-适中', 'wblx_tg':'王不留行-太过'}
index_label_dict

定义ResNet50

残差块定义

class ResidualBlock(nn.Cell):
    expansion = 4  # 最后一个卷积核的数量是第一个卷积核数量的4倍

    def __init__(self, in_channel: int, out_channel: int,
                 stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None:
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=1, weight_init=weight_init)
        self.norm1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(out_channel, out_channel,
                               kernel_size=3, stride=stride,
                               weight_init=weight_init)
        self.norm2 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion,
                               kernel_size=1, weight_init=weight_init)
        self.norm3 = nn.BatchNorm2d(out_channel * self.expansion)

        self.relu = nn.ReLU()
        self.down_sample = down_sample

    def construct(self, x):

        identity = x  # shortscuts分支

        out = self.conv1(x)  # 主分支第一层:1*1卷积层
        out = self.norm1(out)
        out = self.relu(out)
        out = self.conv2(out)  # 主分支第二层:3*3卷积层
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv3(out)  # 主分支第三层:1*1卷积层
        out = self.norm3(out)

        if self.down_sample is not None:
            identity = self.down_sample(x)

        out += identity  # 输出为主分支与shortcuts之和
        out = self.relu(out)

        return out

创建残差块的函数 

def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],
               channel: int, block_nums: int, stride: int = 1):
    down_sample = None  # shortcuts分支


    if stride != 1 or last_out_channel != channel * block.expansion:

        down_sample = nn.SequentialCell([
            nn.Conv2d(last_out_channel, channel * block.expansion,
                      kernel_size=1, stride=stride, weight_init=weight_init),
            nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)
        ])

    layers = []
    layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))

    in_channel = channel * block.expansion
    # 堆叠残差网络
    for _ in range(1, block_nums):

        layers.append(block(in_channel, channel))

    return nn.SequentialCell(layers)

ResNet定义

from mindspore import load_checkpoint, load_param_into_net
from mindspore import ops


class ResNet(nn.Cell):
    def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],
                 layer_nums: List[int], num_classes: int, input_channel: int) -> None:
        super(ResNet, self).__init__()

        self.relu = nn.ReLU()
        # 第一个卷积层,输入channel为3(彩色图像),输出channel为64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)
        self.norm = nn.BatchNorm2d(64)
        # 最大池化层,缩小图片的尺寸
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        # 各个残差网络结构块定义
        self.layer1 = make_layer(64, block, 64, layer_nums[0])
        self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)
        self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)
        self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)
        # 平均池化层
        self.avg_pool = ops.ReduceMean(keep_dims=True)
        # self.avg_pool = nn.AvgPool2d()
        
        # flattern层
        self.flatten = nn.Flatten()
        # 全连接层
        self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)

    def construct(self, x):
        
        x = self.conv1(x)
        x = self.norm(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avg_pool(x,(2,3))
        
        x = self.flatten(x)
        x = self.fc(x)

        return x

使用预训练的ResNet微调进行预测

def _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],
            layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,
            input_channel: int):
    model = ResNet(block, layers, num_classes, input_channel)

    if pretrained:
        # 加载预训练模型
        download(url=model_url, path=pretrained_ckpt)
        param_dict = load_checkpoint(pretrained_ckpt)
        load_param_into_net(model, param_dict)

    return model


def resnet50(num_classes: int = 1000, pretrained: bool = False):
    resnet50_url = "https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/models/application/resnet50_224_new.ckpt"
    resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"
    return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,
                   pretrained, resnet50_ckpt, 2048)
network = resnet50(pretrained=True)
num_class = 12
# 全连接层输入层的大小
in_channel = network.fc.in_channels
fc = nn.Dense(in_channels=in_channel, out_channels=num_class)
# 重置全连接层
network.fc = fc

for param in network.get_parameters():
    param.requires_grad = True

模型训练与推理

num_epochs = 50
# early stopping
patience = 5
lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size_train * num_epochs,
                        step_per_epoch=step_size_train, decay_epoch=num_epochs)
opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
model = ms.Model(network, loss_fn, opt, metrics={'acc'})

# 最佳模型存储路径
best_acc = 0
best_ckpt_dir = "./BestCheckpoint"
best_ckpt_path = "./BestCheckpoint/resnet50-best.ckpt"

def train_loop(model, dataset, loss_fn, optimizer):
    # Define forward function
    def forward_fn(data, label):
        logits = model(data)
        loss = loss_fn(logits, label)
        return loss, logits

    # Get gradient function
    grad_fn = ms.ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

    # Define function of one-step training
    def train_step(data, label):
        (loss, _), grads = grad_fn(data, label)
        loss = ops.depend(loss, optimizer(grads))
        return loss
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0 or batch == step_size_train - 1:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

from sklearn.metrics import classification_report

def test_loop(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    y_true = []
    y_pred = []
    for data, label in dataset.create_tuple_iterator():
        y_true.extend(label.asnumpy().tolist())
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        y_pred.extend(pred.argmax(1).asnumpy().tolist())
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    print(classification_report(y_true,y_pred,target_names= list(index_label_dict.values()),digits=3))
    return correct,test_loss

开始训练 

no_improvement_count = 0
acc_list = []
loss_list = []
stop_epoch = num_epochs
for t in range(num_epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(network, dataset_train, loss_fn, opt)
    acc,loss = test_loop(network, dataset_val, loss_fn)
    acc_list.append(acc)
    loss_list.append(loss)
    if acc > best_acc:
        best_acc = acc
        if not os.path.exists(best_ckpt_dir):
            os.mkdir(best_ckpt_dir)
        ms.save_checkpoint(network, best_ckpt_path)
        no_improvement_count = 0
    else:
        no_improvement_count += 1
        if no_improvement_count > patience:
            print('Early stopping triggered. Restoring best weights...')
            stop_epoch = t
            break 
print("Done!")

模型推理

import matplotlib.pyplot as plt

num_class = 12  # 
net = resnet50(num_class)
best_ckpt_path = 'BestCheckpoint/resnet50-best.ckpt'
# 加载模型参数
param_dict = ms.load_checkpoint(best_ckpt_path)
ms.load_param_into_net(net, param_dict)
model = ms.Model(net)
image_size = 224
workers = 1

def visualize_model(dataset_test):
    # 加载验证集的数据进行验证
    data = next(dataset_test.create_tuple_iterator())
    # print(data)
    images = data[0].asnumpy()
    labels = data[1].asnumpy()
    # 预测图像类别
    output = model.predict(ms.Tensor(data[0]))
    pred = np.argmax(output.asnumpy(), axis=1)

    # 显示图像及图像的预测值
    plt.figure(figsize=(10, 6))
    for i in range(6):
        plt.subplot(2, 3, i+1)
        color = 'blue' if pred[i] == labels[i] else 'red'
        plt.title('predict:{}  actual:{}'.format(index_label_dict[pred[i]],index_label_dict[labels[i]]), color=color)
        picture_show = np.transpose(images[i], (1, 2, 0))
        mean = np.array([0.4914, 0.4822, 0.4465])
        std = np.array([0.2023, 0.1994, 0.2010])
        picture_show = std * picture_show + mean
        picture_show = np.clip(picture_show, 0, 1)
        plt.imshow(picture_show)
        plt.axis('off')

    plt.show()

visualize_model(dataset_val)

更多内容可参考MindSpore官方教程:

基于MindSpore实现ResNet50中药分拣_哔哩哔哩_bilibili 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/888084.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数据结构与算法——动态规划算法简析

1.初步了解动态规划 由于本篇博客属于动态规划的初阶学习,所以大多都是简单的表示,更深层次的学术用语会在之后深度学习动态规划之后出现,本文主要是带各位了解一下动态规划的大致框架 1.1状态表示 通常的我们会开辟一个dp数组来存储需要表示…

015 品牌关联分类

文章目录 后端CategoryBrandEntity.javaCategoryBrandController.javaCategoryBrandServiceImpl.javaCategoryServiceImpl.javaBrandServiceImpl.java删除 npm install pubsub-jsnpm install --save pubsub-js这个错误是由于在尝试安装 pubsub-js 时,npm 发现了项目…

数据结构(栈和队列的实现)

1. 栈(Stack) 1.1 栈的概念与结构 栈是一种特殊的线性表,其只允许固定的一段插入和删除操作;进行数据插入和删除的一段叫做栈顶,另一端叫栈底;栈中的元素符合后进先出LIFO(Last In First Out&…

C++——模拟实现vector

1.查看vector的源代码 2.模拟实现迭代器 #pragma oncenamespace jxy {//模板尽量不要分离编译template <class T>class vector{public:typedef T* iterator;//typedef会受到访问限定符的限制typedef const T* const_iterator;//const迭代器是指向的对象不能修改&#xf…

透明物体的投射和接收阴影

1、让透明度测试Shader投射阴影 &#xff08;1&#xff09;同样我们使用FallBack的形式投射阴影&#xff0c;但是需要注意的是&#xff0c;FallBack的内容为&#xff1a;Transparent / Cutout / VertexLit&#xff0c;该默认Shader中会把裁剪后的物体深度信息写入到 阴影映射纹…

毕业设计_基于springboot+ssm+bootstrap的旅游管理系统【源码+SQL+教程+可运行】【41001】.zip

毕业设计_基于springbootssmbootstrap的旅游管理系统【源码SQL教程可运行】【41001】.zip 下载地址&#xff1a; https://download.csdn.net/download/qq_24428851/89828190 管理系统 url: http://localhost:8080/managerLoginPageuser: admin password: 123 用户门户网站…

【设计模式-解释模式】

定义 解释器模式是一种行为设计模式&#xff0c;用于定义一种语言的文法&#xff0c;并提供一个解释器来处理该语言的句子。它通过为每个语法规则定义一个类&#xff0c;使得可以将复杂的表达式逐步解析和求值。这种模式适用于需要解析和执行语法规则的场景。 UML图 组成角色…

SPDK从安装到运行hello_world示例程序

SPDK从安装到运行示例程序 #mermaid-svg-dwdwvhrJiTcgTkVf {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-dwdwvhrJiTcgTkVf .error-icon{fill:#552222;}#mermaid-svg-dwdwvhrJiTcgTkVf .error-text{fill:#552222;s…

android compose ScrollableTabRow indicator 指示器设置宽度

.requiredWidth(30.dp) Box(modifier Modifier.background(Color.LightGray).fillMaxWidth()) {ScrollableTabRow(selectedTabIndex selectedTabIndex, // 默认选中第一个标签containerColor ColorPageBg,edgePadding 1.dp, // 内容与边缘的距离indicator { tabPositions…

【本地缓存】Java 中的 4 种本地缓存

目录 1、手写一个简单的本地缓存1.1、封装缓存实体类1.2、创建缓存工具类1.3、测试 2、Guava Cache2.1、Guava Cache 简介2.2、入门案例2.2.1、引入 POM 依赖2.2.2、创建 LoadingCache 缓存 2.3、Guava Cache 的优劣势和适用场景 3、Caffeine3.1、Caffeine 简介3.2、对比 Guava…

图的基本概念 - 离散数学系列(五)

目录 1. 图的定义 节点与边 2. 度与路径 节点的度 路径与圈 3. 图的连通性 连通图与非连通图 强连通与弱连通 连通分量 4. 实际应用场景 1. 社交网络 2. 城市交通系统 3. 网络结构 5. 例题与练习 例题1&#xff1a;节点的度 例题2&#xff1a;判断连通性 练习题…

linux基础 超级笔记

1.Linux系统的组成 Linux系统内核&#xff1a;提供系统最核心的功能&#xff0c;如软硬件和资源调度。 系统及应用程序&#xff1a;文件、任务管理器。 2.Linux发行版 通过修改内核代码自行集成系统程序&#xff0c;即封装。比如Ubuntu和centos这种。不过基础命令是完全相…

【瑞昱RTL8763E】刷屏

1 显示界面填充 用户创建的各个界面在 rtk_gui group 中。各界面中 icon[]表对界面进行描述&#xff0c;表中的每个元素代表一 个显示元素&#xff0c;可以是背景、小图标、字符等&#xff0c;UI_WidgetTypeDef 结构体含义如下&#xff1a; typedef struct _UI_WidgetTypeDef …

vite学习教程03、vite+vue2打包配置

文章目录 前言一、修改vite.config.js二、配置文件资源/路径提示三、测试打包参考文章资料获取 前言 博主介绍&#xff1a;✌目前全网粉丝3W&#xff0c;csdn博客专家、Java领域优质创作者&#xff0c;博客之星、阿里云平台优质作者、专注于Java后端技术领域。 涵盖技术内容&…

【深度强化学习基础】(一)基本概念

【深度强化学习基础】&#xff08;一&#xff09;基本概念 一、概率论基础知识二、强化学习领域术语三、强化学习中两个随机性的来源&#xff1a;四、rewards以及returns五、Value Functions1.Action-Value Function Q π ( s , a ) Q_\pi(s,a) Qπ​(s,a)2.State-Value Funct…

【高等数学学习记录】函数的极限

一、知识点 &#xff08;一&#xff09;知识结构 #mermaid-svg-Dz0Ns0FflWSBWY50 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-Dz0Ns0FflWSBWY50 .error-icon{fill:#552222;}#mermaid-svg-Dz0Ns0FflWSBWY50 .erro…

影刀---如何进行自动化操作

本文不是广告&#xff0c;没有人给我宣传费&#xff0c;只是单纯的觉得这个软件很好用 感谢大家的多多支持哦 本文 1.基本概念与操作&#xff08;非标准下拉框和上传下载&#xff09;非标准对话框的操作上传对话框、下载的对话框、提示的对话框 2.综合案例3.找不到元素怎么办&a…

Leecode刷题之路第12天之整数转罗马数字

题目出处 12-整数转罗马数字-题目出处 题目描述 个人解法 思路&#xff1a; todo 代码示例&#xff1a;&#xff08;Java&#xff09; todo复杂度分析 todo 官方解法 12-整数转罗马数字-官方解法 方法1&#xff1a;模拟 思路&#xff1a; 代码示例&#xff1a;&#xff08…

class 032 位图

这篇文章是看了“左程云”老师在b站上的讲解之后写的, 自己感觉已经能理解了, 所以就将整个过程写下来了。 这个是“左程云”老师个人空间的b站的链接, 数据结构与算法讲的很好很好, 希望大家可以多多支持左程云老师, 真心推荐. 左程云的个人空间-左程云个人主页-哔哩哔哩视频…

SpringBoot项目:前后端打包与部署(使用 Maven)

文章目录 IDEA后端打包与部署&#xff08;使用 Maven&#xff09;1. 确保 Maven 已安装&#xff0c;并引入 pom 插件2. 清理并安装项目3. 定位生成的 JAR 包和配置文件4. 创建部署文件夹5. 上传到服务器 前端打包与部署&#xff08;使用 npm&#xff09;1. 确保 Node.js 和 npm…