|
- from PIL import ImageFile
- from mindspore import dtype as mstype
- import mindspore.dataset as de
- import mindspore.dataset.vision.c_transforms as vision_C
- import mindspore.dataset.transforms.c_transforms as normal_C
-
- ImageFile.LOAD_TRUNCATED_IMAGES = True
-
- def create_imagenet_dataset(mode, data_dir, image_size, per_batch_size,
- rank=0,
- group_size=1,
- num_parallel_workers=2,
- sampler=None,
- class_indexing=None,
- drop_remainder=True,
- transform=None,
- target_transform=None):
-
- de.config.set_num_parallel_workers(num_parallel_workers)
-
- mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
- std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
-
- if mode == 'train':
- shuffle = True
- else:
- shuffle = False
-
- if transform is None:
- if mode == 'train':
- transform_img = [
- vision_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
- vision_C.RandomHorizontalFlip(prob=0.5),
- vision_C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
- vision_C.Normalize(mean=mean, std=std),
- vision_C.HWC2CHW()
- ]
- else:
- transform_img = [
- vision_C.Decode(),
- vision_C.Resize((int(256*(image_size/224)), int(256*(image_size/224)))),
- # vision_C.Resize((image_size, image_size)),
- vision_C.CenterCrop(image_size),
- vision_C.Normalize(mean=mean, std=std),
- vision_C.HWC2CHW()
- ]
- else:
- transform_img = transform
-
- if target_transform is None:
- transform_label = [
- normal_C.TypeCast(mstype.int32)
- ]
- else:
- transform_label = target_transform
-
- de_dataset = de.ImageFolderDataset(data_dir, num_parallel_workers=num_parallel_workers,
- shuffle=shuffle, sampler=sampler, class_indexing=class_indexing,
- num_shards=group_size, shard_id=rank)
-
- de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=num_parallel_workers, operations=transform_img)
- de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=num_parallel_workers, operations=transform_label)
-
- columns_to_project = ["image", "label"]
- de_dataset = de_dataset.project(columns=columns_to_project)
-
- # de_dataset = de_dataset.shuffle(buffer_size=per_batch_size * 10)
- de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder)
- de_dataset = de_dataset.repeat(1)
-
- return de_dataset
|