《昇思25天学习打卡营第2天|02快速入门》

课程目标

这节课准备再学习下训练模型的基本流程,因此还是选择快速入门课程。

整体流程

整体介绍下流程:

  1. 数据处理
  2. 构建网络模型
  3. 训练模型
  4. 保存模型
  5. 加载模型
    思路是比较清晰的,看来文档写的是比较连贯合理的。

数据处理

看数据也是手写体数据集的例子。
他们把数据都放存储了一份,可以通过设置获取到训练集合和测试集合。
构建了一个以64为一批的包:在这里插入图片描述
可以迭代获取到数据:
在这里插入图片描述
整体来说获取数据的部分还是比较清晰的。

网络构建

构建网络的方法和pytorch是比较接近的:
在这里插入图片描述
可以看出来,将数据先打平,然后放到全链接层,之后经过relu,再经过两个循环就构建好了网络。
模型的样子差不多是:
在这里插入图片描述

模型训练

在这里插入图片描述
通过截图可以看出来,损失函数和优化器都依次进行定义。注意这里使用的是交叉熵损失函数,所以要求的label是[batch_size],logits是[batch_size, num_class]。
损失函数的实现逻辑:

import numpy as np

def softmax(logits):
    exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
    probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
    return probs

def cross_entropy_loss(logits, labels):
    probs = softmax(logits)
    batch_size = logits.shape[0]
    
    # 取出正确类别的概率
    correct_log_probs = -np.log(probs[np.arange(batch_size), labels])
    
    # 计算平均损失
    loss = np.sum(correct_log_probs) / batch_size
    return loss

# 示例
logits = np.array([[2.0, 1.0, 0.1], [1.2, 0.9, 3.2], [0.5, 2.1, 0.3]])
labels = np.array([0, 2, 1])

loss = cross_entropy_loss(logits, labels)
print(f'Loss: {loss}')

在这个实现中:

softmax 函数对 logits 进行 softmax 操作。
cross_entropy_loss 函数计算交叉熵损失。
np.log 计算负对数概率。
np.arange(batch_size) 创建一个数组 [0, 1, 2, …, batch_size-1] 用于选择正确类别的概率。

通过最上面训练的代码也可以看出来,每一个step会进行一次计算优化器,获得loss。然后每100个step输出一次数据。
在整体的更上层,执行了3个epoch。
在这里插入图片描述

保存模型

在这里插入图片描述

加载模型

整体看着也挺简单的:
在这里插入图片描述

打卡

完结撒花,打卡。
在这里插入图片描述

总结

今天又过了一次,从构建数据到构建模型,和训练的整体过程都介绍完毕了。这里的模型很简单,所以训练的时候也很简单。如果是大语言模型的训练过程,需要使用到更复杂的处理逻辑,可能会依赖DeepSpeed进行并行训练。希望在接下来的学习中有机会接触到。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/798189.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

提高项目透明度:有效的跟踪软件

国内外主流的10款项目进度跟踪软件对比:PingCode、Worktile、Teambition、Tower、Asana、Trello、Jira、ClickUp、Notion、Liquid Planner。 在项目管理中,确保进度跟踪的准确性与效率是每位项目经理面临的主要挑战之一。选用合适的项目进度跟踪软件不仅…

800 元打造家庭版 SOC 安全运营中心

今天,我们开始一系列新的文章,将从独特而全面的角度探索网络安全世界,结合安全双方:红队和蓝队。 这种方法通常称为“紫队”,集成了进攻和防御技术,以提供对威胁和安全解决方案的全面了解。 在本系列的第一篇文章中,我们将指导您完成以 100 欧元约800元左右的预算创建…

Sentinel-1 Level 1数据处理的详细算法定义(三)

《Sentinel-1 Level 1数据处理的详细算法定义》文档定义和描述了Sentinel-1实现的Level 1处理算法和方程,以便生成Level 1产品。这些算法适用于Sentinel-1的Stripmap、Interferometric Wide-swath (IW)、Extra-wide-swath (EW)和Wave模式。 今天介绍的内容如下: Sentinel-1 L…

zookeeper的shell操作

一:启动拽库的shell命令行 zkCli.sh -server localhost:2181 退出:quit 二:查询所有的命令 help 三:查询对应的节点 --查询zk上的根节点 ls / ls /zookeeper 四:查询对应节点的节点信息(节点的元数据&a…

idea启动ssm项目详细教程

前言 今天碰到一个ssm的上古项目,项目没有使用内置的tomcat作为服务器容器,这个时候就需要自己单独设置tomcat容器。这让我想起了我刚入行时被外置tomcat配置支配的恐惧。现在我打算记录一下配置的过程,希望对后面的小伙伴有所帮助吧。 要求…

React学习笔记02-----

一、React简介 想实现页面的局部刷新,而不是整个网页的刷新。AJAXDOM可以实现局部刷新 1.特点 (1)虚拟DOM 开发者通过React来操作原生DOM,从而构建页面。 React通过虚拟DOM来实现,可以解决DOM的兼容性问题&#x…

UNIAPP_ReferenceError: TextEncoder is not defined 解决

错误信息 1、安装text-decoding npm install text-decoding2、main.js import { TextEncoder, TextDecoder } from text-decoding global.TextEncoder TextEncoder global.TextDecoder TextDecoder

MySQL 进阶(三)【SQL 优化】

1、SQL 优化 1.1、插入数据优化 1.1.1、Insert 优化 1、批量插入 插入多条数据时,不建议使用单条的插入语句,而是下面的批量插入: INSERT INTO tb_name VALUES (),(),(),...; 批量插入建议一次批量 500~100 条,如果数据量比…

【CSS in Depth 2 精译】2.6 CSS 自定义属性(即 CSS 变量)+ 2.7 本章小结

文章目录 2.6 自定义属性(即 CSS 变量)2.6.1 动态变更自定义属性 2.7 本章小结 当前内容所在位置 第一章 层叠、优先级与继承第二章 相对单位 2.1 相对单位的威力2.2 em 与 rem2.3 告别像素思维2.4 视口的相对单位2.5 无单位的数值与行高2.6 自定义属性 …

讲讲 JVM 的内存结构(附上Demo讲解)

讲讲 JVM 的内存结构 什么是 JVM 内存结构?线程私有程序计数器​虚拟机栈本地方法栈 线程共享堆​方法区​注意永久代​元空间​运行时常量池​直接内存​ 代码详解 什么是 JVM 内存结构? JVM内存结构分为5大区域,程序计数器、虚拟机栈、本地…

头歌---数组之Fibonacci数列

一、数组初始化几种方式 1.数组定义时,数组元素全部赋初值 2.部分数组赋初值 >>>>>前三个元素已知初值 >>>>>后三个元素系统自动赋初值为0 注意: 当定义数组时,如果未对它的元素指定过初值,对于内部局部数组…

【openwrt】Openwrt系统新增普通用户指南

文章目录 1 如何新增普通用户2 如何以普通用户权限运行服务3 普通用户如何访问root账户的ubus服务4 其他权限控制5 参考 Openwrt系统在默认情况下只提供一个 root账户,所有的服务都是以 root权限运行的,包括 WebUI也是通过root账户访问的,…

【C++航海王:追寻罗杰的编程之路】哈希的应用——位图 | 布隆过滤器

目录 1 -> 位图 1.1 -> 位图的概念 1.2 -> 位图的应用 2 -> 布隆过滤器 2.1 -> 布隆过滤器的提出 2.2 -> 布隆过滤器的概念 2.3 -> 布隆过滤器的插入 2.4 -> 布隆过滤器的查找 2.5 -> 布隆过滤器的删除 2.6 -> 布隆过滤器的优点 2.7…

视频监控汇聚平台LntonCVS视频集中存储平台解决负载均衡的方案

随着技术的进步和企业对监控需求的增加,视频监控系统规模不断扩大,接入大量设备已成常态化挑战。为应对这一挑战,视频汇聚系统LntonCVS视频融合平台凭借其卓越的高并发处理能力,为企业视频监控管理系统提供可靠的负载均衡服务保障…

6.Neo4j数据库备份

对neo4j数据进行备份、还原、迁移操作时,要关闭neo4j。 将neo4j作为服务使用进行安装: neo4j install-service 先执行上面的命令,才能执行 neo4j stop 数据备份 执行备份命令: neo4j-admin dump --databasegraph.db --to/ne…

C++的入门基础(二)

目录 引用的概念和定义引用的特性引用的使用const引用指针和引用的关系引用的实际作用inlinenullptr 引用的概念和定义 在语法上引用是给一个变量取别名,和这个变量共用同一块空间,并不会给引用开一块空间。 取别名就是一块空间有多个名字 类型& …

Docker基本管理1

Docker 概述 Docker是一个开源的应用容器引擎,基于go语言开发并遵循了apache2.0协议开源。 Docker是在Linux容器里运行应用的开源工具,是一种轻量级的“虚拟机”。 Docker 的容器技术可以在一台主机上轻松为任何应用创建一个轻量级的、可移植的、自给自…

Spring Web MVC入门(2)(请求1)

目录 请求 1.传递单个参数 2.传递多个参数 3.传递对象 4.后端参数重命名(后端参数映射) 非必传参数设置 5.传递数组 请求 访问不同的路径就是发送不同的请求.在发送请求时,可能会带一些参数,所以学习Spring的请求,主要是学习如何传递参数到后端及后端如何接收. 1.传递单…

Linux多线程编程-哲学家就餐问题详解与实现(C语言)

在哲学家就餐问题中,假设有五位哲学家围坐在圆桌前,每位哲学家需要进行思考和进餐两种活动。他们的思考不需要任何资源,但进餐需要使用两根筷子(左右两侧各一根)。筷子是共享资源,哲学家们在进行进餐时需要…

IDEA中Git常用操作及Git存储原理

Git简介与使用 Intro Git is a free and open source distributed version control system designed to handle everything from small to very large projects with speed and efficiency. Git是一款分布式版本控制系统(VSC),是团队合作开发…