昇思25天学习打卡营第5天 | 网络构建

目录

1.定义模型类

2.模型层

nn.Flatten

nn.Dense

nn.ReLU

nn.SequentialCell

nn.Softmax

3.模型参数

代码实现:

总结


神经网络模型是由神经网络层和Tensor操作构成的,

mindspore.nn提供了常见神经网络层的实现,

在MindSpore中,Cell类是构建所有网络的基类,也是网络的基本单元。

一个神经网络模型表示为一个Cell,它由不同的子Cell构成。

使用这样的嵌套结构,可以简单地使用面向对象编程的思维,对神经网络结构进行构建和管理

1.定义模型类

定义神经网络时,可以继承nn.Cell类,在__init__方法中进行子Cell的实例化和状态管理,在construct方法中实现Tensor操作。

construct意为神经网络(计算图)构建

构建完成后,实例化Network对象,并查看其结构:

三个全连接层(Dense)和两个ReLU激活函数的序列模型

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 10, weight_init="normal", bias_init="zeros")
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
print(model)

我们构造一个输入数据,直接调用模型,可以获得一个十维的Tensor输出,其包含每个类别的原始预测值。

model.construct()方法不可直接调用。

在此基础上,我们通过一个nn.Softmax层实例来获得预测概率。

X = ops.ones((1, 28, 28), mindspore.float32)
logits = model(X)
# print logits
logits
pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

2.模型层

我们分解上面构造的神经网络模型中的每一层。首先我们构造一个shape为(3, 28, 28)的随机数据(3个28x28的图像),依次通过每一个神经网络层来观察其效果。

input_image = ops.ones((3, 28, 28), mindspore.float32)
print(input_image.shape)

nn.Flatten

实例化nn.Flatten层,将28x28的2D张量转换为784大小的连续数组。

nn.Dense

nn.Dense为全连接层,其使用权重和偏差对输入进行线性变换。

nn.ReLU

nn.ReLU层给网络中加入非线性的激活函数,帮助神经网络学习各种复杂的特征。

nn.SequentialCell

nn.SequentialCell是一个有序的Cell容器。输入Tensor将按照定义的顺序通过所有Cell。我们可以使用SequentialCell来快速组合构造一个神经网络模型。

nn.Softmax

最后使用nn.Softmax将神经网络最后一个全连接层返回的logits的值缩放为[0, 1],表示每个类别的预测概率。axis指定的维度数值和为1。

3.模型参数

网络内部神经网络层具有权重参数和偏置参数(如nn.Dense),这些参数会在训练过程中不断进行优化,可通过 model.parameters_and_names() 来获取参数名及对应的参数详情。

代码实现:

总结

构建网络,定义模型类时,要有这个框架,继承类,在他里面进行实例化和状态管理:

  1. class Network(nn.Cell): 定义了一个类,继承自 nn.Cell

  2. def __init__(self):  Network 类的构造函数,初始化类的属性。

  3. super().__init__(): 调用父类 nn.Cell 的构造函数。

  4. def construct(self, x): 定义了 Network 类的 construct 方法,它是MindSpore中定义模型前向传播逻辑的方法。参数 x 表示输入数据。

  5. x = self.flatten(x): 使用 self.flatten 层将输入数据 x 展平。

  6. logits = self.dense_relu_sequential(x): 将展平后的数据 x 通过 self.dense_relu_sequential 序列模型进行前向传播,得到模型的原始输出 logits。在分类任务中,logits 是模型的线性输出

  7. return logits: 返回模型的输出 logits

class Network(nn.Cell):
    def __init__(self):
        super().__init__()




def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits
  1. self.flatten = nn.Flatten(): 初始化一个 nn.Flatten 层,这个层用于将多维输入数据展平为一维数据。在处理图像数据时,通常需要将图像的二维数据(例如,28x28像素)展平为一维向量。

  2. self.dense_relu_sequential = nn.SequentialCell(...): 初始化一个序列模型,包含三个全连接层(nn.Dense)和两个ReLU激活函数(nn.ReLU)。这个序列模型的初始化与之前解释的相同。

预测的时候:

1. pred_probab = nn.Softmax(axis=1)(logits): 使用了 nn.Softmax 函数来将模型的输出 logits 转换为概率分布。

Softmax 函数通常用于多类分类问题的输出层,它可以将一个向量的元素转换为一个概率分布,使得所有元素的和为1。

参数 axis=1 表示 Softmax 函数将在第二个维度(通常是特征维度)上应用,即对于每个样本,将其对应的 logits 转换为概率。

2. y_pred = pred_probab.argmax(1): 这行代码使用了 argmax 函数来找到每个样本概率最高的类别索引。argmax 函数返回输入数组中最大元素的索引。在这里,它沿着第二个维度(即每个样本的概率分布)找到最大值的索引,这代表了模型预测的类别。

pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

别的也没什么了吧~~~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/739448.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

从宏基因组中鉴定病毒序列(1)

Introduction 在环境微生物学和生态学研究中,宏基因组学(Metagenomics)技术的应用已经彻底改变了我们对微生物群落的理解。宏基因组学通过对环境样本中的全部遗传物质进行测序和分析,可以全面揭示微生物群落的组成、功能和相互作…

操作系统概论(二)

一、单项选择题(本大题共20小题,每小题1分,共20分) 在每小题列出的四个备选项中只有一个选项是符合题目要求的,请将其代码填写在题后的括号内。错选、多选或未选均无分。 1.操作员接口是操作系统为用户提供的使用计算机系统的手…

自产厂家将品牌入驻美国商超的详细流程及其显著优势

随着全球化的深入推进,越来越多的国内厂家开始寻求海外市场的拓展,其中美国商超成为了一个重要的目标市场。那么,国内厂家想要将产品入驻美国商超需要经历哪些详细流程呢?同时,这样的举措又有哪些显著优势呢?接下来,…

西部证券:1+1>2?

又一起券商收购拉开帷幕,证券业并购浪潮呼之欲出。 这次是——西部证券。 最近,西部证券公告称,因自身发展需要正在筹划收购国融证券控股权事项, 这是继“浙商国都”、“国联民生”、“华创太平洋”之后,今年券商并购…

HTML(16)——边距问题

清楚默认样式 很多标签都有默认的样式,往往我们不需要这些样式,就需要清楚默认样式 写法: 用通配符选择器,选择所有标签,清除所有内外边距选中所有的选择器清楚 *{ margin:0; padding:0; } 盒子模型——元素溢出 作…

Android CTS环境搭建

CTS即Compatibility Test Suite意为兼容性测试,是Google推出的Android平台兼容性测试机制。其目的是尽早发现不兼容性,并确保软件在整个开发过程中保持兼容性。只有通过CTS认证的设备才能合法的安装并使用Google market等Google应用。 搭建CTS测试环境需…

2008年 - 2021年 地级市-人口密度数据

人口密度是一个关键的人口统计指标,它反映了在一定地理范围内的人口分布情况。这个指标对于理解一个国家或地区的空间人口分布、资源分配、社会经济发展和城市规划等方面都具有重要意义。 人口密度的计算方法 人口密度是通过将一个地区的常住人口数除以其面积来计…

一文详解去噪扩散概率模型(DDPM)

节前,我们星球组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学。 针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。 合集&#x…

恒远世达:把握现在,高考后逆袭,开启日本留学之路!

一年一度的高考已经落幕,马上就要出高考分数了,有人欢喜有人忧,奋斗学习了这么多年,就为了考上一所理想的大学,一旦没考上,心情会非常的低落。 在传统心态中,高考失利意味着人生重大失败&#…

VS Code SSH 远程连接服务器及坑点解决

背景 Linux服务器重装了一下,IP没有变化,结果VS Code再重连的时候就各种问题,导致把整个流程全部走了一遍,留个经验帖以备查看 SSH 首先确保Windows安装了ssh,通过cmd下ssh命令查看是否安装了。 没安装,…

CAD平台大模型场景显示性能优化分析总结

1.性能瓶颈原因 图元过于复杂 (1)图元内的三角形面片过多。对于CAD平台大场景,单帧三角面片数量达到5000万。 (2)图元的各种计算过多。 过多的图元。例如土建场景:将近20万的构件,绘制次数将…

原装GUVCL-T10GD韩国GENICOM光电二极管紫外线传感器原厂代理商

深圳市宏南科技有限公司是韩国GenUV公司的原厂代理商,所售紫外线传感器均来自于原始生产厂商直接供货,非第三方转售。 GUVCL-T10GD 韩国GENICOM光电二极管光传感器 / 低亮度 / 紫外线 UV-C传感器 GUVCL-T10GD 采用基于氮化铟的材料 肖特基型 光电二极管…

生产环境安装odoo

odoo可以在多平台运行,但是在生产环境下官方不建议在Windows平台部署。在Windows下可能不能很好的支持一服务多worker的形式,更推荐在Linux下部署。 常见的Linux如Ubuntu、Debian等Debian系或Redhat系都能执行官网的包安装。 地址:Download |…

C# Web控件与数据感应之数据返写

目录 关于数据返写 准备视图 范例运行环境 ControlInducingFieldName 方法 设计与实现 如何根据 ID 查找控件 FindControlEx 方法 调用示例 小结 关于数据返写 数据感应也即数据捆绑,是一种动态的,Web控件与数据源之间的交互,数据…

Docker(七)-Docker容器数据卷

1.容器数据卷是什么 卷就是目录或者文件,存在于一个或者多个容器中,由docker挂载到容器,不属于容器内(类似于笔记本电脑外的一个移动硬盘)。 卷的设计目的就是数据持久化,完全独立于容器的生存周期,因此Docker不会在容…

博图随机生成俄罗斯方块程序

一、程序结构 1.定义基础数据,俄罗斯方块图形共19中,使用WORD编码存储在数组内 2.添加随机生成int数值的FC函数块,生成1-19 的随机数 3.查找数组内图形显示在HMI画面上 二、程序 1.生成1-19 的随机数,并显示当前图形样式 2.生成按…

网页设计软件Bootstrap Studio6.7.1

Bootstrap Studio是一个适用于Windows的程序,允许您使用流行的fre***orca Bootstrap创建和原型网站。您可以将现成的组件拖动到工作区并直观地自定义它们。该程序生成干净和语义的PDF、CSS和JS代码,所有Web浏览器都支持这些代码。 Bootstrap Studio有一个漂亮而强大的界面,它…

Microsoft Edge浏览器安装crx拓展插件教程

1、首先打开edge浏览器,点击顶部地址栏。 2、在地址栏中输入"edge://flags/#extensions-on-edge-urls"并按下回车。2、在地址栏中输入"edge://flags/#extensions-on-edge-urls"并按下回车。 3、进入后,将图示选项改为“已禁用”。 …

邮件群发推送的方法技巧?有哪些注意事项?

邮件群发推送的策略如何实现?邮件推送怎么评估效果? 电子邮件营销是现代企业进行推广和沟通的重要工具。有效的邮件群发推送不仅能提高客户参与度,还能促进销售增长。AokSend将探讨一些关键的邮件群发推送方法和技巧,以帮助企业优…

现在本科录取率最高已达79%了。。。

郭震原创,手撸码字1035 你好,我是郭震 高考今天陆续出分,查了下去年高考本科录取率,排名第一的上海,已达到79.19%: 不知道诸位看到这个数字,有何感想? 1 本科含金量 1977年本科录取率…