|
- # Copyright (c) Open-MMLab. All rights reserved.
- import torch
- from torch.nn.parallel.distributed import (DistributedDataParallel,
- _find_tensors)
-
- from mmcv.utils import TORCH_VERSION
- from .scatter_gather import scatter_kwargs
-
-
- class MMDistributedDataParallel(DistributedDataParallel):
- """The DDP module that supports DataContainer.
-
- MMDDP has two main differences with PyTorch DDP:
-
- - It supports a custom type :class:`DataContainer` which allows more
- flexible control of input data.
- - It implement two APIs ``train_step()`` and ``val_step()``.
- """
-
- def scatter(self, inputs, kwargs, device_ids):
- return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
-
- def train_step(self, *inputs, **kwargs):
- """train_step() API for module wrapped by DistributedDataParallel.
-
- This method is basically the same as
- ``DistributedDataParallel.forward()``, while replacing
- ``self.module.forward()`` with ``self.module.train_step()``.
- It is compatible with PyTorch 1.1 - 1.5.
- """
- if getattr(self, 'require_forward_param_sync', True):
- self._sync_params()
- if self.device_ids:
- inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
- if len(self.device_ids) == 1:
- output = self.module.train_step(*inputs[0], **kwargs[0])
- else:
- outputs = self.parallel_apply(
- self._module_copies[:len(inputs)], inputs, kwargs)
- output = self.gather(outputs, self.output_device)
- else:
- output = self.module.train_step(*inputs, **kwargs)
-
- if torch.is_grad_enabled() and getattr(
- self, 'require_backward_grad_sync', True):
- if self.find_unused_parameters:
- self.reducer.prepare_for_backward(list(_find_tensors(output)))
- else:
- self.reducer.prepare_for_backward([])
- else:
- if TORCH_VERSION > '1.2':
- self.require_forward_param_sync = False
- return output
-
- def val_step(self, *inputs, **kwargs):
- """val_step() API for module wrapped by DistributedDataParallel.
-
- This method is basically the same as
- ``DistributedDataParallel.forward()``, while replacing
- ``self.module.forward()`` with ``self.module.val_step()``.
- It is compatible with PyTorch 1.1 - 1.5.
- """
- if getattr(self, 'require_forward_param_sync', True):
- self._sync_params()
- if self.device_ids:
- inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
- if len(self.device_ids) == 1:
- output = self.module.val_step(*inputs[0], **kwargs[0])
- else:
- outputs = self.parallel_apply(
- self._module_copies[:len(inputs)], inputs, kwargs)
- output = self.gather(outputs, self.output_device)
- else:
- output = self.module.val_step(*inputs, **kwargs)
-
- if torch.is_grad_enabled() and getattr(
- self, 'require_backward_grad_sync', True):
- if self.find_unused_parameters:
- self.reducer.prepare_for_backward(list(_find_tensors(output)))
- else:
- self.reducer.prepare_for_backward([])
- else:
- if TORCH_VERSION > '1.2':
- self.require_forward_param_sync = False
- return output
|