昇思25天学习打卡营第08天 | 模型训练

昇思25天学习打卡营第08天 | 模型训练

文章目录

  • 昇思25天学习打卡营第08天 | 模型训练
    • 超参数
    • 损失函数
    • 优化器
      • 优化过程
    • 训练与评估
    • 总结
    • 打卡

模型训练一般遵循四个步骤:

  1. 构建数据集
  2. 定义神经网络模型
  3. 定义超参数、损失函数和优化器
  4. 输入数据集进行训练和评估

构建数据集和网络模型在之前的内容在已经涉及,不再赘述。

超参数

超参数(Hyperparameters)是可以调整的参数,可以控制模型训练的过程。

深度学习模型多采用随机梯度下降算法SGD进行优化:
w t + 1 = w t − η 1 n ∑ x ∈ B ∇ l ( x , w t ) w_{t+1}=w_t- \eta\frac1n\sum_{x\in B}\nabla l(x,w_t) wt+1=wtηn1xBl(x,wt)
其中, η \eta η是学习率, n n n是batch大小,都是超参数,这两个参数是直接影响模型性能收敛的重要参数。
一般会定义三个超参数:

  • epoch:遍历数据集的次数
  • batch size:每个批次数据的大小。size 过小导致花费时间多,梯度震荡严重,不利于收敛;size 过大容易陷入局部极小值。
  • learning rate:学习率国小会导致收敛速度变慢;过大则可能会导致训练不收敛。

损失函数

损失函数用于评估模型预测值和目标值之间的误差。
常见的损失函数包括:

  • nn.MSELoss:均方误差,用于回归
  • nn.NLLLoss:负对数似然,用于分类
  • nn.CrossEntropyLoss:结合了nn.LogSoftmaxnn.NLLLoss,可以对logits进行归一化并计算预测误差
loss_fn = nn.CrossEntropyLoss()

优化器

优化器内部定义了模型参数的优化过程,所有的优化逻辑都封装在优化器对象中。

optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

优化过程

通过自动微分获得的微分函数,计算参数对应的梯度,并传入优化器中,即可实现参数优化。

grads = grad_fn(inputs)
optimizer(grads)

训练与评估

遍历一次数据集被称为一轮(epoch),每轮执行训练时包含两个步骤:

  1. 训练:迭代训练数据集,并尝试收敛到最佳参数。
  2. 验证/测试:迭代测试数据集,检查模型性能是否提升。
# Define forward function
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# Define function of one-step training
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss

def train_loop(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

def test_loop(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练过程一般为:

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(model, train_dataset)
    test_loop(model, test_dataset, loss_fn)

总结

这一节的内容对深度学习模型训练的一般过程进行了详细的介绍,从数据集构建到模型定义,接着定义超参数并选择合适的值,创建损失函数和优化器对象完成训练前的准备。通过封装一个模型调用和loss计算的前向计算函数并自动微分,在每个epoch中计算loss并优化参数,从而完成模型的训练。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/777903.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

东芝TB6560AHQ/AFG步进电机驱动IC:解锁卓越的电机控制性能

作为一名工程师,一直在寻找可靠且高效的组件来应用于你的项目中。东芝的TB6560AHQ/AFG步进电机驱动IC能够提供精准且多功能的电机控制,完全符合现代应用的高要求,保证高性能和易用性。在这篇文章中,我们将探讨TB6560AHQ/AFG的主要…

CentOS 7.9 停止维护(2024-6-30)后可用在线yum源 —— 筑梦之路

众所周知,centos 7 在2024年6月30日,生命周期结束,官方不再进行支持维护,而很多环境一时之间无法完全更新替换操作系统,因此对于yum源还是需要的,特别是对于互联网环境来说,在线yum源使用方便很…

直播预告 | VMware大规模迁移实战,HyperMotion助力业务高效迁移

2006年核高基专项启动,2022年国家79号文件要求2027年央国企100%完成信创改造……国家一系列信创改造政策的推动,让服务器虚拟化软件巨头VMware在中国的市场份额迅速缩水。 加之VMware永久授权的取消和部分软件组件销售策略的变更,导致VMware…

移动端UI风格营造舒适氛围

移动端UI风格营造舒适氛围

XXL-JOB中断信号感知

目录 背景 思路 实现逻辑 总结 背景 在使用xxl-job框架时,由于系统是由线程池去做异步逻辑,然后主线程等待,在控制台手动停止时,会出现异步线程不感知信号中断的场景,如下场景 而此时如果人工在控制台停止xxl-job执…

insert阻塞了insert?

一、发现问题 在arms监控页面看到某条insert语句的执行时长达到了431毫秒。 数据库中存在,insert语句受到了行锁阻塞,而阻塞的源头也在执行同样的insert语句,同样都是对表USERSYS_TASK_USER_LOG_TEMP01的插入操作,很是费解。 二…

idea创建的maven项目pom文件引入的坐标报红原因

如下所示 我们在引入某些依赖坐标的时候,即使点击了右上角的mavne刷新之后还是报红。 其实这是正常现象,实际上是我们的本地仓库当中没有这些依赖坐标,而idea就会通过报红来标记这些依赖来说明在我们的本地仓库是不存在的。 那有的同学就会…

ODOO17的邮件机制-系统自动推送修改密码的邮件

用户收到被要求重置密码的邮件: 我们来分析一下ODOO此邮件的工作机制: 1、邮件模板定义 2、渲染模板的函数: 3、调用此函数的机制: 当用户移除或增加了信任的设备(如电脑、手机端等),系统会自…

农业气象站:现代农业的守护者与引领者

随着科技的飞速发展,农业领域也在经历着前所未有的变革。在这一变革中,农业气象站以其独特的功能和作用,逐渐成为了现代农业的守护者与引领者。 农业气象站,顾名思义,是专门用于观测和记录农田气象要素的设施。这些气象…

轻松设置:服务器域名配置全攻略

目录 前置条件 在阅读本篇内容之前,请先确保以下物料已准备好: 一台公网服务器,服务正常运行申请完成的域名,在对应域名服务商后台正常DNS解析域名备案完成可选条件:有https访问请求时,需要申请SSL证书 …

Android在framework层添加自定义服务的流程

环境说明 ubuntu16.04android4.1java version “1.6.0_45”GNU Make 3.81gcc version 5.4.0 20160609 (Ubuntu 5.4.0-6ubuntu1~16.04.12) 可能有人会问,现在都2024了怎么还在用android4版本,早都过时了。确实,现在最新的都是Android13、And…

在Linux环境下搭建Redis服务结合内网穿透实现通过GUI工具远程管理数据库

文章目录 前言1. 安装Docker步骤2. 使用docker拉取redis镜像3. 启动redis容器4. 本地连接测试4.1 安装redis图形化界面工具4.2 使用RDM连接测试 5. 公网远程访问本地redis5.1 内网穿透工具安装5.2 创建远程连接公网地址5.3 使用固定TCP地址远程访问 前言 本文主要介绍如何在Li…

Python处理表格数据常用的 N+个操作

Python作为一种强大且易用的编程语言,其在数据处理方面表现尤为出色。特别是当我们面对大量的表格数据时,Python的各类库和工具可以极大地提高我们的工作效率。以下,我将详细介绍Python处理表格数据常用的操作。 首先,我们需要安…

【算法笔记自学】第 3 章 入门篇(1)——入门模拟

3.1简单模拟 自己写的题解 #include <stdio.h> #include <stdlib.h> int main() {int N;int num0;scanf("%d",&N);while(N!1){if(N%20){NN/2;}else{N(3*N1)/2;}num;}printf("%d",num);system("pause"); // 防止运行后自动退出&…

SpringBoot+OSS实现文件上传

创建spring boot项目 pom依赖 <dependency><groupId>com.aliyun.oss</groupId><artifactId>aliyun-sdk-oss</artifactId><version>3.17.4</version></dependency><dependency><groupId>javax.xml.bind</groupI…

Transformer前置知识:Seq2Seq模型

Seq2Seq model Seq2Seq&#xff08;Sequence to Sequence&#xff09;模型是一类用于将一个序列转换为另一个序列的深度学习模型&#xff0c;广泛应用于自然语言处理&#xff08;NLP&#xff09;任务&#xff0c;如机器翻译、文本摘要、对话生成等。Seq2Seq模型由编码器&#…

直播预告|飞思实验室暑期公益培训7月10日正式开启,报名从速!

01 培训背景 很荣幸地向大家宣布&#xff1a;卓翼飞思实验室将于7月10日正式开启为期两个月的暑期公益培训&#xff01;本次培训为线上直播&#xff0c;由中南大学计算机学院特聘副教授&#xff0c;RflySim平台总研发负责人戴训华副教授主讲。 培训将基于“RflySim—智能无人…

数据可视化之智慧农业的窗口与引擎

在科技日新月异的今天,农业作为国民经济的基础产业,正逐步向智能化、数字化转型。农业为主题的数据可视化大屏看板,作为这一转型过程中的重要工具,不仅为农业管理者提供了全面、实时的农田信息,还促进了农业资源的优化配置和农业生产效率的提升。本文将深入探讨农业数据可…

Git 运用小知识

1.Git添加未完善代码的解决方法 1.1 Git只是提交未推送 把未完善的代码提交到本地仓库 只需点击撤销提交&#xff0c;提交的未完善代码会被撤回 代码显示未提交状态 1.2 Git提交并推送 把未完善的代码提交并推送到远程仓库 点击【未完善提交并推送】的结点选择还原提交&#x…

最佳 iPhone 解锁软件工具,可免费下载用于电脑操作的

业内专业人士表示&#xff0c;如果您拥有 iPhone&#xff0c;您一定知道忘记锁屏密码会多么令人沮丧。由于 Apple 的安全功能强大&#xff0c;几乎不可能在没有密码或 Apple ID 的情况下访问锁定的 iPhone。 “当我忘记密码时&#xff0c;如何在没有密码的情况下解锁iPhone&am…