基于Mindspore,通过Resnet50迁移学习实现猫十二分类

使用平台介绍

使用平台:启智AI协作平台
使用数据集:百度猫十二分类

数据集介绍

有cat_12_train和cat_12_test和train_list.txt
train_list.txt内有每张图片所对应的标签

Minspore部分操作科普

数据集加载

Mindspore加载图片数据集就直接调整成这种格式就行,然后可以用这个函数加载,自动生成两个列,一列是图片,一列是标签;ImageFolderDataset函数会自动读取和处理数据集,标签就是文件夹的名称
在这里插入图片描述
在这里插入图片描述

数据处理

map函数里可以一键进行处理和映射,定义好数据处理函数,直接把路径和标签Map处理,后面可以带上是训练集还是测试集的标签;,要处理图片就指定 input_columns 参数为image,这个是前面数据读取形成的;
本项目这里用的是ImageFolderDataset,可以自动生成图片对应数据的两个列,要是相对数据处理就设置为数据列名即可
在这里插入图片描述

数据批处理和重复

batch就是批处理,把数据分成指定数量的一个个批次
在这里插入图片描述
最后repeat对数据进行重复
在这里插入图片描述

整体数据处理流程

就是读取数据形成数据和标签对应的列(读取数据函数有很多),然后定义数据预处理函数,在map函数里一键映射,指定要处理的列一键处理,最后对数据进行批次划分,就拿到可以放进训练网络函数的规范数据集了。当然使用时候还要用create_tuple_iterator或者create_dict_iterator函数形成可迭代的数据集。

迁移学习

在迁移学习中,固定特征训练和模型微调都是常用的技术

固定特征训练

在源任务上训练一个模型,并将其应用到目标任务上。在这个过程中,模型的特征提取器是固定的,只对输出层进行调整。这种方法可以利用源任务中已经学习到的特征,从而减少目标任务的训练时间和数据需求。固定特征训练通常适用于目标任务和源任务具有相似的特征空间,并且目标任务的数据量较小的情况。

模型微调

使用预训练模型作为初始模型,并在目标任务的数据集上进行进步训练。在微调阶段,可以根据目标任务的数据和特定要求调整模型的参数使其适应目标任务。模型微调的主要目的是通过在目标任务上的有限训练来调整预训练模型,以取得更好的性能。

举例

以训练一个猫、狗分类器为例,固定特征训练是指在一个大型的猫狗图片数据集上训练一个通用的图像识别模型,然后将该模型应用于特定的猫、狗分类任务。在这个过程中,我们只需要调整模型的输出层,使其能够正确地对猫和狗进行分类。而模型微调则是指使用一个已经在大型数据集上训练好的通用图像识别模型,然后在特定的猫、狗图片数据集上进行进一步训练,以优化模型的性能。在这个过程中我们可以根据猫、狗图片的特点来调整模型的参数,使其能够更好地识别猫和狗。

数据处理过程

由于本项目采用的数据集是百度所提供的猫十二分类,要使用ImageFolderDataset的话,形式不太匹配

现有形式

cat_12_train和cat_12_test里面都是一张一张的图片
train_list.txt内有每张图片所对应的标签
在这里插入图片描述

处理后形式

train和val文件夹内分别有十二个子文件夹,代表12类猫,每个子文件夹内又有一张张的图片
在这里插入图片描述在这里插入图片描述

处理代码

这里有相关代码进行自动划分,但是对于训练集和测试集的划分,我直接采用了手动操作,也可以用代码来实现的;

# 处理异常图片
dir_lit = os.listdir('./work/cat_12_train/')
# dir_lit为一个列表,里面是一张张图片的名称
for list in dir_lit:
# list是图片名称,这里的操作是把这个图片形成一个个的路径
    img_path=os.path.join('./work/cat_12_train/',list)
    print(img_path)
    # 如果不是RGB那就转换为RGB
    img=Image.open(img_path)
    if img.mode != 'RGB':
        img = img.convert('RGB')
        img.save(img_path)

dir_lit = os.listdir('./work/cat_12_test/')
for list in dir_lit:
    img_path=os.path.join('./work/cat_12_test/',list)
    img=Image.open(img_path)
    if img.mode != 'RGB':
        img = img.convert('RGB')
        img.save(img_path)
# 整理数据格式
# 创建12个文件夹分别对应标签
path='./work/MyDataset/'
for i in range(12):
    if not os.path.exists(path+str(i)):
        os.mkdir(path+str(i))
    else:
        continue
#读取每一行
with open(f'./work/train_list.txt','r')as f:
    img_path=f.readlines()
    print(img_path)
    # 里面是一个个的'cat_12_train/8GOkTtqw7E6IHZx4olYnhzvXLCiRsUfM.jpg\t0\n'

# 把对应文件放到对应标签文件夹下
for img in img_path:
    # 拿取每一张图片路径
    # print(img)
    img_src= img.split('\t')[0]
    # img_src为一个个图片路径
    rel_src= img_src.split('cat_12_train/')[1]
    # rel_src为图片名称
    img_label = img.split('\t')[1]
    img_label = img_label.split('\n')[0]
    # img_label为图片标签
    print(img_src)
    print(rel_src)
    print(img_label)
    # os.system(f'cp ./work/{img_src} ./work/MyDataset/{img_label}/{rel_src}')
    shutil.copy(f'./work/{img_src}',f'./work/MyDataset/{img_label}/{rel_src}')
print('图片处理完毕')

整体代码

# 解压上传的数据集压缩包并查看数据集结构
!unzip MyDataset.zip -d data/
import os
print(os.listdir("data"))
print(os.listdir("data/train"))
print(os.listdir('data/train/1'))

输出

['val', 'train']
['9', '10', '5', '8', '11', '2', '7', '3', '0', '1', '6', '4']
['DKkQylbgdrWRjYap63MCJe0UBLhcHXPm.jpg', 'k9HWNaG2Z1wUKAOYdSDu7vRr4xBqmTCV.jpg', 'LfIoOrSNvKHQzsGtm5eMZc0lRuBXhTP6.jpg', '0esFjXNqc5xbMmUaJkRVPwQorWlu3LvA.jpg', 'PMoFIabq0W9U2wETZr7yf4JLYdBxv6hQ.jpg', 'QlUX4zHfPZ3LxRDswqm5FeMbnTWNaj6g.jpg', '7E9oOUcQjkLMvpAtNymHCRSqFfdGVDK4.jpg', 'pznq7EivBH9LwrNysIWxgGeomTlOP8cZ.jpg', '3Ndv9X6uTgzFtnoA01VECIBPj7xqlewG.jpg', 'o1g6adKmS4lBDw2F5buAYnetUWh7xXGz.jpg', '7QZTYlspK2fqdJUwjC0HDmOFrM5W4PX9.jpg', 'qgimvDE8Zaf4PJ32dkNhwVy5nxATOtrX.jpg', 'RJnThakSOGUzeFBdigXAm2NsL8jyYvu3.jpg', 'Ig61xq3ME78fdCRTDWhaKkcyuOQj24PG.jpg', 'tAdqSefI0DohNuU6wgVyPca7Qz5lYTOH.jpg', 'pTe31kYFqwyOGmV50sbhgoLQ9KcjJaxd.jpg', 'HUuwb4gRqoPWD3Lvrsa9hVcQ7FSfOT8t.jpg', 'cqkJDEpWiwS69UxFtKMPRgb4mXhj1LAs.jpg', '2pjBVbqF30cUTvRIYtsCGfgwPKOJz4ua.jpg', 'e7f1iucltpVQTXroFR96xawm2BDZYnNG.jpg', 'tWcMpXTe78zo2ikhbUOqPud6VJ5RfSEw.jpg', 'gcTxA5NwLztvWr7YPCMnDjdFyfqoa2uK.jpg', 'qbKjsR05lrFVYfLChtMGD7im36cUgAnE.jpg', 'S4hfUR5kOj7CXr826Gxa9t3bEBPioJq1.jpg', 'A8PMtHzoyFE0WgjpZ2qUYbduL4T9arxN.jpg', '8E0bSi3h1aVy2cNWpgOsKvxZCQtzqkLU.jpg', 'oDc9XxipzfBjUAEl0hOmyd4PNr5v8IsG.jpg', 'AVGJoCPsX3L6I5Y2M94kEO7vNHmt1Ure.jpg', 'SqyF8c0Rak1NedXpYvlI9TsVwzOhtGZJ.jpg', '04Iv3QNKtu2DAfRTgs9XZwBMb1Cm7l6P.jpg', '94NwrzYLo8iMtagTR3SfPGHmWvZXbyUx.jpg', 'oiU4YjnhNpI3JWagx8SuTCktA6qZXGRH.jpg', 'e416wAERYOQ7NutUJDcIVFk2oPWpC3q8.jpg', '9I3enpUrZ5xD2TvRAOFt4S7lBfVMdqsJ.jpg', 'oLNGFnUPmQhxOkdbv37HwSj8uql4z1sp.jpg', 'nHfDoId8SXKzMt1weR7bJlaWPcNx32yv.jpg', 'dzV1Psxncp6H8g4KWhX3mbrTfqwuLaNv.jpg', 'OW6e1GbpNsfmxFvLQKMnIByX4hDcS2io.jpg', 'MOGw0PDqjmnLdViez26b7WY3hU85vatH.jpg', '3R5BWakTdG2hKjJoiNxg0pr61LsDqSuM.jpg', '41OaVziAEuenpKqv7LCYMPsGH6BkQotD.jpg', 'fRbdkW0GDAhBjpTVeonPItycEia45Ns9.jpg', 'RsYG3VJi7NTXptoPQvWKhcFaDqIe4EkC.jpg', 'jWPXtA07yYrcRxNBUkwC9dpuSs2M463e.jpg', 'bqRVATEuI4x3kJSO9DitWjYms8KoG2Lz.jpg', 'OCXPGzodQsZHRnfMFaBkqW9hKYxA4glr.jpg', 'hAPzcCeE04sSadDIFB627ipyOWgjX8mq.jpg', 'PIbktpOd2DqHwVceLzUyE7CmohMjNA91.jpg', 'adAjP14SXL6vVJ95TRrMIYDiHl7BUqbF.jpg', 'hTxYnXrQ35vwKL1NEMSIot02djHy8i6D.jpg', 'kwuLVmg7n9I4iEOzMQC1NxJfvX6Bhoqs.jpg', '78WTn5auMmQshIZi2qDAYdR1oKcwEzfk.jpg', 'hOEPm8o26CBptkD5yT3fsgbM4dRuVZzw.jpg', 'PIpXbRiu68dm2s1DfxHJGAYLegOUMzoB.jpg', '8NxvitwMaCpsuEQT17nDXzFR5gAZ4rfc.jpg', 'JgxcdpvW7f6lKMShPjHFeZ2RDX13UCiN.jpg', 'BhHLRN0QTWOwl3UEG9J17XScni2P5gVe.jpg', '7Gw2o8LJTF4ecZI6nl1WuDrsAOSfQPaq.jpg', 't6xZhQkD2jWCOi1r3fK7T9slGHVbwgNd.jpg', 'gbfyYtlWaAO4iUCPK2cFVkoQX9MmwJTI.jpg', '9PQic6o4VyZ13pLAYu2avFWSbJRz57fC.jpg', 'Cf2Z3j6hYliVOduEvK5NJp4yba0wSGcA.jpg', '2QvYgMIzELXH4Fy8GNDBaPS3W0tVZ5xq.jpg', 'oFXrWl80gRMenqPbG2uZv5wk4KmHjaNd.jpg', 'l1NsjeJKdvFimRgh6IEZuqfxCw5An8o4.jpg', 'tvWgSwN9m3BZ5qOXjK7LexVCIn6F4AHk.jpg', 'ahcTZUbOmJsloVt8vGMjwPqIXd0x9iy1.jpg', 'hp0nNWXar81lB3eSYE9kGcdDL4tJfuwC.jpg', 'HNehmorRIS6M2iDj48gZL9OKva7Xck1n.jpg', '0GX4YKdcwBi15lTpR7ExWO2ZagseoVNI.jpg', 'b57SiGKYPaE1DrfxJtVeQdlOAUojLZhR.jpg', 'f4gLhHjyKdxTumna7F5pGWPqVIRY01et.jpg', 'AHxQ1GFgRLs802diT3VIlwOoqkW9Sar4.jpg', '09i4DcyrWktZb1naHFEpL5elhG3CvYxu.jpg', 'AieqOGKQC7fgDayW9kLuJ5mUHx4XP2I6.jpg', 'MwZIekE7oxPtRpTHQVf4l6qA2iC1zWLh.jpg', '6M8xAZBdQLkuDcF14HEz53J0IboiPfUa.jpg', 'eD4gfaQTFdoWUCnhPj6YmIZBl5AMxNik.jpg', 'PwBJK7rZHDhq46ynYoj9Saxip0IldMV2.jpg', 'neHaTbwPkdVmoOA3JyWxR7Lh92NDpzEf.jpg', 'E38k6xhQFYKALn4tDlwiPBfpdCygeSNs.jpg', 'N1VpzqjoRmPZCQ5KasAv72TwtMDFrG0d.jpg', 'RvXKbfDuF4W6exgVcInE3SktJa9LzBj8.jpg', 'aMoxSymjdiUwbJ6k5NzGR09uILQ4sEc3.jpg', 'l1WfIcvOZk9jAn2xQwtSCEgY5XhRyoFB.jpg', '0IWfLUGk53iHt9SElNzKsBCDwuMjpPbR.jpg', '80vskwDtCRAz9iWYjnrhIGfeXdUZx16b.jpg', '0WglfKCD5Gu2LqI1msTSZa6orBO7XAz9.jpg', 'os5kaDubPM1hY7f6gRrSOZqNQFEAU89v.jpg', 'm2azqs0NGPDdjR8rUTxWF4covLE5ikQZ.jpg', 'siKAzUrV9eykjlCQ0odZEhnW7FIgTuLm.jpg', 'jZWaf0ne2R1pDo9hTBkCbA8YOq3LlQ5x.jpg', 'Kny8zFiIt4vxNSO6g7Lu9kGfdVJoqPC3.jpg', 'LTMkHx9w2nfsRiZec3bEVtmujpv7qS1y.jpg', 'oZin4PuwTet39xWCYhUBfvlzGyISb5DV.jpg', 'oJ4HWQkZDvta8rUyinRu9fVNs3BX1Kj7.jpg', 'mUp6082yMQghXY13OtvxabTrNSEeiu4B.jpg', 'RZWpn9jGxcKSUb3Y56fVMQHlJhEIeNiA.jpg', 'L86JQlekn9Ko01TbXHYMFImdZ2upxg5h.jpg', 'I0YcgXB97QL6MtHlU3p1znqWdCGOD4mo.jpg', 'IG9NFCfMybKYiQhquOd5H3DjwlSakW6U.jpg', 'SZYosxl4cHRWyT3h5JFqNjGdnIag6907.jpg', '4dMVtGvRJbrjK6X17STZ9Lx3kgeEioqp.jpg', 'q9YrDFK73Jfv15SHpTWelGAIwnBxt0iE.jpg', 'spNU7J8uk6BXiAyQErHegYMzjOaFR2qV.jpg', 'p1ji352o6vhd8l0Q4uNVRZrIgkSLnfBq.jpg', '7WQ0ByMPtJAdZ8h5OkveLi3ScuU6bIY1.jpg', 'hajCi0GDlVP2ONg6FeSWrvubQ34ozwkx.jpg', 'RKLDkUwmFg0Oj5tPeIs31y7J8AQZ9dni.jpg', 'hWOAp1EV6nJzYxaHt03T8GPNe7ujUiF9.jpg', 'B0a2VHnwQv1byMDTlEiOJXxI7Scs9Zjf.jpg', 'gMzOoyTrGniBj1vxN0AeD9VQsFHU7aKR.jpg', 'CiBq0GVawv1rdYyLDjcWoIXP6SKbzH8F.jpg', 'G71cYNEBD6shJLkgVzwb52m8oRluKUHS.jpg', '7IdLnFCb3a25cKNV6tXuYi1fe0hJQMOU.jpg', 'jsThJuVYQxUKSz3btXdA5q8M1O9Cioaw.jpg', '2OpyK1cm85obujwEMqGWNv9V7PnQfJ3U.jpg', 'DLIZr5TjPepd8csioJXMbYHk94RmKx6v.jpg', 'Q80DEFkGlxJj2qR37t4ZKpY6zMdvuIyr.jpg', 'EFxXsVJ0qHkomcBhnLfY96W5U4yOliQG.jpg', 'okw9N05dAnsxgW2IuQy7eGhz1iLOqVrJ.jpg', 'obRL95fxtP1uCNBwQiTjWsdqUvgp2Z43.jpg', '6jTZ5sfCpGwJWIK3DaYQvixLbNt48nHr.jpg', 'gLqBoG3ah0AHXIYWS7dFTt6pxDw8snQv.jpg', '1lzs3kM8NiILvcgDYtn6fdCoSeXauJ5P.jpg', 'FEuyDnKSIJ1a6UtY5LB2rGpRqOm3xP9Z.jpg', 'OKvn2uJmWQi4R9Cs8B7fxbkZtoczq31Y.jpg', 'ByYKkZHb8omRPcvfe5GzXsxOQ3DlLuUq.jpg', 'SKoiaj8C3UGyvJQXh5zWwrxNmYkdEHqn.jpg', '2cKUvXCjm5HNWksY1b4ioIgdSFqyMtEJ.jpg', 'kJK9OA3hXpMWeUY7cifvrz0BItn1VS2T.jpg', 't1DnLxSZXwWTgeJsyE02lrjHfdM35po8.jpg', '5C76eISyb3vmPZuMYcARHU8aFQrBWf1k.jpg', 'puBcg8Fh6tXs27doz1aAIl4L0iVYC3wE.jpg', 'MV5C7YmuzG1LyZplFXvqOQkW4JStjcNP.jpg', 'ruleKNQvzwqmy5sn9MDd7I2RUJjVCWh8.jpg', 'l9Z3gPwjC5HbhINcfVO8dnz1qAxBrJkU.jpg', 'H9BcFOo8UI3jX2CyW0mzxn7agJNAsZQS.jpg', 'jHUJE37YZOGAXInPmyCSp9f0o4uvRe5W.jpg', 'I8jNkAVgZ1yqDw5K9b0Wm4rETfiGBcUF.jpg', 'kKzQrE6GjfpeFhsXx2Ddu9YaTHc3PUbB.jpg', 'jla5O2TkVhefr07XDLpMEonuG6yJWgYd.jpg', 'Km9BZsaSUoxQ4VArcXYyHThIDRbq2t7l.jpg', 'fBp0Yor4EQtWkM7I3TsnNHLXuvCFacjS.jpg']

超参数设置

batch_size = 18                             # 批量大小
image_size = 224                            # 训练图像空间大小
num_epochs = 10                             # 训练周期数
lr = 0.001                                  # 学习率
momentum = 0.9                              # 动量
workers = 4                                 # 并行线程个数

数据预处理

import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision

# 数据集目录路径
data_path_train = "data/train/"
data_path_val = "data/val/"

# 创建训练数据集

def create_dataset_canidae(dataset_path, usage):
    """数据加载"""
    data_set = ds.ImageFolderDataset(dataset_path,
                                     num_parallel_workers=workers,
                                     shuffle=True,)
    # 数据增强操作
    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
    scale = 32
    if usage == "train":
        # Define map operations for training dataset
        trans = [
            vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            vision.RandomHorizontalFlip(prob=0.5),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]
    else:
        # Define map operations for inference dataset
        trans = [
            vision.Decode(),
            vision.Resize(image_size + scale),
            vision.CenterCrop(image_size),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]

    # 数据映射操作
    data_set = data_set.map(
        operations=trans,
        input_columns='image',
        num_parallel_workers=workers)
    # 批量操作
    data_set = data_set.batch(batch_size)
    return data_set


dataset_train = create_dataset_canidae(data_path_train, "train")
step_size_train = dataset_train.get_dataset_size()
dataset_val = create_dataset_canidae(data_path_val, "val")
step_size_val = dataset_val.get_dataset_size()
print(step_size_train)
print(step_size_val)
data = next(dataset_val.create_dict_iterator())
images = data["image"]
labels = data["label"]
print("Tensor of image", images.shape)
print("Labels:", labels)

输出

96
24
Tensor of image (18, 3, 224, 224)
Labels: [ 1  2  4  3  0 10  5  4 11  9 11  6  7  1 11  5  1  3]

数据集可视化查看

import matplotlib.pyplot as plt
import numpy as np

# class_name对应label,按文件夹字符串从小到大的顺序标记label
class_name = {0: "0", 1: "1",2: "2", 3: "3",4: "4", 5: "5",6: "6", 7: "7",8: "8", 9: "9",10: "10", 11: "11",12: "12"}

plt.figure(figsize=(5, 5))
for i in range(4):
    # 获取图像及其对应的label
    data_image = images[i].asnumpy()
    data_label = labels[i]
    # 处理图像供展示使用
    data_image = np.transpose(data_image, (1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    data_image = std * data_image + mean
    data_image = np.clip(data_image, 0, 1)
    # 显示图像
    plt.subplot(2, 2, i+1)
    plt.imshow(data_image)
    plt.title(class_name[int(labels[i].asnumpy())])
    plt.axis("off")

plt.show()

在这里插入图片描述

网络结构搭建

from typing import Type, Union, List, Optional
from mindspore import nn, train
from mindspore.common.initializer import Normal


weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)
class ResidualBlockBase(nn.Cell):
    expansion: int = 1  # 最后一个卷积核数量与第一个卷积核数量相等

    def __init__(self, in_channel: int, out_channel: int,
                 stride: int = 1, norm: Optional[nn.Cell] = None,
                 down_sample: Optional[nn.Cell] = None) -> None:
        super(ResidualBlockBase, self).__init__()
        if not norm:
            self.norm = nn.BatchNorm2d(out_channel)
        else:
            self.norm = norm

        self.conv1 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=3, stride=stride,
                               weight_init=weight_init)
        self.conv2 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=3, weight_init=weight_init)
        self.relu = nn.ReLU()
        self.down_sample = down_sample

    def construct(self, x):
        """ResidualBlockBase construct."""
        identity = x  # shortcuts分支

        out = self.conv1(x)  # 主分支第一层:3*3卷积层
        out = self.norm(out)
        out = self.relu(out)
        out = self.conv2(out)  # 主分支第二层:3*3卷积层
        out = self.norm(out)

        if self.down_sample is not None:
            identity = self.down_sample(x)
        out += identity  # 输出为主分支与shortcuts之和
        out = self.relu(out)

        return out
class ResidualBlock(nn.Cell):
    expansion = 4  # 最后一个卷积核的数量是第一个卷积核数量的4倍

    def __init__(self, in_channel: int, out_channel: int,
                 stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None:
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=1, weight_init=weight_init)
        self.norm1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(out_channel, out_channel,
                               kernel_size=3, stride=stride,
                               weight_init=weight_init)
        self.norm2 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion,
                               kernel_size=1, weight_init=weight_init)
        self.norm3 = nn.BatchNorm2d(out_channel * self.expansion)

        self.relu = nn.ReLU()
        self.down_sample = down_sample

    def construct(self, x):

        identity = x  # shortscuts分支

        out = self.conv1(x)  # 主分支第一层:1*1卷积层
        out = self.norm1(out)
        out = self.relu(out)
        out = self.conv2(out)  # 主分支第二层:3*3卷积层
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv3(out)  # 主分支第三层:1*1卷积层
        out = self.norm3(out)

        if self.down_sample is not None:
            identity = self.down_sample(x)

        out += identity  # 输出为主分支与shortcuts之和
        out = self.relu(out)

        return out
def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],
               channel: int, block_nums: int, stride: int = 1):
    down_sample = None  # shortcuts分支


    if stride != 1 or last_out_channel != channel * block.expansion:

        down_sample = nn.SequentialCell([
            nn.Conv2d(last_out_channel, channel * block.expansion,
                      kernel_size=1, stride=stride, weight_init=weight_init),
            nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)
        ])

    layers = []
    layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))

    in_channel = channel * block.expansion
    # 堆叠残差网络
    for _ in range(1, block_nums):

        layers.append(block(in_channel, channel))

    return nn.SequentialCell(layers)
from mindspore import load_checkpoint, load_param_into_net


class ResNet(nn.Cell):
    def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],
                 layer_nums: List[int], num_classes: int, input_channel: int) -> None:
        super(ResNet, self).__init__()

        self.relu = nn.ReLU()
        # 第一个卷积层,输入channel为3(彩色图像),输出channel为64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)
        self.norm = nn.BatchNorm2d(64)
        # 最大池化层,缩小图片的尺寸
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        # 各个残差网络结构块定义,
        self.layer1 = make_layer(64, block, 64, layer_nums[0])
        self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)
        self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)
        self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)
        # 平均池化层
        self.avg_pool = nn.AvgPool2d()
        # flattern层
        self.flatten = nn.Flatten()
        # 全连接层
        self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)

    def construct(self, x):

        x = self.conv1(x)
        x = self.norm(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.fc(x)

        return x


def _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],
            layers: List[int], num_classes: int, pretrained: bool, pretrianed_ckpt: str,
            input_channel: int):
    model = ResNet(block, layers, num_classes, input_channel)

    if pretrained:
        # 加载预训练模型
        # download(url=model_url, path=pretrianed_ckpt, replace=True)
        param_dict = load_checkpoint(pretrianed_ckpt)
        load_param_into_net(model, param_dict)

    return model


def resnet50(num_classes: int = 1000, pretrained: bool = False):
    "ResNet50模型"
    resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"
    resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"
    return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,
                   pretrained, resnet50_ckpt, 2048)

形式一:模型微调

模型训练
from mindspore import nn, train
from mindspore.nn import Loss, Accuracy
!pip install download
import mindspore as ms
from download import download
network = resnet50(pretrained=True)

# 全连接层输入层的大小
in_channels = network.fc.in_channels
# 输出通道数大小为狼狗分类数2
head = nn.Dense(in_channels, 12)
# 重置全连接层
network.fc = head

# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
network.avg_pool = avg_pool

import mindspore as ms
import mindspore

# 定义优化器和损失函数
opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=momentum)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

# 实例化模型
model = train.Model(network, loss_fn, opt, metrics={"Accuracy": Accuracy()})

def forward_fn(inputs, targets):

    logits = network(inputs)
    loss = loss_fn(logits, targets)

    return loss

grad_fn = mindspore.ops.value_and_grad(forward_fn, None, opt.parameters)

def train_step(inputs, targets):

    loss, grads = grad_fn(inputs, targets)
    opt(grads)

    return loss
    
# 创建迭代器
data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs)
# 最佳模型保存路径
best_ckpt_dir = "./BestCheckpoint"
best_ckpt_path = "./BestCheckpoint/resnet50-best.ckpt"
import os
import time

# 开始循环训练
print("Start Training Loop ...")

best_acc = 0

for epoch in range(num_epochs):
    losses = []
    network.set_train()

    epoch_start = time.time()

    # 为每轮训练读入数据
    for i, (images, labels) in enumerate(data_loader_train):
        labels = labels.astype(ms.int32)
        loss = train_step(images, labels)
        losses.append(loss)

    # 每个epoch结束后,验证准确率

    acc = model.eval(dataset_val)['Accuracy']

    epoch_end = time.time()
    epoch_seconds = (epoch_end - epoch_start) * 1000
    step_seconds = epoch_seconds/step_size_train

    print("-" * 20)
    print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (
        epoch+1, num_epochs, sum(losses)/len(losses), acc
    ))
    print("epoch time: %5.3f ms, per step time: %5.3f ms" % (
        epoch_seconds, step_seconds
    ))

    if acc > best_acc:
        best_acc = acc
        if not os.path.exists(best_ckpt_dir):
            os.mkdir(best_ckpt_dir)
        ms.save_checkpoint(network, best_ckpt_path)

print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "
      f"save the best ckpt file in {best_ckpt_path}", flush=True)

输出

Start Training Loop ...
--------------------
Epoch: [  1/ 10], Average Train Loss: [1.774], Accuracy: [0.838]
epoch time: 60892.337 ms, per step time: 634.295 ms
--------------------
Epoch: [  2/ 10], Average Train Loss: [0.762], Accuracy: [0.905]
epoch time: 8745.406 ms, per step time: 91.098 ms
--------------------
Epoch: [  3/ 10], Average Train Loss: [0.568], Accuracy: [0.921]
epoch time: 8449.129 ms, per step time: 88.012 ms
--------------------
Epoch: [  4/ 10], Average Train Loss: [0.508], Accuracy: [0.910]
epoch time: 8199.763 ms, per step time: 85.414 ms
--------------------
Epoch: [  5/ 10], Average Train Loss: [0.459], Accuracy: [0.900]
epoch time: 7856.060 ms, per step time: 81.834 ms
--------------------
Epoch: [  6/ 10], Average Train Loss: [0.405], Accuracy: [0.931]
epoch time: 8138.927 ms, per step time: 84.780 ms
--------------------
Epoch: [  7/ 10], Average Train Loss: [0.368], Accuracy: [0.919]
epoch time: 8333.523 ms, per step time: 86.808 ms
--------------------
Epoch: [  8/ 10], Average Train Loss: [0.354], Accuracy: [0.912]
epoch time: 8271.008 ms, per step time: 86.156 ms
--------------------
Epoch: [  9/ 10], Average Train Loss: [0.338], Accuracy: [0.928]
epoch time: 8457.969 ms, per step time: 88.104 ms
--------------------
Epoch: [ 10/ 10], Average Train Loss: [0.338], Accuracy: [0.907]
epoch time: 8183.743 ms, per step time: 85.247 ms
================================================================================
End of validation the best Accuracy is:  0.931, save the best ckpt file in ./BestCheckpoint/resnet50-best.ckpt
模型评估
import matplotlib.pyplot as plt
import mindspore as ms

def visualize_model(best_ckpt_path, val_ds):
    net = resnet50()
    # 全连接层输入层的大小
    in_channels = net.fc.in_channels
    # 输出通道数大小为分类数12
    head = nn.Dense(in_channels, 12)
    # 重置全连接层
    net.fc = head
    # 平均池化层kernel size为7
    avg_pool = nn.AvgPool2d(kernel_size=7)
    # 重置平均池化层
    net.avg_pool = avg_pool
    # 加载模型参数
    param_dict = ms.load_checkpoint(best_ckpt_path)
    ms.load_param_into_net(net, param_dict)
    model = train.Model(net)
    #print(net)
    # 加载验证集的数据进行验证
    data = next(val_ds.create_dict_iterator())
    images = data["image"].asnumpy()
    print(type(images))
    print(images.shape)
    #print(images)
    labels = data["label"].asnumpy()
    #print(labels)
    class_name = {0: "0", 1: "1",2: "2", 3: "3",4: "4", 5: "5",6: "6", 7: "7",8: "8", 9: "9",10: "10", 11: "11",12: "12"}
    # 预测图像类别
    data_pre=ms.Tensor(data["image"])
    print(data_pre.shape)
    print(type(data_pre))
    output = model.predict(data_pre)
    
    #print(output)
    pred = np.argmax(output.asnumpy(), axis=1)

    # 显示图像及图像的预测值
    plt.figure(figsize=(5, 5))
    for i in range(4):
        plt.subplot(2, 2, i + 1)
        # 若预测正确,显示为蓝色;若预测错误,显示为红色
        color = 'blue' if pred[i] == labels[i] else 'red'
        plt.title('predict:{}'.format(class_name[pred[i]]), color=color)
        picture_show = np.transpose(images[i], (1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        picture_show = std * picture_show + mean
        picture_show = np.clip(picture_show, 0, 1)
        plt.imshow(picture_show)
        plt.axis('off')

    plt.show()
visualize_model('BestCheckpoint/resnet50-best.ckpt', dataset_val)

输出
在这里插入图片描述

形式二:固定特征训练

模型训练
net_work = resnet50(pretrained=True)
# 全连接层输入层的大小
in_channels = net_work.fc.in_channels
# 输出通道数大小为分类数12
head = nn.Dense(in_channels, 12)
# 重置全连接层
net_work.fc = head
# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
net_work.avg_pool = avg_pool
# 冻结除最后一层外的所有参数
for param in net_work.get_parameters():
    if param.name not in ["fc.weight", "fc.bias"]:
        param.requires_grad = False
# 定义优化器和损失函数
opt = nn.Momentum(params=net_work.trainable_params(), learning_rate=lr, momentum=0.5)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
def forward_fn(inputs, targets):
    logits = net_work(inputs)
    loss = loss_fn(logits, targets)
    return loss
grad_fn = ms.ops.value_and_grad(forward_fn, None, opt.parameters)
def train_step(inputs, targets):
    loss, grads = grad_fn(inputs, targets)
    opt(grads)
    return loss
# 实例化模型
model1 = train.Model(net_work, loss_fn, opt, metrics={"Accuracy": Accuracy()})
dataset_train = create_dataset_canidae(data_path_train, "train")
step_size_train = dataset_train.get_dataset_size()
dataset_val = create_dataset_canidae(data_path_val, "val")
step_size_val = dataset_val.get_dataset_size()
num_epochs = 10
# 创建迭代器
data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs)
data_loader_val = dataset_val.create_tuple_iterator(num_epochs=num_epochs)
best_ckpt_dir = "./BestCheckpoint"
best_ckpt_path = "./BestCheckpoint/resnet50-best-freezing-param.ckpt"
# 开始循环训练
print("Start Training Loop ...")
best_acc = 0
for epoch in range(num_epochs):
    losses = []
    net_work.set_train()
    epoch_start = time.time()
    # 为每轮训练读入数据
    for i, (images, labels) in enumerate(data_loader_train):
        labels = labels.astype(ms.int32)
        loss = train_step(images, labels)
        losses.append(loss)
    # 每个epoch结束后,验证准确率
    acc = model1.eval(dataset_val)['Accuracy']
    epoch_end = time.time()
    epoch_seconds = (epoch_end - epoch_start) * 1000
    step_seconds = epoch_seconds/step_size_train
    print("-" * 20)
    print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (
        epoch+1, num_epochs, sum(losses)/len(losses), acc
    ))
    print("epoch time: %5.3f ms, per step time: %5.3f ms" % (
        epoch_seconds, step_seconds
    ))
    if acc > best_acc:
        best_acc = acc
        if not os.path.exists(best_ckpt_dir):
            os.mkdir(best_ckpt_dir)
        ms.save_checkpoint(net_work, best_ckpt_path)
print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "
      f"save the best ckpt file in {best_ckpt_path}", flush=True)
模型评估
visualize_model(best_ckpt_path, dataset_val)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/434745.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

uniapp H5 $el.querySelectorAll is not a function

在监听是否在可视区域遇到问题&#xff08;网页端&#xff09; 解决方案 <view class"container"> ...省略 业务代码... </view>参考 &#xff1a; https://blog.csdn.net/qq_18841969/article/details/134620559

狼人杀 (狼人) 个人理解玩法

今天 我们来说说 狼人杀游戏 每个板子都有的一个角色 狼人 因为 动物园板子 平民被换成了 羊驼 所以 狼人也是唯一一个所有板子都有的角色 狼人的技能非常简单 每天晚上 可以袭击一名玩家 如果没有特殊情况 被袭击的玩家天亮时会直接出局 特殊情况包括 比较典型的有 守卫的盾…

Matter 笔记1-环境准备,编译

不要远程登录Ubuntu输入以下命令&#xff0c;原因&#xff1a;ubuntu/linux上的http代理设置 1. 准备 1.1 工具 Ubuntu 22.04 LTSClash 里General的端口设置到ubuntu 的网络设置里 1.2 代码 这里使用芯科整理过的代码 git clone https://github.com/SiliconLabs/matter.…

Redis线程模型解析

引言 Redis是一个高性能的键值对&#xff08;key-value&#xff09;内存数据库&#xff0c;以其卓越的读写速度和灵活的数据类型而广受欢迎。在Redis 6.0之前的版本中&#xff0c;它采用的是一种独特的单线程模型来处理客户端的请求。尽管单线程在概念上似乎限制了其扩展性和并…

实验笔记之——Gaussian Splatting SLAM配置与测试

之前博客对基于3DGS的SLAM进行了调研 学习笔记之——3D Gaussian Splatting及其在SLAM与自动驾驶上的应用调研_3d gaussian splatting slam-CSDN博客文章浏览阅读3.2k次&#xff0c;点赞40次&#xff0c;收藏58次。论文主页3D Gaussian Splatting是最近NeRF方面的突破性工作&a…

腾讯云服务器运行yum检测超级慢问题

公司使用腾讯云服务器。最近买的几台服务器使用yum命令安装或 更新软件特别慢。如下图&#xff1a; 从图中看出网络速度极慢。 大约要等5-10分钟检测和更新配置完毕&#xff0c;进入到软件下载界面下载软件速度就快了。 琢磨了一下&#xff0c;连接慢并不是连接不上。查看yum…

第 125 场 LeetCode 双周赛题解

A 超过阈值的最少操作数 I 排序然后查找第一个大于等于 k 的元素所在的位置 class Solution { public:int minOperations(vector<int> &nums, int k) {sort(nums.begin(), nums.end());return lower_bound(nums.begin(), nums.end(), k) - nums.begin();} };B 超过阈…

后台组件-语言包

<groupId>org.qlm</groupId><artifactId>qlm-language</artifactId><version>1.0-SNAPSHOT</version> 平台提供多语言支持&#xff0c;以上为语言包&#xff0c;提供后台多语言支持。首批实现&#xff1a; public class LanguageConstan…

《操作系统真相还原》读书笔记二:环境搭建 xshell连接virtualbox

修改 sshd_config 使用 vi /etc/ssh/sshd_config命令进入sshd服务配置&#xff0c;键盘输入i进行编辑&#xff0c;将监听端口、监听地址前的 # 号去除&#xff0c;开启允许远程登录&#xff0c;开启使用用户名密码来作为连接验证。修改完成&#xff0c;按一下Esc&#xff0c;输…

【Godot4自学手册】第二十节增加游戏的打击感,镜头震颤、冻结帧和死亡特效

这节我主要学习增加游戏的打击感。我们通过镜头震颤、冻结帧、增加攻击点特效&#xff0c;增加死亡。开始了。 一、添加攻击点特效 增加攻击点特效就是&#xff0c;在攻击敌人时&#xff0c;会在敌人受击点显示一个受击动画。 1.添加动画。 第一步先做个受击点动画。切换到…

【数据结构】堆的TopK问题

大家好&#xff0c;我是苏貝&#xff0c;本篇博客带大家了解堆的TopK问题&#xff0c;如果你觉得我写的还不错的话&#xff0c;可以给我一个赞&#x1f44d;吗&#xff0c;感谢❤️ 目录 一. 前言二. TopK三. 代码 一. 前言 TOP-K问题&#xff1a;即求数据结合中前K个最大的元…

鸿蒙Harmony应用开发—ArkTS声明式开发(通用属性:背景设置)

设置组件的背景样式。 说明&#xff1a; 从API Version 7开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。 background10 background(builder: CustomBuilder, options?: { align?: Alignment }) 设置组件背景。 系统能力&#xff1a; …

外包干了2年,技术退步明显

先说一下自己的情况&#xff0c;研究生&#xff0c;19年进入广州某软件公司&#xff0c;干了接近4年的功能测试&#xff0c;今年年初&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落!而我已经在一个企业干了四年的功能测试&#xf…

CCF CSP 202006-4 1246 (100分详细题解,矩阵乘法+快速幂)

202006-4 1246 &#xff08;100分详细题解&#xff0c;矩阵乘法快速幂&#xff09; 可以先看下csp官方的思路讲解&#xff0c;大思路是状态转移&#xff0c;先看下面s<2的情况 1 -> 2 2 -> 4 4 -> 16 6 -> 6 64 4 16 -> 26&#xff08;不考虑2&#xff0c;6…

“计算机界三大神书”之一 ——SICP【赠书活动】

目录 前言一、SICP简介二、改编本SCIP JS福利总结 前言 《计算机程序的构造和解释》&#xff08;Structure and Interpretation of Computer Programs&#xff0c;简记为SICP&#xff09;是MIT的基础课教材&#xff0c;出版后引起计算机教育界的广泛关注&#xff0c;对推动全世…

Java开发必须掌握,java高级工程师面试类的加载

前言 这些算法&#xff0c;都是小编一点一点看的大佬们的方法&#xff0c;自己积累的. 如果有什么描述的不对 点击领取2024完整开源项目《一线大厂Java面试题解析后端开发学习笔记最新架构讲解视频实战项目源码讲义》 的地方还望大佬赐教 多交流才能进步&#xff0c;加油&#…

Apache POI 解析和处理Excel

摘要&#xff1a;由于开发需要批量导入Excel中的数据&#xff0c;使用了Apache POI库&#xff0c;记录下使用过程 1. 背景 Java 中操作 Excel 文件的库常用的有Apache POI 和阿里巴巴的 EasyExcel 。Apache POI 是一个功能比较全面的 Java 库&#xff0c;适合处理复杂的 Offi…

errno 和 strerror函数

今天写了一个很简单的代码&#xff0c;编译时没啥错误和警告&#xff08;主要编译选项没开启警告&#xff09;&#xff0c;然后运行时居然 segmentation fault&#xff0c;把我给看傻了&#xff0c;代码如下&#xff1a; #include <stdio.h> #include <stdlib.h> …

2024护网面试题精选(一)

0x00.基础漏洞篇 00-TOP10漏洞 1.SQL注入 2.失效的身份认证和会话管理 3.跨站脚本攻击XSS 4.直接引用不安全的对象 5.安全配置错误 6.敏感信息泄露 7.缺少功能级的访问控制 8.跨站请求伪造CSRF 9.实验含有已知漏洞的组件 10.未验证的重定向和转发 01-SQL注入漏洞 …

PackagesNotFoundError:学习利用报错信息找到解决方法

反思&#xff1a;之前看到报错经常是直接复制报错信息去网上搜&#xff0c;但很多情况下报错信息里其实就给出了解决方案 报错信息&#xff1a; Collecting package metadata (current_repodata.json): done Solving environment: unsuccessful initial attempt using frozen …