|
- # Copyright (c) Open-MMLab. All rights reserved.
- import torch
- import torch.distributed as dist
- import torch.nn as nn
- from torch._utils import (_flatten_dense_tensors, _take_tensors,
- _unflatten_dense_tensors)
-
- from mmcv.utils import TORCH_VERSION
- from .registry import MODULE_WRAPPERS
- from .scatter_gather import scatter_kwargs
-
-
- @MODULE_WRAPPERS.register_module()
- class MMDistributedDataParallel(nn.Module):
-
- def __init__(self,
- module,
- dim=0,
- broadcast_buffers=True,
- bucket_cap_mb=25):
- super(MMDistributedDataParallel, self).__init__()
- self.module = module
- self.dim = dim
- self.broadcast_buffers = broadcast_buffers
-
- self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
- self._sync_params()
-
- def _dist_broadcast_coalesced(self, tensors, buffer_size):
- for tensors in _take_tensors(tensors, buffer_size):
- flat_tensors = _flatten_dense_tensors(tensors)
- dist.broadcast(flat_tensors, 0)
- for tensor, synced in zip(
- tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
- tensor.copy_(synced)
-
- def _sync_params(self):
- module_states = list(self.module.state_dict().values())
- if len(module_states) > 0:
- self._dist_broadcast_coalesced(module_states,
- self.broadcast_bucket_size)
- if self.broadcast_buffers:
- if TORCH_VERSION < '1.0':
- buffers = [b.data for b in self.module._all_buffers()]
- else:
- buffers = [b.data for b in self.module.buffers()]
- if len(buffers) > 0:
- self._dist_broadcast_coalesced(buffers,
- self.broadcast_bucket_size)
-
- def scatter(self, inputs, kwargs, device_ids):
- return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
-
- def forward(self, *inputs, **kwargs):
- inputs, kwargs = self.scatter(inputs, kwargs,
- [torch.cuda.current_device()])
- return self.module(*inputs[0], **kwargs[0])
-
- def train_step(self, *inputs, **kwargs):
- inputs, kwargs = self.scatter(inputs, kwargs,
- [torch.cuda.current_device()])
- output = self.module.train_step(*inputs[0], **kwargs[0])
- return output
-
- def val_step(self, *inputs, **kwargs):
- inputs, kwargs = self.scatter(inputs, kwargs,
- [torch.cuda.current_device()])
- output = self.module.val_step(*inputs[0], **kwargs[0])
- return output
|