MATLAB深度学习(七)——ResNet残差网络

一、ResNet网络

        ResNet是深度残差网络的简称。其核心思想就是在,每两个网络层之间加入一个残差连接,缓解深层网络中的梯度消失问题

二、残差结构

        在多层神经网络模型里,设想一个包含诺干层自网络,子网络的函数用H(x)来表示,其中x是子网络的输入。残差学习是通过重新设定这个参数,让一个参数层表达一个残差函数                 ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        F(x)=H(x)-x

        因此这个子网络的输出y 就是 

        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​   y=F(x)+x

        其中 +x 的操作,是通过一个相当于恒等映射的跳跃连接来完成的,它将残差块的输入直接与输出连接,这就是残差结构。按照上面的结构递推,根据前向传播,第i个残差块的输出就是第i+1个残差块的输入

        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        x_{\ell+1}=F(x_\ell)+x_\ell

        根据递归公式,可以推导出

        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        x_L=x_\ell+\sum_{i=l}^{L-1}F(x_i)

        这里的L表示任意后续残差块,i是靠前块,那么公式就说明了总会有信号能从浅层到深层。

        从反向传播来看,根据上面的公式对 Xl 进行求导,可以得到

        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        \begin{aligned} \frac{\partial\mathcal{E}}{\partial x_\ell} & =\frac{\partial\mathcal{E}}{\partial x_L}\frac{\partial x_L}{\partial x_\ell} \\ & =\frac{\partial\mathcal{E}}{\partial x_L}\left(1+\frac{\partial}{\partial x_\ell}\sum_{i=l}^{L-1}F(x_i)\right) \\ & =\frac{\partial\mathcal{E}}{\partial x_L}+\frac{\partial\mathcal{E}}{\partial x_L}\frac{\partial}{\partial x_\ell}\sum_{i=l}^{L-1}F(x_i) \end{aligned}

        这里的 \mathcal{E} 是最小损失化函数。以上说明,浅层的梯度计算 \frac{\partial\mathcal{E}}{\partial x_{\ell}},总会直接加上上一个项 \frac{\partial\mathcal{E}}{\partial x_L}因为存在额外的一项,所以就想 F(xi)很小,总的梯度都不会消失。

三、基于ResNet识别实现步骤

        其主要步骤为1.加载图像数据,并将数据分为训练集合与验证集;2.加载MATLAB训练好的ResNet50;3.和Alexnet一样替换最后几层;4.按照网络配置调整图像数据;5.对网络进行训练

3.1 调整ResNet实现迁移学习

        针对ImageNet的数据任务,原本最后三层 FC SOFTMAX 输出层是针对1000个类别的物体进行识别,针对图像问题,继续调整这三层,首先冻结上面的层。将下面的三层进行替换。由于ResNet50需要输入的图像大小为224 * 224 *3.

unzip('MerchData.zip');
img_ds = imageDatastore('MerchData', ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

total_split = countEachLabel(img_ds); %返回一个包含每个标签和相应图像数量的表格

num_images = length(img_ds.Labels); %返回的图像的个数,5个类型都有15张照片
perm =  randperm(num_images,10);  %随机取出10个
figure
for i = 1:9
    subplot(3,3,i)
    imshow(imread(img_ds.Files{perm(i)})); %从图像数据集 (img_ds) 中读取一张图像并显示它
    %可以用alexnet的方法
end

test_idx = randperm(num_images,9); %随机取5张做样本
img_ds_Test = subset(img_ds,test_idx);
train_idx = setdiff(1:length(img_ds.Files),test_idx);
img_ds_Train = subset(img_ds,train_idx);



%% 步骤2:加载预训练好的网络

% 加载ResNet50网络(注:该网络需要提前下载,当输入下面命令时按要求下载即可)
net = resnet50;

%% 步骤3:对网络结构进行调整,替换最后几层

% 获取网络图结构
LayerGraph = layerGraph(net);
clear net;

% 确定训练数据中新冠图片标签类别数量:5类
numClasses = numel(categories(img_ds_Train.Labels));
disp(numClasses);

% 保留ResNet50倒数第三层之前的网络,并替换后3层
% 倒数第三层的全连接层,这里修改为5类
newLearnableLayer = fullyConnectedLayer(numClasses,...
    'Name','new_fc',...
    'WeightLearnRateFactor',10,...
'BiasLearnRateFactor',10);
%numClasses 分类任务。通过设置 WeightLearnRateFactor 和 BiasLearnRateFactor
%来控制学习率的调整,使得这些层在训练过程中能更快速地学习
% 分别替换最后3层:fc1000、softmax和分类输出层
LayerGraph = replaceLayer(LayerGraph,'fc1000',newLearnableLayer);
%替换了原网络中的 fc1000 层(ResNet50 中的全连接层)为 new_fc 层
% (即刚才定义的新的全连接层)。replaceLayer 函数通过层的名字来替换图层


newSoftmaxLayer = softmaxLayer('Name','new_softmax');
LayerGraph = replaceLayer(LayerGraph,'fc1000_softmax',newSoftmaxLayer);


newClassLayer = classificationLayer('Name','new_classoutput');
LayerGraph = replaceLayer(LayerGraph,'ClassificationLayer_fc1000',newClassLayer);

%% 步骤4:按照网络配置调整图像数据

% 输入图像格式转换,这里调用了自定义函数preprocess
img_ds_Train.ReadFcn = @(filename)preprocess(filename);
img_ds_Test.ReadFcn  = @(filename)preprocess(filename);

% 数据增强的参数
augmenter = imageDataAugmenter(...
    'RandRotation',[-5 5],...
    'RandXReflection',1,...
    'RandYReflection',1,...
    'RandXShear',[-0.05 0.05],...
    'RandYShear',[-0.05 0.05]);
% 将批量训练图像的大小调整为与输入层的大小相同
aug_img_ds_train = augmentedImageDatastore([224 224],img_ds_Train,'DataAugmentation',augmenter);
% 将批量测试图像的大小调整为与输入层的大小相同
aug_img_ds_test = augmentedImageDatastore([224 224],img_ds_Test);

%% 步骤5:对网络进行训练

% 对训练参数进行设置
options = trainingOptions('adam',...
    'MaxEpochs',10,...
    'MiniBatchSize',8,...
    'Shuffle','every-epoch',...
    'InitialLearnRate',1e-4,...
    'Verbose',false,...
    'Plots','training-progress',...
    'ExecutionEnvironment','cpu');

% 用训练图像对网络进行训练
netTransfer = trainNetwork(aug_img_ds_train,LayerGraph,options);

%% 步骤6:进行测试并查看结果

% 对训练好的网络采用验证数据集进行验证
[YPred,scores] = classify(netTransfer,aug_img_ds_test);

% 随机显示验证效果
idx = randperm(numel(img_ds_Test.Files),4);
figure
for i = 1:4
    subplot(2,2,i)
    I = readimage(img_ds_Test,idx(i));
    imshow(I)
    label = YPred(idx(i));
    title(string(label));
end

%% 计算分类准确率
YValidation = img_ds_Test.Labels;
accuracy = mean(YPred == YValidation);

%% 创建并显示混淆矩阵
figure
confusionchart(YValidation,YPred) 

实现效果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/935240.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

go语言zero框架调用自己的安装的redis服务配置与使用

在 Go 语言中调用自己安装的 Redis 服务,可以分为几个步骤:从安装 Redis 服务到配置、启动 Redis,最后在 Go 代码中连接并使用 Redis。以下是详细的步骤: ## 1. 安装 Redis 服务 ### 1.1 在 Linux 系统上安装 Redis 假设你使用…

Cerebras 推出 CePO,填补推理与规划能力的关键空白

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

Google Cloud Database Option(数据库选项说明)

关系数据库 在关系数据库中,信息存储在表、行和列中,这通常最适合结构化数据。因此,它们用于数据结构不经常更改的应用程序。与大多数关系数据库交互时使用 SQL(结构化查询语言)。它们为数据提供 ACID 一致性模式&am…

ArcGIS将MultiPatch数据转换为Obj数据

文章目录 ArcGIS将MultiPatch数据转换为Obj数据1 效果2 技术路线2.1 Multipatch To Collada2.2 Collada To Obj3 代码实现4 附录4.1 环境4.2 一些坑ArcGIS将MultiPatch数据转换为Obj数据 1 效果 2 技术路线 MultiPatch --MultipatchToCollada–> Collada --Assimp–> O…

微信小程序5-图片实现点击动作和动态加载同类数据

搜索 微信小程序 “动物觅踪” 观看效果 感谢阅读,初学小白,有错指正。 一、功能描述 a. 原本想通过按钮加载背景图片,来实现一个可以点击的搜索button,但是遇到两个难点,一是按钮大小调整不方便(网上搜索…

使用nmap确定扫描目标

nmap可以通过IP、主机名、域名等指定单一目标,也可以使用IP范围、列表文件、等指定多个IP。 单一目标 IP nmap IP主机名 nmap hostname域名 nmap domainname,可以通过--dns-server指定dns服务器地址,也可以通过--system-dns指定使用操作系统…

【C++】关联存储结构容器-set(集合)详解

目录 一、基本概念 二、内部实现 三、常用操作 3.1 构造函数 3.2 插入操作 3.3 删除操作 3.4 查找操作 3.5 访问元素 3.6 容量操作 3.7 交换操作 四、特性 五、应用场景 结语 一、基本概念 set是C标准模板库(STL)中的一种关联容器&#xf…

ssm-day03 aoptx

AOP AOP指的是面向对象思想编程问题的一些补充和完善 soutp、soutv 解耦通俗理解就是把非核心代码剥出来,减少对业务功能代码的影响 设计模式是解决某些特定问题的最佳解决方案,后面一点要记得学这个!!! cxk唱跳哈…

谷粒商城—分布式高级①.md

1. ELASTICSEARCH 1、安装elastic search dokcer中安装elastic search (1)下载ealastic search和kibana docker pull elasticsearch:7.6.2 docker pull kibana:7.6.2(2)配置 mkdir -p /mydata/elasticsearch/config mkdir -p /mydata/elasticsearch/data echo "h…

数据仓库的性能问题及解决之道

随着数据量不断增长和业务复杂度逐渐攀升,数据处理效率面临巨大挑战。最典型的表现是面向分析型场景的数据仓库性能问题越来越突出,压力大、性能低,查询时间长甚至查不出来,跑批跑不完造成生产事故等问题时有发生。当数据仓库出现…

云和恩墨 zCloud 与华为云 GaussDB 完成兼容性互认证

近日,云和恩墨(北京)信息技术有限公司(以下简称:云和恩墨)的多元数据库智能管理平台 zCloud 与华为云计算技术有限公司(以下简称:华为云)的 GaussDB 数据库完成了兼容性互…

分布式专题(4)之MongoDB快速实战与基本原理

一、MongoDB介绍 1.1 什么是MongoDB MongoDB是一个文档数据库(以JSON为数据模型),由C语言编写,旨在为WEB应用提供可扩展的高性能存储解决方案。 MongoDB是一个介于关系数据库和非关系数据库之间的产品,是非关系数据库当中功能最丰富&#xf…

03篇--二值化与自适应二值化

二值化 定义 何为二值化?顾名思义,就是将图像中的像素值改为只有两种值,黑与白。此为二值化。 二值化操作的图像只能是灰度图,意思就是二值化也是一个二维数组,它与灰度图都属于单信道,仅能表示一种色调…

Linux下redis环境的搭建

1.redis的下载 redis官网下载redis的linux压缩包,官网地址:Redis下载 网盘链接: 通过网盘分享的文件:redis-5.0.4.tar.gz 链接: https://pan.baidu.com/s/1cz3ifYrDcHWZXmT1fNzBrQ?pwdehgj 提取码: ehgj 2.redis安装与配置 将包上传到 /…

印闪网络:阿里云数据库MongoDB版助力金融科技出海企业降本增效

客户背景 上海印闪网络科技有限公司,于2017年1月成立,投资方包括红杉资本等多家国际知名风投公司。公司业务聚焦东南亚普惠金融,常年稳居行业头部。创始团队来自腾讯,中国团队主要由运营、风控及产研人员组成,核心成员…

【有啥问啥】大语言模型Prompt中的“System指令”:深入剖析与误区澄清

大语言模型Prompt中的“System指令”:深入剖析与误区澄清 引言 在与大语言模型(LLM)交互时,“prompt”(提示符)这一概念已不再陌生。Prompt是引导模型生成特定类型文本的关键输入,决定了模型的…

【解决】Vue配置了端口号 发布项目仍会改变

1 梗概 这里记录Vue配置了端口号,npm run serve 发布运行仍会选择其他端口,一般是配置的端口号1 2 解决方案 网上有些教程说是由于 portfonder 版本问题,需要降低版本,可能这个的确是个解决方案。 还有说在package.json 中配置 &q…

uniapp-在windows上IOS真机运行(含开发证书申请流程)

前期准备 1、Itunes [Windows 32位 iTunes]下载地址、所有版本的iTunes下载地址 [Windows 64位 iTunes]下载地址、所有版本的iTunes下载地址 2、爱思助手 https://www.i4.cn/ 3、typeC转Usb接口 (物理意义的设备接口) 4、Mac电脑(用来生成证书&am…

网络层分析

网络访问层仍受到传输介质的性质和相关适配器的设备驱动程序的影响很大。网络层与网络适配器的硬件性质几乎是完全分离的。为什么是几乎呢?该层不仅负责发送和接受网络数据,还负责在彼此不直接连接的系统之间转发和路由分组。查找最佳路由并选择适当的网…

罗技键鼠更换新台式机无蓝牙通过接收器安装

优联驱动下载: http://support.logitech.com.cn/zh_cn/software/unifying (下载安装后按照步骤一步步操作,匹配后即可使用) 向京东客服反馈后提供的驱动下载安装连接 有问题欢迎评论沟通~