模型训练-保存训练数据

1.目的

找到一个可运行的代码,可以每个epoch打印训练数据,但是不会保存。因为在改进模型需要这些训练数据进行对比,所以需要将每个epoch的训练数据保存下来,写到一个文件中。

2.解决方案

直接问ChatGPT,提示词如下:

做一个深度学习的图像分类任务,现在需要将每个epoch的训练数据存到一个文件,用python来实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import json

# 定义一个简单的网络
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(32 * 28 * 28, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

# 设置超参数
batch_size = 64
num_epochs = 10
learning_rate = 0.01

# 创建数据集和数据加载器
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='data/', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 创建网络、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# 文件名
output_file = "training_data.json"

# 开始训练
training_data = []

for epoch in range(num_epochs):
    epoch_loss = 0.0
    correct = 0
    total = 0

    # 训练一个 epoch
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    # 计算平均损失和准确度
    avg_loss = epoch_loss / total
    accuracy = correct / total

    # 将 epoch 的数据添加到列表中
    epoch_data = {
        "epoch": epoch + 1,
        "avg_loss": avg_loss,
        "accuracy": accuracy
    }
    training_data.append(epoch_data)

    # 打印 epoch 的训练结果
    print(f"Epoch {epoch + 1}: Avg Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

# 将所有 epoch 的训练数据保存到文件中
with open(output_file, 'w') as f:
    json.dump(training_data, f)

print(f"Training data saved to {output_file}")

由于缺少数据集,也懒得去找一个数据集,就将自己py文件中训练部分摘出来给ChatGPT,让它在训练方法中实现这个功能。 

def train():
    data_format(os.path.join(config.root_path, './data/train.txt'),
    os.path.join(config.root_path, './data/data'), os.path.join(config.root_path, './data/train.json'))
    data = read_from_file(config.train_data_path, config.data_dir, config.only)
    train_data, val_data = train_val_split(data)
    train_loader = processor(train_data, config.train_params)
    val_loader = processor(val_data, config.val_params)

    best_acc = 0
    epoch = config.epoch
    for e in range(epoch):
        print('-' * 20 + ' ' + 'Epoch ' + str(e+1) + ' ' + '-' * 20)
        # 训练模型
        tloss, tloss_list = trainer.train(train_loader)
        print('Train Loss: {}'.format(tloss))
        # writer.add_scalar('Training/loss', tloss, e)
        # 验证模型
        vloss, vacc = trainer.valid(val_loader)
        print('Valid Loss: {}'.format(vloss))
        print('Valid Acc: {}'.format(vacc))
        # writer.add_scalar('Validation/loss', vloss, e)
        # writer.add_scalar('Validacc/acc', vacc, e)
        # 保存训练数据
        training_data = {
            "epoch": e + 1,
            "train_loss": tloss,
            "valid_loss": vloss,
            "valid_acc": vacc
        }
        with open('training_data.json', 'a') as f:
            json.dump(training_data, f)
            f.write('\n')
        print("数据保存完成")
        # 保存最佳模型
        if vacc > best_acc:
            best_acc = vacc
            save_model(config.output_path, config.fuse_model_type, model)
            print('Update best model!')


     
    print('-' * 20 + ' ' + 'Training Finished' + ' ' + '-' * 20)
    print('Best Validation Accuracy: {}'.format(best_acc))

在我的代码中具体加入的是下列几行代码

# 保存训练数据
training_data = {
    "epoch": e + 1,
    "train_loss": tloss,
    "valid_loss": vloss,
    "valid_acc": vacc
}
with open('training_data.json', 'a') as f:
    json.dump(training_data, f)
    f.write('\n')
print("数据保存完成")

 

代码意思如下: 

  1. with open('training_data.json', 'a') as f:: 打开名为 'training_data.json' 的文件,以追加模式 'a',并将其赋给变量 f。如果文件不存在,将会创建一个新文件。
  2. json.dump(training_data, f): 将变量 training_data 中的数据以 JSON 格式写入到文件 f 中。这个操作会将 training_data 中的内容转换成 JSON 格式,并写入到文件中。
  3. f.write('\n'): 写入一个换行符 \n 到文件 f 中,确保每次写入 JSON 数据后都有一个新的空行,使得每个 JSON 对象都独占一行,便于后续处理。

这段代码的作用是将变量 training_data 中的数据以 JSON 格式写入到文件 'training_data.json' 中,并确保每次写入后都有一个换行符分隔。

3.结果

可以在每个epoch训练完成后,将训练损失,验证损失和验证准确率保存在training_data.json文件中。

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/535538.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AtCoder ABC248 A-D题解

比赛链接:ABC348 Problem A: 签到。 #include <bits/stdc.h> using namespace std; int main(){int N;cin>>N;for(int i1;i<N;i){if(i%30)cout<<x<<endl;elsecout<<o<<endl;}return 0; } Problem B: 枚举即可。 #include <bit…

51蓝桥杯之DS18B20

DS18B20 基础知识 代码流程实现 将官方提供例程文件添加到工程中 添加onewire.c文件到keil4里面 一些代码补充知识 代码 #include "reg52.h" #include "onewire.h" #include "absacc.h" unsigned char num[10]{0xc0,0xf9,0xa4,0xb0,0x99,…

【域适应】基于深度域适应MMD损失的典型四分类任务实现

关于 MMD &#xff08;maximum mean discrepancy&#xff09;是用来衡量两组数据分布之间相似度的度量。一般地&#xff0c;如果两组数据分布相似&#xff0c;那么MMD 损失就相对较小&#xff0c;说明两组数据/特征处于相似的特征空间中。基于这个想法&#xff0c;对于源域和目…

24/04/11总结

IO流(First edition): IO流&#xff1a;用于读入写出文件中的数据 流的方向&#xff08;输入指拿出来,输出指写进去) 输入流:读取 输出流:写出 操作文件类型 字节流:所有类型文件 字符流:纯文本 字节流: InputStream的子类:FileInputStream:操作本地文件的字节输入流 OutputSt…

工作流引擎常见API(以camunda为例)

在Camunda中&#xff0c;API的继承关系主要体现在各个服务接口之间。以下是Camunda中一些常见服务接口的继承关系&#xff1a; ProcessEngineServices 接口&#xff1a; RepositoryService&#xff1a; 负责管理流程定义和部署。RuntimeService&#xff1a; 负责管理流程实例的…

2023年通用人工智能AGI等级保护白皮书

今天分享的是人工智能专题系列深度研究报告&#xff1a;《人工智能专题&#xff1a;2023年通用人工智能AGI等级保护白皮书》。 通用人工智能发展现状 本章主要介绍通用人工智能的基本情况&#xff0c;包括其发展历史、现状以及组成架构等内容。本文还将通过从技术角度出发来分…

PTA 应急救援站选址(floyd+打印路径)

大学城虎溪社区有很多居民小区&#xff0c;居民小区道路图是连通的。现要在该社区新建一个应急救援站&#xff0c;且该应急救援站要和某个小区建在一起。为了使应急救援最快速&#xff0c;经各部门商量决定&#xff1a;应急救援站建好后&#xff0c;离应急救援站最远的小区到应…

大话设计模式之命令模式

命令模式是一种行为型设计模式&#xff0c;它将请求或操作封装成一个对象&#xff0c;从而允许客户端参数化操作。这意味着客户端将一个请求封装为一个对象&#xff0c;这样可以将请求的参数化、队列化和记录日志&#xff0c;以及支持可撤销的操作。 命令模式主要由以下几个角…

kaggle 泰坦尼克号1(根据男女性存活率)

kaggle竞赛 泰坦尼克号 流程 下载kaggle数据集导入所要使用的包引入kaggle的数据集csv文件查看数据集的大小和长度去除冗余数据建立特征工程导出结果csv文件 1.下载kaggle数据集 2.导入所要使用的包 import pandas as pd import numpy as np import matplotlib.pyplot as …

leetcode328.奇偶链表

1. 题目描述 在线练习 2. 解题思路 这道题&#xff0c;官方给的是中等难度。其实是一道基础题&#xff0c;大家应该都可以写得出来。 题目中给的示例可以清楚的看到&#xff0c;合并前后的奇偶链的各自包含的节点的顺序是不变的&#xff0c;我们基本可以确定使用尾插法来合并…

Ansys Mechanical | 软件介绍:业界一流的有限元求解器

Ansys Mechanical 有限元分析软件 Ansys Mechanical 是业界一流的有限元求解器&#xff0c;具有结构、热学、声学、瞬态和非线性功能&#xff0c;可帮助改进建模。 ​ 软件概览 Ansys Mechanical 创建了一个使用有限元仿真分析软件&#xff08;FEA&#xff09;进行结构分析…

猝不及防 CCF-B ICPP 2024投稿延期至4月22日提交摘要 机会来了别错过

会议之眼 快讯 第53届ICPP&#xff08;International Conference on Parallel Processing&#xff09;即国际并行处理会议将于 2024年 8月12日-15日在瑞典哥特兰岛举行&#xff01;ICPP是世界上最古老的连续举办的并行计算计算机科学会议之一。它是学术界、工业界和政府的研究…

欢迎加入PenPad Season 2 ,获得勋章以及海量 Scroll 生态权益

PenPad 是 Scroll 生态中的首个 LaunchPad 平台&#xff0c;该平台继承了 Scroll 生态的技术优势&#xff0c;具备包括隐私在内的系列特点&#xff0c;同时且也被认为是 Scroll 生态最重要的价值入口之一。Penpad 与 Scroll 官方始终保持着合作&#xff0c;同时该项目自启动以来…

你一定不能错过的多模态大模型!阿里千问开源Qwen-VL!具备图文解读等能力

1. Qwen-VL简介 1.1. 介绍 Qwen-VL的多语言视觉语言模型系列,基于Qwen-7B语言模型。该模型通过视觉编码器和位置感知的视觉语言适配器,赋予语言模型视觉理解能力。 Qwen-VL采用了三阶段的训练流程,并在多个视觉语言理解基准测试中取得了领先的成绩。该模型支持多语言、多图…

办公室电脑监控软件哪个最好用

办公室电脑监控软件哪个最好用 办公室监控软件主要用于帮助企业管理员监控员工在工作时间内的电脑使用情况&#xff0c;以提高工作效率、保障数据安全、遵守合规要求和维护良好的工作秩序。以下是一些推荐的办公室监控软件。 1、安企神 (1) 强大的监控功能&#xff1a;域智盾…

【文献分享】机器学习 + 分子动力学(LAMMPS 输入文件)+ 第一性原理 + 热学性质 + 动力学性质

分享一篇关于机器学习 分子动力学 第一性原理 热学性质 动力学性质的文章。 感谢论文的原作者&#xff01; 关键词&#xff1a; 1. Machine learning, 2. Deep potential, 3. Molecular dynamics 4. Molten salts 5. Thermophysical properties 6. Phase diagram 主…

并查集加训

1.模板 #include<iostream> using namespace std; const int N 1e4 10; int p[N]; int n, m;int fd(int x){if(x ! p[x]){p[x] fd(p[x]);}return p[x]; }int main(){scanf("%d%d", &n, &m);for(int i 1; i < n; i){p[i] i;}int z, x, y;while(…

nvm更新node版本

1、nvm安装和管理多个 Node.js 版本&#xff1a;NVM 允许用户在计算机上同时安装多个不同版本的 Node.js。这使得开发人员可以轻松地在不同的项目中使用不同的 Node.js 版本&#xff0c;而无需手动安装或卸载。 2、nvm切换 Node.js 版本&#xff1a;通过 NVM&#xff0c;用户可…

一辆新能源汽车需要多少颗传感器?

随着科技的发展和环保意识的日益提高&#xff0c;新能源汽车&#xff08;包括纯电动汽车、混合动力汽车等&#xff09;在全球范围内越来越受到欢迎。这些汽车不仅减少了碳排放&#xff0c;还推动了汽车产业的创新。然而&#xff0c;这些高科技汽车的背后&#xff0c;隐藏着许多…

git lfs 大文件管理

简介 git-lfs 是 Git Large File Storage 的缩写&#xff0c;是 Git 的一个扩展&#xff0c;用于处理大文件的版本控制。 它允许你有效地管理和存储大型二进制文件&#xff0c;而不会使 Git 仓库变得过大和不稳定。以下是一些与 git-lfs 相关的常见命令和解释&#xff1a; 常…