|
- import argparse
- import math
- import os
- import struct
- import sys
- import time
- import glob
-
- import numpy as np
- import torch
- import torch.nn.functional as F
- from PIL import Image
-
- # import Util.AE as AE
- #import AE
- import Model.model as model
- from Model.context_model import Weighted_Gaussian
- from Util.metrics import evaluate
- from Util import torch_msssim
- from Util.block_metric import check_RD_GEO # blcok based metric
- from Util.config import dict
- from Util.generate_substitute import SubstituteGenerator
-
- ############################## Load Configuration Parameters ########################
- GPU = dict['GPU']
- USE_GEO = False
- block_width = 512
- block_height = 512
- USE_MULTI_HYPER = False
-
-
- def encode(out_dir):
- test_image_paths = []
- dirs = os.listdir(args.input)
- dir_list = ['ClassA_6k', 'ClassB_4k', 'ClassC_2k', 'ClassD_Kodak']
- for dir in dirs:
- if dir == dir_list[args.test_set]:
- path = os.path.join(args.input, dir)
- if os.path.isdir(path):
- test_image_paths += glob.glob(path + '/*.png')
- if os.path.isfile(path):
- test_image_paths.append(path)
-
- im_dirs = test_image_paths
-
- M, N2 = 256, 192
-
- NIC = model.Image_coding(3, M, N2, M, M // 2).cuda()
- context = Weighted_Gaussian(M).cuda()
- modnet = model.modnet(M).cuda()
- context_mod = Weighted_Gaussian(M).cuda()
-
- ######################### Load Model #########################
- modnet.load_state_dict(torch.load('modnet.pkl'))
- context_mod.load_state_dict(torch.load('context_mod.pkl'))
- NIC.load_state_dict(torch.load('mse25600.pkl'))
- context.load_state_dict(torch.load('mse25600p.pkl'))
-
- #################### Compress Each Image ###################
- bpp_list = []
- rgb_psnr_list = []
- rgb_msssim_list = []
- enc_time_list = []
- dec_time_list = []
- idx = -1
- c = []
-
- for im_dir in im_dirs:
- idx += 1
- if idx != 23:
- continue
- dec_time = 0
- enc_dec_time_start = time.time()
-
- ######################### Read Image #########################
- img = Image.open(im_dir)
- source_img = np.array(img)
- img = source_img / 255.0
- H, W, _ = img.shape
- print(img.shape)
- num_pixels = H * W
- C = 3
-
- out_img = np.zeros([H, W, C]) # recon image
- H_offset = 0
- W_offset = 0
- ######################### Spliting Image #########################
- Block_Num_in_Width = int(np.ceil(W / block_width))
- Block_Num_in_Height = int(np.ceil(H / block_height))
- img_block_list = []
- for i in range(Block_Num_in_Height):
- for j in range(Block_Num_in_Width):
- img_block_list.append(img[i * block_height:np.minimum((i + 1) * block_height, H),
- j * block_width:np.minimum((j + 1) * block_width, W), ...])
-
- ######################### Padding Image #########################
- Block_Idx = 0
- bpp = 0
-
- for img in img_block_list: # Traverse CTUs
- block_H = img.shape[0]
- block_W = img.shape[1]
- tile = 64.
- block_H_PAD = int(tile * np.ceil(block_H / tile))
- block_W_PAD = int(tile * np.ceil(block_W / tile))
- im = np.zeros([block_H_PAD, block_W_PAD, 3], dtype='float32')
- im[:block_H, :block_W, :] = img[:, :, :3]
- im = torch.FloatTensor(im)
- im = im.permute(2, 0, 1).contiguous()
- im = im.view(1, C, block_H_PAD, block_W_PAD)
- if GPU:
- im = im.cuda()
-
- Block_Idx += 1
- # begin processing CTU
- im_block_list = []
- im_block_loc_list = []
- im_block_list.append(im) # list size = 1
- im_block_loc_list.append([0, 0, block_H_PAD, block_W_PAD])
-
- for im_block_loc, im_block in zip(im_block_loc_list, im_block_list): # Traverse CUs
- with torch.no_grad():
- lmd = torch.ones((1, 1, 1)) * args.lmd
- lmd = lmd.cuda()
-
- x1, x2 = NIC.encoder(im)
- xq2, xp2 = NIC.factorized_entropy_func(x2, 2)
- x3 = NIC.hyper_dec(xq2)
- hyper_dec = NIC.p(x3)
- xq1 = torch.round(x1)
- xp1, _ = context(xq1, hyper_dec)
- modnet_mask = modnet(xq1, xp1, lmd)
- xq1 = xq1 * modnet_mask
- xq1 = torch.round(xq1) # 再量化一次
- xp1, _ = context_mod(xq1, hyper_dec)
- fake = NIC.decoder(xq1)
-
- bits = torch.sum(torch.log(xp1)) / (-np.log(2)) + torch.sum(torch.log(xp2)) / (-np.log(2))
-
- bpp_blk = bits/num_pixels
- bpp += bpp_blk.item()
- dec_time_start = time.time()
-
- ################################### Reconstruct Image #######################################
- output_ = torch.clamp(fake, min=0., max=1.0)
- out = output_.data[0].cpu().numpy()
- out = out.transpose(1, 2, 0)
-
- out_img[H_offset: H_offset + block_H, W_offset: W_offset + block_W, :] = out[:block_H, :block_W,
- :]
- dec_time += (time.time() - dec_time_start)
-
-
- W_offset += block_W
- if W_offset >= W:
- W_offset = 0
- H_offset += block_H
-
- out_img = np.round(out_img * 255.0)
- out_img = out_img.astype('uint8')
- out_img = out_img[:H, :W, :]
- # calculate bpp, psnr, msssim
-
- [rgb_psnr, rgb_msssim, yuv_psnr, y_msssim] = evaluate(source_img, out_img)
- bpp_list.append(bpp)
- rgb_psnr_list.append(rgb_psnr)
- rgb_msssim_list.append(rgb_msssim)
-
- class_name = im_dir.split('/')[-2]
- image_name = im_dir.split('/')[-1].replace('.png', '')
- enc_dec_time = time.time() - enc_dec_time_start
- enc_time = enc_dec_time - dec_time
- enc_time_list.append(enc_time)
- dec_time_list.append(dec_time)
-
- print(class_name + '/' + image_name + '\t' + str(bpp) + '\t' + str(rgb_psnr) + '\t' +
- str(enc_time) + '\t' + str(dec_time) + '\n')
-
- del out_img
- mean_bpp = np.mean(bpp_list)
- mean_psnr = np.mean(rgb_psnr_list)
-
- print("mean of bpp:", mean_bpp)
- print("mean of psnr:", mean_psnr)
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument("-i", "--input", type=str, default='D:/CV_dataset/NIC_test/', help="Input Image")
- parser.add_argument("-o", "--output", type=str, default='./', help="Output Bin(encode)/Image(decode)")
- parser.add_argument("--lmd", type=int, default=1)
- parser.add_argument("--test_set", type=int, default=3)
- parser.add_argument("--pic_idx", type=int, default=0)
-
- args = parser.parse_args()
-
- encode(args.output)
|