You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

36 lines
1.3 KiB

  1. import numpy as np
  2. from mindspore import context, Tensor
  3. from mindspore.train.serialization import export, load_param_into_net
  4. from src.config.config import ESRGAN_config,PSNR_config
  5. from src.utils import get_network, resume_model
  6. if __name__ == '__main__':
  7. config_psnr = PSNR_config
  8. context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
  9. model_psnr = RRDBNet(
  10. in_nc=config_psnr["ch_size"],
  11. out_nc=config_psnr["ch_size"],
  12. nf=config_psnr["G_nf"],
  13. nb=config_psnr["G_nb"],
  14. )
  15. model_psnr.set_train(True)
  16. param_dict_gan = load_checkpoint(args_opt.ganckpt_path)
  17. param_dict_psnr = load_checkpoint(args_opt.psnrckpt_path)
  18. param_dict = OrderedDict()
  19. alpha = args_opt.alpha
  20. print('Interpolating with alpha = ', alpha)
  21. for name,cell_PSNR in net_PSNR.cells_and_names():
  22. cell_ESRGAN = param_dict_gan[name]
  23. net_interp[name] = (1 - alpha) * cell_PSNR + alpha * cell_ESRGAN
  24. load_param_into_net(model_psnr, param_dict)
  25. input_array = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 32, 32)).astype(np.float32))
  26. input_label = Tensor(np.random.uniform(-1.0, 1.0, size=(1, 3, 128,128)).astype(np.float32))
  27. G_file = f"ESRGAN_Generator"
  28. export(G, input_array, file_name=G_file + '-300_11.air', file_format='AIR')