|
- import os
-
- os.environ['DEVICE_ID'] = "1"
- # import random
- import argparse
- from solver import Solver
- from dataloader import ImageFolder
- import mindspore.dataset as ds
-
-
- def main(config):
- if config.type:
- config.ckpt_save_root = os.path.join(config.result_path, "checkpoints_n")
- else:
- config.ckpt_save_root = os.path.join(config.result_path, "checkpoints_p")
- print(">>>>>>>>>>>Save Path>>>>>>>>>>>>:", config.ckpt_save_root)
- if not os.path.exists(config.ckpt_save_root):
- os.makedirs(config.ckpt_save_root)
- config.num_epochs_decay = int(config.num_epochs * config.decay_ratio)
- print(config)
- # Save config
- config_save_path_f = open(os.path.join(config.result_path, "config.txt"), 'a')
- config_save_path_f.write(str(config))
- config_save_path_f.close()
-
- data_dir = "BCData"
- # ----- load data for training ----- #
- train_image = ImageFolder(root=config.train_path,
- crop_size=config.crop_size,
- mode='train',
- augmentation_prob=config.augmentation_prob,
- type=config.type)
- train_loader = ds.GeneratorDataset(train_image, column_names=["input", "target", "paths"], shuffle=True)
- train_loader = train_loader.batch(config.batch_size)
- config.steps = train_loader.get_dataset_size()
-
- # ----- load data for validation ----- #
- valid_image = ImageFolder(root=config.valid_path,
- crop_size=config.crop_size,
- mode='train',
- augmentation_prob=config.augmentation_prob,
- type=config.type)
- valid_loader = ds.GeneratorDataset(valid_image, column_names=["input", "target", "paths"], shuffle=False)
- valid_loader = valid_loader.batch(config.batch_size)
- solver = Solver(config, train_loader, valid_loader)
-
- # Train and sample the images
- solver.train()
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--img_ch', type=int, default=3)
- parser.add_argument('--output_ch', type=int, default=1)
- parser.add_argument('--type', type=int, default=0)
- parser.add_argument('--num_epochs', type=int, default=90)
- parser.add_argument('--decay_ratio', type=float, default=0.5)
- parser.add_argument('--batch_size', type=int, default=4)
- parser.add_argument('--num_workers', type=int, default=8)
- parser.add_argument('--lr', type=float, default=1e-6)
- parser.add_argument('--augmentation_prob', type=float, default=1.)
- parser.add_argument('--crop_size', type=int, default=512)
-
- parser.add_argument('--log_step', type=int, default=2)
- parser.add_argument('--val_step', type=int, default=2)
- parser.add_argument('--ckpt_save_freq', type=int, default=1)
-
- parser.add_argument('--model_type', type=str, default='UCSRNet')
- parser.add_argument('--train_path', type=str, default='BCData/images/train/')
- parser.add_argument('--valid_path', type=str, default='BCData/images/validation/')
- parser.add_argument('--result_path', type=str, default='results')
-
- config = parser.parse_args()
- main(config)
|