Source code for olympus.distributed.multigpu

import torch.distributed
import torch.nn
import sys

from olympus.utils.options import options

_stdout = None

# From 0 to N representing each process running in parallel
# a _rank == -1 means the process is the launcher process or a single GPU training process
_rank = -1


[docs]def rank(): return _rank
[docs]def set_rank(rank_): global _rank _rank = rank_
[docs]class NoOut: def __init__(self): pass
[docs] def write(self, string): pass
[docs] def flush(self): pass
[docs]def arguments(parser): parser.add_argument( '--rank', type=int, default=0, metavar='R', help='process rank') parser.add_argument( '--dist-url', type=str, default=None, metavar='DIST_URL', help='distributed backend (nccl:tcp://localhost:8123)') parser.add_argument( '--world-size', type=int, default=1, metavar='WS', help='Number of process running in parallel') return parser
[docs]def enable_distributed_process(rank, dist_url, world_size, silence_stdout=options('distributed.noprint', True)): """Initialize distributed system, if all required arguments are met, do nothing if not Parameters ---------- rank: Optional[int] Process rank dist_url: str Distributed backend uri world_size: int Number of processes running in parallel silence_stdout: bool Silence the standard output for ranks superior to 0 """ global _stdout if rank is None: return if world_size > 1: set_rank(rank) backend, url = dist_url.split(':', maxsplit=1) torch.distributed.init_process_group( backend=backend, init_method=url, rank=rank, world_size=world_size ) if rank != 0 and silence_stdout: _stdout = sys.stdout sys.stdout = NoOut()
[docs]def data_parallel(model, device_ids=None, *args, **kwargs): if device_ids is not None or rank() != -1: return torch.nn.parallel.DistributedDataParallel(model, device_ids, *args, **kwargs) return model