|
- from os import path as osp
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- import numpy as np
- from mindspore import Tensor, context
- from mindspore.common import dtype as mstype
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- import sys
- sys.path.append("/disk2/yuyanze/BasicSR/BasicSR")
- from basicsr.archs import build_network
- from basicsr.data import build_dataloader, build_dataset
- from basicsr.data.data_sampler import EnlargedSampler
- from basicsr.metrics.metric import calc_psnr, quantize, calc_ssim,hwc2chw
- from basicsr.utils_edvr.options import parse_options
-
- def create_train_val_dataloader(opt,phase):
- # create train and val dataloaders
- train_loader= []
- #print("***********************************",opt['dist'])
- if phase == 'train':
- dataset_enlarge_ratio = 200
- dataset_opt = opt['datasets']['train']
- train_set = build_dataset(dataset_opt)
- train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
- dataset = build_dataloader(
- train_set,
- phase,
- num_gpu=opt['num_gpu'],
- dist=opt['dist'],
- sampler=train_sampler,
- seed=opt['manual_seed'])
-
- elif phase in ['val', 'test']:
- dataset_opt = opt['datasets']['val']
- val_set = build_dataset(dataset_opt)
- dataset = build_dataloader(val_set, phase, num_gpu=opt['num_gpu'], dist=opt['dist'], seed=opt['manual_seed'])
- else:
- raise ValueError(f'Dataset phase {phase} is not recognized.')
-
- return dataset
-
- def init_net(opt):
- """
- init edsr network
- """
- # define network
- net_g = build_network(opt['network_g'])
-
- # load pretrained models
- load_path = opt['path'].get('pretrain_network_g', None)
- if load_path is not None:
- param_key = opt['path'].get('param_key_g', 'params')
- print(net_g)
- return net_g
-
- def test_pipeline(root_path):
- opt, args = parse_options(root_path, is_train=True)
- opt['root_path'] = root_path
-
- val_dataset = create_train_val_dataloader(opt,'val')
-
- net_m = init_net(opt)
- param_dict = load_checkpoint("EDVR_L_x4_SR_Vimeo90K_official-162b54e4.ckpt")
- param_not_load=load_param_into_net(net_m, param_dict,strict_load=False)
- print("The param is not loaded",param_not_load)
- net_m.set_train(False)
- print("load net weights successfully")
-
-
- for batch_idx, (lr,hr) in enumerate(val_dataset):
- print("*****************current index*****************:",batch_idx)
-
- lr = Tensor(lr, mstype.float32)
- pred = net_m(lr)
- pred_np = pred.asnumpy()
- hr = hr.asnumpy()
- pred_np = quantize(pred_np, 255)
- #print(pred_np)
- psnr = calc_psnr(pred_np, hr, 4, 255.0)
- pred_np = pred_np.reshape(pred_np.shape[-3:]).transpose(1, 2, 0)
- hr = hr.reshape(hr.shape[-3:]).transpose(1, 2, 0)
- ssim = calc_ssim(pred_np, hr, 4)
-
- print("current psnr: ", psnr)
- print("current ssim: ", ssim)
- #psnrs[batch_idx, 0] = psnr
- #ssims[batch_idx, 0] = ssim
-
-
- if __name__ == '__main__':
- context.set_context(device_target="Ascend",device_id=5)
- root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
- test_pipeline(root_path)
-
|