最近在做一个新的声学模型,其中遇到一个点就是每个sentence的长度不一样的花,直接用dataloader的读取是有问题的。查了下中文资料,大家大多数这个问题都是趋于用torch.nn.utils.rnn.PackedSequence
来打包的,这个在dataloader里面其实就不太适用,pytorch论坛上提到用dataloader的collate_fn来处理的,所以想写个资料总结下 。
pytorch里面dataset的工作逻辑:
pytorch的数据载入主要是这么几个逻辑,从底层一步步来讲,我用h5矩阵,图片和音频三个方面来举例,首先是逻辑层次是,首先把data装进用torch.utils.data.Dataset
装进一个dataset的对象里面,然后在把dataset这个对象传递给一个torch.utils.data.DataLoader
dataset的工作逻辑
数据集的切分一般在dataset这个对象上做处理,支持随机切分等,详见torch.utils.data - PyTorch master documentation,一般来讲,我都是写一个torch.utils.data.Dataset的子类,里面就三个成员函数,初始化,长度和读取,一般在读取你自己定义的读取方法,我习惯的是h5矩阵的话,就读一段(子矩阵),图片就是一张图,或者一段音频。
这里面有个很关键的点,就是dataset的逻辑是一次读一个item,最好不要在dataset层面一次slice一段,slice这个层面的事情交给dataloader来做,原因我一会说。
记住dataset的逻辑在于装和item读取,预处理,其他都不要做。
dataloader的工作逻辑
dataloader层面主要就是slice读取数据,shuffle也是在这个层面来做。
dataloader有几个关键点,很多地方都零零碎碎的提到过,我总结下,
- 是稀松平常的batch_size, sampler, shuffle这几个稀松平常的不提,shuffle是在dataset的item层面做混洗,
- 注意,num_workers是一个多线程的读取,当batchsize>1的时候,多线程读取item, 然后各个item调用一个collate_fn合并成新的tensor,其中h5依然是个坑,anaconda安装的h5是不支持多线程的,请参考并行 HDF5 和 h5py安装并行h5,至于num*_*worker以及pin_memoru的具体使用,参考云梦:Pytorch 提速指南,不重复造轮子。
- 关于这个collate\fn是重点,当开启多线程了一个,多线程先后读取了dataset里面batch_size个item以后,生成了一个list,里面每个元素就是batchsize个item,然后用collatefn合并,如果没有指定的collatefns的话,就直接合并成一个高一维的tensor。
collatefns的工作逻辑
coolatefns的输入是个list,长度为batchsize,其中各个元素是各个item,函数的目的就是合并。
当各个item变长时,不指定collatefns合并就会报错,懒人方法就是把在dataset里面的读取函数把tensor加到最长,就可以直接merge。
当使用collatefns时,pytorch论坛上有人写了一个函数,我贴过来,大家配合注释看看:
1 | def pad_tensor(vec, pad, dim): |
调用使用:
1 | train_loader = DataLoader(ds, ..., collate_fn=PadCollate(dim=0)) |
来源:DataLoader for various length of data
对于读取了以后的数据,在rnn中的工作逻辑,pytorch的文档也提到过
total_length
is useful to implement thepacksequence->recurrentnetwork->unpacksequence
pattern in aModule
wrapped inDataParallel
. See this FAQ sectionfor details.
来源:torch.nn - PyTorch master documentation
关于读取到了的padding的变长数据,如何pack,请参考 @尹相楠 的:
尹相楠:PyTorch 训练 RNN 时,序列长度不固定怎么办?
本文转载自 知乎 Charlie的语音处理实验室