|
- import mindspore
- import mindspore.dataset.vision.c_transforms as c_vision
- from mindspore.dataset.transforms import c_transforms
- from mindspore import dtype
- from config import convmixer_cfg as cfg
- from mindspore.communication.management import get_rank, get_group_size
-
- def create_dataset(data_path, batch_size=cfg.batch_size, repeat_size=cfg.repeat_size,
- rank_id=0, rank_size=1,num_parallel_workers=6):
- """ create dataset for train or test
- Args:
- data_path: 数据集的路径
- batch_size: 每个训练batch的大小
- repeat_size: 数据集扩充倍数,训练时重复次数
- num_parallel_workers: 设置工作进程数
- """
-
- rank_id = get_rank()
- rank_size = get_group_size()
-
- imagenet_dataset = mindspore.dataset.ImageFolderDataset(data_path, num_shards=rank_size, shard_id=rank_id)
- # mindspore.dataset.ImageFolderDataset
- # 设置 图像大小、图片像素归一化的参数
- resize_height, resize_width = cfg.image_height, cfg.image_width
- rescale = 1.0 / 255.0
- shift = 0.0
- rescale_nml = 1 / 0.3081
- shift_nml = -1 * 0.1307 / 0.3081
-
- # 定义 图片操作
- trans = [
- c_vision.RandomCropDecodeResize(size=(resize_height, resize_width), scale=(0.08, 1.0), ratio=(0.75, 1.333)),
- c_vision.RandomHorizontalFlip(prob=0.5),
- #c_vision.Resize((resize_height, resize_width)), # 调整图片大小
- c_transforms.TypeCast(dtype.float32),
- c_vision.Rescale(rescale, shift), # 归一化图片像素值
- c_vision.Rescale(rescale_nml, shift_nml),
- c_vision.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network.
- ]
-
- labels_typecast_op = c_transforms.TypeCast(dtype.int32)
-
-
- # 对图片数据进行操作
- imagenet_dataset = imagenet_dataset.map(input_columns="image",num_parallel_workers=num_parallel_workers,
- operations=trans)
- imagenet_dataset = imagenet_dataset.map(input_columns="label",num_parallel_workers=num_parallel_workers,
- operations=labels_typecast_op)
-
- # 对数据集 划分批次
- imagenet_dataset = imagenet_dataset.shuffle(buffer_size=cfg.buffer_size) # 10000 as in LeNet train script
- imagenet_dataset = imagenet_dataset.batch(batch_size, drop_remainder=True)
- imagenet_dataset = imagenet_dataset.repeat(repeat_size)
-
- return imagenet_dataset
-
-
- if __name__ =='__main__':
- data_path = r'D:\imagenet-1k\ImageNetdata\train'
- imagenet_1k_dataset = create_dataset(data_path)
- print(imagenet_1k_dataset.dataset_size)
- print(imagenet_1k_dataset.num_classes())
|