|
- '''
- Contributors: Aolin Feng, Dezhao Wang, Yueyu Hu, Haojie Liu, Tong Chen, Chuanmin Jia, Yihang Chen
- '''
-
- import argparse
- import os
-
- import numpy as np
- import torch
- import torch.nn as nn
- from PIL import Image
- from torch.autograd import Variable
- from torch.utils.data import DataLoader, Dataset
- from torchvision import transforms
- from torchvision.datasets import ImageFolder
-
- import Model.model as model
- import Util.torch_msssim as torch_msssim
- from Model.context_model import Weighted_Gaussian, Weighted_Gaussian_res
- import time
- import math
- from collections import OrderedDict
- from Util.config import dict
-
- USE_MULTI_HYPER = False
- USE_PREDICTOR = False
-
- #os.environ["CUDA_VISIBLE_DEVICES"] = "2"
-
- def adjust_learning_rate(optimizer, epoch, init_lr):
- """Sets the learning rate to the initial LR decayed by 2 every 3 epochs"""
-
- lr = init_lr * (0.5 ** ((epoch) // 3))
- if lr < 1e-6:
- lr = 1e-6
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr
- return lr
-
-
- def train(args):
- # NIC_Dataset
- train_data = ImageFolder(root='D:/CV_dataset/NIC_train', transform=transforms.Compose([transforms.RandomCrop(256),transforms.ToTensor()]))
- train_loader = DataLoader(train_data, batch_size=args.bs,
- shuffle=True, num_workers=0)
-
- NIC = model.Image_coding(3, args.M, args.N2, args.M, args.M // 2).cuda()
- context = Weighted_Gaussian(args.M).cuda()
- modnet = model.modnet(args.M).cuda()
- context_mod = Weighted_Gaussian(args.M).cuda()
- model_existed = os.path.exists('modnet.pkl') and os.path.exists('context_mod.pkl')
- if model_existed:
- modnet.load_state_dict(torch.load('modnet.pkl'))
- context_mod.load_state_dict(torch.load('context_mod.pkl'))
- print('loaded')
-
- NIC.load_state_dict(torch.load('mse25600.pkl'))
- context.load_state_dict(torch.load('mse25600p.pkl'))
-
- #NIC_modnet = nn.DataParallel(NIC_modnet, device_ids=[0, 1])
- opt1 = torch.optim.Adam(modnet.parameters(), lr=args.lr)
- opt2 = torch.optim.Adam(context_mod.parameters(), lr=args.lr)
-
- for epoch in range(20):
-
- rec_loss_tmp = 0
- last_time = time.time()
- train_bpp_tmp = 0
- mse_tmp = 0
-
- cur_lr = adjust_learning_rate(opt1, epoch, args.lr)
- adjust_learning_rate(opt2, epoch, args.lr)
-
- for step, batch_x in enumerate(train_loader):
- batch_x = batch_x[0]
- num_pixels = batch_x.size()[0] * \
- batch_x.size()[2] * batch_x.size()[3]
- batch_x = Variable(batch_x).cuda()
- b = batch_x.size()[0]
- lmd_set = [i + 1 for i in range(8)] + [4 * i + 12 for i in range(14)] + [8 * i + 80 for i in range(23)]
- random_index = np.array(np.random.rand(b) * 45, dtype=np.int)
-
- random_lambda = [lmd_set[idx] for idx in random_index]
- random_lambda = np.array(random_lambda, dtype=np.float32)
-
- lmd = torch.from_numpy(random_lambda).cuda()
- lmd = torch.reshape(lmd, (b, 1)).cuda()
-
- x1, x2 = NIC.encoder(batch_x)
- xq2, xp2 = NIC.factorized_entropy_func(x2, 2)
- x3 = NIC.hyper_dec(xq2)
- hyper_dec = NIC.p(x3)
- xq1 = model.UniverseQuant.apply(x1)
- xp1, _ = context(xq1, hyper_dec)
- modnet_mask = modnet(xq1,xp1,lmd)
- xq1 = xq1 * modnet_mask
- xq1 = model.UniverseQuant.apply(xq1) #再量化一次
- xp1,_ = context_mod(xq1,hyper_dec)
- fake = NIC.decoder(xq1)
-
- delta = (fake - batch_x) ** 2
- delta = delta.view(b, -1)
- batch_mse = torch.mean(delta, dim=1, keepdim=False).cuda()
- dloss = torch.mean(lmd * batch_mse)
-
- rate = torch.sum(torch.log(xp1)) / (-np.log(2)) + torch.sum(torch.log(xp2)) / (-np.log(2))
- bpp = rate/num_pixels
-
- l_rec = dloss + 0.01 * bpp
-
- opt1.zero_grad()
- opt2.zero_grad()
-
- l_rec.backward()
-
- # gradient clip
- torch.nn.utils.clip_grad_norm_(modnet.parameters(), 5)
- torch.nn.utils.clip_grad_norm_(context_mod.parameters(), 5)
-
- opt1.step()
- opt2.step()
-
- rec_loss_tmp += l_rec.item()
- mse_tmp += dloss.item()
- train_bpp_tmp += bpp.item()
-
- if step % 100 == 0:
- with open(os.path.join(args.out_dir, 'train_modnet.txt'), 'a') as fd:
- time_used = time.time()-last_time
- last_time = time.time()
- mse = mse_tmp / (step+1)
- psnr = 10.0 * np.log10(1./mse)
- bpp_total = train_bpp_tmp / (step+1)
-
- fd.write('ep:%d step:%d time:%.1f lr:%.8f loss:%.6f MSE:%.6f bpp_total:%.4f psnr:%.2f\n'
- %(epoch, step, time_used, cur_lr, rec_loss_tmp/(step+1), mse,bpp_total, psnr))
- print('ep:%d step:%d time:%.1f lr:%.8f loss:%.6f MSE:%.6f bpp_total:%.4f psnr:%.2f\n'
- %(epoch, step, time_used, cur_lr, rec_loss_tmp/(step+1), mse,bpp_total, psnr))
- fd.close()
-
- if (step+1) % 300 == 0:
- torch.save(modnet.state_dict(),
- os.path.join(args.out_dir, 'modnet.pkl'))
- torch.save(context_mod.state_dict(),
- os.path.join(args.out_dir, 'context_mod.pkl'))
- print('saved')
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument("--M", type=int, default=256, help="the value of M")
- parser.add_argument("--N2", type=int, default=192, help="the value of N2")
- parser.add_argument("--bs", type=int, default=6)
- parser.add_argument("--lr", type=float, default=5e-5, help="initial learning rate.")
- parser.add_argument('--out_dir', type=str, default='./')
-
-
- args = parser.parse_args()
- print(args)
- train(args)
|