Dataset#

class streaming.Dataset(local, remote=None, split=None, shuffle=True, prefetch=100000, keep_zip=None, retry=2, timeout=60, hash=None, batch_size=None)[source]#

A sharded, streamed, iterable dataset.

Parameters
  • local (str) โ€“ Local dataset directory where shards are cached by split.

  • remote (str, optional) โ€“ Download shards from this remote path or directory. If None, this rank and workerโ€™s partition of the dataset must all exist locally. Defaults to None.

  • split (str, optional) โ€“ Which dataset split to use, if any. Defaults to None.

  • shuffle (bool) โ€“ Whether to iterate over the samples in randomized order. Defaults to True.

  • prefetch (int, optional) โ€“ Target number of samples remaining to prefetch while iterating. Defaults to None.

  • keep_zip (bool, optional) โ€“ Whether to keep or delete the compressed file when decompressing downloaded shards. If set to None, keep iff remote is local. Defaults to None.

  • retry (int) โ€“ Number of download re-attempts before giving up. Defaults to 2.

  • timeout (float) โ€“ Number of seconds to wait for a shard to download before raising an exception. Defaults to 60.

  • hash (str, optional) โ€“ Optional hash or checksum algorithm to use to validate shards. Defaults to None.

  • batch_size (int, optional) โ€“ Hint the batch size that will be used on each deviceโ€™s DataLoader. Defaults to None.

To write the dataset:
>>> import numpy as np
>>> from PIL import Image
>>> from uuid import uuid4
>>> from streaming import MDSWriter
>>> dirname = 'dirname'
>>> columns = {
...     'uuid': 'str',
...     'img': 'jpeg',
...     'clf': 'int'
... }
>>> compression = 'zstd'
>>> hashes = 'sha1', 'xxh64'
>>> samples = [
...     {
...         'uuid': str(uuid4()),
...         'img': Image.fromarray(np.random.randint(0, 256, (32, 48, 3), np.uint8)),
...         'clf': np.random.randint(10),
...     }
...     for i in range(1000)
... ]
>>> with MDSWriter(dirname, columns, compression, hashes) as out:
...     for sample in samples:
...         out.write(sample)

To read the dataset:
>>> from streaming import Dataset
>>> dataset = Dataset(dirname)
>>> for sample in dataset:
...     print(sample)

To read the dataset (with all optional arguments):
>>> from streaming import Dataset
>>> dataset = Dataset(local=dirname, remote=None, split=None, shuffle=True,
...                   prefetch=100_000, keep_zip=None, retry=2, timeout=60, hash=None,
...                   batch_size=None)
download(num_processes=None)[source]#

Load all shards, downloading if not local (blocking).

Parameters

num_processes (int, optional) โ€“ Number of concurrent shard downloads (ie, size of the process pool). If None, uses number of CPUs. Defaults to None.