Source code for streaming.base.format.json.writer

# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

""":class:`JSONWriter` converts a list of samples into binary `.mds` files that can be read as a :class:`JSONReader`."""

import json
from typing import Any, Dict, List, Optional, Tuple

import numpy as np

from streaming.base.format.base.writer import SplitWriter
from streaming.base.format.json.encodings import is_json_encoded, is_json_encoding

__all__ = ['JSONWriter']


[docs]class JSONWriter(SplitWriter): r"""Writes a streaming JSON dataset. Args: dirname (str): Local dataset directory. columns (Dict[str, str]): Sample columns. compression (str, optional): Optional compression or compression:level. Defaults to ``None``. hashes (List[str], optional): Optional list of hash algorithms to apply to shard files. Defaults to ``None``. size_limit (int, optional): Optional shard size limit, after which point to start a new shard. If None, puts everything in one shard. Defaults to ``None``. newline (str): Newline character inserted between samples. Defaults to ``\\n``. """ format = 'json' def __init__(self, dirname: str, columns: Dict[str, str], compression: Optional[str] = None, hashes: Optional[List[str]] = None, size_limit: Optional[int] = 1 << 26, newline: str = '\n') -> None: super().__init__(dirname, compression, hashes, size_limit) for encoding in columns.values(): assert is_json_encoding(encoding) self.columns = columns self.newline = newline
[docs] def encode_sample(self, sample: Dict[str, Any]) -> bytes: """Encode a sample dict to bytes. Args: sample (Dict[str, Any]): Sample dict. Returns: bytes: Sample encoded as bytes. """ obj = {} for key, encoding in self.columns.items(): value = sample[key] assert is_json_encoded(encoding, value) obj[key] = value text = json.dumps(obj, sort_keys=True) + self.newline return text.encode('utf-8')
[docs] def get_config(self) -> Dict[str, Any]: """Get object describing shard-writing configuration. Returns: Dict[str, Any]: JSON object. """ obj = super().get_config() obj.update({'columns': self.columns, 'newline': self.newline}) return obj
[docs] def encode_split_shard(self) -> Tuple[bytes, bytes]: """Encode a split shard out of the cached samples (data, meta files). Returns: Tuple[bytes, bytes]: Data file, meta file. """ data = b''.join(self.new_samples) num_samples = np.uint32(len(self.new_samples)) sizes = list(map(len, self.new_samples)) offsets = np.array([0] + sizes).cumsum().astype(np.uint32) obj = self.get_config() text = json.dumps(obj, sort_keys=True) meta = num_samples.tobytes() + offsets.tobytes() + text.encode('utf-8') return data, meta