|
- # Copyright (c) Open-MMLab. All rights reserved.
- from collections.abc import Mapping, Sequence
-
- import torch
- import torch.nn.functional as F
- from torch.utils.data.dataloader import default_collate
-
- from .data_container import DataContainer
-
-
- def collate(batch, samples_per_gpu=1):
- """Puts each data field into a tensor/DataContainer with outer dimension
- batch size.
-
- Extend default_collate to add support for
- :type:`~mmcv.parallel.DataContainer`. There are 3 cases.
-
- 1. cpu_only = True, e.g., meta data
- 2. cpu_only = False, stack = True, e.g., images tensors
- 3. cpu_only = False, stack = False, e.g., gt bboxes
- """
-
- if not isinstance(batch, Sequence):
- raise TypeError(f'{batch.dtype} is not supported.')
-
- if isinstance(batch[0], DataContainer):
- stacked = []
- if batch[0].cpu_only:
- for i in range(0, len(batch), samples_per_gpu):
- stacked.append(
- [sample.data for sample in batch[i:i + samples_per_gpu]])
- return DataContainer(
- stacked, batch[0].stack, batch[0].padding_value, cpu_only=True)
- elif batch[0].stack:
- for i in range(0, len(batch), samples_per_gpu):
- assert isinstance(batch[i].data, torch.Tensor)
-
- if batch[i].pad_dims is not None:
- ndim = batch[i].dim()
- assert ndim > batch[i].pad_dims
- max_shape = [0 for _ in range(batch[i].pad_dims)]
- for dim in range(1, batch[i].pad_dims + 1):
- max_shape[dim - 1] = batch[i].size(-dim)
- for sample in batch[i:i + samples_per_gpu]:
- for dim in range(0, ndim - batch[i].pad_dims):
- assert batch[i].size(dim) == sample.size(dim)
- for dim in range(1, batch[i].pad_dims + 1):
- max_shape[dim - 1] = max(max_shape[dim - 1],
- sample.size(-dim))
- padded_samples = []
- for sample in batch[i:i + samples_per_gpu]:
- pad = [0 for _ in range(batch[i].pad_dims * 2)]
- for dim in range(1, batch[i].pad_dims + 1):
- pad[2 * dim -
- 1] = max_shape[dim - 1] - sample.size(-dim)
- padded_samples.append(
- F.pad(
- sample.data, pad, value=sample.padding_value))
- stacked.append(default_collate(padded_samples))
- elif batch[i].pad_dims is None:
- stacked.append(
- default_collate([
- sample.data
- for sample in batch[i:i + samples_per_gpu]
- ]))
- else:
- raise ValueError(
- 'pad_dims should be either None or integers (1-3)')
-
- else:
- for i in range(0, len(batch), samples_per_gpu):
- stacked.append(
- [sample.data for sample in batch[i:i + samples_per_gpu]])
- return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
- elif isinstance(batch[0], Sequence):
- transposed = zip(*batch)
- return [collate(samples, samples_per_gpu) for samples in transposed]
- elif isinstance(batch[0], Mapping):
- return {
- key: collate([d[key] for d in batch], samples_per_gpu)
- for key in batch[0]
- }
- else:
- return default_collate(batch)
|