|
- '''
- Contributors: Aolin Feng, Jianpin Lin, Dezhao Wang, Ding Ding, Wei Wang, Yueyu Hu, Haojie Liu, Tong Chen, Chuanmin Jia, Yihang Chen
- '''
-
- import argparse
- import math
- import os
- import struct
- import sys
- import time
- import glob
-
- import numpy as np
- import torch
- import torch.nn.functional as F
- import matplotlib as mpl
- import matplotlib.pyplot as plt
- from mpl_toolkits.mplot3d import Axes3D
- import matplotlib.image as mpimg
-
- import numpy as np
- from numpy import *
- from pylab import *
- import timeit
-
- from PIL import Image
- #import cv2
- import Model.model as model
- from Model.context_model import Weighted_Gaussian, Weighted_Gaussian_res
- from Util.metrics import evaluate
- from Util import torch_msssim
- from Util.block_metric import check_RD_GEO # blcok based metric
- from Util.config import dict
- from Util.generate_substitute import SubstituteGenerator
-
- os.environ['LRU_CACHE_CAPACITY'] = '9'
- #os.environ["CUDA_VISIBLE_DEVICES"] = "8"
- ############################## Load Configuration Parameters ########################
- GPU = dict['GPU']
- SAVE_REC = dict['SAVE_REC']
- USE_GEO = dict['USE_GEO']
- USE_VR_MODEL = dict['USE_VR_MODEL']
- USE_PREPROCESSING = dict['USE_PREPROCESSING']
- if USE_PREPROCESSING:
- num_steps = dict['num_steps']
- block_width = dict['CTU_size']
- block_height = dict['CTU_size']
-
- USE_MULTI_HYPER = dict['USE_MULTI_HYPER']
- USE_PREDICTOR = dict['USE_PREDICTOR']
-
- assert (USE_MULTI_HYPER and USE_VR_MODEL) is False
- assert (USE_PREDICTOR and bool(1 - USE_MULTI_HYPER)) is False
-
- # index - [0-15]
- if USE_VR_MODEL:
- models = ["mse_VR_low", "mse_VR_high", "msssim_VR_low", "msssim_VR_high"]
- max_lambdas = [64, 256, 1.28, 6.40]
- else:
- models = ["mse200", "mse400", "mse800", "mse1600", "mse3200", "mse6400", "mse12800", "mse25600",
- "msssim4", "msssim8", "msssim16", "msssim32", "msssim64", "msssim128", "msssim320", "msssim640"]
-
-
- # model_dir = "/model/ljp105/NIC_v01_baseline/"
- # VR_model_dir = "/model/ljp105/NIC_v02_VR_models"
- # @torch.no_grad()
-
- def histeq(im,nbr_bins = 256):
- """对一幅灰度图像进行直方图均衡化"""
- #计算图像的直方图
- #在numpy中,也提供了一个计算直方图的函数histogram(),第一个返回的是直方图的统计量,第二个为每个bins的中间值
- imhist,bins = histogram(im.flatten(),nbr_bins,normed= True)
- cdf = imhist.cumsum() #
- cdf = 255.0 * cdf / cdf[-1]
- #使用累积分布函数的线性插值,计算新的像素值
- im2 = interp(im.flatten(),bins[:-1],cdf)
- return im2.reshape(im.shape),cdf
-
-
- def b0(tensor, threshold): # (9,-1)
- b = 0
- for i in [1, 2, 4, 5, 7, 8]:
- if tensor[i] > threshold:
- b = 1
- break
-
- return b
-
-
- def inference_rd(im_dirs, out_dir, model_dir, model_index, lambda_rd_ori):
- if os.path.exists(out_dir) is False:
- os.makedirs(out_dir)
-
- if USE_VR_MODEL:
- lambda_rd_max = max_lambdas[model_index]
- if lambda_rd_ori > 1.2 * lambda_rd_max:
- lambda_rd_ori = 1.2 * lambda_rd_max
- lambda_rd_nom = lambda_rd_ori / lambda_rd_max
- lambda_rd_nom_scaled = int(lambda_rd_nom / 1.2 * pow(2, 16))
- lambda_rd_nom_used = lambda_rd_nom_scaled / pow(2, 16) * 1.2
- lambda_rd_numpy = np.zeros((1, 1), np.float32)
- lambda_rd_numpy[0, 0] = lambda_rd_nom_used
- lambda_rd = torch.Tensor(lambda_rd_numpy)
- M, N2 = 192, 128
- if (model_index == 1) or (model_index == 3):
- M, N2 = 256, 192
- image_comp = model.Image_coding(3, M, N2, M, M // 2)
- context = Weighted_Gaussian(M)
-
- log_name = os.path.join(out_dir,
- models[model_index] + '_lmbda' + str(int(lambda_rd_ori * 100)) + '_test_' + str(
- block_width) + '_RD.txt')
- else:
- M, N2 = 192, 128
- if (model_index == 6) or (model_index == 7) or (model_index == 14) or (model_index == 15):
- M, N2 = 256, 192
- if USE_MULTI_HYPER:
- if USE_PREDICTOR:
- image_comp = model.Image_coding_multi_hyper_res(3, M, N2, M, M // 2)
- else:
- image_comp = model.Image_coding_multi_hyper(3, M, N2, M, M // 2)
- else:
- image_comp = model.Image_coding(3, M, N2, M, M // 2)
- if USE_PREDICTOR:
- context = Weighted_Gaussian_res(M)
- else:
- context = Weighted_Gaussian(M)
- lambda_rd = None
- log_name = os.path.join(out_dir, models[model_index] + '_test_' + str(block_width) + '_RD.txt')
-
- if USE_PREPROCESSING:
- lmbda_list = [200, 400, 800, 1600, 3200, 6400, 12800, 25600, 4, 8, 16, 32, 64, 128, 320, 640]
- stepsize_list = [150, 75, 30, 10, 10, 5, 3, 1, 100, 10, 7, 5, 1, 1, 1, 0.3]
-
- if USE_VR_MODEL:
- lmbda = lambda_rd_ori * 100
- step_size = stepsize_list[lmbda_list.index(lmbda)]
- reconstruction_metric = 'mse' if model_index <= 1 else 'msssim'
- else:
- step_size = stepsize_list[model_index]
- lmbda = lmbda_list[model_index]
- reconstruction_metric = 'mse' if model_index <= 7 else 'msssim'
-
- substitute_generator = SubstituteGenerator(model=image_comp, context_model=context, llambda=lmbda,
- num_steps=num_steps, step_size=step_size,
- reconstruct_metric=reconstruction_metric,
- )
-
- ######################### Load Model #########################
- image_comp.load_state_dict(torch.load(
- os.path.join(model_dir, models[model_index] + r'.pkl'), map_location='cpu'))
- context.load_state_dict(torch.load(
- os.path.join(model_dir, models[model_index] + r'p.pkl'), map_location='cpu'))
- if GPU:
- image_comp.cuda()
- context.cuda()
- #################### Compress Each Image ###################
- time_red = 0
- psnr_loss = 0
- loops = 0
- cnt = 0
- for im_dir in im_dirs:
- print(cnt)
- cnt += 1
- dec_time = 0
- enc_dec_time_start = time.time()
- bin_dir = os.path.join(out_dir, 'enc.bin')
-
- file_object = open(bin_dir, 'wb')
- ######################### Read Image #########################
- img = Image.open(im_dir)
- source_img = np.array(img)
- img = source_img / 255.0
- H, W, _ = img.shape
- print(img.shape)
- num_pixels = H * W
- C = 3
- Head = struct.pack('2HB4?H', H, W, model_index, USE_GEO, USE_VR_MODEL, USE_MULTI_HYPER, USE_PREDICTOR,
- block_width)
- file_object.write(Head)
- if USE_VR_MODEL:
- Head_lmbda = struct.pack('H', lambda_rd_nom_scaled)
- file_object.write(Head_lmbda)
-
- H_offset = 0
- W_offset = 0
- ######################### Spliting Image #########################
- Block_Num_in_Width = int(np.ceil(W / block_width))
- Block_Num_in_Height = int(np.ceil(H / block_height))
- img_block_list = []
- for i in range(Block_Num_in_Height):
- for j in range(Block_Num_in_Width):
- img_block_list.append(img[i * block_height:np.minimum((i + 1) * block_height, H),
- j * block_width:np.minimum((j + 1) * block_width, W), ...])
-
- ######################### Padding Image #########################
- Block_Idx = 0
- for img in img_block_list: # Traverse CTUs
- block_H = img.shape[0]
- block_W = img.shape[1]
- tile = 64.
- block_H_PAD = int(tile * np.ceil(block_H / tile))
- block_W_PAD = int(tile * np.ceil(block_W / tile))
- im = np.zeros([block_H_PAD, block_W_PAD, 3], dtype='float32')
- im[:block_H, :block_W, :] = img[:, :, :3]
- im = torch.FloatTensor(im)
- im = im.permute(2, 0, 1).contiguous()
- im = im.view(1, C, block_H_PAD, block_W_PAD)
- if GPU:
- im = im.cuda()
- if USE_VR_MODEL:
- lambda_rd = lambda_rd.cuda()
- print('====> Encoding Image:', im_dir, "%dx%d" % (block_H, block_W), 'to', out_dir,
- " Block Idx: %d" % (Block_Idx))
- Block_Idx += 1
- # begin processing CTU
- im_block_list = []
- im_block_loc_list = []
- im_block_list.append(im) # list size = 1
- im_block_loc_list.append([0, 0, block_H_PAD, block_W_PAD])
-
- for im_block_loc, im_block in zip(im_block_loc_list, im_block_list): # Traverse CUs
- ############################ Geometric Flip and rotate ########################
- if USE_GEO:
- _, _, geo_index, _, orig_rd = check_RD_GEO(im_block, lambda_rd, image_comp, context, model_index)
- i_rot = int(geo_index % 4)
- if geo_index < 4:
- im_block = torch.rot90(im_block, k=i_rot, dims=[2, 3])
- else:
- im_block = torch.rot90(torch.flip(im_block, dims=[2]), k=i_rot, dims=[2, 3])
-
- ############################ Preprocessing to find a substitute ########################
- if USE_PREPROCESSING:
- if USE_VR_MODEL:
- im_block = substitute_generator.perturb(orig_image=im_block, orig_rd=orig_rd,
- lambda_rd=lambda_rd)
- else:
- im_block = substitute_generator.perturb(orig_image=im_block, orig_rd=orig_rd)
-
- if USE_MULTI_HYPER:
- with torch.no_grad():
- fig = plt.figure()
- ax = Axes3D(fig)
-
- y_main, y_hyper, y_hyper_2 = image_comp.encoder(im_block, lambda_rd)
- _, c, h, w = y_main.size()
- y_hyper_q = torch.round(y_hyper)
- cm_hot = mpl.cm.get_cmap('coolwarm')
- rangelatent = []
- for ci in range(M):
- current_channel = torch.squeeze(y_main[0, ci, :, :].view(1, -1)).cpu().numpy().tolist()
- rangelatent.append(max(current_channel) - min(current_channel))
-
- c_list=[3,14,16,47,69,95,122,131,135,251,254,255]
-
- hyper_dec = image_comp.p(image_comp.hyper_1_dec(y_hyper_q))
-
- # Hyper 2
- y_hyper_2_q, xp3 = image_comp.factorized_entropy_func(y_hyper_2, 2)
-
- # Hyper 1
- tmp2 = image_comp.hyper_2_dec(y_hyper_2_q)
- if USE_PREDICTOR:
- y_hyper_predict = image_comp.Y_2(tmp2)
- y_hyper_res = y_hyper - y_hyper_predict
- y_hyper_q = torch.round(y_hyper_res)
- else:
- y_hyper_q = torch.round(y_hyper)
-
- # Main
- if USE_PREDICTOR:
- tmp = image_comp.hyper_1_dec(y_hyper_q + y_hyper_predict)
- else:
- tmp = image_comp.hyper_1_dec(y_hyper_q)
- if USE_PREDICTOR:
- y_main_predict = image_comp.Y(tmp)
- y_main_res = y_main - y_main_predict
- y_main_q = torch.round(y_main_res)
- else:
- y_main_q = torch.round(y_main)
-
- xp3, params_prob = context(y_main_q, y_main_q + y_main_predict, hyper_dec)
- entropy = -torch.log2(xp3)
- entropy_channel = torch.mean(entropy,dim=(2,3),keepdim=False)[0,:].cpu().numpy().tolist()
- for ci in range(len(entropy_channel)):
- print(ci, entropy_channel[ci])
- print(entropy_channel.index(min(entropy_channel)))
- params_prob = torch.squeeze(params_prob.view(1, 9, -1)).permute(1, 0).cpu().numpy().tolist()
- if GPU:
- y_main_q = y_main_q.cuda()
-
- _, c, h, w = y_main_q.size()
- a = 0
- ticks = 0
- threshold = 0.04
- Datas = torch.reshape(y_main_q, [-1]).cpu().numpy().astype(np.int).tolist()
-
- while a == 0:
- b = 0
- e_data = []
- e_params_prob = []
- sc = np.zeros(M)
- e_id = []
- for i in range(c * h * w):
- if b0(params_prob[i], threshold) == 1:
- e_id.append(1)
- e_data.append(Datas[i])
- e_params_prob.append(params_prob[i])
- else:
- if Datas[i] != 0:
- print('!!!')
- e_id.append(1)
- ticks += 1
- threshold -= 0.001
- b = 1
- y_main_q = y_main_q.cpu()
- y_main_predict = y_main_predict.cpu()
- ci = i // (h * w)
- hi = (i - ci * h * w) // w
- wi = i - ci * h * w - hi * w
- y_main_q[0, ci, hi, wi] = -y_main_predict[0, ci, hi, wi]
- y_main_q = y_main_q.cuda()
- y_main_predict = y_main_predict.cuda()
- xp3, params_prob = context(y_main_q, y_main_q + y_main_predict, hyper_dec)
- Datas = torch.reshape(y_main_q, [-1]).cpu().numpy().astype(np.int).tolist()
- params_prob = torch.squeeze(params_prob.view(1, 9, -1)).permute(1,
- 0).cpu().numpy().tolist()
- break
- else:
- e_id.append(0)
- sc[i // (h * w)] += 1 / (h * w)
-
- a = 1 - b
- MAXD = max(Datas)
- MIND=min(Datas)
- for ci in c_list:
- eid = e_id[ci*h*w:(ci+1)*h*w]
- eid = np.reshape(np.array(eid),(48,32))
- rec = y_main[0, ci, :, :]
-
- com = np.random.rand(48,32)*7/11 + 4/11
-
- com = cm_hot(com)
- com = np.uint8(com*150)
- com = Image.fromarray(com)
- com.save('map/'+str(ci)+'com.bmp')
-
- com2 = np.random.rand(48, 32)*7/11 + 4/11
- com2 = com2 * eid
- com2 = cm_hot(com2)
- com2 = np.uint8(com2 * 150)
- com2 = Image.fromarray(com2)
- com2.save('map/' + str(ci) + 'com2.bmp')
-
- # rec = torch.clamp(rec,min=0,max=1)
- rec = rec.cpu().numpy()
- rec_l = y_main.view(1,-1)[0,:].cpu().numpy().tolist()
- MAX=max(rec_l)
- MIN=min(rec_l)
-
-
- rec = (rec-MIN)/(MAX-MIN)
- rec, _ = histeq(rec*255, 255)
- rec = 1-rec/255
- rec = cm_hot(rec)
-
- bar = np.zeros((48*10,2*10))
- for hi in range(48*10):
- for wi in range(2*10):
- bar[hi][wi] = 1-hi/480
-
- bar = cm_hot(bar)
- bar = np.uint8(bar*255)
- bar = Image.fromarray(bar)
- bar.save('map/bar.bmp')
-
- rec = np.uint8(rec*150)
- rec = Image.fromarray(rec)
- #rec.show()
- rec.save('map/'+str(ci)+'.bmp')
- print('feature map saved')
-
-
- sc = sc.tolist()
- for ci in range(256):
- print(ci,sc[ci])
-
- loops += (ticks + 1) / len(img_block_list)
- y_main_q = y_main_q.cuda()
- y_main_predict = y_main_predict.cuda()
- rec1 = image_comp.decoder(y_main_q + y_main_predict)
- rec1 = torch.clamp(rec1, min=0, max=1)
- rec1 = torch.squeeze(rec1, dim=0)
- rec1 = rec1.permute(1, 2, 0)
- rec1 = 255 * rec1
- rec1 = rec1.cpu().numpy().astype(np.uint8)
- rec1 = Image.fromarray(rec1)
- rec1.save('nic_helic20.png')
- mse = torch.mean((rec1 - im_block) ** 2).item()
- psnr1 = -10 * np.log10(mse)
- psnr_loss += (psnr1 - psnr0) / len(img_block_list)
- time_red += (1 - len(e_data) / len(Datas)) / len(img_block_list)
- ec = []
- for i in range(c):
- for j in range(h * w):
- if e_id[i * h * w + j] == 1:
- ec.append(i)
- break
- # print(len(ec))
- W_offset += block_W
- if W_offset >= W:
- W_offset = 0
- H_offset += block_H
-
- return time_red, psnr_loss, loops
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument("-i", "--input", type=str, default='D:/CV_dataset/NIC_test/ClassD_Kodak', help="Input Image")
- parser.add_argument("-o", "--output", type=str, default='performance2/', help="Output Bin(encode)/Image(decode)")
- parser.add_argument("-m_dir", "--model_dir", type=str, default='./',
- help="Directory containing trained models")
- parser.add_argument("-m", "--model", type=int, default=1, help="Model Index [0-9]")
- parser.add_argument("--lambda_rd", type=float, default=16, help="Input lambda for variable-rate models")
- # parser.add_argument('--encode', dest='coder_flag', action='store_true')
- # parser.add_argument('--decode', dest='coder_flag', action='store_false')
- # parser.add_argument("--block_width", type=int, default=2048, help="coding block width")
- # parser.add_argument("--block_height", type=int, default=1024, help="coding block height")
- parser.add_argument("--number", type=int, default=25000, help="Number of Coding Images")
- args = parser.parse_args()
-
- num = 0
- test_image_paths = []
- dirs = os.listdir(args.input)
- for dir in dirs:
- path = os.path.join(args.input, dir)
- if dir == 'Thumbs.db' or dir != '4.png':
- continue
- if os.path.isdir(path):
- test_image_paths += glob.glob(path + '/*.png')[0:args.number]
- if os.path.isfile(path):
- if num < args.number:
- test_image_paths.append(path)
- num += 1
- print(len(test_image_paths))
- time_red, psnr_loss, loops = inference_rd(test_image_paths, args.output, args.model_dir, args.model, args.lambda_rd)
- print(time_red / len(test_image_paths), psnr_loss / len(test_image_paths), loops / len(test_image_paths))
-
-
-
-
-
-
-
|