# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0
"""The Pile.
The Pile is a 825 GiB diverse, open source language modelling data set that consists of 22 smaller,
high-quality datasets combined together.
"""
from typing import Any, Dict, Optional
from transformers.models.auto.tokenization_auto import AutoTokenizer
from streaming.base import StreamingDataset
__all__ = ['StreamingPile']
[docs]class StreamingPile(StreamingDataset):
"""Implementation of the the Pile using StreamingDataset.
Args:
tokenizer_name (str): The name of the HuggingFace tokenizer to use to tokenize samples.
max_seq_len (int): The max sequence length of each token sample.
group_method (str): How to group text samples into token samples. Currently only supporting
``'truncate'``.
local (str): Local dataset directory where shards are cached by split.
remote (str, optional): Download shards from this remote path or directory. If None, this
rank and worker's partition of the dataset must all exist locally. Defaults to
``None``.
split (str, optional): Which dataset split to use, if any. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
predownload (int, optional): Target number of samples ahead to download the shards of while
iterating. Defaults to ``100_000``.
keep_zip (bool, optional): Whether to keep or delete the compressed file when
decompressing downloaded shards. If set to None, keep iff remote is local. Defaults to
``None``.
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption.
Defaults to ``None``, which is interpreted as the number of nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
"""
def __init__(self,
tokenizer_name: str,
max_seq_len: int,
group_method: str,
local: str,
remote: Optional[str] = None,
split: Optional[str] = None,
shuffle: bool = False,
predownload: Optional[int] = 100_000,
keep_zip: Optional[bool] = None,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
shuffle_seed: int = 9176,
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None) -> None:
if group_method not in ['truncate']:
raise ValueError(f'Only group_method="truncate" is supported at this time.')
super().__init__(local, remote, split, shuffle, predownload, keep_zip, download_retry,
download_timeout, validate_hash, shuffle_seed, num_canonical_nodes,
batch_size)
self.tokenizer_name = tokenizer_name
self.max_seq_len = max_seq_len
self.group_method = group_method
# Build tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
if self.tokenizer.pad_token is None:
# Some tokenizers (e.g. GPT2 tokenizer) have no padding token which causes bugs
self.tokenizer.pad_token = self.tokenizer.eos_token
def _tokenize(self, text_sample: Dict[str, Any]):
"""Apply the tokenizer to a sample.
Args:
text_sample (Dict[str, Any]): Sample to tokenize.
"""
if self.group_method == 'truncate':
truncation = True
padding = 'max_length'
max_length = self.max_seq_len
else:
truncation = False
padding = False
max_length = None
return self.tokenizer(text_sample['text'],
truncation=truncation,
padding=padding,
max_length=max_length)
def __getitem__(self, idx: int) -> Any:
"""Get sample by global index, blocking to load its shard if missing.
Args:
idx (int): Sample index.
Returns:
Any: Sample data.
"""
text_sample = super().__getitem__(idx)
token_sample = self._tokenize(text_sample)
# Skip any token grouping, currently only supporting group_method='truncate'
return token_sample