|
-
- import os
- import argparse
- import logging
- import numpy as np
- from net import *
- import json
- from dataset import DataSet, UVGDataSet
- from drawuvg import uvgdrawplt
- import mindspore as ms
- import mindspore as mindspore
- import mindspore.nn as nn
- import mindspore.dataset as ds
- from mindspore import context
- import datetime
- import mindspore.ops as ops
- from tqdm import tqdm
- import subprocess
-
- parser = argparse.ArgumentParser(description='DVC_P reimplement')
- parser.add_argument('-l', '--log', default='./train.log',
- help='output training details')
- parser.add_argument('-p', '--pretrain', default = '',
- help='load pretrain model')
- parser.add_argument('--config', dest='config', default = 'config.json',
- help = 'hyperparameter of Reid in json format')
- parser.add_argument('--image_size', type=int, nargs='+', default=[256, 256])
- parser.add_argument('--train_lambda', type=int, default=1024)
-
-
- def adjust_learning_rate(optimizer, global_step):
- global cur_lr
- lr = cur_lr*0.8
- cur_lr = lr
- ops.assign(optimizer.learning_rate, ms.Tensor(lr, ms.float32))
-
-
- def forward_fn(input_image, ref_image, quant_noise_feature,
- quant_noise_z, quant_noise_mv, train_lambda): #前向传播并计算loss
- clipped_recon_image, mse_loss, warploss, interloss, bpp_feature, bpp_z, bpp_mv, bpp, VGG_loss, G_loss, D_loss = model(input_image,
- ref_image, quant_noise_feature, quant_noise_z, quant_noise_mv, is_train=True, global_step=global_step)
- mse_loss, warploss, interloss, bpp_feature, bpp_z, bpp_mv, bpp = \
- ops.ReduceMean(keep_dims=False)(mse_loss), ops.ReduceMean(keep_dims=False)(warploss), \
- ops.ReduceMean(keep_dims=False)(interloss), ops.ReduceMean(keep_dims=False)(bpp_feature), \
- ops.ReduceMean(keep_dims=False)(bpp_z), ops.ReduceMean(keep_dims=False)(bpp_mv), ops.ReduceMean(keep_dims=False)(bpp)
- distribution_loss = bpp
- if global_step < 500000:
- warp_weight = 0.1
- else:
- warp_weight = 0
-
- distortion = mse_loss + warp_weight * (warploss + interloss)
- if global_step >= 400000:
- distortion = distortion + 0.04 * VGG_loss + 0.1 * G_loss + 0.1 * D_loss
-
- rd_loss = train_lambda * distortion + distribution_loss
-
- return rd_loss, mse_loss, warploss,interloss, bpp, bpp_feature, bpp_mv, bpp_z #结合下面has_aux=True,只对第一项计算梯度
-
- def train_step(optimizer, input_image, ref_image, quant_noise_feature,
- quant_noise_z, quant_noise_mv, train_lambda): #一个batch step的操作
-
- grad_fn = mindspore.value_and_grad(forward_fn, #梯度函数
- grad_position=None, #只对网络变量求导
- weights=optimizer.parameters, #需要返回梯度的网络变量
- has_aux=True) #是否返回辅助参数的标志,若为True,fn输出数量必须超过一个,其中只有fn第一个输出参与求导,其他输出值将直接返回
-
- (rd_loss, mse_loss, warploss,interloss,bpp, bpp_feature, bpp_mv, bpp_z), grads = grad_fn(input_image, ref_image, quant_noise_feature,
- quant_noise_z, quant_noise_mv, train_lambda) #这里的输入参数直接传给forward_fn,返回的第一项是forward_fn的输出,第二项是梯度
- # print(loss)
- grads = ops.clip_by_value(grads, -0.5, 0.5)
- out_loss = {}
- out_loss['rd_loss'] = rd_loss
- out_loss['mse_loss'] = mse_loss
- out_loss['warploss'] = warploss
- out_loss['interloss'] = interloss
- out_loss['bpp'] = bpp
- out_loss['bpp_feature'] = bpp_feature
- out_loss['bpp_mv'] = bpp_mv
- out_loss['bpp_z'] = bpp_z
- out_loss = ops.Depend()(out_loss, optimizer(grads)) #optimizer(grads)优化器更新参数,Depend确保梯度更新完成,再输出out_loss
- return out_loss
-
- class AverageMeter:
- """Compute running average."""
-
- def __init__(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
-
- def test(epoch, net, test_loader):
- net.set_train(False)
- test_dataloader = test_loader.create_dict_iterator()
- loss_avg = AverageMeter()
- bpp_loss_avg = AverageMeter()
- mse_loss_avg = AverageMeter()
- for batch_idx, input in enumerate(test_dataloader):
- input_image = input['input_image']
- ref_image = input['ref_image']
- clipped_recon_image, mse_loss, warploss, interloss, bpp_feature, bpp_z, bpp_mv, bpp, _, _, _ = net(input_image, ref_image, is_train=False)
- mse_loss, warploss, interloss, bpp_feature, bpp_z, bpp_mv, bpp = \
- ops.ReduceMean(keep_dims=False)(mse_loss), ops.ReduceMean(keep_dims=False)(warploss), \
- ops.ReduceMean(keep_dims=False)(interloss), ops.ReduceMean(keep_dims=False)(bpp_feature), \
- ops.ReduceMean(keep_dims=False)(bpp_z), ops.ReduceMean(keep_dims=False)(bpp_mv), ops.ReduceMean(keep_dims=False)(bpp)
- rd_loss = args.train_lambda * mse_loss + bpp
- bpp_loss_avg.update(bpp.asnumpy().item())
- loss_avg.update(rd_loss.asnumpy().item())
- mse_loss_avg.update(mse_loss.asnumpy().item())
- log_data = f"Test epoch {epoch}: Average losses:"\
- f"\tLoss: {loss_avg.avg:.3f} |"\
- f"\tMSE loss: {mse_loss_avg.avg:.3f} |"\
- f"\tBpp loss: {bpp_loss_avg.avg:.2f}\n"
- logger.info(log_data)
- return loss_avg.avg
-
-
- def train(epoch, traindata_loader, batch_num, model, optimizer, global_step, train_lambda):
- # train_loader = DataLoader(dataset = train_dataset, shuffle=True, num_workers=gpu_num, batch_size=gpu_per_batch, pin_memory=True)
-
- model.set_train()
-
- # global optimizer
- bat_cnt = 0
- cal_cnt = 0
- sumloss = 0
- sumpsnr = 0
- suminterpsnr = 0
- sumwarppsnr = 0
- sumbpp = 0
- sumbpp_feature = 0
- sumbpp_mv = 0
- sumbpp_z = 0
- # tot_iter = len(train_loader)
- t0 = datetime.datetime.now()
- for batch_idx, input in tqdm(enumerate(traindata_loader)):
- global_step += 1
- bat_cnt += 1
- input_image = input['input_image']
- ref_image = input['ref_image']
- # input_image, ref_image = Var(input[0]), Var(input[1])
- quant_noise_feature = input['quant_noise_feature']
- quant_noise_z = input['quant_noise_z']
- quant_noise_mv = input['quant_noise_mv']
-
- out_loss = train_step(optimizer, input_image,
- ref_image, quant_noise_feature, quant_noise_z, quant_noise_mv, train_lambda)
-
- if global_step % cal_step == 0:
- cal_cnt += 1
- if out_loss['mse_loss'] > 0:
- psnr = 10 * (ops.log(1 * 1.0 / out_loss['mse_loss']) / np.log(10)).asnumpy().item()
- else:
- psnr = 100
- if out_loss['warploss'] > 0:
- warppsnr = 10 * (ops.log(1 * 1.0 / out_loss['warploss']) / np.log(10)).asnumpy().item()
- else:
- warppsnr = 100
- if out_loss['interloss'] > 0:
- interpsnr = 10 * (ops.log(1 * 1.0 / out_loss['interloss']) / np.log(10)).asnumpy().item()
- else:
- interpsnr = 100
-
- loss_ = out_loss['rd_loss'].asnumpy().item()
-
- sumloss += loss_
- sumpsnr += psnr
- suminterpsnr += interpsnr
- sumwarppsnr += warppsnr
- sumbpp += out_loss['bpp'].asnumpy().item()
- sumbpp_feature += out_loss['bpp_feature'].asnumpy().item()
- sumbpp_mv += out_loss['bpp_mv'].asnumpy().item()
- sumbpp_z += out_loss['bpp_z'].asnumpy().item()
-
- if (global_step % print_step) == 0 and cal_cnt > 1:
- print("out_loss['rd_loss']:",out_loss['rd_loss'])
- t1 = datetime.datetime.now()
- deltatime = t1 - t0
- log = 'Train Epoch : {:02} [{:4}/{:4} ({:3.0f}%)] Avgloss:{:.6f} lr:{} time:{}'.format(epoch,
- batch_idx, train_loader.get_dataset_size(), 100. * batch_idx / train_loader.get_dataset_size(), sumloss / cal_cnt, cur_lr,
- (deltatime.seconds + 1e-6 * deltatime.microseconds) / bat_cnt)
- print(log)
- log = 'details : warppsnr : {:.2f} interpsnr : {:.2f} psnr : {:.2f}'.format(sumwarppsnr / cal_cnt,
- suminterpsnr / cal_cnt, sumpsnr / cal_cnt)
- print(log)
-
- if (batch_idx % log_step)== 0 and bat_cnt > 1:
- log_data = f'input_image.shape: {input_image.shape}, cur_lr: {cur_lr}'
- logger.info(log_data)
- log_data = f'Train epoch {epoch}, global_step {global_step}: ['\
- f'{batch_idx}/{batch_num}'\
- f' ({100. * batch_idx / batch_num:.0f}%)]'\
- f'\trd_loss: {sumloss / cal_cnt:.3f} |'\
- f'\tpsnr: {sumpsnr / cal_cnt:.5f} |'\
- f'\twarppsnr: {sumwarppsnr / cal_cnt:.5f} |'\
- f'\tbpp: {sumbpp / cal_cnt:.5f} |'\
- f'\tbpp_feature: {sumbpp_feature / cal_cnt:.5f} |'\
- f'\tbpp_z: {sumbpp_z / cal_cnt:.5f} |'\
- f'\tbpp_mv: {sumbpp_mv / cal_cnt:.5f} |'\
- f'\tinterpsnr: {suminterpsnr / cal_cnt:.2f}'
- logger.info(log_data)
- bat_cnt = 0
- cal_cnt = 0
- sumbpp = sumbpp_feature = sumbpp_mv = sumbpp_z = sumloss = sumpsnr = suminterpsnr = sumwarppsnr = 0
- t0 = t1
- log = 'Train Epoch : {:02} Loss:\t {:.6f}\t lr:{}'.format(epoch, sumloss / bat_cnt, cur_lr)
- logger.info(log)
- #save_model(model, train_lambda, global_step)
- return global_step
-
-
- if __name__ == "__main__":
- context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") #GRAPH_MODE(静态图模式) PYNATIVE_MODE(动态图模式)
- context.set_context(save_graphs=False)
- context.set_context(device_id=int(os.getenv('DEVICE_ID', '0')))
- print("int(os.getenv('DEVICE_ID', '0')): ",int(os.getenv('DEVICE_ID', '0'))) #0
- if ms.get_context("device_target") == "GPU":
- context.set_context(enable_graph_kernel=False) #如果开启的话,会在本地生成额外的过程文件,开启图算融合以优化网络执行性能,常用于GPU,动静态模式应该都可以
- ms.reset_auto_parallel_context()
- ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.STAND_ALONE, gradients_mean=True, device_num=1)
- ds.config.set_enable_shared_mem(False) #多进程可以使用共享内存
-
- args = parser.parse_args()
- cur_lr = base_lr = 2e-5# * gpu_num
- train_lambda = args.train_lambda #作者的是2048
- print('train_lambda', train_lambda)
- log_step = 500
- cal_step = 10
- print_step = 50
- warmup_step = 0# // gpu_num
- gpu_per_batch = 16 #T4用8 V100用16
- #test_step = 10000# // gpu_num
- tot_epoch = 10000
- tot_step = 700000
- decay_interval = 1800000
- lr_decay = 0.1
- logger = logging.getLogger("video compression")
- # tb_logger = None
- global_step = 0
- ref_i_dir = geti(train_lambda)
-
- best_loss = float("inf")
-
- formatter = logging.Formatter('%(asctime)s - %(levelname)s] %(message)s')
- stdhandler = logging.StreamHandler()
- stdhandler.setLevel(logging.INFO)
- stdhandler.setFormatter(formatter)
- logger.addHandler(stdhandler)
- if args.log != '':
- filehandler = logging.FileHandler("./train_{}.log".format(args.train_lambda))
- filehandler.setLevel(logging.INFO)
- filehandler.setFormatter(formatter)
- logger.addHandler(filehandler)
- logger.setLevel(logging.INFO)
- logger.info("DVC_P training")
- #logger.info("config : ")
- #logger.info(open(args.config).read())
- #parse_config(args.config)
-
- model = VideoCompressor(is_train=True, args=args)
-
- if args.pretrain != '':
- # global_step = load_model(model, args.pretrain)
- load_model(model, args.pretrain)
- #global_step = 500000
- # net = torch.nn.DataParallel(net, list(range(gpu_num)))
- optimizer = nn.Adam(params=[param for param in model.trainable_params() if 'vgg' not in param.name], learning_rate=base_lr)
-
- # global train_dataset, test_dataset
- train_dataset = DataSet(rootdir="/userhome/data/vimeo_septuplet/sequences/",
- test_txt="/userhome/data/vimeo_septuplet/sep_trainlist_DVC.txt")
-
- # global gpu_per_batch
- train_loader = ds.GeneratorDataset(train_dataset, column_names=["input_image", "ref_image", "quant_noise_feature", "quant_noise_z", "quant_noise_mv"],
- num_parallel_workers=2, shuffle=True, python_multiprocessing=False) #使用多线程
- train_loader = train_loader.batch(gpu_per_batch, drop_remainder=True)
- batch_num = train_loader.get_dataset_size()
- traindata_loader = train_loader.create_dict_iterator()
-
- test_dataset = DataSet(rootdir="/userhome/data/vimeo_septuplet/sequences/",
- test_txt="/userhome/data/vimeo_septuplet/sep_testlist_DVC.txt")
- test_loader = ds.GeneratorDataset(test_dataset, column_names=["input_image", "ref_image", "quant_noise_feature", "quant_noise_z", "quant_noise_mv"],
- num_parallel_workers=2, shuffle=False, python_multiprocessing=False) #使用多线程
- test_loader = test_loader.batch(1, drop_remainder=False)
-
- stepoch = global_step // (train_dataset.__len__() // (gpu_per_batch))# * gpu_num))
- for epoch in range(stepoch, tot_epoch):
- if (epoch+1) % 5 == 0:
- adjust_learning_rate(optimizer, global_step)
- if global_step > tot_step:
- save_model(model, train_lambda, global_step)
- break
- global_step = train(epoch, traindata_loader, batch_num, model, optimizer, global_step, train_lambda)
- save_model(model, train_lambda, global_step)
-
- print('Beginning once test!')
- loss = test(epoch, model, test_loader)
- is_best = loss < best_loss
- best_loss = min(loss, best_loss)
- if is_best:
- save_model(model, train_lambda, str(global_step) + '_best')
- log_data = f"Update best model using best_loss: {best_loss}"
- logger.info(log_data)
-
- logger.info('Finish once training!')
|