|
- import json
-
- from monai.data import CacheDataset, DataLoader
- from monai.transforms import (
- Compose,
- EnsureTyped,
- EnsureChannelFirstd,
- LoadImaged,
- ScaleIntensityRanged,
- RandSpatialCropd,
- RandGaussianNoised,
- RandScaleIntensityd,
- RandShiftIntensityd,
- ResizeWithPadOrCropd,
- )
-
- def get_data(args):
- tf = open(args.data_dir, "r")
- new_dict = json.load(tf)
- train_files = new_dict['train_list'][:]
- valid_files = new_dict['valid_list'][:]
-
- print('Train files: {}'.format(len(train_files)))
- print('Valid files: {}'.format(len(valid_files)))
- return train_files, valid_files
-
-
- def get_transforms(args):
- common_transform = [
- LoadImaged(keys=["image", "label"]),
- EnsureChannelFirstd(keys=["image", "label"]),
- ScaleIntensityRanged(
- keys=["image"], a_min=args.a_min, a_max=args.a_max,
- b_min=0.0, b_max=1.0, clip=True,
- ),
- RandSpatialCropd(
- keys=["image", "label"],
- roi_size=args.image_size,
- random_size=False,
- ),
- ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=args.image_size),
- ]
- enhance_transform = [ ## 数据增强
- RandGaussianNoised(keys="image", prob=0.5, std=0.05),
- RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
- RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
- ]
- ensure_type = [EnsureTyped(keys=["image", "label"]),]
-
- train_transform = Compose(common_transform + enhance_transform + ensure_type)
- valid_transform = Compose(common_transform + ensure_type)
- return train_transform, valid_transform
-
-
- def get_loader(args):
- train_files, val_files = get_data(args)
- train_transform, val_transform = get_transforms(args)
- val_ds = CacheDataset(data=val_files, transform=val_transform, cache_rate=args.cache_rate, num_workers=args.workers)
- train_ds = CacheDataset(data=train_files, transform=train_transform, cache_rate=args.cache_rate, num_workers=args.workers)
- train_loader = DataLoader(
- train_ds, batch_size=args.batch_size, shuffle=True,
- num_workers=args.workers,
- )
- val_loader = DataLoader(val_ds, batch_size=args.batch_size, num_workers=args.workers)
-
- return train_ds, train_loader, val_loader
|