|
- import albumentations as A
- import mindspore.dataset.vision.py_transforms as py_vision
-
- # DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
- TRAIN_DIR = "maps/train"
- VAL_DIR = "maps/val"
- LEARNING_RATE = 0.0002
- BATCH_SIZE = 1
- NUM_WORKERS = 2
- IMAGE_SIZE = 256
- CHANNELS_IMG = 3
- L1_LAMBDA = 100
- LAMBDA_GP = 10
- NUM_EPOCHS = 100
- LOAD_MODEL = False
- SAVE_MODEL = True
- CHECKPOINT_DISC = "disc.pth.tar"
- CHECKPOINT_GEN = "gen.pth.tar"
-
- both_transform = A.Compose(
- [A.Resize(width=256, height=256),A.HorizontalFlip(p=0.5)], additional_targets={"image0": "image"},
- )
-
- transform_only_input = A.Compose(
- [
- A.HorizontalFlip(p=0.5),
- A.ColorJitter(p=0.2),
- A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
- py_vision.ToTensor(),
- ]
- )
-
- transform_only_mask = A.Compose(
- [
- A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
- py_vision.ToTensor(),
- ]
- )
-
-
-
-
-
-
-
-
- # 下面是修改后 不使用albumentation库房 的代码
- # import torch
- # # import albumentations as A
- # # from albumentations.pytorch import ToTensorV2
- # import mindspore.dataset.vision.py_transforms as py_vision
- # import mindspore.dataset.transforms.py_transforms as py_transforms
- #
- #
- #
- # DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
- # TRAIN_DIR = "maps/train"
- # VAL_DIR = "maps/val"
- # LEARNING_RATE = 2e-4
- # BATCH_SIZE = 16
- # NUM_WORKERS = 2
- # IMAGE_SIZE = 256
- # CHANNELS_IMG = 3
- # L1_LAMBDA = 100
- # LAMBDA_GP = 10
- # NUM_EPOCHS = 500
- # LOAD_MODEL = False
- # SAVE_MODEL = False
- # CHECKPOINT_DISC = "disc.pth.tar"
- # CHECKPOINT_GEN = "gen.pth.tar"
- #
- # both_transform = py_transforms.Compose(
- # [py_transforms.Resize(width=256, height=256),], additional_targets={"image0": "image"},
- # )
- #
- # transform_only_input = py_transforms.Compose(
- # [
- # py_vision.RandomHorizontalFlip(p=0.5),
- # py_vision.RandomColorAdjust(p=0.2), #与ColorJitter()参数不对应
- # py_vision.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), #MS中没有 max_pixel_value=255.0 参数
- # py_vision.ToTensor(),
- # ]
- # )
- #
- # transform_only_mask = py_transforms.Compose(
- # [
- # py_vision.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
- # py_vision.ToTensor(),
- # ]
- # )
|