|
- # Copyright (c) Facebook, Inc. and its affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- # author: adefossez
-
- import logging
- import os
-
- import torch
- from torch.utils.data.distributed import DistributedSampler
- from torch.utils.data import DataLoader, Subset
- from torch.nn.parallel.distributed import DistributedDataParallel
-
- logger = logging.getLogger(__name__)
- rank = 0
- world_size = 1
-
-
- def init(args):
- """init.
- Initialize DDP using the given rendezvous file.
- """
- global rank, world_size
- if args.ddp:
- assert args.rank is not None and args.world_size is not None
- rank = args.rank
- world_size = args.world_size
- if world_size == 1:
- return
- torch.cuda.set_device(rank)
- torch.distributed.init_process_group(
- backend=args.ddp_backend,
- init_method='file://' + os.path.abspath(args.rendezvous_file),
- world_size=world_size,
- rank=rank)
- logger.debug("Distributed rendezvous went well, rank %d/%d", rank, world_size)
-
-
- def average(metrics, count=1.):
- """average.
- Average all the relevant metrices across processes
- `metrics`should be a 1D float32 fector. Returns the average of `metrics`
- over all hosts. You can use `count` to control the weight of each worker.
- """
- if world_size == 1:
- return metrics
- tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32)
- tensor *= count
- torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
- return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist()
-
-
- def wrap(model):
- """wrap.
- Wrap a model with DDP if distributed training is enabled.
- """
- if world_size == 1:
- return model
- else:
- return DistributedDataParallel(
- model,
- device_ids=[torch.cuda.current_device()],
- output_device=torch.cuda.current_device())
-
-
- def barrier():
- if world_size > 1:
- torch.distributed.barrier()
-
-
- def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
- """loader.
- Create a dataloader properly in case of distributed training.
- If a gradient is going to be computed you must set `shuffle=True`.
- :param dataset: the dataset to be parallelized
- :param args: relevant args for the loader
- :param shuffle: shuffle examples
- :param klass: loader class
- :param kwargs: relevant args
- """
-
- if world_size == 1:
- return klass(dataset, *args, shuffle=shuffle, **kwargs)
-
- if shuffle:
- # train means we will compute backward, we use DistributedSampler
- sampler = DistributedSampler(dataset)
- # We ignore shuffle, DistributedSampler already shuffles
- return klass(dataset, *args, **kwargs, sampler=sampler)
- else:
- # We make a manual shard, as DistributedSampler otherwise replicate some examples
- dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
- return klass(dataset, *args, shuffle=shuffle)
|