昇思MindSpore学习总结五——网络构建

1、网络构建

        神经网络模型是由神经网络层和Tensor操作构成的,mindspore.nn提供了常见神经网络层的实现,在MindSpore中,Cell类是构建所有网络的基类,也是网络的基本单元。一个神经网络模型表示为一个Cell,它由不同的子Cell构成。使用这        样的嵌套结构,可以简单地使用面向对象编程的思维,对神经网络结构进行构建和管理。

mindspore.nn.Cell类——mindspore.nn.Cell(auto_prefix=Trueflags=None)

MindSpore中神经网络的基本构成单元。模型或神经网络层应当继承该基类。

mindspore.nn 中神经网络层也是Cell的子类,如 mindspore.nn.Conv2d 、 mindspore.nn.ReLU 等。Cell在GRAPH_MODE(静态图模式)下将编译为一张计算图,在PYNATIVE_MODE(动态图模式)下作为神经网络的基础模块。

【参数】

  • auto_prefix (bool,可选) - 是否自动为Cell及其子Cell生成NameSpace。该参数同时会影响 Cell 中权重参数的名称。如果设置为 True ,则自动给权重参数的名称添加前缀,否则不添加前缀。通常情况下,骨干网络应设置为 True ,否则会产生重名问题。用于训练骨干网络的优化器、 mindspore.nn.TrainOneStepCell 等,应设置为 False ,否则骨干网络的权重参数名会被误改。默认值: True 。

  • flags (dict,可选) - Cell的配置信息,目前用于绑定Cell和数据集。用户也通过该参数自定义Cell属性。默认值: None 。

下面构建一个用于Mnist数据集分类的神经网络模型 。

2、定义模型

        当我们定义神经网络时,可以继承nn.Cell类,在__init__方法中进行子Cell的实例化和状态管理,在construct方法中实现Tensor操作。

import mindspore
from mindspore import nn, ops

construct意为神经网络(计算图)构建。 

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 10, weight_init="normal", bias_init="zeros")
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

 构建完成后,实例化Network对象,并查看其结构。

model = Network()
print(model)

         构造一个输入数据,直接调用模型,可以获得一个十维的Tensor输出,其包含每个类别的原始预测值。

X = ops.ones((1, 28, 28), mindspore.float32)
logits = model(X)
# print logits
logits

 通过一个nn.Softmax层实例来获得预测概率。

pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

 3、模型层

        分解上节构造的神经网络模型中的每一层。首先我们构造一个shape为(3, 28, 28)的随机数据(3个28x28的图像),依次通过每一个神经网络层来观察其效果。

input_image = ops.ones((3, 28, 28), mindspore.float32)
print(input_image.shape)

 3.1 nn.Flatten层

mindspore.nn.Flatten(start_dim=1end_dim=- 1)

沿着从 start_dim 到 end_dim 的维度,对输入Tensor进行展平。

【参数】

  • start_dim (int, 可选) - 要展平的第一个维度。默认值: 1 。

  • end_dim (int, 可选) - 要展平的最后一个维度。默认值: -1 。

输入——>x (Tensor) - 要展平的输入Tensor,输出——>Tensor。如果没有维度被展平,返回原始的输入 input,否则返回展平后的Tensor。如果 input 是零维Tensor,将会返回一个一维Tensor。

将28x28的2D张量转换为784大小的连续数组。

flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.shape)

 3.2 nn.Dense层

mindspore.nn.Dense(in_channelsout_channelsweight_init=Nonebias_init=Nonehas_bias=Trueactivation=Nonedtype=mstype.float32)

全连接层。使用权重和偏差对输入进行线性变换。

适用于输入的密集连接层。公式如下:

outputs=activation(X∗kernel+bias)

其中 𝑋 是输入Tensor, activation 是激活函数, kernel 是一个权重矩阵,其数据类型与 𝑋 相同, bias 是一个偏置向量,其数据类型与 𝑋 相同(仅当has_bias为True时)。

【参数】

  • in_channels (int) - Dense层输入Tensor的空间维度。

  • out_channels (int) - Dense层输出Tensor的空间维度。

  • weight_init (Union[Tensor, str, Initializer, numbers.Number]) - 权重参数的初始化方法。数据类型与 x 相同。str的值引用自函数 initializer。默认值:None ,权重使用HeUniform初始化。

  • bias_init (Union[Tensor, str, Initializer, numbers.Number]) - 偏置参数的初始化方法。数据类型与 x 相同。str的值引用自函数 initializer。默认值:None ,偏差使用Uniform初始化。

  • has_bias (bool) - 是否使用偏置向量 bias 。默认值: True 。

  • activation (Union[str, Cell, Primitive, None]) - 应用于全连接层输出的激活函数。可指定激活函数名,如’relu’,或具体激活函数,如 mindspore.nn.ReLU 。默认值: None 。

  • dtype (mindspore.dtype) - Parameter的数据类型。默认值: mstype.float32 。

输入——>x (Tensor) - shape为 (∗,𝑖𝑛_𝑐ℎ𝑎𝑛𝑛𝑒𝑙𝑠) 的Tensor。参数中的 in_channels 应等于输入中的 𝑖𝑛_𝑐ℎ𝑎𝑛𝑛𝑒𝑙𝑠 

输出——>shape为 (∗,𝑜𝑢𝑡_𝑐ℎ𝑎𝑛𝑛𝑒𝑙𝑠) 的Tensor。

layer1 = nn.Dense(in_channels=28*28, out_channels=20)
hidden1 = layer1(flat_image)
print(hidden1.shape)

 3.3 nn.ReLU

         mindspore.nn.ReLU

逐元素计算ReLU(Rectified Linear Unit activation function)修正线性单元激活函数。

ReLU(𝑖𝑛𝑝𝑢𝑡)=(𝑖𝑛𝑝𝑢𝑡)+=max(0,𝑖𝑛𝑝𝑢𝑡),

逐元素求 max(0,𝑖𝑛𝑝𝑢𝑡) 。

print(f"Before ReLU: {hidden1}\n\n")
hidden1 = nn.ReLU()(hidden1)
print(f"After ReLU: {hidden1}")

3.4 nn.SequentialCell 

        mindspore.nn.SequentialCell(*args)

构造Cell顺序容器。关于Cell的介绍,可参考 Cell。

SequentialCell将按照传入List的顺序依次将Cell添加。此外,也支持OrderedDict作为构造器传入。

【参数】

  • args (list, OrderedDict) - 仅包含Cell子类的列表或有序字典。

seq_modules = nn.SequentialCell(
    flatten,
    layer1,
    nn.ReLU(),
    nn.Dense(20, 10)
)

logits = seq_modules(input_image)
print(logits.shape)

 3.5 nn.Softmax

        mindspore.nn.Softmax(axis=- 1)

逐元素计算Softmax激活函数,它是二分类函数 mindspore.nn.Sigmoid 在多分类上的推广,目的是将多分类的结果以概率的形式展现出来。

对输入Tensor在轴 axis 上的元素计算其指数函数值,然后归一化到[0, 1]范围,总和为1。

Softmax定义为:

其中, 𝑖𝑛𝑝𝑢𝑡𝑖 是输入Tensor在轴 axis 上的第 𝑖 个元素。

【参数】

  • axis (int,可选) - 指定Softmax运算的轴axis,假设输入 input 的维度为input.ndim,则 axis 的范围为 [-input.ndim, input.ndim) ,-1表示最后一个维度。默认值: -1 。

        最后使用nn.Softmax将神经网络最后一个全连接层返回的logits的值缩放为[0, 1],表示每个类别的预测概率。axis指定的维度数值和为1。

softmax = nn.Softmax(axis=1)
pred_probab = softmax(logits)

4、模型参数

         网络内部神经网络层具有权重参数和偏置参数(如nn.Dense),这些参数会在训练过程中不断进行优化,可通过 model.parameters_and_names() 来获取参数名及对应的参数详情。

print(f"Model structure: {model}\n\n")

for name, param in model.parameters_and_names():
    print(f"Layer: {name}\nSize: {param.shape}\nValues : {param[:2]} \n")

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/757321.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【机器学习300问】136、C4.5虽然改善了ID3决策树算法的部分缺点,但还是有不足,请问还有更好的算法吗?CART算法构建决策树

一、C4.5算法仍存在的不足 (1)计算效率不高 C4.5使用的信息增益率计算涉及熵的对数计算,特别是当属性值数量大时,计算成本较高。 (2)处理连续数值属性不够高效 ID3算法只能处理离散属性,需要预…

STM32CUBEMX配置USB虚拟串口

STM32CUBEMX配置USB虚拟串口 cubemx上默认配置即可。 外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 配置完后生成工程,主要就是要知道串口的收发接口就行了。 发送:CDC_Transmit_FS(),同时记得包含头文件#include “…

转运机器人:智能物流的得力助手

在物流行业,转运机器人已经成为提高转运效率、降低成本的重要工具。而富唯智能转运机器人凭借其出色的性能和智能化的设计,成为了众多企业的得力助手。 富唯智能转运机器人采用了先进的AMR控制系统,可以一体化控制移动机器人并实现与产线设备…

电脑提示vcomp140.dll丢失的几种有效的解决方法,轻松搞定dll问题

在电脑使用过程中,我们可能会遇到一些错误提示,其中之一就是找不到vcomp140.dll。那么,究竟什么是vcomp140.dll呢?为什么会出现找不到vcomp140.dll的情况呢?本文将从vcomp140.dll的定义、常见原因、对电脑的影响以及解…

go语言DAY7 字典Map 指针 结构体 函数

Go中Map底层原理剖析_go map底层实现-CSDN博客 目录 Map 键值对key,value 注意: map唯一确定的key值通过哈希运算得出哈希值 一、 map的声明及初始化: 二、 map的增删改查操作: 三、 map的赋值操作与切片对比: 四、 通用所有…

最佳学习率和Batch Size缩放中的激增现象

前言 《Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling》原文地址GitHub项目地址Some-Paper-CN。本项目是译者在学习长时间序列预测、CV、NLP和机器学习过程中精读的一些论文,并对其进行了中文翻译。还有部分最佳示例教程。如果有帮助到大家&a…

C语言学习记录(十二)——指针与数组及字符串

文章目录 前言一、指针和数组二、指针和二维数组**行指针(数组指针)** 三、 字符指针和字符串四、指针数组 前言 一个学习嵌入式的小白~ 有问题评论区或私信指出~ 提示:以下是本篇文章正文内容,下面案例可供参考 一、指针和数组 在C语言中 &#xff0…

网盘挂载系统-知识资源系统-私域内容展示系统

系统介绍: 存储:一共支持约30款云盘存储,其中包括主流的(百度网盘、阿里云盘、夸克云盘、迅雷云盘、蓝奏云、天翼云盘),部分展示 以及特别的(一刻相册、对象存储、又拍云存储、SFTP、MEGA 网盘…

锁机制 -- 概述篇

锁机制 1、概述 ​  加锁是为了解决并发场景下,多个线程对同一资源同时进行操作,而导致同一线程多次操作出现结果不唯一的情况(一次操作包含多条指令)。结果不唯一发生的原因在于指令的错乱,前提条件是多线程环境及…

原子变量原理剖析

一、原子操作 原子操作保证指令以原子的方式执行,执行过程不被打断。先看一个实例,如下所示,如果thread_func_a和thread_func_b同时运行,执行完成后,i的值是多少? // test.c static int i 0;void thread…

013、MongoDB常用操作命令与高级特性深度解析

目录 MongoDB常用操作命令与高级特性深度解析 1. 数据库操作的深入探讨 1.1 数据库管理 1.1.1 数据库统计信息 1.1.2 数据库修复 1.1.3 数据库用户管理 1.2 数据库事务 2. 集合操作的高级特性 2.1 固定集合(Capped Collections) 2.2 集合验证(Schema Validation) 2.…

自组装mid360便捷化bag包采集设备

一、问题一:电脑太重,换nuc 采集mid360数据的过程中,发现了头疼的问题,得一手拿着电脑,一手拿着mid360来采集,实在是累胳膊。因此,网购了一个intel nuc, 具体型号是12wshi5000华尔街峡谷nuc12i…

Python私教张大鹏 PyWebIO通过事件回调实现表格的编辑和删除功能

从上面可以看出,PyWebIO把交互分成了输入和输出两部分:输入函数为阻塞式调用,会在用户浏览器上显示一个表单,在用户提交表单之前输入函数将不会返回;输出函数将内容实时输出至浏览器。这种交互方式和控制台程序是一致的…

在Ubuntu 18.04.6 LTS 交叉编译生成Windows 11下的gdb 8.1.1

1. 安装mingw sudo apt-get install mingw-w64 2. 下载 gdb 8.1.1源码 https://ftp.gnu.org/gnu/gdb/gdb-8.1.1.tar.gz 解压命令 tar -xf gdb-8.1.1.tar.gz 进入目录,创建build目录: hq@hq:~/gdb-8.1.1/build$ 执行配置 ../confi

视频云计算的未来发展趋势:智能化、个性化与云端协同助力智慧城市安防监控

随着信息技术的飞速发展,云计算作为一种全新的服务模式,正在改变我们处理数据和信息的方式。而视频云计算技术,作为云计算领域的一个重要分支,以其独特的优势,正在逐步渗透到我们生活的各个领域。 一、视频云计算技术…

[leetcode hot 150]第一百二十二题,买卖股票的最佳时机Ⅱ

题目: 给你一个整数数组 prices ,其中 prices[i] 表示某支股票第 i 天的价格。 在每一天,你可以决定是否购买和/或出售股票。你在任何时候 最多 只能持有 一股 股票。你也可以先购买,然后在 同一天 出售。 返回 你能获得的 最大…

javaScript利用indexOf()查找字符串的某个字符出现的位置

1 创建字符串 2 利用indexof()查询字符串的字符 3 利用while循环判断indexOf是否等于-1,不等于-1就打印一次并且索引号1去查下一个字符 //创建字符串var str1234567812311231;var indexstr.indexOf(1);//查询该字符while(index !-1)//indexOf()没有查到会返回-1{…

企业本地大模型用Ollama+Open WebUI+Stable Diffusion可视化问答及画图

最近在尝试搭建公司内部用户的大模型,可视化回答,并让它能画图出来, 主要包括四块: Ollama 管理和下载各个模型的工具Open WebUI 友好的对话界面Stable Diffusion 绘图工具Docker 部署在容器里,提高效率以上运行环境Win10, Ollama,SD直接装在windows10下, 然后安装Docker…

Linux中彩色打印

看之前关注下公众号呗 第1部分:引言 1.1 Python在文本处理中的重要性 Python作为一种广泛使用的高级编程语言,以其简洁的语法和强大的功能在文本处理领域占有一席之地。无论是数据清洗、自动化脚本编写,还是复杂的文本分析,Py…

甄选范文“论云上自动化运维及其应用”,软考高级论文,系统架构设计师论文

论文真题 云上自动化运维是传统IT运维和DevOps的延伸,通过云原生架构实现运维的再进化。云上自动化运维可以有效帮助企业降低IT运维成本,提升系统的灵活度,以及系统的交付速度,增强系统的可靠性,构建更加安全、可信、开放的业务平台。 请围绕“云上自动化运维及其应用”…