|
- '''
- Contributors: Aolin Feng, Jianpin Lin, Dezhao Wang, Ding Ding, Wei Wang, Yueyu Hu, Haojie Liu, Tong Chen, Chuanmin Jia, Yihang Chen
- '''
-
- import argparse
- import math
- import os
- import struct
- import sys
- import time
- import glob
-
- import numpy as np
- import torch
- import torch.nn.functional as F
- from PIL import Image
-
-
- import Model.model as model
- from Model.context_model import Weighted_Gaussian, Weighted_Gaussian_res
- from Util.metrics import evaluate
- from Util import torch_msssim
- from Util.block_metric import check_RD_GEO # blcok based metric
- from Util.config import dict
- from Util.generate_substitute import SubstituteGenerator
-
- os.environ['LRU_CACHE_CAPACITY'] = '1'
- ############################## Load Configuration Parameters ########################
- GPU = dict['GPU']
- SAVE_REC = dict['SAVE_REC']
- USE_GEO = dict['USE_GEO']
- USE_VR_MODEL = dict['USE_VR_MODEL']
- USE_PREPROCESSING = dict['USE_PREPROCESSING']
- if USE_PREPROCESSING:
- num_steps = dict['num_steps']
- block_width = dict['CTU_size']
- block_height = dict['CTU_size']
-
- USE_MULTI_HYPER = dict['USE_MULTI_HYPER']
- USE_PREDICTOR = dict['USE_PREDICTOR']
-
- assert (USE_MULTI_HYPER and USE_VR_MODEL) is False
- assert (USE_PREDICTOR and bool(1 - USE_MULTI_HYPER)) is False
-
- # index - [0-15]
- if USE_VR_MODEL:
- models = ["mse_VR_low", "mse_VR_high", "msssim_VR_low", "msssim_VR_high"]
- max_lambdas = [64, 256, 1.28, 6.40]
- else:
- models = ["mse200", "mse400", "mse800", "mse1600", "mse3200", "mse6400", "mse12800", "mse25600",
- "msssim4", "msssim8", "msssim16", "msssim32", "msssim64", "msssim128", "msssim320", "msssim640"]
-
- #os.environ["CUDA_VISIBLE_DEVICES"] = "3"
-
- # model_dir = "/model/ljp105/NIC_v01_baseline/"
- # VR_model_dir = "/model/ljp105/NIC_v02_VR_models"
- #@torch.no_grad()
- def inference_rd(im_dirs, out_dir, model_dir, model_index, lambda_rd_ori):
- if os.path.exists(out_dir) is False:
- os.makedirs(out_dir)
-
- if USE_VR_MODEL:
- lambda_rd_max = max_lambdas[model_index]
- if lambda_rd_ori > 1.2 * lambda_rd_max:
- lambda_rd_ori = 1.2 * lambda_rd_max
- lambda_rd_nom = lambda_rd_ori / lambda_rd_max
- lambda_rd_nom_scaled = int(lambda_rd_nom / 1.2 * pow(2, 16))
- lambda_rd_nom_used = lambda_rd_nom_scaled / pow(2, 16) * 1.2
- lambda_rd_numpy = np.zeros((1, 1), np.float32)
- lambda_rd_numpy[0, 0] = lambda_rd_nom_used
- lambda_rd = torch.Tensor(lambda_rd_numpy)
- M, N2 = 192, 128
- if (model_index == 1) or (model_index == 3):
- M, N2 = 256, 192
- image_comp = model.Image_coding(3, M, N2, M, M // 2)
- context = Weighted_Gaussian(M)
-
- log_name = os.path.join(out_dir,
- models[model_index] + '_lmbda' + str(int(lambda_rd_ori * 100)) + '_test_' + str(
- block_width) + '_RD.txt')
- else:
- M, N2 = 192, 128
- if (model_index == 6) or (model_index == 7) or (model_index == 14) or (model_index == 15):
- M, N2 = 256, 192
- if USE_MULTI_HYPER:
- if USE_PREDICTOR:
- image_comp = model.Image_coding_multi_hyper_res(3, M, N2, M, M // 2)
- else:
- image_comp = model.Image_coding_multi_hyper(3, M, N2, M, M // 2)
- else:
- image_comp = model.Image_coding(3, M, N2, M, M // 2)
- if USE_PREDICTOR:
- context = Weighted_Gaussian_res(M)
- else:
- context = Weighted_Gaussian(M)
- lambda_rd = None
- log_name = os.path.join(out_dir, models[model_index] + '_test_' + str(block_width) + '_RD.txt')
-
- if USE_PREPROCESSING:
- lmbda_list = [200, 400, 800, 1600, 3200, 6400, 12800, 25600, 4, 8, 16, 32, 64, 128, 320, 640]
- stepsize_list = [150, 75, 30, 10, 10, 5, 3, 1, 100, 10, 7, 5, 1, 1, 1, 0.3]
-
- if USE_VR_MODEL:
- lmbda = lambda_rd_ori * 100
- step_size = stepsize_list[lmbda_list.index(lmbda)]
- reconstruction_metric = 'mse' if model_index <= 1 else 'msssim'
- else:
- step_size = stepsize_list[model_index]
- lmbda = lmbda_list[model_index]
- reconstruction_metric = 'mse' if model_index <= 7 else 'msssim'
-
- substitute_generator = SubstituteGenerator(model=image_comp, context_model=context, llambda=lmbda,
- num_steps=num_steps, step_size=step_size,
- reconstruct_metric=reconstruction_metric,
- )
-
- ######################### Load Model #########################
- image_comp.load_state_dict(torch.load(
- os.path.join(model_dir, models[model_index] + r'.pkl'), map_location='cpu'))
- context.load_state_dict(torch.load(
- os.path.join(model_dir, models[model_index] + r'p.pkl'), map_location='cpu'))
- if GPU:
- image_comp.cuda()
- context.cuda()
- #################### Compress Each Image ###################
-
- for im_dir in im_dirs:
- R = 0
- dec_time = 0
- enc_dec_time_start = time.time()
- bin_dir = os.path.join(out_dir, 'enc.bin')
-
- file_object = open(bin_dir, 'wb')
- ######################### Read Image #########################
- img = Image.open(im_dir)
- source_img = np.array(img)
- img = source_img / 255.0
- H, W, _ = img.shape
- print(img.shape)
- num_pixels = H * W
- C = 3
- Head = struct.pack('2HB4?H', H, W, model_index, USE_GEO, USE_VR_MODEL, USE_MULTI_HYPER, USE_PREDICTOR,
- block_width)
- file_object.write(Head)
- if USE_VR_MODEL:
- Head_lmbda = struct.pack('H', lambda_rd_nom_scaled)
- file_object.write(Head_lmbda)
-
- H_offset = 0
- W_offset = 0
- ######################### Spliting Image #########################
- Block_Num_in_Width = int(np.ceil(W / block_width))
- Block_Num_in_Height = int(np.ceil(H / block_height))
- img_block_list = []
- for i in range(Block_Num_in_Height):
- for j in range(Block_Num_in_Width):
- img_block_list.append(img[i * block_height:np.minimum((i + 1) * block_height, H),
- j * block_width:np.minimum((j + 1) * block_width, W), ...])
-
- ######################### Padding Image #########################
- Block_Idx = 0
- nec_tmp=0
- for img in img_block_list: # Traverse CTUs
- block_H = img.shape[0]
- block_W = img.shape[1]
- tile = 64.
- block_H_PAD = int(tile * np.ceil(block_H / tile))
- block_W_PAD = int(tile * np.ceil(block_W / tile))
- im = np.zeros([block_H_PAD, block_W_PAD, 3], dtype='float32')
- im[:block_H, :block_W, :] = img[:, :, :3]
- im = torch.FloatTensor(im)
- im = im.permute(2, 0, 1).contiguous()
- im = im.view(1, C, block_H_PAD, block_W_PAD)
- if GPU:
- im = im.cuda()
- if USE_VR_MODEL:
- lambda_rd = lambda_rd.cuda()
- print('====> Encoding Image:', im_dir, "%dx%d" % (block_H, block_W), 'to', out_dir,
- " Block Idx: %d" % (Block_Idx))
- Block_Idx += 1
- # begin processing CTU
- im_block_list = []
- im_block_loc_list = []
- im_block_list.append(im) # list size = 1
- im_block_loc_list.append([0, 0, block_H_PAD, block_W_PAD])
-
- for im_block_loc, im_block in zip(im_block_loc_list, im_block_list): # Traverse CUs
- ############################ Geometric Flip and rotate ########################
- if USE_GEO:
- _, _, geo_index, _, orig_rd = check_RD_GEO(im_block, lambda_rd, image_comp, context, model_index)
- i_rot = int(geo_index % 4)
- if geo_index < 4:
- im_block = torch.rot90(im_block, k=i_rot, dims=[2, 3])
- else:
- im_block = torch.rot90(torch.flip(im_block, dims=[2]), k=i_rot, dims=[2, 3])
-
- ############################ Preprocessing to find a substitute ########################
- if USE_PREPROCESSING:
- if USE_VR_MODEL:
- im_block = substitute_generator.perturb(orig_image=im_block, orig_rd=orig_rd,
- lambda_rd=lambda_rd)
- else:
- im_block = substitute_generator.perturb(orig_image=im_block, orig_rd=orig_rd)
-
- if USE_MULTI_HYPER:
- with torch.no_grad():
- y_main, y_hyper, y_hyper_2 = image_comp.encoder(im_block, lambda_rd)
-
- # Hyper 2
- y_hyper_2_q, xp3 = image_comp.factorized_entropy_func(y_hyper_2, 2)
-
- # Hyper 1
- tmp2 = image_comp.hyper_2_dec(y_hyper_2_q)
- if USE_PREDICTOR:
- y_hyper_predict = image_comp.Y_2(tmp2)
- y_hyper_res = y_hyper - y_hyper_predict
- y_hyper_q = torch.round(y_hyper_res)
- else:
- y_hyper_q = torch.round(y_hyper)
-
- # Main
- if USE_PREDICTOR:
- tmp = image_comp.hyper_1_dec(y_hyper_q + y_hyper_predict)
- else:
- tmp = image_comp.hyper_1_dec(y_hyper_q)
- if USE_PREDICTOR:
- y_main_predict = image_comp.Y(tmp)
- y_main_res = y_main - y_main_predict
- y_main_q = torch.round(y_main_res)
- else:
- y_main_q = torch.round(y_main)
-
- y_main_q = torch.Tensor(y_main_q.cpu().numpy().astype(np.int))
-
- if GPU:
- y_main_q = y_main_q.cuda()
-
- _,c,h,w=y_main_q.size()
- Datas = torch.reshape(y_main_q, [-1]).cpu().numpy().astype(np.int).tolist()
- nec=[]
- for i in range(c):
- current_channel = Datas[i*h*w:(i+1)*h*w]
- if min(current_channel)==max(current_channel):
- nec.append(1)
- else:
- nec.append(0)
- nec_idx=torch.LongTensor(nec).cuda()
- if torch.sum((nec_idx-nec_tmp).pow(2))==0:
- additional_bits = 1
- else:
- additional_bits = c+1
- if Block_Idx==1:
- additional_bits=c
- nec_tmp=nec_idx
- print(Block_Idx,additional_bits)
- R+=additional_bits-c
-
-
- W_offset += block_W
- if W_offset >= W:
- W_offset = 0
- H_offset += block_H
-
- class_name = im_dir.split('/')[-2]
- image_name = im_dir.split('/')[-1].replace('.png', '')
- enc_dec_time = time.time() - enc_dec_time_start
- enc_time = enc_dec_time - dec_time
-
- with open(log_name, 'a') as f:
- f.write(class_name + '/' + image_name + '\t' + str(R/num_pixels) + '\n')
- f.close()
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument("-i", "--input", type=str, default='D:/CV_dataset/NIC_test/', help="Input Image")
- parser.add_argument("-o", "--output", type=str, default='performance/', help="Output Bin(encode)/Image(decode)")
- parser.add_argument("-m_dir", "--model_dir", type=str, default='../Weights_0.4/', help="Directory containing trained models")
- parser.add_argument("-m", "--model", type=int, default=3, help="Model Index [0-9]")
- parser.add_argument("--lambda_rd", type=float, default=16, help="Input lambda for variable-rate models")
- # parser.add_argument('--encode', dest='coder_flag', action='store_true')
- # parser.add_argument('--decode', dest='coder_flag', action='store_false')
- # parser.add_argument("--block_width", type=int, default=2048, help="coding block width")
- # parser.add_argument("--block_height", type=int, default=1024, help="coding block height")
- parser.add_argument("--number", type=int, default=24, help="Number of Coding Images")
- args = parser.parse_args()
-
- num = 0
- test_image_paths = []
- dirs = os.listdir(args.input)
- for dir in dirs:
- #if dir!='ClassD_Kodak':
- #continue
- path = os.path.join(args.input, dir)
- if os.path.isdir(path):
- test_image_paths += glob.glob(path + '/*.png')[0:args.number]
- if os.path.isfile(path):
- if num < args.number:
- test_image_paths.append(path)
- num += 1
-
- inference_rd(test_image_paths, args.output, args.model_dir, args.model, args.lambda_rd)
|