Matlab多输入单输出之倾斜手写数字识别

本本主要介绍使用matlab构建多输入单输出的网络架构,来实现倾斜的手写数字识别,使用concatenationLayer来拼接特征,实现网络输入多个特征。

1.加载训练数据

加载数据:手写数字的图像、真实数字标签和数字顺时针旋转的角度。

load DigitsDataTrain

网络的输入数据类型需要是datastore,使用 arrayDatastore 将三个普通矩阵变为datastore,最后再使用combine合并。

dsX1Train = arrayDatastore(XTrain,IterationDimension=4);
dsX2Train = arrayDatastore(anglesTrain);
dsTTrain = arrayDatastore(labelsTrain);
dsTrain = combine(dsX1Train,dsX2Train,dsTTrain);

显示20个随机训练图像:

numObservationsTrain = numel(labelsTrain);
idx = randperm(numObservationsTrain,20);

figure
tiledlayout("flow");
for i = 1:numel(idx)
    nexttile
    imshow(XTrain(:,:,:,idx(i)))
    title("Angle: " + anglesTrain(idx(i)))
end

图片

2.设计网络架构

设计如下的网络结构:

图片

  • 对于图像输入,指定一个大小与输入数据匹配的图像输入层。

  • 对于特征输入,指定一个大小与输入特征数量匹配的特征输入层。

  • 对于图像输入分支,进行卷积、批归一化和ReLU层块,其中卷积层有16个5×5滤波器。

  • 为了将批归一化层的输出转换为特征向量,需要用一个大小为50的全连接层。

  • 要将第一个全连接层的输出与特征输入连接起来,使用flatten layer将全连接层中的 "SSCB" (空间、空间、通道、批处理)输出展平,使其具有 "CB" 格式。

  • 沿第一维度(channel维度)将平坦层的输出与特征输入连接起来

  • 对于分类输出,包括一个输出大小与类数匹配的全连接层,然后是softmax层。

创建一个空的神经网络:

net = dlnetwork;

创建一个网络主分支,并将其添加到网络中:

[h,w,numChannels,numObservations] = size(XTrain);
numFeatures = 1;
classNames = categories(labelsTrain);
numClasses = numel(classNames);

imageInputSize = [h w numChannels];
filterSize = 5;
numFilters = 16;

layers = [
    imageInputLayer(imageInputSize,Normalization="none")
    convolution2dLayer(filterSize,numFilters)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(50)
    flattenLayer
    concatenationLayer(1,2,Name="cat")
    fullyConnectedLayer(numClasses)
    softmaxLayer];

net = addLayers(net,layers);

将feature input layer添加到网络中,并将其连接到 concatenation layer的第二个输入:

featInput = featureInputLayer(numFeatures,Name="features");
net = addLayers(net,featInput);
net = connectLayers(net,"features","cat/in2");

在绘图中可视化网络:

figure
plot(net)

3.训练网络

使用SGDM优化器进行训练,训练15个epochs,以0.01的学习率进行训练,在图表中显示训练进度并监控accuracy指标,不显示详细输出。

options = trainingOptions("sgdm", ...
    MaxEpochs=15, ...
    InitialLearnRate=0.01, ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=0);

使用 trainnet 函数训练神经网络,对于分类使用交叉熵损失。

net = trainnet(dsTrain,net,"crossentropy",options);

图片

4.测试网络

通过将测试集上的预测与真实标签进行比较来测试网络的分类准确性,加载测试数据:

load DigitsDataTest

使用 minibatchpredict 函数进行预测,并使用 scores2label 函数将分数转换为标签。

scores = minibatchpredict(net,XTest,anglesTest);
YTest = scores2label(scores,classNames);

在混淆图中可视化预测:

figure
confusionchart(labelsTest,YTest)

图片

评估分类准确性:

accuracy = mean(YTest == labelsTest)

accuracy = 0.9878

查看一些预测的图像:

idx = randperm(size(XTest,4),9);
figure
tiledlayout(3,3)
for i = 1:9
    nexttile
    I = XTest(:,:,:,idx(i));
    imshow(I)

    label = string(YTest(idx(i)));
    title("Predicted Label: " + label)
end

图片

5.不用角度特征训练和测试网络

% 网络设计
net_without_feature = dlnetwork;
layers = [
    imageInputLayer(imageInputSize,Normalization="none")
    convolution2dLayer(filterSize,numFilters)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

net_without_feature = addLayers(net_without_feature,layers);
% 网络训练
options = trainingOptions("sgdm", ...
    MaxEpochs=15, ...
    InitialLearnRate=0.01, ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=0);

dsTrain_without_feature = combine(dsX1Train,dsTTrain);

net_without_feature = trainnet(dsTrain_without_feature,net_without_feature,"crossentropy",options);

图片

% 在混淆矩阵中可视化预测。
scores = minibatchpredict(net_without_feature,XTest);
YTest = scores2label(scores,classNames);
figure
confusionchart(labelsTest,YTest)

图片

% 评估分类准确性。
accuracy = mean(YTest == labelsTest)

accuracy = 0.9858

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/918107.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

pytest结合allure做接口自动化

这是一个采用pytest框架,结合allure完成接口自动化测试的项目,最后采用allure生成直观美观的测试报告,由于添加了allure的特性,使得测试报告覆盖的内容更全面和阅读起来更方便。 1. 使用pytest构建测试框架,首先配置好…

【无人机设计与控制】基于MATLAB的四旋翼无人机PID双闭环控制研究

摘要 本文基于MATLAB/Simulink环境,对四旋翼无人机进行了PID双闭环控制设计与仿真研究。通过分析四旋翼无人机的动力学模型与运动学模型,建立了姿态和位置双闭环控制系统,以实现无人机的稳定飞行与精确轨迹跟踪。仿真实验验证了该控制策略的…

强大的正则表达式——Easy

进入题目界面输入难度1后,让我们输入正则表达式(regex): 目前不清楚题目要求,先去下载附件查看情况: import re import random# pip install libscrc import libscrcallowed_chars "0123456789()|*&q…

pytest | 框架的简单使用

这里写目录标题 单个文件测试方法执行测试套件的子集测试名称的子字符串根据应用的标记进行选择 其他常见的测试命令 pytest框架的使用示例 pytest将运行当前目录及其子目录中test_*.py或 *_test.py 形式的所有 文件 文件内的函数名称可以test* 或者test_* 开头 单个文件测试…

【安卓恶意软件检测-论文】DroidEvoler:自我进化的 Android 恶意软件检测系统

DroidEvolver:自我进化的 Android 恶意软件检测系统 摘要 鉴于Android框架的频繁变化和Android恶意软件的不断演变,随着时间的推移以有效且可扩展的方式检测恶意软件具有挑战性。为了应对这一挑战,我们提出了DroidEvolver,这是一…

Vulnhub靶场 Billu_b0x 练习

目录 0x00 准备0x01 主机信息收集0x02 站点信息收集0x03 漏洞查找与利用1. 文件包含2. SQL注入3. 文件上传4. 反弹shell5. 提权(思路1:ssh)6. 提权(思路2:内核)7. 补充 0x04 总结 0x00 准备 下载链接&#…

LabVIEW弧焊参数测控系统

在现代制造业中,焊接技术作为关键的生产工艺之一,其质量直接影响到最终产品的性能与稳定性。焊接过程中,电流、电压等焊接参数的精确控制是保证焊接质量的核心。基于LabVIEW开发的弧焊参数测控系统,通过实时监控和控制焊接过程中关…

CentOS网络配置

上一篇文章:VMware Workstation安装Centos系统 在CentOS系统中进行网络配置是确保系统能够顺畅接入网络的重要步骤。本文将详细介绍如何配置静态IP地址、网关、DNS等关键网络参数,以帮助需要的人快速掌握CentOS网络配置的基本方法和技巧。通过遵循本文的…

低速接口项目之串口Uart开发(一)——串口UART

本节目录 一、串口UART 二、串口协议 三、串口硬件 四、往期文章链接本节内容 一、串口UART 串口UART,通用异步收发传输器(Universal Asynchronnous Receiver / Transmitter),一种异步收发传输器,全双工传输。数据发送时,将并行…

Uni-APP+Vue3+鸿蒙 开发菜鸟流程

参考文档 文档中心 运行和发行 | uni-app官网 AppGallery Connect DCloud开发者中心 环境要求 Vue3jdk 17 Java Downloads | Oracle 中国 【鸿蒙开发工具内置jdk17,本地不使用17会报jdk版本不一致问题】 开发工具 HBuilderDevEco Studio【目前只下载这一个就…

SQL 外连接

1 外连接 外连接是一种用于结合两个或多个表的方式,返回至少一个表中的所有记录。 左外连接 LEFT JOIN,左表为驱动表,右表为从表。返回驱动表的所有记录以及从表中的匹配记录。如果从表没有匹配,则结果中从表的部分为NULL。 右…

笔记|M芯片MAC (arm64) docker上使用 export / import / commit 构建amd64镜像

很简单的起因,我的东西最终需要跑在amd64上,但是因为mac的架构师arm64,所以直接构建好的代码是没办法跨平台运行的。直接在arm64上pull下来的docker镜像也都是arm64架构。 检查镜像架构: docker inspect 8135f475e221 | grep Arc…

SAP+Internet主题HTML样式选择

SAP目前只支持三种HTML样式选择: 样式一 背景色:深色,蓝 特点:适中型排列,与SAP界面排列相同,富含UI特征,整齐美观 URL地址:http://cn1000-sap-01.sc.com:8000/sap/bc/gui/sap/it…

使用 Qt 实现基于海康相机的图像采集和显示系统(不使用外部视觉库,如Halcon\OpenCv)[工程源码联系博主索要]

本文将梳理一个不借助外部视觉库(如 OpenCV/Halcon)的海康相机图像采集和显示 Demo。该程序直接使用 Qt GUI 来显示图像。通过海康 MVS SDK 实现相机的连接、参数设置、图像采集和异常处理等功能,并通过 Qt 界面展示操作结果。 1. 功能概述 …

C# 异步Task异常处理和堆栈追踪显示

Task的问题 在C#中异步Task是一个很方便的语法,经常用在处理异步,例如需要下载等待等方法中,不用函数跳转,代码阅读性大大提高,深受大家喜欢。 但是有时候发现我们的异步函数可能出现了报错,但是异常又没…

31.3 XOR压缩和相关的prometheus源码解读

本节重点介绍 : xor 压缩value原理xor压缩过程讲解xor压缩prometheus源码解读xor 压缩效果 xor 压缩value原理 原理:时序数据库相邻点变化不大,采用异或压缩float64的前缀和后缀0个数 xor压缩过程讲解 第一个值使用原始点存储计算和前面的值的xor 如果XOR值为0&…

游戏引擎学习第16天

视频参考:https://www.bilibili.com/video/BV1mEUCY8EiC/ 这些字幕讨论了编译器警告的概念以及如何在编译过程中启用和处理警告。以下是字幕的内容摘要: 警告的定义:警告是编译器用来告诉你某些地方可能存在问题,尽管编译器不强制要求你修复…

解析煤矿一张图

解析煤矿一张图 ​ 煤矿一张图是指通过数字化、智能化技术将煤矿的各项信息、数据和资源进行集中展示和管理,形成一个综合的可视化平台。这一平台将矿井的地理信息、设备状态、人员位置、安全生产、环境监测等信息整合成一个统一的“图形”,以便于管理者…

Python学习27天

字典 dict{one:1,two:2,three:3} # 遍历1: # 先取出Key for key in dict:# 取出Key对应的valueprint(f"key:{key}---value:{dict[key]}")#遍历2,依次取出value for value in dict.values():print(value)# 遍历3:依次取出key,value …

【伪造检测】Noise Based Deepfake Detection via Multi-Head Relative-Interaction

一、研究动机 [!note] 动机:目前基于噪声的检测是利用Photo Response Non-Uniformity (PRNU)实现的,这是一种由于相机感光传感器而造成的缺陷噪声,主要用图像的源识别,在伪造检测的任务中并没有很好的表现。因此在文中提出了一种基…