Source code for streaming.base.shuffle.py1e

# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Shuffling algorithm that shuffles by randomly placing shard samples in expanded ranges.

This algorithm has more balanced downloading and a lower minimum cache limit than ``py1b`` and
``py1br``, but also slightly lower shuffle quality. The range the samples from each shard can cover
is determined by ``shuffle_block_size``.
"""

import warnings

import numpy as np
from numpy.typing import NDArray

from streaming.base.shuffle.py1s import divide_spans


[docs]def get_shuffle_py1e(shard_sizes: NDArray[np.int64], num_canonical_nodes: int, seed: int, epoch: int, block_size: int = 1 << 18) -> NDArray[np.int64]: """Get the shuffled global ordering of samples for an epoch. The assignment of shards to nodes is fixed across epochs, but each grouping of shards is processed concurrently in a different order by each node's workers each epoch. Args: shard_sizes (NDArray[np.int64]): Number of samples contained in each shard, in order. num_canonical_nodes (int): Number of canonical nodes. seed (int): Base random seed, which is held constant over an entire training run. epoch (int): Current epoch, which is added to the seed to get a different deterministic shuffle each epoch. block_size (int): Unit of shuffle, used to set the std and clip length for the gaussian noise to be added to each shard. Defaults to ``1 << 18``. Returns: NDArray[np.int64]: 1:1 mapping of sample ID to shuffled sample ID. """ # Create each shard's sample ID span (begin, end excl). spans = [] num_samples = 0 for shard_size in shard_sizes: span = num_samples, num_samples + shard_size spans.append(span) num_samples += shard_size # Generate the initial ordering of shards, which is fixed over an entire training run. run_rng = np.random.default_rng(seed) run_rng.shuffle(spans) # Break the shard spans at canonical node boundaries. # The super_spans are the indices of spans that correspond to each canonical node. spans, super_spans = divide_spans(spans, num_samples, num_canonical_nodes) # Shuffle the span ordering within each canonical node uniquely to this epoch. epoch_rng = np.random.default_rng(seed + epoch) for begin, end in super_spans: # Retrieve the spans (shard parts) associated with this canonical node. part = spans[begin:end] epoch_rng.shuffle(part) # pyright: ignore spans[begin:end] = part # Populate the global sample ID mapping, shuffling within each span. ids = np.empty(num_samples, np.int64) offset = 0 warn_user = False # Iterate through each canonical node's spans. # We don't want samples crossing canonical node boundaries. for cn_begin, cn_end in super_spans: cn_spans = spans[cn_begin:cn_end] cn_span_sizes = np.array([end - begin for begin, end in cn_spans]) num_cn_samples = cn_span_sizes.sum() # The spans of a canonical node are shuffled, so they have sample ids that are # not contiguous. We need to get the correct sample ids for the current canonical node. cn_samples = np.empty(num_cn_samples) samples_inserted = 0 for begin, end in cn_spans: # Inserting span samples into cn_samples array. cn_span_samples = np.arange(begin, end) epoch_rng.shuffle(cn_span_samples) cn_samples[samples_inserted:samples_inserted + (end - begin)] = cn_span_samples samples_inserted += (end - begin) # Iterate over each span and shift sample indices by sampling from uniform distribution. cn_sample_offset = 0 sample_positions = np.arange(num_cn_samples).astype(np.float64) for span_size in cn_span_sizes: # Sample the block size uniformly in a fixed range centered around the block_size. # This helps to ensure that when training across a large number of nodes, downloads # are more balanced. rand_block_size = epoch_rng.integers(int(0.75 * block_size), int(1.25 * block_size)) # The maximum range on each side of the span is (rand_block_size - span_size) / 2. # This ensures that the span samples are only found in a max range of rand_block_size. cutoff = (rand_block_size - span_size) / 2 # if cutoff is negative, this means span size is less than rand_block_size, so we set # cutoff to 0 (no shuffling for this span) and warn the user later. if cutoff < 0: cutoff = 0 warn_user = True # Make sure the lower bound of the range doesn't cross the start of the canonical node. lower_bound = max(-cutoff, -cn_sample_offset) # Make sure the upper bound of the range doesn't cross the end of the canonical node. upper_bound = min(cutoff, num_cn_samples - cn_sample_offset - span_size) # Sample shifts from a uniform distribution with the bounds calculated above. shifts = epoch_rng.uniform(low=lower_bound, high=upper_bound, size=span_size) # Add shifts to shard sample indices. sample_positions[cn_sample_offset:cn_sample_offset + span_size] += shifts # Update sample offset for the next shard. cn_sample_offset += span_size # Get incides that would sort the sample_positions array. sort_indices = np.argsort(sample_positions) # Apply the sorting to the samples for our canonical node. cn_samples = cn_samples[sort_indices] # Assign the newly shuffled samples to the global ids array. ids[offset:offset + num_cn_samples] = cn_samples offset += num_cn_samples # If warn_user is true, this means the block size for shifts was smaller than a span size. # This will result in no shuffling being done on that span aka shard part, so warn user. if warn_user: warnings.warn('Shuffle block size was smaller than shard size for some shards. This \ will result in these shards not being shuffled with other shards. Set \ shuffle_block_size to a larger value for a higher quality shuffle.') return ids