|
- import os
- import mindspore
- from config import convmixer_cfg as cfg
- import mindspore.dataset.vision.py_transforms as py_vision
- import mindspore.dataset.vision.c_transforms as c_vision
- from mindspore.dataset.transforms import py_transforms
- from mindspore.dataset.transforms import c_transforms
- from mindspore.communication import get_rank, get_group_size
- from mindspore import dtype
-
-
- def create_dataset(data_path=r'D:\imagenet-1k\ImageNetdata\train', mode='train', batch_size=2, repeat_size=1,
- num_parallel_workers=6):
- """
- Create dataset for train or test
- Args:
- data_path: 数据集的路径
- batch_size: 每个训练batch的大小
- repeat_size: 数据集扩充倍数,训练时重复次数
- num_parallel_workers: 设置工作进程数
- mode: 指定创建数据集的模式,可选'train','val','test'
- """
- print(os.path.dirname(os.path.abspath(__file__)))
- print(os.listdir(data_path))
- num_shards = get_group_size()
- shard_id = get_rank()
- imagenet_dataset = mindspore.dataset.ImageFolderDataset(data_path, num_shards=num_shards, shard_id=shard_id)
- # imagenet_dataset = mindspore.dataset.MnistDataset(data_path, 'train', shuffle=True)
-
- # 设置 图像大小、图片像素归一化的参数
- resize_height, resize_width = cfg.image_height, cfg.image_width
- rescale = 1.0 / 255.0
- shift = 0.0
- rescale_nml = 1 / 0.3081
- shift_nml = -1 * 0.1307 / 0.3081
-
- # 定义 图片操作
- if mode == 'val':
- trans = [
- c_vision.Decode(),
- c_vision.Resize((resize_height, resize_width)), # 调整图片大小
- c_transforms.TypeCast(dtype.float32),
- c_vision.Rescale(rescale, shift), # 归一化图片像素值
- # c_vision.Normalize(mean, std),
- c_vision.Rescale(rescale_nml, shift_nml),
- c_vision.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network.
- ]
- elif mode == 'train':
- trans = [
- c_vision.RandomCropDecodeResize(size=(resize_height, resize_width), scale=(0.08, 1.0), ratio=(0.75, 1.333)),
- c_vision.RandomHorizontalFlip(prob=0.5),
- c_vision.Resize((resize_height, resize_width)), # 调整图片大小
- c_transforms.TypeCast(dtype.float32),
- c_vision.Rescale(rescale, shift), # 归一化图片像素值
- # c_vision.Normalize(mean, std),
- c_vision.Rescale(rescale_nml, shift_nml),
- c_vision.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network.
- ]
-
- label_typecast_op = c_transforms.TypeCast(dtype.int32) # change data type of label to int32 to fit network
-
- # 对图片数据进行操作
- imagenet_dataset = imagenet_dataset.map(input_columns="image", num_parallel_workers=num_parallel_workers,
- operations=trans)
- imagenet_dataset = imagenet_dataset.map(input_columns="label", num_parallel_workers=num_parallel_workers,
- operations=label_typecast_op)
-
- # 对数据集 划分批次
- buffer_size = 10000
- imagenet_dataset = imagenet_dataset.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
- imagenet_dataset = imagenet_dataset.batch(batch_size, drop_remainder=True)
- imagenet_dataset = imagenet_dataset.repeat(repeat_size)
-
- return imagenet_dataset
-
-
- if __name__ == '__main__':
- data_path = r'D:\imagenet-1k\ImageNetdata\train'
- imagenet_1k_dataset = create_dataset(data_path)
- print(imagenet_1k_dataset.dataset_size)
- print(imagenet_1k_dataset.num_classes())
|