|
- import math
- import io
- import torch
- from torchvision import transforms
- import numpy as np
- import time
- from PIL import Image
-
- import matplotlib.pyplot as plt
- from compressai.zoo import bmshj2018_hyperprior, bmshj2018_factorized
- import mindspore as ms
- from mindspore import save_checkpoint, Tensor
- from mindspore import load_checkpoint, load_param_into_net
-
- # pad to 64 muiltiple
- def compute_padding(in_h: int, in_w: int, *, out_h=None, out_w=None, min_div=1):
- """Returns tuples for padding and unpadding.
- Args:
- in_h: Input height.
- in_w: Input width.
- out_h: Output height.
- out_w: Output width.
- min_div: Length that output dimensions should be divisible by.
- """
- if out_h is None:
- out_h = (in_h + min_div - 1) // min_div * min_div
- if out_w is None:
- out_w = (in_w + min_div - 1) // min_div * min_div
-
- if out_h % min_div != 0 or out_w % min_div != 0:
- raise ValueError(
- f"Padded output height and width are not divisible by min_div={min_div}."
- )
-
- left = (out_w - in_w) // 2
- right = out_w - in_w - left
- top = (out_h - in_h) // 2
- bottom = out_h - in_h - top
-
- pad = (left, right, top, bottom)
- unpad = (-left, -right, -top, -bottom)
-
- return pad, unpad
-
- def pad(x, p=2**6):
- x_s = x.shape
- h, w = x_s[2], x_s[3]
- pad, _ = compute_padding(h, w, min_div=p)
- return ms.ops.pad(x, pad, mode="constant", value=0)
-
- def compute_psnr(img1, img2):
- img1 = img1.asnumpy()
- img2 = img2.asnumpy()
-
- mse = np.mean( (img1 - img2) ** 2 )
- if mse < 1.0e-10:
- return 100
- PIXEL_MAX = 1
- return np.array([20 * math.log10(PIXEL_MAX / math.sqrt(mse))])
-
- def compute_msssim(img1, img2):
- MSSSIM = ms.nn.MSSSIM(max_val=1, power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)(img1,img2)
- return MSSSIM.asnumpy()
-
- # def compute_bpp(out_net):
- # size = out_net[0].shape
- # num_pixels = size[0] * size[2] * size[3]
- # ms_log = ms.ops.Log()
- # return sum(ms_log(likelihoods).sum() / (-math.log(2) * num_pixels) for likelihoods in out_net[1][0]).asnumpy()
-
- def compute_bpp(out_net):
- shape = out_net['x_hat'].shape
- num_pixels = shape[0] * shape[2] * shape[3]
- return sum(ms.ops.log(likelihoods).sum() / (-math.log(2) * num_pixels)
- for likelihoods in out_net['likelihoods'].values()).asnumpy()
-
- if __name__ == '__main__':
- choose = 1
- exp_name = ['bmshj2018-factorized', 'bmshj2018-hyperprior'][choose]
- ckpt_name = 'bmshj2018-factorized-prior' if exp_name == 'bmshj2018-factorized' else 'bmshj2018-hyperprior'
-
- import argparse, os, glob
- import pandas as pd
-
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument(
- "--input", default='example.png', type=str,
- help="Input filename.")
- parser.add_argument(
- "--qp", default=2, type=int,
- help="Quality parameter, choose from [1~7] (model0) or [1~8] (model1)"
- )
- parser.add_argument(
- "--model_type", default=0, type=int
- )
-
- args = parser.parse_args()
- ms.set_context(device_target="GPU")
-
- #exp = 'bmshj2018-hyperprior'
-
- if not os.path.exists('test_output'):
- os.makedirs('test_output')
- qp_list = [1,2,5,7]
- imglist = glob.glob("/userhome/CAE-ADMM/kodak/test/*.png")
- assert imglist != []
- for index, qp in enumerate(qp_list):
- args.qp = qp
-
- net_name = exp_name.replace('-', '_')
- net = locals()[net_name](quality=qp, pretrained=False)
- param_dir = 'compressai/models/convert_%s-%d.ckpt' % (ckpt_name, qp)
- param_dict = load_checkpoint(param_dir)
- param_not_load = load_param_into_net(net, param_dict)
- print(param_not_load)
-
- net.set_train(False)
-
- PSNR_all = np.array([])
- bpp_all = np.array([])
- MSSSIM_all = np.array([])
- time_all = 0
-
- for i, oneimg in enumerate(imglist):
- img = Image.open(oneimg).convert('RGB')
- totensor=ms.dataset.vision.ToTensor()
- ms_x = Tensor(totensor(img))
- ms_x = ms_x.unsqueeze(0)
- #print('shape of input_img:', ms_x.shape)
-
- ms_x = pad(ms_x, p=64)
- # forward
- start_time = time.time()
- ms_out=net.construct(ms_x)
- end_time = time.time()
- time_cost = end_time - start_time
- x_hat = ms_out['x_hat'].clamp(0, 1)
- # print('shape of rec_img:', ms_out[0].shape)
-
- test_psnr = compute_psnr(ms_x, x_hat)
- #print('------------------>', PSNR_all.shape, test_psnr.shape)
- PSNR_all = np.concatenate((PSNR_all, test_psnr), axis=0)
- test_bpp =compute_bpp(ms_out)
- bpp_all = np.concatenate((bpp_all, [test_bpp]), axis=0)
- ms_ssim = compute_msssim(ms_x, x_hat)
- MSSSIM_all = np.concatenate((MSSSIM_all, ms_ssim), axis=0)
- time_all = time_all + time_cost
-
- # if i>1:break
-
- PSNR_all = np.concatenate((PSNR_all, [np.mean(PSNR_all)]), axis=0)
- PSNR_all = PSNR_all.reshape(-1, 1).mean()
- bpp_all = np.concatenate((bpp_all, [np.mean(bpp_all)]), axis=0)
- bpp_all = bpp_all.reshape(-1, 1).mean()
- MSSSIM_all = np.concatenate((MSSSIM_all, [np.mean(MSSSIM_all)]), axis=0)
- MSSSIM_all = MSSSIM_all.reshape(-1, 1).mean()
-
-
- mode = 'w'
- header = True
- if index != 0:
- mode = 'a'
- header=False
-
- all_results = [{'qp':qp, 'bpp':bpp_all, 'PSNR':PSNR_all, 'MSSSIM':MSSSIM_all, 'time_cost':time_all, 'GPU M':4004}]
- print(all_results)
- results2 = pd.DataFrame(data=all_results)
- results2.to_csv(f'test_output/%s_mindspore.csv' % exp_name,index=False, mode=mode, header=header)
|