昇思第7天

模型训练

模型训练一般分为四个步骤:

构建数据集。
定义神经网络模型。
定义超参、损失函数及优化器。
输入数据集进行训练与评估。

  1. 数据集加载
import mindspore
from mindspore import nn
# 从 MindSpore 数据集包中导入 vision 和 transforms 模块。
# vision:包含处理图像数据的工具。
# transforms:包含数据转换的工具。
from mindspore.dataset import vision, transforms
# 从 MindSpore 数据集包中导入 MnistDataset 类,用于加载 MNIST 数据集。
from mindspore.dataset import MnistDataset
# 从 download 模块中导入 download 函数,用于下载数据集。
from download import download

# 指定数据集的 URL 地址。
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"

# 使用 download 函数下载数据集并解压到当前目录。
path = download(url, "./", kind="zip", replace=True)

# 定义一个数据管道函数,接收数据集路径和批量大小作为参数。
def datapipe(path, batch_size):
    # 定义图像数据的转换操作列表。
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),       # 缩放图像像素值到 [0, 1] 范围。
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),  # 标准化图像数据。
        vision.HWC2CHW()                      # 转换图像格式从 HWC(高度、宽度、通道)到 CHW(通道、高度、宽度)。
    ]
    # 定义标签数据的转换操作,将标签转换为 int32 类型。
    label_transform = transforms.TypeCast(mindspore.int32)

    # 加载指定路径的数据集。
    dataset = MnistDataset(path)
    # 对数据集的图像应用转换操作。
    dataset = dataset.map(image_transforms, 'image')
    # 对数据集的标签应用转换操作。
    dataset = dataset.map(label_transform, 'label')
    # 将数据集分批,每批包含指定数量的样本。
    dataset = dataset.batch(batch_size)
    # 返回处理后的数据集。
    return dataset

# 创建训练数据集,批量大小为 64。
train_dataset = datapipe('MNIST_Data/train', batch_size=64)

# 创建测试数据集,批量大小为 64。
test_dataset = datapipe('MNIST_Data/test', batch_size=64)
  1. 构建神经网络
 # 定义一个神经网络类 Network,继承自 nn.Cell。
class Network(nn.Cell):
    # 在初始化方法中定义网络的结构。
    def __init__(self):
        # 调用父类的初始化方法。
        super().__init__()
        # 定义一个平坦化层,用于将输入的多维数据展开为一维。
        self.flatten = nn.Flatten()
        # 定义一个顺序容器 SequentialCell,其中包含多个层顺序连接。
        self.dense_relu_sequential = nn.SequentialCell(
            # 全连接层,将输入数据的尺寸从 28*28(即 784)转换为 512。
            nn.Dense(28*28, 512),
            # ReLU 激活函数。
            nn.ReLU(),
            # 全连接层,将输入数据的尺寸从 512 转换为 512。
            nn.Dense(512, 512),
            # ReLU 激活函数。
            nn.ReLU(),
            # 全连接层,将输入数据的尺寸从 512 转换为 10(对应于 10 个类别)。
            nn.Dense(512, 10)
        )

    # 定义前向传播方法,用于计算网络的输出。
    def construct(self, x):
        # 将输入数据平坦化。
        x = self.flatten(x)
        # 依次通过顺序容器中的各层,得到最终的输出 logits。
        logits = self.dense_relu_sequential(x)
        # 返回计算得到的 logits。
        return logits

# 创建一个 Network 类的实例,表示定义好的神经网络模型。
model = Network()

3.定义超参、损失函数及优化器。

# 定义训练的参数。
# 训练的轮数,即数据集将被遍历的次数。
epochs = 3
# 每个批次的大小,即一次训练中使用的样本数。
batch_size = 64
# 学习率,即模型参数在每次更新时调整的幅度。
learning_rate = 1e-2
# 定义训练的参数。
# 训练的轮数,即数据集将被遍历的次数。
epochs = 3
# 每个批次的大小,即一次训练中使用的样本数。
batch_size = 64
# 学习率,即模型参数在每次更新时调整的幅度。
learning_rate = 1e-2

# 定义损失函数,用于计算预测结果与实际标签之间的差异。
# 使用交叉熵损失函数(CrossEntropyLoss),这是分类问题中常用的损失函数。
loss_fn = nn.CrossEntropyLoss()

# 定义优化器,用于更新模型的参数。

# 使用随机梯度下降(SGD)优化器。
# model.trainable_params() 获取模型中所有需要训练的参数。
# learning_rate 指定优化器的学习率。
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)


4.训练与评估
训练

# 定义前向函数,用于计算模型输出和损失。
def forward_fn(data, label):
    # 使用模型计算预测值(logits)。
    logits = model(data)
    # 计算预测值与真实标签之间的损失。
    loss = loss_fn(logits, label)
    # 返回损失值和预测值。
    return loss, logits

# 获取梯度函数,用于计算损失相对于模型参数的梯度。
# mindspore.value_and_grad 会计算前向函数的值和梯度。
# forward_fn: 计算损失的前向函数。
# None: 不需要计算的额外输出。
# optimizer.parameters: 需要计算梯度的参数。
# has_aux=True: 表示前向函数返回多个值。
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# 定义单步训练函数。
def train_step(data, label):
    # 计算损失和梯度。
    (loss, _), grads = grad_fn(data, label)
    # 使用优化器更新模型参数。
    optimizer(grads)
    # 返回当前步的损失值。
    return loss

# 定义训练循环函数。
def train_loop(model, dataset):
    # 获取数据集的大小(即批次的数量)。
    size = dataset.get_dataset_size()
    # 设置模型为训练模式。
    model.set_train()
    # 枚举数据集的每个批次。
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        # 执行单步训练,获取当前批次的损失值。
        loss = train_step(data, label)

        # 每 100 个批次打印一次损失值和当前批次编号。
        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

测试函数

# 定义测试循环函数,用于在测试集上评估模型的性能。
def test_loop(model, dataset, loss_fn):
    # 获取数据集的批次数量。
    num_batches = dataset.get_dataset_size()
    # 设置模型为评估模式。
    model.set_train(False)
    
    # 初始化总样本数、测试损失和正确预测数。
    total, test_loss, correct = 0, 0, 0
    
    # 枚举数据集的每个批次。
    for data, label in dataset.create_tuple_iterator():
        # 使用模型进行预测。
        pred = model(data)
        # 累加总样本数。
        total += len(data)
        # 累加测试损失。
        test_loss += loss_fn(pred, label).asnumpy()
        # 累加正确预测数。
        correct += (pred.argmax(1) == label).asnumpy().sum()
    
    # 计算平均损失。
    test_loss /= num_batches
    # 计算准确率。
    correct /= total
    
    # 打印测试结果,包括准确率和平均损失。
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

运行

# 定义损失函数和优化器。
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

# 执行多个 epoch 的训练循环。
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    # 执行训练循环。
    train_loop(model, train_dataset)
    # 在测试集上进行评估。
    test_loop(model, test_dataset, loss_fn)

print("Done!")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/767276.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

使用DC/AC电源模块时需要注意的事项

BOSHIDA 使用DC/AC电源模块时需要注意的事项 1. 仔细阅读和理解产品说明书:在使用DC/AC电源模块之前,应该仔细阅读和理解产品说明书,了解其性能特点、技术要求和使用方法,以确保正确使用和避免潜在的安全风险。 2. 选择适当的电…

MySQL 9.0 发布了!

从昨晚开始,在DBA群里大家就在讨论MySQL 9.0发布的事情,但是Release Note和官方文档都没有更新,所以今天早上一上班就赶紧瞅了下具体更新了哪些内容? 整体看来,基本没什么创新。下面是9.0新增或废弃的一些特性。 &…

Power Platform功能管理实战概述

Power Platform功能管理实战概述 Microsoft Power Platform是一个强大的低代码开发平台,它使组织能够自动化商业流程、开发自定义应用程序,并加强与客户的连接。该平台由四个主要组件组成:Power Apps、Power Automate、Power BI和Power Virt…

【探索Linux】P.36(传输层 —— TCP协议段格式)

阅读导航 引言一、TCP段的基本格式二、控制位详细介绍三、16位接收窗口大小⭕窗口大小的作用⭕窗口大小的限制⭕窗口缩放选项⭕窗口大小的更新⭕窗口大小与拥塞控制 四、紧急指针温馨提示 引言 在上一篇文章中,我们深入探讨了一种无连接的UDP协议,它以其…

Searchsploit漏洞利用搜索工具的介绍及使用

目录 0x00 介绍0x01 常用参数0x02 使用1. 在线搜索2. 使用步骤3. 使用实例 0x00 介绍 kali自带的,Searchsploit会通过本地的Exploit-db查找软件漏洞信息。 Exploit Database(https://gitlab.com/exploit-database/exploitdb)存储了大量的漏洞…

33.哀家要长脑子了!

憋说了,感觉好不容易长出来的脑子又缩回去了。。。 1.539. 最小时间差 - 力扣(LeetCode) 把所有时间排好序,然后计算两两之间的分钟差就好,但是要注意加上最后一个和第一个的判断,因为这个时间是按字典序来…

AI研究的主要推动力会是什么?ChatGPT团队研究科学家:算力成本下降

AI 研究发展的主要推动力是什么?在最近的一次演讲中,OpenAI 研究科学家 Hyung Won Chung 给出了自己的答案。 近日,斯坦福大学《CS25: Transformers United V4》课程迎来了一位我们熟悉的技术牛人:Hyung Won Chung。 Chung 是 O…

Hadoop-03-Hadoop集群 免密登录 超详细 3节点公网云 分发脚本 踩坑笔记 SSH免密 服务互通 集群搭建 开启ROOT

章节内容 上一节完成: HDFS集群XML的配置MapReduce集群XML的配置Yarn集群XML的配置统一权限DNS统一配置 背景介绍 这里是三台公网云服务器,每台 2C4G,搭建一个Hadoop的学习环境,供我学习。 之前已经在 VM 虚拟机上搭建过一次&…

Spring容器生命周期中如前置运行程序和后置运行程序

在Spring容器加入一个实现了BeanPostProcessor接口bean实例,重写postProcessBeforeInitialization、postProcessAfterInitialization方法,在方法里面写具体的实现,从而达到Spring容器在初如化前或销毁时执行预定的程序,方法如下&a…

深入浅出:npm常用命令详解与实践【保姆级教程】

大家好,我是CodeQi! 在我刚开始学习前端开发的时候,有一件事情让我特别头疼:管理和安装各种各样的依赖包。 那时候,我还不知道 npm 的存在,手动下载和管理这些库简直是噩梦。 后来,我终于接触到了 npm(Node Package Manager),它不仅帮我解决了依赖管理问题,还让我…

解决Visual Studio 一直弹出管理员身份运行问题(win10/11解决办法)

不知道大家是否有遇到这个问题 解决办法也很简单 找到启动文件 如果是快捷方式就继续打开文件位置 找到这个程序启动项 右键 选择 兼容性疑难解答(win11 则需要 按住 shift 右键) win10 解决办法 这样操作完后就可以了 win11解决办法按以下选择就行

深入理解策略梯度算法

策略梯度(Policy Gradient)算法是强化学习中的一种重要方法,通过优化策略以获得最大回报。本文将详细介绍策略梯度算法的基本原理,推导其数学公式,并提供具体的例子来指导其实现。 策略梯度算法的基本概念 在强化学习…

AI大模型时代来临:企业如何抢占先机?

AI大模型时代来临:企业如何抢占先机? 2023年,被誉为大模型元年,AI大模型的发展如同一股不可阻挡的潮流,正迅速改变着我们的工作和生活方式。从金融到医疗,从教育到制造业,AI大模型正以其强大的生成能力和智能分析,重塑着行业的未来。 智能化:企业核心能力的转变 企…

【CUDA】 归约 Reduction

Reduction Reduction算法从一组数值中产生单个数值。这个单个数值可以是所有元素中的总和、最大值、最小值等。 图1展示了一个求和Reduction的例子。 图1 线程层次结构 在Reduction算法中,线程的常见组织方式是为每个元素使用一个线程。下面将展示利用许多不同方…

AI-算力集群通往AGI

背景: 自GPT-4发布以来,全球AI能力的发展势头有放缓的迹象。 但这并不意味着Scaling Law失效,也不是因为训练数据不够,而是结结实实的遇到了算力瓶颈。 具体来说,GPT-4的训练算力约2e25 FLOP,近期发布的几个…

双曲方程初值问题的差分逼近(迎风格式)

稳定性: 数值例子 例一 例二 代码 % function chap4_hyperbolic_1st0rder_1D % test the upwind scheme for 1D hyperbolic equation % u_t + a*u_x = 0,0<x<L,O<t<T, % u(x,0) = |x-1|,0<X<L, % u(0,t) = 1% foundate = 2015-4-22’; % chgedate = 202…

刷代码随想录有感(124):动态规划——最长公共子序列

题干&#xff1a; 代码&#xff1a; class Solution { public:int findLength(vector<int>& nums1, vector<int>& nums2) {vector<vector<int>>dp(nums1.size() 1, vector<int>(nums2.size() 1, 0));int res 0;for(int i 1; i <…

买华为智驾,晚了肯定要后悔

文 | AUTO芯球 作者 | 雷慢 晚了就来不及了&#xff01; 你买华为系的车&#xff0c;薅羊毛真的要趁早。 华为ADS2.0高阶智驾正在慢慢恢复原价&#xff0c; 你看啊&#xff0c;就在昨天&#xff0c;华为宣布ADS智驾优惠后价格调到3万元&#xff0c; 只有6000元的优惠了。…

153. 寻找旋转排序数组中的最小值(中等)

153. 寻找旋转排序数组中的最小值 1. 题目描述2.详细题解3.代码实现3.1 Python3.2 Java 1. 题目描述 题目中转&#xff1a;153. 寻找旋转排序数组中的最小值 2.详细题解 如果不考虑 O ( l o g n ) O(log n) O(logn)的时间复杂度&#xff0c;直接 O ( n ) O(n) O(n)时间复杂…