|
- '''
- Contributors: Aolin Feng, Jianpin Lin, Dezhao Wang, Ding Ding, Wei Wang, Yueyu Hu, Haojie Liu, Tong Chen, Chuanmin Jia, Yihang Chen, Ziqing Ge
- '''
-
- import argparse
- import math
- import os
- import struct
- import sys
- import time
- import glob
-
- import numpy as np
- import torch
- import torch.nn.functional as F
- from PIL import Image
- import random
-
- # import Util.AE as AE
- import AE
- import Model.model as model
- from Model.context_model import Weighted_Gaussian, Weighted_Gaussian_res
- from Util.metrics import evaluate
- from Util import torch_msssim
- from Util.block_metric import check_RD_GEO # blcok based metric
- from Util.config import dict
- from Util.generate_substitute import SubstituteGenerator
- from Model.ol_post_enhancement import Enhancement_net
- from Model.ol_post_enhancement import online_training
- from Model.acceleration_utils import pack_bools, unpack_bools
-
- os.environ['LRU_CACHE_CAPACITY'] = '1'
- ############################## Load Configuration Parameters ########################
- GPU = dict['GPU']
- SAVE_REC = dict['SAVE_REC']
- USE_GEO = dict['USE_GEO']
- USE_VR_MODEL = dict['USE_VR_MODEL']
- USE_PREPROCESSING = dict['USE_PREPROCESSING']
- USE_POSTPROCESSING = dict['USE_POSTPROCESSING']
-
- if USE_PREPROCESSING:
- num_steps = dict['num_steps']
- if USE_POSTPROCESSING:
- post_processing_epochs = dict['Post_processing_epochs']
- post_processing_learning_rate = dict['Post_processing_learning_rate']
- postprocessing_CTU = dict['Postprocessing_CTU']
-
- block_width = dict['CTU_size']
- block_height = dict['CTU_size']
-
- USE_MULTI_HYPER = dict['USE_MULTI_HYPER']
- USE_PREDICTOR = dict['USE_PREDICTOR']
- USE_ACCELERATION = dict['USE_ACCELERATION']
-
- assert (USE_MULTI_HYPER and USE_VR_MODEL) is False
- assert (USE_PREDICTOR and bool(1-USE_MULTI_HYPER)) is False
-
- # index - [0-15]
- if USE_VR_MODEL:
- models = ["mse_VR_low", "mse_VR_high", "msssim_VR_low", "msssim_VR_high"]
- max_lambdas = [64, 256, 1.28, 6.40]
- else:
- models = ["mse200", "mse400", "mse800", "mse1600", "mse3200", "mse6400", "mse12800", "mse25600",
- "msssim4", "msssim8", "msssim16", "msssim32", "msssim64", "msssim128", "msssim320", "msssim640"]
- #model_dir = "/model/ljp105/NIC_v01_baseline/"
- #VR_model_dir = "/model/ljp105/NIC_v02_VR_models"
- # @torch.no_grad()
- def inference_rd(im_dirs, out_dir, model_dir, model_index, lambda_rd_ori):
- if os.path.exists(out_dir) is False:
- os.makedirs(out_dir)
-
- if USE_VR_MODEL:
- lambda_rd_max = max_lambdas[model_index]
- if lambda_rd_ori > 1.2 * lambda_rd_max:
- lambda_rd_ori = 1.2 * lambda_rd_max
- lambda_rd_nom = lambda_rd_ori / lambda_rd_max
- lambda_rd_nom_scaled = int(lambda_rd_nom / 1.2 * pow(2, 16))
- lambda_rd_nom_used = lambda_rd_nom_scaled / pow(2, 16) * 1.2
- lambda_rd_numpy = np.zeros((1, 1), np.float32)
- lambda_rd_numpy[0, 0] = lambda_rd_nom_used
- lambda_rd = torch.Tensor(lambda_rd_numpy)
- M, N2 = 192, 128
- if (model_index == 1) or (model_index == 3):
- M, N2 = 256, 192
- image_comp = model.Image_coding(3, M, N2, M, M // 2)
- context = Weighted_Gaussian(M)
-
- log_name = os.path.join(out_dir, models[model_index]+'_lmbda'+str(int(lambda_rd_ori*100)) + '_test_' + str(block_width) + '_RD.txt')
- else:
- M, N2 = 192, 128
- if (model_index == 6) or (model_index == 7) or (model_index == 14) or (model_index == 15):
- M, N2 = 256, 192
- if USE_MULTI_HYPER:
- if USE_PREDICTOR:
- image_comp = model.Image_coding_multi_hyper_res(3, M, N2, M, M // 2)
- else:
- image_comp = model.Image_coding_multi_hyper(3, M, N2, M, M // 2)
- else:
- image_comp = model.Image_coding(3, M, N2, M, M // 2)
- if USE_PREDICTOR:
- context = Weighted_Gaussian_res(M)
- else:
- context = Weighted_Gaussian(M)
- lambda_rd = None
- log_name = os.path.join(out_dir,models[model_index]+'_test_'+str(block_width)+'_RD.txt')
-
- if USE_PREPROCESSING:
- lmbda_list = [200, 400, 800, 1600, 3200, 6400, 12800, 25600, 4, 8, 16, 32, 64, 128, 320, 640]
- stepsize_list = [150, 75, 30, 10, 10, 5, 3, 1, 100, 10, 7, 5, 1, 1, 1, 0.3]
-
- if USE_VR_MODEL:
- lmbda = lambda_rd_ori * 100
- step_size= stepsize_list[lmbda_list.index(lmbda)]
- reconstruction_metric = 'mse' if model_index <= 1 else 'msssim'
- else:
- step_size = stepsize_list[model_index]
- lmbda = lmbda_list[model_index]
- reconstruction_metric = 'mse' if model_index <= 7 else 'msssim'
-
- substitute_generator = SubstituteGenerator(model=image_comp, context_model=context, llambda=lmbda,
- num_steps=num_steps, step_size=step_size,
- reconstruct_metric=reconstruction_metric,
- )
-
- ######################### Load Model #########################
- image_comp.load_state_dict(torch.load(
- os.path.join(model_dir, models[model_index] + r'.pkl'), map_location='cpu'))
- context.load_state_dict(torch.load(
- os.path.join(model_dir, models[model_index] + r'p.pkl'), map_location='cpu'))
- if GPU:
- image_comp.cuda()
- context.cuda()
- #################### Compress Each Image ###################
- bpp_list = []
- rgb_psnr_list = []
- rgb_msssim_list = []
- for im_dir in im_dirs:
- dec_time = 0
- enc_dec_time_start = time.time()
- bin_dir = os.path.join(out_dir,'enc.bin')
- # bin_dir = os.path.join(out_dir, im_dir.split('/')[-1].replace('.png','enc.bin'))
- rec_dir = os.path.join(out_dir, im_dir.split('/')[-1].replace('.png','dec.png'))
- file_object = open(bin_dir, 'wb')
- ######################### Read Image #########################
- img = Image.open(im_dir)
- source_img = np.array(img)
- img = source_img/255.0
- H, W, _ = img.shape
- print(img.shape)
- num_pixels = H * W
- C = 3
- Head = struct.pack('2HB4?H', H, W, model_index, USE_GEO, USE_VR_MODEL, USE_MULTI_HYPER, USE_PREDICTOR, block_width)
- file_object.write(Head)
- if USE_VR_MODEL:
- Head_lmbda = struct.pack('H', lambda_rd_nom_scaled)
- file_object.write(Head_lmbda)
- out_img = np.zeros([H, W, C]) # recon image
- H_offset = 0
- W_offset = 0
- ######################### Spliting Image #########################
- Block_Num_in_Width = int(np.ceil(W / block_width))
- Block_Num_in_Height = int(np.ceil(H / block_height))
- img_block_list = []
- for i in range(Block_Num_in_Height):
- for j in range(Block_Num_in_Width):
- img_block_list.append(img[i * block_height:np.minimum((i + 1) * block_height, H),j * block_width:np.minimum((j + 1) * block_width,W),...])
-
- ######################### Padding Image #########################
- Block_Idx = 0
- for img in img_block_list: # Traverse CTUs
- block_H = img.shape[0]
- block_W = img.shape[1]
- tile = 64.
- block_H_PAD = int(tile * np.ceil(block_H / tile))
- block_W_PAD = int(tile * np.ceil(block_W / tile))
- im = np.zeros([block_H_PAD, block_W_PAD, 3], dtype='float32')
- im[:block_H, :block_W, :] = img[:, :, :3]
- im = torch.FloatTensor(im)
- im = im.permute(2, 0, 1).contiguous()
- im = im.view(1, C, block_H_PAD, block_W_PAD)
- if GPU:
- im = im.cuda()
- if USE_VR_MODEL:
- lambda_rd = lambda_rd.cuda()
- print('====> Encoding Image:', im_dir, "%dx%d" % (block_H, block_W), 'to', out_dir, " Block Idx: %d" % (Block_Idx))
- Block_Idx += 1
- # begin processing CTU
- im_block_list = []
- im_block_loc_list = []
- im_block_list.append(im) # list size = 1
- im_block_loc_list.append([0,0,block_H_PAD,block_W_PAD])
-
- for im_block_loc, im_block in zip(im_block_loc_list, im_block_list): # Traverse CUs
- ############################ Geometric Flip and rotate ########################
- if USE_GEO:
- _, _, geo_index, _, orig_rd = check_RD_GEO(im_block, lambda_rd, image_comp, context, model_index)
- i_rot =int(geo_index % 4)
- if geo_index < 4:
- im_block = torch.rot90(im_block, k=i_rot, dims=[2, 3])
- else:
- im_block = torch.rot90(torch.flip(im_block, dims=[2]), k=i_rot, dims=[2, 3])
-
- ############################ Preprocessing to find a substitute ########################
- if USE_PREPROCESSING:
- if USE_VR_MODEL:
- im_block = substitute_generator.perturb(orig_image=im_block, orig_rd = orig_rd, lambda_rd=lambda_rd)
- else:
- im_block = substitute_generator.perturb(orig_image=im_block, orig_rd = orig_rd)
-
- if USE_MULTI_HYPER:
- with torch.no_grad():
- y_main, y_hyper, y_hyper_2 = image_comp.encoder(im_block, lambda_rd)
-
- # Hyper 2
- y_hyper_2_q, xp3 = image_comp.factorized_entropy_func(y_hyper_2, 2)
-
- # Hyper 1
- tmp2 = image_comp.hyper_2_dec(y_hyper_2_q)
- hyper_2_dec = image_comp.p_2(tmp2)
- if USE_PREDICTOR:
- y_hyper_predict = image_comp.Y_2(tmp2)
- y_hyper_res = y_hyper - y_hyper_predict
- y_hyper_q = torch.round(y_hyper_res)
- else:
- y_hyper_q = torch.round(y_hyper)
-
- # Main
- if USE_PREDICTOR:
- tmp = image_comp.hyper_1_dec(y_hyper_q + y_hyper_predict)
- else:
- tmp = image_comp.hyper_1_dec(y_hyper_q)
- hyper_dec = image_comp.p(tmp)
- if USE_PREDICTOR:
- y_main_predict = image_comp.Y(tmp)
- y_main_res = y_main - y_main_predict
- y_main_q = torch.round(y_main_res)
- else:
- y_main_q = torch.round(y_main)
-
- y_main_q = torch.Tensor(y_main_q.cpu().numpy().astype(np.int))
- y_hyper_q = torch.Tensor(y_hyper_q.cpu().numpy().astype(np.int))
- y_hyper_2_q = torch.Tensor(y_hyper_2_q.cpu().numpy().astype(np.int))
-
- if GPU:
- y_main_q = y_main_q.cuda()
- dec_time_start = time.time()
-
- if USE_PREDICTOR:
- rec_img_block = image_comp.decoder(y_main_q + y_main_predict, lambda_rd)
- else:
- rec_img_block = image_comp.decoder(y_main_q, lambda_rd)
- ############################ Reverse Geometric Flip and Rotate ########################
- if USE_GEO:
- if geo_index < 4:
- rec_img_block = torch.rot90(rec_img_block,k=4-i_rot, dims=[2, 3])
- else:
- rec_img_block = torch.flip(torch.rot90(rec_img_block,k=4-i_rot, dims=[2, 3]), dims=[2])
-
- ################################### Reconstruct Image #######################################
- output_ = torch.clamp(rec_img_block, min=0., max=1.0)
- out = output_.data[0].cpu().numpy()
- out = out.transpose(1, 2, 0)
-
- out_img[H_offset : H_offset + block_H, W_offset : W_offset + block_W, :] = out[:block_H, :block_W, :]
- dec_time += (time.time()-dec_time_start)
-
- if USE_PREDICTOR:
- xp3, params_prob = context(y_main_q, y_main_q + y_main_predict, hyper_dec)
- else:
- xp3, params_prob = context(y_main_q, hyper_dec)
-
- # Main Arith Encode
- Datas = torch.reshape(y_main_q, [-1]).cpu().numpy().astype(np.int).tolist()
- _, c, h, w = y_main_q.shape
- if USE_ACCELERATION:
- EC_this_block = []
- Effective_Datas = []
- EC_last_block = []
- for i in range(c):
- current_channel = Datas[i * h * w:(i + 1) * h * w]
- if max(current_channel) == 0 and min(current_channel) == 0:
- EC_this_block.append(0)
- else:
- EC_this_block.append(1)
- Effective_Datas += current_channel
- if Block_Idx == 1:
- pack_bools(EC_this_block, M, file_object)
- else:
- if EC_this_block == EC_last_block:
- b = struct.pack('1?', 1)
- file_object.write(b)
- else:
- b = struct.pack('1?', 0)
- file_object.write(b)
- pack_bools(EC_this_block, M, file_object)
- EC_last_block = EC_this_block
- EC_index = [i for i in range(c) if EC_this_block[i] == 1]
- EC_index = torch.LongTensor(EC_index).cuda()
- params_prob = torch.index_select(params_prob, dim=2, index=EC_index)
-
- Max_Main = max(Datas)
- Min_Main = min(Datas)
- sample = np.arange(Min_Main, Max_Main+1+1) # [Min_V - 0.5 , Max_V + 0.5]
- print("Main Channel:", c)
- if USE_ACCELERATION:
- sample = torch.FloatTensor(np.tile(sample, [1, len(EC_index), h, w, 1])).cuda()
- else:
- sample = torch.FloatTensor(np.tile(sample, [1, c, h, w, 1])).cuda()
-
- # 3 gaussian
- prob0, mean0, scale0, prob1, mean1, scale1, prob2, mean2, scale2 = [
- torch.chunk(params_prob, 9, dim=1)[i].squeeze(1) for i in range(9)]
- del params_prob
- # keep the weight summation of prob == 1
- probs = torch.stack([prob0, prob1, prob2], dim=-1)
- del prob0, prob1, prob2
-
- probs = F.softmax(probs, dim=-1)
- # process the scale value to positive non-zero
- scale0 = torch.abs(scale0)
- scale1 = torch.abs(scale1)
- scale2 = torch.abs(scale2)
- scale0[scale0 < 1e-6] = 1e-6
- scale1[scale1 < 1e-6] = 1e-6
- scale2[scale2 < 1e-6] = 1e-6
- m0 = torch.distributions.normal.Normal(mean0, scale0)
- m1 = torch.distributions.normal.Normal(mean1, scale1)
- m2 = torch.distributions.normal.Normal(mean2, scale2)
- if USE_ACCELERATION:
- lower = torch.zeros(1, len(EC_index), h, w, Max_Main - Min_Main + 2)
- else:
- lower = torch.zeros(1, c, h, w, Max_Main - Min_Main + 2)
-
- for i in range(sample.shape[4]):
- # print("CDF:", i)
- lower0 = m0.cdf(sample[:, :, :, :, i].cuda()-0.5)
- lower1 = m1.cdf(sample[:, :, :, :, i].cuda()-0.5)
- lower2 = m2.cdf(sample[:, :, :, :, i].cuda()-0.5)
- lower[:, :, :, :, i] = probs[:, :, :, :, 0]*lower0 + \
- probs[:, :, :, :, 1]*lower1+probs[:, :, :, :, 2]*lower2
- del probs, lower0, lower1, lower2
-
- precise = 16
- cdf_m = lower.data.cpu().numpy()*((1 << precise) - (Max_Main -
- Min_Main + 1)) # [1, c, h, w ,Max-Min+1]
- cdf_m = cdf_m.astype(np.int32) + sample.cpu().numpy().astype(np.int32) - Min_Main
- if USE_ACCELERATION:
- cdf_main = np.reshape(cdf_m, [len(Effective_Datas), -1])
- Cdf_lower = list(map(lambda x, y: int(y[x - Min_Main]), Effective_Datas, cdf_main))
- Cdf_upper = list(map(lambda x, y: int(
- y[x - Min_Main]), Effective_Datas, cdf_main[:, 1:]))
- else:
- cdf_main = np.reshape(cdf_m, [len(Datas), -1])
- # Cdf[Datas - Min_V]
- Cdf_lower = list(map(lambda x, y: int(y[x - Min_Main]), Datas, cdf_main))
- # Cdf[Datas + 1 - Min_V]
- Cdf_upper = list(map(lambda x, y: int(
- y[x - Min_Main]), Datas, cdf_main[:, 1:]))
-
- AE.encode_cdf(Cdf_lower, Cdf_upper, "main.bin")
- FileSizeMain = os.path.getsize("main.bin")
- print("main.bin: %d bytes" % (FileSizeMain))
-
- # Hyper 1 Arith Encode
- Datas = torch.reshape(y_hyper_q, [-1]).cpu().numpy().astype(np.int).tolist()
- Max_HYPER_1 = max(Datas)
- Min_HYPER_1 = min(Datas)
- sample = np.arange(Min_HYPER_1, Max_HYPER_1+1+1) # [Min_V - 0.5 , Max_V + 0.5]
- _, c, h, w = y_hyper_q.shape
- print("Hyper 1 Channel:", c)
- sample = torch.FloatTensor(np.tile(sample, [1, c, h, w, 1])).cuda()
-
- mean = hyper_2_dec[:, :c, :, :]
- scale = hyper_2_dec[:, c:, :, :]
-
- scale = torch.abs(scale)
- scale[scale < 1e-6] = 1e-6
-
- m = torch.distributions.normal.Normal(mean, scale)
- lower = torch.zeros(1, c, h, w, Max_HYPER_1-Min_HYPER_1+2).cuda()
- for ii in range(sample.shape[4]):
- lower[:,:,:,:,ii] = m.cdf(sample[:,:,:,:,ii]-0.5)
- precise = 16
- cdf_m = lower.data.cpu().numpy()*((1 << precise) - (Max_HYPER_1 -
- Min_HYPER_1 + 1)) # [1, c, h, w ,Max-Min+1]
- cdf_m = cdf_m.astype(np.int32) + sample.cpu().numpy().astype(np.int32) - Min_HYPER_1
- cdf_main = np.reshape(cdf_m, [len(Datas), -1])
-
- # Cdf[Datas - Min_V]
- Cdf_lower = list(map(lambda x, y: int(y[x - Min_HYPER_1]), Datas, cdf_main))
- # Cdf[Datas + 1 - Min_V]
- Cdf_upper = list(map(lambda x, y: int(
- y[x - Min_HYPER_1]), Datas, cdf_main[:, 1:]))
- AE.encode_cdf(Cdf_lower, Cdf_upper, "hyper_1.bin")
- FileSizeHyper1 = os.path.getsize("hyper_1.bin")
- print("hyper_1.bin: %d bytes" % (FileSizeHyper1))
-
-
- # Hyper 2 Arith Encode
- Min_HYPER_2 = torch.min(y_hyper_2_q).cpu().numpy().astype(np.int16).tolist()
- Max_HYPER_2 = torch.max(y_hyper_2_q).cpu().numpy().astype(np.int16).tolist()
- _, c, h, w = y_hyper_2_q.shape
- # print("Hyper Channel:", c)
- Datas_hyper = torch.reshape(
- y_hyper_2_q, [c, -1]).cpu().numpy().astype(np.int16).tolist()
- # [Min_V - 0.5 , Max_V + 0.5]
- sample = np.arange(Min_HYPER_2, Max_HYPER_2+1+1)
- sample = np.tile(sample, [c, 1, 1])
- lower = torch.sigmoid(image_comp.factorized_entropy_func._logits_cumulative(
- torch.FloatTensor(sample).cuda() - 0.5, stop_gradient=False))
-
- cdf_h = lower.data.cpu().numpy()*((1 << precise) - (Max_HYPER_2 -
- Min_HYPER_2 + 1)) # [N1, 1, Max-Min+1]
- cdf_h = cdf_h.astype(np.int16) + sample.astype(np.int16) - Min_HYPER_2
- cdf_hyper = np.reshape(np.tile(cdf_h, [len(Datas_hyper[0]), 1, 1, 1]), [
- len(Datas_hyper[0]), c, -1])
-
- # Datas_hyper [256, N], cdf_hyper [256,1,X]
- Cdf_0, Cdf_1 = [], []
- for i in range(c):
- Cdf_0.extend(list(map(lambda x, y: int(
- y[x - Min_HYPER_2]), Datas_hyper[i], cdf_hyper[:, i, :]))) # Cdf[Datas - Min_V]
- Cdf_1.extend(list(map(lambda x, y: int(
- y[x - Min_HYPER_2]), Datas_hyper[i], cdf_hyper[:, i, 1:]))) # Cdf[Datas + 1 - Min_V]
- AE.encode_cdf(Cdf_0, Cdf_1, "hyper_2.bin")
- FileSizeHyper2 = os.path.getsize("hyper_2.bin")
- print("hyper_2.bin: %d bytes" % (FileSizeHyper2))
-
- if USE_GEO:
- Head_block = struct.pack('6h3IB', Min_Main, Max_Main, Min_HYPER_1, Max_HYPER_1,Min_HYPER_2,Max_HYPER_2, FileSizeMain, FileSizeHyper1, FileSizeHyper2, geo_index)
- else:
- Head_block = struct.pack('6h3I', Min_Main, Max_Main, Min_HYPER_1, Max_HYPER_1,Min_HYPER_2,Max_HYPER_2, FileSizeMain, FileSizeHyper1, FileSizeHyper2)
-
- else:
- with torch.no_grad():
- y_main, y_hyper = image_comp.encoder(im_block, lambda_rd)
- y_main_q= torch.round(y_main)
- y_main_q = torch.Tensor(y_main_q.cpu().numpy().astype(np.int))
- if GPU:
- y_main_q = y_main_q.cuda()
- dec_time_start = time.time()
- rec_img_block = image_comp.decoder(y_main_q, lambda_rd)
- ############################ Reverse Geometric Flip and Rotate ########################
- if USE_GEO:
- if geo_index < 4:
- rec_img_block = torch.rot90(rec_img_block,k=4-i_rot, dims=[2, 3])
- else:
- rec_img_block = torch.flip(torch.rot90(rec_img_block,k=4-i_rot, dims=[2, 3]), dims=[2])
-
- ################################### Reconstruct Image #######################################
- output_ = torch.clamp(rec_img_block, min=0., max=1.0)
- out = output_.data[0].cpu().numpy()
- out = out.transpose(1, 2, 0)
-
- out_img[H_offset : H_offset + block_H, W_offset : W_offset + block_W, :] = out[:block_H, :block_W, :]
- dec_time += (time.time()-dec_time_start)
-
- y_hyper_q, xp2 = image_comp.factorized_entropy_func(y_hyper, 2)
- y_hyper_q = torch.Tensor(y_hyper_q.cpu().numpy().astype(np.int))
- if GPU:
- y_hyper_q = y_hyper_q.cuda()
-
- hyper_dec = image_comp.p(image_comp.hyper_dec(y_hyper_q))
- xp3, params_prob = context(y_main_q, hyper_dec)
-
- # Main Arith Encode
- Datas = torch.reshape(y_main_q, [-1]).cpu().numpy().astype(np.int).tolist()
- Max_Main = max(Datas)
- Min_Main = min(Datas)
- sample = np.arange(Min_Main, Max_Main+1+1) # [Min_V - 0.5 , Max_V + 0.5]
- _, c, h, w = y_main_q.shape
- print("Main Channel:", c)
- sample = torch.FloatTensor(np.tile(sample, [1, c, h, w, 1]))
- if GPU:
- sample = sample.cuda()
-
- # 3 gaussian
- prob0, mean0, scale0, prob1, mean1, scale1, prob2, mean2, scale2 = [
- torch.chunk(params_prob, 9, dim=1)[i].squeeze(1) for i in range(9)]
- del params_prob
- # keep the weight summation of prob == 1
- probs = torch.stack([prob0, prob1, prob2], dim=-1)
- del prob0, prob1, prob2
-
- probs = F.softmax(probs, dim=-1)
- # process the scale value to positive non-zero
- scale0 = torch.abs(scale0)
- scale1 = torch.abs(scale1)
- scale2 = torch.abs(scale2)
- scale0[scale0 < 1e-6] = 1e-6
- scale1[scale1 < 1e-6] = 1e-6
- scale2[scale2 < 1e-6] = 1e-6
-
- m0 = torch.distributions.normal.Normal(mean0, scale0)
- m1 = torch.distributions.normal.Normal(mean1, scale1)
- m2 = torch.distributions.normal.Normal(mean2, scale2)
- lower = torch.zeros(1, c, h, w, Max_Main-Min_Main+2)
- for i in range(sample.shape[4]):
- # print("CDF:", i)
- lower0 = m0.cdf(sample[:, :, :, :, i]-0.5)
- lower1 = m1.cdf(sample[:, :, :, :, i]-0.5)
- lower2 = m2.cdf(sample[:, :, :, :, i]-0.5)
- lower[:, :, :, :, i] = probs[:, :, :, :, 0]*lower0 + \
- probs[:, :, :, :, 1]*lower1+probs[:, :, :, :, 2]*lower2
- del probs, lower0, lower1, lower2
-
- precise = 16
- cdf_m = lower.data.cpu().numpy()*((1 << precise) - (Max_Main -
- Min_Main + 1)) # [1, c, h, w ,Max-Min+1]
- cdf_m = cdf_m.astype(np.int32) + sample.cpu().numpy().astype(np.int32) - Min_Main
- cdf_main = np.reshape(cdf_m, [len(Datas), -1])
-
- # Cdf[Datas - Min_V]
- Cdf_lower = list(map(lambda x, y: int(y[x - Min_Main]), Datas, cdf_main))
- # Cdf[Datas + 1 - Min_V]
- Cdf_upper = list(map(lambda x, y: int(
- y[x - Min_Main]), Datas, cdf_main[:, 1:]))
- AE.encode_cdf(Cdf_lower, Cdf_upper, "main.bin")
- FileSizeMain = os.path.getsize("main.bin")
- #TotalFileSizeMain += FileSizeMain
- #print("main.bin: %d bytes" % (FileSizeMain))
-
- # Hyper Arith Encode
- Min_V_HYPER = torch.min(y_hyper_q).cpu().numpy().astype(np.int).tolist()
- Max_V_HYPER = torch.max(y_hyper_q).cpu().numpy().astype(np.int).tolist()
- _, c, h, w = y_hyper_q.shape
- # print("Hyper Channel:", c)
- Datas_hyper = torch.reshape(
- y_hyper_q, [c, -1]).cpu().numpy().astype(np.int).tolist()
- # [Min_V - 0.5 , Max_V + 0.5]
- sample = np.arange(Min_V_HYPER, Max_V_HYPER+1+1)
- sample = np.tile(sample, [c, 1, 1])
- sample_tensor = torch.FloatTensor(sample)
- if GPU:
- sample_tensor = sample_tensor.cuda()
- lower = torch.sigmoid(image_comp.factorized_entropy_func._logits_cumulative(
- sample_tensor - 0.5, stop_gradient=False))
- cdf_h = lower.data.cpu().numpy()*((1 << precise) - (Max_V_HYPER -
- Min_V_HYPER + 1)) # [N1, 1, Max-Min+1]
- cdf_h = cdf_h.astype(np.int) + sample.astype(np.int) - Min_V_HYPER
- cdf_hyper = np.reshape(np.tile(cdf_h, [len(Datas_hyper[0]), 1, 1, 1]), [
- len(Datas_hyper[0]), c, -1])
-
- # Datas_hyper [256, N], cdf_hyper [256,1,X]
- Cdf_0, Cdf_1 = [], []
- for i in range(c):
- Cdf_0.extend(list(map(lambda x, y: int(
- y[x - Min_V_HYPER]), Datas_hyper[i], cdf_hyper[:, i, :]))) # Cdf[Datas - Min_V]
- Cdf_1.extend(list(map(lambda x, y: int(
- y[x - Min_V_HYPER]), Datas_hyper[i], cdf_hyper[:, i, 1:]))) # Cdf[Datas + 1 - Min_V]
- AE.encode_cdf(Cdf_0, Cdf_1, "hyper.bin")
- FileSizeHyper = os.path.getsize("hyper.bin")
- if USE_GEO:
- Head_block = struct.pack('4h2IB', Min_Main, Max_Main, Min_V_HYPER, Max_V_HYPER,
- FileSizeMain, FileSizeHyper, geo_index)
- else:
- Head_block = struct.pack('4h2I', Min_Main, Max_Main, Min_V_HYPER, Max_V_HYPER,
- FileSizeMain, FileSizeHyper)
-
-
- file_object.write(Head_block) # CU information
- # cat Head_Infor and 2 files together
- # print("Head Info:",Head)
- if USE_MULTI_HYPER:
- with open("main.bin", 'rb') as f:
- bits = f.read()
- file_object.write(bits)
- with open("hyper_1.bin", 'rb') as f:
- bits = f.read()
- file_object.write(bits)
- with open("hyper_2.bin", 'rb') as f:
- bits = f.read()
- file_object.write(bits)
- else:
- with open("main.bin", 'rb') as f:
- bits = f.read()
- file_object.write(bits)
- with open("hyper.bin", 'rb') as f:
- bits = f.read()
- file_object.write(bits)
- del im, im_block_list, im_block_loc_list
- W_offset += block_W
- if W_offset >= W:
- W_offset = 0
- H_offset += block_H
-
-
- ############################################# postprocessing ###################################################################
-
- #post-processing only performs when reconstruction_metrics is MSE
- if USE_POSTPROCESSING and reconstruction_metric == "mse":
- enhance_net = Enhancement_net()
- W_offset_pp = 0
- H_offset_pp = 0
- out_img_pp = np.zeros([H, W, C]) # recon image
- source_img_normalized = source_img / 255.0
- out_img_normalized = out_img[:H, :W, :]
- ######################### Spliting Image #########################
- block_height_pp = postprocessing_CTU
- block_width_pp = postprocessing_CTU
- Block_Num_in_Width_pp = int(np.ceil(W / block_width_pp))
- Block_Num_in_Height_pp = int(np.ceil(H / block_height_pp))
- img_block_list_pp = []
- out_img_block_list_pp = []
- for i in range(Block_Num_in_Height_pp):
- for j in range(Block_Num_in_Width_pp):
- img_block_list_pp.append(source_img_normalized[i * block_height_pp:np.minimum((i + 1) * block_height_pp, H),j * block_width_pp:np.minimum((j + 1) * block_width_pp,W),...])
- out_img_block_list_pp.append(out_img_normalized[i * block_height_pp:np.minimum((i + 1) * block_height_pp, H),j * block_width_pp:np.minimum((j + 1) * block_width_pp,W),...])
- ######################### Padding Image #########################
- Block_Idx_pp = 0
- for idx, img_pp in enumerate(img_block_list_pp): # Traverse CTUs
-
- block_H_pp, block_W_pp, _ = img_pp.shape
- im_pp = torch.FloatTensor(img_pp)
- im_pp = im_pp.permute(2, 0, 1).contiguous()
- im_pp = im_pp.view(1, C, block_H_pp, block_W_pp)
-
- out_img_block_pp = out_img_block_list_pp[idx]
- out_im_pp = torch.FloatTensor(out_img_block_pp)
- out_im_pp = out_im_pp.permute(2, 0, 1).contiguous()
- out_im_pp = out_im_pp.view(1, C, block_H_pp, block_W_pp)
-
- if GPU:
- im_pp = im_pp.cuda()
- if USE_VR_MODEL:
- lambda_rd = lambda_rd.cuda()
- print('====> Post processing Image:', im_dir, "%dx%d" % (block_H_pp, block_W_pp), 'to', out_dir, " Block Idx: %d" % (Block_Idx_pp))
- Block_Idx_pp += 1
-
- #################################### Content adaptive Online Training (post processing) ########################
- #only perform the postprocessing if the optimization metric is MSE
- if block_H_pp >= 3 and block_W_pp >= 3:
- enhance_net = Enhancement_net()
- enhance_net.cuda()
- isEnhance, params_list, rec_img_block = online_training(out_im_pp, im_pp,
- block_H_pp, block_W_pp,
- enhance_net, lmbda,
- post_processing_epochs, post_processing_learning_rate)
-
- Head_pp_block = struct.pack('?', isEnhance)
- file_object.write(Head_pp_block) # CU information
-
- if isEnhance:
- pp_block_params = struct.pack('84e', *params_list)
- file_object.write(pp_block_params)
-
- else:
- rec_img_block = out_im_pp
- ################################### Reconstruct Image #######################################
- output_ = torch.clamp(rec_img_block, min=0., max=1.0)
-
- out = output_.data[0].cpu().numpy()
- out = out.transpose(1, 2, 0)
-
- out_img_pp[H_offset_pp : H_offset_pp + block_H_pp, W_offset_pp : W_offset_pp + block_W_pp, :] = out[:block_H_pp, :block_W_pp, :]
-
- W_offset_pp += block_W_pp
- if W_offset_pp >= W:
- W_offset_pp = 0
- H_offset_pp += block_H_pp
- out_img = out_img_pp
-
- file_object.close()
- with open(bin_dir, 'rb') as f:
- bpp = len(f.read())*8./num_pixels
- f.close()
-
- out_img = np.round(out_img * 255.0)
- out_img = out_img.astype('uint8')
- out_img = out_img[:H, :W, :]
- if SAVE_REC:
- img = Image.fromarray(out_img)
- img.save(rec_dir)
- [rgb_psnr, rgb_msssim, yuv_psnr,y_msssim]=evaluate(source_img,out_img)
- bpp_list.append(bpp)
- rgb_psnr_list.append(rgb_psnr)
- rgb_msssim_list.append(rgb_msssim)
-
- class_name = im_dir.split('/')[-2]
- image_name = im_dir.split('/')[-1].replace('.png','')
- enc_dec_time = time.time() - enc_dec_time_start
- enc_time = enc_dec_time - dec_time
-
- print(class_name+'/'+image_name+'\t'+str(bpp)+'\t'+str(rgb_psnr)+'\t'+str(rgb_msssim)+'\t'+str(-10*np.log10(1-rgb_msssim))+
- '\t'+str(yuv_psnr)+'\t'+str(y_msssim)+'\t'+str(-10*np.log10(1-y_msssim))+
- '\t'+str(enc_time)+'\t'+str(dec_time)+'\n')
-
- with open(log_name, 'a') as f:
- f.write(class_name+'/'+image_name+'\t'+str(bpp)+'\t'+str(rgb_psnr)+'\t'+str(rgb_msssim)+'\t'+str(-10*np.log10(1-rgb_msssim))+
- '\t'+str(yuv_psnr)+'\t'+str(y_msssim)+'\t'+str(-10*np.log10(1-y_msssim))+
- '\t'+str(enc_time)+'\t'+str(dec_time)+'\n')
- f.close()
- del out_img
- mean_bpp = np.mean(bpp_list)
- mean_psnr = np.mean(rgb_psnr_list)
- mean_msssim = np.mean(rgb_msssim_list)
- with open(log_name, 'a') as f:
- f.write("mean of bpp:" + str(mean_bpp) + '\n' + "mean of psnr:" + str(mean_psnr) +'\n' + "mean of msssim:" + str(mean_msssim))
- f.close()
- print("mean of bpp:", mean_bpp)
- print("mean of psnr:", mean_psnr)
- print("mean of msssim:", mean_msssim)
-
- if __name__ == '__main__':
-
- seed = 0
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- np.random.seed(seed)
- random.seed(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
-
- parser = argparse.ArgumentParser()
- parser.add_argument("-i", "--input", type=str, required=True, help="Input Image")
- parser.add_argument("-o", "--output", type=str, required=True, help="Output Bin(encode)/Image(decode)")
- parser.add_argument("-m_dir", "--model_dir", type=str, required=True, help="Directory containing trained models")
- parser.add_argument("-m", "--model", type=int, default=0, help="Model Index [0-9]")
- parser.add_argument("--lambda_rd", type=float, default=16, help="Input lambda for variable-rate models")
- # parser.add_argument('--encode', dest='coder_flag', action='store_true')
- # parser.add_argument('--decode', dest='coder_flag', action='store_false')
- # parser.add_argument("--block_width", type=int, default=2048, help="coding block width")
- # parser.add_argument("--block_height", type=int, default=1024, help="coding block height")
- parser.add_argument("--number", type=int, default=24, help="Number of Coding Images")
- args = parser.parse_args()
-
- num = 0
- test_image_paths = []
- dirs = os.listdir(args.input)
- for dir in dirs:
- path = os.path.join(args.input, dir)
- if os.path.isdir(path):
- test_image_paths += glob.glob(path + '/*.png')[0:args.number]
- if os.path.isfile(path):
- if num < args.number:
- test_image_paths.append(path)
- num += 1
-
- inference_rd(test_image_paths, args.output, args.model_dir, args.model, args.lambda_rd)
|