|
- # Copyright (c) Open-MMLab. All rights reserved.
- from itertools import chain
-
- from torch.nn.parallel import DataParallel
-
- from .scatter_gather import scatter_kwargs
-
-
- class MMDataParallel(DataParallel):
- """The DataParallel module that supports DataContainer.
-
- MMDataParallel has two main differences with PyTorch DataParallel:
-
- - It supports a custom type :class:`DataContainer` which allows more
- flexible control of input data during both GPU and CPU inference.
- - It implement two more APIs ``train_step()`` and ``val_step()``.
-
- Args:
- module (:class:`nn.Module`): Module to be encapsulated.
- device_ids (list[int]): Device IDS of modules to be scattered to.
- Defaults to None when GPU is not available.
- output_device (str | int): Device ID for output. Defaults to None.
- dim (int): Dimension used to scatter the data. Defaults to 0.
- """
-
- def __init__(self, *args, dim=0, **kwargs):
- super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
- self.dim = dim
-
- def forward(self, *inputs, **kwargs):
- """Override the original forward function.
-
- The main difference lies in the CPU inference where the datas in
- :class:`DataContainers` will still be gathered.
- """
- if not self.device_ids:
- # We add the following line thus the module could gather and
- # convert data containers as those in GPU inference
- inputs, kwargs = self.scatter(inputs, kwargs, [-1])
- return self.module(*inputs[0], **kwargs[0])
- else:
- return super().forward(*inputs, **kwargs)
-
- def scatter(self, inputs, kwargs, device_ids):
- return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
-
- def train_step(self, *inputs, **kwargs):
- if not self.device_ids:
- # We add the following line thus the module could gather and
- # convert data containers as those in GPU inference
- inputs, kwargs = self.scatter(inputs, kwargs, [-1])
- return self.module.train_step(*inputs, **kwargs)
-
- assert len(self.device_ids) == 1, \
- ('MMDataParallel only supports single GPU training, if you need to'
- ' train with multiple GPUs, please use MMDistributedDataParallel'
- 'instead.')
-
- for t in chain(self.module.parameters(), self.module.buffers()):
- if t.device != self.src_device_obj:
- raise RuntimeError(
- 'module must have its parameters and buffers '
- f'on device {self.src_device_obj} (device_ids[0]) but '
- f'found one of them on device: {t.device}')
-
- inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
- return self.module.train_step(*inputs[0], **kwargs[0])
-
- def val_step(self, *inputs, **kwargs):
- if not self.device_ids:
- # We add the following line thus the module could gather and
- # convert data containers as those in GPU inference
- inputs, kwargs = self.scatter(inputs, kwargs, [-1])
- return self.module.val_step(*inputs, **kwargs)
-
- assert len(self.device_ids) == 1, \
- ('MMDataParallel only supports single GPU training, if you need to'
- ' train with multiple GPUs, please use MMDistributedDataParallel'
- ' instead.')
-
- for t in chain(self.module.parameters(), self.module.buffers()):
- if t.device != self.src_device_obj:
- raise RuntimeError(
- 'module must have its parameters and buffers '
- f'on device {self.src_device_obj} (device_ids[0]) but '
- f'found one of them on device: {t.device}')
-
- inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
- return self.module.val_step(*inputs[0], **kwargs[0])
|