|
- import math
- import numpy as np
- from mindspore import context
- from mindspore import dataset as ds
- from mindspore.context import ParallelMode
- from mindspore.communication import get_rank, get_group_size
- import numpy as np
-
- class DistributedSampler:
-
- def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
- if num_replicas is None:
- print("***********Setting world_size to 1 since it is not passed in ******************")
- num_replicas = 1
- if rank is None:
- print("***********Setting rank to 0 since it is not passed in ******************")
- rank = 0
- self.dataset_size = dataset_size
- self.num_replicas = num_replicas
- self.rank = rank
- self.epoch = 0
- self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
- self.total_size = self.num_samples * self.num_replicas
- self.shuffle = shuffle
-
- def __iter__(self):
-
- if self.shuffle:
- indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
- # np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
- indices = indices.tolist()
- self.epoch += 1
- # change to list type
- else:
- indices = list(range(self.dataset_size))
-
- # add extra samples to make it evenly divisible
- indices += indices[:(self.total_size - len(indices))]
- assert len(indices) == self.total_size
-
- # subsample
- indices = indices[self.rank:self.total_size:self.num_replicas]
- assert len(indices) == self.num_samples
-
- return iter(indices)
-
- def __len__(self):
- return self.num_samples
-
- def create_train_dataset(dataset, args):
- """return train dataset """
- parallel_mode = context.get_auto_parallel_context("parallel_mode")
- if parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
- rank = get_rank()
- device_num = get_group_size()
- distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=True)
- # tr_loader = ds.GeneratorDataset(dataset, ["mixture", "lens", "sources"], num_parallel_workers=args.threads,
- # sampler=distributed_sampler)
- tr_loader = ds.GeneratorDataset(dataset, ["mixture", "lens", "sources"]
- sampler=distributed_sampler)
- else:
- # tr_loader = ds.GeneratorDataset(dataset, ["mixture", "lens", "sources"], num_parallel_workers=args.threads,
- # sampler=distributed_sampler)
- tr_loader = ds.GeneratorDataset(dataset, ["mixture", "lens", "sources"],
- sampler=distributed_sampler)
- tr_loader = tr_loader.batch(args.batch_size, drop_remainder=True)
- return tr_loader
|