|
-
- import os
- import argparse
- import logging
- import numpy as np
- from net import *
- import json
- from dataset import DataSet, UVGDataSet
- from drawuvg import uvgdrawplt
- import mindspore as ms
- import mindspore as mindspore
- import mindspore.nn as nn
- import mindspore.dataset as ds
- from mindspore import context
- import datetime
- import mindspore.ops as ops
-
- train_lambda = 512 #作者的是2048
- gpu_per_batch = 16 #T4用8 V100用16
- global cur_lr
- cur_lr = 2e-4# * gpu_num
- log_step = 50
- cal_step = 1
- print_step = 50
- warmup_step = 0# // gpu_num
- # test_step = 10000# // gpu_num
- # tot_epoch = 1000
- # tot_step = 2000000
- # decay_interval = 1800000
- # lr_decay = 0.1
- logger = logging.getLogger("VideoCompression")
- # tb_logger = None
- global_step = 0
- ref_i_dir = geti(train_lambda)
-
- parser = argparse.ArgumentParser(description='DVC reimplement')
- parser.add_argument('-l', '--log', default='/userhome/DVC/MindSpore/train.log',
- help='output training details')
- parser.add_argument('-p', '--pretrain', default = '/userhome/DVC/MindSpore/snapshot/best_512.ckpt',
- help='load pretrain model')
- parser.add_argument('--testuvg', default = False, action='store_true')
- parser.add_argument("-e","--epochs",default=12,type=int,help="Number of epochs (default: %(default)s)",
- )
-
- def adjust_learning_rate(optimizer, global_step):
- global cur_lr
- lr = cur_lr*0.8
- # global warmup_step
- # if global_step < warmup_step:
- # lr = base_lr * global_step / warmup_step
- # elif global_step < decay_interval:# // gpu_num:
- # lr = base_lr
- # else:
- # lr = base_lr * (lr_decay ** (global_step // decay_interval))
- cur_lr = lr
- ops.assign(optimizer.learning_rate, ms.Tensor(lr, ms.float32))
-
- def testuvg(global_step, test_dataset, net, testfull=False):
- # with torch.no_grad():
- test_loader = ds.GeneratorDataset(test_dataset, column_names=["input_images", "ref_image", "refbpp", "refpsnr", "refmsssim"],
- num_parallel_workers=2, shuffle=False, python_multiprocessing=False)
- test_loader = test_loader.batch(1, drop_remainder=False)
- net.set_train(False)
- sumbpp = 0
- sumpsnr = 0
- summsssim = 0
- cnt = 0
- test_dataloader = test_loader.create_dict_iterator()
- for batch_idx, input in enumerate(test_dataloader):
- if batch_idx % 10 == 0:
- print("testing : %d/%d"% (batch_idx, test_loader.get_dataset_size()))
- input_images = input["input_images"]
- ref_image = input["ref_image"]
- ref_bpp = input["refbpp"]
- ref_psnr = input["refpsnr"]
- ref_msssim = input["refmsssim"]
- seqlen = input_images.shape[1]
- sumbpp += ops.ReduceMean(keep_dims=False)(ref_bpp).asnumpy()
- sumpsnr += ops.ReduceMean(keep_dims=False)(ref_psnr).asnumpy()
- summsssim += ops.ReduceMean(keep_dims=False)(ref_msssim).asnumpy()
- cnt += 1
- for i in range(seqlen):
- input_image = input_images[:, i, :, :, :]
- inputframe, refframe = input_image, ref_image
- # inputframe, refframe = Var(input_image), Var(ref_image)
- clipped_recon_image, mse_loss, warploss, interloss, bpp_feature, bpp_z, bpp_mv, bpp = net(inputframe, refframe, is_train=False)
- sumbpp += ops.ReduceMean(keep_dims=False)(bpp).asnumpy()
- sumpsnr += ops.ReduceMean(keep_dims=False)(10 * (ops.log(1. / mse_loss) / np.log(10))).asnumpy()
- summsssim += ms_ssim(clipped_recon_image, input_image, data_range=1.0, size_average=True).asnumpy()
- cnt += 1
- ref_image = clipped_recon_image
- log = "global step %d : " % (global_step) + "\n"
- logger.info(log)
- sumbpp /= cnt
- sumpsnr /= cnt
- summsssim /= cnt
- log = "UVGdataset : average bpp : %.6lf, average psnr : %.6lf, average msssim: %.6lf\n" % (sumbpp, sumpsnr, summsssim)
- logger.info(log)
- uvgdrawplt([sumbpp], [sumpsnr], [summsssim], global_step, testfull=testfull)
-
- class AverageMeter:
- """Compute running average."""
-
- def __init__(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
-
- def test(epoch, net, test_loader):
- net.set_train(False)
- test_dataloader = test_loader.create_dict_iterator()
- loss_avg = AverageMeter()
- bpp_loss_avg = AverageMeter()
- mse_loss_avg = AverageMeter()
- for batch_idx, input in enumerate(test_dataloader):
- input_image = input['input_image']
- ref_image = input['ref_image']
- clipped_recon_image, mse_loss, warploss, interloss, bpp_feature, bpp_z, bpp_mv, bpp = net(input_image, ref_image, is_train=False)
- mse_loss, warploss, interloss, bpp_feature, bpp_z, bpp_mv, bpp = \
- ops.ReduceMean(keep_dims=False)(mse_loss), ops.ReduceMean(keep_dims=False)(warploss), \
- ops.ReduceMean(keep_dims=False)(interloss), ops.ReduceMean(keep_dims=False)(bpp_feature), \
- ops.ReduceMean(keep_dims=False)(bpp_z), ops.ReduceMean(keep_dims=False)(bpp_mv), ops.ReduceMean(keep_dims=False)(bpp)
- rd_loss = train_lambda * mse_loss + bpp
- bpp_loss_avg.update(bpp.asnumpy().item())
- loss_avg.update(rd_loss.asnumpy().item())
- mse_loss_avg.update(mse_loss.asnumpy().item())
- log_data = f"Test epoch {epoch}: Average losses:"\
- f"\tLoss: {loss_avg.avg:.3f} |"\
- f"\tMSE loss: {mse_loss_avg.avg:.3f} |"\
- f"\tBpp loss: {bpp_loss_avg.avg:.2f}\n"
- logger.info(log_data)
- return loss_avg.avg
-
- def train(epoch, train_loader, model, optimizer, global_step):
- # global gpu_per_batch
- batch_num = train_loader.get_dataset_size()
- traindata_loader = train_loader.create_dict_iterator()
- # train_loader = DataLoader(dataset = train_dataset, shuffle=True, num_workers=gpu_num, batch_size=gpu_per_batch, pin_memory=True)
-
- def forward_fn(input_image, ref_image, quant_noise_feature,
- quant_noise_z, quant_noise_mv, is_train): #前向传播并计算loss
- clipped_recon_image, mse_loss, warploss, interloss, bpp_feature, bpp_z, bpp_mv, bpp = model(input_image,
- ref_image, quant_noise_feature, quant_noise_z, quant_noise_mv, is_train=is_train)
- mse_loss, warploss, interloss, bpp_feature, bpp_z, bpp_mv, bpp = \
- ops.ReduceMean(keep_dims=False)(mse_loss), ops.ReduceMean(keep_dims=False)(warploss), \
- ops.ReduceMean(keep_dims=False)(interloss), ops.ReduceMean(keep_dims=False)(bpp_feature), \
- ops.ReduceMean(keep_dims=False)(bpp_z), ops.ReduceMean(keep_dims=False)(bpp_mv), ops.ReduceMean(keep_dims=False)(bpp)
- distribution_loss = bpp
- if epoch < 5:
- warp_weight = 0.1
- else:
- warp_weight = 0
- distortion = mse_loss + warp_weight * (warploss + interloss)
- rd_loss = train_lambda * distortion + distribution_loss
- return rd_loss, mse_loss, warploss,interloss,bpp, bpp_feature, bpp_mv, bpp_z #结合下面has_aux=True,只对第一项计算梯度
-
- grad_fn = mindspore.value_and_grad(forward_fn, #梯度函数
- grad_position=None, #只对网络变量求导
- weights=optimizer.parameters, #需要返回梯度的网络变量
- has_aux=True) #是否返回辅助参数的标志,若为True,fn输出数量必须超过一个,其中只有fn第一个输出参与求导,其他输出值将直接返回
-
- def train_step(input_image, ref_image, quant_noise_feature,
- quant_noise_z, quant_noise_mv, is_train): #一个batch step的操作
- (rd_loss, mse_loss, warploss,interloss,bpp, bpp_feature, bpp_mv, bpp_z), grads = grad_fn(input_image, ref_image, quant_noise_feature,
- quant_noise_z, quant_noise_mv, is_train) #这里的输入参数直接传给forward_fn,返回的第一项是forward_fn的输出,第二项是梯度
- # print(loss)
- grads = ops.clip_by_value(grads, -0.5, 0.5)
- out_loss = {}
- out_loss['rd_loss'] = rd_loss
- out_loss['mse_loss'] = mse_loss
- out_loss['warploss'] = warploss
- out_loss['interloss'] = interloss
- out_loss['bpp'] = bpp
- out_loss['bpp_feature'] = bpp_feature
- out_loss['bpp_mv'] = bpp_mv
- out_loss['bpp_z'] = bpp_z
- out_loss = ops.Depend()(out_loss, optimizer(grads)) #optimizer(grads)优化器更新参数,Depend确保梯度更新完成,再输出out_loss
- return out_loss
-
- model.set_train()
-
- # global optimizer
- bat_cnt = 0
- cal_cnt = 0
- sumloss = 0
- sumpsnr = 0
- suminterpsnr = 0
- sumwarppsnr = 0
- sumbpp = 0
- sumbpp_feature = 0
- sumbpp_mv = 0
- sumbpp_z = 0
- # tot_iter = len(train_loader)
- t0 = datetime.datetime.now()
- for batch_idx, input in enumerate(traindata_loader):
- global_step += 1
- bat_cnt += 1
- input_image = input['input_image']
- ref_image = input['ref_image']
- # input_image, ref_image = Var(input[0]), Var(input[1])
- quant_noise_feature = input['quant_noise_feature']
- quant_noise_z = input['quant_noise_z']
- quant_noise_mv = input['quant_noise_mv']
-
- out_loss = train_step(input_image,
- ref_image, quant_noise_feature, quant_noise_z, quant_noise_mv, is_train=True)
-
- if global_step % cal_step == 0:
- cal_cnt += 1
- if out_loss['mse_loss'] > 0:
- psnr = 10 * (ops.log(1 * 1.0 / out_loss['mse_loss']) / np.log(10)).asnumpy().item()
- else:
- psnr = 100
- if out_loss['warploss'] > 0:
- warppsnr = 10 * (ops.log(1 * 1.0 / out_loss['warploss']) / np.log(10)).asnumpy().item()
- else:
- warppsnr = 100
- if out_loss['interloss'] > 0:
- interpsnr = 10 * (ops.log(1 * 1.0 / out_loss['interloss']) / np.log(10)).asnumpy().item()
- else:
- interpsnr = 100
-
- loss_ = out_loss['rd_loss'].asnumpy().item()
-
- sumloss += loss_
- sumpsnr += psnr
- suminterpsnr += interpsnr
- sumwarppsnr += warppsnr
- sumbpp += out_loss['bpp'].asnumpy().item()
- sumbpp_feature += out_loss['bpp_feature'].asnumpy().item()
- sumbpp_mv += out_loss['bpp_mv'].asnumpy().item()
- sumbpp_z += out_loss['bpp_z'].asnumpy().item()
-
- # if (global_step % print_step) == 0 and cal_cnt > 1:
- # print("out_loss['rd_loss']:",out_loss['rd_loss'])
- # t1 = datetime.datetime.now()
- # deltatime = t1 - t0
- # log = 'Train Epoch : {:02} [{:4}/{:4} ({:3.0f}%)] Avgloss:{:.6f} lr:{} time:{}'.format(epoch,
- # batch_idx, train_loader.get_dataset_size(), 100. * batch_idx / train_loader.get_dataset_size(), sumloss / cal_cnt, cur_lr,
- # (deltatime.seconds + 1e-6 * deltatime.microseconds) / bat_cnt)
- # print(log)
- # log = 'details : warppsnr : {:.2f} interpsnr : {:.2f} psnr : {:.2f}'.format(sumwarppsnr / cal_cnt,
- # suminterpsnr / cal_cnt, sumpsnr / cal_cnt)
- # print(log)
-
- if (batch_idx % log_step)== 0 and bat_cnt > 1:
- # log_data = f'input_image.shape: {input_image.shape}, cur_lr: {cur_lr}'
- # logger.info(log_data)
- log_data = f'Train epoch {epoch}, global_step {global_step}: ['\
- f'{batch_idx}/{batch_num}'\
- f' ({100. * batch_idx / batch_num:.0f}%)]'\
- f'\trd_loss: {sumloss / cal_cnt:.3f} |'\
- f'\tpsnr: {sumpsnr / cal_cnt:.5f} |'\
- f'\twarppsnr: {sumwarppsnr / cal_cnt:.5f} |'\
- f'\tbpp: {sumbpp / cal_cnt:.5f} |'\
- f'\tbpp_feature: {sumbpp_feature / cal_cnt:.5f} |'\
- f'\tbpp_z: {sumbpp_z / cal_cnt:.5f} |'\
- f'\tbpp_mv: {sumbpp_mv / cal_cnt:.5f} |'\
- f'\tinterpsnr: {suminterpsnr / cal_cnt:.2f}'
- logger.info(log_data)
- # save_model(model, global_step)
- bat_cnt = 0
- cal_cnt = 0
- sumbpp = sumbpp_feature = sumbpp_mv = sumbpp_z = sumloss = sumpsnr = suminterpsnr = sumwarppsnr = 0
- # t0 = t1
- log = 'Train Epoch : {:02} Loss:\t {:.6f}\t lr:{}'.format(epoch, sumloss / bat_cnt, cur_lr)
- logger.info(log)
- return global_step
-
- def save_model(model, append_info, lmbda,dirname='/userhome/DVC/MindSpore/snapshot/'):
- filename = dirname + f'best_{lmbda}.ckpt'
- mindspore.save_checkpoint(model, filename, append_dict=append_info)
- print('checkpoint存储成功: ', filename)
-
- def load_model(model, model_path):
- param_dict = mindspore.load_checkpoint(model_path)
- param_not_load = mindspore.load_param_into_net(model, param_dict)
- print('Load checkpoint from '+model_path)
- print("param_not_load: ", param_not_load) #打印网络中没有被加载的参数,正常应该为空
- append_info = {}
- if 'loss' in param_dict.keys():
- append_info['loss'] = param_dict['loss'].value().asnumpy().item()
- if 'lr' in param_dict.keys():
- append_info['lr'] = param_dict['lr'].value().asnumpy().item()
- if 'epoch' in param_dict.keys():
- append_info['epoch'] = param_dict['epoch'].value().asnumpy().item()
- return append_info
-
- if __name__ == "__main__":
- context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") #GRAPH_MODE(静态图模式) PYNATIVE_MODE(动态图模式)
- context.set_context(save_graphs=False)
- context.set_context(device_id=int(os.getenv('DEVICE_ID', '0')))
- print("int(os.getenv('DEVICE_ID', '0')): ",int(os.getenv('DEVICE_ID', '0'))) #0
- if ms.get_context("device_target") == "GPU":
- context.set_context(enable_graph_kernel=False) #如果开启的话,会在本地生成额外的过程文件,开启图算融合以优化网络执行性能,常用于GPU,动静态模式应该都可以
- ms.reset_auto_parallel_context()
- ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.STAND_ALONE, gradients_mean=True, device_num=1)
- ds.config.set_enable_shared_mem(False) #多进程可以使用共享内存
-
- args = parser.parse_args()
-
- formatter = logging.Formatter('%(asctime)s - %(levelname)s] %(message)s')
- stdhandler = logging.StreamHandler()
- stdhandler.setLevel(logging.INFO)
- stdhandler.setFormatter(formatter)
- logger.addHandler(stdhandler)
- if args.log != '':
- args.log = args.log[:-4] + '_' + str(train_lambda) + '.log'
- filehandler = logging.FileHandler(args.log)
- filehandler.setLevel(logging.INFO)
- filehandler.setFormatter(formatter)
- logger.addHandler(filehandler)
- logger.setLevel(logging.INFO)
- logger.info("DVC training")
- # logger.info("config : ")
- # logger.info(open(args.config).read())
- # parse_config(args.config)
-
- model = VideoCompressor(is_train=True)
- stepoch = 0
- best_loss = float("inf")
- if args.pretrain != '' and os.path.isfile(args.pretrain):
- append_info = load_model(model, args.pretrain)
- if 'epoch' in append_info.keys():
- stepoch = append_info['epoch']
- if 'loss' in append_info.keys():
- best_loss = append_info['loss']
- if 'lr' in append_info.keys():
- # global cur_lr
- cur_lr = append_info['lr']
- print(append_info)
- optimizer = nn.Adam(params=model.trainable_params(), learning_rate=cur_lr)
- # global train_dataset, test_dataset
- if args.testuvg:
- test_dataset = UVGDataSet(refdir=ref_i_dir, testfull=True)
- print('testing UVG')
- testuvg(0, test_dataset, model, testfull=True)
- exit(0) #程序正常退出
-
- train_dataset = DataSet(rootdir="/userhome/DVC/PyTorch/data/vimeo_septuplet/sequences/",
- test_txt="/userhome/DVC/PyTorch/data/vimeo_septuplet/sep_trainlist_DVC.txt")
- train_loader = ds.GeneratorDataset(train_dataset, column_names=["input_image", "ref_image", "quant_noise_feature", "quant_noise_z", "quant_noise_mv"],
- num_parallel_workers=2, shuffle=True, python_multiprocessing=False) #使用多线程
- train_loader = train_loader.batch(gpu_per_batch, drop_remainder=False)
- # train_loader = ds.GeneratorDataset(train_dataset, column_names=["input_image", "ref_image", "quant_noise_feature", "quant_noise_z", "quant_noise_mv"],
- # num_parallel_workers=3, shuffle=True, python_multiprocessing=False) #使用多线程
- # train_loader = train_loader.batch(gpu_per_batch, drop_remainder=False)
- test_dataset = DataSet(rootdir="/userhome/DVC/PyTorch/data/vimeo_septuplet/sequences/",
- test_txt="/userhome/DVC/PyTorch/data/vimeo_septuplet/sep_testlist_DVC.txt")
- test_loader = ds.GeneratorDataset(test_dataset, column_names=["input_image", "ref_image", "quant_noise_feature", "quant_noise_z", "quant_noise_mv"],
- num_parallel_workers=2, shuffle=False, python_multiprocessing=False) #使用多线程
- test_loader = test_loader.batch(1, drop_remainder=False)
-
- # test_dataset = UVGDataSet(refdir=ref_i_dir)
- # stepoch = global_step // (train_dataset.__len__() // (gpu_per_batch))# * gpu_num))
- for epoch in range(stepoch, stepoch+args.epochs):
- if epoch % 5 == 4:
- adjust_learning_rate(optimizer, global_step)
- # if global_step > tot_step:
- # save_model(model, global_step)
- # break
- global_step = train(epoch, train_loader, model, optimizer, global_step)
- loss = test(epoch, model, test_loader)
- is_best = loss < best_loss
- best_loss = min(loss, best_loss)
- if is_best:
- save_model(model, {"loss": best_loss,'epoch':epoch+1,'lr':cur_lr}, train_lambda)
- log_data = f"Update best model using best_loss: {best_loss}"
- logger.info(log_data)
- logger.info('Finish once training!')
|