240622_昇思学习打卡-Day4-ResNet50迁移学习

240622_昇思学习打卡-Day4-ResNet50迁移学习

我们对事物的认知都是一点一点积累出来的,往往借助已经认识过的东西,可以更好地理解和认识新的有关联的东西。比如一个人会骑自行车,我们让他去骑摩托车他也很快就能学会,比如已经学会C++,现在让他去学python他也很容易就能理解。这种情况我们一般称为举一反三。反言之,我们从原始部落找出来一个人(仅作举例),指着摩托车让他骑,可能是一件特别难的事,因为他对这个领域没有丝毫的认知和理解,在实现这件事上就会特别困难。

映射到神经网络上也是一样的道理,如果我们在训练时不导入预训练权重,他就像一个没有见过现代社会的原始人,学任何东西都特别慢,学习成本特别高,但如果我们导入了相似模型结构下针对别的任务的训练权重(比如训练识别自行车),用来训练识别摩托车,我们只需要改变网络最后的分类层,即可得到比较好的训练效果,可以大大缩小模型训练的时间。

原理是我在这么多层神经网络的训练下,已经明白了轱辘(车轮)长什么样,把手长什么样,最后的分类层只是区分出来什么是自行车,你现在给我一堆摩托的照片,我就可以去寻找两者的相似处,我对轱辘和把手的认知就不用从0开始重新学习,只需要进行微调,比如摩托车的轱辘比自行车大一点,摩托车的车把手比自行车大一点。基于以前已经学习到的东西,可以大大缩小训练成本。

迁移学习就是这个道理。我们在神经网络技术的发展中,针对不同的任务,不可能每个网络都从0开始训练,那样需要的数据集及成本都是不可承受的。

本文以ResNet50迁移学习为例展开讲解,在MindSpore架构下运行。

数据准备

下载数据集

本文用到狗与狼分类数据集,使用download接口下载,也可自行下载放在项目当前目录下

from download import download

dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"

download(dataset_url, "./datasets-Canidae", kind="zip", replace=True)

数据目录结构如下:

text
datasets-Canidae/data/
└── Canidae
    ├── train
    │   ├── dogs
    │   └── wolves
    └── val
        ├── dogs
        └── wolves

首先定义一些超参数:

batch_size = 18                             # 批量大小
image_size = 224                            # 训练图像空间大小
num_epochs = 5                             # 训练周期数
lr = 0.001                                  # 学习率
momentum = 0.9                              # 动量
workers = 4                                 # 并行线程个数

加载数据集以及做一些数据增强(本文所用的狼狗数据集属于ImageNet数据集,其典型mean和std值分别为[0.485, 0.456, 0.406]和[0.229, 0.224, 0.225],所以代码中直接使用):

# 导入MindSpore库,用于深度学习框架
import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision

# 定义训练数据集和验证数据集的路径
# 数据集目录路径
data_path_train = "./datasets-Canidae/data/Canidae/train/"
data_path_val = "./datasets-Canidae/data/Canidae/val/"

# 定义函数,用于创建Canidae分类任务的训练集或验证集
# 参数dataset_path: 数据集路径,usage: 数据集的用途,"train"或"val"
# 返回处理后的数据集
# 创建训练数据集
def create_dataset_canidae(dataset_path, usage):
    """数据加载"""
    # 初始化ImageFolderDataset,使用多线程并打乱数据顺序
    # 使用mindspore.dataset.ImageFolderDataset接口来加载数据集
    data_set = ds.ImageFolderDataset(dataset_path,
                                     num_parallel_workers=workers,
                                     shuffle=True,)

    # 定义数据预处理的参数
    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
    scale = 32

    # 根据数据集的用途(训练或验证)选择不同的数据增强操作
    if usage == "train":
        # 训练集的数据增强操作,包括随机裁剪、水平翻转、归一化等
        # Define map operations for training dataset
        trans = [
            vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            vision.RandomHorizontalFlip(prob=0.5),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]
    else:
        # 验证集的数据增强操作,主要包括解码、缩放、中心裁剪、归一化等
        # Define map operations for inference dataset
        trans = [
            vision.Decode(),
            vision.Resize(image_size + scale),
            vision.CenterCrop(image_size),
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]

    # 对数据集应用预处理操作
    # 数据映射操作
    data_set = data_set.map(
        operations=trans,
        input_columns='image',
        num_parallel_workers=workers)

    # 将数据集分批处理,指定批大小
    # 批量操作
    data_set = data_set.batch(batch_size)

    return data_set

# 创建训练数据集和验证数据集,并获取每个数据集的步长(即数据集的大小)
dataset_train = create_dataset_canidae(data_path_train, "train")
step_size_train = dataset_train.get_dataset_size()

dataset_val = create_dataset_canidae(data_path_val, "val")
step_size_val = dataset_val.get_dataset_size()

数据集可视化

mindspore.dataset.ImageFolderDataset接口中加载的训练数据集返回值为字典,用户可通过 create_dict_iterator 接口创建数据迭代器,使用 next 迭代访问数据集。前面 batch_size 设为18,所以使用 next 一次可获取18个图像及标签数据。

# 获取训练数据集的第一个批次数据,是18张图像及标签数据。
data = next(dataset_train.create_dict_iterator())
images = data["image"]
labels = data["label"]

print("Tensor of image", images.shape)
print("Labels:", labels)

'''
执行结果为
Tensor of image (18, 3, 224, 224)    # 意思是这一批有18张3通道(RGB通道)的长224宽224的图像
Labels: [1 1 0 1 1 0 1 0 0 0 0 0 0 0 0 1 0 1]    # 因为该任务是一个二分类任务,所以类别只有简单的0和1
'''

目前拿到的数据我们可以先看看长什么样,展示图像及标题,标题为对应的label名称

# 导入matplotlib.pyplot库用于绘图
import matplotlib.pyplot as plt
# 导入numpy库用于处理数组
import numpy as np

# 定义一个字典,映射类别编号到类别名称
# class_name对应label,按文件夹字符串从小到大的顺序标记label
class_name = {0: "dogs", 1: "wolves"}

# 创建一个5x5大小的画布
plt.figure(figsize=(5, 5))
# 循环遍历4个图像
for i in range(4):
    # 获取当前图像的数据和标签
    # 获取图像及其对应的label
    data_image = images[i].asnumpy()
    data_label = labels[i]
    # 将图像数据从HWC格式转换为RGB格式
    # 处理图像供展示使用
    data_image = np.transpose(data_image, (1, 2, 0))
    # 对图像进行预处理,包括归一化和颜色空间转换
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    data_image = std * data_image + mean
    data_image = np.clip(data_image, 0, 1)
    # 在画布上创建子图,并显示当前图像
    # 显示图像
    plt.subplot(2, 2, i+1)
    plt.imshow(data_image)
    # 设置子图标题为图像的类别名称
    plt.title(class_name[int(labels[i].asnumpy())])
    # 关闭子图的坐标轴显示
    plt.axis("off")

# 显示画布上的所有图像
plt.show()

image-20240622221036216

训练模型

from typing import Type, Union, List, Optional
from mindspore import nn, train
from mindspore.common.initializer import Normal


weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)

今日就先写这些,明天学完之后汇总。

不知道为什么我在这篇里面点编辑然后发表之后变成了一篇新的博客,那这里就附上后续的博客地址吧
240622_昇思学习打卡-Day4-5-ResNet50迁移学习

打卡图片:
在这里插入图片描述







本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/753151.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

电脑提醒事项怎么显示在桌面

在繁忙的工作节奏中,我们经常会面临多项任务同时进行的情况。为了确保不遗漏任何重要事务,设置电脑提醒事项就显得尤为重要。想象一下,当你正忙于一个项目时,电脑屏幕突然弹出一个提醒,告诉你接下来的会议时间&#xf…

梦想CAD二次开发

1.mxdraw简介 mxdraw是一个HTML5 Canvas JavaScript框架,它在THREE.js的基础上扩展开发,为用户提供了一套在前端绘图更为方便,快捷,高效率的解决方案,mxdraw的实质为一个前端二维绘图平台。你可以使用mxdraw在画布上绘…

实力认可!安全狗受聘成为福建省网信系统2024年度网络安全技术支撑单位

6月6日,福建省委网信办组织召开福建省网信系统2024年度网络安全技术支撑单位座谈会。 作为国内云原生安全领导厂商,安全狗也受邀出席此次活动。 省委宣传部副部长、省委网信办主任、省互联网信息办公室主任张远出席会议并颁发支撑单位证书。安全狗凭借出…

如何用Vue3和Plotly.js绘制动态3D图表?

本文由ScriptEcho平台提供技术支持 项目地址:传送门 Plotly.js: 使用Vue.js动态加载数据并绘制图表 应用场景 在数据可视化应用中,需要将数据动态加载到图表中并进行实时更新。本文将展示如何使用Plotly.js和Vue.js实现这一功能,从加载外…

java基于ssm+jsp 电子商城系统

1管理员功能模块 管理员登录,通过填写用户名、密码进行登录,如图1所示。 图1管理员登录界面图 管理员登录进入电子商城系统可以查看个人中心、用户管理、医生管理、药品信息管理、线上诊疗管理、医生信息管理、管理员管理、论坛管理、系统管理、订单管…

snowflake 不再是个数据仓库公司了

标题先上结论,为啥这么认为,且听接下来道来。 snowflake 非常成功,开创了云数仓先河,至今在数仓架构上也是相对比较先进的,国内一堆模仿的公司,传统上我们会认为 snowflake 肯定是一家数据仓库公司。不过最…

智能工业网络,需要何种工业以太网交换机作为支撑?

随着工业企业数字化及信息化的进一步深化升级,工业领域相关控制及信息系统的业务类型不断增加、复杂性不断提升,工业控制网络与工业信息网络也呈现融合趋势,具备支持多业务、多协议、多厂商设备和数据的互联互通、共网承载以及高质量传输能力…

vue-router的学习

vue-router 基于Vue做单页面应用时,vue-router可以帮我们将url地址和组件绑定,在我们切换路由组件时,url改变,页面无需从服务端重新加载,即不用刷新,打个比方,我们用美团外卖点外卖时&#xff…

一文带你彻底搞懂设计模式之单例模式!!由浅入深,图文并茂,超超超详细的单例模式讲解!!

一文带你彻底搞懂设计模式之单例模式! 一、什么是单例模式?单例模式分类饿汉式创建单例对象懒汉式创建单例对象 多问一个为什么? 二、为什么要有单例模式?使用单例模式的原因单例模式的应用场景 三、多线程下的单例模式饿汉式懒汉…

【高级篇】InnoDB引擎深入:核心机制与实战优化(十五)

引言 在探索了MySQL集群与分布式技术之后,我们进入了数据库引擎的核心地带——InnoDB。作为MySQL的默认存储引擎,InnoDB凭借其对事务的支持、行级锁定、高效的恢复机制以及复杂的内存管理,成为众多应用场景的首选。本章,我们将深入InnoDB的内部机制,透彻理解锁管理、事务…

【C++】动态内存管理new和delete

文章目录 一、C的内存管理方式二、new和delete的用法1.操作内置类型2.操作自定义内置类型 三、new和delete的底层实现1.operator new和operator delete函数2.new和delete的实现原理 四、定位new表达式五、malloc/free和new/delete的区别 一、C的内存管理方式 之前在C语言的动态…

找不到vcomp140.dll怎么办,总结多种解决方法

​在日常使用电脑的过程中,我们可能会遇到一些错误提示,其中之一就是“vcomp140.dll丢失”。那么,vcomp140.dll是什么?它为什么会丢失?丢失后对电脑有什么影响?又该如何解决呢?本文将详细介绍vc…

leetcode 动态规划(基础版)单词拆分

题目: 题解: 一种可行的dp做法是基于完全背包问题,将s看成是一个背包,wordDict看作是物品,然后往s中放入物品判断最终是否可以变为给定的s即可。这道题和上一题都用到了在dp如何枚举连续子串和状态表示:枚…

一个 API 客户端和一份 TS 学习手册

第75期: Insomnia:超好看的 API 客户端 项目介绍: 一款适用于 GraphQL、REST、WebSockets 和 gRPC 的开源 API 客户端,颜值超高。 跨平台,支持 Mac、Windows 和 Linux。但不支持网页版,需要下载客户端。…

如何借助ai(文心一言)获取tushare的数据

1. 准备工作 确保已安装python ,安装Tushare库 和文心一言的地址(文心一言): 注册Tushare账号并获取Token:在Tushare官方网站注册账号,并获取个人Token。如下 tushare地址:(点击即…

PD快充诱骗芯片工作原理,USB-C充电器出不来电压是什么原因?

一般使用Type-C接口的充电器基本上都是采用新的快充协议——PD快充协议,它不同于以前的USB-A的QC协议,这种协议,默认是没有快充电压输出的,VBUS和GND是0V。 所以,我们可以使用电阻的方式(电流小&#xff09…

【Apache Doris】如何实现高并发点查?(原理+实践全析)

【Apache Doris】如何实现高并发点查?(原理实践全析) 一、背景说明二、原理介绍三、环境信息四、Jmeter初始化五、参数预调六、用例准备七、高并发实测八、影响因素九、总结 本文主要分享 Apache Doris 是如何实现高并发点查的,以…

突破SaaS产品运营困境:多渠道运营如何集中管理?

随着数字化时代的到来,SaaS(软件即服务)产品已成为企业日常运营不可或缺的工具。然而,在竞争激烈的市场环境下,SaaS产品运营越来越重视多渠道、多平台布局,以更广泛地触及潜在用户,然而&#xf…

《昇思25天学习打卡营第10天 | 昇思MindSporeFCN图像语义分割》

第10天 本节学习了FCN图像语义分割。全卷积网络是用于图像语义分割的一种框架。FCN是首个端到端(end to end)进行像素级(pixel level)预测的全卷积网络。FCN有两大明显的优点:一是可以接受任意大小的输入图像&#xff…

2024年跨境电商关键数据统计:市场规模将达到1.976万亿美元

预计2024年跨境电商消费市场规模将达到1.976万亿美元,占全球网上销售总额的31.2%。这一数据无疑展示了跨境电商市场的巨大潜力和迅猛增长趋势。 全球跨境电商的现状与未来 现状 2023年,全球跨境电商市场规模预计达到1.56万亿美元,占全球电子…