Pytorch Webdataset初体验
最近都在用pytorch,虽然pytorch很多东西都比tensorflow舒服,但是在data pipeline
方面还是tensorflow比较有优势,缺乏一个紧凑压缩的record的读取方法,虽然可以用DALI,但是之前用了一下还是不够灵活。最近在pytorch博客中发现了一个Webdataset
,因此就尝试一下。
介绍
他的方法是将所有的样本压缩到tar文件中,使用名字作为样本的key
,比如样本A
可以包含A.jpg,A.json
等等,读取的时候根据key
一次性将所有的样本元素全部读取到dict
中,之后我们可以随意的map
,灵活性还是比较大的。
结论
经过测试之后,速度较以前的确有所提升,并且读取的速度比较稳定。不过也有几个不太方便的地方:
无法得知数据集的长度 因为是
tar
文件,构建数据集时无法得知整体长度,所以需要显式的指定。不像
tfrecord
,无法对一个tar
文件进行多线程读取。pytorch
中的dataloader
可以指定多个worker
进行读取,但是如果tar
文件没有进行分片的话,就不会起作用,必须要将tar
文件先进行分片才行。不过就算不分片,速度也比原来的多线程读取要快、要稳定。无法进行
concat
等等操作 这个没有办法,毕竟tensorflow
的dataset
也没有这个功能。
例子
制作分片的数据集
from pathlib import Path |
读取分片的数据集
def get_pattern_and_total_num(root, stage='train'): |