Pytorch指定数据加载器使用子进程

torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)

num_workers 参数是 DataLoader 类的一个参数,它指定了数据加载器使用的子进程数量。通过增加 num_workers 的数量,可以并行地读取和预处理数据,从而提高数据加载的速度。

通常情况下,增加 num_workers 的数量可以提高数据加载的效率,因为它可以使数据加载和预处理工作在多个进程中同时进行。然而,当 num_workers 的数量超过一定阈值时,增加更多的进程可能不会再带来更多的性能提升,甚至可能会导致性能下降。

这是因为增加 num_workers 的数量也会增加进程间通信的开销。当 num_workers 的数量过多时,进程间通信的开销可能会超过并行化所带来的收益,从而导致性能下降。

此外,还需要考虑到计算机硬件的限制。如果你的计算机 CPU 核心数量有限,增加 num_workers 的数量也可能会导致性能下降,因为每个进程需要占用 CPU 核心资源。

因此,对于 num_workers 参数的设置,需要根据具体情况进行调整和优化。通常情况下,一个合理的 num_workers 值应该在 2 到 8 之间,具体取决于你的计算机硬件配置和数据集大小等因素。在实际应用中,可以通过尝试不同的 num_workers 值来找到最优的配置。

综上所述,当 num_workers 的值从 4 增加到 8 时,如果你的计算机硬件配置和数据集大小等因素没有发生变化,那么两者之间的性能差异可能会很小,或者甚至没有显著差异。

测试代码如下

import torch
import torchvision
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import time

if __name__ == '__main__':
    mp.freeze_support()
    train_on_gpu = torch.cuda.is_available()
    if not train_on_gpu:
        print('CUDA is not available. Training on CPU...')
    else:
        print('CUDA is available! Training on GPU...')

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    batch_size = 4
    # 设置数据预处理的转换
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((512,512)),  # 调整图像大小为 224x224
        torchvision.transforms.ToTensor(),  # 转换为张量
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化
    ])
    dataset = torchvision.datasets.ImageFolder('C:\\Users\\ASUS\\PycharmProjects\\pythonProject1\\cats_and_dogs_train',
                                                     transform=transform)


    val_ratio = 0.2
    val_size = int(len(dataset) * val_ratio)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)
    val_dataset = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)

    model = models.resnet18()

    num_classes = 2
    for param in model.parameters():
        param.requires_grad = False

    model.fc = nn.Sequential(
        nn.Dropout(),
        nn.Linear(model.fc.in_features, num_classes),
        nn.LogSoftmax(dim=1)
    )
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss().to(device)
    model.to(device)

    filename = "recognize_cats_and_dogs.pt"

    def save_checkpoint(epoch, model, optimizer, filename):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
        }
        torch.save(checkpoint, filename)

    num_epochs = 3
    train_loss = []
    for epoch in range(num_epochs):
        running_loss = 0
        correct = 0
        total = 0
        epoch_start_time = time.time()
        for i, (inputs, labels) in enumerate(train_dataset):
            # 将数据放到设备上
            inputs, labels = inputs.to(device), labels.to(device)
            # 前向计算
            outputs = model(inputs)
            # 计算损失和梯度
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            # 更新模型参数
            optimizer.step()
            # 记录损失和准确率
            running_loss += loss.item()
            train_loss.append(loss.item())
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
        accuracy_train = 100 * correct / total
        # 在测试集上计算准确率
        with torch.no_grad():
            running_loss_test = 0
            correct_test = 0
            total_test = 0
            for inputs, labels in val_dataset:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                running_loss_test += loss.item()

                _, predicted = torch.max(outputs.data, 1)
                correct_test += (predicted == labels).sum().item()
                total_test += labels.size(0)
            accuracy_test = 100 * correct_test / total_test
            # 输出每个 epoch 的损失和准确率
        epoch_end_time = time.time()
        epoch_time = epoch_end_time - epoch_start_time
        print("Epoch [{}/{}], Time: {:.4f}s, Loss: {:.4f}, Train Accuracy: {:.2f}%, Loss: {:.4f}, Test Accuracy: {:.2f}%"
              .format(epoch + 1, num_epochs,epoch_time,running_loss / len(val_dataset),
                      accuracy_train, running_loss_test / len(val_dataset), accuracy_test))
        save_checkpoint(epoch, model, optimizer, filename)

    plt.plot(train_loss, label='Train Loss')
    # 添加图例和标签
    plt.legend()
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Loss')

    # 显示图形
    plt.show()

不同num_workers的结果如下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/103543.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

maven-default-http-blocker (http://0.0.0.0/): Blocked mirror for repositories

前言 略 说明 新设备上安装了mvn 3.8.5,编译新项目出错: [ERROR] Non-resolvable parent POM for com.admin.project:1.0: Could not transfer artifact com.extend.parent:pom:1.6.9 from/to maven-default-http-blocker (http://0.0.0.0/): Bl…

建联合作1000+达人,如何高效管理?

随着社交媒体的发展,达人营销已成为品牌营销重要的方式之一,甚至可以说是必选项。 对于很多品牌商家来说,一次合作几百个不同类型、不同社媒平台的达人,已屡见不鲜。在电商大促前、主推单品爆品时,同时合作上千个达人&…

jenkins实践篇(1)——基于分支的自动发布

问题背景 想起初来公司时,我们还是在发布机上直接执行发布脚本来运行和部署服务,并且正式环境和测试环境的脚本都在一起,直接手动操作脚本时存在比较大的风险就是将环境部署错误,并且当时脚本部署逻辑还没有检测机制,…

vscode Coder Runner 运行C++

1. 设置Code Runner 2. 防止输入读不到,把在终端运行勾上。 3. 设置minw/bin的环境变量 安装mingw教程:https://blog.csdn.net/fancy_male/article/details/133992000 4. 见图

数据结构:选择题+编程题(每日一练)

目录 选择题: 题一: 题二: 题三: 题四: 题五: 编程题: 题一:单值二叉树 思路一: 题二:二叉树的最大深度 思路一: 本人实力有限可能对…

Go学习第八章——面向“对象”编程(结构体与方法)

Go面向“对象”编程(结构体与方法) 1 结构体1.1 快速入门1.2 内存解析1.3 创建结构体四种方法1.4 注意事项和使用细节 2 方法2.1 方法的声明和调用2.2 快速入门案例2.3 调用机制和传参原理2.4 注意事项和细节2.5 方法和函数区别 3 工厂模式 Golang语言面…

服装服饰小程序商城的作用是什么

随着数字化转型节奏加快,各个行业都在加紧线上平台渠道的建设,而对于服装业来讲,在线商城无疑是**的选择,一方面可以去除平台的限制,摆脱佣金烦恼,还可搭建自有会员体系、私域流量池,持续转化变…

使用ruoyi框架遇到的问题修改记录

使用ruoyi框架遇到的问题修改记录 文章目录 使用ruoyi框架遇到的问题修改记录上传后文件名改变上传时设置单多文件及其他选项附件显示文件名,点击下载附件直接显示图片表格固定列查询数据库作为下拉选项值字典使用加入json递归注解,防止无限递归内存溢出…

从零开始学习wpsjs

1.这是一个简单的wpsjs学习文档,我是边学习wpsjs边记录学习的,希望对您的学习有所帮助 开发事项: 全局安装wpsjs:npm install -g wpsjsWpsjs create HelloWps 安装wps npm install -g wpsjs 新建一个wps加载项 输入命令wpsjs create HelloW…

塔式服务器介绍

大家都知道服务器分为机架式服务器、刀片式服务器、塔式服务器三类,今天小编就分别讲一讲这三种服务器,第三篇先来讲一讲塔式服务器的介绍。 塔式服务器定义:塔式服务器的外观和普通电脑差不多,直立放置。机箱比较大,服…

药房商城小程序便捷购药体验,找药买药一键操作

在我们的生活中,偶尔会遇到身体不适或需要常见药品的情况。然而,传统的购药方式往往会浪费大量的时间和精力。幸运的是,现在有了药房商城小程序,仅需一键操作,您就能享受到便捷、高效的购药体验。 药房商城小程序是用于…

聊聊分布式架构10——Zookeeper入门详解

目录 01ZooKeeper的ZAB协议 ZAB协议概念 ZAB协议基本模式 消息广播 崩溃恢复 选举出新的Leader服务器 数据同步 02Zookeeper的核心 ZooKeeper 的核心特点 ZooKeeper 的核心组件 选举算法概述 服务器启动时的Leader选举 服务器运行期间的Leader选举 03ZooKeeper的…

Netty核心源码剖析

Netty 线程模型 Netty高并发高性能架构设计精髓 主从Reactor线程模型NIO多路复用非阻塞无锁串行化设计思想支持高性能序列化协议零拷贝(直接内存的使用)ByteBuf内存池设计灵活的TCP参数配置能力并发优化 无锁串行化设计思想 在大多数场景下,并行多线程处理可以提…

安卓使用android studio跨进程通信之AIDL

我写这篇文章不想从最基础的介绍开始,我直接上步骤吧. 1.创建服务端 1.1:创建服务端项目:我的as版本比较高,页面就是这样的 1.2:创建AIDL文件,右键项目,选中aidl aidl名字可以自定义也可以默认 basicTypes是自带的,可以删掉,也可以不删,然后把你自己所需的接口写上去 1.3:创建…

CSS 两栏布局

目录 CSS两栏布局(左列定宽,右列自适应宽) 方法一:浮动margin 方法二:定位margin 方法三:浮动BFC 方法四:Flex布局 方法五:able布局 CSS两栏布局(左列不定宽&#…

uniapp 自定义导航栏

自定义导航栏 修改 pages.json 在 pages.json 中将 navigateionStyle 设为 custom 新建 systemInfo.js systemInfo.js 用来获取当前设备的机型系统信息,放在 common 目录下 /*** 此 js 文件管理关于当前设备的机型系统信息*/ const systemInfo function() {/***…

信息检索与数据挖掘 | 【实验】排名检索模型

文章目录 📚实验内容📚相关概念📚实验步骤🐇分词预处理🐇构建倒排索引表🐇计算query和各个文档的相似度🐇queries预处理及检索函数🔥对输入的文本进行词法分析和标准化处理&#x1f…

从0开始在Vscode中搭建Vue2/3项目详细步骤

1.安装node.js:Node.js下载安装及环境配置教程【超详细】_nodejs下载_WHF__的博客-CSDN博客 node.js自带npm,无需单独安装。 验证: node -v npm -v 2.先简单创建一个空文件夹,vscode进入该文件夹,并打开终端。 3.安装cnpm&…

【Gensim概念】03/3 NLP玩转 word2vec

第三部分 对象函数 八 word2vec对象函数 该对象本质上包含单词和嵌入之间的映射。训练后,可以直接使用它以各种方式查询这些嵌入。有关示例,请参阅模块级别文档字符串。 类型 KeyedVectors 1) add_lifecycle_event(event_name, log_level2…

OpenCV视频车流量识别详解与实践

视频车流量识别基本思想是使用背景消去算法将运动物体从图片中提取出来,消除噪声识别运动物体轮廓,最后,在固定区域统计筛选出来符合条件的轮廓。 基于统计背景模型的视频运动目标检测技术: 背景获取:需要在场景存在…