|
- import argparse
- import os
-
- import numpy as np
- import torch
- import torch.nn as nn
- from PIL import Image
- from torch.autograd import Variable
- from torch.utils.data import DataLoader, Dataset
- from torchvision import transforms
- from torchvision.datasets import ImageFolder
- import Model.model as model
- from Model.context_model import Weighted_Gaussian
- import time
- from traindata import Traindataset
- import Util.torch_msssim as torch_msssim
-
-
- def train(args):
-
- root = 'D:/CV_dataset/Kodak'
-
- test_data = ImageFolder(root,transform=transforms.Compose([transforms.ToTensor()]))
- print(len(test_data))
- test_loader = DataLoader(test_data, batch_size=1,
- shuffle=False, num_workers=0)
-
- NIC_modformer = model.Image_coding_multi_hyper_modformer(3, args.M, args.N2, args.M, args.M // 2).cuda()
-
- model_existed = os.path.exists('./modnet/modformer_mse.pkl')
-
- if model_existed:
- NIC_modformer.load_state_dict(torch.load('./modnet/modformer_mse.pkl'))
-
- print('resumed the pre-trained model')
-
- else:
- print("pre-trained model not found")
-
- mse_func = nn.MSELoss()
-
- last_time = time.time()
- bpp_tmp = 0
- psnr_tmp = 0
- c = 0
- for step, batch_x in enumerate(test_loader):
-
- batch_x = batch_x[0]
- num_pixels = batch_x.size()[0] * \
- batch_x.size()[2] * batch_x.size()[3]
- batch_x = Variable(batch_x).cuda()
- b, _, _, _ = batch_x.size()
-
- with torch.no_grad():
- x1,x2,x3 = NIC_modformer.encoder(batch_x)
- xq3, xp3 = NIC_modformer.factorized_entropy_func(x3, 2)
- x4 = NIC_modformer.hyper_2_dec(xq3)
- hyper_2_dec = NIC_modformer.p_2(x4)
- xq2 = torch.round(x2)
- xp2 = NIC_modformer.gaussin_entropy_func_for_hyper(xq2, hyper_2_dec)
- x5 = NIC_modformer.hyper_1_dec(xq2)
- lmd = torch.ones((1, 1, 1)) * args.lmd
- lmd = lmd.cuda()
- mask = NIC_modformer.modformer_mask(x5, lmd)
- #print(mask)
-
- c += torch.mean(mask).item()/24
-
- x1 = x1 * mask
- hyper_dec = NIC_modformer.p(x5)
- xq1 = torch.round(x1)
- fake = NIC_modformer.decoder(xq1)
- xp1, _ = NIC_modformer.context(xq1, hyper_dec)
-
- bpp_total = torch.sum(torch.log(xp1) * mask) / (-np.log(2) * num_pixels) + torch.sum(torch.log(xp2)) / (
- -np.log(2) * num_pixels) + torch.sum(torch.log(xp3)) / (-np.log(2) * num_pixels)
-
- dloss = mse_func(fake,batch_x).item()
-
- bpp_tmp += bpp_total.item()
- psnr = -10.0 * np.log10(dloss)
-
- #print('img_id:%d bpp:%.4f psnr(dB):%.2f\n'
- # % (step+1, bpp_total, psnr))
-
- psnr_tmp += psnr
-
-
- bpp_total = bpp_tmp / 24
- psnr_total = psnr_tmp / 24
- print(c)
- print('total: bpp:%.4f psnr(dB):%.2f\n'
- % (bpp_total, psnr_total))
-
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument("--M", type=int, default=256, help="the value of M")
- parser.add_argument("--lmd", type=float, default=1)
- parser.add_argument("--N2", type=int, default=192, help="the value of N2")
- parser.add_argument('--out_dir', type=str, default='./modnet')
- parser.add_argument("--gpu", type=int, default=1)
-
- args = parser.parse_args()
- print(args)
-
- train(args)
|