|
- import argparse
- import math
- import os
- import struct
- import sys
- import time
- import glob
-
- import numpy as np
- import torch
- import torch.nn.functional as F
- from PIL import Image
-
- # import Util.AE as AE
- import AE
- import Model.model as model
- from Model.context_model import Weighted_Gaussian
- from Util.metrics import evaluate
- from Util import torch_msssim
- from Util.block_metric import check_RD_GEO # blcok based metric
- from Util.config import dict
- from Util.generate_substitute import SubstituteGenerator
-
- os.environ['LRU_CACHE_CAPACITY'] = '1'
- ############################## Load Configuration Parameters ########################
- GPU = dict['GPU']
- SAVE_REC = dict['SAVE_REC']
- USE_GEO = dict['USE_GEO']
- USE_VR_MODEL = dict['USE_VR_MODEL']
- USE_PREPROCESSING = dict['USE_PREPROCESSING']
- if USE_PREPROCESSING:
- num_steps = dict['num_steps']
- block_width = dict['CTU_size']
- block_height = dict['CTU_size']
-
- USE_MULTI_HYPER = dict['USE_MULTI_HYPER']
-
- assert (USE_MULTI_HYPER and USE_VR_MODEL) is False
-
-
- def encode(out_dir):
- test_image_paths = []
- dirs = os.listdir(args.input)
- dir_list = ['ClassA_6k', 'ClassB_4k', 'ClassC_2k', 'ClassD_Kodak']
- for dir in dirs:
- if dir == dir_list[args.test_set]:
- path = os.path.join(args.input, dir)
- if os.path.isdir(path):
- test_image_paths += glob.glob(path + '/*.png')
- if os.path.isfile(path):
- test_image_paths.append(path)
-
- im_dirs = test_image_paths
-
- M, N2 = 256, 192
-
- NIC_modformer = model.Image_coding_multi_hyper_modformer(3, M, N2, M, M // 2)
-
- lambda_rd = None
-
- ######################### Load Model #########################
- NIC_modformer.load_state_dict(torch.load(
- 'Weights/modformer_backup.pkl', map_location='cpu'))
-
- if GPU:
- NIC_modformer.cuda()
-
- #################### Compress Each Image ###################
- bpp_list = []
- rgb_psnr_list = []
- rgb_msssim_list = []
- enc_time_list = []
- dec_time_list = []
- idx = 0
- for im_dir in im_dirs:
- idx += 1
- if idx > 1:
- continue
- dec_time = 0
- enc_dec_time_start = time.time()
- bin_dir = os.path.join(out_dir, 'enc.bin')
- # bin_dir = os.path.join(out_dir, im_dir.split('/')[-1].replace('.png','enc.bin'))
- rec_dir = os.path.join(out_dir, 'dec.png')
- file_object = open(bin_dir, 'wb')
- ######################### Read Image #########################
- img = Image.open(im_dir)
- source_img = np.array(img)
- img = source_img / 255.0
- H, W, _ = img.shape
- print(img.shape)
- num_pixels = H * W
- C = 3
- Head = struct.pack('4H', H, W, args.lmd, block_width)
- file_object.write(Head)
-
- out_img = np.zeros([H, W, C]) # recon image
- H_offset = 0
- W_offset = 0
- ######################### Spliting Image #########################
- Block_Num_in_Width = int(np.ceil(W / block_width))
- Block_Num_in_Height = int(np.ceil(H / block_height))
- img_block_list = []
- for i in range(Block_Num_in_Height):
- for j in range(Block_Num_in_Width):
- img_block_list.append(img[i * block_height:np.minimum((i + 1) * block_height, H),
- j * block_width:np.minimum((j + 1) * block_width, W), ...])
-
- ######################### Padding Image #########################
- Block_Idx = 0
- for img in img_block_list: # Traverse CTUs
- block_H = img.shape[0]
- block_W = img.shape[1]
- tile = 64.
- block_H_PAD = int(tile * np.ceil(block_H / tile))
- block_W_PAD = int(tile * np.ceil(block_W / tile))
- im = np.zeros([block_H_PAD, block_W_PAD, 3], dtype='float32')
- im[:block_H, :block_W, :] = img[:, :, :3]
- im = torch.FloatTensor(im)
- im = im.permute(2, 0, 1).contiguous()
- im = im.view(1, C, block_H_PAD, block_W_PAD)
- if GPU:
- im = im.cuda()
-
- print('====> Encoding Image:', im_dir, "%dx%d" % (block_H, block_W), 'to', out_dir,
- " Block Idx: %d" % (Block_Idx))
- Block_Idx += 1
- # begin processing CTU
- im_block_list = []
- im_block_loc_list = []
- im_block_list.append(im) # list size = 1
- im_block_loc_list.append([0, 0, block_H_PAD, block_W_PAD])
-
- for im_block_loc, im_block in zip(im_block_loc_list, im_block_list): # Traverse CUs
- with torch.no_grad():
- x1, x2, x3 = NIC_modformer.encoder(im_block)
- xq3, xp3 = NIC_modformer.factorized_entropy_func(x3, 2)
- x4 = NIC_modformer.hyper_2_dec(xq3)
- hyper_2_dec = NIC_modformer.p_2(x4)
- xq2 = torch.round(x2)
- x5 = NIC_modformer.hyper_1_dec(xq2)
- lmd = torch.ones((1, 1, 1)) * args.lmd
- lmd = lmd.cuda()
- mask = NIC_modformer.modformer_mask(x5, lmd)
-
- x1 = x1 * mask
- hyper_dec = NIC_modformer.p(x5)
- y_main_q = torch.round(x1)
- rec_img_block = NIC_modformer.decoder(y_main_q)
-
- if GPU:
- y_main_q = y_main_q.cuda()
- dec_time_start = time.time()
-
- ################################### Reconstruct Image #######################################
- output_ = torch.clamp(rec_img_block, min=0., max=1.0)
- out = output_.data[0].cpu().numpy()
- out = out.transpose(1, 2, 0)
-
- out_img[H_offset: H_offset + block_H, W_offset: W_offset + block_W, :] = out[:block_H, :block_W,
- :]
- dec_time += (time.time() - dec_time_start)
-
- Datas = torch.reshape(y_main_q, [-1]).cpu().numpy().astype(np.int)
- masks = torch.reshape(mask, [-1]).cpu().numpy().astype(np.int)
- e_data = Datas[masks == 1].tolist()
- _, c, h, w = y_main_q.shape
-
- y_hyper_q = torch.Tensor(xq2.cpu().numpy().astype(np.int))
- y_hyper_2_q = torch.Tensor(xq3.cpu().numpy().astype(np.int))
-
- xp1, params_prob = NIC_modformer.context(y_main_q.cuda(), hyper_dec)
- params_prob = torch.reshape(params_prob, (1, 9, -1)).cpu().numpy()[:, :, masks == 1]
- params_prob = torch.from_numpy(params_prob)
-
- # Main Arith Encode
- print("Main Channel:", c)
- Max_Main = max(Datas)
- Min_Main = min(Datas)
-
- # 3 gaussian
- prob0, mean0, scale0, prob1, mean1, scale1, prob2, mean2, scale2 = [
- torch.chunk(params_prob, 9, dim=1)[i].squeeze(1) for i in range(9)]
- del params_prob
- # keep the weight summation of prob == 1
- probs = torch.stack([prob0, prob1, prob2], dim=-1)
- del prob0, prob1, prob2
-
- probs = F.softmax(probs, dim=-1)
- # process the scale value to positive non-zero
- scale0 = torch.abs(scale0)
- scale1 = torch.abs(scale1)
- scale2 = torch.abs(scale2)
- scale0[scale0 < 1e-6] = 1e-6
- scale1[scale1 < 1e-6] = 1e-6
- scale2[scale2 < 1e-6] = 1e-6
-
- sample = np.arange(Min_Main, Max_Main + 1 + 1) # [Min_V - 0.5 , Max_V + 0.5]
- sample = torch.FloatTensor(np.tile(sample, [1, len(e_data), 1])).cpu()
-
- m0 = torch.distributions.normal.Normal(mean0, scale0)
- m1 = torch.distributions.normal.Normal(mean1, scale1)
- m2 = torch.distributions.normal.Normal(mean2, scale2)
- lower = torch.zeros(1, len(e_data), Max_Main - Min_Main + 2)
-
- for i in range(sample.shape[2]):
- # print("CDF:", i)
- lower0 = m0.cdf(sample[:, :, i] - 0.5)
- lower1 = m1.cdf(sample[:, :, i] - 0.5)
- lower2 = m2.cdf(sample[:, :, i] - 0.5)
- lower[:, :, i] = probs[:, :, 0] * lower0 + \
- probs[:, :, 1] * lower1 + probs[:, :, 2] * lower2
-
- precise = 16
- cdf_m = lower.data.cpu().numpy() * (
- (1 << precise) - (Max_Main - Min_Main + 1)) # [1, c, h, w ,Max-Min+1]
- cdf_m = cdf_m.astype(np.int32) + sample.cpu().numpy().astype(np.int32) - Min_Main
- cdf_main = np.reshape(cdf_m, [len(e_data), -1])
-
- # Cdf[Datas - Min_V]
- Cdf_lower = list(map(lambda x, y: int(y[x - Min_Main]), e_data, cdf_main))
- # Cdf[Datas + 1 - Min_V]
- Cdf_upper = list(map(lambda x, y: int(
- y[x - Min_Main]), e_data, cdf_main[:, 1:]))
-
- AE.encode_cdf(Cdf_lower, Cdf_upper, "main.bin")
- FileSizeMain = os.path.getsize("main.bin")
- del Cdf_lower, Cdf_upper, Datas, e_data
-
- print("main.bin: %d bytes" % (FileSizeMain))
- del probs, lower0, lower1, lower2
-
- # Hyper 1 Arith Encode
- Datas = torch.reshape(y_hyper_q, [-1]).cpu().numpy().astype(np.int).tolist()
- Max_HYPER_1 = max(Datas)
- Min_HYPER_1 = min(Datas)
- sample = np.arange(Min_HYPER_1, Max_HYPER_1 + 1 + 1) # [Min_V - 0.5 , Max_V + 0.5]
- _, c, h, w = y_hyper_q.shape
- print("Hyper 1 Channel:", c)
- sample = torch.FloatTensor(np.tile(sample, [1, c, h, w, 1])).cuda()
-
- mean = hyper_2_dec[:, :c, :, :]
- scale = hyper_2_dec[:, c:, :, :]
-
- scale = torch.abs(scale)
- scale[scale < 1e-6] = 1e-6
-
- m_notgzq = torch.distributions.normal.Normal(mean, scale)
- lower = torch.zeros(1, c, h, w, Max_HYPER_1 - Min_HYPER_1 + 2).cuda()
- for ii in range(sample.shape[4]):
- lower[:, :, :, :, ii] = m_notgzq.cdf(sample[:, :, :, :, ii] - 0.5)
- precise = 16
- cdf_m = lower.data.cpu().numpy() * ((1 << precise) - (Max_HYPER_1 -
- Min_HYPER_1 + 1)) # [1, c, h, w ,Max-Min+1]
- cdf_m = cdf_m.astype(np.int32) + sample.cpu().numpy().astype(np.int32) - Min_HYPER_1
- cdf_main = np.reshape(cdf_m, [len(Datas), -1])
-
- # Cdf[Datas - Min_V]
- Cdf_lower = list(map(lambda x, y: int(y[x - Min_HYPER_1]), Datas, cdf_main))
- # Cdf[Datas + 1 - Min_V]
- Cdf_upper = list(map(lambda x, y: int(
- y[x - Min_HYPER_1]), Datas, cdf_main[:, 1:]))
- AE.encode_cdf(Cdf_lower, Cdf_upper, "hyper_1.bin")
- FileSizeHyper1 = os.path.getsize("hyper_1.bin")
- print("hyper_1.bin: %d bytes" % (FileSizeHyper1))
-
- # Hyper 2 Arith Encode
- Min_HYPER_2 = torch.min(y_hyper_2_q).cpu().numpy().astype(np.int16).tolist()
- Max_HYPER_2 = torch.max(y_hyper_2_q).cpu().numpy().astype(np.int16).tolist()
- _, c, h, w = y_hyper_2_q.shape
- # print("Hyper Channel:", c)
- Datas_hyper = torch.reshape(
- y_hyper_2_q, [c, -1]).cpu().numpy().astype(np.int16).tolist()
- # [Min_V - 0.5 , Max_V + 0.5]
- sample = np.arange(Min_HYPER_2, Max_HYPER_2 + 1 + 1)
- sample = np.tile(sample, [c, 1, 1])
- lower = torch.sigmoid(NIC_modformer.factorized_entropy_func._logits_cumulative(
- torch.FloatTensor(sample).cuda() - 0.5, stop_gradient=False))
-
- cdf_h = lower.data.cpu().numpy() * ((1 << precise) - (Max_HYPER_2 -
- Min_HYPER_2 + 1)) # [N1, 1, Max-Min+1]
- cdf_h = cdf_h.astype(np.int16) + sample.astype(np.int16) - Min_HYPER_2
- cdf_hyper = np.reshape(np.tile(cdf_h, [len(Datas_hyper[0]), 1, 1, 1]), [
- len(Datas_hyper[0]), c, -1])
-
- # Datas_hyper [256, N], cdf_hyper [256,1,X]
- Cdf_0, Cdf_1 = [], []
- for i in range(c):
- Cdf_0.extend(list(map(lambda x, y: int(
- y[x - Min_HYPER_2]), Datas_hyper[i], cdf_hyper[:, i, :]))) # Cdf[Datas - Min_V]
- Cdf_1.extend(list(map(lambda x, y: int(
- y[x - Min_HYPER_2]), Datas_hyper[i], cdf_hyper[:, i, 1:]))) # Cdf[Datas + 1 - Min_V]
- AE.encode_cdf(Cdf_0, Cdf_1, "hyper_2.bin")
- FileSizeHyper2 = os.path.getsize("hyper_2.bin")
- print("hyper_2.bin: %d bytes" % (FileSizeHyper2))
-
- Head_block = struct.pack('6H3I', Max_Main, Min_Main, Min_HYPER_1, Max_HYPER_1, Min_HYPER_2, Max_HYPER_2,
- FileSizeMain, FileSizeHyper1, FileSizeHyper2)
-
- file_object.write(Head_block) # CU information
- # cat Head_Infor and 2 files together
- # print("Head Info:",Head)
-
- with open("main.bin", 'rb') as f:
- bits = f.read()
- file_object.write(bits)
- with open("hyper_1.bin", 'rb') as f:
- bits = f.read()
- file_object.write(bits)
- with open("hyper_2.bin", 'rb') as f:
- bits = f.read()
- file_object.write(bits)
- del im, im_block_list, im_block_loc_list
- W_offset += block_W
- if W_offset >= W:
- W_offset = 0
- H_offset += block_H
- file_object.close()
-
- out_img = np.round(out_img * 255.0)
- out_img = out_img.astype('uint8')
- out_img = out_img[:H, :W, :]
- if SAVE_REC:
- img = Image.fromarray(out_img)
- img.save(rec_dir)
- # calculate bpp, psnr, msssim
-
- with open(bin_dir, "rb") as f:
- bpp = len(f.read()) * 8. / num_pixels
- print(bpp)
- f.close()
-
- [rgb_psnr, rgb_msssim, yuv_psnr, y_msssim] = evaluate(source_img, out_img)
- bpp_list.append(bpp)
- rgb_psnr_list.append(rgb_psnr)
- rgb_msssim_list.append(rgb_msssim)
-
- class_name = im_dir.split('/')[-2]
- image_name = im_dir.split('/')[-1].replace('.png', '')
- enc_dec_time = time.time() - enc_dec_time_start
- enc_time = enc_dec_time - dec_time
- enc_time_list.append(enc_time)
- dec_time_list.append(dec_time)
-
- print(class_name + '/' + image_name + '\t' + str(bpp) + '\t' + str(rgb_psnr) + '\t' +
- str(enc_time) + '\t' + str(dec_time) + '\n')
-
- del out_img
- mean_bpp = np.mean(bpp_list)
- mean_psnr = np.mean(rgb_psnr_list)
- mean_msssim = np.mean(rgb_msssim_list)
- mean_enctime = np.mean(enc_time_list)
-
- print("mean of bpp:", mean_bpp)
- print("mean of psnr:", mean_psnr)
- print("mean of msssim:", mean_msssim)
- print("mean of enctime:", mean_enctime)
- return y_hyper_2_q, y_hyper_q
-
-
- @torch.no_grad()
- def decode(bin_dir, rec_dir):
- ############### retreive head info ###############
- T = time.time()
- file_object = open(bin_dir, 'rb')
-
- head_len = struct.calcsize('4H')
- bits = file_object.read(head_len)
-
- [H, W, lmd, CTU_size] = struct.unpack('4H', bits)
- print(H, W)
-
- block_width = CTU_size
- block_height = CTU_size
- C = 3
- out_img = np.zeros([H, W, C])
- H_offset = 0
- W_offset = 0
- Block_Num_in_Width = int(np.ceil(W / block_width))
- Block_Num_in_Height = int(np.ceil(H / block_height))
- Block_Idx = 0
-
- M, N2 = 256, 192
- NIC_modformer = model.Image_coding_multi_hyper_modformer(3, M, N2, M, M // 2)
- lambda_rd = None
-
- c_main = M
- c_hyper = 256
- c_hyper_2 = 128
-
- ######################### Load Model #########################
- NIC_modformer.load_state_dict(torch.load(
- 'Weights/modformer_backup.pkl', map_location='cpu'))
-
- if GPU:
- NIC_modformer = NIC_modformer.cuda()
-
- for i_block in range(Block_Num_in_Height):
- for j_block in range(Block_Num_in_Width):
- # [block_H, block_W] indicates real shape of the current block
- block_H = block_height
- block_W = block_width
- if i_block == Block_Num_in_Height - 1:
- block_H = H - (Block_Num_in_Height - 1) * block_height
- if j_block == Block_Num_in_Width - 1:
- block_W = W - (Block_Num_in_Width - 1) * block_width
- print('==================> Decoding Block:', "(%d, %d)" % (i_block, j_block),
- "[%d, %d]" % (block_H, block_W))
- Block_Idx += 1
- precise = 16
- tile = 64.
-
- block_H_PAD = int(tile * np.ceil(block_H / tile))
- block_W_PAD = int(tile * np.ceil(block_W / tile))
- block_loc_list = []
- block_loc_list.append([0, 0, block_H_PAD, block_W_PAD])
-
- for block_loc in block_loc_list:
- # block_loc -> [vertical_location, horizontal_location, block_height, block_width]
- print('==================> Decoding sub_block:',
- "(%d, %d, %d, %d)" % (block_loc[0], block_loc[1], block_loc[2], block_loc[3]))
-
-
- enc_height = block_H_PAD
- enc_width = block_W_PAD
-
- Block_head_len = struct.calcsize('6H3I')
- bits = file_object.read(Block_head_len)
- [Max_Main, Min_Main, Min_HYPER_1, Max_HYPER_1, Min_HYPER_2, Max_HYPER_2, FileSizeMain, FileSizeHyper1,
- FileSizeHyper2] = struct.unpack('6H3I', bits)
-
- with open("main.bin", 'wb') as f:
- bits = file_object.read(FileSizeMain)
- f.write(bits)
- with open("hyper_1.bin", 'wb') as f:
- bits = file_object.read(FileSizeHyper1)
- f.write(bits)
- with open("hyper_2.bin", 'wb') as f:
- bits = file_object.read(FileSizeHyper2)
- f.write(bits)
-
- ############### Hyper 2 Decoder ###############
- # [Min_V - 0.5 , Max_V + 0.5]
- sample = np.arange(Min_HYPER_2, Max_HYPER_2 + 1 + 1)
- sample = np.tile(sample, [c_hyper_2, 1, 1])
- # Here goes HYY
- lower = torch.sigmoid(NIC_modformer.factorized_entropy_func._logits_cumulative(
- torch.FloatTensor(sample).cuda() - 0.5, stop_gradient=False))
- cdf_h = lower.data.cpu().numpy() * ((1 << precise) - (Max_HYPER_2 -
- Min_HYPER_2 + 1)) # [N1, 1, Max - Min]
- cdf_h = cdf_h.astype(np.int) + sample.astype(np.int) - Min_HYPER_2
-
- AE.init_decoder("hyper_2.bin", Min_HYPER_2, Max_HYPER_2)
-
- Recons = []
- for ii in range(c_hyper_2):
- for jj in range(int(enc_height * enc_width / 64 / 64)):
- # print(cdf_h[i,0,:])
- Recons.append(AE.decode_cdf(cdf_h[ii, 0, :].tolist()))
-
- # reshape Recons to y_hyper_q [1, c_hyper, H_PAD/64, W_PAD/64]
- y_hyper_2_q = torch.reshape(torch.Tensor(
- Recons), [1, c_hyper_2, int(enc_height / 64), int(enc_width / 64)])
-
- print(torch.sum((y_hyper_2_q - encode_y_hyper_2_q) ** 2))
- # IPython.embed()
- ############### Hyper 1 Decoder ###############
- # hyper_dec = image_comp.p(image_comp.hyper_dec(y_hyper_q))
- hyper_2_dec = NIC_modformer.p_2(NIC_modformer.hyper_2_dec(y_hyper_2_q.cuda()))
- # print("hyper_2_dec",hyper_2_dec.mean())
- _, c, h, w = hyper_2_dec.shape
- c //= 2
- mean = hyper_2_dec[:, :c, :, :]
- scale = hyper_2_dec[:, c:, :, :]
- scale = torch.abs(scale)
- scale[scale < 1e-6] = 1e-6
- # import IPython
- # IPython.embed()
- m = torch.distributions.normal.Normal(mean, scale)
-
- sample = np.arange(Min_HYPER_1, Max_HYPER_1 + 1 + 1) # [Min_V - 0.5 , Max_V + 0.5]
- sample = torch.FloatTensor(np.tile(sample, [1, c, h, w, 1])).cuda()
-
- lower = torch.zeros(1, c, h, w, Max_HYPER_1 - Min_HYPER_1 + 2).cuda()
- for cc in range(sample.shape[-1]):
- lower[..., cc] = m.cdf(sample[..., cc] - 0.5)
- # lower = m.cdf(sample-0.5)
- precise = 16
-
- cdf_m = lower.data.cpu().numpy() * ((1 << precise) - (Max_HYPER_1 - Min_HYPER_1 + 1))
- cdf_m = cdf_m.astype(np.int32) + sample.cpu().numpy().astype(np.int32) - Min_HYPER_1
-
- AE.init_decoder("hyper_1.bin", Min_HYPER_1, Max_HYPER_1)
-
- Recons = []
- for ii in range(c):
- for jj in range(int(h)):
- for kk in range(int(w)):
- # import IPython
- # IPython.embed()
- # print(ii,jj,kk)
- Recons.append(AE.decode_cdf(cdf_m[0, ii, jj, kk, :].tolist()))
-
- y_hyper_q = torch.reshape(torch.Tensor(Recons), [1, c, h, w])
- print(torch.sum((y_hyper_q - encode_y_hyper_q) ** 2))
-
- ############### Main Decoder ###############
- if GPU:
- y_hyper_q = y_hyper_q.cuda()
-
- x5 = NIC_modformer.hyper_1_dec(y_hyper_q)
-
- hyper_dec = NIC_modformer.p(x5)
- lmd_info = torch.ones((1, 1, 1)).cpu()
- lmd_info = lmd_info * lmd
- lmd_info = lmd_info.cuda()
-
- mask = NIC_modformer.modformer_mask(x5, lmd_info)
- masks = torch.reshape(mask, [-1]).cpu().numpy().astype(np.int)
-
- h, w = int(enc_height / 16), int(enc_width / 16)
-
- p3d = (5, 5, 5, 5, 5, 5)
- y_main_q = torch.zeros(1, 1, c_main + 10, h + 10, w + 10) # 8000x4000 -> 500*250
-
- if GPU:
- y_main_q = y_main_q.cuda()
- if USE_VR_MODEL:
- lambda_rd = lambda_rd.cuda()
- # AE.init_decoder("output_test/"+models[model_index]+"/main.bin", Min_Main, Max_Main)
- hyper = torch.unsqueeze(NIC_modformer.context.conv3(hyper_dec), dim=1)
- print("check4")
- #
- NIC_modformer.context.conv1.weight.data *= NIC_modformer.context.conv1.mask
-
- AE.init_decoder("main.bin", Min_Main, Max_Main)
- sample = np.arange(Min_Main, Max_Main + 1 + 1) # [Min_V - 0.5 , Max_V + 0.5]
- sample = torch.FloatTensor(sample)
-
- for i in range(np.sum(masks)):
- ci = i // (h * w)
- hi = (i - ci * h * w) // w
- wi = i - ci * h * w - hi * w
- if masks[i] == 0:
- y_main_q[0, 0, ci + 5, hi + 5, wi + 5] = 0
- else:
-
- x1 = F.conv3d(y_main_q[:, :, ci:ci + 12, hi:hi + 12, wi:wi + 12],
- weight=NIC_modformer.context.conv1.weight,
- bias=NIC_modformer.context.conv1.bias) # [1,24,1,1,1]
- params_prob = NIC_modformer.context.conv2(
- torch.cat((x1, hyper[:, :, ci:ci + 2, hi:hi + 2, wi:wi + 2]), dim=1))
- params_prob = params_prob.cpu()
- # 3 gaussian
- prob0, mean0, scale0, prob1, mean1, scale1, prob2, mean2, scale2 = params_prob[
- 0, :, 0, 0, 0]
- # keep the weight summation of prob == 1
- probs = torch.stack([prob0, prob1, prob2], dim=-1)
- probs = F.softmax(probs, dim=-1)
-
- # process the scale value to positive non-zero
- scale0 = torch.abs(scale0)
- scale1 = torch.abs(scale1)
- scale2 = torch.abs(scale2)
- scale0[scale0 < 1e-6] = 1e-6
- scale1[scale1 < 1e-6] = 1e-6
- scale2[scale2 < 1e-6] = 1e-6
- # 3 gaussian distributions
- m0 = torch.distributions.normal.Normal(mean0.view(1, 1).repeat(
- 1, Max_Main - Min_Main + 2), scale0.view(1, 1).repeat(1, Max_Main - Min_Main + 2))
- m1 = torch.distributions.normal.Normal(mean1.view(1, 1).repeat(
- 1, Max_Main - Min_Main + 2), scale1.view(1, 1).repeat(1, Max_Main - Min_Main + 2))
- m2 = torch.distributions.normal.Normal(mean2.view(1, 1).repeat(
- 1, Max_Main - Min_Main + 2), scale2.view(1, 1).repeat(1, Max_Main - Min_Main + 2))
- lower0 = m0.cdf(sample - 0.5)
- lower1 = m1.cdf(sample - 0.5)
- lower2 = m2.cdf(sample - 0.5) # [1,c,h,w,Max-Min+2]
-
- lower = probs[0:1] * lower0 + probs[1:2] * lower1 + probs[2:3] * lower2
- cdf_m = lower.data.cpu().numpy() * ((1 << precise) - (Max_Main -
- Min_Main + 1)) # [1, c, h, w ,Max-Min+1]
- cdf_m = cdf_m.astype(np.int) + \
- sample.cpu().numpy().astype(np.int) - Min_Main
-
- pixs = AE.decode_cdf(cdf_m[0, :].tolist())
- y_main_q[0, 0, ci + 5, hi + 5, wi + 5] = pixs
-
- del hyper, hyper_dec
-
- y_main_q = y_main_q[0, :, 5:-5, 5:-5, 5:-5]
-
- if GPU:
- y_main_q = y_main_q.cuda()
-
- rec = NIC_modformer.decoder(y_main_q, lambda_rd)
- output_ = torch.clamp(rec, min=0., max=1.0)
- out = output_.data[0].cpu().numpy()
- out = out.transpose(1, 2, 0)
- out_img[H_offset: H_offset + block_H, W_offset: W_offset + block_W, :] = out[:block_H, :block_W, :]
-
- del block_loc_list
-
- W_offset += block_W
- if W_offset >= W:
- W_offset = 0
- H_offset += block_H
-
- print('Decoding success!')
- out_img = np.round(out_img * 255.0)
- out_img = out_img.astype('uint8')
- img = Image.fromarray(out_img[:H, :W, :])
- img.save(rec_dir)
- del out_img
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument("-i", "--input", type=str, default='/home/gzq/v03_nic/NIC_test', help="Input Image")
- parser.add_argument("-o", "--output", type=str, default='./', help="Output Bin(encode)/Image(decode)")
- parser.add_argument("--lmd", type=int, default=1)
- parser.add_argument("--test_set", type=int, default=3)
-
- args = parser.parse_args()
-
- encode_y_hyper_2_q, encode_y_hyper_q = encode(args.output)
- bin_dir = 'enc.bin'
- rec_dir = 'dec.png'
- decode(bin_dir, rec_dir)
|