Source code for streaming.base.format.mds.writer

# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

""":class:`MDSWriter` converts a list of samples into binary `.mds` files that can be read as a :class:`MDSReader`."""

import json
from typing import Any, Dict, List, Optional

import numpy as np

from streaming.base.format.base.writer import JointWriter
from streaming.base.format.mds.encodings import (get_mds_encoded_size, get_mds_encodings,
                                                 is_mds_encoding, mds_encode)

__all__ = ['MDSWriter']


[docs]class MDSWriter(JointWriter): """Writes a streaming MDS dataset. Args: dirname (str): Local dataset directory. columns (Dict[str, str]): Sample columns. compression (str, optional): Optional compression or compression:level. Defaults to ``None``. hashes (List[str], optional): Optional list of hash algorithms to apply to shard files. Defaults to ``None``. size_limit (int, optional): Optional shard size limit, after which point to start a new shard. If None, puts everything in one shard. Defaults to ``1 << 26``. """ format = 'mds' extra_bytes_per_sample = 4 def __init__(self, dirname: str, columns: Dict[str, str], compression: Optional[str] = None, hashes: Optional[List[str]] = None, size_limit: Optional[int] = 1 << 26) -> None: super().__init__(dirname, compression, hashes, size_limit, 0, self.extra_bytes_per_sample) self.columns = columns self.column_names = [] self.column_encodings = [] self.column_sizes = [] for name in sorted(columns): encoding = columns[name] if not is_mds_encoding(encoding): raise TypeError( f'MDSWriter passed column "{name}" with encoding "{encoding}" is unsupported. Supported encodings are {get_mds_encodings()}' ) size = get_mds_encoded_size(encoding) self.column_names.append(name) self.column_encodings.append(encoding) self.column_sizes.append(size) obj = self.get_config() text = json.dumps(obj, sort_keys=True) self.config_data = text.encode('utf-8') self.extra_bytes_per_shard = 4 + 4 + len(self.config_data) self._reset_cache()
[docs] def encode_sample(self, sample: Dict[str, Any]) -> bytes: """Encode a sample dict to bytes. Args: sample (Dict[str, Any]): Sample dict. Returns: bytes: Sample encoded as bytes. """ sizes = [] data = [] for key, encoding, size in zip(self.column_names, self.column_encodings, self.column_sizes): value = sample[key] datum = mds_encode(encoding, value) if size is None: size = len(datum) sizes.append(size) else: assert size == len(datum) data.append(datum) head = np.array(sizes, np.uint32).tobytes() body = b''.join(data) return head + body
[docs] def get_config(self) -> Dict[str, Any]: """Get object describing shard-writing configuration. Returns: Dict[str, Any]: JSON object. """ obj = super().get_config() obj.update({ 'column_names': self.column_names, 'column_encodings': self.column_encodings, 'column_sizes': self.column_sizes }) return obj
[docs] def encode_joint_shard(self) -> bytes: """Encode a joint shard out of the cached samples (single file). Returns: bytes: File data. """ num_samples = np.uint32(len(self.new_samples)) sizes = list(map(len, self.new_samples)) offsets = np.array([0] + sizes).cumsum().astype(np.uint32) offsets += len(num_samples.tobytes()) + len(offsets.tobytes()) + len(self.config_data) sample_data = b''.join(self.new_samples) return num_samples.tobytes() + offsets.tobytes() + self.config_data + sample_data