【AI炼丹术】写深度学习代码的一些心得体会

写深度学习代码的一些心得体会

  • 体会1
  • 体会2
  • 体会3
  • 总结
  • 内容来源

一般情况下,拿到一批数据之后,首先会根据任务先用领域内经典的Model作为baseline跑通,然后再在这个框架内加入自己设计的Model,微调代码以及修改一些超参数即可。总体流程参考如下:

  1. 先写dataset部分,包括数据的读取、预处理、增广等操作,将数据集准备好。
  2. 然后model部分baseline无需修改,proposed是自行设计,定义模型的结构和参数,建立模型架构。
  3. 最后是train部分,这里调用所有的类实现训练:包括定义模型,模型包裹;获取dataloader;定义loss,优化器,学习率,定义early stoping策略;保存模型权重,保存日志。

当然,文无定法。这个顺序并不是固定不变的,也可以根据具体情况作出相应的调整。例如,当你的数据集已经准备好了,可以直接开始定义模型,然后再定义训练过程;或者在进行模型训练之前,先进行数据集的分析和可视化等操作。

体会1

源自:作者三四但不犹豫
对于图像任务:

  1. 顺序上,先写dataset部分,检查基本的transform,再搭model,构建head和loss,就可以把一个基础的、可以跑的网络就能跑起来了(这点很重要);
  2. 可视化很重要,如果是本地开发机,善用cv.imshow直观、便捷地可视化处理的结果;
  3. 一个基础的train/inference流程跑通后,分别构建1 张、10 张的数据用于debug,确保任意改动后,可以overfit;
  4. 调试代码阶段避免随机性、避免数据增强,一定用tensorboard之类的工具观察 loss 下降是否合理;
  5. 一般数据集最好处理成coco的格式,我的任务跟传统任务不太一样,但也尽量仿照coco来设计,写dataset的时候可以参考开源实现;
  6. 善用开源框架,比如Open-MMLab,Detectron2之类的,好处是方便实验,在框架里写不容易出现难以察觉的bug,坏处是开源框架为了适配各种网络,代码复杂程度会高一点,建议从第一版入手了解框架,然后基于最新的一边阅读一边开发。

体会2

源自:捡到一束光
先给结论:以写了两三年pytorch代码的经验而言,比较好的顺序是先写model,再写dataset,最后写train。在讨论码组件的具体顺序前,先分析每一个组件背后的目的和逻辑。

  • model构成了整个深度学习训练与推断系统骨架,也确定了整个AI模型的输入和输出格式
    • 对于视觉任务,模型架构多为卷积神经网络或是最新的ViT模型
    • 对于NLP任务,模型架构多为Transformer以及Bert
    • 对于时间序列预测,模型架构多为RNNLSTM

不同的model对应了不同的数据输入格式,如ResNet一般是输入多通道二维矩阵,而ViT则需要输入带有位置信息的图像patchs。确定了用什么样的model后,数据的输入格式也就确定下来。根据确定的输入格式,我们才能构建对应的dataset。

  • dataset构建了整个AI模型的输入与输出格式。

    • 在写作dataset组件时,我们需要考虑数据的存储位置与存储方式,如数据是否是分布式存储的,模型是否要在多机多卡的情况下运行,读写速度是否存在瓶颈,如果机械硬盘带来了读写瓶颈则需要将数据预加载进内存等。
    • 在写dataset组件时,我们也要反向微调model组件。例如,确定了分布式训练的数据读写后,需要用nn.DataParallel或者nn.DistributedDataParallel等模块包裹model,使模型能够在多机多卡上运行。
    • 此外,dataset组件的写作也会影响训练策略,这也为构建train组件做了铺垫。比如根据显存大小,我们需要确定相应的BatchSize,而BatchSize则直接影响学习率的大小。再比如根据数据的分布情况,我们需要选择不同的采样策略进行Feature Balance,而这也会体现在训练策略中。
  • train构建了模型的训练策略以及评估方法,它是最重要也是最复杂的组件。先构建model与dataset可以添加限制,减少train组件的复杂度。

    • 在train组件中,我们需要根据训练环境(单机多卡,多机多卡或是联邦学习)确定模型更新的策略,以及确定训练总时长epochs,优化器的类型,学习率的大小与衰减策略,参数的初始化方法,模型损失函数
    • 此外,为了对抗过拟合,提升泛化性,还需要引入合适的正则化方法,如Dropout,BatchNorm,L2-Regularization,Data Augmentation等。
    • 有些提升泛化性能的方法可以直接在train组件中实现(如添加L2-Reg,Mixup),有些则需要添加进model中(如Dropout与BatchNorm),还有些需要添加进dataset中(如Data Augmentation)。。

此外,train还需要记录训练过程的一些重要信息,并将这些信息可视化出来,比如在每个epoch上记录训练集的平均损失以及测试集精度,并将这些信息写入tensorboard,然后在网页端实时监控。在构建train组件中,我们需要随时根据模型表现进行参数微调,并根据结果改进model和dataset两个组件。
tensorboard

体会3

源自:芙兰朵露
作为data driven的学科,不同的AI model适合不同的数据类型,选择用哪个模型是基于你的数据长什么样来决定的。初学者知道用CNN处理图片,用RNN处理时间序列/语言,但这些都是最基础的工作,真正体现水平的是根据数据的性质来选择合适的细分模型。比如稀疏图像需要用Sparse CNN,语言Transformer效果比较好,但对某些特殊的时间序列RNN也有奇效。

接下来还有很多技术细节,比如需不需要数据增强?需不需要标签平滑?需不需要残差链接?需不需要多loss,如果需要如何平衡?需不需要解释模型?我甚至没有提到超参数,因为超参数是锦上添花而不是雪中送炭。只要没有明确的信息瓶颈,超参数对模型的影响是很小的。

上面提到的这些问题不需要全想明白,但心里要大致有个谱,至少也要知道这些问题是可能影响你的训练结果的,这其实需要相当的阅读和积累。这样之后出了问题才知道去哪里debug。

然后就可以开始写了。这些问题想明白之后,其实先写哪个part已经不重要了,因为你的心中已经有了一个picture,先把这个picture给sketch下来,然后开始跑,第一遍效果肯定不好,但你要根据输出的结果大致判断哪个部分出了问题,然后针对性地去改进。这一步真的没什么好办法,很多时候其实是直觉,做多了自然就知道了。训练模型-发现问题-修改模型-再训练,就像炼丹一样,经过无数遍的抟炼,才能得到最后的金丹。

其实洋洋洒洒说了这么多,本质不过是几个字:解决问题的能力making things to work几乎是机器学习中最重要的能力了,而这种能力就是在日常的积累和训练中反复磨练出来的,成功的路上没有捷径

总结

单纯就个人习惯而言,先写model,确保model的结果没有错误,调试正确。然后写dataset,并调试输出正确。之后写损失函数,并调试正确。最后写train训练代码,推理代码。

内容来源

  1. 写深度学习代码是先写model还是dataset还是train呢,有个一般化的顺序吗?
  2. A Recipe for Training Neural Networks

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/14991.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

RocketMQ第一节(MQ的初步了解)

目录 1:什么是消息队列 2:MQ的基础模型 3:MQ的作用 3.1:MQ用来解耦 3.2: 削峰填谷 4:MQ怎么选 1:什么是消息队列 MQ全称是Message Queue (消息队列),是消息传输中间件&#xf…

huggingface下载的.arrow数据集读取与使用说明

1.数据下载方式:load_dataset 将数据集下载到本地:(此处下载的是一个物体目标检测的数据集) from datasets import load_dataset # 下载的数据集名称, model_name keremberke/plane-detection # 数据集保存的路径 save_path da…

mac十大必备软件排行榜 mac垃圾清理软件哪个好

刚拿到全新的mac电脑却不知道该怎么使用?首先应该装什么软件呢?如果你有同样的疑惑,今天这篇文章一定不要错过。接下来小编为大家介绍mac十大必备软件排行榜,以及mac垃圾清理软件哪个好。 一、mac十大必备软件排行榜 1.CleanMyM…

停车场管理系统的设计与实现_kaic

目 录 1 概 述 1.1研究背景 1.2研究现状 1.3研究内容 2 相关技术简介 2.1 JSP技术 2.2 JAVA技术 2.3 MYSQL数据库 2.4 B/S结构 3 系统需求分析 3.1 系统可行性分析 3.1.1 操作可行性 3.1.2 经济可行性 3.1.3 技术可行性 3.2 系统性能分析 3.3系统流程分析 3.3.1注册流程 3.3.…

智慧园区数字化转型下的移动App发展

随着智慧城市的建设和智慧园区的崛起,智慧园区数字一体化建设成为园区发展的重心,当然数字转型离不开移动应用的整合服务。 在过去的几年中,智慧园区移动应用已经发展成为园区管理和服务的重要手段之一,为企业和员工提供了更加便…

Machine Learning-Ex6(吴恩达课后习题)Support Vector Machines

目录 1. Support Vector Machines 1.1 Example Dataset 1 1.2 SVM with Gaussian Kernels 1.2.1 Gaussian Kernel 1.2.2 Example Dataset 2 1.2.3 Example Dataset 3 2. Spam Classification 2.1 Preprocessing Emails 2.1.1 Vocabulary List 2.2 Extracting Feature…

安卓GB28181-2022 RTP over TCP

使用TCP传输RTP包,GB28181-2016和GB28181-2022 都是按IETF RFC4571来的。使用TCP发送RTP包,前面加个16位无符号长度字段就好(网络字节序)。具体定义格式如下: 需要注意的是LENGTH值可以是0,0的话表示空包; 另外UDP传输RTP包&#…

第二届易派客工业品展圆满落幕 3天超7万人次观展

4月15日,第二届易派客工业品展览会在苏州国际博览中心成功闭幕,展会期间共7.4万人次观展。展会以“绿色•智造•融通•赋能”为主题,为参展企业衔接供需、共享商机、共促发展提供平台,推动工业企业数字化转型、致力供应链优化升级…

blast的-max_target_seqs?

Shah, N., Nute, M.G., Warnow, T., and Pop, M. (2018). Misunderstood parameter of NCBI BLAST impacts the correctness of bioinformatics workflows. Bioinformatics. 杂志Bioinformatics以letter to the editor的形式刊发了来自美国马里兰大学计算机系的Nidhi Shah等人…

powershell搞定烦人的Windows Defender

0x00 Windows Defender真烦 最近装了不少虚拟机,发现目前较新版本的windows Defender是真的烦,关了一段时间后,自己又打开。特别是装了域控后的winserver 2016,半都关不掉,做个实验是真烦。 顺手去查了下如何使用pow…

ThinkPHP模型操作上

ThinkPHP模型操作上 前言模型一、创建模型二、模型操作 总结 前言 在mvc架构中,模型的解释是写逻辑代码的地方,其实还可以这样理解,就是一串操作写在一个模型类中,就是你要完成某一项功能,将这个功能的代码写在一个mod…

2023年产业基金研究报告

第一章 行业概况 1.1 概述 产业基金,又称为产业投资基金,是一种由政府、企业、金融机构等出资设立的,专门用于支持和促进特定产业发展的投资基金。产业基金通常以股权投资和长期投资为主,旨在推动产业结构升级、促进科技创新、提…

基于ResNet-attention的负荷预测

一、attention机制 注意力模型最近几年在深度学习各个领域被广泛使用,无论是图像处理、语音识别还是自然语言处理的各种不同类型的任务中,都很容易遇到注意力模型的身影。从注意力模型的命名方式看,很明显其借鉴了人类的注意力机制。我们来看…

融云 CTO 岑裕:出海技术前沿探索和排「坑」实践

在本文中,你将看到以下内容: 全球通信网络在接入点、链路加速、服务商、协议等层面的动态演进; 进入到具体市场,禁运国、跨国拦截、区域一致性差等细节“坑点”如何应对; 融云如何从技术侧帮助开发者应对本地化用户体…

Hive与HBase的区别及应用场景

目录: 零、前言一、定义二、区别三、应用场景 零、前言 在学大数据分析的过程中,Hive和HBase是两个非常重要的内容,对于初学者而言容易混淆。所以比较两者区别,能够帮助我们对这两个组件有一个清晰的认识和定位。那么,…

一篇文章看懂MySQL的多表连接(包含左/右/全外连接)

MySQL的多表查询 这是第二次学习多表查询,关于左右连接还是不是很熟悉,因此重新看一下。小目标:一篇文章看懂多表查询!! 这篇博客是跟着宋红康老师学习的,点击此处查看视频,关于数据库我放在了…

大神们分享STM32的学习方法

单片机用处这么广,尤其是STM32生态这么火!如何快速上手学习呢? 第一:你要考虑的是,要用STM32实现什么 为什么使用STM32而不是8051? 是因为51的频率太低,无法满足计算需求?是51的管脚太少,无法…

云HIS(二级医院,乡镇医院,民营医院,标准化HIS医院信息管理系统源码)

传统 HIS(基于医院信息系统) 和云 HIS(基于云计算的医院信息系统)各有优缺点,选择哪种系统需要根据具体情况进行权衡。 传统 HIS 系统通常由医院自行开发和维护,适用于医院内部信息化程度较高、数据安全性…

【软件测试】第1章 软件测试概述

系列文章目录 文章目录 系列文章目录前言第1章 软件测试概述1.1 软件、软件危机和软件工程1.1.1 基本概念1.1.2 软件工程的目标及其一般开发过程1.1.3 软件过程模型 1.2 软件缺陷与软件故障1.2.1 基本概念1.2.2 典型案例 1.3 软件测试的概念1.3.1 软件测试的定义1.3.2 软件测试…

计算机程序安装及使用须知_kaic

安装及使用须知 1 数据库建模程序的使用 本文件夹中的“PowerDesigner建模”目录下包含三个可运行文件TMS1.cdm,TMS.cdm,TMS.pdm分别为TMS系统的实体关系简图、实体关系图和数据库模型,使用PowerDesigner集成开发环境打开任意一个文件即可运…