StreamingDataset#

class streaming.StreamingDataset(local, remote=None, split=None, shuffle=False, predownload=100000, keep_zip=None, download_retry=2, download_timeout=60, validate_hash=None, shuffle_seed=9176, num_canonical_nodes=None, batch_size=None)[source]#

A streaming pytorch IterableDataset that is also resumable mid-epoch.

Checkpoints are represented in JSON as follows:

{

‘epoch’: int, ‘sample_in_epoch’: int, ‘shuffle_seed’: int, ‘num_canonical_nodes’: int,

}

Parameters
  • local (str) – Local dataset directory where shards are cached by split.

  • remote (str, optional) – Download shards from this remote path or directory. If None, this rank and worker’s partition of the dataset must all exist locally. Defaults to None.

  • split (str, optional) – Which dataset split to use, if any. Defaults to None.

  • shuffle (bool) – Whether to iterate over the samples in randomized order. Defaults to False.

  • predownload (int, optional) – Target number of samples ahead to download the shards of while iterating. Defaults to 100_000.

  • keep_zip (bool, optional) – Whether to keep or delete the compressed file when decompressing downloaded shards. If set to None, keep iff remote is local. Defaults to None.

  • download_retry (int) – Number of download re-attempts before giving up. Defaults to 2.

  • download_timeout (float) – Number of seconds to wait for a shard to download before raising an exception. Defaults to 60.

  • validate_hash (str, optional) – Optional hash or checksum algorithm to use to validate shards. Defaults to None.

  • shuffle_seed (int) – Seed for Deterministic data shuffling. Defaults to 9176.

  • num_canonical_nodes (int, optional) – Canonical number of nodes for shuffling with resumption. Defaults to None, which is interpreted as the number of nodes of the initial run.

  • batch_size (int, optional) – Batch size of its DataLoader, which affects how the dataset is partitioned over the workers. Defaults to None.

load_state_dict(obj)[source]#

Load a dict containing training state (called from non-worker process).

This is called on each copy of the dataset when resuming.

We just save the state to shared memory for workers to pick up when __iter__ is next called. We use shm because changes to this copy of the dataset wouldn’t be picked up by persistent workers.

Parameters

obj (Dict[str, Any]) – The state.

property next_epoch#

Get property next_epoch.

Returns

int – Next epoch.

state_dict(num_samples, from_beginning)[source]#

Get a dict containing training state (called from non-worker process).

This is called on rank zero.

Our stock StreamingDataLoader counts samples from start of training (from_beginning=false). However, if you are always counting from the start of the epoch, set from_beginning=true.

Parameters
  • num_samples (int) – The number of samples processed so far in the current epoch.

  • from_beginning (int) – Whether we are counting samples from the start of this epoch, or the start of just this potentially resumed training run this epoch.

Returns

Dict[str, Any] – The state.