|
- import mindspore
- import mindspore.dataset as ds
- import mindspore.dataset.vision.c_transforms as c_trans
- from mindspore.dataset.transforms.c_transforms import TypeCast
- import mindspore as ms
- import pickle
- import os
- import numpy as np
-
- class GetDatasetGenerator:
- def __init__(self, root, train=True):
- self.root = root
- self.train = train
- self.data = []
- self.targets = []
- if self.train:
- train_list = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5']
- else:
- train_list = ['test_batch']
- for file_name in train_list:
- file_path = os.path.join(self.root, "cifar-10-batches-py", file_name)
- with open(file_path, 'rb') as f:
- entry = pickle.load(f, encoding='latin1')
- self.data.append(entry['data'])
- self.targets.extend(entry['labels'])
-
- self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
- self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
-
-
- def __getitem__(self, index):
- # return (mindspore.Tensor(self.data[index], dtype=mindspore.uint8), mindspore.Tensor(self.targets[index], dtype=mindspore.uint32))
- return (self.data[index], self.targets[index])
-
- def __len__(self):
- return len(self.data)
-
-
- def create_cifar_dataset(dataset_path, do_train, batch_size, image_size=(224, 224)):
- dataset_generator = GetDatasetGenerator(dataset_path, do_train)
- dataset = ds.GeneratorDataset(dataset_generator, ["image", "label"], shuffle=do_train, num_parallel_workers=1)
-
- # define map operations
- trans = []
- if do_train:
- trans += [
- c_trans.RandomCrop((32, 32), (4, 4, 4, 4)),
- c_trans.RandomHorizontalFlip(prob=0.5)
- ]
-
- trans += [
- c_trans.Resize(image_size),
- c_trans.Rescale(1.0 / 255.0, 0.0),
- c_trans.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
- c_trans.HWC2CHW()
- ]
-
- type_cast_op = TypeCast(ms.int32)
-
- data_set = dataset.map(operations=type_cast_op, input_columns="label")
- data_set = data_set.map(operations=trans, input_columns="image")
-
- # apply batch operations
- data_set = data_set.batch(batch_size, drop_remainder=do_train)
- return data_set
-
- def load_dataloader(data_root, batch_size):
- trainloader = create_cifar_dataset(data_root, True, batch_size)
- testloader = create_cifar_dataset(data_root, False, batch_size)
- num_examples = {"trainset": int(trainloader.get_dataset_size() * trainloader.batch_size), "testset": int(testloader.get_dataset_size() * testloader.batch_size)}
- return trainloader, testloader, num_examples
-
-
-
- if __name__ == '__main__':
- train_path = "dataset/cifar-10-batches-bin"
- eval_path = "dataset/cifar-10-verify-bin"
|