StreamingDataset#
- class streaming.StreamingDataset(*, local, remote=None, split=None, shuffle=False, predownload=100000, keep_zip=None, download_retry=2, download_timeout=60, validate_hash=None, shuffle_seed=9176, num_canonical_nodes=None, batch_size=None, partition_algo='orig', shuffle_algo='py2s')[source]#
A streaming pytorch IterableDataset that is also resumable mid-epoch.
Checkpoints are represented in JSON as follows:
- {
‘epoch’: int, ‘sample_in_epoch’: int, ‘shuffle_seed’: int, ‘num_canonical_nodes’: int,
}
- Parameters
local (str) – Local dataset directory where shards are cached by split.
remote (str, optional) – Download shards from this remote path or directory. If None, this rank and worker’s partition of the dataset must all exist locally. Defaults to
None
.split (str, optional) – Which dataset split to use, if any. Defaults to
None
.shuffle (bool) – Whether to iterate over the samples in randomized order. Defaults to
False
.predownload (int, optional) – Target number of samples ahead to download the shards of while iterating. Defaults to
100_000
.keep_zip (bool, optional) – Whether to keep or delete the compressed file when decompressing downloaded shards. If set to None, keep iff remote is local. Defaults to
None
.download_retry (int) – Number of download re-attempts before giving up. Defaults to
2
.download_timeout (float) – Number of seconds to wait for a shard to download before raising an exception. Defaults to
60
.validate_hash (str, optional) – Optional hash or checksum algorithm to use to validate shards. Defaults to
None
.shuffle_seed (int) – Seed for Deterministic data shuffling. Defaults to
9176
.num_canonical_nodes (int, optional) – Canonical number of nodes for shuffling with resumption. Defaults to
None
, which is interpreted as the number of nodes of the initial run.batch_size (int, optional) – Batch size of its DataLoader, which affects how the dataset is partitioned over the workers. Defaults to
None
.partition_algo (str) – Which partitioning algorithm to use. Defaults to
orig
.shuffle_algo (str) – Which shuffling algorithm to use. Defaults to
py2s
.
- load_state_dict(obj)[source]#
Load a dict containing training state (called from non-worker process).
This is called on each copy of the dataset when resuming.
We just save the state to shared memory for workers to pick up when __iter__ is next called. We use shm because changes to this copy of the dataset wouldn’t be picked up by persistent workers.
- Parameters
obj (Dict[str, Any]) – The state.
- property next_epoch#
Get property next_epoch.
- Returns
int – Next epoch.
- state_dict(num_samples, from_beginning)[source]#
Get a dict containing training state (called from non-worker process).
This is called on rank zero.
Our stock StreamingDataLoader counts samples from start of training (from_beginning=false). However, if you are always counting from the start of the epoch, set from_beginning=true.