You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

179 lines
7.6 KiB

  1. from time import time
  2. import os
  3. import argparse
  4. import ast
  5. import numpy as np
  6. import cv2
  7. import mindspore
  8. import mindspore.common.dtype as mstype
  9. from mindspore import nn
  10. from mindspore.train.serialization import load_param_into_net, load_checkpoint
  11. from mindspore.ops import operations as ops
  12. from mindspore import Tensor, context
  13. from mindspore.common import set_seed
  14. from mindspore.context import ParallelMode
  15. from mindspore.communication.management import init, get_rank, get_group_size
  16. from mindspore.train.callback import (
  17. CheckpointConfig,
  18. ModelCheckpoint,
  19. _InternalCallbackParam,
  20. RunContext,
  21. )
  22. from mindspore.ops import composite as C
  23. from src.model.RRDB_Net import RRDBNet
  24. from src.model.discriminator_net import VGGStyleDiscriminator128
  25. from src.model.cell import GeneratorLossCell, DiscriminatorLossCell, TrainOneStepCellDis, TrainOneStepCellGen
  26. from src.config.config import ESRGAN_config
  27. from src.dataset.dataset_DIV2K import get_dataset_DIV2K
  28. def parse_args():
  29. parser = argparse.ArgumentParser("ESRGAN")
  30. parser.add_argument('--device_target', type=str,
  31. default="Ascend", help='Platform')
  32. parser.add_argument('--device_id', type=int,
  33. default=7, help='device_id')
  34. parser.add_argument(
  35. "--aug", type=bool, default=True, help="Use augement for dataset"
  36. )
  37. parser.add_argument("--loss_scale", type=float,
  38. default=1024.0, help="loss scale")
  39. parser.add_argument('--data_dir', type=str,
  40. default=None, help='Dataset path')
  41. parser.add_argument("--batch_size", type=int, default=16, help="batch_size")
  42. parser.add_argument("--epoch_size", type=int,
  43. default=20, help="epoch_size")
  44. parser.add_argument('--Giters', type=int, default=2, help='number of G iters per each D iter')
  45. parser.add_argument("--rank", type=int, default=1,
  46. help="local rank of distributed")
  47. parser.add_argument(
  48. "--group_size", type=int, default=0, help="world size of distributed"
  49. )
  50. parser.add_argument(
  51. "--keep_checkpoint_max", type=int, default=40, help="max checkpoint for saving"
  52. )
  53. parser.add_argument(
  54. "--model_save_step", type=int, default=5000, help="step num for saving"
  55. )
  56. parser.add_argument('--Gpretrained_path', type=str, default="psnr.ckpt")
  57. parser.add_argument('--experiment', default="./images", help='Where to store samples and models')
  58. parser.add_argument("--run_distribute", type=ast.literal_eval,
  59. default=False, help="Run distribute, default: false.")
  60. # Modelarts
  61. args, _ = parser.parse_known_args()
  62. return args
  63. # save image
  64. def save_img(img, img_name,save_dir):
  65. save_img = C.clip_by_value(img.squeeze(), 0, 1).asnumpy().transpose(1, 2, 0)
  66. # save img
  67. save_fn = save_dir + '/' + img_name
  68. cv2.imwrite(save_fn, cv2.cvtColor(save_img * 255, cv2.COLOR_BGR2RGB), [cv2.IMWRITE_PNG_COMPRESSION, 0])
  69. def train():
  70. args_opt = parse_args()
  71. config = ESRGAN_config
  72. context.set_context(mode=context.GRAPH_MODE,device_target="Ascend", device_id=args_opt.device_id, save_graphs=False)
  73. # Device Environment
  74. if args_opt.run_distribute:
  75. if args_opt.device_target == "Ascend":
  76. rank = args_opt.device_id
  77. # device_num = device_num
  78. context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
  79. gradients_mean=True)
  80. init()
  81. else:
  82. init("nccl")
  83. context.reset_auto_parallel_context()
  84. rank = get_rank()
  85. device_num = get_group_size()
  86. context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
  87. gradients_mean=True)
  88. else:
  89. rank = 0
  90. device_num = 1
  91. dataset, dataset_len = get_dataset_DIV2K(
  92. base_dir="./data/lw", downsample_factor=config["down_factor"], mode="train", aug=args_opt.aug, repeat=10, batch_size=args_opt.batch_size,shard_id=args_opt.group_size,shard_num=args_opt.rank,num_readers=4)
  93. dataset_iter = dataset.create_dict_iterator()
  94. generator = RRDBNet(
  95. in_nc=config["ch_size"],
  96. out_nc=config["ch_size"],
  97. nf=config["G_nf"],
  98. nb=config["G_nb"],
  99. )
  100. discriminator = VGGStyleDiscriminator128(
  101. num_in_ch=config["ch_size"], num_feat=config["D_nf"])
  102. param_dict = load_checkpoint(args_opt.Gpretrained_path)
  103. load_param_into_net(generator, param_dict)
  104. # Define network with loss
  105. G_loss_cell = GeneratorLossCell(generator, discriminator,config["vgg_pretrain_path"])
  106. D_loss_cell = DiscriminatorLossCell(discriminator)
  107. lr_G = nn.piecewise_constant_lr(
  108. milestone=config["lr_steps"], learning_rates=config["lr_G"]
  109. )
  110. lr_D = nn.piecewise_constant_lr(
  111. milestone=config["lr_steps"], learning_rates=config["lr_D"]
  112. )
  113. optimizerD = nn.Adam(discriminator.trainable_params(
  114. ), learning_rate=lr_D, beta1=0.5, beta2=0.999,loss_scale=args_opt.loss_scale)
  115. optimizerG = nn.Adam(generator.trainable_params(
  116. ), learning_rate=lr_G, beta1=0.5, beta2=0.999,loss_scale=args_opt.loss_scale)
  117. # Define One step train
  118. G_trainOneStep = TrainOneStepCellGen(G_loss_cell, optimizerG)
  119. D_trainOneStep = TrainOneStepCellDis(D_loss_cell, optimizerD)
  120. # Train
  121. G_trainOneStep.set_train()
  122. D_trainOneStep.set_train()
  123. print('Start Training')
  124. ckpt_config = CheckpointConfig(
  125. save_checkpoint_steps=args_opt.model_save_step,keep_checkpoint_max=args_opt.keep_checkpoint_max)
  126. ckpt_cb_g = ModelCheckpoint(
  127. config=ckpt_config, directory="./checkpoints", prefix='Generator')
  128. ckpt_cb_d = ModelCheckpoint(
  129. config=ckpt_config, directory="./checkpoints", prefix='Discriminator')
  130. cb_params_g = _InternalCallbackParam()
  131. cb_params_g.train_network = generator
  132. cb_params_g.cur_step_num = 0
  133. cb_params_g.batch_num = args_opt.batch_size
  134. cb_params_g.cur_epoch_num = 0
  135. cb_params_d = _InternalCallbackParam()
  136. cb_params_d.train_network = discriminator
  137. cb_params_d.cur_step_num = 0
  138. cb_params_d.batch_num = args_opt.batch_size
  139. cb_params_d.cur_epoch_num = 0
  140. run_context_g = RunContext(cb_params_g)
  141. run_context_d = RunContext(cb_params_d)
  142. ckpt_cb_g.begin(run_context_g)
  143. ckpt_cb_d.begin(run_context_d)
  144. start = time()
  145. minibatch = args_opt.batch_size
  146. ones = ops.Ones()
  147. zeros = ops.Zeros()
  148. real_labels = ones((minibatch, 1), mindspore.float32)
  149. fake_labels = zeros((minibatch, 1), mindspore.float32)+Tensor(np.random.random(size=(minibatch,1)),dtype=mindspore.float32)*0.05
  150. num_iters = config["niter"]
  151. for iterator in range(num_iters):
  152. data = next(dataset_iter)
  153. inputs = data["inputs"]
  154. real_hr = data["target"]
  155. generator_loss_all = G_trainOneStep(inputs, real_hr, fake_labels, real_labels)
  156. fake_hr = generator_loss_all[0]
  157. generator_loss = generator_loss_all[1]
  158. if (iterator + 1) % args_opt.Giters == 0:
  159. discriminator_loss = D_trainOneStep(fake_hr,real_hr)
  160. if (iterator + 1) % 500 == 0:
  161. print('%d:[%d/%d]Loss_D: %10f Loss_G: %10f'
  162. % ((iterator+1)//dataset_len,iterator,num_iters,
  163. np.sum(discriminator_loss.asnumpy()), generator_loss.asnumpy()))
  164. save_img(real_hr[0], 'real_samples_{0}.png'.format(iterator + 1),args_opt.experiment)
  165. save_img(fake_hr[0], 'fake_samples_{0}.png'.format(iterator + 1),args_opt.experiment)
  166. if __name__ == "__main__":
  167. train()