|
- # Copyright (c) Open-MMLab. All rights reserved.
- import torch
- from torch.nn.parallel._functions import Scatter as OrigScatter
-
- from ._functions import Scatter
- from .data_container import DataContainer
-
-
- def scatter(inputs, target_gpus, dim=0):
- """Scatter inputs to target gpus.
-
- The only difference from original :func:`scatter` is to add support for
- :type:`~mmcv.parallel.DataContainer`.
- """
-
- def scatter_map(obj):
- if isinstance(obj, torch.Tensor):
- if target_gpus != [-1]:
- return OrigScatter.apply(target_gpus, None, dim, obj)
- else:
- # for CPU inference we use self-implemented scatter
- return Scatter.forward(target_gpus, obj)
- if isinstance(obj, DataContainer):
- if obj.cpu_only:
- return obj.data
- else:
- return Scatter.forward(target_gpus, obj.data)
- if isinstance(obj, tuple) and len(obj) > 0:
- return list(zip(*map(scatter_map, obj)))
- if isinstance(obj, list) and len(obj) > 0:
- out = list(map(list, zip(*map(scatter_map, obj))))
- return out
- if isinstance(obj, dict) and len(obj) > 0:
- out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
- return out
- return [obj for targets in target_gpus]
-
- # After scatter_map is called, a scatter_map cell will exist. This cell
- # has a reference to the actual function scatter_map, which has references
- # to a closure that has a reference to the scatter_map cell (because the
- # fn is recursive). To avoid this reference cycle, we set the function to
- # None, clearing the cell
- try:
- return scatter_map(inputs)
- finally:
- scatter_map = None
-
-
- def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
- """Scatter with support for kwargs dictionary."""
- inputs = scatter(inputs, target_gpus, dim) if inputs else []
- kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
- if len(inputs) < len(kwargs):
- inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
- elif len(kwargs) < len(inputs):
- kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
- inputs = tuple(inputs)
- kwargs = tuple(kwargs)
- return inputs, kwargs
|