Tip

This tutorial is available as a Jupyter notebook.

Open in Colab

Spark DataFrame to MDS#

In this tutorial, we will demonstrate how to use the Streaming Spark converter to convert a Spark DataFrame to create a StreamingDataset. The users have the option to pass in a preprocessing job such as a tokenizer to the converter, which can be useful if materializing the intermediate dataframe is time consuming or taking extra development.

Tutorial Covers#

  1. Installation of libraries

  2. Basic: Convert Spark DataFrame to MDS format.

  3. Advanced: Convert Spark DataFrame into tokenized format and convert to MDS format.

Setup#

Let’s start by installing mosaicml-streaming and some other needed packages.

[ ]:
%pip install --upgrade fsspec  datasets transformers
[ ]:
%pip install mosaicml-streaming
[ ]:
%pip install pyspark==3.4.1
[ ]:
import os
import shutil
from typing import Any, Sequence, Dict, Iterable, Optional
from pyspark.sql import SparkSession
import pandas as pd
import numpy as np
from tempfile import mkdtemp
import datasets as hf_datasets
from transformers import AutoTokenizer, PreTrainedTokenizerBase

We’ll be using Streaming’s dataframe_to_mds() method which converts a DataFrame into Streaming’s MDS format.

[ ]:
from streaming.base.converters import dataframe_to_mds

Basic: Convert Spark DataFrame to MDS format#

Steps: 1. Create a Synthetic NLP dataset. 2. Store the above dataset as a parquet file. 3. Load the parquet file as spark dataframe. 4. Convert the Spark DataFrame to MDS format. 5. Load the MDS dataset and inspect the output.

Create a Synthetic NLP dataset#

In this tutorial, we will be creating a synthetic number-saying dataset, i.e. converting numbers from digits to words, for example, the number 123 would be converted to one hundred twenty three. The numbers are generated sequentially.

Let’s make a short synthetic number-saying dataset class.

[ ]:
class NumberAndSayDataset:
    """Generate a synthetic number-saying dataset.

    Converts numbers from digits to words. Supports positive and negative numbers
    up to approximately 99 million.

    Args:
        num_samples (int): number of samples. Defaults to 100.
        column_names list[str]: A list of features and target name. Defaults to ['number',
            'words'].
        seed (int): seed value for deterministic randomness.
    """

    ones = (
        'zero one two three four five six seven eight nine ten eleven twelve thirteen fourteen ' +
        'fifteen sixteen seventeen eighteen nineteen').split()

    tens = 'twenty thirty forty fifty sixty seventy eighty ninety'.split()

    def __init__(self,
                 num_samples: int = 100,
                 column_names: list[str] = ['number', 'words'],
                 seed: int = 987) -> None:
        self.num_samples = num_samples
        self.column_encodings = ['int', 'str']
        self.column_sizes = [8, None]
        self.column_names = column_names
        self._index = 0
        self.seed = seed

    def __len__(self) -> int:
        return self.num_samples

    def _say(self, i: int) -> list[str]:
        if i < 0:
            return ['negative'] + self._say(-i)
        elif i <= 19:
            return [self.ones[i]]
        elif i < 100:
            return [self.tens[i // 10 - 2]] + ([self.ones[i % 10]] if i % 10 else [])
        elif i < 1_000:
            return [self.ones[i // 100], 'hundred'] + (self._say(i % 100) if i % 100 else [])
        elif i < 1_000_000:
            return self._say(i // 1_000) + ['thousand'
                                           ] + (self._say(i % 1_000) if i % 1_000 else [])
        elif i < 1_000_000_000:
            return self._say(
                i // 1_000_000) + ['million'] + (self._say(i % 1_000_000) if i % 1_000_000 else [])
        else:
            assert False

    def _get_number(self) -> int:
        sign = (np.random.random() < 0.8) * 2 - 1
        mag = 10**np.random.uniform(1, 4) - 10
        return sign * int(mag**2)

    def __iter__(self):
        return self

    def __next__(self) -> dict[str, Any]:
        if self._index >= self.num_samples:
            raise StopIteration
        number = self._get_number()
        words = ' '.join(self._say(number))
        self._index += 1
        return {
            self.column_names[0]: number,
            self.column_names[1]: words,
        }

    @property
    def seed(self) -> int:
        return self._seed

    @seed.setter
    def seed(self, value: int) -> None:
        self._seed = value  # pyright: ignore
        np.random.seed(self._seed)

Store the dataset as a parquet file#

[ ]:
# Create a temporary directory
local_dir = mkdtemp()

syn_dataset = NumberAndSayDataset()
df = pd.DataFrame.from_dict([record for record in syn_dataset])
df.to_parquet(os.path.join(local_dir, 'synthetic_dataset.parquet'))

Load the parquet file as spark dataframe#

[ ]:
spark = SparkSession.builder.getOrCreate()
pdf = spark.read.parquet(os.path.join(local_dir, 'synthetic_dataset.parquet'))

Take a peek at the spark dataframe

[ ]:
pdf.show(5, truncate=False)

Convert the spark dataframe to MDS format#

[ ]:
# Empty the MDS output directory
out_path = os.path.join(local_dir, 'mds')
shutil.rmtree(out_path, ignore_errors=True)

# Specify the mandatory MDSWriter arguments `out` and `columns`.
mds_kwargs = {'out': out_path, 'columns': {'number': 'int64', 'words':'str'}}

# Convert the dataset to an MDS format. It divides the dataframe into 4 parts, one parts per worker and merge the `index.json` from 4 sub-parts into one in a parent directory.
dataframe_to_mds(pdf.repartition(4), merge_index=True, mds_kwargs=mds_kwargs)

Let’s check file structures in the output MDS dataset. One can see four directories and one index.json file. The index.json file contains the meta-data information about all four sub-directories.

[ ]:
%ls {out_path}

Load the MDS dataset using StreamingDataset#

Here, we use StreamingDataset to load the MDS dataset and inspect it.

[ ]:
from torch.utils.data import DataLoader
import streaming
from streaming import StreamingDataset

# clean stale shared memory if any
streaming.base.util.clean_stale_shared_memory()

dataset = StreamingDataset(local=out_path, remote=None, batch_size=2, predownload=4)

dataloader = DataLoader(dataset, batch_size=2, num_workers=1)

for i, data in enumerate(dataloader):
    print(data)
    # Display only first 10 batches
    if i == 10:
        break

Advanced: Convert Spark DataFrame into tokenized format and convert to MDS format#

Steps: 1. [Same as above] Create a Synthetic NLP dataset. 2. [Same as above] Store the above dataset as a parquet file. 3. [Same as above] Load the parquet file as spark dataframe. 4. Create a user defined function which modifies the dataframe 4. Convert the modified data into MDS format. 5. Load the MDS dataset and look at the output.

For steps 1-3, follow the steps detailed above.

Create a user defined function which modifies the dataframe#

The user defined function should be an iterable function and it must yield an output as a dictionary with key as the column name and value as the output of that column. For example, in this tutorial, the key is tokens and value is the tokenized output in bytes. If an iterable function is defined, the user takes the full responsibility of providing the correct columns argument, in the case below, it should be

columns={'tokens': 'bytes'}

where tokens is the key created by the udf_iterator, and bytes represents the format of the field so that MDS chooses the proper encoding method.

Take a peek at the Spark DataFrame

[ ]:
pdf.show(5, truncate=False)

Convert the Spark DataFrame to MDS format#

This time we supply the user defined iterable function and the associated function arguments. For the purpose of demonstration, the user defined tokenization function pandas_processing_fn is largely simplified. For practical applications, the users may want to have more involved preprocessing steps. For concatenation dataset and more process examples, users are referred to Mosaic’s LLM Foundry.

[ ]:
import os
import warnings
from typing import Dict, Iterable, Union
import datasets as hf_datasets
import pandas as pd
import numpy as np
from torch.utils.data import IterableDataset
from transformers import PreTrainedTokenizerBase


def pandas_processing_fn(df: pd.DataFrame, **args) -> Iterable[Dict[str, bytes]]:
    """
    Parameters:
    -----------
    df : pandas.DataFrame
        The input pandas DataFrame that needs to be processed.

    **args : keyword arguments
        Additional arguments to be passed to the 'process_some_data' function during processing.

    Returns:
    --------
    iterable obj
    """
    hf_dataset = hf_datasets.Dataset.from_pandas(df=df, split=args['split'])
    tokenizer = AutoTokenizer.from_pretrained(args['tokenizer'])
    # we will enforce length, so suppress warnings about sequences too long for the model
    tokenizer.model_max_length = int(1e30)
    max_length = args['concat_tokens']

    for sample in hf_dataset:

        buffer = []
        for sample in hf_dataset:
            encoded = tokenizer(sample['words'],
                                truncation=False,
                                padding=False)
            iids = encoded['input_ids']
            buffer = buffer + iids
            while len(buffer) >= max_length:
                concat_sample = buffer[:max_length]
                buffer = []
                yield {
                    # convert to bytes to store in MDS binary format
                    'tokens': np.asarray(concat_sample).tobytes()
                }
[ ]:
# Empty the MDS output directory
out_path = os.path.join(local_dir, 'mds')
shutil.rmtree(out_path, ignore_errors=True)

# Provide a MDS keyword args. Ensure `columns` field maps the output from iterable function (Tokenizer in this example)
mds_kwargs = {'out': out_path, 'columns': {'tokens': 'bytes'}}

# Tokenizer arguments
udf_kwargs = {
    'concat_tokens': 4,
    'tokenizer': 'EleutherAI/gpt-neox-20b',
    'eos_text': '<|endoftext|>',
    'compression': 'zstd',
    'split': 'train',
    'no_wrap': False,
    'bos_text': '',
}

# Convert the dataset to an MDS format. It fetches sample from dataframe, tokenize it, and then convert to MDS format.
# It divides the dataframe into 4 parts, one parts per worker and merge the `index.json` from 4 sub-parts into one in a parent directory.
dataframe_to_mds(pdf.repartition(4), merge_index=True, mds_kwargs=mds_kwargs, udf_iterable=pandas_processing_fn, udf_kwargs=udf_kwargs)

Let’s check file structures in the output MDS dataset. One can see four directories and one index.json file. The index.json file contains the meta-data information about all four sub-directories.

[ ]:
%ls {out_path}
[ ]:
%cat {out_path +'/index.json'}

Load the MDS dataset using StreamingDataset#

Here, we use StreamingDataset to load the MDS dataset and inspect it.

[ ]:
from torch.utils.data import DataLoader
import streaming
from streaming import StreamingDataset

# clean stale shared memory if any
streaming.base.util.clean_stale_shared_memory()

dataset = StreamingDataset(local=out_path, remote=None, batch_size=2, predownload=4)

dataloader = DataLoader(dataset, batch_size=2, num_workers=1)

for i, data in enumerate(dataloader):
    print(data)
    # Display only first 10 batches
    if i == 10:
        break

Cleanup#

[ ]:
shutil.rmtree(out_path, ignore_errors=True)
shutil.rmtree(local_dir, ignore_errors=True)

What next?#

You’ve now seen an in-depth look at how to convert a Spark DataFrame to MDS format and load the same MDS dataset for model training.