前言
本文主要介绍pytorch里面批数据的处理方法,以及这个算法的效果是什么样的。具体就是要弄明白这个批数据选取的算法是在干什么,不会涉及到网络的训练。
from torch.utils.data import DataLoader, TensorDataset
主要实现就是上面的数据集和数据载入两个类来实现该算法功能,这里只要求会调用接口就够了。
一、生成数据集
import torch
from torch.utils.data import DataLoader, TensorDataset
# 准备数据集与定义batch_size
batch_size = 8
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)
print(x)
print(y)
输出:
二、将训练数据进行batch处理
# 将训练数据放入torch的数据集
train_dataset = TensorDataset(x, y)
# 载入batch批次选取数据规则
train_loader = DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True, # True表示每一个epoch都打乱抽取
num_workers=2 # 定义工作线程个数
)
三、epoch训练
# 训练模型
epochs = 3
for epoch in range(epochs):
# 每一个epoch表示将整个数据集所有数据都训练一遍
for step,(batch_x, batch_y) in enumerate(train_loader):
# training......
# 这里用enumerate是为了让你更加情况观察,batch的逻辑是怎么样的
# 实际中只要 for batch_x,batch_y in train_loader就可以了
print('Epoch:',epoch,'| Step:',step,'| batch x:',batch_x.data.numpy(),'| batch y:',batch_y.data.numpy())
# 测试模型(略)
输出:
【注】:可以看到每一个epoch将所有样本点都涉及到了一次,并且还是打乱顺序了的。
下面看看将shuffle=False不打乱顺序会发生什么:
【注】:可以看到每一个epoch,都是相同的结果,可想而知这样训练效果肯定没有打乱的好。
注意到,上半batch=5,恰好将样本总数10均分为2分,那么要是不能均分会发生什么,下面将batch=8,看看会发生什么。
可以看到直接将不够的组就直接剩下的了。
总结
后面我们会经常用到这种batch和epoch的训练方法。