|
- import argparse
- import math
- import os
- import struct
- import sys
- import time
- import glob
-
- import numpy as np
- import torch
- import torch.nn.functional as F
- from PIL import Image
-
- # import Util.AE as AE
- #import AE
- from torch.autograd import Variable
- import Model.model as model
- from Model.context_model import Weighted_Gaussian
- from Util.metrics import evaluate
- from Util import torch_msssim
- from Util.block_metric import check_RD_GEO # blcok based metric
- from Util.config import dict
- import torch.nn as nn
- from Util.generate_substitute import SubstituteGenerator
-
- ############################## Load Configuration Parameters ########################
- GPU = dict['GPU']
- USE_GEO = dict['USE_GEO']
- block_width = 1024
- block_height = 1024
- USE_MULTI_HYPER = False
-
- os.environ["CUDA_VISIBLE_DEVICES"] = "4"
-
- def meanstd(x):
- n = len(x)
- m=0
- std=0
- for i in range(n):
- m+=x[i]/n
- std+=(x[i]**2)/n
- return m, np.sqrt(std-m**2)
-
- def encode():
- test_image_paths = []
- dirs = os.listdir(args.input)
- dir_list = ['ClassA_6k', 'ClassB_4k', 'ClassC_2k', 'ClassD_Kodak']
- for dir in dirs:
- if dir == dir_list[args.test_set]:
- path = os.path.join(args.input, dir)
- if os.path.isdir(path):
- test_image_paths += glob.glob(path + '/*.png')
- if os.path.isfile(path):
- test_image_paths.append(path)
-
- im_dirs = test_image_paths
-
- M, N2 = 256,192
-
- NIC_anchor = model.Image_coding(3, M, N2, M, M // 2)
- context = Weighted_Gaussian(M)
- loss_func = nn.MSELoss()
-
- ######################### Load Model #########################
- NIC_anchor.load_state_dict(torch.load(
- 'mse25600.pkl', map_location='cpu'))
- context.load_state_dict(torch.load(
- 'mse25600p.pkl', map_location='cpu'))
- print('loaded')
- if GPU:
- NIC_anchor.cuda()
- context.cuda()
-
- #################### Compress Each Image ###################
-
- idx = -1
- for im_dir in im_dirs:
- idx += 1
- if idx!=0:
- continue
- ######################### Read Image #########################
- img = Image.open(im_dir)
- source_img = np.array(img)
- img = source_img / 255.0
- H, W, _ = img.shape
-
- C = 3
- ######################### Spliting Image #########################
- Block_Num_in_Width = int(np.ceil(W / block_width))
- Block_Num_in_Height = int(np.ceil(H / block_height))
- img_block_list = []
- for i in range(Block_Num_in_Height):
- for j in range(Block_Num_in_Width):
- img_block_list.append(img[i * block_height:np.minimum((i + 1) * block_height, H),
- j * block_width:np.minimum((j + 1) * block_width, W), ...])
-
- ######################### Padding Image #########################
- Block_Idx = 0
-
- for img in img_block_list: # Traverse CTUs
- block_H = img.shape[0]
- block_W = img.shape[1]
- tile = 64.
- block_H_PAD = int(tile * np.ceil(block_H / tile))
- block_W_PAD = int(tile * np.ceil(block_W / tile))
- im = np.zeros([block_H_PAD, block_W_PAD, 3], dtype='float32')
- im[:block_H, :block_W, :] = img[:, :, :3]
- im = torch.FloatTensor(im)
- im = im.permute(2, 0, 1).contiguous()
- im = im.view(1, C, block_H_PAD, block_W_PAD)
- if GPU:
- im = im.cuda()
-
- Block_Idx += 1
- # begin processing CTU
- im_block_list = []
- im_block_loc_list = []
- im_block_list.append(im) # list size = 1
- im_block_loc_list.append([0, 0, block_H_PAD, block_W_PAD])
-
- for im_block_loc, im_block in zip(im_block_loc_list, im_block_list): # Traverse CUs
- _,c,h,w = im_block.size()
- x1, x2= NIC_anchor.encoder(im_block)
- shape = x1.shape
- xq1 = torch.round(x1) # xq1 = model.UniverseQuant.apply(x1)
- xq1.retain_grad()
- xq2, xp2 = NIC_anchor.factorized_entropy_func(x2, 2)
- x3 = NIC_anchor.hyper_dec(xq2)
- hyper_dec = NIC_anchor.p(x3)
-
- output = NIC_anchor.decoder(xq1)
- dloss = loss_func(output,im_block)
- dloss.backward()
- xp1,_ = context(xq1,hyper_dec)
- rate = -torch.log(xp1)/(h*w*np.log(2))
- print(torch.sum(rate).item(),-10*np.log10(dloss.item()))
-
- mask_latent = torch.FloatTensor(shape)
-
- g = xq1.grad.detach().cpu()
- xq1 = xq1.detach().cpu()
- rate = rate.detach().cpu()
- for i in range(256):
- for j in range(shape[2]):
- for k in range(shape[3]):
- if g[0,i,j,k]==0 or (xq1[0,i,j,k]!=0 and torch.abs(rate/(xq1*g))[0,i,j,k]>100*args.lmd):
- mask_latent[0,i,j,k] = -2
- else:
- mask_latent[0, i, j, k] = 2
- mask_latent = nn.Parameter(mask_latent.cuda())
- mask_latent.requires_grad = True
- opt = torch.optim.Adam([mask_latent], lr=0.01)
- xq1 = Variable(xq1).cuda()
- hyper_dec = Variable(hyper_dec).cuda()
-
- for itr in range(1000):
- opt.zero_grad()
- mask = F.sigmoid(mask_latent)
- output = NIC_anchor.decoder(xq1 * mask)
- xp1,_ = context(xq1*mask,hyper_dec)
- bpp = torch.sum(mask*torch.log(xp1)/(-np.log(2)*h*w))
- loss = bpp + args.lmd * loss_func(output,im_block)
- loss.backward()
- opt.step()
-
- mask = F.sigmoid(mask_latent)
- xq1 = xq1 * mask
- xp1, _ = context(xq1, hyper_dec)
- output = NIC_anchor.decoder(xq1)
- dloss = loss_func(output, im_block)
- bpp = torch.sum(mask * torch.log(xp1) / (-np.log(2) * h * w))
- print(mask_latent)
- print(bpp.item(), -10 * np.log10(dloss.item()))
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument("-i", "--input", type=str, default='/backup3/NIC_test', help="Input Image")
- parser.add_argument("-o", "--output", type=str, default='./', help="Output Bin(encode)/Image(decode)")
- parser.add_argument("--lmd", type=int, default=128)
- parser.add_argument("--test_set", type=int, default=3)
- parser.add_argument("--pic_idx", type=int, default=0)
-
- args = parser.parse_args()
- encode()
|