# Copyright 2022-2024 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0
"""Information about nodes, ranks, and workers."""
from typing import Any, Dict, Tuple
from torch.utils.data import get_worker_info
from typing_extensions import Self
from streaming.base import distributed as dist
[docs]class World:
"""Information about the nodes, ranks and workers of this run.
.. warning::
Be careful as to whether this object was initialized in a worker (if workers are used)
or in a rank (which will claim one worker per rank).
.. warning::
In this World object, the counts (num_nodes, num_ranks, num_workers) are global -- not
to be confused with DataLoader num_workers, which is per rank.
Nodes are all assumed to contain the same number of devices (via local_world_size).
Nodes:
- node / num_nodes
- is_multinode
Ranks:
- rank / num_ranks
- rank_of_node / ranks_per_node
Workers:
- worker / num_workers
- worker_of_node / workers_per_node
- worker_of_rank / workers_per_rank
- is_leader
- is_local_leader
"""
def __init__(
self,
num_nodes: int,
ranks_per_node: int,
workers_per_rank: int,
worker: int,
) -> None:
self.node = worker // (ranks_per_node * workers_per_rank)
self.num_nodes = num_nodes
self.is_multinode = 1 < num_nodes
self.rank = worker // workers_per_rank
self.num_ranks = num_nodes * ranks_per_node
self.rank_of_node = self.rank % ranks_per_node
self.ranks_per_node = ranks_per_node
self.worker = worker
self.num_workers = num_nodes * ranks_per_node * workers_per_rank
self.worker_of_node = self.worker % (ranks_per_node * workers_per_rank)
self.workers_per_node = ranks_per_node * workers_per_rank
self.worker_of_rank = self.worker % workers_per_rank
self.workers_per_rank = workers_per_rank
self.is_leader = not worker
self.is_local_leader = not self.worker_of_node
[docs] def to_json(self) -> Dict[str, Any]:
"""Get a JSON version of this config.
Returns:
Dict[str, Any]: JSON config.
"""
return dict(self.__dict__)
@classmethod
def _get_worker_info(cls) -> Tuple[int, int]:
"""Get worker info, or default to 0 of 1.
Returns:
Tuple[int, int]: Worker ID out of how many workers.
"""
info = get_worker_info()
if info:
ret = info.id, info.num_workers
else:
ret = 0, 1
return ret
[docs] @classmethod
def detect(cls) -> Self:
"""Detect the world state.
Returns:
Self: A new World state object according to dist and get_worker_info().
"""
rank = dist.get_rank()
ranks_per_node = dist.get_local_world_size()
num_nodes = dist.get_world_size() // ranks_per_node
worker_of_rank, workers_per_rank = cls._get_worker_info()
worker = rank * workers_per_rank + worker_of_rank
return cls(num_nodes, ranks_per_node, workers_per_rank, worker)
[docs] def copy(self) -> Self:
"""Get a copy of this world state.
Returns:
Self: A new copy with the same state.
"""
return World(
num_nodes=self.num_nodes,
ranks_per_node=self.ranks_per_node,
workers_per_rank=self.workers_per_rank,
worker=self.worker,
)
[docs] def replicate(self, replication: int) -> Self:
"""Get a copy of this world state with the given replication factor.
Args:
replication (int): Replication factor -- how many consecutive devices that should see
the same samples..
Returns:
Self: A new sample replication version of this World state object.
"""
if replication <= 0:
raise ValueError(f'Replication factor must be positive.')
if self.num_ranks % replication:
raise ValueError(f'World size must be divisible by your replication factor.')
rank = self.rank // replication # Evenly divide ranks.
num_ranks = self.num_ranks // replication # Floor divide our rank.
worker = rank * self.workers_per_rank + self.worker_of_rank # Derive worker.
num_nodes = (num_ranks + self.ranks_per_node - 1) // self.ranks_per_node # Ceil divide.
ranks_per_node = num_ranks // num_nodes # Evenly divide ranks per node.
return World(
num_nodes=num_nodes,
ranks_per_node=ranks_per_node,
workers_per_rank=self.workers_per_rank,
worker=worker,
)
[docs] def detect_workers(self) -> Self:
"""Get a copy of this world state with the worker information newly detected.
Returns:
Self: A new workers-newly-detected version of this World state object.
"""
worker_of_rank, workers_per_rank = self._get_worker_info()
worker = self.rank * workers_per_rank + worker_of_rank
return World(
num_nodes=self.num_nodes,
ranks_per_node=self.ranks_per_node,
workers_per_rank=workers_per_rank,
worker=worker,
)