昇思学习打卡-9-ResNet50图像分类

文章目录

  • 网络介绍
  • 数据可视化
  • 部分网络实现
    • Building Block结构
    • BottleNet模块
  • 模型训练
  • 推理结果可视化
  • 学习总结
    • 优点
    • 不足

网络介绍

在ResNet网络提出之前,传统的卷积神经网络都是将一系列的卷积层和池化层堆叠得到的,但当网络堆叠到一定深度时,就会出现退化问题。

数据可视化

import matplotlib.pyplot as plt
import numpy as np

data_iter = next(dataset_train.create_dict_iterator())

images = data_iter["image"].asnumpy()
labels = data_iter["label"].asnumpy()
print(f"Image shape: {images.shape}, Label shape: {labels.shape}")

# 训练数据集中,前六张图片所对应的标签
print(f"Labels: {labels[:6]}")

classes = []

with open(data_dir + "/batches.meta.txt", "r") as f:
    for line in f:
        line = line.rstrip()
        if line:
            classes.append(line)

# 训练数据集的前六张图片
plt.figure()
for i in range(6):
    plt.subplot(2, 3, i + 1)
    image_trans = np.transpose(images[i], (1, 2, 0))
    mean = np.array([0.4914, 0.4822, 0.4465])
    std = np.array([0.2023, 0.1994, 0.2010])
    image_trans = std * image_trans + mean
    image_trans = np.clip(image_trans, 0, 1)
    plt.title(f"{classes[labels[i]]}")
    plt.imshow(image_trans)
    plt.axis("off")
plt.show()

在这里插入图片描述

部分网络实现

Building Block结构

from typing import Type, Union, List, Optional
import mindspore.nn as nn
from mindspore.common.initializer import Normal

# 初始化卷积层与BatchNorm的参数
weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)

class ResidualBlockBase(nn.Cell):
    expansion: int = 1  # 最后一个卷积核数量与第一个卷积核数量相等

    def __init__(self, in_channel: int, out_channel: int,
                 stride: int = 1, norm: Optional[nn.Cell] = None,
                 down_sample: Optional[nn.Cell] = None) -> None:
        super(ResidualBlockBase, self).__init__()
        if not norm:
            self.norm = nn.BatchNorm2d(out_channel)
        else:
            self.norm = norm

        self.conv1 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=3, stride=stride,
                               weight_init=weight_init)
        self.conv2 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=3, weight_init=weight_init)
        self.relu = nn.ReLU()
        self.down_sample = down_sample

    def construct(self, x):
        """ResidualBlockBase construct."""
        identity = x  # shortcuts分支

        out = self.conv1(x)  # 主分支第一层:3*3卷积层
        out = self.norm(out)
        out = self.relu(out)
        out = self.conv2(out)  # 主分支第二层:3*3卷积层
        out = self.norm(out)

        if self.down_sample is not None:
            identity = self.down_sample(x)
        out += identity  # 输出为主分支与shortcuts之和
        out = self.relu(out)

        return out

BottleNet模块

class ResidualBlock(nn.Cell):
    expansion = 4  # 最后一个卷积核的数量是第一个卷积核数量的4倍

    def __init__(self, in_channel: int, out_channel: int,
                 stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None:
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channel, out_channel,
                               kernel_size=1, weight_init=weight_init)
        self.norm1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(out_channel, out_channel,
                               kernel_size=3, stride=stride,
                               weight_init=weight_init)
        self.norm2 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion,
                               kernel_size=1, weight_init=weight_init)
        self.norm3 = nn.BatchNorm2d(out_channel * self.expansion)

        self.relu = nn.ReLU()
        self.down_sample = down_sample

    def construct(self, x):

        identity = x  # shortscuts分支

        out = self.conv1(x)  # 主分支第一层:1*1卷积层
        out = self.norm1(out)
        out = self.relu(out)
        out = self.conv2(out)  # 主分支第二层:3*3卷积层
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv3(out)  # 主分支第三层:1*1卷积层
        out = self.norm3(out)

        if self.down_sample is not None:
            identity = self.down_sample(x)

        out += identity  # 输出为主分支与shortcuts之和
        out = self.relu(out)

        return out在这里插入代码片

模型训练

# 开始循环训练
print("Start Training Loop ...")

for epoch in range(num_epochs):
    curr_loss = train(data_loader_train, epoch)
    curr_acc = evaluate(data_loader_val)

    print("-" * 50)
    print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (
        epoch+1, num_epochs, curr_loss, curr_acc
    ))
    print("-" * 50)

    # 保存当前预测准确率最高的模型
    if curr_acc > best_acc:
        best_acc = curr_acc
        ms.save_checkpoint(network, best_ckpt_path)

print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "
      f"save the best ckpt file in {best_ckpt_path}", flush=True)

推理结果可视化

import matplotlib.pyplot as plt


def visualize_model(best_ckpt_path, dataset_val):
    num_class = 10  # 对狼和狗图像进行二分类
    net = resnet50(num_class)
    # 加载模型参数
    param_dict = ms.load_checkpoint(best_ckpt_path)
    ms.load_param_into_net(net, param_dict)
    # 加载验证集的数据进行验证
    data = next(dataset_val.create_dict_iterator())
    images = data["image"]
    labels = data["label"]
    # 预测图像类别
    output = net(data['image'])
    pred = np.argmax(output.asnumpy(), axis=1)

    # 图像分类
    classes = []

    with open(data_dir + "/batches.meta.txt", "r") as f:
        for line in f:
            line = line.rstrip()
            if line:
                classes.append(line)

    # 显示图像及图像的预测值
    plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        # 若预测正确,显示为蓝色;若预测错误,显示为红色
        color = 'blue' if pred[i] == labels.asnumpy()[i] else 'red'
        plt.title('predict:{}'.format(classes[pred[i]]), color=color)
        picture_show = np.transpose(images.asnumpy()[i], (1, 2, 0))
        mean = np.array([0.4914, 0.4822, 0.4465])
        std = np.array([0.2023, 0.1994, 0.2010])
        picture_show = std * picture_show + mean
        picture_show = np.clip(picture_show, 0, 1)
        plt.imshow(picture_show)
        plt.axis('off')

    plt.show()


# 使用测试数据集进行验证
visualize_model(best_ckpt_path=best_ckpt_path, dataset_val=dataset_val)

在这里插入图片描述

学习总结

经学习,对resnet网络的优缺点进行总结,其中

优点

  • 残差学习:ResNet引入了残差学习框架,允许网络学习输入和输出之间的残差函数,而不是直接学习映射关系,这使得网络能够训练更深的架构。
  • 解决梯度消失问题:通过使用残差连接,ResNet缓解了深层网络训练中的梯度消失问题,使得网络可以有效地进行反向传播。
  • 简化训练过程:ResNet的架构简化了网络的训练过程,因为每一层只需要学习输入和输出的残差,而不是整个映射。
  • 提高性能:在多个视觉识别任务中,ResNet展示了其卓越的性能,包括ImageNet和COCO竞赛。
  • 易于扩展:ResNet的设计允许通过简单地增加层的数量来扩展网络的深度,而无需重新设计网络架构。
  • 泛化能力:ResNet在多种数据集和任务上展示了良好的泛化能力。

不足

  • 计算资源消耗:由于ResNet包含大量的层,它需要较多的计算资源和内存,这可能限制了其在资源受限的环境中的应用。
  • 过拟合风险:在数据量较小的情况下,ResNet由于其复杂性可能会过拟合,尤其是在没有足够正则化的情况下。
  • 训练时间:相比于较浅的网络,ResNet的训练时间更长,因为它有更多的参数和层需要优化。
  • 参数数量:ResNet的参数数量较多,这可能导致模型在某些任务上不够高效。
  • 对超参数敏感:ResNet的性能可能对超参数(如学习率、批量大小等)的选择较为敏感,需要仔细调整以获得最佳性能。

此章节学习到此结束,感谢昇思平台。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/780645.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

STM32崩溃问题排查

文章目录 前言1. 问题说明2. STM32(Cortex M4内核)的寄存器3. 崩溃问题分析3.1 崩溃信息的来源是哪里?3.2 崩溃信息中的每个关键字代表的含义3.3 利用崩溃信息去查找造成崩溃的点3.4 keil5中怎么根据地址找到问题点3.5 keil5上编译时怎么输出…

C++模板元编程(二)——完美转发

完美转发指的是函数模板可以将自己的参数“完美”地转发给内部调用的其它函数。所谓完美,即不仅能准确地转发参数的值,还能保证被转发参数的左、右值属性不变。 文章目录 场景旧的方法新的方法内部实现参考文献 场景 思考下面的代码: templ…

深度学习之网络构建

目标 选择合适的神经网络 卷积神经网络(CNN):我们处理图片、视频一般选择CNN 循环神经网络(RNN):我们处理时序数据一般选择RNN 超参数的设置 为什么训练的模型的错误率居高不下 如何调测出最优的超参数 …

【pytorch20】多分类问题

网络结构以及示例 该网络的输出不是一层或两层的,而是一个十层的代表有十分类 新建三个线性层,每个线性层都有w和b的tensor 首先输入维度是784,第一个维度是ch_out,第二个维度才是ch_in(由于后面要转置),没有经过softmax函数和…

C++ 引用——引用的本质

本质:引用的本质在c内部实现是一个指针常量 C推荐用引用技术,因为语法方便,引用本质是指针常量,但是所有的指针操作编译器都帮我们做了 示例: 运行结果:

C++初学者指南-4.诊断---valgrind

C初学者指南-4.诊断—Valgrind Valgrind(内存错误检测工具) 检测常见运行时错误 读/写释放的内存或不正确的堆栈区域使用未初始化的值不正确的内存释放,如双重释放滥用内存分配函数内存泄漏–非故意的内存消耗通常与程序逻辑缺陷有关&#xf…

水箱高低水位浮球液位开关工作原理

工作原理 水箱高低水位浮球液位开关是一种利用浮球随液位升降来实现液位控制的设备。其基本原理是浮球在液体的浮力作用下上下浮动,通过磁性作用驱动与之相连的磁簧开关的开合,从而实现液位的高低控制和报警。当液位升高时,浮球上浮&#xf…

cmake find_package 使用笔记

目录 1 find_package2 config mode2.1 搜索的文件名2.2 搜索路径 3 module mode3.1 搜索的文件名3.2 搜索路径 参考 1 find_package 这是官方文档 下面是学习总结: 首先是find_package的作用是什么?引入预编译的库。 find_package有两种模式&#xff1a…

C语言 指针和数组——指针和二维数组之间的关系

目录 换个角度看二维数组 指向二维数组的行指针 按行指针访问二维数组元素 再换一个角度看二维数组 按列指针访问二维数组元素 二维数组作函数参数 指向二维数组的行指针作函数参数 指向二维数组的列指针作函数参数​编辑 用const保护你传给函数的数据 小结 换个角度看…

使用antd的<Form/>组件获取富文本编辑器输入的数据

前端开发中,嵌入富文本编辑器时,可以通过富文本编辑器自身的事件处理函数将数据传输给后端。有时候,场景稍微复杂点,比如一个输入页面除了要保存富文本编辑器的内容到后端,可能还有一些其他输入组件获取到的数据也一并…

Win10安装MongoDB(详细版)

文章目录 1、安装MongoDB Server1.1. 下载1.2. 安装 2、手动安装MongoDB Compass(GUI可视工具)2.1. 下载2.2.安装 3、测试连接3.1.MongoDB Compass 连接3.2.使用Navicat连接 1、安装MongoDB Server 1.1. 下载 官网下载地址 https://www.mongodb.com/try/download/community …

『大模型笔记』《Pytorch实用教程》(第二版)

『大模型笔记』《Pytorch实用教程》(第二版) 文章目录 一. 《Pytorch实用教程》(第二版)1.1 上篇1.2 中篇1.3 下篇1.4 本书亮点1.5 本书内容及结构二. 参考文献🖥️ 配套代码(开源免费):https://github.com/TingsongYu/PyTorch-Tutorial-2nd📚 在线阅读(开源免费)…

nginx相关概念(反向代理、负载均衡)

1 Nginx 是什么 Nginx是一款轻量级的Web 服务器,其特点是占有内存少,并发能力强 2 Nginx 反向代理 正向代理代替客户端去发送请求反向代理代替服务端接受请求 2.1 正向代理 若客户端无法直接访问到目标服务器 server 则客户端需要配置代理服务器 pr…

正则表达式 先行断言 \S {} 示例

目录 数据准备一. 先行断言1.1 正向先行断言1.2 负向先行断言 二. 配合 {} 和 \S 使用2.1 匹配一个任意非空白字符2.2 匹配任意多个非空白字符2.3 匹配3个非空白字符2.4 匹配至少3个非空白字符2.5 匹配0~3个非空白字符 数据准备 ⏹文件1 0561-10 AAA 123 dfg 345 sss 0561-2…

【电路笔记】-AB类放大器

AB类放大器 文章目录 AB类放大器1、概述2、AB类放大器介绍3、AB类放大器效率4、偏置方法4.1 电压偏置4.2 分压网络4.3 电位器偏置4.4 二极管偏置5、二极管网络和电流源6、AB类放大器的电源分配7、总结1、概述 A类放大器提供非常好的输出线性度,这意味着可以忠实地再现信号,但…

数据结构之“队列”(全方位认识)

🌹个人主页🌹:喜欢草莓熊的bear 🌹专栏🌹:数据结构 前言 上期博客介绍了” 栈 “这个数据结构,他具有先进后出的特点。本期介绍“ 队列 ”这个数据结构,他具有先进先出的特点。 目录…

【无标题】ubuntu的安装和登录

1.安装ubuntu系统并允许root用户登录 一、安装ubuntu分为以下几个步骤: a、在官方网站下载镜像:https://ubuntu.com/ b、选择安装模式为:Try or install Ubuntu (体验或安装) c、选择体验系统或安装系统&#xff1a…

ComfyUI预处理器ControlNet简单介绍与使用(附件工作流)

简介 ControlNet 是一个很强的插件,提供了很多种图片的控制方式,有的可以控制画面的结构,有的可以控制人物的姿势,还有的可以控制图片的画风,这对于提高AI绘画的质量特别有用。接下来就演示几种热门常用的控制方式 1…

RAG理论:ES混合搜索BM25+kNN(cosine)以及归一化

接前一篇:RAG实践:ES混合搜索BM25+kNN(cosine) https://blog.csdn.net/Xin_101/article/details/140230948 本文主要讲解混合搜索相关理论以及计算推导过程, 包括BM25、kNN以及ES中使用混合搜索分数计算过程。 详细讲解: (1)ES中如何通过BM25计算关键词搜索分数; (2)…

【手写数据库内核组件】0201 哈希表hashtable的实战演练,多种非加密算法,hash桶的冲突处理,查找插入删除操作的代码实现

hash表原理与实战 ​专栏内容: postgresql使用入门基础手写数据库toadb并发编程 个人主页:我的主页 管理社区:开源数据库 座右铭:天行健,君子以自强不息;地势坤,君子以厚德载物. 文章目录 hash表…