You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

169 lines
6.1 KiB

  1. import mindspore
  2. from mindspore import nn
  3. from src.dataset.dataset_DIV2K import get_dataset_DIV2K
  4. from src.model.RRDB_Net import RRDBNet
  5. from src.config import config
  6. from mindspore.parallel import set_algo_parameters
  7. from mindspore.train.model import Model
  8. from mindspore.train.callback import LossMonitor, TimeMonitor
  9. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
  10. from mindspore import context
  11. import argparse
  12. from mindspore.communication.management import init
  13. class BuildTrainNetwork(nn.Cell):
  14. def __init__(self, network, criterion):
  15. super(BuildTrainNetwork, self).__init__()
  16. self.network = network
  17. self.criterion = criterion
  18. def construct(self, input_data, label):
  19. output = self.network(input_data)
  20. net_loss = self.criterion(output, label)
  21. return net_loss
  22. def parse_args():
  23. parser = argparse.ArgumentParser("Generator Pretrain")
  24. parser.add_argument("--data_url", type=str, default=None, help="Dataset path")
  25. parser.add_argument("--train_url", type=str, default=None, help="Train output path")
  26. parser.add_argument("--modelArts_mode", type=bool, default=True)
  27. parser.add_argument(
  28. "--device_id",
  29. type=int,
  30. default=0,
  31. help="device id of GPU or Ascend. (Default: None)",
  32. )
  33. parser.add_argument(
  34. "--aug", type=bool, default=True, help="Use augement for dataset"
  35. )
  36. parser.add_argument("--loss_scale", type=float,
  37. default=1024.0, help="loss scale")
  38. parser.add_argument("--batch_size", type=int, default=4, help="batch_size")
  39. parser.add_argument("--epoch_size", type=int, default=20, help="epoch_size")
  40. parser.add_argument("--rank", type=int, default=0, help="local rank of distributed")
  41. parser.add_argument(
  42. "--group_size", type=int, default=1, help="world size of distributed"
  43. )
  44. parser.add_argument(
  45. "--save_steps", type=int, default=1000, help="steps interval for saving"
  46. )
  47. parser.add_argument(
  48. "--keep_checkpoint_max", type=int, default=20, help="max checkpoint for saving"
  49. )
  50. # 分布式
  51. parser.add_argument("--distribute", type=bool, default=False, help="run distribute")
  52. args, _ = parser.parse_known_args()
  53. return args
  54. def train(config):
  55. args_opt = parse_args()
  56. config_psnr = config.PSNR_config
  57. # 这里开始 ModelArts部分
  58. device_num = int(os.getenv("RANK_SIZE"))
  59. device_id = int(os.getenv("DEVICE_ID"))
  60. rank_id = int(os.getenv('RANK_ID'))
  61. local_data_url = "/cache/data"
  62. local_train_url = "/cache/lwESRGAN"
  63. local_zipfolder_url = "/cache/tarzip"
  64. local_pretrain_url = "/cache/pretrain"
  65. obs_res_path = "obs://heu-535/pretrain"
  66. pretrain_filename = "vgg19_ImageNet.ckpt"
  67. filename = "DIV2K.zip"
  68. mox.file.make_dirs(local_train_url)
  69. context.set_context(mode=context.GRAPH_MODE,save_graphs=False,device_target="Ascend")
  70. # init multicards training
  71. if args.modelArts_mode:
  72. device_num = int(os.getenv("RANK_SIZE"))
  73. device_id = int(os.getenv("DEVICE_ID"))
  74. rank_id = int(os.getenv('RANK_ID'))
  75. parallel_mode = ParallelMode.DATA_PARALLEL
  76. context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
  77. context.set_auto_parallel_context(device_num=device_num,parallel_mode=parallel_mode, gradients_mean=True)
  78. set_algo_parameters(elementwise_op_strategy_follow=True)
  79. context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160])
  80. init()
  81. local_data_url = os.path.join(local_data_url, str(device_id))
  82. mox.file.make_dirs(local_data_url)
  83. local_zip_path = os.path.join(local_zipfolder_url, str(device_id), filename)
  84. print("device:%d, local_zip_path: %s" % (device_id, local_zip_path))
  85. obs_zip_path = os.path.join(args_opt.data_url, filename)
  86. mox.file.copy(obs_zip_path, local_zip_path)
  87. print(
  88. "====================== device %d copy end =================================\n"
  89. % (device_id)
  90. )
  91. unzip_command = "unzip -o %s -d %s" % (local_zip_path, local_data_url)
  92. os.system(unzip_command)
  93. print(
  94. "======================= device %d unzip end =================================\n"
  95. % (device_id)
  96. )
  97. # transfer dataset
  98. local_pretrain_url = os.path.join(local_zipfolder_url,pretrain_filename)
  99. obs_pretrain_url = os.path.join(obs_res_path,pretrain_filename)
  100. mox.file.copy(obs_pretrain_url, local_pretrain_url)
  101. model_psnr = RRDBNet(
  102. in_nc=config_psnr["ch_size"],
  103. out_nc=config_psnr["ch_size"],
  104. nf=config_psnr["G_nf"],
  105. nb=config_psnr["G_nb"],
  106. )
  107. dataset,dataset_len = get_dataset_DIV2K(
  108. base_dir=local_data_url,
  109. downsample_factor=config_psnr["down_factor"],
  110. mode="train",
  111. aug=args_opt.aug,
  112. repeat=1,
  113. num_readers=4,
  114. shard_id=args_opt.rank,
  115. shard_num=args_opt.group_size,
  116. batch_size=args_opt.batch_size,
  117. )
  118. lr = nn.piecewise_constant_lr(
  119. milestone=config_psnr["lr_steps"], learning_rates=config_psnr["lr"]
  120. )
  121. opt = nn.Adam(
  122. params=model_psnr.trainable_params(), learning_rate=lr, beta1=0.9, beta2=0.99,loss_scale=args_opt.loss_scale
  123. )
  124. loss = nn.L1Loss()
  125. loss.add_flags_recursive(fp32=True)
  126. train_net = BuildTrainNetwork(model_psnr, loss)
  127. iters_per_check = dataset_len
  128. model = Model(train_net, optimizer=opt)
  129. # callback for saving ckpts
  130. time_cb = TimeMonitor()
  131. loss_cb = LossMonitor()
  132. cbs = [time_cb, loss_cb]
  133. config_ck = CheckpointConfig(
  134. save_checkpoint_steps=args_opt.save_steps,
  135. keep_checkpoint_max=args_opt.keep_checkpoint_max,
  136. )
  137. ckpoint_cb = ModelCheckpoint(
  138. prefix="psnr", directory=local_train_url, config=config_ck
  139. )
  140. if device_id ==0:
  141. cbs.append(ckpoint_cb)
  142. model.train(
  143. args_opt.epoch_size, dataset, callbacks=cbs, dataset_sink_mode=False,
  144. )
  145. if device_id == 0:
  146. mox.file.copy_parallel(local_train_url, args_opt.train_url)
  147. if __name__ == "__main__":
  148. train(config)