昇思25天学习打卡营第13天 | ResNet50迁移学习

昇思25天学习打卡营第13天 | ResNet50迁移学习

文章目录

  • 昇思25天学习打卡营第13天 | ResNet50迁移学习
    • 数据集
      • 加载数据集
      • 数据集可视化
    • 模型训练
      • 固定特征
    • 总结
    • 打卡

在实际应用场景中,由于训练数据集不足,很少从头开始训练整个网络。普遍做法是在一个非常大的基础数据集上训练得到一个 预训练模型,然后使用该模型来初始化网络的权重参数或作为固定特征提取器应用于特定的任务。

数据集

使用ResNet50在狗与狼分类数据集上进行训练,数据集中图片来自ImageNet,每个分类大约120张训练图像和30张验证图像。

from download import download

dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"

download(dataset_url, "./datasets-Canidae", kind="zip", replace=True)

加载数据集

使用mindspore.dataset.ImageFolderDataset接口来加载数据集,并进行数据增强::

batch_size = 18                             # 批量大小
image_size = 224                            # 训练图像空间大小
num_epochs = 5                             # 训练周期数
lr = 0.001                                  # 学习率
momentum = 0.9                              # 动量
workers = 4                                 # 并行线程个数

import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision

# 数据集目录路径
data_path_train = "./datasets-Canidae/data/Canidae/train/"
data_path_val = "./datasets-Canidae/data/Canidae/val/"

# 创建训练数据集

def create_dataset_canidae(dataset_path, usage):
    """数据加载"""
    data_set = ds.ImageFolderDataset(dataset_path,
                                     num_parallel_workers=workers,
                                     shuffle=True,)

    # 数据增强操作
    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
    scale = 32

    if usage == "train":
        # Define map operations for training dataset
        trans = [
            vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            vision.RandomHorizontalFlip(prob=0.5),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]
    else:
        # Define map operations for inference dataset
        trans = [
            vision.Decode(),
            vision.Resize(image_size + scale),
            vision.CenterCrop(image_size),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]


    # 数据映射操作
    data_set = data_set.map(
        operations=trans,
        input_columns='image',
        num_parallel_workers=workers)


    # 批量操作
    data_set = data_set.batch(batch_size)

    return data_set


dataset_train = create_dataset_canidae(data_path_train, "train")
step_size_train = dataset_train.get_dataset_size()

dataset_val = create_dataset_canidae(data_path_val, "val")
step_size_val = dataset_val.get_dataset_size()

数据集可视化

通过next迭代访问数据集,一次获取batch_size个图像及标签:

data = next(dataset_train.create_dict_iterator())
images = data["image"]
labels = data["label"]

print("Tensor of image", images.shape)
print("Labels:", labels)

import matplotlib.pyplot as plt
import numpy as np

# class_name对应label,按文件夹字符串从小到大的顺序标记label
class_name = {0: "dogs", 1: "wolves"}

plt.figure(figsize=(5, 5))
for i in range(4):
    # 获取图像及其对应的label
    data_image = images[i].asnumpy()
    data_label = labels[i]
    # 处理图像供展示使用
    data_image = np.transpose(data_image, (1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    data_image = std * data_image + mean
    data_image = np.clip(data_image, 0, 1)
    # 显示图像
    plt.subplot(2, 2, i+1)
    plt.imshow(data_image)
    plt.title(class_name[int(labels[i].asnumpy())])
    plt.axis("off")

plt.show()

模型训练

使用ResNet50网络,通过将pretrained设置为True来加载预训练模型:

class ResNet(nn.Cell):
	 def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],layer_nums: List[int], num_classes: int, input_channel: int) -> None:
	 	# ...
	 def construct(self, x):
	 	# ...

	def _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],
            layers: List[int], num_classes: int, pretrained: bool, pretrianed_ckpt: str,
            input_channel: int):
    model = ResNet(block, layers, num_classes, input_channel)

    if pretrained:
        # 加载预训练模型
        download(url=model_url, path=pretrianed_ckpt, replace=True)
        param_dict = load_checkpoint(pretrianed_ckpt)
        load_param_into_net(model, param_dict)

    return model


def resnet50(num_classes: int = 1000, pretrained: bool = False):
    "ResNet50模型"
    resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"
    resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"
    return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,
                   pretrained, resnet50_ckpt, 2048)    

通过resnet50接口创建网络模型,如果设置pretrained=True,则下载并加载预训练模型到ResNet50中。

固定特征

使用固定特征进行训练时,需要冻结除最后一层之外的所有网络层。通过设置requires_grad=False冻结参数,使其不在反向传播中计算梯度。

import mindspore as ms
import matplotlib.pyplot as plt
import os
import time

net_work = resnet50(pretrained=True)

# 全连接层输入层的大小
in_channels = net_work.fc.in_channels
# 输出通道数大小为狼狗分类数2
head = nn.Dense(in_channels, 2)
# 重置全连接层
net_work.fc = head

# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
net_work.avg_pool = avg_pool

# 冻结除最后一层外的所有参数
for param in net_work.get_parameters():
    if param.name not in ["fc.weight", "fc.bias"]:
        param.requires_grad = False

# 定义优化器和损失函数
opt = nn.Momentum(params=net_work.trainable_params(), learning_rate=lr, momentum=0.5)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')


def forward_fn(inputs, targets):
    logits = net_work(inputs)
    loss = loss_fn(logits, targets)

    return loss

grad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)

def train_step(inputs, targets):
    loss, grads = grad_fn(inputs, targets)
    opt(grads)
    return loss

# 实例化模型
model1 = train.Model(net_work, loss_fn, opt, metrics={"Accuracy": train.Accuracy()})

总结

这一小节对预训练模型引入的原因进行了说明,通过加载预训练模型的参数到ResNet50模型中进行参数初始化,从而加速网络的训练过程。可以通过设置参数requires_grad=False来冻结参数,使其作为固定特征,不参与梯度计算与参数优化。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/797059.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Qt下使用OpenCV的鼠标回调函数进行圆形/矩形/多边形的绘制

文章目录 前言一、设置imshow显示窗口二、绘制圆形三、绘制矩形四、绘制多边形五、示例完整代码总结 前言 本文主要讲述了在Qt下使用OpenCV的鼠标回调在OpenCV的namedWindow和imshow函数显示出来的界面上进行一些图形的绘制,并最终将绘制好的图形显示在QLabel上。示…

算法笔记——LCR

一.LCR 152. 验证二叉搜索树的后序遍历序列 题目描述: 给你一个二叉搜索树的后续遍历序列,让你判断该序列是否合法。 解题思路: 根据二叉搜索树的特性,二叉树搜索的每一个结点,大于左子树,小于右子树。…

到底哪些牌子的鼠标好?选择鼠标需要注意哪些问题?

鼠标的选择从外观材质、手感、配置到价格定位都不尽相同,消费者的选择也越来越多。一般在选择鼠标时,我们也会发现鼠标能够选择的品牌虽然众多,但是不同品牌下的鼠标在品质和款式上都是大不相同的,那么到底哪些牌子的鼠标好呢?我…

【Python】数据分析-Matplotlib绘图

数据分析 Jupyter Notebook Jupyter Notebook: 一款用于编程、文档、笔记和展示的软件。 启动命令: jupyter notebookMatplotlib 设置中文格式:plt.rcParams[font.sans-serif] [KaiTi] # 查看本地所有字体 import matplotlib.font_manager a sorted…

uniapp使用多列布局显示图片,一行两列

完整代码&#xff1a; <script setup>const src "https://qiniu-web-assets.dcloud.net.cn/unidoc/zh/shuijiao.jpg" </script><template><view class"content"><view class"img-list"><image :src"src…

zabbix web页面添加对nginx监控

1.nginx安装zabbix-agent2,并修改配置文件中server地址为zabbix-server的地址 ]# egrep ^Server|^Hostname /etc/zabbix/zabbix_agent2.conf Server172.16.1.162 ServerActive172.16.1.162 Hostnameweb01 2.zabbix web页面上进行添加客户端 3.默认的nginx监控模板中的状态模块…

基于Opencv的裂缝检测及测量

最终效果如下: 不仅标出了裂纹位置,还标出了裂纹的尺寸 原图如下: 核心原理就是基于opencv的图片处理及轮廓查找,具体逻辑看代码,话不多说上代码: # 在一张图片上检测圆 import cv2 import numpy as npdef detect_circle(img):"""在一张图片上检测圆img…

Prometheus 云原生 - 微服务监控报警系统 (Promethus、Grafana、Node_Exporter)部署、简单使用

目录 开始 Prometheus 介绍 基本原理 组件介绍 下文部署组件的工作方式 Prometheus 生态安装&#xff08;Mac&#xff09; 安装 prometheus 安装 grafana 安装 node_exporter Prometheus 生态安装&#xff08;Docker&#xff09; 安装 prometheus 安装 Grafana 安装…

mysql-联合查询

一.联合查询的概念 .对于unio查询,就是把多次查询的结果合并起来,形成一个新的查询果集。 SELECT 字段列表 FROM 表A... UNION[ALL] SELECT 字段列表 FROM 表B...&#xff0c; 二.将薪资低于5000的员工,和年龄大于50岁的员工全部查询出来 select * from emp where salary&…

【数智化CIO展】沃太能源CIO陈丽:AI 浪潮下的中国企业数智化转型机遇与挑战...

陈丽 本文由沃太能源CIO陈丽投递并参与由数据猿联合上海大数据联盟共同推出的《2024中国数智化转型升级优秀CIO》榜单/奖项评选。 大数据产业创新服务媒体 ——聚焦数据 改变商业 在当今飞速发展的数字时代&#xff0c;中国企业正面临着前所未有的变革机遇和挑战。“中国企业数…

02:项目二:感应开关盖垃圾桶

感应开关盖垃圾桶 1、PWM开发SG901.1、怎样通过C51单片机输出PWM波&#xff1f;1.2、通过定时器输出PWM波来控制SG90 2、超声波测距模块的使用3、感应开关盖垃圾桶 需要材料&#xff1a; 1、SG90舵机模块 2、HC-SR04超声波模块 3、震动传感器 4、蜂鸣器 5、若干杜邦线 1、PWM开…

win10 docker-compose搭建ELK日志收集

elk的威名大家都知道&#xff0c;以前前司有专门的人维护&#xff0c;现在换了环境&#xff0c;实在不想上服务器看&#xff0c;所以就摸索下自己搭建&#xff0c;由于现场服务器是需要类似向日葵那样连接&#xff0c;我还是把日志弄回来&#xff0c;自己本地filebeat上传到es中…

WPF学习(2) -- 样式基础

一、代码 <Window x:Class"学习.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml"xmlns:d"http://schemas.microsoft.com/expression/blend/2008&…

jenkins系列-02.配置jenkins

首先&#xff1a;我们要给jenkins配备jdkmaven: 从上一节我们知道 ~/dockerV/jenkins/jenkins/data目录 就是 容器中jenkins的home目录 所以把jdkmaven 放在当前宿主机上的 ~/dockerV/jenkins/jenkins/data目录下即可 容器内&#xff1a; 开始配置jenkins: 注意是在jenkins…

护网HW面试常问——组件中间件框架漏洞(包含流量特征)

apache&iis&nginx中间件解析漏洞 参考我之前的文章&#xff1a;护网HW面试—apache&iis&nginx中间件解析漏洞篇-CSDN博客 log4j2 漏洞原理&#xff1a; 该漏洞主要是由于日志在打印时当遇到${后&#xff0c;以:号作为分割&#xff0c;将表达式内容分割成两部…

时间轮算法理解、Kafka实现

概述 TimingWheel&#xff0c;时间轮&#xff0c;简单理解就是一种用来存储若干个定时任务的环状队列&#xff08;或数组&#xff09;&#xff0c;工作原理和钟表的表盘类似。 关于环形队列&#xff0c;请参考环形队列。 时间轮由两个部分组成&#xff0c;一个环状数组&…

韦东山嵌入式linux系列-驱动设计的思想(面向对象/分层/分离)

1 面向对象 字符设备驱动程序抽象出一个 file_operations 结构体&#xff1b; 我们写的程序针对硬件部分抽象出 led_operations 结构体。 2 分层 上下分层&#xff0c;比如我们前面写的 LED 驱动程序就分为 2 层&#xff1a; ① 上层实现硬件无关的操作&#xff0c;比如注册…

WPF学习(3) -- 控件模板

一、操作过程 二、代码 <Window x:Class"学习.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml"xmlns:d"http://schemas.microsoft.com/expressio…

css基础(1)

CSS CCS Syntax CSS 规则由选择器和声明块组成。 CSS选择器 CSS选择器用于查找想要设置样式的HTML元素 一般选择器分为五类 Simple selectors (select elements based on name, id, class) 简单选择器&#xff08;根据名称、id、类选择元素&#xff09; //页面上的所有 …

WEB前端03-CSS3基础

CSS3基础 1.CSS基本概念 CSS是Cascading Style Sheets&#xff08;层叠样式表&#xff09;的缩写&#xff0c;它是一种对Web文档添加样式的简单机制&#xff0c;是一种表现HTML或XML等文件外观样式的计算机语言&#xff0c;是一种网页排版和布局设计的技术。 CSS的特点 纯C…