|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- """DBPN test"""
- import argparse
- import ast
- import os
- import time
- from src.model.rbpn import Net as RBPN
- from mindspore import load_checkpoint, load_param_into_net, context
- from src.datasets.dataset import RBPNDataset, create_train_dataset , RBPNDatasetTest , create_val_dataset
- from src.util.utils import save_img, save_losses, save_psnr, compute_psnr , init_weights ,PSNR
- import numpy as np
-
- import moxing as mox
-
-
- environment = 'train'
- if environment == 'debug':
- workroot = '/home/ma-user/work' # 调试任务使用该参数
- else:
- workroot = '/home/work/user-job-dir' # 训练任务使用该参数
- print('current work mode:' + environment + ', workroot:' + workroot)
- parser = argparse.ArgumentParser(description="RBPN eval")
- parser.add_argument("--device_id", type=int, default=1, help="device id, default: 0.")
- # parser.add_argument("--val_path", type=str, default=r'/mass_data/dataset/Vid4')
- # parser.add_argument("--ckpt", type=str, default=r'./weights/第1次_RBPN.ckpt')
- parser.add_argument('--upscale_factor', type=int, default=4, choices=[2, 4, 8],
- help="Super resolution upscale factor")
- parser.add_argument('--testBatchSize', type=int, default=1, help='testing batch size')
- parser.add_argument('--model_type', type=str, default='RBPN')
- parser.add_argument('--save_eval_path', type=str, default="./Results/eval", help='save eval image path')
- # parser.add_argument('--data_dir', type=str, default=r'/mass_data/dataset/Vid4')
- parser.add_argument('--file_list', type=str, default='foliage.txt')
- parser.add_argument('--ckpt_name', type=str, default='69_RBPN.ckpt')
- parser.add_argument('--other_dataset', type=bool, default=True, help="use other dataset than vimeo-90k")
- parser.add_argument('--future_frame', type=bool, default=True, help="use future frame")
- parser.add_argument('--nFrames', type=int, default=7)
- parser.add_argument('--residual', type=bool, default=False)
- # parser.add_argument(
- # '--checkpoint_id',
- # type=int,
- # required=True,
- # help="use which checkpoint(.ckpt) file to eval"
- # )
-
- parser.add_argument('--data_url',
- help='path to training/inference dataset folder',
- default=workroot + '/data/')
-
- parser.add_argument('--train_url',
- help='model folder to save/load',
- default=workroot + '/model/')
- parser.add_argument(
- '--device_target',
- type=str,
- default="Ascend",
- choices=['Ascend', 'CPU'],
- help='device where the code will be implemented (default: CPU),若要在启智平台上使用NPU,需要在启智平台训练界面上加上运行参数device_target=Ascend')
-
- args = parser.parse_args()
- print(args)
-
-
-
- def predict(ds, model):
- """predict
- Args:
- ds(Dataset): eval dataset
- model(Cell): the generate model
- """
- avg_psnr = 0
- times = 0
- for index, batch in enumerate(ds.create_dict_iterator(), 1):
- input = batch['input_image']
- target = batch['target_image']
- bicubic = batch['bicubic_image']
- neigbor = batch['neigbor_image']
- flow = batch['flow_image']
-
- prediction = model(input , neigbor , flow)
- # prediction = prediction.cpu()
- prediction = prediction[0].asnumpy().astype(np.float32)
- prediction = prediction * 255.
-
- target = target.squeeze().asnumpy().astype(np.float32)
- target = target * 255.
-
- psnr_predicted = PSNR(prediction, target, shave_border=args.upscale_factor)
- avg_psnr += psnr_predicted
- print("第{}次的psnr为{}:".format(index , psnr_predicted))
- times = index
- print("平均的psnr值为: " , avg_psnr/times)
-
-
-
- if __name__ == "__main__":
- context.set_context(mode=context.GRAPH_MODE, device_id=args.device_id ,device_target=args.device_target)
- home = os.path.dirname(os.path.realpath(__file__))
- data_dir = os.path.join(home, 'data') # 数据集存放路径
- train_dir = os.path.join(home, 'checkpoints') # 模型存放路径
- if not os.path.exists(data_dir):
- os.mkdir(data_dir)
- # 初始化模型存放目录
- obs_train_url = args.train_url
- if not os.path.exists(train_dir):
- os.mkdir(train_dir)
- if environment == 'train':
- obs_data_url = args.data_url
- # 将数据拷贝到训练环境
- try:
- mox.file.copy_parallel(obs_data_url, data_dir)
- print("Successfully Download {} to {}".format(obs_data_url,
- data_dir))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- obs_data_url, data_dir) + str(e))
-
- path = home
- datanames = os.listdir(path)
- list = []
- for i in datanames:
- list.append(i)
- print("********************list:", list)
-
- # 数据集选择
- zip_out_dir = home + '/data/Vid4'
- file_l = home + '/foliage.txt'
- val_dataset = RBPNDatasetTest(zip_out_dir, args.nFrames, args.upscale_factor, file_l, args.other_dataset,
- args.future_frame)
- val_ds = create_val_dataset(val_dataset, args)
- print("=======> load model ckpt")
-
- # ckpt = os.path.join(home, 'rbpn_epoch{args.checkpoint_id}.ckpt')
- ckpt = os.path.join(home, args.ckpt_name)
-
- params = load_checkpoint(ckpt)
- print('===> Building model ', args.model_type)
-
-
-
-
-
- model = RBPN(num_channels=3, base_filter=256, feat=64, num_stages=3, n_resblock=5, nFrames=args.nFrames,
- scale_factor=args.upscale_factor)
- # init_weights(model, 'KaimingNormal', 0.02)
- init_weights(model, 'normal', 0.02)
- load_param_into_net(model, params)
- predict(val_ds, model)
|