|
- from networks import compress_low,decompress_low,NetLow,NetHigh
- import argparse
- import glob
- import numpy as np
- from PIL import Image
- import os
- import time
- import pandas as pd
- import torch
- from pytorch_msssim import ms_ssim
-
- def psnr(img1, img2):
- img1 = np.array(Image.open(img1)).astype(np.float64)
- img2 = np.array(Image.open(img2)).astype(np.float64)
- mse = np.mean((img1-img2)**2)
- if mse == 0:
- return float('inf')
- else:
- return 20*np.log10(255/np.sqrt(mse))
-
- def MSSSIM(img1, img2):
- img1 = np.array(Image.open(img1).convert('RGB')).astype(np.float32)
- img1 = torch.tensor(img1).unsqueeze(0).permute((0, 3, 1, 2))
- img2 = np.array(Image.open(img2).convert('RGB')).astype(np.float32)
- img2 = torch.tensor(img2).unsqueeze(0).permute((0, 3, 1, 2))
- msssim = ms_ssim(img1,img2, data_range=255., size_average=False)
- return msssim.numpy()
-
- def bpp(bin, img):
- size = os.path.getsize(bin)
- f = Image.open(img)
- bpp = size * 8 / f.size[1] / f.size[0]
- return bpp
-
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument(
- "--model_type", default=0, type=int,
- help="Model type, choose from 0:PSNR 1:MS-SSIM"
- )
- parser.add_argument(
- "--device", default='cuda', type=str,
- help="Which device does the network run on?"
- )
- args = parser.parse_args()
-
- if not os.path.exists('test_output'):
- os.makedirs('test_output')
- device = torch.device(args.device)
- qp_list = [1,2,5,7]
- imglist = glob.glob("/userhome/dataset/kodak/test/*.png")
- for qp in qp_list:
- out_dir = 'test_output/qp' + str(qp)
- if not os.path.exists(out_dir):
- os.makedirs(out_dir)
- args.qp = qp
- mode = 'low'
- if args.qp > 3:
- mode = 'high'
- if mode == 'low':
- net = NetLow().eval()
- else:
- net = NetHigh().eval()
- net = net.to(device)
- model_type = args.model_type
- net.load_state_dict(torch.load('models/model'+str(model_type)+'_qp'+str(qp)+'.pth'))
- path_all = np.array([])
- PSNR_all = np.array([])
- bpp_all = np.array([])
- MSSSIM_all = np.array([])
- enc_time = np.array([])
- dec_time = np.array([])
-
- for i, oneimg in enumerate(imglist):
- start_time = time.time()
- print('enc '+oneimg.split('/')[-1])
- args.input = oneimg
- args.output = out_dir+'/'+oneimg.split('/')[-1][:-3]+'bin'
- compress_low(net, args)
- afterenc_time = time.time()
- enc_time = np.concatenate((enc_time, [afterenc_time-start_time]), axis=0)
- # if i>1:break
- for i, oneimg in enumerate(imglist):
- start_time = time.time()
- print('dec '+oneimg.split('/')[-1])
- args.input = out_dir+'/'+oneimg.split('/')[-1][:-3]+'bin'
- args.output = out_dir+'/'+oneimg.split('/')[-1][:-4]+'_rec.png'
- decompress_low(net, args)
- afterdec_time = time.time()
- dec_time = np.concatenate((dec_time, [afterdec_time-start_time]), axis=0)
- # if i>1:break
- for i, oneimg in enumerate(imglist):
- path_all = np.concatenate((path_all, [oneimg.split('/')[-1]]), axis=0)
- test_psnr = psnr(oneimg, out_dir+'/'+oneimg.split('/')[-1][:-4]+'_rec.png')
- PSNR_all = np.concatenate((PSNR_all, [test_psnr]), axis=0)
- test_bpp = bpp(out_dir+'/'+oneimg.split('/')[-1][:-3]+'bin',oneimg)
- bpp_all = np.concatenate((bpp_all, [test_bpp]), axis=0)
- msssim = MSSSIM(oneimg, out_dir+'/'+oneimg.split('/')[-1][:-4]+'_rec.png')
- MSSSIM_all = np.concatenate((MSSSIM_all, msssim), axis=0)
- # if i>1:break
- path_all = np.concatenate((path_all, ['all']), axis=0)
- path_all = path_all.reshape(-1, 1)
- PSNR_all = np.concatenate((PSNR_all, [np.mean(PSNR_all)]), axis=0)
- PSNR_all = PSNR_all.reshape(-1, 1)
- bpp_all = np.concatenate((bpp_all, [np.mean(bpp_all)]), axis=0)
- bpp_all = bpp_all.reshape(-1, 1)
- MSSSIM_all = np.concatenate((MSSSIM_all, [np.mean(MSSSIM_all)]), axis=0)
- MSSSIM_all = MSSSIM_all.reshape(-1, 1)
- enc_time = np.concatenate((enc_time, [np.sum(enc_time)]), axis=0)
- enc_time = enc_time.reshape(-1, 1)
- dec_time = np.concatenate((dec_time, [np.sum(dec_time)]), axis=0)
- dec_time = dec_time.reshape(-1, 1)
- all_results = np.concatenate((path_all, bpp_all, PSNR_all, MSSSIM_all,enc_time,dec_time), axis=1)
- results2 = pd.DataFrame(columns=['imgname', 'bpp', 'PSNR', 'MSSSIM','enc_time','dec_time'], data=all_results)
- results2.to_csv(f'test_output/coarse2fine_pytorch_qp{qp}.csv',index=False)
|