Source code for streaming.text.pile

# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""The Pile.

The Pile is a 825 GiB diverse, open source language modelling data set that consists of 22 smaller,
high-quality datasets combined together.
"""

from typing import Any, Dict, Optional

from transformers.models.auto.tokenization_auto import AutoTokenizer

from streaming.base import StreamingDataset

__all__ = ['StreamingPile']


[docs]class StreamingPile(StreamingDataset): """Implementation of the the Pile using StreamingDataset. Args: tokenizer_name (str): The name of the HuggingFace tokenizer to use to tokenize samples. max_seq_len (int): The max sequence length of each token sample. group_method (str): How to group text samples into token samples. Currently only supporting ``'truncate'``. local (str): Local dataset directory where shards are cached by split. remote (str, optional): Download shards from this remote path or directory. If None, this rank and worker's partition of the dataset must all exist locally. Defaults to ``None``. split (str, optional): Which dataset split to use, if any. Defaults to ``None``. shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to ``False``. predownload (int, optional): Target number of samples ahead to download the shards of while iterating. Defaults to ``100_000``. keep_zip (bool, optional): Whether to keep or delete the compressed file when decompressing downloaded shards. If set to None, keep iff remote is local. Defaults to ``None``. download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. download_timeout (float): Number of seconds to wait for a shard to download before raising an exception. Defaults to ``60``. validate_hash (str, optional): Optional hash or checksum algorithm to use to validate shards. Defaults to ``None``. shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption. Defaults to ``None``, which is interpreted as the number of nodes of the initial run. batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is partitioned over the workers. Defaults to ``None``. """ def __init__(self, tokenizer_name: str, max_seq_len: int, group_method: str, local: str, remote: Optional[str] = None, split: Optional[str] = None, shuffle: bool = False, predownload: Optional[int] = 100_000, keep_zip: Optional[bool] = None, download_retry: int = 2, download_timeout: float = 60, validate_hash: Optional[str] = None, shuffle_seed: int = 9176, num_canonical_nodes: Optional[int] = None, batch_size: Optional[int] = None) -> None: if group_method not in ['truncate']: raise ValueError(f'Only group_method="truncate" is supported at this time.') super().__init__(local, remote, split, shuffle, predownload, keep_zip, download_retry, download_timeout, validate_hash, shuffle_seed, num_canonical_nodes, batch_size) self.tokenizer_name = tokenizer_name self.max_seq_len = max_seq_len self.group_method = group_method # Build tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) if self.tokenizer.pad_token is None: # Some tokenizers (e.g. GPT2 tokenizer) have no padding token which causes bugs self.tokenizer.pad_token = self.tokenizer.eos_token def _tokenize(self, text_sample: Dict[str, Any]): """Apply the tokenizer to a sample. Args: text_sample (Dict[str, Any]): Sample to tokenize. """ if self.group_method == 'truncate': truncation = True padding = 'max_length' max_length = self.max_seq_len else: truncation = False padding = False max_length = None return self.tokenizer(text_sample['text'], truncation=truncation, padding=padding, max_length=max_length) def __getitem__(self, idx: int) -> Any: """Get sample by global index, blocking to load its shard if missing. Args: idx (int): Sample index. Returns: Any: Sample data. """ text_sample = super().__getitem__(idx) token_sample = self._tokenize(text_sample) # Skip any token grouping, currently only supporting group_method='truncate' return token_sample