【深度学习】(7)--神经网络之保存最优模型

文章目录

  • 保存最优模型
    • 一、两种保存方法
      • 1. 保存模型参数
      • 2. 保存完整模型
    • 二、迭代模型
  • 总结

保存最优模型

我们在迭代模型训练时,随着次数初始的增多,模型的准确率会逐渐的上升,但是同时也随着迭代次数越来越多,由于模型会开始学习到训练数据中的噪声或非共性特征,发生过拟合现象,使得模型的准确率会上下震荡甚至于下降。

本篇就是介绍我们如何在进行那么多次迭代之中,找到训练最好效果时,模型的参数或完整模型。也方便以后使用模型时直接使用。

一、两种保存方法

我们知道,一个模型到底好不好,主要体现在对测试集数据结果上的表现,所以我们的方法主要从测试集入手,计算每次迭代测试集数据的准确率,取到准确率最大时对应的模型和参数

那么,我们该如何保存模型和参数呢?介绍一个小东西:

  • 文件拓展名pt\pth,t7,使用pt\pth或t7作为模型文件扩展名,保存模型的整个状态(包括模型架构和参数)或仅保存模型的参数(即状态字典,state_dict)。

1. 保存模型参数

方法

torch.save(model.state_dict(),path)
# model.state_dict()是一个从参数名称映射到参数张量的字典对象,它包含了模型的所有权重和偏置项
# path为创建的保存模型的文件

通过比较每一次迭代准确率的大小,取准确率最大时模型的参数

best_acc = 0
"""-----测试集-----"""
def test(dataloader,model,loss_fn):
    global best_acc
    size = len(dataloader.dataset) # 总数据大小
    num_batches = len(dataloader) # 划分的小批次数量
    model.eval()
    test_loss,correct = 0,0
    with torch.no_grad():
        for x,y in dataloader:
            x,y = x.to(device),y.to(device)
            pred = model.forward(x)
            test_loss += loss_fn(pred,y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 预测正确的个数
    test_loss /= num_batches
    correct /= size
    correct = round(correct, 4)
    print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")

    # 保存最优模型的方法(文件扩展名一般:pt\pth,t7)
    if correct > best_acc:
        best_acc = correct
    # 1. 保存模型参数方法:torch.save(model.state_dict(),path)  (w,b)
        print(model.state_dict().keys()) # 输出模型参数名称cnn
        torch.save(model.state_dict(),"best.pth") 

2. 保存完整模型

方法

torch.save(model,path)
# 直接得到整个模型

依旧是通过比较每一次迭代准确率的大小,但是取准确率最大时的整个模型

def test(dataloader,model,loss_fn):
    global best_acc
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss,correct = 0,0
    with torch.no_grad():
        for x,y in dataloader:
            x,y = x.to(device),y.to(device)
            pred = model.forward(x)
            test_loss += loss_fn(pred,y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            a = (pred.argmax(1) == y)
            b = (pred.argmax(1) == y).type(torch.float)
    test_loss /= num_batches
    correct /= size
    correct = round(correct, 4)
    print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")

# 保存最优模型的方法(文件扩展名一般:pt\pth,t7)
    if correct > best_acc:
        best_acc = correct
    # 2. 保存完整模型(w,b,模型cnn)
        torch.save(model,"best1.pt")

二、迭代模型

接下来就要迭代模型,得到最优的模型:

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001,weight_decay=0.0001)

epochs = 150
# training_data、test_data:数据预处理好的数据
train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)
for t in range(epochs):
    print(f"Epoch {t+1} \n-------------------------")
    train(train_dataloader,model,loss_fn,optimizer)
    test(test_dataloader,model,loss_fn)
print("Done!")

在每轮数据迭代后,project工程栏中的best1.ptbest.pth文件中模型会随着迭代及时更新,迭代结束后,文件中保存的就是最优模型以及最优的模型参数。

在这里插入图片描述

总结

本篇介绍了:

  1. 为什么随着迭代次数越来越多,模型的准确率会上下震荡甚至于下降。—> 过拟合
  2. pt\pth,t7三个扩展名,用于保存完整模型或者模型参数。
  3. 模型的好坏,通过体现在测试集的结果上。
  4. 保存最优模型的两种方法:保存模型参数和保存完整模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/886078.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

STM32 软件触发ADC采集

0.91寸OLED屏幕大小的音频频谱,炫酷! STM32另一个很少人知道的的功能——时钟监测 晶振与软件的关系(深度理解) STM32单片机一种另类的IO初始化方法 ADC是一个十分重要的功能,几乎任何一款单片机都会包含这个功能&a…

C++ 游戏开发

C游戏开发 C 是一种高效、灵活且功能强大的编程语言,因其性能和控制能力而在游戏开发中被广泛应用。许多著名的游戏引擎,如 Unreal Engine、CryEngine 和 Godot 等,都依赖于 C 进行核心开发。本文将详细介绍 C 在游戏开发中的应用&#xff0…

网页钓鱼---钓鱼网页的制作与危害上线

免责声明: 本文仅供了解攻击方攻击手法反制使用,切勿用于非法用途 前言 欢迎来到我的博客 个人主页:北岭敲键盘的荒漠猫-CSDN博客 基础能力要求 如果仅仅造个页面的话,不需要什么基础。但是如果想要让这个钓鱼页面能够真正实现攻击,得要求制…

Android 简单实现联系人列表+字母索引联动效果

效果如上图。 Main Ideas 左右两个列表左列表展示人员数据,含有姓氏首字母的 header item右列表是一个全由姓氏首字母组成的索引列表,点击某个item,展示一个气泡组件(它会自动延时关闭), 左列表滚动并显示与点击的索引列表item …

Meta首款多模态Llama 3.2开源:支持图像推理,还有可在手机上运行的版本 | LeetTalk Daily...

“LeetTalk Daily”,每日科技前沿,由LeetTools AI精心筛选,为您带来最新鲜、最具洞察力的科技新闻。 Meta最近推出的Llama Stack的发布标志着一个重要的里程碑。这一新技术的推出不仅为开发者提供了强大的多模态能力,还为企业和初…

基于单片机的多路温度检测系统

**单片机设计介绍,基于单片机CAN总线的多路温度检测系统设计 文章目录 前言概要功能设计设计思路 软件设计效果图 程序设计程序 前言 💗博主介绍:✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师,一名热衷于单片机技术探…

详细介绍:API 和 SPI 的区别

文章目录 Java SPI (Service Provider Interface) 和 API (Application Programming Interface) 的区别详解目录1. 定义和目的1.1 API (Application Programming Interface)1.2 SPI (Service Provider Interface) 2. 使用场景2.1 API 的应用场景2.2 SPI 的应用场景 3. 加载和调…

Python的异步编程

什么是协程? 协程不是计算机系统提供,程序员人为创造。 协程也可以被称为微线程,是一种用户态内的上下文切换技术。简而言之,其实就是通过一个线程实现代码块相互切换执行。 实现协程有那么几种方法: greenlet&…

Qt/C++ 解决调用国密SM3,SM4加密解密字符串HEX,BASE64格式转换和PKCS5Padding字符串填充相关问题

项目中遇到了需要与JAVA WEB接口使用SM3,SM4加密数据对接的需求,于是简单了解了下SM3与SM4加密算法在C环境下的实现。并使用Qt/C还原了在线SM3国密加密工具和在线SM4国密加密解密工具网页的示例功能的实现 目录导读 前言SM3算法简介SM4算法简介 实现示例字符串HEX,B…

慢病中医药膳养生食疗管理微信小程序、基于微信小程序的慢病中医药膳养生食疗管理系统设计与实现、中医药膳养生食疗管理微信小程序的开发与应用(源码+文档+定制)

博主介绍: ✌我是阿龙,一名专注于Java技术领域的程序员,全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师,我在计算机毕业设计开发方面积累了丰富的经验。同时,我也是掘金、华为云、阿里云、InfoQ等平台…

计算机网络:计算机网络体系结构 —— 专用术语总结

文章目录 专用术语实体协议服务服务访问点 SAP 服务原语 SP 协议数据单元 PDU服务数据单元 SDU 专用术语 实体 实体是指任何可以发送或接收信息的硬件或软件进程 对等实体是指通信双方处于相同层次中的实体,如通信双方应用层的浏览器进程和 Web 服务器进程。 协…

docker 部署 Seatunnel 和 Seatunnel Web

docker 部署 Seatunnel 和 Seatunnel Web 说明: 部署方式前置条件,已经在宿主机上运行成功运行文件采用挂载宿主机目录的方式部署SeaTunnel Engine 采用的是混合模式集群 编写Dockerfile并打包镜像 Seatunnel FROM openjdk:8 WORKDIR /opt/seatunne…

【在Linux世界中追寻伟大的One Piece】System V共享内存

目录 1 -> System V共享内存 1.1 -> 共享内存数据结构 1.2 -> 共享内存函数 1.2.1 -> shmget函数 1.2.2 -> shmot函数 1.2.3 -> shmdt函数 1.2.4 -> shmctl函数 1.3 -> 实例代码 2 -> System V消息队列 3 -> System V信号量 1 -> Sy…

基于两分支卷积和 Transformer 的轻量级多尺度特征融合超分辨率网络 !

当前的单图像超分辨率(SISR)算法有两种主要的深度学习模型,一种是基于卷积神经网络(CNN)的模型,另一种是基于Transformer的模型。前者利用不同卷积核大小的卷积层堆叠来设计模型,使得模型能够更…

OpenFeign微服务部署

一.开启nacos 和redis 1.查看nacos和redis是否启动 docker ps2.查看是否安装nacos和redis docker ps -a3.启动nacos和redis docker start nacos docker start redis-6379 docker ps 二.使用SpringSession共享例子 这里的两个例子在我的一个博客有创建过程&#xff0c…

rtmp协议转websocketflv的去队列积压

websocket server的优点 websocket server的好处:WebSocket 服务器能够实现实时的数据推送,服务器可以主动向客户端发送数据 1 不需要客户端不断轮询。 2 不需要实现httpserver跨域。 在需要修改协议的时候比较灵活,我们发送数据的时候比较…

Linux云计算 |【第四阶段】RDBMS1-DAY3

主要内容: 子查询(单行单列、多行单列、单行多列、多行多列)、分页查询limit、联合查询union、插入语句、修改语句、删除语句 一、子查询 子查询就是指的在一个完整的查询语句之中,嵌套若干个不同功能的小查询,从而一…

安宝特案例 | 某知名日系汽车制造厂,借助AR实现智慧化转型

案例介绍 在全球制造业加速数字化的背景下,工厂的生产管理与设备维护效率愈发重要。 某知名日系汽车制造厂当前面临着设备的实时监控、故障维护,以及跨地域的管理协作等挑战,由于场地分散和突发状况的不可预知性,传统方式已无法…

大模型部署——NVIDIA NIM 和 LangChain 如何彻底改变 AI 集成和性能

DigiOps与人工智能 人工智能已经从一个未来主义的想法变成了改变全球行业的强大力量。人工智能驱动的解决方案正在改变医疗保健、金融、制造和零售等行业的企业运营方式。它们不仅提高了效率和准确性,还增强了决策能力。人工智能的价值不断增长,这从它处…

Html 转为 MarkDown

在 RAG 中,通常需要将 HTML 转为 Markdown,有很多第三方 API 都支持 HTML 的转换,本文使用一个代码文档的例子 https://www.joinquant.com/help/api/help#name:Stock,将聚宽 API 转为 Markdown。本文通过两种方式进行实现,使用收费和开源的解决方案。聚宽 API 格式转为 Ma…