|
- import argparse
- import os
-
- import numpy as np
- import torch
- import torch.nn as nn
- from PIL import Image
- from torch.autograd import Variable
- from torch.utils.data import DataLoader, Dataset
- from torchvision import transforms
- from torchvision.datasets import ImageFolder
- import Model.model as model
- from Model.context_model import Weighted_Gaussian
- import time
- from traindata import Traindataset
- import Util.torch_msssim as torch_msssim
-
- def adjust_learning_rate(optimizer, epoch, init_lr):
-
- if epoch < 9:
- lr = init_lr
- else:
- lr = init_lr * (0.5 ** ((epoch-3) // 3))
- if lr < 1e-6:
- lr = 1e-6
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr
- return lr
-
- def train(args):
-
- root = 'D:/CV_dataset/NIC_train'
-
- t0=time.time()
- train_data = ImageFolder(root,transform=transforms.Compose([transforms.RandomCrop(256),transforms.ToTensor()]))
- print(len(train_data))
- train_loader = DataLoader(train_data, batch_size=args.b_size,
- shuffle=True, num_workers=0)
- print(time.time()-t0)
-
- NIC_modformer = model.Image_coding_multi_hyper_modformer(3, args.M, args.N2, args.M, args.M // 2).cuda()
-
- model_existed = os.path.exists('./modnet/modformer_mse.pkl')
-
- if model_existed:
- NIC_modformer.load_state_dict(torch.load('./modnet/modformer_mse.pkl'))
-
- print('resumed the pre-trained model')
-
- else:
- print("pre-trained model not found")
-
- gpu_id = [id for id in range(args.gpu)]
- NIC_modformer = nn.DataParallel(NIC_modformer, device_ids=gpu_id)
-
- opt = torch.optim.Adam(NIC_modformer.parameters(), lr=args.lr)
- msssim_func = torch_msssim.MS_SSIM(max_val=1.).cuda()
-
- for epoch in range(16):
-
- rec_loss_tmp = 0
- last_time = time.time()
- train_bpp_tmp = 0
- mse_tmp = 0
- cur_lr = adjust_learning_rate(opt, epoch, args.lr)
-
- for step, batch_x in enumerate(train_loader):
-
- batch_x = batch_x[0]
- num_pixels = batch_x.size()[0] * \
- batch_x.size()[2] * batch_x.size()[3]
- batch_x = Variable(batch_x).cuda()
- b, _, _, _ = batch_x.size()
-
- dloss, bits = NIC_modformer(batch_x, if_training=1)
-
- train_bpp_total = bits / num_pixels
-
- l_rec = dloss + 0.01 * train_bpp_total
-
- opt.zero_grad()
-
- l_rec.backward()
-
- torch.nn.utils.clip_grad_norm_(NIC_modformer.parameters(), 5)
-
- opt.step()
-
- rec_loss_tmp += l_rec.item()
- train_bpp_tmp += train_bpp_total.item()
- mse_tmp += dloss.item()
-
- if step % 100 == 0:
- time_used = time.time() - last_time
- last_time = time.time()
- psnr = -10.0 * np.log10(mse_tmp / (step + 1))
- bpp_total = train_bpp_tmp / (1 + step)
-
- with open(os.path.join(args.out_dir, 'modformer_mse.txt'), 'a') as fd:
- fd.write(
- 'ep:%d step:%d time:%.1f lr:%.8f loss:%.6f bpp:%.4f psnr(dB):%.2f\n'
- % (epoch, step, time_used, cur_lr, rec_loss_tmp / (step + 1), bpp_total, psnr))
- fd.close()
-
- print('ep:%d step:%d time:%.1f lr:%.8f loss:%.6f bpp:%.4f psnr(dB):%.2f\n'
- % (epoch, step, time_used, cur_lr, rec_loss_tmp / (step + 1), bpp_total, psnr))
-
- if (step + 2) % 300 == 0:
- torch.save(NIC_modformer.module.state_dict(),
- os.path.join(args.out_dir, 'modformer_mse.pkl'), _use_new_zipfile_serialization=False)
- print('model written.')
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument("--M", type=int, default=256, help="the value of M")
- parser.add_argument("--b_size", type=int, default=6)
- parser.add_argument("--N2", type=int, default=192, help="the value of N2")
- parser.add_argument("--weight_dir", type=str, default='/userdata/Weights.zip/Weights')
- parser.add_argument("--lr", type=float, default=5e-6, help="initial learning rate.")
- parser.add_argument('--out_dir', type=str, default='./modnet')
- parser.add_argument("--gpu", type=int, default=1)
-
- args = parser.parse_args()
- print(args)
-
- train(args)
|