|
- # from subnet import endecoder
- # out = endecoder.build_model()
-
- import os
- import argparse
- import logging
- import numpy as np
- from net import *
- import json
- from dataset import DataSet, UVGDataSet
- from drawuvg import uvgdrawplt
- import mindspore as ms
- import mindspore as mindspore
- import mindspore.nn as nn
- import mindspore.dataset as ds
- from mindspore import context
- import datetime
- import mindspore.ops as ops
- # torch.backends.cudnn.enabled = True
- # gpu_num = 4
- # gpu_num = torch.cuda.device_count()
- cur_lr = base_lr = 1e-4# * gpu_num
- train_lambda = 1024 #作者的是2048
- log_step = 100
- cal_step = 10
- print_step = 10
- warmup_step = 0# // gpu_num
- gpu_per_batch = 16 #T4用8 V100用16
- test_step = 10000# // gpu_num
- tot_epoch = 1000000
- tot_step = 2000000
- decay_interval = 1800000
- lr_decay = 0.1
- logger = logging.getLogger("VideoCompression")
- # tb_logger = None
- global_step = 0
- ref_i_dir = geti(train_lambda)
-
- parser = argparse.ArgumentParser(description='DVC reimplement')
- parser.add_argument('-l', '--log', default='/userhome/DVC/MindSpore/loguvg.txt',
- help='output training details')
- parser.add_argument('-p', '--pretrain', default = '/userhome/DVC/MindSpore/snapshot/iter16503.ckpt',
- help='load pretrain model')
- parser.add_argument('--testuvg', default = True, action='store_true')
- parser.add_argument('--config', dest='config', default = 'config.json',
- help = 'hyperparameter of Reid in json format')
-
-
- def testuvg(test_dataset, net, testfull=False):
- # with torch.no_grad():
- test_loader = ds.GeneratorDataset(test_dataset, column_names=["input_images", "ref_image", "refbpp", "refpsnr", "refmsssim"],
- num_parallel_workers=2, shuffle=False, python_multiprocessing=False)
- test_loader = test_loader.batch(1, drop_remainder=False)
- net.set_train(False)
- sumbpp = 0
- sumpsnr = 0
- summsssim = 0
- cnt = 0
- test_dataloader = test_loader.create_dict_iterator()
- for batch_idx, input in enumerate(test_dataloader):
- if batch_idx % 2 == 0:
- print("testing : %d/%d"% (batch_idx, test_loader.get_dataset_size()))
- input_images = input["input_images"]
- ref_image = input["ref_image"]
- ref_bpp = input["refbpp"]
- ref_psnr = input["refpsnr"]
- ref_msssim = input["refmsssim"]
- seqlen = input_images.shape[1]
- sumbpp += ops.ReduceMean(keep_dims=False)(ref_bpp).asnumpy()
- sumpsnr += ops.ReduceMean(keep_dims=False)(ref_psnr).asnumpy()
- summsssim += ops.ReduceMean(keep_dims=False)(ref_msssim).asnumpy()
- cnt += 1
- for i in range(seqlen):
- input_image = input_images[:, i, :, :, :]
- inputframe, refframe = input_image, ref_image
- # inputframe, refframe = Var(input_image), Var(ref_image)
- clipped_recon_image, mse_loss, warploss, interloss, bpp_feature, bpp_z, bpp_mv, bpp, VGG_loss, G_loss, D_loss = net(inputframe, refframe, is_train=False)
- sumbpp += ops.ReduceMean(keep_dims=False)(bpp).asnumpy()
- sumpsnr += ops.ReduceMean(keep_dims=False)(10 * (ops.log(1. / mse_loss) / np.log(10))).asnumpy()
- summsssim += ms_ssim(clipped_recon_image, input_image, data_range=1.0, size_average=True).asnumpy()
- cnt += 1
- ref_image = clipped_recon_image
- log = "global step %d : " % (global_step) + "\n"
- logger.info(log)
- sumbpp /= cnt
- sumpsnr /= cnt
- summsssim /= cnt
- log = "UVGdataset : average bpp : %.6lf, average psnr : %.6lf, average msssim: %.6lf\n" % (sumbpp, sumpsnr, summsssim)
- logger.info(log)
- uvgdrawplt([sumbpp], [sumpsnr], [summsssim], global_step, testfull=testfull)
-
-
-
- if __name__ == "__main__":
- context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") #GRAPH_MODE(静态图模式) PYNATIVE_MODE(动态图模式)
- context.set_context(save_graphs=False)
- context.set_context(device_id=int(os.getenv('DEVICE_ID', '0')))
- if ms.get_context("device_target") == "GPU":
- context.set_context(enable_graph_kernel=False) #如果开启的话,会在本地生成额外的过程文件,开启图算融合以优化网络执行性能,常用于GPU,动静态模式应该都可以
- ms.reset_auto_parallel_context()
- ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.STAND_ALONE, gradients_mean=True, device_num=1)
- ds.config.set_enable_shared_mem(False) #多进程可以使用共享内存
-
- args = parser.parse_args()
-
- model = VideoCompressor(is_train=False, args=args)
- global_step = load_model(model, args.pretrain)
- Test(model)
|