整体架构
平时使用 pytorch 加载数据时大概是这样的:
import numpy as np
from torch.utils.data import Dataset, DataLoader
class ExampleDataset(Dataset):
def __init__(self):
self.data = [1, 2, 3, 4, 5]
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
def collate_fn(batch):
return np.array(batch)
dataset = ExampleDataset() # create the dataset
dataloader = DataLoader(
dataset=dataset,
batch_size=2,
shuffle=True,
num_workers=4,
collate_fn=collate_fn
)
for datapoint in dataloader:
print(datapoint)
- 继承
Dataset
类,定义一个迭代器,包含两个魔法方法:__getitem__(self, idx)
和__len__(self)
,分别实现如何获取一条数据和如何设定数据长度; - 定义
collate_fn
函数,设定如何组织一个 batch; - 实例化
Dataset
,并和collate_fn
一起传入DataLoader
,参数batch_size
设置批大小、shuffle
设置是否打乱、num_workers
设置并行加载数据的进程数。
然而,背后到底干了什么,我们不清楚,甚至遇到 DataLoader
的如 sampler
、batch_sampler
、worker_init_fn
的其他参数,就会懵逼。那就看一看官方文档,了解一下 torch.utils.data
是如何工作的。
上图是数据加载的整体框架图,官网说 DataLoader
组合了 dataset
和 sampler
,多个 workers
根据 dataset
提供的数据副本和sampler
提供的 keys
并行地加载数据,并通过 collate_fn
组成 batch
供用户迭代。需要注意的有:
- 每个
worker
持有数据的一个副本,故占用内存 “主线程内存 * num_workers
”; - 即使用户不提供
sampler
对象 (通常不提供),DataLoader
也会根据shuffle
参数创建一个默认的sampler
对象;一旦提供了,其前路的shuffle
参数不能为True
(不提供就好); - 即使用户不提供
batch_sampler
对象 (通常不提供),DataLoader
也会根据batch_sampler, drop_last
参数创建一个默认的batch_sampler
对象;一旦提供了,其前路的shuffle, drop_last
不能为True
,batch_size
必须为 1 1 1,sampler
必须为 None,因为创建BatchSampler
时已经有了这些参数;
本质上是把创建batch_sampler
的活拉出来由用户在DataLoader
外自定义地做了。
Dataset
分为两种:map-style 和 iterable-style。前者的数据可通过 [idx or key]
访问,后者的数据只能通过迭代器 next
一个个访问。所以上面架构中的采样器是对于 map-style 数据集说的;iterable-style 的数据集的访问顺序由迭代器决定。
Sampler
torch.utils.data.Sampler
的子类或 Iterable
,两个例子:
class AccedingSequenceLengthSampler(tu_data.Sampler[int]):
def __init__(self, data: List[str]) -> None:
super().__init__()
self.data = data
def __len__(self) -> int:
return len(self.data)
def __iter__(self) -> Iterator[int]:
"""
:return: 实现了按数据长短顺序访问数据集
"""
sizes = torch.tensor([len(x) for x in self.data])
yield from torch.argsort(sizes).tolist()
class AccedingSequenceLengthBatchSampler(tu_data.Sampler[List[int]]):
def __init__(self, data: List[str], batch_size: int) -> None:
super().__init__()
self.data = data
self.batch_size = batch_size
def __len__(self) -> int:
return (len(self.data) + self.batch_size - 1) // self.batch_size
def __iter__(self) -> Iterator[List[int]]:
sizes = torch.tensor([len(x) for x in self.data])
for batch in torch.chunk(torch.argsort(sizes), len(self)): # 按块遍历
yield batch.tolist()
Batch
batch_sampler
提供一批下标,取得一批数据后由 collate_fn
将这批数据整合:
if collate_fn is None:
if self._auto_collation:
collate_fn = _utils.collate.default_collate
else: # self.batch_sampler is None: (batch_size is None) and (batch_sampler is None)
collate_fn = _utils.collate.default_convert
分两种情况:
- automatic batching is disabled:调用
default_convert
函数简单地将 NumPy arrays 转化为 PyTorch Tensor; - automatic batching is enabled:调用
default_collate
函数,转化会变得复杂一点:
from torch.utils import data as tu_data
import collections
# %% Example with a batch of `int`s:
tu_data.default_collate([0, 1, 2, 3])
# tensor([0, 1, 2, 3])
# %% Example with a batch of `str`s:
tu_data.default_collate(['a', 'b', 'c'])
# ['a', 'b', 'c']
# %% Example with `Map` inside the batch:
tu_data.default_collate([
{'A': 0, 'B': 1},
{'A': 100, 'B': 100}
])
# {'A': tensor([0, 100]), 'B': tensor([1, 100])}, 同 key 的合并了
# %% Example with `NamedTuple` inside the batch:
Point = collections.namedtuple('Point', ['x', 'y'])
tu_data.default_collate([Point(0, 0), Point(1, 1)])
# Point(x=tensor([0, 1]), y=tensor([0, 1])), 同 name 的合并了, 大概和 dict 一样吧
# %% Example with `Tuple` inside the batch:
tu_data.default_collate([(0, 1), (2, 3)])
# [tensor([0, 2]), tensor([1, 3])], 对 list 内部执行 collate
# %% Example with `List` inside the batch:
tu_data.default_collate([[0, 1], [2, 3]]) # [tensor([0, 2]), tensor([1, 3])], 对 list 内部执行 collate, 并没有变成二维 tensor
Multi-process Data Loading
dataset
, collate_fn
, and worker_init_fn
are passed to each worker,大概能说明 batch 是在子进程内部合成的。
有一个需要注意的地方是内存增长问题,当 __get_item__(self, key)
访问数据时,由于 Python 对象的 refcount 机制,数据会不断地复制,从而内存爆炸。但这里说解决 number of workers * size of parent process 问题,就不追究了,反正尽量用 numpy 或 pytorch tensor 吧。
iterable-style datasets 的随机性