昇思25天学习打卡营第13天|应用实践之ResNet50迁移学习

基本介绍

        今日的应用实践的模型是计算机实践领域中十分出名的模型----ResNet模型。ResNet是一种残差网络结构,它通过引入“残差学习”的概念来解决随着网络深度增加时训练困难的问题,从而能够训练更深的网络结构。现很多网络极深的模型或多或少都受此影响。今日的主要任务是使用ResNet50进行迁移学习,所谓的迁移学习是不从头到尾训练一个模型,而是在一个特别大的数据集上面训练得到一个预训练模型,然后使用该模型的权重作为初始化参数,最后应用于特定的任务。对特定的任务来说也可以对模型进行训练,但是需要冻结模型的某些参数,一般只训练模型的分类器部分,不会训练特征提取部分。下面我们详细讲讲今日的应用实践。

数据集准备

        本次使用的数据集是来自ImageNet数据集中抽取出来的狼狗数据集,每个类别大概有120张训练图像与30张验证图像,数据集可通过华为云OBS和相关API直接下载,然后进行加载。可借助MindSpore提供的API对数据集进行加载,同时,因为数据集在送入模型之前需做些处理,这些可封装到一个函数内,具体代码如下:

def create_dataset_canidae(dataset_path, usage):
    """数据加载"""
    data_set = ds.ImageFolderDataset(dataset_path,
                                     num_parallel_workers=workers,
                                     shuffle=True,)

    # 数据增强操作
    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
    scale = 32

    if usage == "train":
        # Define map operations for training dataset
        trans = [
            vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            vision.RandomHorizontalFlip(prob=0.5),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]
    else:
        # Define map operations for inference dataset
        trans = [
            vision.Decode(),
            vision.Resize(image_size + scale),
            vision.CenterCrop(image_size),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]


    # 数据映射操作
    data_set = data_set.map(
        operations=trans,
        input_columns='image',
        num_parallel_workers=workers)


    # 批量操作
    data_set = data_set.batch(batch_size)

    return data_set

通过上述操作后,我们可以很方便的调用数据集,无论是进行训练还是可视化,都可以。

模型搭建

        准备好数据集后,自然就是进行模型搭建,ResNet50是一个非常常见的模型,具体模型结构和不同AI框架下的代码都很容易获取,MindSpore官方也有相关的实现代码,我们直接使用,模型的代码如下:

class ResNet(nn.Cell):
    def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],
                 layer_nums: List[int], num_classes: int, input_channel: int) -> None:
        super(ResNet, self).__init__()

        self.relu = nn.ReLU()
        # 第一个卷积层,输入channel为3(彩色图像),输出channel为64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)
        self.norm = nn.BatchNorm2d(64)
        # 最大池化层,缩小图片的尺寸
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        # 各个残差网络结构块定义,
        self.layer1 = make_layer(64, block, 64, layer_nums[0])
        self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)
        self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)
        self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)
        # 平均池化层
        self.avg_pool = nn.AvgPool2d()
        # flattern层
        self.flatten = nn.Flatten()
        # 全连接层
        self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)

    def construct(self, x):

        x = self.conv1(x)
        x = self.norm(x)
        x = self.relu(x)
        x = self.max_pool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.fc(x)

        return x


def _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],
            layers: List[int], num_classes: int, pretrained: bool, pretrianed_ckpt: str,
            input_channel: int):
    model = ResNet(block, layers, num_classes, input_channel)

    if pretrained:
        # 加载预训练模型
        download(url=model_url, path=pretrianed_ckpt, replace=True)
        param_dict = load_checkpoint(pretrianed_ckpt)
        load_param_into_net(model, param_dict)

    return model


def resnet50(num_classes: int = 1000, pretrained: bool = False):
    "ResNet50模型"
    resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"
    resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"
    return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,
                   pretrained, resnet50_ckpt, 2048)

训练模型

        我们采用定特征进行训练,需要冻结除最后一层之外的所有网络层,最后一层其实就是一个分类层,前面是特征提取层,MindSpore可以通过设置 requires_grad == False 冻结参数,以便不在反向传播中计算梯度,具体代码如下:

import mindspore as ms
import matplotlib.pyplot as plt
import os
import time

net_work = resnet50(pretrained=True)

# 全连接层输入层的大小
in_channels = net_work.fc.in_channels
# 输出通道数大小为狼狗分类数2
head = nn.Dense(in_channels, 2)
# 重置全连接层
net_work.fc = head

# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
net_work.avg_pool = avg_pool

# 冻结除最后一层外的所有参数
for param in net_work.get_parameters():
    if param.name not in ["fc.weight", "fc.bias"]:
        param.requires_grad = False

# 定义优化器和损失函数
opt = nn.Momentum(params=net_work.trainable_params(), learning_rate=lr, momentum=0.5)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')


def forward_fn(inputs, targets):
    logits = net_work(inputs)
    loss = loss_fn(logits, targets)

    return loss

grad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)

def train_step(inputs, targets):
    loss, grads = grad_fn(inputs, targets)
    opt(grads)
    return loss

# 实例化模型
model1 = train.Model(net_work, loss_fn, opt, metrics={"Accuracy": train.Accuracy()})

一切准备妥当后,便可以进行训练,由于有预训练模型,模型的训练速度非常快,比从头到尾训练快了好几倍。训练了5轮,效果就非常好了:

模型可视化预测

        有了训练好的模型,自然要看看训练得好不好,除了看评价指标,最直观的就是实际使用一下模型,由于这是图像分类任务,所以将其可视化。这一整个流程的代码如下:

def visualize_model(best_ckpt_path, val_ds):
    net = resnet50()
    # 全连接层输入层的大小
    in_channels = net.fc.in_channels
    # 输出通道数大小为狼狗分类数2
    head = nn.Dense(in_channels, 2)
    # 重置全连接层
    net.fc = head
    # 平均池化层kernel size为7
    avg_pool = nn.AvgPool2d(kernel_size=7)
    # 重置平均池化层
    net.avg_pool = avg_pool
    # 加载模型参数
    param_dict = ms.load_checkpoint(best_ckpt_path)
    ms.load_param_into_net(net, param_dict)
    model = train.Model(net)
    # 加载验证集的数据进行验证
    data = next(val_ds.create_dict_iterator())
    images = data["image"].asnumpy()
    labels = data["label"].asnumpy()
    class_name = {0: "dogs", 1: "wolves"}
    # 预测图像类别
    output = model.predict(ms.Tensor(data['image']))
    pred = np.argmax(output.asnumpy(), axis=1)

    # 显示图像及图像的预测值
    plt.figure(figsize=(5, 5))
    for i in range(4):
        plt.subplot(2, 2, i + 1)
        # 若预测正确,显示为蓝色;若预测错误,显示为红色
        color = 'blue' if pred[i] == labels[i] else 'red'
        plt.title('predict:{}'.format(class_name[pred[i]]), color=color)
        picture_show = np.transpose(images[i], (1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        picture_show = std * picture_show + mean
        picture_show = np.clip(picture_show, 0, 1)
        plt.imshow(picture_show)
        plt.axis('off')

    plt.show()

调用该函数后,可视化结果如下:可以看出还是很准的

Jupyter运行情况

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/786466.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Gartner发布采用美国防部模型实施零信任的方法指南:七大支柱落地方法

零信任是网络安全计划的关键要素,但制定策略可能会很困难。安全和风险管理领导者应使用美国国防部模型的七大支柱以及 Gartner 研究来设计零信任策略。 战略规划假设 到 2026 年,10% 的大型企业将拥有全面、成熟且可衡量的零信任计划,而 202…

分享五款软件,成为高效生活的好助手

​ 给大家分享一些优秀的软件工具,是一件让人很愉悦的事情,今天继续带来5款优质软件。 1.图片放大——Bigjpg ​ Bigjpg是一款图片放大软件,采用先进的AI算法,能够在不损失图片质量的前提下,将低分辨率图片放大至所需尺寸。无论…

STM32CubeIDE离线汉化教程

按照网上的方法下载好ZIP文件后 帮助->安装新软件-> 按顺序选择文件 点击完成,后等待右下脚的精度条到位即可

浮动刹车盘和固定刹车盘有什么区别

在讨论刹车系统的设计理念时,浮动和固定刹车盘无疑是两个重要的分支。 它们各自拥有独特的设计哲学、工作原理以及适用场景,这些差异直接影响到了制动系统的性能和耐久性。 浮动刹车盘和固定刹车盘在设计和工作原理上有显著的区别,主要体现在…

Linux: 命令行参数和环境变量究竟是什么?

Linux: 命令行参数和环境变量究竟是什么? 一、命令行参数1.1 main函数参数意义1.2 命令行参数概念1.3 命令行参数实例 二、环境变量2.1 环境变量概念2.2 环境变量:PATH2.2.1 如何查看PATH中的内容2.2.2 如何让自己的可执行文件不带路径运行 2.3 环境变量…

自动化立体仓库设计步骤:7步

导语 大家好,我是社长,老K。专注分享智能制造和智能仓储物流等内容。 完整版文件和更多学习资料,请球友到知识星球【智能仓储物流技术研习社】自行下载 这份文件是关于自动化立体仓库设计步骤的详细指南,其核心内容包括以下几个阶…

如何禁用键盘上的特定键或快捷方式?这里有详细步骤

要禁用特定的键盘键或快捷键吗?微软官方应用程序Microsoft PowerToys使这项任务变得非常简单。以下是使用Microsoft PowerToys中的键盘管理器禁用特定键或快捷方式的快速指南。 如果你还没有安装Microsoft PowerToys 如果你的设备上没有安装Microsoft PowerToys&a…

新能源汽车推广 - 世媒讯软文发稿需要注意什么

在环保意识日益增强和政策支持的背景下,新能源汽车市场迎来了前所未有的发展机遇。对于新能源汽车企业而言,如何有效地推广产品成为了关键。而软文发稿作为一种重要的营销手段,能够通过内容的形式潜移默化地影响消费者的认知和决策。那么&…

由于找不到emp.dll无法运行游戏的多个有效解决方法分享

在玩游戏时候是否遇到过找不到emp.dll,无法继续执行代码问题无法打开游戏?那么这个emp.dll是什么呢?为什么会丢失,emp.dll丢失要怎么办?今天就给大家详细介绍一下emp.dll文件与emp.dll丢失的多个解决方法! 一、emp.dll…

【AI技术的未来之路】从模型到应用,跨越超级应用陷阱,迈向个性化智能体

💓 博客主页:倔强的石头的CSDN主页 📝Gitee主页:倔强的石头的gitee主页 ⏩ 文章专栏:《热点时事》 期待您的关注 ​ 目录 引言 一、AI技术应用场景探索: 二、避免超级应用陷阱的策略: 三、个…

PaddleVideo:Squeeze Time算法移植

参考PaddleVideo/docs/zh-CN/contribute/add_new_algorithm.md at develop PaddlePaddle/PaddleVideo GitHubAwesome video understanding toolkits based on PaddlePaddle. It supports video data annotation tools, lightweight RGB and skeleton based action recognitio…

[数仓]七、离线数仓(PrestoKylin即席查询)

第1章 Presto 1.1 Presto简介 1.1.1 Presto概念 1.1.2 Presto架构 1.1.4 Presto、Impala性能比较 Presto、Impala性能比较_presto和impala对比-CSDN博客 测试结论:Impala性能稍领先于Presto,但是Presto在数据源支持上非常丰富,包括Hive、图数据库、传统关系型数据库、Re…

Codeforces Round 956 F. array-value 【01Trie查询异或最小值】

题意 给定一个非负整数数组 a a a 对每个长度至少为 2 2 2 的子数组&#xff0c;定义其权值为&#xff1a;子数组内两两异或值最小值 即 b ⊂ a [ l , r ] , w ( b ) min ⁡ l ≤ i < j ≤ r { a i ⨁ a j } b \subset a[l, r], \quad w(b) \min_{l \leq i < j \le…

谷歌账号被停用怎么办?立刻申诉!申诉流程和经验、中英文申诉信模板

很多鞥有这两年新注册的Google账号或者购买的谷歌账号&#xff0c;在使用时可能都遇到过被停用的情况。极端的还有刚注册号&#xff0c;反手就被谷歌 停用了&#xff0c;或者被连续停用。 今天我们就来聊一聊&#xff0c;谷歌账号为什么会被停用&#xff0c;以及谷歌账号被停用…

走拼箱货必看海运拼箱的实用技巧

在国际海运运输中&#xff0c;海运拼箱适用于货物数量较少或体积不足以填满整个集装箱的情况。 海运拼箱货物通常由物流公司或货代进行组织和管理。多个货主的货物通过合理拼装&#xff0c;使集装箱空间得到充分利用。 那么&#xff0c;在海运拼箱和整柜有哪些不同&#xff0c…

淘宝商品历史价格查询(免费)

当前资料来源于网络&#xff0c;禁止用于商用&#xff0c;仅限于学习。 淘宝联盟里面就可以看到历史价格 并且没有加密 淘宝商品历史价格查询可以通过以下步骤进行&#xff1a; 先下载后&#xff0c;登录app注册账户 打开淘宝网站或淘宝手机App。在搜索框中输入你想要查询的商…

Qt 线程 QThread类详解

Qt 线程中QThread的使用 在进行桌面应用程序开发的时候&#xff0c; 假设应用程序在某些情况下需要处理比较复杂的逻辑&#xff0c; 如果只有一个线程去处理&#xff0c;就会导致窗口卡顿&#xff0c;无法处理用户的相关操作。这种情况下就需要使用多线程&#xff0c;其中一个…

亚马逊云科技EC2简明教程

&#x1f4a1; 完全适用于新手操作的Amazon EC2引导教程 简述 在亚马逊云科技中&#xff0c;存在多种计算服务&#xff0c;在此&#xff0c;我们将会着重讨论Amazon EC2(以下简称EC2)&#xff0c;EC2作为亚马逊云科技的明星产品、核心产品&#xff0c;是大多数开发者和企业用…

基于JAVA+SpringBoot+Vue的自动阅卷分析系统

✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取项目下载方式&#x1f345; 一、项目背景介绍&#xff1a; 在当前教育评估体系中…

网络安全高级工具软件100套

1、 Nessus&#xff1a;最好的UNIX漏洞扫描工具 Nessus 是最好的免费网络漏洞扫描器&#xff0c;它可以运行于几乎所有的UNIX平台之上。它不止永久升级&#xff0c;还免费提供多达11000种插件&#xff08;但需要注册并接受EULA-acceptance–终端用户授权协议&#xff09;。 它…