Tip

This tutorial is available as a Jupyter notebook.

Open in Colab

FaceSynthetics#

Why wait for your data to download when you can stream it instead? Let’s see how to do so with MosaicML Streaming and Composer.

Streaming is useful for multi-node setups where workers don’t have persistent storage and each element of the dataset must be downloaded exactly once. Composer is a library for training neural networks better, faster, and cheaper. In this tutorial, we’ll demonstrate a streaming approach to loading our datasets, using Microsoft’s FaceSynthetics dataset as an example, and we’ll use composer for model training.

Tutorial Goals and Concepts Covered#

The goal of this tutorial is to showcase how to prepare the dataset and use Streaming data loading to train the model. It will consist of a few steps:

  1. Obtaining the dataset

  2. Preparing the dataset for streaming

  3. Streaming the dataset to the local machine

  4. Training a model using these datasets

Let’s get started!

Setup#

Let’s start by making sure the right packages are installed and imported.

First, let’s make sure we’ve installed our dependencies; note that mmcv-full will take some time to unpack. To speed things up, we have included mmcv, mmsegmentation and many other useful computer vision libraries in the mosaicml/pytorch_vision Docker Image. We need Composer for model training and Streaming for streaming the dataset.

[ ]:
%pip install mmsegmentation mmcv mmcv-full

%pip install mosaicml
# To install from source instead of the last release, comment the command above and uncomment the following one.
# %pip install git+https://github.com/mosaicml/composer.git

%pip install mosaicml-streaming
# To install from source instead of the last release, comment the command above and uncomment the following one.
# %pip install git+https://github.com/mosaicml/streaming.git

# (Optional) To upload a streaming dataset to an AWS S3 bucket
%pip install awscli
[ ]:
import os
import time
import torch
import struct
import shutil
import requests

from PIL import Image
from io import BytesIO
from zipfile import ZipFile
from torch.utils.data import DataLoader
from typing import Iterator, Tuple, Dict
from torchvision import transforms as tf

We’ll be using Streaming’s dataset writer called MDSWriter which writes the dataset in Streaming format, Composer DeepLabV3 model which should help improve our performance even on the small hundred-image dataset, and StreamingDataset to load the streaming dataset.

[ ]:
from streaming import MDSWriter, StreamingDataset
from composer.models.deeplabv3 import composer_deeplabv3
[ ]:
from composer import Trainer
from composer.models import composer_deeplabv3
from composer.optim import DecoupledAdamW

Global settings#

For this tutorial, it makes the most sense to organize our global settings here rather than distribute them throughout the cells in which they’re used.

[ ]:
# the location of our dataset
in_root = "./dataset"

# the location of the "remote" streaming dataset (`sds`).
# Upload `out_root` to your cloud storage provider of choice.
out_root = "./sds"
out_train = "./sds/train"
out_test = "./sds/test"

# the location to download the streaming dataset during training
local = './local'
local_train = './local/train'
local_test = './local/test'

# toggle shuffling in dataloader
shuffle_train = True
shuffle_test = False

# possible values for a pixel in the annotation image to take
num_classes = 20

# shard size limit, in bytes
size_limit = 1 << 25

# show a progress bar while downloading
use_tqdm = True

# ratio of training data to test data
training_ratio = 0.9

# training batch size
batch_size = 2 # this is the smallest batch size possible,
               # increase this if your machine can handle it.

# training hardware parameters
device = "gpu" if torch.cuda.is_available() else "cpu"

# number of training epochs
train_epochs = "3ep" # increase the number of epochs for greater accuracy

# number of images in the dataset (training + test)
num_images = 100 # can be 100, 1_000, or 100_000

# location to download the dataset zip file
dataset_archive = "./dataset.zip"

# remote dataset URL
URL = f"https://facesyntheticspubwedata.blob.core.windows.net/iccv-2021/dataset_{num_images}.zip"

# Hashing algorithm to use for dataset
hashes = ['sha1' ,'xxh64']
[ ]:
# upload location for the dataset splits (change this if you want to upload to a different location, for example, AWS S3 bucket location)
upload_location = None

if upload_location is None:
    upload_train_location = None
    upload_test_location = None
else:
    upload_train_location = os.path.join(upload_location, 'train')
    upload_test_location = os.path.join(upload_location, 'test')

Getting the dataset#

[ ]:
if not os.path.exists(dataset_archive):
    response = requests.get(URL)
    with open(dataset_archive, "wb") as dataset_file:
        dataset_file.write(response.content)

    with ZipFile(dataset_archive, 'r') as myzip:
        myzip.extractall(in_root)

Next, we’ll make the directories for our binary streaming dataset files.

Preparing the dataset#

The dataset consists of a directory of images with names in the form 123456.png, 123456_seg.png, and 123456_ldmks.png. For this example, we’ll only use the images with segmentation annotations as labels and ignore the landmarks for now.

[ ]:
def each(dirname: str, start_ix: int = 0, end_ix: int = num_images) -> Iterator[Dict[str, bytes]]:
    for i in range(start_ix, end_ix):
        image = '%s/%06d.png' % (dirname, i)
        annotation = '%s/%06d_seg.png' % (dirname, i)

        with open(image, 'rb') as x, open(annotation, 'rb') as y:
            yield {
                'x': x.read(),
                'y': y.read(),
            }

Below, we’ll set up the logic for writing our starting dataset to files that can be read using a streaming dataloader.

For more information on the MDSWriter check out the API reference.

[1]:
def write_datasets() -> None:
    fields = {'x': 'png', 'y': 'png'}

    num_training_images = int(num_images * training_ratio)

    start_ix, end_ix = 0, num_training_images
    with MDSWriter(out=out_train, columns=fields, hashes=hashes, size_limit=size_limit) as out:
        for sample in each(in_root, start_ix, end_ix):
            out.write(sample)
    start_ix, end_ix = end_ix, num_images
    with MDSWriter(out=out_test, columns=fields, hashes=hashes, size_limit=size_limit) as out:
        for sample in each(in_root, start_ix, end_ix):
            out.write(sample)

Now that we’ve written the datasets to out_root, one can upload them to a cloud storage provider, and we are ready to stream them.

[ ]:
remote_train = upload_train_location or out_train # replace this with your URL for cloud streaming
remote_test  = upload_test_location or out_test

Loading the Data#

We extend StreamingDataset to deserialize the binary data and convert the labels to one-hot encoding.

For more information on the StreamingDataset class check out the API reference.

[ ]:
class FaceSynthetics(StreamingDataset):
    def __init__(self,
                 remote: str,
                 local: str,
                 shuffle: bool,
                 batch_size: int,
                ) -> None:
        super().__init__(local=local, remote=remote, shuffle=shuffle, batch_size=batch_size)

    def __getitem__(self, i:int) -> Tuple[torch.Tensor, torch.Tensor]:
        obj = super().__getitem__(i)
        x = tf.functional.to_tensor(obj['x'])
        y = tf.functional.pil_to_tensor(obj['y'])[0].to(torch.int64)
        y[y == 255] = 19
        return x, y

Putting It All Together#

We’re now ready to actually write the streamable dataset. Let’s do that if we haven’t already.

[ ]:
if not os.path.exists(out_train):
    write_datasets()

(Optional) Upload the Streaming dataset to an AWS S3 bucket of your choice. Uncomment the below line if you have provided the S3 bucket link to upload_location.

[ ]:
# !aws s3 cp $out_root $upload_location --recursive

Once that’s done, we can instantiate our streaming datasets and wrap them in standard dataloaders for training!

[ ]:
dataset_train = FaceSynthetics(remote_train, local_train, shuffle_train, batch_size=batch_size)
dataset_test  = FaceSynthetics(remote_test, local_test, shuffle_test, batch_size=batch_size)

train_dataloader = DataLoader(dataset_train, batch_size=batch_size)
test_dataloader = DataLoader(dataset_test, batch_size=batch_size)

Train with the Streaming Dataloaders#

Now all that’s left to do is train! Doing so with Composer should look pretty familiar by now.

[ ]:
# Create a DeepLabV3 model, and an optimizer for it
model = composer_deeplabv3(
    num_classes=num_classes,
    backbone_arch='resnet101',
    backbone_weights='IMAGENET1K_V2',
    sync_bn=False)
optimizer = DecoupledAdamW(model.parameters(), lr=1e-3)

# Create a trainer object without our model, optimizer, and streaming dataloaders
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    max_duration=train_epochs,
    optimizers=optimizer,
    device=device
)

# Train!
start_time = time.perf_counter()
trainer.fit()
end_time = time.perf_counter()
print(f"It took {end_time - start_time:0.4f} seconds to train")

Cleanup#

That’s it. No need to hang on to the files created by the tutorial…

[ ]:
shutil.rmtree(out_root, ignore_errors=True)
shutil.rmtree(in_root, ignore_errors=True)
shutil.rmtree(local, ignore_errors=True)
if os.path.exists(dataset_archive):
    os.remove(dataset_archive)

What next?#

You’ve now seen an in-depth look at how to prepare and use streaming datasets with Composer.

To continue learning about Streaming, please continue to explore our examples!

Come get involved with MosaicML!#

We’d love for you to get involved with the MosaicML community in any of these ways:

Star Streaming on GitHub#

Help make others aware of our work by starring Streaming on GitHub.

Join the MosaicML Slack#

Head on over to the MosaicML slack to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!

Contribute to Streaming#

Is there a bug you noticed or a feature you’d like? File an issue or make a pull request!