昇思25天学习打卡营第5天|网络构建

 一、简介:

神经网络模型是由神经网络层和Tensor操作构成的,mindspore.nn提供了常见神经网络层的实现,在MindSpore中,Cell类是构建所有网络的基类(这个类和pytorch中的modul类是一样的作用),也是网络的基本单元。一个神经网络模型表示为一个Cell,它由不同的子Cell构成。使用这样的嵌套结构,可以简单地使用面向对象编程的思维,对神经网络结构进行构建和管理。

二、环境准备:

import mindspore
import time
from mindspore import nn, ops

没有下载mindspore的宝子,还是回看我的昇思25天学习打卡营第1天|快速入门-CSDN博客,先下载好再进行下面的操作。

三、神经网络搭建:

1、定义模型类:

我们首先要继承nn.Cell类,并再__init__方法中进行子Cell的实例化和管理,并再construct方法(和pytorch中的forward方法一致)中实现前向计算:

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 10, weight_init="normal", bias_init="zeros")
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

# 实例化并打印
model = Network()
print(model)
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

 ①self.flatten = nn.Flatten():创建一个Flatten层,并将其作为类的属性。Flatten层的作用是将输入的数据“压平”,即不管输入数据的原始形状如何,输出都将是沿着特定维度的连续数组。

② self.dense_relu_sequential = nn.SequentialCell(...):创建一个SequentialCell,它是一种特殊的Cell,可以顺序地执行其中包含的多个层。这个SequentialCell包含了三个全连接层(Dense),每个全连接层后面跟着一个ReLU激活函数层,除了最后一个全连接层:

  • 第一个nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"):这是一个全连接层,它接受28*28=784个输入,并产生512个输出。权重(weight_init)和偏置(bias_init)分别使用正态分布和零值进行初始化。

  • nn.ReLU():ReLU激活函数,其数学表达式为f(x) = max(0, x),即负值输出为零,正值保持不变。

  • 接下来的两个nn.Dense与对应的nn.ReLU层与第一个类似,它们分别接收512个输入并再次输出512个值,以及最终输出10个值,这可能对应于10个类别。

我们构造一个数据,并使用softmax预测其概率:

X = ops.ones((1, 28, 28), mindspore.float32)
logits = model(X)
# print logits
print(logits)

pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

2、模型层详解:

(1)nn.Flatten:

 nn.Flantten方法用于将输入数据“压平”,以便后续处理:

input_image = ops.ones((3, 28, 28), mindspore.float32)
print(input_image.shape)

flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.shape)

print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

(2)nn.Dense:

 nn.Dense层作为全连接层,用于对输入的数据进行线性变换和处理:

layer1 = nn.Dense(in_channels=28*28, out_channels=20)
hidden1 = layer1(flat_image)
print(hidden1.shape)

print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

(3)nn.Relu:

nn.Relu是本次实验中使用的激活函数,用于对神经网络的权重进行处理,以缓解欠拟合和过拟合的发生,常见的激活函数处了Relu,还有:Sigmoid, Tanh等:

print(f"Before ReLU: {hidden1}\n\n")
hidden1 = nn.ReLU()(hidden1)
print(f"After ReLU: {hidden1}")

print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

 (4)nn.SequentialCell:

nn.SequentialCell和pytorch中的nn.Sequential的作用一样,用于存放dense全连接层和激活函数层的组合,以方便在前向计算中使用:

seq_modules = nn.SequentialCell(
    flatten,
    layer1,
    nn.ReLU(),
    nn.Dense(20, 10)
)

logits = seq_modules(input_image)
print(logits.shape)

print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

(5)nn.Softmax:

 nn.softmax方法将神经网络最后一个全连接层返回的logits的值缩放为[0, 1],表示每个类别的预测概率。axis指定的维度数值和为1。

softmax = nn.Softmax(axis=1)
pred_probab = softmax(logits)
print(pred_probab)
# argmax函数返回指定维度上最大值的索引
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

3、模型参数:

网络内部神经网络层具有权重参数和偏置参数(如nn.Dense),这些参数会在训练过程中不断进行优化,可通过 model.parameters_and_names() 来获取参数名及对应的参数详情。

print(f"Model structure: {model}\n\n")

for name, param in model.parameters_and_names():
    print(f"Layer: {name}\nSize: {param.shape}\nValues : {param[:2]} \n")
    
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/736687.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

陀螺仪LSM6DSV16X与AI集成(7)----FIFO数据读取与配置

陀螺仪LSM6DSV16X与AI集成.7--检测自由落体 概述视频教学样品申请源码下载主要内容生成STM32CUBEMX串口配置IIC配置CS和SA0设置串口重定向参考程序初始换管脚获取ID复位操作BDU设置设置量程设置FIFO水印设置速率使用流模式设置FIFO时间戳批处理速率使能时间戳FIFO状态寄存器演示…

Mathtype7在Word2016中闪退(安装过6)

安装教程:https://blog.csdn.net/Little_pudding10/article/details/135465291 Mathtype7在Word2016中闪退是因为安装过Mathtype6,MathPage.wll和MathType Comm***.dotm),不会随着Mathtype的删除自动删除,而新版的Mathtype中的文件…

【从0实现React18】 (三) 初探reconciler 带你初步探寻React的核心逻辑

Reconciler 使React核心逻辑所在的模块,中文名叫协调器,协调(reconciler)就是diff算法的意思 reconciler有什么用? 在前端框架出现之前,通常会使用 jQuery 这样的库来开发页面。jQuery 是一个过程驱动的库,开发者需要…

Django 模版过滤器

Django模版过滤器是一个非常有用的功能,它允许我们在模版中处理数据。过滤器看起来像这样:{{ name|lower }},这将把变量name的值转换为小写。 1,创建应用 python manage.py startapp app5 2,注册应用 Test/Test/sett…

【git1】指令,commit,免密

文章目录 1.常用指令:git branch查看本地分支, -r查看远程分支, -a查看本地和远程,-v查看各分支最后一次提交, -D删除分支2.commit规范:git commit进入vi界面(进入前要git config core.editor vim设一下vi模…

Java项目:基于SSM框架实现的精品酒销售管理系统分前后台【ssm+B/S架构+源码+数据库+毕业论文】

一、项目简介 本项目是一套基于SSM框架实现的精品酒销售管理系统 包含:项目源码、数据库脚本等,该项目附带全部源码可作为毕设使用。 项目都经过严格调试,eclipse或者idea 确保可以运行! 该系统功能完善、界面美观、操作简单、功…

【break】大头哥哥做题

【break】大头哥哥做题 时间限制: 1000 ms 内存限制: 65536 KB 【题目描述】 【参考代码】 #include <iostream> using namespace std; int main(){ int sum 0;//求和int day 0;//天数 while(1){int a;cin>>a;if(a-1){break;//结束当前循环 }sum sum a; …

121.网络游戏逆向分析与漏洞攻防-邮件系统数据分析-邮件读取与发送界面设计

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 如果看不懂、不知道现在做的什么&#xff0c;那就跟着做完看效果 现在的代码都是依据数据包来写的&#xff0c;如果看不懂代码&#xff0c;就说明没看懂数据包…

Nginx+Lua+Redis 实现Nginx301跳转配置管理

业务场景需求 long long ago&#xff1a; 在项目的运维过程中有一次SEO团队提出 网页的URL 中如果可以带上关键字&#xff0c;那么网页在各大搜索引擎中收录和排名有非常重大的突出优势&#xff08;~~SEO团队到底专不专业 ~~&#xff0c;此处不做置评&#xff09;&#xff0c;…

C/C++ strftime函数

目录 strftime()函数 函数原型 头文件 功能 返回值 参数 案例 结语 strftime()函数 函数原型 size_t strftime(char *s, size_t max, const char *format, const struct tm *tm); 头文件 #include <time.h> 功能 用于日期和时间格式化的函数&#xff0c;它允许你…

【算法】二叉树 - 理论基础

1.种类 1.1 满二叉树 只有度为0和2的节点&#xff0c;且度为0的节点都都在同一层。深度为k&#xff0c;有2^k-1个节点。 1.2 完全二叉树 在完全二叉树中&#xff0c;除了最底层节点可能没填满外&#xff0c;其余每层节点数都达到最大值&#xff0c;并且最下面一层的节点都…

【论文复现|智能算法改进】一种基于多策略改进的鲸鱼算法

目录 1.算法原理2.改进点3.结果展示4.参考文献5.代码获取 1.算法原理 SCI二区|鲸鱼优化算法&#xff08;WOA&#xff09;原理及实现【附完整Matlab代码】 2.改进点 混沌反向学习策略 将混沌映射和反向学习策略结合&#xff0c;形成混沌反向学习方法&#xff0c;通过该方 法…

Atcoder Beginner Contest 359

传送门 A - Count Takahashi 时间限制&#xff1a;2秒 内存限制&#xff1a;1024MB 分数&#xff1a;100分 问题描述 给定 N 个字符串。 第 i 个字符串 () 要么是 Takahashi 要么是 Aoki。 有多少个 i 使得 等于 Takahashi &#xff1f; 限制 N 是整数。每个…

Cyber Weekly #12

赛博新闻 1、Anthropic发布Claude 3.5 Sonnet 本周五&#xff08;6月21日&#xff09;凌晨&#xff0c;Anthropic宣布推出其最新的语言模型Claude 3.5 Sonnet&#xff0c;距离上次发布Claude3才过去3个月。Claude3.5拥有20万token的长上下文窗口&#xff0c;目前已经在Claude…

数据库断言

在数据库验证断言 目的&#xff1a;不能相信接口返回结果&#xff0c;通过到数据库检验可知接口返回结果是否真的正确 如何校验&#xff1a;代码与mymql建立网络连接&#xff0c;操作数据库&#xff0c;断开连接 代码&#xff1a;java操作数据库 pom文件配置依赖 步骤&…

15.树形虚拟列表实现(支持10000+以上的数据)el-tree(1万+数据页面卡死)

1.问题使用el-tree渲染的树形结构&#xff0c;当数据超过一万条以上的时候页面卡死 2.解决方法&#xff1a; 使用vue-easy-tree来实现树形虚拟列表&#xff0c;注意&#xff1a;vue-easy-tree需要设置高度 3.代码如下 <template><div class"ve-tree" st…

ARM32开发--WDGT看门狗

知不足而奋进 望远山而前行 目录 文章目录 前言 目标 内容 什么是看门狗 ARM中的看门狗 独立看门狗定时器 窗口看门狗定时器 独立看门狗FWDGT 初始化配置 喂狗 完整代码 窗口看门狗WWDGT 初始化配置 喂狗 完整代码 注意 总结 前言 嵌入式系统在如今的科技发…

OpenGL3.3_C++_Windows(18)

接口块&#xff1a; glsl彼此传输数据&#xff0c;通过in / out&#xff0c;当更多的变量&#xff0c;涉及数组和结构体接口块(Interface Block)类似struct&#xff0c;in / out 块名{……}实例名 Uniform缓冲对象&#xff1a; 首先理解uniform Object&#xff1a;负责向gl…

基于协方差信息的Massive MIMO信道估计算法性能研究

1. 引言 随着移动互联网不断发展&#xff0c;人们对通信的速率和可靠性的要求越来越高[1]。目前第四代移动通信系统已经逐渐商用&#xff0c;研究人员开始着手研究下一代移动通信系统相关技术[2][3]。在下一代移动通信系统中要求下行速率达到10Gbps&#xff0c;这就要求我们使…

Debian Linux安装minikubekubectl

minikube&kubectl minkube用于在本地开发环境中快速搭建一个单节点的Kubernetes集群,还有k3s&#xff0c;k3d&#xff0c;kind都是轻量级的k8skubectl是使用K8s API 与K8s集群的控制面进行通信的命令行工具 这里使用Debian Linux演示&#xff0c;其他系统安装见官网,首先…