|
- # from subnet import endecoder
- # out = endecoder.build_model()
-
- import os
- import argparse
- import logging
- import numpy as np
- from net import *
- import json
- from dataset import DataSet, UVGDataSet
- from drawuvg import uvgdrawplt
- import mindspore as ms
- import mindspore as mindspore
- import mindspore.nn as nn
- import mindspore.dataset as ds
- from mindspore import context
- import datetime
- import mindspore.ops as ops
- import pandas as pd
- # torch.backends.cudnn.enabled = True
-
- parser = argparse.ArgumentParser(description='DVC reimplement')
- parser.add_argument(
- "-d", "--dataset", type=str, default="/userhome/DVC/PyTorch/data/UVG/images/", help="test dataset"
- )
- parser.add_argument(
- "--ckpt_path", type=str, default='/userhome/DVC/MindSpore/snapshot/best_256.ckpt', help="ckpt path"
- )
- parser.add_argument('-C', '--csv', default='DVC_mindspore.csv',type=str,
- help='csv results')
- parser.add_argument('-L', '--lmbda_list', default=256,type=int, nargs='+',
- help='lambda')
- parser.add_argument('-F', '--folders', default='ShakeNDry',type=str, nargs='+',
- help='folders')
-
- def load_model(model, model_path):
- param_dict = mindspore.load_checkpoint(model_path)
- param_not_load = mindspore.load_param_into_net(model, param_dict)
- print('Load checkpoint from '+model_path)
- print("param_not_load: ", param_not_load) #打印网络中没有被加载的参数,正常应该为空
- append_info = {}
- if 'loss' in param_dict.keys():
- append_info['loss'] = param_dict['loss'].value().asnumpy().item()
- if 'lr' in param_dict.keys():
- append_info['lr'] = param_dict['lr'].value().asnumpy().item()
- if 'epoch' in param_dict.keys():
- append_info['epoch'] = param_dict['epoch'].value().asnumpy().item()
- return append_info
-
- def testuvg(global_step, test_dataset, net, testfull=False):
- # with torch.no_grad():
- test_loader = ds.GeneratorDataset(test_dataset, column_names=["input_images", "ref_image", "refbpp", "refpsnr", "refmsssim"],
- num_parallel_workers=2, shuffle=False, python_multiprocessing=False)
- test_loader = test_loader.batch(1, drop_remainder=False)
- net.set_train(False)
- sumbpp = 0
- sumpsnr = 0
- summsssim = 0
- cnt = 0
- run_time = 0
- test_dataloader = test_loader.create_dict_iterator()
- for batch_idx, input in enumerate(test_dataloader):
- if batch_idx % 2 == 0:
- print("testing : %d/%d"% (batch_idx, test_loader.get_dataset_size()))
- input_images = input["input_images"]
- ref_image = input["ref_image"]
- ref_bpp = input["refbpp"]
- ref_psnr = input["refpsnr"]
- ref_msssim = input["refmsssim"]
- seqlen = input_images.shape[1]
- sumbpp += ops.ReduceMean(keep_dims=False)(ref_bpp).asnumpy()
- sumpsnr += ops.ReduceMean(keep_dims=False)(ref_psnr).asnumpy()
- summsssim += ops.ReduceMean(keep_dims=False)(ref_msssim).asnumpy()
- cnt += 1
- for i in range(seqlen):
- input_image = input_images[:, i, :, :, :]
- inputframe, refframe = input_image, ref_image
- # inputframe, refframe = Var(input_image), Var(ref_image)
- start_time = time.time()
- clipped_recon_image, mse_loss, warploss, interloss, bpp_feature, bpp_z, bpp_mv, bpp = net(inputframe, refframe, is_train=False)
- afterdec_time = time.time()
- run_time += afterdec_time-start_time
- sumbpp += ops.ReduceMean(keep_dims=False)(bpp).asnumpy()
- sumpsnr += ops.ReduceMean(keep_dims=False)(10 * (ops.log(1. / mse_loss) / np.log(10))).asnumpy()
- summsssim += ms_ssim(clipped_recon_image, input_image, data_range=1.0, size_average=True).asnumpy()
- cnt += 1
- ref_image = clipped_recon_image
- # sumbpp /= cnt
- # sumpsnr /= cnt
- # summsssim /= cnt
- log = "UVGdataset : average bpp : %.6lf, average psnr : %.6lf, average msssim: %.6lf\n" % (sumbpp/cnt, sumpsnr/cnt, summsssim/cnt)
- print(log)
- # uvgdrawplt([sumbpp], [sumpsnr], [summsssim], global_step, testfull=testfull)
- return cnt,sumbpp,sumpsnr,summsssim,run_time
-
- def save_result(lambda_all, folder_all, PSNR_all, MSSSIM_all, bpp_all, run_time, cnt_all, file):
- lambda_all = lambda_all.reshape(-1, 1)
- folder_all = folder_all.reshape(-1, 1)
- PSNR_all = PSNR_all.reshape(-1, 1)
- MSSSIM_all = MSSSIM_all.reshape(-1, 1)
- bpp_all = bpp_all.reshape(-1, 1)
- run_time = run_time.reshape(-1, 1)
- cnt_all = cnt_all.reshape(-1, 1)
- all_results = np.concatenate((lambda_all, cnt_all, bpp_all, PSNR_all, MSSSIM_all, folder_all,run_time), axis=1)
- results2 = pd.DataFrame(columns=['lambda', 'cnt','bpp', 'PSNR','MSSSIM', 'folder','run_time'], data=all_results)
- results2.to_csv(file,index=False)
-
- if __name__ == "__main__":
- context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") #GRAPH_MODE(静态图模式) PYNATIVE_MODE(动态图模式)
- context.set_context(save_graphs=False)
- context.set_context(device_id=int(os.getenv('DEVICE_ID', '0')))
- print("int(os.getenv('DEVICE_ID', '0')): ",int(os.getenv('DEVICE_ID', '0'))) #0
- if ms.get_context("device_target") == "GPU":
- context.set_context(enable_graph_kernel=False) #如果开启的话,会在本地生成额外的过程文件,开启图算融合以优化网络执行性能,常用于GPU,动静态模式应该都可以
- ms.reset_auto_parallel_context()
- ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.STAND_ALONE, gradients_mean=True, device_num=1)
- ds.config.set_enable_shared_mem(False) #多进程可以使用共享内存
-
- args = parser.parse_args()
- # args.lmbda_list = [args.lmbda_list]
- # lmbda_list = [256,512,1024,2048]
- PSNR_all = np.array([])
- MSSSIM_all = np.array([])
- bpp_all = np.array([])
- run_time = np.array([])
- lambda_all = np.array([])
- folder_all = np.array([])
- cnt_all = np.array([])
- # folders = [
- # # 'Beauty',
- # # 'HoneyBee',
- # # 'ReadySteadyGo',
- # 'YachtRide',
- # 'Bosphorus',
- # 'Jockey',
- # 'ShakeNDry'
- # ]
- for lmbda in args.lmbda_list:
- # filelist='/userhome/DVC/PyTorch/data/UVG/testv.txt'
- # filelist="/userhome/DVC/PyTorch/data/UVG/originalv.txt"
- model = VideoCompressor(is_train=False)
- ckpt_path = args.ckpt_path #'/userhome/DVC/MindSpore/snapshot/best_'+str(lmbda)+'.ckpt'
- append_info = load_model(model, ckpt_path)
- ref_i_dir = geti(lmbda)
- save_refmsssim = '/userhome/DVC/MindSpore/refmsssim_'+ref_i_dir+'.npy'
- for folder in args.folders: #直接一次性加载太多,会爆内存被kill掉
- test_dataset = UVGDataSet([folder],root=args.dataset,
- refdir=ref_i_dir, testfull=True,save_refmsssim=save_refmsssim)
- print('testing '+folder)
- cnt,sumbpp,sumpsnr,summsssim,one_run_time = testuvg(0, test_dataset, model, testfull=True)
- lambda_all = np.concatenate((lambda_all, [str(lmbda)]), axis=0)
- folder_all = np.concatenate((folder_all, [folder]), axis=0)
- PSNR_all = np.concatenate((PSNR_all, [sumpsnr]), axis=0)
- MSSSIM_all = np.concatenate((MSSSIM_all, [summsssim]), axis=0)
- bpp_all = np.concatenate((bpp_all, [sumbpp]), axis=0)
- run_time = np.concatenate((run_time, [one_run_time]), axis=0)
- cnt_all = np.concatenate((cnt_all, [cnt]), axis=0)
- save_result(lambda_all, folder_all, PSNR_all, MSSSIM_all, bpp_all, run_time, cnt_all, args.csv)
- if test_dataset.needsave:
- test_dataset.SaveRefmsssim(save_refmsssim)
|