|
- # Copyright (c) Open-MMLab. All rights reserved.
- import torch
- from torch.nn.parallel._functions import _get_stream
-
-
- def scatter(input, devices, streams=None):
- """Scatters tensor across multiple GPUs."""
- if streams is None:
- streams = [None] * len(devices)
-
- if isinstance(input, list):
- chunk_size = (len(input) - 1) // len(devices) + 1
- outputs = [
- scatter(input[i], [devices[i // chunk_size]],
- [streams[i // chunk_size]]) for i in range(len(input))
- ]
- return outputs
- elif isinstance(input, torch.Tensor):
- output = input.contiguous()
- # TODO: copy to a pinned buffer first (if copying from CPU)
- stream = streams[0] if output.numel() > 0 else None
- if devices != [-1]:
- with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
- output = output.cuda(devices[0], non_blocking=True)
- else:
- # unsquzee the first dimension thus the tensor's shape is the
- # same as those scattered with GPU.
- output = output.unsqueeze(0)
- return output
- else:
- raise Exception(f'Unknown type {type(input)}.')
-
-
- def synchronize_stream(output, devices, streams):
- if isinstance(output, list):
- chunk_size = len(output) // len(devices)
- for i in range(len(devices)):
- for j in range(chunk_size):
- synchronize_stream(output[i * chunk_size + j], [devices[i]],
- [streams[i]])
- elif isinstance(output, torch.Tensor):
- if output.numel() != 0:
- with torch.cuda.device(devices[0]):
- main_stream = torch.cuda.current_stream()
- main_stream.wait_stream(streams[0])
- output.record_stream(main_stream)
- else:
- raise Exception(f'Unknown type {type(output)}.')
-
-
- def get_input_device(input):
- if isinstance(input, list):
- for item in input:
- input_device = get_input_device(item)
- if input_device != -1:
- return input_device
- return -1
- elif isinstance(input, torch.Tensor):
- return input.get_device() if input.is_cuda else -1
- else:
- raise Exception(f'Unknown type {type(input)}.')
-
-
- class Scatter:
-
- @staticmethod
- def forward(target_gpus, input):
- input_device = get_input_device(input)
- streams = None
- if input_device == -1 and target_gpus != [-1]:
- # Perform CPU to GPU copies in a background stream
- streams = [_get_stream(device) for device in target_gpus]
-
- outputs = scatter(input, target_gpus, streams)
- # Synchronize with the copy stream
- if streams is not None:
- synchronize_stream(outputs, target_gpus, streams)
-
- return tuple(outputs)
|