|
- import numpy as np
- import torch
- from Util import torch_msssim
- from Util.config import dict
-
- GPU = dict['GPU']
- USE_GEO = dict['USE_GEO']
- USE_VR_MODEL = dict['USE_VR_MODEL']
- ratio_list = [200,400,800,1600,3200,6400,12800,25600,4,8,16,32,64,128,320,640]
- if USE_VR_MODEL:
- max_lambdas = [64, 256, 1.28, 6.40]
-
- @torch.no_grad()
- def check_RD(test_img, lambda_rd, image_comp, context, model_index): # input 1*C*block_H_PAD*block_W_PAD
- _, _, H, W = test_img.shape
- num_pixels = H * W
- fake, xp1, xp2, xq1, x3 = image_comp(test_img, 2, lambda_rd)
- xp3, _ = context(xq1, x3)
- if ((not USE_VR_MODEL) and (model_index < 8)) or (USE_VR_MODEL and (model_index < 2)):
- d = (torch.sum((fake - test_img) ** 2, [1, 2, 3]) / num_pixels / 3).detach().cpu().numpy()
- else:
- msssim_func = torch_msssim.MS_SSIM(max_val=1.)
- if GPU:
- msssim_func = msssim_func.cuda()
- d = 1. - msssim_func(fake, test_img).detach().cpu().numpy()
- r = ((torch.sum(torch.log(xp2)) + torch.sum(torch.log(xp3))) / (-np.log(2) * num_pixels)).detach().cpu().numpy()
- return r, d, fake
-
- @torch.no_grad()
- def check_RD_GEO(test_img, lambda_rd, image_comp, context, model_index): # data augmentation based RDO
- if USE_VR_MODEL:
- ratio = lambda_rd.cpu().numpy()[0,0] * 100
- ratio *= max_lambdas[model_index]
- else:
- ratio = ratio_list[model_index]
- _, _, H, W = test_img.shape
- num_pixels = H * W
- r_list = []
- d_list = []
- rd_list = []
- rec_list = []
- for i_flip in range(2):
- for i_rot in range(4):
- if (i_flip == 0):
- r_, d_, rec = check_RD(torch.rot90(test_img, k=i_rot, dims=[2, 3]), lambda_rd, image_comp, context, model_index)
- else:
- r_, d_, rec = check_RD(torch.rot90(torch.flip(test_img, dims=[2]), k=i_rot, dims=[2, 3]), lambda_rd, image_comp,
- context, model_index)
- r_list.append(r_)
- d_list.append(d_)
- rd_list.append(ratio * d_ + r_)
- rec_list.append(rec)
- opt_idx = np.array(np.where(rd_list == np.min(rd_list)))
- geo_index = opt_idx[0][0] # geometric operation index
- r = np.array(r_list)[geo_index] + 3 / num_pixels
- d = np.array(d_list)[geo_index]
- rec = rec_list[geo_index]
- del r_list, d_list, rd_list, rec_list
- return r, d, geo_index, rec # rec is used for MS-SSIM model RDO
|