Source code for streaming.text.c4

# Copyright 2022 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""C4 (Colossal Cleaned Common Crawl) dataset.

This dataset is a colossal, cleaned version of Common Crawl's web crawl corpus and it is based on
the `Common Crawl <https://commoncrawl.org>`_ dataset.
"""

from typing import Any, Dict, Iterator, Optional

from transformers.models.auto.tokenization_auto import AutoTokenizer

from streaming.base import StreamingDataset

__all__ = ['StreamingC4']


[docs]class StreamingC4(StreamingDataset): """Implementation of the C4 (Colossal Cleaned Common Crawl) dataset using StreamingDataset. Args: tokenizer_name (str): The name of the HuggingFace tokenizer to use to tokenize samples. max_seq_len (int): The max sequence length of each token sample. group_method (str): How to group text samples into token samples. Currently only supporting ``'truncate'``. local (str): Local dataset directory where shards are cached by split. remote (str, optional): Download shards from this remote path or directory. If None, this rank and worker's partition of the dataset must all exist locally. Defaults to ``None``. split (str, optional): Which dataset split to use, if any. Defaults to ``None``. shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to ``False``. predownload (int, optional): Target number of samples ahead to download the shards of while iterating. Defaults to ``100_000``. keep_zip (bool, optional): Whether to keep or delete the compressed file when decompressing downloaded shards. If set to None, keep iff remote is local. Defaults to ``None``. download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``. download_timeout (float): Number of seconds to wait for a shard to download before raising an exception. Defaults to ``60``. validate_hash (str, optional): Optional hash or checksum algorithm to use to validate shards. Defaults to ``None``. shuffle_seed (int): Seed for Deterministic data shuffling. Defaults to ``9176``. num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with resumption. Defaults to ``None``, which is interpreted as the number of nodes of the initial run. batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is partitioned over the workers. Defaults to ``None``. """ def __init__(self, tokenizer_name: str, max_seq_len: int, group_method: str, local: str, remote: Optional[str] = None, split: Optional[str] = None, shuffle: bool = False, predownload: Optional[int] = 100_000, keep_zip: Optional[bool] = None, download_retry: int = 2, download_timeout: float = 60, validate_hash: Optional[str] = None, shuffle_seed: int = 9176, num_canonical_nodes: Optional[int] = None, batch_size: Optional[int] = None): if group_method not in {'truncate', 'concat'}: raise ValueError( f"group_method='{group_method}' must be one of ['truncate', 'concat'].") super().__init__(local, remote, split, shuffle, predownload, keep_zip, download_retry, download_timeout, validate_hash, shuffle_seed, num_canonical_nodes, batch_size) self.tokenizer_name = tokenizer_name self.max_seq_len = max_seq_len self.group_method = group_method # Build tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) if self.tokenizer.pad_token is None: # Some tokenizers (e.g. GPT2 tokenizer) have no padding token which causes bugs self.tokenizer.pad_token = self.tokenizer.eos_token # suppress warnings when using group_method='concat' and no truncation self.tokenizer.model_max_length = int(1e30) def _tokenize(self, text_sample: Dict[str, Any]): """Apply the tokenizer to a sample. Args: text_sample (Dict[str, Any]): Sample to tokenize. """ if self.group_method == 'truncate': truncation = True padding = 'max_length' max_length = self.max_seq_len elif self.group_method == 'concat': truncation = False padding = False max_length = None else: raise ValueError(f'Got unknown group_method={self.group_method}.') return self.tokenizer(text_sample['text'], truncation=truncation, padding=padding, max_length=max_length) def __getitem__(self, idx: int) -> Any: """Get sample by global index, blocking to load its shard if missing. Args: idx (int): Sample index. Returns: Any: Sample data. """ text_sample = super().__getitem__(idx) token_sample = self._tokenize(text_sample) # Skip any token grouping, currently only supporting group_method='truncate' return token_sample def __iter__(self) -> Iterator[Any]: """Iterable over samples. Since concatenating samples has a custom behavior, it requires extending the parent iterator class. For `group_method = truncate`, simply return the token sample. For `group_method = concat`, keep fetching token samples until it fills up the max_seq_len. Yields: Iterator[Any]: Sample iterator """ if self.group_method == 'truncate': yield from super().__iter__() elif self.group_method == 'concat': buffer = {} while True: iterator = super().__iter__() for sample in iterator: for k, v in sample.items(): buffer[k] = buffer.get(k, []) + v while len(buffer['input_ids']) >= self.max_seq_len: concat_sample = {} for k, v in buffer.items(): concat_sample[k] = v[:self.max_seq_len] buffer[k] = v[self.max_seq_len:] yield concat_sample else: raise ValueError(f'Got unknown group_method={self.group_method}.') def __len__(self) -> Optional[int]: """Number of samples in a dataset. For `group_method = truncate`, return the number of samples. For `group_method = concat`, since it repeat forever, it doesn't have any defined length. Returns: Optional[int]: Number of samples """ if self.group_method == 'truncate': return super().__len__() elif self.group_method == 'concat': return None else: raise ValueError(f'Got unknown group_method={self.group_method}.')