You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

233 lines
10 KiB

  1. from time import time
  2. import os
  3. import argparse
  4. import ast
  5. import numpy as np
  6. from PIL import Image
  7. import mindspore
  8. import mindspore.common.dtype as mstype
  9. from mindspore import nn
  10. from mindspore.train.serialization import load_param_into_net, load_checkpoint
  11. from mindspore.ops import operations as ops
  12. from mindspore import Tensor, context
  13. from mindspore.common import set_seed
  14. from mindspore.context import ParallelMode
  15. from mindspore.communication.management import init, get_rank, get_group_size
  16. from mindspore.train.callback import (
  17. CheckpointConfig,
  18. ModelCheckpoint,
  19. _InternalCallbackParam,
  20. RunContext,
  21. )
  22. import moxing as mox
  23. from src.model.RRDB_Net import RRDBNet
  24. from src.model.discriminator_net import VGGStyleDiscriminator128
  25. from src.model.cell import GeneratorLossCell, DiscriminatorLossCell, TrainOneStepCellDis, TrainOneStepCellGen
  26. from src.config.config import ESRGAN_config
  27. from src.dataset.dataset_DIV2K import get_dataset_DIV2K
  28. # save image
  29. def save_image(img, img_path):
  30. mul = ops.Mul()
  31. add = ops.Add()
  32. if isinstance(img, Tensor):
  33. img = mul(img, 0.5)
  34. img = add(img, 0.5)
  35. img = img.asnumpy().astype(np.uint8).transpose((0, 2, 3, 1))
  36. elif not isinstance(img, np.ndarray):
  37. raise ValueError("img should be Tensor or numpy array, but get {}".format(type(img)))
  38. IMAGE_SIZE = 64 # Image size
  39. IMAGE_ROW = 8 # Row num
  40. IMAGE_COLUMN = 8 # Column num
  41. PADDING = 2 #Interval of small pictures
  42. to_image = Image.new('RGB', (IMAGE_COLUMN * IMAGE_SIZE + PADDING * (IMAGE_COLUMN + 1),
  43. IMAGE_ROW * IMAGE_SIZE + PADDING * (IMAGE_ROW + 1))) # create a new picture
  44. # cycle
  45. i = 0
  46. for y in range(1, IMAGE_ROW + 1):
  47. for x in range(1, IMAGE_COLUMN + 1):
  48. from_image = Image.fromarray(img[i])
  49. to_image.paste(from_image, ((x - 1) * IMAGE_SIZE + PADDING * x, (y - 1) * IMAGE_SIZE + PADDING * y))
  50. i = i + 1
  51. to_image.save(img_path) #save
  52. def parse_args():
  53. parser = argparse.ArgumentParser("ESRGAN")
  54. parser.add_argument("--data_url", type=str, default=None, help="Dataset path")
  55. parser.add_argument("--train_url", type=str, default=None, help="Train output path")
  56. parser.add_argument("--modelArts_mode", type=bool, default=True)
  57. #
  58. parser.add_argument('--device_target', type=str,
  59. default="Ascend", help='Platform')
  60. parser.add_argument('--device_id', type=int,
  61. default=3, help='device_id')
  62. parser.add_argument(
  63. "--aug", type=bool, default=True, help="Use augement for dataset"
  64. )
  65. parser.add_argument("--loss_scale", type=float,
  66. default=1024.0, help="loss scale")
  67. parser.add_argument('--data_dir', type=str,
  68. default=None, help='Dataset path')
  69. parser.add_argument("--batch_size", type=int, default=16, help="batch_size")
  70. parser.add_argument("--epoch_size", type=int,
  71. default=20, help="epoch_size")
  72. parser.add_argument('--Giters', type=int, default=5, help='number of G iters per each D iter')
  73. #
  74. parser.add_argument("--rank", type=int, default=1,
  75. help="local rank of distributed")
  76. parser.add_argument(
  77. "--group_size", type=int, default=0, help="world size of distributed"
  78. )
  79. #
  80. parser.add_argument(
  81. "--keep_checkpoint_max", type=int, default=30, help="max checkpoint for saving"
  82. )
  83. parser.add_argument(
  84. "--model_save_step", type=int, default=3000, help="step num for saving"
  85. )
  86. parser.add_argument('--experiment', default="./images", help='Where to store samples and models')
  87. #
  88. parser.add_argument("--run_distribute", type=ast.literal_eval,
  89. default=False, help="Run distribute, default: false.")
  90. args, _ = parser.parse_known_args()
  91. return args
  92. def train():
  93. args_opt = parse_args()
  94. config = ESRGAN_config
  95. device_num = int(os.getenv("RANK_SIZE"))
  96. device_id = int(os.getenv("DEVICE_ID"))
  97. rank_id = int(os.getenv('RANK_ID'))
  98. local_data_url = "/cache/data"
  99. local_train_url = "/cache/lwESRGAN"
  100. local_zipfolder_url = "/cache/tarzip"
  101. local_pretrain_url = "/cache/pretrain"
  102. local_image_url = "/cache/ESRGANimage"
  103. obs_res_path = "obs://heu-535/pretrain"
  104. pretrain_filename = "psnr-X_XXXXX.ckpt"
  105. vgg_filename = ""
  106. filename = "DIV2K.zip"
  107. mox.file.make_dirs(local_train_url)
  108. mox.file.make_dirs(local_image_url)
  109. context.set_context(mode=context.GRAPH_MODE,save_graphs=False,device_target="Ascend")
  110. # init multicards training
  111. if args_opt.modelArts_mode:
  112. device_num = int(os.getenv("RANK_SIZE"))
  113. device_id = int(os.getenv("DEVICE_ID"))
  114. rank_id = int(os.getenv('RANK_ID'))
  115. parallel_mode = ParallelMode.DATA_PARALLEL
  116. context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
  117. context.set_auto_parallel_context(device_num=device_num,parallel_mode=parallel_mode, gradients_mean=True)
  118. set_algo_parameters(elementwise_op_strategy_follow=True)
  119. context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160])
  120. init()
  121. local_data_url = os.path.join(local_data_url, str(device_id))
  122. mox.file.make_dirs(local_data_url)
  123. local_zip_path = os.path.join(local_zipfolder_url, str(device_id), filename)
  124. print("device:%d, local_zip_path: %s" % (device_id, local_zip_path))
  125. obs_zip_path = os.path.join(args_opt.data_url, filename)
  126. mox.file.copy(obs_zip_path, local_zip_path)
  127. print(
  128. "====================== device %d copy end =================================\n"
  129. % (device_id)
  130. )
  131. unzip_command = "unzip -o %s -d %s" % (local_zip_path, local_data_url)
  132. os.system(unzip_command)
  133. print(
  134. "======================= device %d unzip end =================================\n"
  135. % (device_id)
  136. )
  137. # transfer dataset
  138. local_pretrain_url = os.path.join(local_zipfolder_url,pretrain_filename)
  139. local_pretrain_url_vgg = os.path.join(local_zipfolder_url,vgg_filename)
  140. obs_pretrain_url = os.path.join(obs_res_path,pretrain_filename)
  141. mox.file.copy(obs_pretrain_url, local_pretrain_url)
  142. dataset, dataset_len = get_dataset_DIV2K(base_dir=local_data_url, downsample_factor=config["down_factor"], mode="train", aug=args_opt.aug, repeat=1, batch_size=args_opt.batch_size,shard_id=args_opt.group_size,shard_num=args_opt.rank,num_readers=4)
  143. generator = RRDBNet(
  144. in_nc=config["ch_size"],
  145. out_nc=config["ch_size"],
  146. nf=config["G_nf"],
  147. nb=config["G_nb"],
  148. )
  149. discriminator = VGGStyleDiscriminator128(num_in_ch=config["ch_size"], num_feat=config["D_nf"])
  150. param_dict = load_checkpoint(local_pretrain_url)
  151. load_param_into_net(generator, param_dict)
  152. # Define network with loss
  153. G_loss_cell = GeneratorLossCell(generator, discriminator,local_pretrain_url_vgg)
  154. D_loss_cell = DiscriminatorLossCell(discriminator)
  155. lr_G = nn.piecewise_constant_lr(
  156. milestone=config["lr_steps"], learning_rates=config["lr_G"]
  157. )
  158. lr_D = nn.piecewise_constant_lr(
  159. milestone=config["lr_steps"], learning_rates=config["lr_D"]
  160. )
  161. optimizerD = nn.Adam(discriminator.trainable_params(
  162. ), learning_rate=lr_D, beta1=0.5, beta2=0.999,loss_scale=args_opt.loss_scale)
  163. optimizerG = nn.Adam(generator.trainable_params(
  164. ), learning_rate=lr_G, beta1=0.5, beta2=0.999,loss_scale=args_opt.loss_scale)
  165. # Define One step train
  166. G_trainOneStep = TrainOneStepCellGen(G_loss_cell, optimizerG)
  167. D_trainOneStep = TrainOneStepCellDis(D_loss_cell, optimizerD)
  168. # Train
  169. G_trainOneStep.set_train()
  170. D_trainOneStep.set_train()
  171. print('Start Training')
  172. ckpt_config = CheckpointConfig(
  173. save_checkpoint_steps=args_opt.model_save_step,keep_checkpoint_max=args_opt.keep_checkpoint_max)
  174. ckpt_cb_g = ModelCheckpoint(
  175. config=ckpt_config, directory=local_train_url, prefix='Generator')
  176. ckpt_cb_d = ModelCheckpoint(
  177. config=ckpt_config, directory=local_train_url, prefix='Discriminator')
  178. cb_params_g = _InternalCallbackParam()
  179. cb_params_g.train_network = generator
  180. cb_params_g.cur_step_num = 0
  181. cb_params_g.batch_num = args_opt.batch_size
  182. cb_params_g.cur_epoch_num = 0
  183. cb_params_d = _InternalCallbackParam()
  184. cb_params_d.train_network = discriminator
  185. cb_params_d.cur_step_num = 0
  186. cb_params_d.batch_num = args_opt.batch_size
  187. cb_params_d.cur_epoch_num = 0
  188. run_context_g = RunContext(cb_params_g)
  189. run_context_d = RunContext(cb_params_d)
  190. ckpt_cb_g.begin(run_context_g)
  191. ckpt_cb_d.begin(run_context_d)
  192. start = time()
  193. minibatch = args_opt.batch_size
  194. ones = ops.Ones()
  195. zeros = ops.Zeros()
  196. real_labels = ones((minibatch, 1), mindspore.float32)
  197. fake_labels = zeros((minibatch, 1), mindspore.float32)+Tensor(np.random.random(size=(minibatch,1)),dtype=mindspore.float32)*0.05
  198. num_iters = config["niter"]
  199. for iterator in range(num_iters):
  200. data = next(dataset_iter)
  201. inputs = data["inputs"]
  202. real_hr = data["target"]
  203. generator_loss_all = G_trainOneStep(inputs, real_hr, fake_labels, real_labels)
  204. fake_hr = generator_loss_all[0]
  205. generator_loss = generator_loss_all[1]
  206. if (iterator + 1) % args_opt.Giters == 0:
  207. discriminator_loss = D_trainOneStep(fake_hr,real_hr)
  208. if (iterator + 1) % 5000 == 0:
  209. print('%d:[%d/%d]Loss_D: %10f Loss_G: %10f'
  210. % ((iterator+1)//dataset_len,iterator,num_iters,
  211. np.sum(discriminator_loss.asnumpy()), generator_loss.asnumpy()))
  212. save_img(real_hr[0], 'real_samples_{0}.png'.format(iterator + 1),local_image_url)
  213. save_img(fake_hr[0], 'fake_samples_{0}.png'.format(iterator + 1),local_image_url)
  214. if device_id == 0:
  215. mox.file.copy_parallel(local_train_url, args_opt.train_url)
  216. mox.file.copy_parallel(local_image_url, args_opt.train_url)
  217. if __name__ == "__main__":
  218. train(config)