神经网络学习小记录74——Pytorch 设置随机种子Seed来保证训练结果唯一

神经网络学习小记录74——Pytorch 设置随机种子Seed来保证训练结果唯一

  • 学习前言
  • 为什么每次训练结果不同
  • 什么是随机种子
  • 训练中设置随机种子

学习前言

好多同学每次训练结果不同,最大的指标可能会差到3-4%这样,这是因为随机种子没有设定导致的,我们一起看看怎么设定吧。
在这里插入图片描述

为什么每次训练结果不同

模型训练中存在很多随机值,最常见的有:
1、随机权重,网络有些部分的权重没有预训练,它的值则是随机初始化的,每次随机初始化不同会导致结果不同。
2、随机数据增强,一般来讲网络训练会进行数据增强,特别是少量数据的情况下,数据增强一般会随机变化光照、对比度、扭曲等,也会导致结果不同。
3、随机数据读取,喂入训练数据的顺序也会影响结果。
……
应该还有别的随机值,这里不一一列出,这些随机都很容易影响网络的训练结果。

如果能够固定权重、固定数据增强情况、固定数据读取顺序,网络理论上每一次独立训练的结果都是一样的。

什么是随机种子

随机种子(Random Seed)是计算机专业术语。一般计算机的随机数都是伪随机数,以一个真随机数(种子)作为初始条件,然后用一定的算法不停迭代产生随机数。

按照这个理解,我们如果可以设置最初的 真随机数(种子),那么后面出现的随机数将会是固定序列。

以random库为例,我们使用如下的代码,前两次为随机生成,后两次为设置随机数生成器种子后生成。

import random

# 生成随机整数
print("第一次随机生成")
print(random.randint(1,100))
print(random.randint(1,100))

# 生成随机整数
print("第二次随机生成")
print(random.randint(1,100))
print(random.randint(1,100))

# 设置随机数生成器种子
random.seed(11)

# 生成随机整数
print("第一次设定种子后随机生成")
print(random.randint(1,100))
print(random.randint(1,100))

# 重置随机数生成器种子
random.seed(11)

# 生成随机整数
print("第二次设定种子后随机生成")
print(random.randint(1,100))
print(random.randint(1,100))

结果如下,前两次随机生成的序列不同,后两次设定种子后随机生成的序列相同:

第一次随机生成
66
37
第二次随机生成
93
56
第一次设定种子后随机生成
58
72
第二次设定种子后随机生成
58
72

训练中设置随机种子

一般训练会用到多个库包含有关random的内容。

在pytorch构建的网络中,一般都是使用下面三个库来获得随机数,我们需要对三个库都设置随机种子:
1、torch库;
2、numpy库;
3、random库。

在这里写了一个函数:

#---------------------------------------------------#
#   设置种子
#---------------------------------------------------#
def seed_everything(seed=11):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

这里面写到了cuda、cudnn这类gpu才会用到的东西,实测发现cpu版本的pytorch也可以正常运行。

torch.backends.cudnn.deterministic=True用于保证CUDA 卷积运算的结果确定。
torch.backends.cudnn.benchmark=False是用于保证数据变化的情况下,减少网络效率的变化。为True的话容易降低网络效率。

只需要在所有初始化前,调用该seed初始化函数即可。

另外,Pytorch一般使用Dataloader来加载数据,Dataloader一般会使用多worker加载多进程来加载数据,此时我们需要使用Dataloader自带的worker_init_fn函数初始化Dataloader启动的多进程,这样才能保证多进程数据加载时数据的确定性。

#---------------------------------------------------#
#   设置Dataloader的种子
#---------------------------------------------------#
def worker_init_fn(worker_id, rank, seed):
    worker_seed = rank + seed
    random.seed(worker_seed)
    np.random.seed(worker_seed)
    torch.manual_seed(worker_seed)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/35994.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

删除链表的倒数第 N 个结点——力扣19

题目描述 法一)计算链表长度 class Solution { public:int getLength(ListNode* head){int len0;while(head){len;head head->next;}return len;}ListNode* removeNthFromEnd(ListNode* head, int n) {int len getLength(head);ListNode* dummy new ListNode …

【Hello mysql】 mysql的约束

Mysql专栏:Mysql 本篇博客简介:介绍mysql的约束 mysql的约束 表的约束空属性默认值列描述zerofill主键自增长唯一键外键总结 表的约束 为什么要有约束? 我们在收集一些数据的时候会要求该数据必须存在 比如说像是国家在登记公民信息的时候身…

【Linux | Shell】Linux 安全系统 —— 用户、组、文件权限 - 阅读笔记

目录 一、Linux 的安全性1.1 /etc/passwd 文件1.2 /etc/shadow 文件1.3 添加新用户 —— useradd1.4 删除用户 —— userdel1.5 修改用户 —— usermod、passwd、chpasswd 二、使用 Linux 组2.1 /etc/group 文件2.2 创建新组 —— groupadd2.3 修改组 —— groupmod 三、理解文…

Swagger-Bootstrap-UI

Swagger-Bootstrap-UI 是一个为 Swagger 提供美观、易用的界面展示和增强功能的开源项目。它通过自定义样式和交互,提供了更好的文档展示和交互体验,包括美化的界面、接口测试工具、在线调试、文档导出等功能。 更高阶的有Knife4j,Knife4j是一个集Swagg…

免费 Selenium各大浏览器驱动【谷歌chrme、火狐Firefox、IE浏览器】

aardio群 625494397 废话不多说 直接开整! 竟然还有脸收费 服了 下载对应版本的浏览器驱动 目标网址 应用场景 Selenium库涉及到 安装selenium库 pip install selenium-i https://mirrors.aliyun.com/pypi/simple/下载对应浏览器驱动 https://registry.npmmirror.c…

水电站运行数据3D可视化展示方便管理运维

水电站是现代能源体系中的重要组成部分,对于保障国家能源安全和经济发展具有重要的意义。然而,由于水电站的建设和管理涉及到大量的技术和专业知识,许多人对水电站的运行和维护存在许多疑惑和困惑。为了解决这些问题,我们引入了全…

TCP 重传机制 滑动窗口 流量控制 拥塞控制 学习总结!

大家好,我是三叔,很高兴这期又和大家见面了,一个奋斗在互联网的打工人。 这篇博客主要记录 tcp 的基础知识的学习总结相关内容!备注:图片均来自作者:小林哥 tcp 的三大特性 tcp 大家都不陌生&#xff0c…

Work20230705

//main.c #include "uart4.h" extern void printf(const char *fmt, ...); void delay_ms(int ms) {int i,j;for(i 0; i < ms;i)for (j 0; j < 1800; j); }int main() {while(1){//将获取到的字符1发送到终端//hal_put_char(hal_get_char()1);hal_put_string…

chatGPT之100个例子-从体验到精通

简介 本博文演示了100个chatGPT在各行各业的使用例子,全部看完有助于培养chatGPT解决问题的思维。 在人工智能时代,智能软件并不会淘汰人类,淘汰人类的是会使用人工智能的人! 我们直接使用openAI官方chatGPT,生动演示了chatGPT的一些妙用! 请仔细看完,一定会有收获! 每…

[期末网页作业]-精仿华为官网10个网页(html+css+js)

经过漫长的期末考试季节&#xff0c;我成功地完成了一个华为官网的仿写项目&#xff0c;并且非常高兴地与大家分享。这个项目包含了10个页面&#xff0c;每一个页面都经过了精心的设计和努力的填充。 首先&#xff0c;我注重了页面的整体布局与设计。借鉴了华为官网的风格&…

apache php mysql python 环境部署与离线安装deb包

文章目录 1.背景介绍2. 主要涉及操作2.1 安装系统&#xff1a;2.2 apache mysql php安装2.3 配置2.4 python相关库安装 3. 操作记录3.1 软件安装3.2 读取文件内容后进行文件内容抽取3.3 执行以上的sh脚本3.4 所学3.5 打包发送 4. 参考文献 1.背景介绍 使用的系统为ubuntu18.04…

vscode安装+配置+使用+调试【保姆级教程】

1. VScode是什么 Visual Studio Code简称VS Code&#xff0c;是一款跨平台的、免费且开源的现代轻量级代码编辑器&#xff0c;支持几乎主流开发语言的语法高亮、智能代码补全、自定义快捷键、括号匹配和颜色区分、代码片段提示、代码对比等特性&#xff0c;也拥有对git的开箱即…

架构师进阶之路 - 微服务怎么划分

目录 微服务划分目标 业务、技术、团队导向规划服务 领域检查 依赖DAG检查 分布式事务检查 性能分布检查 稳定&#xff08;易变&#xff09;性检查 调用链检查 微服务划分目标 我们常说服务的合理划分是微服务成功的重中之重&#xff0c;一个合理的服务划分应该符合一下…

Linux性能优化实践——CPU上下文

CPU上下文切换 Linux是一个多任务操作系统&#xff0c;它支持远大于CPU数量的任务同时运行。这些任务不是真正意义上的并行运行&#xff0c;而是系统在短时间内&#xff0c;将CPU轮流分配给它们&#xff0c;造成任务同时运行的错觉。 CPU需要知道任务从哪里加载&#xff0c;从…

Linux调优–I/O 调度器

Linux 的 I/O 调度器是一个以块式 I/O 访问存储卷的进程&#xff0c;有时也叫磁盘调度器。Linux I/O 调度器的工作机制是控制块设备的请求队列&#xff1a;确定队列中哪些 I/O 的优先级更高以及何时下发 I/O 到块设备&#xff0c;以此来减少磁盘寻道时间&#xff0c;从而提高系…

QC51XX---I2c使用

QCC51XX---系统学习目录_嵌入式学习_force的博客-CSDN博客 今天我们聊聊I2C的使用。在开发过程中多多少少会加入一些外围传感器可以给用户带来更好的使用体验。例如,利用接近传感器识别戴上耳机和取下耳机,从而去自动暂停播放音乐,或接听转移通话。又例如,用触摸或敲击替代…

软件安全技术复习内容

软件安全技术 边复习边写的&#xff0c;有错误及时指正第一章 软件安全概述零日漏洞安全威胁分类CIA安全基本属性PDRR模型软件安全的主要方法和技术基本方法主要技术 第二章 软件漏洞概述概念软件漏洞成因分析软件漏洞分类基于漏洞成因的分类基于漏洞利用位置的分类基于威胁类型…

基于Stable Diffusion的2D游戏关卡生成【实战】

接下来的几篇文章将与常规主题有所不同&#xff08;这是在从事通用机器人技术的职业中吸取的教训&#xff09;。 相反&#xff0c;我决定利用我的一些新空闲时间 1 边做边学&#xff0c;并使用所有酷孩子都在谈论的一些很酷的新 ML。 推荐&#xff1a;用 NSDT设计器 快速搭建可…

第五章 运输层【计算机网络】

第五章 运输层【计算机网络】 前言推荐第五章 运输层5.1运输层协议概述5.1.1 进程之间的通信5.1.2运输层的两个主要协议5.1.3运输层的端口 5.2用户数据报协议UDP5.2.1UDP概述5.2.2UDP的首部格式 5.3传输控制协议TCP概述5.3.1TCP最主要的特点5.3.2TCP的连接 5.4可靠传输的工作原…

Node中express路由基本使用

1.路由的基本使用 //引入express const expressrequire("express") //创建路由 const appexpress() //规定路由的请求方法 app.get(/,(req,res)>{res.end("hello express") }) //启动路由&#xff0c;并且端口为9000 app.listen(9000,()>{console.l…