PyTorch 的 dataloader 如何读取变长数据

最近在做一个新的声学模型,其中遇到一个点就是每个sentence的长度不一样的花,直接用dataloader的读取是有问题的。查了下中文资料,大家大多数这个问题都是趋于用torch.nn.utils.rnn.PackedSequence来打包的,这个在dataloader里面其实就不太适用,pytorch论坛上提到用dataloader的collate_fn来处理的,所以想写个资料总结下 。

pytorch里面dataset的工作逻辑:

pytorch的数据载入主要是这么几个逻辑,从底层一步步来讲,我用h5矩阵,图片和音频三个方面来举例,首先是逻辑层次是,首先把data装进用torch.utils.data.Dataset装进一个dataset的对象里面,然后在把dataset这个对象传递给一个torch.utils.data.DataLoader

dataset的工作逻辑

数据集的切分一般在dataset这个对象上做处理,支持随机切分等,详见torch.utils.data - PyTorch master documentation,一般来讲,我都是写一个torch.utils.data.Dataset的子类,里面就三个成员函数,初始化,长度和读取,一般在读取你自己定义的读取方法,我习惯的是h5矩阵的话,就读一段(子矩阵),图片就是一张图,或者一段音频。

这里面有个很关键的点,就是dataset的逻辑是一次读一个item,最好不要在dataset层面一次slice一段,slice这个层面的事情交给dataloader来做,原因我一会说

记住dataset的逻辑在于装和item读取,预处理,其他都不要做。

dataloader的工作逻辑

dataloader层面主要就是slice读取数据,shuffle也是在这个层面来做。

dataloader有几个关键点,很多地方都零零碎碎的提到过,我总结下,

  1. 是稀松平常的batch_size, sampler, shuffle这几个稀松平常的不提,shuffle是在dataset的item层面做混洗,
  2. 注意,num_workers是一个多线程的读取,当batchsize>1的时候,多线程读取item, 然后各个item调用一个collate_fn合并成新的tensor,其中h5依然是个坑,anaconda安装的h5是不支持多线程的,请参考并行 HDF5 和 h5py安装并行h5,至于num*_*worker以及pin_memoru的具体使用,参考云梦:Pytorch 提速指南,不重复造轮子。
  3. 关于这个collate\fn是重点,当开启多线程了一个,多线程先后读取了dataset里面batch_size个item以后,生成了一个list,里面每个元素就是batchsize个item,然后用collatefn合并,如果没有指定的collatefns的话,就直接合并成一个高一维的tensor。

collatefns的工作逻辑

coolatefns的输入是个list,长度为batchsize,其中各个元素是各个item,函数的目的就是合并。

当各个item变长时,不指定collatefns合并就会报错,懒人方法就是把在dataset里面的读取函数把tensor加到最长,就可以直接merge。

当使用collatefns时,pytorch论坛上有人写了一个函数,我贴过来,大家配合注释看看:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def pad_tensor(vec, pad, dim):
"""
args:
vec - tensor to pad
pad - the size to pad to
dim - dimension to pad

return:
a new tensor padded to 'pad' in dimension 'dim'
"""
pad_size = list(vec.shape)
pad_size[dim] = pad - vec.size(dim)
return torch.cat([vec, torch.zeros(*pad_size)], dim=dim)


class PadCollate:
"""
a variant of callate_fn that pads according to the longest sequence in
a batch of sequences
"""

def __init__(self, dim=0):
"""
args:
dim - the dimension to be padded (dimension of time in sequences)
"""
self.dim = dim

def pad_collate(self, batch):
"""
args:
batch - list of (tensor, label)

reutrn:
xs - a tensor of all examples in 'batch' after padding
ys - a LongTensor of all labels in batch
"""
# find longest sequence
max_len = max(map(lambda x: x[0].shape[self.dim], batch))
# pad according to max_len
batch = map(lambda (x, y):
(pad_tensor(x, pad=max_len, dim=self.dim), y), batch)
# stack all
xs = torch.stack(map(lambda x: x[0], batch), dim=0)
ys = torch.LongTensor(map(lambda x: x[1], batch))
return xs, ys

def __call__(self, batch):
return self.pad_collate(batch)

调用使用:

1
train_loader = DataLoader(ds, ..., collate_fn=PadCollate(dim=0))

来源:DataLoader for various length of data

对于读取了以后的数据,在rnn中的工作逻辑,pytorch的文档也提到过

total_length is useful to implement the packsequence->recurrentnetwork->unpacksequence pattern in a Module wrapped in DataParallel. See this FAQ sectionfor details.

来源:torch.nn - PyTorch master documentation

关于读取到了的padding的变长数据,如何pack,请参考 @尹相楠 的:

尹相楠:PyTorch 训练 RNN 时,序列长度不固定怎么办?

本文转载自 知乎 Charlie的语音处理实验室

原文链接:https://zhuanlan.zhihu.com/p/60129684