|
- import argparse
- import os
- import matplotlib.pyplot as plt
- import numpy as np
- import torch
- import torch.nn as nn
- from PIL import Image
- import torch
- import torch.nn.functional as F
- from torch.autograd import Variable
-
- import Model.model as model
- import Util.torch_msssim as torch_msssim
- import time
- import glob
- import struct
- from Model.context_model import Weighted_Gaussian
- from Util.metrics import evaluate
- import AE
-
-
- def adjust_learning_rate(optimizer, epoch, init_lr):
- """Sets the learning rate to the initial LR decayed by 2 every 3 epochs"""
- if epoch < 10:
- lr = init_lr
- else:
- lr = init_lr * (0.5 ** ((epoch - 7) // 3))
- if lr < 1e-6:
- lr = 1e-6
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr
- return lr
-
-
- def test(args):
-
- test_images = []
-
- if os.path.isdir(args.input):
- dirs = os.listdir(args.input)
- for dir in dirs:
- if dir == 'ClassD_Kodak':
- path = os.path.join(args.input, dir)
- if os.path.isdir(path):
- test_images += glob.glob(path + '/*.png')
- if os.path.isfile(path):
- test_images.append(path)
-
-
- else:
- test_images.append(args.input)
-
- test_lambda_list=[4,8,16,32,64,128]
- im_dirs = test_images
-
-
- whole = model.NIC_SQL(3, args.M, args.N2, args.M, args.M // 2).cuda()
-
- model_existed = os.path.exists('./SQLmse.pkl')
-
- if model_existed:
- whole.load_state_dict(torch.load('SQLmse.pkl'))
- print('resumed the trained model')
- else:
- print('model not exists')
-
- loss_func = nn.MSELoss()
-
- for k in range(len(test_lambda_list)):
- for im_dir in im_dirs:
-
- dec_time = 0
- enc_dec_time_start = time.time()
- image_name = im_dir.split('/')[-1].replace('.png', '')
-
- bin_dir = os.path.join(args.out_dir, 'enc.bin')
- rec_dir = os.path.join(args.out_dir, 'dec.png')
- file_object = open(bin_dir, 'wb')
- img = Image.open(im_dir)
- ori_img = np.array(img)
- img = ori_img
- H, W, _ = img.shape
- num_pixels = H * W
- C = 3
- H_offset = 0
- W_offset = 0
- out_img = np.zeros([H, W, C])
-
- Block_Num_in_Width = int(np.ceil(W / 2048))
- Block_Num_in_Height = int(np.ceil(H / 1024))
- img_block_list = []
- for i in range(Block_Num_in_Height):
- for j in range(Block_Num_in_Width):
- img_block_list.append(img[i * 1024:np.minimum((i + 1) * 1024, H),
- j * 2048:np.minimum((j + 1) * 2048, W), ...])
-
- Block_Idx = 0
- y_main_q_list = []
- for img in img_block_list:
- block_H = img.shape[0]
- block_W = img.shape[1]
-
- tile = 64.
- block_H_PAD = int(tile * np.ceil(block_H / tile))
- block_W_PAD = int(tile * np.ceil(block_W / tile))
- im = np.zeros([block_H_PAD, block_W_PAD, 3], dtype='float32')
- im[:block_H, :block_W, :] = img[:, :, :3] / 255.0
- im = torch.FloatTensor(im)
- im = im.permute(2, 0, 1).contiguous()
- im = im.view(1, C, block_H_PAD, block_W_PAD).cuda()
-
- Block_Idx += 1
-
- with torch.no_grad():
-
- y_main, y_hyper = whole.encoder(im)
-
- y_main_q = whole.SQL1(y_main,test_lambda_list[k])
- y_main_q = whole.SQL2(y_main_q,test_lambda_list[k])
-
- y_main_q = torch.round(y_main_q)
- y_main_q = torch.Tensor(y_main_q.cpu().numpy().astype(np.int)).cuda()
-
- Datas = torch.reshape(y_main_q, [-1]).cpu().numpy().astype(np.int).tolist()
-
- Max_Main=max(Datas)
- Min_Main=min(Datas)
-
- dec_time_start = time.time()
- rec = whole.decoder(y_main_q)
-
- #nn编解码 0.65s
-
- output_ = torch.clamp(rec, min=0., max=1.0)
- out = output_.data[0].cpu().numpy()
- out = out.transpose(1, 2, 0)
- out_img[H_offset: H_offset + block_H, W_offset: W_offset + block_W, :] = out[:block_H, :block_W, :]
- dec_time += (time.time() - dec_time_start)
- W_offset += block_W
- if W_offset >= W:
- W_offset = 0
- H_offset += block_H
-
- #重建 0.23s
-
- y_hyper_q, xp2 = whole.factorized_entropy_func(y_hyper, 2)
- y_hyper_q = torch.Tensor(y_hyper_q.cpu().numpy().astype(np.int)).cuda()
- hyper_dec = whole.p(whole.hyper_dec(y_hyper_q))
-
- xp3, params_prob = whole.context(y_main_q, hyper_dec)
- bpp_hyper = (torch.sum(torch.log(xp2)) / (-np.log(2) * num_pixels)).item()
- bpp_main = (torch.sum(torch.log(xp3)) / (-np.log(2) * num_pixels)).item()
- print('bpp_hyper_info:', bpp_hyper, 'bpp_main_info:', bpp_main, 'bpp_total_info:',
- bpp_hyper + bpp_main)
-
- #GMM 0.02s
-
- params_prob = params_prob.cpu()
- prob0, mean0, scale0, prob1, mean1, scale1, prob2, mean2, scale2 = [
- torch.chunk(params_prob, 9, dim=1)[i].squeeze(1) for i in range(9)]
- del params_prob
-
- # params_prob解构 时间0.006s
-
- probs = torch.stack([prob0, prob1, prob2], dim=-1)
- del prob0, prob1, prob2
-
- # probs结合 时间0.09s
-
- probs = F.softmax(probs, dim=-1)
-
- scale0 = torch.abs(scale0)
- scale1 = torch.abs(scale1)
- scale2 = torch.abs(scale2)
- scale0[scale0 < 1e-6] = 1e-6
- scale1[scale1 < 1e-6] = 1e-6
- scale2[scale2 < 1e-6] = 1e-6
-
- # torch操作 0.10s
-
- _, c, h, w = y_main_q.shape
- print("Main Channel:", c)
-
- if args.is_ac == 0:
- t0=time.time()
- lower = torch.zeros(1, c, h, w, Max_Main - Min_Main + 2)
-
- sample = np.arange(Min_Main, Max_Main + 1 + 1)
- sample = torch.FloatTensor(np.tile(sample, [1, c, h, w, 1])).cpu()
-
- m0 = torch.distributions.normal.Normal(mean0, scale0)
- m1 = torch.distributions.normal.Normal(mean1, scale1)
- m2 = torch.distributions.normal.Normal(mean2, scale2)
-
- for i in range(sample.shape[4]):
- lower0 = m0.cdf(sample[:, :, :, :, i] - 0.5)
- lower1 = m1.cdf(sample[:, :, :, :, i] - 0.5)
- lower2 = m2.cdf(sample[:, :, :, :, i] - 0.5)
- lower[:, :, :, :, i] = probs[:, :, :, :, 0] * lower0 + \
- probs[:, :, :, :, 1] * lower1 + probs[:, :, :, :, 2] * lower2
- print(time.time()-t0)
-
- else:
- t0 = time.time()
- lower = torch.zeros(1, c, h, w, Max_Main - Min_Main + 2)
- for ci in range(c):
- for hi in range(h):
- for wi in range(w):
- current_pixel = torch.reshape(y_main_q[0, ci, hi, wi], [-1]).cpu().numpy()
- value = int(current_pixel)
- sample = np.arange(value,value+2)
- sample = torch.FloatTensor(np.tile(sample, [1, 1])).cpu()
- m0 = torch.distributions.normal.Normal(mean0[:, ci, hi, wi], scale0[:, ci, hi, wi])
- m1 = torch.distributions.normal.Normal(mean1[:, ci, hi, wi], scale1[:, ci, hi, wi])
- m2 = torch.distributions.normal.Normal(mean2[:, ci, hi, wi], scale2[:, ci, hi, wi])
-
- for i in range(2):
- lower0 = m0.cdf(sample[:, i] - 0.5)
- lower1 = m1.cdf(sample[:, i] - 0.5)
- lower2 = m2.cdf(sample[:, i] - 0.5)
- lower[:, ci, hi, wi, i + value - Min_Main] = probs[:, ci, hi, wi, 0] * lower0 + \
- probs[:, ci, hi, wi, 1] * lower1 + probs[:, ci, hi, wi, 2] * lower2
- print(time.time() - t0)
-
- del probs, lower0, lower1, lower2
- # codebook构建 48.20s
-
- sample = np.arange(Min_Main, Max_Main + 1 + 1)
- sample = torch.FloatTensor(np.tile(sample, [1, c, h, w, 1])).cpu()
-
- precise = 16
- cdf_m = lower.data.cpu().numpy() * ((1 << precise) - (Max_Main-Min_Main + 1)) # [1, c, h, w ,Max-Min+1]
- cdf_m = cdf_m.astype(np.int32) + sample.numpy().astype(np.int32) - Min_Main
- cdf_main = np.reshape(cdf_m, [len(Datas), -1])
-
- # 精度舍入 0.79s
- # 熵编码计算总时间 58.78s 优化的重心
-
- Cdf_lower = list(map(lambda x, y: int(y[x - Min_Main]), Datas, cdf_main))
- Cdf_upper = list(map(lambda x, y: int(
- y[x - Min_Main]), Datas, cdf_main[:, 1:]))
- AE.encode_cdf(Cdf_lower, Cdf_upper, args.out_dir+"/main.bin")
- FileSizeMain = os.path.getsize(args.out_dir+"/main.bin")
- print("main.bin: %d bytes" % (FileSizeMain))
-
- # AE.encode_cdf时间 59.21s
-
- Min_V_HYPER = torch.min(y_hyper_q).cpu().numpy().astype(np.int).tolist()
- Max_V_HYPER = torch.max(y_hyper_q).cpu().numpy().astype(np.int).tolist()
- _, c, h, w = y_hyper_q.shape
- Datas_hyper = torch.reshape(
- y_hyper_q, [c, -1]).cpu().numpy().astype(np.int).tolist()
-
- sample = np.arange(Min_V_HYPER, Max_V_HYPER + 1 + 1)
- sample = np.tile(sample, [c, 1, 1])
- sample = torch.FloatTensor(sample).cuda()
- lower = torch.sigmoid(whole.factorized_entropy_func._logits_cumulative(
- sample - 0.5, stop_gradient=False))
- cdf_h = lower.data.cpu().numpy() * ((1 << precise) - (Max_V_HYPER -
- Min_V_HYPER + 1))
-
- cdf_h = cdf_h.astype(np.int) + sample.detach().cpu().numpy().astype(np.int) - Min_V_HYPER
- cdf_hyper = np.reshape(np.tile(cdf_h, [len(Datas_hyper[0]), 1, 1, 1]), [
- len(Datas_hyper[0]), c, -1])
-
-
- Cdf_0, Cdf_1 = [], []
- for i in range(c):
- Cdf_0.extend(list(map(lambda x, y: int(
- y[x - Min_V_HYPER]), Datas_hyper[i], cdf_hyper[:, i, :])))
- Cdf_1.extend(list(map(lambda x, y: int(
- y[x - Min_V_HYPER]), Datas_hyper[i], cdf_hyper[:, i, 1:])))
- AE.encode_cdf(Cdf_0, Cdf_1, args.out_dir+"/hyper.bin")
- FileSizeHyper = os.path.getsize(args.out_dir+"/hyper.bin")
- print("hyper.bin: %d bytes" % (FileSizeHyper))
-
- #hyper熵编码 0.13s
-
- Head_block = struct.pack('2H4h2I', block_H, block_W, Min_Main, Max_Main, Min_V_HYPER, Max_V_HYPER,
- FileSizeMain, FileSizeHyper)
- file_object.write(Head_block)
-
- with open(args.out_dir+"/main.bin", 'rb') as f:
- bits = f.read()
- file_object.write(bits)
- f.close()
- with open(args.out_dir+"/hyper.bin", 'rb') as f:
- bits = f.read()
- file_object.write(bits)
- f.close()
- del im
-
- file_object.close()
- with open(bin_dir, "rb") as f:
- bpp = len(f.read()) * 8. / num_pixels
- print('bpp_total_true:', bpp)
- f.close()
-
- out_img = np.round(out_img * 255.0)
- out_img = out_img.astype('uint8')
- img = Image.fromarray(out_img[:H, :W, :])
- img.save(rec_dir)
- [rgb_psnr, rgb_msssim, yuv_psnr, y_msssim] = evaluate(ori_img, out_img)
-
- class_name = im_dir.split('/')[-2]
- image_name = im_dir.split('/')[-1].replace('.png', '')
- enc_dec_time = time.time() - enc_dec_time_start
- enc_time = enc_dec_time - dec_time
-
- if args.is_ac==0:
- with open(os.path.join(args.out_dir, 'mse'+str(test_lambda_list[k])+'.txt'),
- "a") as f:
- f.write(class_name + '/' + image_name + '\t' + str(bpp) + '\t' + str(rgb_psnr) + '\t' + str(
- rgb_msssim) + '\t' + str(-10 * np.log10(1 - rgb_msssim)) +
- '\t' + str(yuv_psnr) + '\t' + str(y_msssim) + '\t' + str(-10 * np.log10(1 - y_msssim)) + '\t' + str(enc_time) + '\n')
- f.close()
- else :
- with open(os.path.join(args.out_dir, 'mse_ac'+str(test_lambda_list[k])+'.txt'),
- "a") as f:
- f.write(class_name + '/' + image_name + '\t' + str(bpp) + '\t' + str(rgb_psnr) + '\t' + str(
- rgb_msssim) + '\t' + str(-10 * np.log10(1 - rgb_msssim)) +
- '\t' + str(yuv_psnr) + '\t' + str(y_msssim) + '\t' + str(-10 * np.log10(1 - y_msssim)) + '\t' + str(enc_time) + '\n')
- f.close()
- del out_img
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument("--M", type=int, default=256, help="the value of M")
- parser.add_argument("--b_size", type=int, default=1)
- parser.add_argument("--N2", type=int, default=192, help="the value of N2")
- #parser.add_argument("--lmd", type=float, default=4 , help="Lambda for rate-distortion tradeoff.")
- parser.add_argument('--out_dir', type=str, default='output1/')
- parser.add_argument("--test_set", type=int, default=0)
- parser.add_argument("-i", "--input", type=str, help="Input Image",default='../../../userdata/nic_v02_test.zip/test')
- parser.add_argument("--gpu", type=int, default=1)
- parser.add_argument("--is_ac", type=int, default=0)
-
- args = parser.parse_args()
- print(args)
-
- if os.path.exists(args.out_dir):
- os.system('rm -rf '+args.out_dir)
- os.system('mkdir '+args.out_dir)
-
- try:
- os.system('cd ../../../userdata/test/ClassD_Kodak/')
- os.system('rm Thumbs.db')
- finally:
- os.system('cd ../../../userhome/gzq/nic_sql/')
-
- test(args)
|