You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

111 lines
3.6 KiB

  1. import mindspore
  2. from mindspore import nn
  3. from src.dataset.dataset_DIV2K import get_dataset_DIV2K
  4. from src.model.RRDB_Net import RRDBNet
  5. from src.config import config
  6. from mindspore.train.model import Model
  7. from mindspore.train.callback import LossMonitor, TimeMonitor
  8. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
  9. from mindspore import context
  10. import argparse
  11. from mindspore.communication.management import init
  12. class BuildTrainNetwork(nn.Cell):
  13. def __init__(self, network, criterion):
  14. super(BuildTrainNetwork, self).__init__()
  15. self.network = network
  16. self.criterion = criterion
  17. def construct(self, input_data, label):
  18. output = self.network(input_data)
  19. net_loss = self.criterion(output, label)
  20. return net_loss
  21. def parse_args():
  22. parser = argparse.ArgumentParser("Generator Pretrain")
  23. parser.add_argument(
  24. "--device_id",
  25. type=int,
  26. default=0,
  27. help="device id of GPU or Ascend. (Default: None)",
  28. )
  29. parser.add_argument("--loss_scale", type=float,
  30. default=1024.0, help="loss scale")
  31. parser.add_argument(
  32. "--aug", type=bool, default=True, help="Use augement for dataset"
  33. )
  34. parser.add_argument("--batch_size", type=int, default=4, help="batch_size")
  35. parser.add_argument("--epoch_size", type=int, default=20, help="epoch_size")
  36. parser.add_argument("--rank", type=int, default=0, help="local rank of distributed")
  37. parser.add_argument(
  38. "--group_size", type=int, default=1, help="world size of distributed"
  39. )
  40. parser.add_argument(
  41. "--save_steps", type=int, default=1000, help="steps interval for saving"
  42. )
  43. parser.add_argument(
  44. "--keep_checkpoint_max", type=int, default=20, help="max checkpoint for saving"
  45. )
  46. args, _ = parser.parse_known_args()
  47. return args
  48. def train(config):
  49. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=2)
  50. args = parse_args()
  51. config_psnr = config.PSNR_config
  52. model_psnr = RRDBNet(
  53. in_nc=config_psnr["ch_size"],
  54. out_nc=config_psnr["ch_size"],
  55. nf=config_psnr["G_nf"],
  56. nb=config_psnr["G_nb"],
  57. )
  58. dataset,dataset_len = get_dataset_DIV2K(
  59. base_dir="./data",
  60. downsample_factor=config_psnr["down_factor"],
  61. mode="train",
  62. aug=args.aug,
  63. repeat=1,
  64. num_readers=4,
  65. shard_id=args.rank,
  66. shard_num=args.group_size,
  67. batch_size=args.batch_size,
  68. )
  69. lr = nn.piecewise_constant_lr(
  70. milestone=config_psnr["lr_steps"], learning_rates=config_psnr["lr"]
  71. )
  72. opt = nn.Adam(
  73. params=model_psnr.trainable_params(), learning_rate=lr, beta1=0.9, beta2=0.99,loss_scale=args_opt.loss_scale
  74. )
  75. loss = nn.L1Loss()
  76. loss.add_flags_recursive(fp32=True)
  77. train_net = BuildTrainNetwork(model_psnr, loss)
  78. model = Model(train_net, optimizer=opt)
  79. # callback for saving ckpts
  80. time_cb = TimeMonitor(data_size=1000)
  81. loss_cb = LossMonitor()
  82. cbs = [time_cb, loss_cb]
  83. if args.rank == 0:
  84. config_ck = CheckpointConfig(
  85. save_checkpoint_steps=args.save_steps,
  86. keep_checkpoint_max=args.keep_checkpoint_max,
  87. )
  88. ckpoint_cb = ModelCheckpoint(
  89. prefix="psnr", directory="./checkpoints", config=config_ck
  90. )
  91. cbs.append(ckpoint_cb)
  92. model.train(
  93. args.epoch_size, dataset, callbacks=cbs, dataset_sink_mode=False,
  94. )
  95. if __name__ == "__main__":
  96. train(config)