您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

132 行
4.7 KiB

  1. """Evaluation"""
  2. import os
  3. import time
  4. import argparse
  5. import datetime
  6. import glob
  7. import numpy as np
  8. import cv2
  9. from collections import OrderedDict
  10. import mindspore.nn as nn
  11. from mindspore.nn import PSNR,SSIM
  12. from mindspore import Tensor, context
  13. from mindspore.train.model import Model
  14. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  15. from mindspore.ops import operations as P
  16. from mindspore.ops import functional as F
  17. from mindspore.common import dtype as mstype
  18. from src.config.config import ESRGAN_config,PSNR_config
  19. from src.model.RRDB_Net import RRDBNet
  20. class BuildEvalNetwork(nn.Cell):
  21. def __init__(self, network):
  22. super(BuildEvalNetwork, self).__init__()
  23. self.network = network
  24. def construct(self, input_data):
  25. output = self.network(input_data)
  26. return output
  27. def parse_args(cloud_args=None):
  28. """parse_args"""
  29. parser = argparse.ArgumentParser('Eval ESRGAN')
  30. parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
  31. help='device where the code will be implemented. (Default: Ascend)')
  32. # dataset related
  33. parser.add_argument('--data_path', type=str,
  34. default='', help='eval data dir')
  35. parser.add_argument('--ganckpt_path', type=str,
  36. default='', help='gan ckpt file')
  37. parser.add_argument('--psnrckpt_path', type=str,
  38. default='', help='psnr ckpt file')
  39. parser.add_argument('--batch_size', default=16,
  40. type=int, help='batch size for per npu')
  41. # logging related
  42. parser.add_argument('--log_path', type=str,
  43. default='outputs/', help='path to save log')
  44. parser.add_argument('--rank', type=int, default=0,
  45. help='local rank of distributed')
  46. parser.add_argument('--group_size', type=int, default=1,
  47. help='world size of distributed')
  48. parser.add_argument('--alpha', type=float, default=0.4,
  49. help='weight factor of psnr model in eval')
  50. args_opt = parser.parse_args()
  51. return args_opt
  52. set_seed(1)
  53. def test():
  54. args_opt = parse_args()
  55. config_psnr = PSNR_config
  56. print(f"test args: {args_opt}\ncfg: {config}")
  57. context.set_context(
  58. mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=1
  59. )
  60. model_psnr = RRDBNet(
  61. in_nc=config_psnr["ch_size"],
  62. out_nc=config_psnr["ch_size"],
  63. nf=config_psnr["G_nf"],
  64. nb=config_psnr["G_nb"],
  65. )
  66. # 需要对每个参数进行单独计算
  67. dataset,dataset_len = get_dataset_DIV2K(
  68. base_dir="./data",
  69. downsample_factor=config_psnr["down_factor"],
  70. mode="valid",
  71. aug=False,
  72. repeat=1,
  73. num_readers=4,
  74. shard_id=args.rank,
  75. shard_num=args.group_size,
  76. batch_size=1,
  77. )
  78. eval_net = BuildEvalNetwork(model_psnr)
  79. # load model and Interpolating
  80. param_dict_gan = load_checkpoint(args_opt.ganckpt_path)
  81. param_dict_psnr = load_checkpoint(args_opt.psnrckpt_path)
  82. param_dict = OrderedDict()
  83. alpha = args_opt.alpha
  84. print('Interpolating with alpha = ', alpha)
  85. for name,cell_PSNR in net_PSNR.cells_and_names():
  86. cell_ESRGAN = param_dict_gan[name]
  87. net_interp[name] = (1 - alpha) * cell_PSNR + alpha * cell_ESRGAN
  88. load_param_into_net(eval_net, param_dict)
  89. eval_net.set_train(False)
  90. ssim = nn.SSIM()
  91. psnr = nn.PSNR()
  92. test_data_iter = dataset.create_dict_iter(out_numpy=False)
  93. psnr_bic_all = 0.0
  94. psnr_real_all = 0.0
  95. ssim_bic_all = 0.0
  96. ssim_real_all = 0.0
  97. for i, sample in enumerate(test_data):
  98. lr = sample['inputs']
  99. real_hr = sample['target']
  100. gen_hr = eval_net(lr)
  101. # 这里用mindspore的双三次插值采样结果
  102. bic_hr = None
  103. psnr_bic = psnr(gen_hr,bic_hr)
  104. psnr_real = psnr(gen_hr,real_hr)
  105. ssim_bic = ssim(gen_hr,bic_hr)
  106. ssim_real = ssim(gen_hr,real_hr)
  107. psnr_bic_all += psnr_bic
  108. psnr_real_all = psnr_real
  109. ssim_bic_all = ssim_bic
  110. ssim_real_all = ssim_real
  111. print(psnr_bic,psnr_real,ssim_bic,ssim_real)
  112. result_img_path = os.path.join(args_opt.results_path + "DIV2K", 'Bic_SR_HR_' + str(i))
  113. if i % 50 == 0:
  114. results_img = np.concatenate((bic_img[0].asnumpy(), sr_img[0].asnumpy(), hr_img[0].asnumpy()), 1)
  115. cv2.imwrite(result_img_path, results_img)
  116. psnr_bic_all += psnr_bic_all/dataset_len
  117. psnr_real_all = psnr_real_all/dataset_len
  118. ssim_bic_all = ssim_bic_all/dataset_len
  119. ssim_real_all = ssim_real_all/dataset_len
  120. print(psnr_bic_all,psnr_real_all,ssim_bic_all,ssim_real_all)