Browse Source

first commit

master
root 2 months ago
commit
88642171e1
30 changed files with 2289 additions and 0 deletions
  1. +108
    -0
      .gitignore
  2. +0
    -0
      README.md
  3. +0
    -0
      README_CN.md
  4. +0
    -0
      eval.py
  5. +0
    -0
      export.py
  6. +221
    -0
      extract_subimages.py
  7. +0
    -0
      module_test.py
  8. +0
    -0
      src/__init__.py
  9. +0
    -0
      src/config/__init__.py
  10. +49
    -0
      src/config/config.py
  11. +0
    -0
      src/dataset/__init__.py
  12. +145
    -0
      src/dataset/dataset_DIV2K.py
  13. +172
    -0
      src/dataset/transform_opencv.py
  14. +84
    -0
      src/model/RRDB_Net.py
  15. +95
    -0
      src/model/VGG.py
  16. +0
    -0
      src/model/__init__.py
  17. +83
    -0
      src/model/discriminator_net.py
  18. +56
    -0
      src/model/loss.py
  19. +0
    -0
      src/utils/__init__.py
  20. +326
    -0
      src/utils/eval_util.py
  21. +221
    -0
      src/utils/extract_subimages.py
  22. +81
    -0
      src/utils/loss.py
  23. +192
    -0
      src/utils/matlab_functions.py
  24. +47
    -0
      src/utils/metric_util.py
  25. +205
    -0
      src/utils/niqe.py
  26. BIN
      src/utils/niqe_pris_params.npz
  27. +140
    -0
      src/utils/psnr_ssim.py
  28. +23
    -0
      src/utils/rename.py
  29. +0
    -0
      train.py
  30. +41
    -0
      train_psnr.py

+ 108
- 0
.gitignore View File

@@ -0,0 +1,108 @@
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version
.idea/
# celery beat schedule file
celerybeat-schedule
workspace.xml
# SageMath parsed files
*.sage.py
.idea/workspace.xml
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
# file
.idea/
.vscode/
data/
kernel_meta/

+ 0
- 0
README.md View File


+ 0
- 0
README_CN.md View File


+ 0
- 0
eval.py View File


+ 0
- 0
export.py View File


+ 221
- 0
extract_subimages.py View File

@@ -0,0 +1,221 @@
"""A multi-thread tool to crop large images to sub-images for faster IO.
(This preprocessing code is copied and modified from official implement:
https://github.com/open-mmlab/mmsr/tree/master/codes/data_scripts)"""
import os
import os.path as osp
import sys
from multiprocessing import Pool
import numpy as np
import cv2
from PIL import Image
import time
from shutil import get_terminal_size
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
def main():
mode = 'pair' # single (one input folder) | pair (extract corresponding GT and LR pairs)
opt = {}
opt['n_thread'] = 20
opt['compression_level'] = 3 # 3 is the default value in cv2
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
if mode == 'single':
opt['input_folder'] = './data/DIV2K/DIV2K_train_HR'
opt['save_folder'] = './data/DIV2K/DIV2K800_sub'
opt['crop_sz'] = 480 # the size of each sub-image
opt['step'] = 240 # step of the sliding crop window
opt['thres_sz'] = 48 # size threshold
extract_signle(opt)
elif mode == 'pair':
GT_folder = './data/DIV2K/DIV2K_train_HR'
LR_folder = './data/DIV2K/DIV2K_train_LR_X4'
save_GT_folder = './data/DIV2K/DIV2K800_sub'
save_LR_folder = './data/DIV2K/DIV2K800_sub_LRx4'
scale_ratio = 4
crop_sz = 480 # the size of each sub-image (GT)
step = 240 # step of the sliding crop window (GT)
thres_sz = 48 # size threshold
########################################################################
# check that all the GT and LR images have correct scale ratio
img_GT_list = _get_paths_from_images(GT_folder)
img_LR_list = _get_paths_from_images(LR_folder)
assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.'
for path_GT, path_LR in zip(img_GT_list, img_LR_list):
img_GT = Image.open(path_GT)
img_LR = Image.open(path_LR)
w_GT, h_GT = img_GT.size
w_LR, h_LR = img_LR.size
assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
w_GT, scale_ratio, w_LR, path_GT)
assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
w_GT, scale_ratio, w_LR, path_GT)
# check crop size, step and threshold size
assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format(
scale_ratio)
assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio)
assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format(
scale_ratio)
print('process GT...')
opt['input_folder'] = GT_folder
opt['save_folder'] = save_GT_folder
opt['crop_sz'] = crop_sz
opt['step'] = step
opt['thres_sz'] = thres_sz
extract_signle(opt)
print('process LR...')
opt['input_folder'] = LR_folder
opt['save_folder'] = save_LR_folder
opt['crop_sz'] = crop_sz // scale_ratio
opt['step'] = step // scale_ratio
opt['thres_sz'] = thres_sz // scale_ratio
extract_signle(opt)
assert len(_get_paths_from_images(save_GT_folder)) == len(
_get_paths_from_images(
save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.'
else:
raise ValueError('Wrong mode.')
def extract_signle(opt):
input_folder = opt['input_folder']
save_folder = opt['save_folder']
if not osp.exists(save_folder):
os.makedirs(save_folder)
print('mkdir [{:s}] ...'.format(save_folder))
else:
print('Folder [{:s}] already exists. Exit...'.format(save_folder))
sys.exit(1)
img_list = _get_paths_from_images(input_folder)
def update(arg):
pbar.update(arg)
pbar = ProgressBar(len(img_list))
pool = Pool(opt['n_thread'])
for path in img_list:
pool.apply_async(worker, args=(path, opt), callback=update)
pool.close()
pool.join()
print('All subprocesses done.')
def worker(path, opt):
crop_sz = opt['crop_sz']
step = opt['step']
thres_sz = opt['thres_sz']
img_name = osp.basename(path)
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
n_channels = len(img.shape)
if n_channels == 2:
h, w = img.shape
elif n_channels == 3:
h, w, c = img.shape
else:
raise ValueError('Wrong image shape - {}'.format(n_channels))
h_space = np.arange(0, h - crop_sz + 1, step)
if h - (h_space[-1] + crop_sz) > thres_sz:
h_space = np.append(h_space, h - crop_sz)
w_space = np.arange(0, w - crop_sz + 1, step)
if w - (w_space[-1] + crop_sz) > thres_sz:
w_space = np.append(w_space, w - crop_sz)
index = 0
for x in h_space:
for y in w_space:
index += 1
if n_channels == 2:
crop_img = img[x:x + crop_sz, y:y + crop_sz]
else:
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
crop_img = np.ascontiguousarray(crop_img)
cv2.imwrite(
osp.join(opt['save_folder'],
img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img,
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
return 'Processing {:s} ...'.format(img_name)
# ##############
# ### Utils ####
# ##############
class ProgressBar(object):
'''A progress bar which can print the progress
modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
'''
def __init__(self, task_num=0, bar_width=50, start=True):
self.task_num = task_num
max_bar_width = self._get_max_bar_width()
self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
self.completed = 0
if start:
self.start()
def _get_max_bar_width(self):
terminal_width, _ = get_terminal_size()
max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
if max_bar_width < 10:
print('terminal width is too small ({}), please consider widen the terminal for better '
'progressbar visualization'.format(terminal_width))
max_bar_width = 10
return max_bar_width
def start(self):
if self.task_num > 0:
sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
' ' * self.bar_width, self.task_num, 'Start...'))
else:
sys.stdout.write('completed: 0, elapsed: 0s')
sys.stdout.flush()
self.start_time = time.time()
def update(self, msg='In progress...'):
self.completed += 1
elapsed = time.time() - self.start_time + 1e-9
fps = self.completed / elapsed
if self.task_num > 0:
percentage = self.completed / float(self.task_num)
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
mark_width = int(self.bar_width * percentage)
bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
sys.stdout.write('\033[2F') # cursor up 2 lines
sys.stdout.write('\033[J') # clean the output (remove extra chars since last display)
sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
else:
sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
self.completed, int(elapsed + 0.5), fps))
sys.stdout.flush()
# ###################
# ### Data Utils ####
# ###################
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def _get_paths_from_images(path):
"""get image path list from image folder"""
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
images = []
for dirpath, _, fnames in sorted(os.walk(path)):
for fname in sorted(fnames):
if is_image_file(fname):
img_path = os.path.join(dirpath, fname)
images.append(img_path)
assert images, '{:s} has no valid image file'.format(path)
return images
if __name__ == '__main__':
main()

+ 0
- 0
module_test.py View File


+ 0
- 0
src/__init__.py View File


+ 0
- 0
src/config/__init__.py View File


+ 49
- 0
src/config/config.py View File

@@ -0,0 +1,49 @@
ESRGAN_config = {
# image setting
"input_size": 32,
"gt_size": 128,
"ch_size": 3,
"scale": 4,
# Generator setting
"G_nf":64,
"G_nb":23,
# discriminator setting
"D_nf":64,
# training setting
"niter": 400000,
"lr_G": [1e-4,5e-5,2e-5,1e-5],
"lr_D": [1e-4,5e-5,2e-5,1e-5],
"lr_steps": [50000, 100000, 200000, 300000],
"w_pixel": 1e-2,
"w_feature": 1.0,
"w_gan": 5e-3,
"gan_type": "gan",
# gan | ragan
"save_steps": 5000
}
PSNR_config = {
# image setting
"input_size": 32,
"gt_size": 128,
"ch_size": 3,
"scale": 4,
# Generator setting
"G_nf":64,
"G_nb":23,
# training setting
"niter": 400000,
"lr": [2e-4,1e-4,5e-5,2e-5],
"lr_steps": [200000, 400000, 600000, 800000],
"w_pixel": 1.0,
"pixel_criterion": "l1",
"save_steps": 5000
}

+ 0
- 0
src/dataset/__init__.py View File


+ 145
- 0
src/dataset/dataset_DIV2K.py View File

@@ -0,0 +1,145 @@
import os
from os import listdir
from os.path import join
from PIL import Image, ImageOps
import random
import mindspore.dataset.engine as de
import numpy as np
from mindspore.dataset.vision import c_transforms as C
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
def load_img(filepath):
img = Image.open(filepath).convert('RGB')
# y, _, _ = img.split()
return img
def rescale_img(img_in, scale):
size_in = img_in.size
new_size_in = tuple([int(x * scale) for x in size_in])
img_in = img_in.resize(new_size_in, resample=Image.BICUBIC)
return img_in
def augment(img_in, img_tar, flip_h=True, rot=True):
info_aug = {'flip_h': False, 'flip_v': False, 'trans': False}
if random.random() < 0.5 and flip_h:
img_in = ImageOps.flip(img_in)
img_tar = ImageOps.flip(img_tar)
info_aug['flip_h'] = True
if rot:
if random.random() < 0.5:
img_in = ImageOps.mirror(img_in)
img_tar = ImageOps.mirror(img_tar)
info_aug['flip_v'] = True
if random.random() < 0.5:
img_in = img_in.rotate(180)
img_tar = img_tar.rotate(180)
info_aug['trans'] = True
return img_in, img_tar, info_aug
class Dataset_DIV2K():
def __init__(self, base_dir, downsample_factor, data_augmentation, mode):
super(Dataset_DIV2K, self).__init__()
if mode == "train":
self.filelist_label = sorted(
[join(base_dir, "DIV2K800_sub", x) for x in listdir(os.path.join(base_dir, "data/DIV2K/DIV2K800_sub")) if
is_image_file(x)])
LR_dir = join(base_dir, "data/DIV2K",("DIV2K800_sub_LRx" + str(downsample_factor)))
self.filelist_img = sorted(
[os.path.join(base_dir, LR_dir, x) for x in listdir(os.path.join(base_dir, LR_dir)) if
is_image_file(x)])
elif mode == "valid":
self.filelist_label = sorted(
[join(base_dir,"data", "DIV2K_valid_HR", x) for x in listdir(os.path.join(base_dir, "data/DIV2K/DIV2K_train_HR")) if
is_image_file(x)])
LR_dir = join(base_dir, "data/DIV2K",("DIV2K_valid_LR_X" + str(downsample_factor)))
self.filelist_img = sorted(
[os.path.join(base_dir, LR_dir, x) for x in listdir(os.path.join(base_dir, LR_dir)) if
is_image_file(x)])
self.data_augmentation = data_augmentation
def __getitem__(self, index):
input = load_img(self.filelist_img[index])
target = load_img(self.filelist_label[index])
if self.data_augmentation:
input, target, _ = augment(input, target)
input = np.transpose(input / 255.0, (2, 0, 1))
target = np.transpose(target / 255.0, (2, 0, 1))
return input, target
def __len__(self):
return len(self.filelist_img)
class Dataset_Flickr():
def __init__(self, base_dir, downsample_factor, data_augmentation, mode, num_trainset):
super(Dataset_Flickr, self).__init__()
self.filelist_label = sorted(
[join(x, base_dir, "Flickr2K_HR", x) for x in listdir(os.path.join(base_dir, "Flickr2K_HR")) if
is_image_file(x)])
LR_dir = join(base_dir, "Flickr2K_LR_bicubic", ("X" + str(downsample_factor)))
self.filelist_img = sorted(
[os.path.join(base_dir, LR_dir, x) for x in listdir(os.path.join(base_dir, LR_dir)) if
is_image_file(x)])
if mode == "train":
self.filelist_img = self.filelist_img[:num_trainset]
self.filelist_label = self.filelist_label[:num_trainset]
else:
self.filelist_img = self.filelist_img[num_trainset:]
self.filelist_label = self.filelist_label[num_trainset:]
self.data_augmentation = data_augmentation
def __getitem__(self, index):
input = load_img(self.filelist_img[index])
target = load_img(self.filelist_label[index])
if self.data_augmentation:
input, target, _ = augment(input, target)
input = np.transpose(input / 255.0, (2, 0, 1))
target = np.transpose(target / 255.0, (2, 0, 1))
return input, target
def __len__(self):
return len(self.filelist_img)
def get_dataset_DIV2K(base_dir, downsample_factor, mode, aug, repeat, num_readers, shard_num, shard_id, batch_size):
dataset_DIV2K = Dataset_DIV2K(base_dir, downsample_factor, data_augmentation=aug, mode=mode)
data_set = de.GeneratorDataset(source=dataset_DIV2K, column_names=["input", "target"],
shuffle=True, num_parallel_workers=num_readers,
num_shards=shard_num, shard_id=shard_id)
data_set = data_set.shuffle(buffer_size=batch_size * 10)
data_set = data_set.batch(batch_size, drop_remainder=True)
data_set = data_set.repeat(repeat)
return data_set
def get_dataset_Flickr(base_dir, downsample_factor, mode, aug, repeat, resize_shape, num_trainset, num_readers,
shard_num, shard_id,
num_parallel_calls, batch_size):
dataset_Flickr = Dataset_Flickr(base_dir, downsample_factor, data_augmentation=aug, mode=mode,
num_trainset=num_trainset)
data_set = de.GeneratorDataset(source=dataset_Flickr, column_names=["input", "target"],
shuffle=True, num_parallel_workers=num_readers,
num_shards=shard_num, shard_id=shard_id)
data_set = data_set.shuffle(buffer_size=batch_size * 10)
data_set = data_set.batch(batch_size, drop_remainder=True)
data_set = data_set.repeat(repeat)
return data_set

+ 172
- 0
src/dataset/transform_opencv.py View File

@@ -0,0 +1,172 @@
import cv2
import random
def mod_crop(img, scale):
"""Mod crop images, used during testing.
Args:
img (ndarray): Input image.
scale (int): Scale factor.
Returns:
ndarray: Result image.
"""
img = img.copy()
if img.ndim in (2, 3):
h, w = img.shape[0], img.shape[1]
h_remainder, w_remainder = h % scale, w % scale
img = img[:h - h_remainder, :w - w_remainder, ...]
else:
raise ValueError(f'Wrong img ndim: {img.ndim}.')
return img
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
"""Paired random crop.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if not isinstance(img_gts, list):
img_gts = [img_gts]
if not isinstance(img_lqs, list):
img_lqs = [img_lqs]
h_lq, w_lq, _ = img_lqs[0].shape
h_gt, w_gt, _ = img_gts[0].shape
lq_patch_size = gt_patch_size // scale
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError(
f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
f'multiplication of LQ ({h_lq}, {w_lq}).')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
f'({lq_patch_size}, {lq_patch_size}). '
f'Please remove {gt_path}.')
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
img_lqs = [
v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
for v in img_lqs
]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
img_gts = [
v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
for v in img_gts
]
if len(img_gts) == 1:
img_gts = img_gts[0]
if len(img_lqs) == 1:
img_lqs = img_lqs[0]
return img_gts, img_lqs
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
We use vertical flip and transpose for rotation implementation.
All the images in the list use the same augmentation.
Args:
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
is an ndarray, it will be transformed to a list.
hflip (bool): Horizontal flip. Default: True.
rotation (bool): Ratotation. Default: True.
flows (list[ndarray]: Flows to be augmented. If the input is an
ndarray, it will be transformed to a list.
Dimension is (h, w, 2). Default: None.
return_status (bool): Return the status of flip and rotation.
Default: False.
Returns:
list[ndarray] | ndarray: Augmented images and flows. If returned
results only have one element, just return ndarray.
"""
hflip = hflip and random.random() < 0.5
vflip = rotation and random.random() < 0.5
rot90 = rotation and random.random() < 0.5
def _augment(img):
if hflip: # horizontal
cv2.flip(img, 1, img)
if vflip: # vertical
cv2.flip(img, 0, img)
if rot90:
img = img.transpose(1, 0, 2)
return img
def _augment_flow(flow):
if hflip: # horizontal
cv2.flip(flow, 1, flow)
flow[:, :, 0] *= -1
if vflip: # vertical
cv2.flip(flow, 0, flow)
flow[:, :, 1] *= -1
if rot90:
flow = flow.transpose(1, 0, 2)
flow = flow[:, :, [1, 0]]
return flow
if not isinstance(imgs, list):
imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1:
imgs = imgs[0]
if flows is not None:
if not isinstance(flows, list):
flows = [flows]
flows = [_augment_flow(flow) for flow in flows]
if len(flows) == 1:
flows = flows[0]
return imgs, flows
else:
if return_status:
return imgs, (hflip, vflip, rot90)
else:
return imgs
def img_rotate(img, angle, center=None, scale=1.0):
"""Rotate image.
Args:
img (ndarray): Image to be rotated.
angle (float): Rotation angle in degrees. Positive values mean
counter-clockwise rotation.
center (tuple[int]): Rotation center. If the center is None,
initialize it as the center of the image. Default: None.
scale (float): Isotropic scale factor. Default: 1.0.
"""
(h, w) = img.shape[:2]
if center is None:
center = (w // 2, h // 2)
matrix = cv2.getRotationMatrix2D(center, angle, scale)
rotated_img = cv2.warpAffine(img, matrix, (w, h))
return rotated_img

+ 84
- 0
src/model/RRDB_Net.py View File

@@ -0,0 +1,84 @@
import mindspore
import mindspore.nn as nn
from mindspore import ops
import numpy as np
import functools
# 除了最后的整体网络 其他部分验证无误
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.SequentialCell(*layers)
class ResidualDenseBlock_5C(nn.Cell):
def __init__(self, nf=64, gc=32,res_beta=0.2, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, padding=1, has_bias=bias,pad_mode="pad")
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, padding=1, has_bias=bias,pad_mode="pad")
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, padding=1, has_bias=bias,pad_mode="pad")
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, padding=1, has_bias=bias,pad_mode="pad")
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, padding=1, has_bias=bias,pad_mode="pad")
self.lrelu = nn.LeakyReLU(0.2)
self.res_beta = res_beta
self.cat = mindspore.ops.Concat(1)
def construct(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(self.cat((x, x1))))
x3 = self.lrelu(self.conv3(self.cat((x, x1, x2))))
x4 = self.lrelu(self.conv4(self.cat((x, x1, x2, x3))))
x5 = self.conv5(self.cat((x, x1, x2, x3, x4)))
return x5 * self.res_beta + x
class RRDB(nn.Cell):
'''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32,res_beta=0.2):
super(RRDB, self).__init__()
self.res_beta = res_beta
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def construct(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class RRDBNet(nn.Cell):
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, padding=1, has_bias=True,pad_mode="pad",bias_init="zeros",weight_init="normal")
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, padding=1, has_bias=True,pad_mode="pad")
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, padding=1, has_bias=True,pad_mode="pad")
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, padding=1, has_bias=True,pad_mode="pad")
self.HRconv = nn.Conv2d(nf, nf, 3, 1, padding=1, has_bias=True,pad_mode="pad")
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, padding=1, has_bias=True,pad_mode="pad")
self.lrelu = nn.LeakyReLU(0.2)
self.shape = mindspore.ops.Shape()
def construct(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea_size = list(self.shape(fea))
fea = self.lrelu(self.upconv1(ops.ResizeNearestNeighbor((fea_size[2]*2, fea_size[3]*2), True)(fea)))
fea_size = list(self.shape(fea))
fea = self.lrelu(self.upconv2(ops.ResizeNearestNeighbor((fea_size[2]*2, fea_size[3]*2), True)(fea)))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out

+ 95
- 0
src/model/VGG.py View File

@@ -0,0 +1,95 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""VGG."""
import mindspore.nn as nn
from mindspore.common.initializer import initializer
import mindspore.common.dtype as mstype
def _make_layer(base, batch_norm):
"""Make stage network of VGG."""
layers = []
in_channels = 3
for v in base:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
weight_shape = (v, in_channels, 3, 3)
weight = initializer('XavierUniform', shape=weight_shape, dtype=mstype.float32).to_tensor()
conv2d = nn.Conv2d(in_channels=in_channels,
out_channels=v,
kernel_size=3,
padding=0,
pad_mode='same',
weight_init=weight)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()]
else:
layers += [conv2d, nn.ReLU()]
in_channels = v
return nn.SequentialCell(layers)
class Vgg(nn.Cell):
"""
VGG network definition.
Args:
base (list): Configuration for different layers, mainly the channel number of Conv layer.
num_classes (int): Class numbers. Default: 1000.
batch_norm (bool): Whether to do the batchnorm. Default: False.
batch_size (int): Batch size. Default: 1.
Returns:
Tensor, infer output tensor.
"""
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1):
super(Vgg, self).__init__()
_ = batch_size
self.layers = _make_layer(base, batch_norm=batch_norm)
self.flatten = nn.Flatten()
self.classifier = nn.SequentialCell([
nn.Dense(512 * 7 * 7, 4096),
nn.ReLU(),
nn.Dense(4096, 4096),
nn.ReLU(),
nn.Dense(4096, num_classes)])
def construct(self, x):
x = self.layers(x)
x = self.flatten(x)
x = self.classifier(x)
return x
cfg = {
'11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
def vgg16(num_classes=1000):
net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=False)
return net
def vgg19(num_classes=1000):
net = Vgg(cfg['19'], num_classes=num_classes, batch_norm=False)
return net

+ 0
- 0
src/model/__init__.py View File


+ 83
- 0
src/model/discriminator_net.py View File

@@ -0,0 +1,83 @@
import mindspore.nn as nn
import mindspore
import numpy as np
from mindspore import Tensor
class VGGStyleDiscriminator512(nn.Cell):
"""VGG style discriminator with input size 512 x 512.
It is used to train SRGAN and ESRGAN.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_feat (int): Channel number of base intermediate features.
Default: 64.
"""
def __init__(self, num_in_ch, num_feat):
super(VGGStyleDiscriminator512, self).__init__()
self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, padding=1, has_bias=True, pad_mode="pad")
self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, padding=1, has_bias=False, pad_mode="pad")
self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True)
self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, padding=1, has_bias=False, pad_mode="pad")
self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True)
self.conv1_1 = nn.Conv2d(
num_feat * 2, num_feat * 2, 4, 2, padding=1, has_bias=False, pad_mode="pad")
self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True)
self.conv2_0 = nn.Conv2d(
num_feat * 2, num_feat * 4, 3, 1, padding=1, has_bias=False, pad_mode="pad")
self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True)
self.conv2_1 = nn.Conv2d(
num_feat * 4, num_feat * 4, 4, 2, padding=1, has_bias=False, pad_mode="pad")
self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True)
self.conv3_0 = nn.Conv2d(
num_feat * 4, num_feat * 8, 3, 1, padding=1, has_bias=False, pad_mode="pad")
self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.conv3_1 = nn.Conv2d(
num_feat * 8, num_feat * 8, 4, 2, padding=1, has_bias=False, pad_mode="pad")
self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.conv4_0 = nn.Conv2d(
num_feat * 8, num_feat * 8, 3, 1, padding=1, has_bias=False, pad_mode="pad")
self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.conv4_1 = nn.Conv2d(
num_feat * 8, num_feat * 8, 4, 2, padding=1, has_bias=False, pad_mode="pad")
self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
# (n/4)**2 * 4
self.linear1 = nn.Dense(num_feat * 8 * 4 * 64, 100)
self.linear2 = nn.Dense(100, 1)
self.lrelu = nn.LeakyReLU(0.2)
self.flatten = nn.Flatten()
def construct(self, x):
feat = self.lrelu(self.conv0_0(x))
feat = self.lrelu(self.bn0_1(
self.conv0_1(feat))) # output spatial size: (64, 64)
feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
feat = self.lrelu(self.bn1_1(
self.conv1_1(feat))) # output spatial size: (32, 32)
feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
feat = self.lrelu(self.bn2_1(
self.conv2_1(feat))) # output spatial size: (16, 16)
feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
feat = self.lrelu(self.bn3_1(
self.conv3_1(feat))) # output spatial size: (8, 8)
feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
feat = self.lrelu(self.bn4_1(
self.conv4_1(feat))) # output spatial size: (4, 4)
feat = self.flatten(feat)
feat = self.lrelu(self.linear1(feat))
out = self.linear2(feat)
return out

+ 56
- 0
src/model/loss.py View File

@@ -0,0 +1,56 @@
import mindspore
from mindspore import nn as nn
from src.model.VGG import vgg16
import mindspore.ops.functional as F
class PerceptualLoss(nn.Cell):
# 内容损失
def __init__(self):
super(PerceptualLoss, self).__init__()
vgg = vgg16()
loss_network = nn.SequentialCell(*list(vgg.layers)[0:44])
for l in loss_network.layers:
l.requires_grad = False
self.loss_network = loss_network
self.l1_loss = nn.L1Loss()
def construct(self, high_resolution, fake_high_resolution):
# the input scale range is [0, 1] (vgg is [0, 255]).
# 读取时放缩到了0-1 在计算内容损失时应该还原
# 12.75 is rescale factor for vgg featuremaps.
perception_loss = self.l1_loss((self.loss_network(high_resolution* 255.)/12.75), (self.loss_network(fake_high_resolution* 255.)/12.75))
return perception_loss
class DiscriminatorLoss(nn.Cell):
def __init__(self,gan_type='ragan'):
super(DiscriminatorLoss, self).__init__()
self.gan_type = gan_type
self.cross_entropy = mindspore.ops.BinaryCrossEntropy
self.sigma = mindspore.ops.Sigmoid
def construct(self, hr,sr):
# hr是真实的高分辨率图像 sr是生成图像
if self.gan_type =="ragan":
return 0.5 * (
self.cross_entropy(F.ones_like(hr), self.sigma(hr - mindspore.ops.ReduceMean(sr))) +
self.cross_entropy(F.zeros_like(sr), self.sigma(sr - mindspore.ops.ReduceMean(hr))))
elif self.gan_type == 'gan':
real_loss = self.cross_entropy(F.ones_like(hr), self.sigma(hr))
fake_loss = self.cross_entropy(F.zeros_like(sr), self.sigma(sr))
return real_loss + fake_loss
class GeneratorLoss(nn.Cell):
def __init__(self,gan_type='ragan'):
super(GeneratorLoss, self).__init__()
self.gan_type = gan_type
self.cross_entropy = mindspore.ops.BinaryCrossEntropy
self.sigma = mindspore.ops.Sigmoid
def construct(self, hr,sr):
# hr是真实的高分辨率图像 sr是生成图像
if self.gan_type =="ragan":
return 0.5 * (
self.cross_entropy(F.ones_like(sr), self.sigma(sr - mindspore.ops.ReduceMean(hr))) +
self.cross_entropy(F.zeros_like(hr), self.sigma(hr - mindspore.ops.ReduceMean(sr))))
elif self.gan_type == 'gan':
return self.cross_entropy(F.ones_like(sr),self.sigma(sr))

+ 0
- 0
src/utils/__init__.py View File


+ 326
- 0
src/utils/eval_util.py View File

@@ -0,0 +1,326 @@
import cv2
import yaml
import sys
import time
import numpy as np
from absl import logging
from src.utils.convert_tfrecord import load_tfrecord_dataset
def load_yaml(load_path):
"""load yaml file"""
with open(load_path, 'r') as f:
loaded = yaml.load(f, Loader=yaml.Loader)
return loaded
def load_dataset(cfg, key, shuffle=True, buffer_size=10240):
"""load dataset"""
dataset_cfg = cfg[key]
logging.info("load {} from {}".format(key, dataset_cfg['path']))
dataset = load_tfrecord_dataset(
tfrecord_name=dataset_cfg['path'],
batch_size=cfg['batch_size'],
gt_size=cfg['gt_size'],
scale=cfg['scale'],
shuffle=shuffle,
using_bin=dataset_cfg['using_bin'],
using_flip=dataset_cfg['using_flip'],
using_rot=dataset_cfg['using_rot'],
buffer_size=buffer_size)
return dataset
def create_lr_hr_pair(raw_img, scale=4.):
lr_h, lr_w = raw_img.shape[0] // scale, raw_img.shape[1] // scale
hr_h, hr_w = lr_h * scale, lr_w * scale
hr_img = raw_img[:hr_h, :hr_w, :]
lr_img = imresize_np(hr_img, 1 / scale)
return lr_img, hr_img
def tensor2img(tensor):
return (np.squeeze(tensor.numpy()).clip(0, 1) * 255).astype(np.uint8)
def change_weight(model, vars1, vars2, alpha=1.0):
for i, var in enumerate(model.trainable_variables):
var.assign((1 - alpha) * vars1[i] + alpha * vars2[i])
class ProgressBar(object):
"""A progress bar which can print the progress modified from
https://github.com/hellock/cvbase/blob/master/cvbase/progress.py"""
def __init__(self, task_num=0, completed=0, bar_width=25):
self.task_num = task_num
max_bar_width = self._get_max_bar_width()
self.bar_width = (bar_width
if bar_width <= max_bar_width else max_bar_width)
self.completed = completed
self.first_step = completed
self.warm_up = False
def _get_max_bar_width(self):
from shutil import get_terminal_size
terminal_width, _ = get_terminal_size()
max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
if max_bar_width < 10:
logging.info('terminal width is too small ({}), please consider '
'widen the terminal for better progressbar '
'visualization'.format(terminal_width))
max_bar_width = 10
return max_bar_width
def reset(self):
"""reset"""
self.completed = 0
def update(self, inf_str=''):
"""update"""
self.completed += 1
if not self.warm_up:
self.start_time = time.time() - 1e-2
self.warm_up = True
elapsed = time.time() - self.start_time
fps = (self.completed - self.first_step) / elapsed
percentage = self.completed / float(self.task_num)
mark_width = int(self.bar_width * percentage)
bar_chars = '>' * mark_width + ' ' * (self.bar_width - mark_width)
stdout_str = \
'\rTraining [{}] {}/{}, {} {:.1f} step/sec'
sys.stdout.write(stdout_str.format(
bar_chars, self.completed, self.task_num, inf_str, fps))
sys.stdout.flush()
###############################################################################
# These processing code is copied and modified from official implement: #
# https://github.com/open-mmlab/mmsr #
###############################################################################
def imresize_np(img, scale, antialiasing=True):
# Now the scale should be the same for H and W
# input: img: Numpy, HWC RBG [0,1]
# output: HWC RBG [0,1] w/o round
# (Modified from
# https://github.com/open-mmlab/mmsr/blob/master/codes/data/util.py)
in_H, in_W, in_C = img.shape
_, out_H, out_W = in_C, np.ceil(in_H * scale), np.ceil(in_W * scale)
out_H, out_W = out_H.astype(np.int64), out_W.astype(np.int64)
kernel_width = 4
kernel = 'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = _calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing)
weights_W, indices_W, sym_len_Ws, sym_len_We = _calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing)
# process H dimension
# symmetric copying
img_aug = np.zeros(((in_H + sym_len_Hs + sym_len_He), in_W, in_C))
img_aug[sym_len_Hs:sym_len_Hs + in_H] = img
sym_patch = img[:sym_len_Hs, :, :]
sym_patch_inv = sym_patch[::-1]
img_aug[0:sym_len_Hs] = sym_patch_inv
sym_patch = img[-sym_len_He:, :, :]
sym_patch_inv = sym_patch[::-1]
img_aug[sym_len_Hs + in_H:sym_len_Hs + in_H + sym_len_He] = sym_patch_inv
out_1 = np.zeros((out_H, in_W, in_C))
kernel_width = weights_H.shape[1]
for i in range(out_H):
idx = int(indices_H[i][0])
out_1[i, :, 0] = weights_H[i].dot(
img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1))
out_1[i, :, 1] = weights_H[i].dot(
img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1))
out_1[i, :, 2] = weights_H[i].dot(
img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1))
# process W dimension
# symmetric copying
out_1_aug = np.zeros((out_H, in_W + sym_len_Ws + sym_len_We, in_C))
out_1_aug[:, sym_len_Ws:sym_len_Ws + in_W] = out_1
sym_patch = out_1[:, :sym_len_Ws, :]
sym_patch_inv = sym_patch[:, ::-1]
out_1_aug[:, 0:sym_len_Ws] = sym_patch_inv
sym_patch = out_1[:, -sym_len_We:, :]
sym_patch_inv = sym_patch[:, ::-1]
out_1_aug[:, sym_len_Ws + in_W:sym_len_Ws + in_W + sym_len_We] = \
sym_patch_inv
out_2 = np.zeros((out_H, out_W, in_C))
kernel_width = weights_W.shape[1]
for i in range(out_W):
idx = int(indices_W[i][0])
out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].dot(
weights_W[i])
out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].dot(
weights_W[i])
out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].dot(
weights_W[i])
return out_2.clip(0, 255)
def _cubic(x):
absx = np.abs(x)
absx2 = absx ** 2
absx3 = absx ** 3
return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).astype(np.float64)) \
+ (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (
((absx > 1) * (absx <= 2)).astype(np.float64))
def _calculate_weights_indices(in_length, out_length, scale, kernel,
kernel_width, antialiasing):
if (scale < 1) and (antialiasing):
# Use a modified kernel to simultaneously interpolate and antialias
# larger kernel width
kernel_width = kernel_width / scale
# Output-space coordinates
x = np.linspace(1, out_length, out_length)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5+scale in output
# space maps to 1.5 in input space.
u = x / scale + 0.5 * (1 - 1 / scale)
# What is the left-most pixel that can be involved in the computation?
left = np.floor(u - kernel_width / 2)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
P = (np.ceil(kernel_width) + 2).astype(np.int32)
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.reshape(int(out_length), 1).repeat(P, axis=1) + \
np.linspace(0, P - 1, P).reshape(1, int(P)).repeat(out_length, axis=0)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center = \
u.reshape(int(out_length), 1).repeat(P, axis=1) - indices
# apply cubic kernel
if (scale < 1) and (antialiasing):
weights = scale * _cubic(distance_to_center * scale)
else:
weights = _cubic(distance_to_center)
# Normalize the weights matrix so that each row sums to 1.
weights_sum = np.sum(weights, 1).reshape(int(out_length), 1)
weights = weights / weights_sum.repeat(P, axis=1)
# If a column in weights is all zero, get rid of it. only consider the
# first and last column.
weights_zero_tmp = np.sum((weights == 0), 0)
if not np.isclose(weights_zero_tmp[0], 0, rtol=1e-6):
indices = indices[:, 1:1 + int(P) - 2]
weights = weights[:, 1:1 + int(P) - 2]
if not np.isclose(weights_zero_tmp[-1], 0, rtol=1e-6):
indices = indices[:, 0:0 + int(P) - 2]
weights = weights[:, 0:0 + int(P) - 2]
weights = weights.copy()
indices = indices.copy()
sym_len_s = -indices.min() + 1
sym_len_e = indices.max() - in_length
indices = indices + sym_len_s - 1
return weights, indices, int(sym_len_s), int(sym_len_e)
def calculate_psnr(img1, img2):
# img1 and img2 have range [0, 255]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2)**2)
if mse == 0:
return float('inf')
return 20 * np.log10(255.0 / np.sqrt(mse))
def _ssim(img1, img2):
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) \
/ ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def calculate_ssim(img1, img2):
'''calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
if img1.ndim == 2:
return _ssim(img1, img2)
elif img1.ndim == 3:
if img1.shape[2] == 3:
ssims = []
for _ in range(3):
ssims.append(_ssim(img1, img2))
return np.array(ssims).mean()
elif img1.shape[2] == 1:
return _ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError('Wrong input image dimensions.')
def rgb2ycbcr(img, only_y=True):
"""Convert rgb to ycbcr
only_y: only return Y channel
Input:
uint8, [0, 255]
float, [0, 1]
"""
in_img_type = img.dtype
img.astype(np.float32)
if in_img_type != np.uint8:
img *= 255.
img = img[:, :, ::-1]
# convert
if only_y:
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
else:
rlt = np.matmul(
img, [[24.966, 112.0, -18.214],
[128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
if in_img_type == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(in_img_type)

+ 221
- 0
src/utils/extract_subimages.py View File

@@ -0,0 +1,221 @@
"""A multi-thread tool to crop large images to sub-images for faster IO.
(This preprocessing code is copied and modified from official implement:
https://github.com/open-mmlab/mmsr/tree/master/codes/data_scripts)"""
import os
import os.path as osp
import sys
from multiprocessing import Pool
import numpy as np
import cv2
from PIL import Image
import time
from shutil import get_terminal_size
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
def main():
mode = 'pair' # single (one input folder) | pair (extract corresponding GT and LR pairs)
opt = {}
opt['n_thread'] = 20
opt['compression_level'] = 3 # 3 is the default value in cv2
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.
if mode == 'single':
opt['input_folder'] = './data/DIV2K/DIV2K_train_HR'
opt['save_folder'] = './data/DIV2K/DIV2K800_sub'
opt['crop_sz'] = 480 # the size of each sub-image
opt['step'] = 240 # step of the sliding crop window
opt['thres_sz'] = 48 # size threshold
extract_signle(opt)
elif mode == 'pair':
GT_folder = './data/DIV2K/DIV2K_train_HR'
LR_folder = './data/DIV2K/DIV2K_train_LR_bicubic/X4'
save_GT_folder = './data/DIV2K/DIV2K800_sub'
save_LR_folder = './data/DIV2K/DIV2K800_sub_bicLRx4'
scale_ratio = 4
crop_sz = 480 # the size of each sub-image (GT)
step = 240 # step of the sliding crop window (GT)
thres_sz = 48 # size threshold
########################################################################
# check that all the GT and LR images have correct scale ratio
img_GT_list = _get_paths_from_images(GT_folder)
img_LR_list = _get_paths_from_images(LR_folder)
assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.'
for path_GT, path_LR in zip(img_GT_list, img_LR_list):
img_GT = Image.open(path_GT)
img_LR = Image.open(path_LR)
w_GT, h_GT = img_GT.size
w_LR, h_LR = img_LR.size
assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
w_GT, scale_ratio, w_LR, path_GT)
assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
w_GT, scale_ratio, w_LR, path_GT)
# check crop size, step and threshold size
assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format(
scale_ratio)
assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio)
assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format(
scale_ratio)
print('process GT...')
opt['input_folder'] = GT_folder
opt['save_folder'] = save_GT_folder
opt['crop_sz'] = crop_sz
opt['step'] = step
opt['thres_sz'] = thres_sz
extract_signle(opt)
print('process LR...')
opt['input_folder'] = LR_folder
opt['save_folder'] = save_LR_folder
opt['crop_sz'] = crop_sz // scale_ratio
opt['step'] = step // scale_ratio
opt['thres_sz'] = thres_sz // scale_ratio
extract_signle(opt)
assert len(_get_paths_from_images(save_GT_folder)) == len(
_get_paths_from_images(
save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.'
else:
raise ValueError('Wrong mode.')
def extract_signle(opt):
input_folder = opt['input_folder']
save_folder = opt['save_folder']
if not osp.exists(save_folder):
os.makedirs(save_folder)
print('mkdir [{:s}] ...'.format(save_folder))
else:
print('Folder [{:s}] already exists. Exit...'.format(save_folder))
sys.exit(1)
img_list = _get_paths_from_images(input_folder)
def update(arg):
pbar.update(arg)
pbar = ProgressBar(len(img_list))
pool = Pool(opt['n_thread'])
for path in img_list:
pool.apply_async(worker, args=(path, opt), callback=update)
pool.close()
pool.join()
print('All subprocesses done.')
def worker(path, opt):
crop_sz = opt['crop_sz']
step = opt['step']
thres_sz = opt['thres_sz']
img_name = osp.basename(path)
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
n_channels = len(img.shape)
if n_channels == 2:
h, w = img.shape
elif n_channels == 3:
h, w, c = img.shape
else:
raise ValueError('Wrong image shape - {}'.format(n_channels))
h_space = np.arange(0, h - crop_sz + 1, step)
if h - (h_space[-1] + crop_sz) > thres_sz:
h_space = np.append(h_space, h - crop_sz)
w_space = np.arange(0, w - crop_sz + 1, step)
if w - (w_space[-1] + crop_sz) > thres_sz:
w_space = np.append(w_space, w - crop_sz)
index = 0
for x in h_space:
for y in w_space:
index += 1
if n_channels == 2:
crop_img = img[x:x + crop_sz, y:y + crop_sz]
else:
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
crop_img = np.ascontiguousarray(crop_img)
cv2.imwrite(
osp.join(opt['save_folder'],
img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img,
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
return 'Processing {:s} ...'.format(img_name)
# ##############
# ### Utils ####
# ##############
class ProgressBar(object):
'''A progress bar which can print the progress
modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
'''
def __init__(self, task_num=0, bar_width=50, start=True):
self.task_num = task_num
max_bar_width = self._get_max_bar_width()
self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
self.completed = 0
if start:
self.start()
def _get_max_bar_width(self):
terminal_width, _ = get_terminal_size()
max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
if max_bar_width < 10:
print('terminal width is too small ({}), please consider widen the terminal for better '
'progressbar visualization'.format(terminal_width))
max_bar_width = 10
return max_bar_width
def start(self):
if self.task_num > 0:
sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
' ' * self.bar_width, self.task_num, 'Start...'))
else:
sys.stdout.write('completed: 0, elapsed: 0s')
sys.stdout.flush()
self.start_time = time.time()
def update(self, msg='In progress...'):
self.completed += 1
elapsed = time.time() - self.start_time + 1e-9
fps = self.completed / elapsed
if self.task_num > 0:
percentage = self.completed / float(self.task_num)
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
mark_width = int(self.bar_width * percentage)
bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
sys.stdout.write('\033[2F') # cursor up 2 lines
sys.stdout.write('\033[J') # clean the output (remove extra chars since last display)
sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
else:
sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
self.completed, int(elapsed + 0.5), fps))
sys.stdout.flush()
# ###################
# ### Data Utils ####
# ###################
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def _get_paths_from_images(path):
"""get image path list from image folder"""
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
images = []
for dirpath, _, fnames in sorted(os.walk(path)):
for fname in sorted(fnames):
if is_image_file(fname):
img_path = os.path.join(dirpath, fname)
images.append(img_path)
assert images, '{:s} has no valid image file'.format(path)
return images
if __name__ == '__main__':
main()

+ 81
- 0
src/utils/loss.py View File

@@ -0,0 +1,81 @@
import mindspore
from mindspore import nn as nn
from src.model.VGG import vgg16
import mindspore.ops.functional as F
class PixelLoss(nn.Cell):
# 感知损失
def __init__(self,criterion='l1'):
super(PixelLoss, self).__init__()
if criterion == 'l1':
self.loss = nn.L1Loss()
elif criterion == 'l2':
self.loss = nn.MSELoss()
def construct(self, hr,sr):
return self.loss(hr,sr)
class PerceptualLoss(nn.cell):
# 内容损失
def __init__(self):
super(PerceptualLoss, self).__init__()
vgg = vgg16()
loss_network = nn.SequentialCell(*list(vgg.layers)[0:44])
for l in loss_network.layers:
l.requires_grad = False
self.loss_network = loss_network
self.l1_loss = nn.L1Loss()
def construct(self, high_resolution, fake_high_resolution):
# the input scale range is [0, 1] (vgg is [0, 255]).
# 12.75 is rescale factor for vgg featuremaps.
perception_loss = self.l1_loss((self.loss_network(high_resolution* 255.)/12.75), (self.loss_network(fake_high_resolution* 255.)/12.75))
return perception_loss
def DiscriminatorLoss(gan_type='ragan'):
"""discriminator loss"""
binary_cross_entropy = mindspore.ops.BinaryCrossEntropy()
cross_entropy = binary_cross_entropy
sigma = mindspore.ops.Sigmoid
def discriminator_loss_ragan(hr, sr):
return 0.5 * (
cross_entropy(F.ones_like(hr), sigma(hr - mindspore.ops.ReduceMean(sr))) +
cross_entropy(F.zeros_like(sr), sigma(sr - mindspore.ops.ReduceMean(hr))))
def discriminator_loss(hr, sr):
real_loss = cross_entropy(F.ones_like(hr), sigma(hr))
fake_loss = cross_entropy(F.zeros_like(sr), sigma(sr))
return real_loss + fake_loss
if gan_type == 'ragan':
return discriminator_loss_ragan
elif gan_type == 'gan':
return discriminator_loss
else:
raise NotImplementedError(
'Discriminator loss type {} is not recognized.'.format(gan_type))
def GeneratorLoss(gan_type='ragan'):
"""generator loss"""
binary_cross_entropy = mindspore.ops.BinaryCrossEntropy()
cross_entropy = binary_cross_entropy
sigma = mindspore.ops.Sigmoid
def generator_loss_ragan(hr, sr):
return 0.5 * (
cross_entropy(F.ones_like(sr), sigma(sr - mindspore.ops.ReduceMean(hr))) +
cross_entropy(F.zeros_like(sr), sigma(hr - mindspore.ops.ReduceMean(sr))))
def generator_loss(hr, sr):
return cross_entropy(F.ones_like(sr), sigma(sr))
if gan_type == 'ragan':
return generator_loss_ragan
elif gan_type == 'gan':
return generator_loss
else:
raise NotImplementedError(
'Generator loss type {} is not recognized.'.format(gan_type))

+ 192
- 0
src/utils/matlab_functions.py View File

@@ -0,0 +1,192 @@
import numpy as np
def rgb2ycbcr(img, y_only=False):
"""Convert a RGB image to YCbCr image.
This function produces the same results as Matlab's `rgb2ycbcr` function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
else:
out_img = np.matmul(
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
[24.966, 112.0, -18.214]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
[65.481, -37.797, 112.0]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def ycbcr2rgb(img):
"""Convert a YCbCr image to RGB image.
This function produces the same results as Matlab's ycbcr2rgb function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted RGB image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
[0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]]) * 255.0 + [
-222.921, 135.576, -276.836
] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def ycbcr2bgr(img):
"""Convert a YCbCr image to BGR image.
The bgr version of ycbcr2rgb.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted BGR image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
[0.00791071, -0.00153632, 0],
[0, -0.00318811, 0.00625893]]) * 255.0 + [
-276.836, 135.576, -222.921
] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def _convert_input_type_range(img):
"""Convert the type and range of the input image.
It converts the input image to np.float32 type and range of [0, 1].
It is mainly used for pre-processing the input image in colorspace
convertion functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
(ndarray): The converted image with type of np.float32 and range of
[0, 1].
"""
img_type = img.dtype
img = img.astype(np.float32)
if img_type == np.float32:
pass
elif img_type == np.uint8:
img /= 255.
else:
raise TypeError('The img type should be np.float32 or np.uint8, '
f'but got {img_type}')
return img
def _convert_output_type_range(img, dst_type):
"""Convert the type and range of the image according to dst_type.
It converts the image to desired type and range. If `dst_type` is np.uint8,
images will be converted to np.uint8 type with range [0, 255]. If
`dst_type` is np.float32, it converts the image to np.float32 type with
range [0, 1].
It is mainly used for post-processing images in colorspace convertion
functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The image to be converted with np.float32 type and
range [0, 255].
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
converts the image to np.uint8 type with range [0, 255]. If
dst_type is np.float32, it converts the image to np.float32 type
with range [0, 1].
Returns:
(ndarray): The converted image with desired type and range.
"""
if dst_type not in (np.uint8, np.float32):
raise TypeError('The dst_type should be np.float32 or np.uint8, '
f'but got {dst_type}')
if dst_type == np.uint8:
img = img.round()
else:
img /= 255.
return img.astype(dst_type)

+ 47
- 0
src/utils/metric_util.py View File

@@ -0,0 +1,47 @@
import numpy as np
from src.utils.matlab_functions import bgr2ycbcr
def reorder_image(img, input_order='HWC'):
"""Reorder images to 'HWC' order.
If the input_order is (h, w), return (h, w, 1);
If the input_order is (c, h, w), return (h, w, c);
If the input_order is (h, w, c), return as it is.
Args:
img (ndarray): Input image.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
If the input image shape is (h, w), input_order will not have
effects. Default: 'HWC'.
Returns:
ndarray: reordered image.
"""
if input_order not in ['HWC', 'CHW']:
raise ValueError(
f'Wrong input_order {input_order}. Supported input_orders are '
"'HWC' and 'CHW'")
if len(img.shape) == 2:
img = img[..., None]
if input_order == 'CHW':
img = img.transpose(1, 2, 0)
return img
def to_y_channel(img):
"""Change to Y channel of YCbCr.
Args:
img (ndarray): Images with range [0, 255].
Returns:
(ndarray): Images with range [0, 255] (float type) without round.
"""
img = img.astype(np.float32) / 255.
if img.ndim == 3 and img.shape[2] == 3:
img = bgr2ycbcr(img, y_only=True)
img = img[..., None]
return img * 255.

+ 205
- 0
src/utils/niqe.py View File

@@ -0,0 +1,205 @@
import cv2
import math
import numpy as np
from scipy.ndimage.filters import convolve
from scipy.special import gamma
from src.metrics import reorder_image, to_y_channel
def estimate_aggd_param(block):
"""Estimate AGGD (Asymmetric Generalized Gaussian Distribution) paramters.
Args:
block (ndarray): 2D Image block.
Returns:
tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
distribution (Estimating the parames in Equation 7 in the paper).
"""
block = block.flatten()
gam = np.arange(0.2, 10.001, 0.001) # len = 9801
gam_reciprocal = np.reciprocal(gam)
r_gam = np.square(gamma(gam_reciprocal * 2)) / (
gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
left_std = np.sqrt(np.mean(block[block < 0]**2))
right_std = np.sqrt(np.mean(block[block > 0]**2))
gammahat = left_std / right_std
rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
rhatnorm = (rhat * (gammahat**3 + 1) *
(gammahat + 1)) / ((gammahat**2 + 1)**2)
array_position = np.argmin((r_gam - rhatnorm)**2)
alpha = gam[array_position]
beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
return (alpha, beta_l, beta_r)
def compute_feature(block):
"""Compute features.
Args:
block (ndarray): 2D Image block.
Returns:
list: Features with length of 18.
"""
feat = []
alpha, beta_l, beta_r = estimate_aggd_param(block)
feat.extend([alpha, (beta_l + beta_r) / 2])
# distortions disturb the fairly regular structure of natural images.
# This deviation can be captured by analyzing the sample distribution of
# the products of pairs of adjacent coefficients computed along
# horizontal, vertical and diagonal orientations.
shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
for i in range(len(shifts)):
shifted_block = np.roll(block, shifts[i], axis=(0, 1))
alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block)
# Eq. 8
mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
feat.extend([alpha, mean, beta_l, beta_r])
return feat
def niqe(img,
mu_pris_param,
cov_pris_param,
gaussian_window,
block_size_h=96,
block_size_w=96):
"""Calculate NIQE (Natural Image Quality Evaluator) metric.
Ref: Making a "Completely Blind" Image Quality Analyzer.
This implementation could produce almost the same results as the official
MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
Note that we do not include block overlap height and width, since they are
always 0 in the official implementation.
For good performance, it is advisable by the official implemtation to
divide the distorted image in to the same size patched as used for the
construction of multivariate Gaussian model.
Args:
img (ndarray): Input image whose quality needs to be computed. The
image must be a gray or Y (of YCbCr) image with shape (h, w).
Range [0, 255] with float type.
mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
model calculated on the pristine dataset.
cov_pris_param (ndarray): Covariance of a pre-defined multivariate
Gaussian model calculated on the pristine dataset.
gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
image.
block_size_h (int): Height of the blocks in to which image is divided.
Default: 96 (the official recommended value).
block_size_w (int): Width of the blocks in to which image is divided.
Default: 96 (the official recommended value).
"""
assert img.ndim == 2, (
'Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
# crop image
h, w = img.shape
num_block_h = math.floor(h / block_size_h)
num_block_w = math.floor(w / block_size_w)
img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]
distparam = [] # dist param is actually the multiscale features
for scale in (1, 2): # perform on two scales (1, 2)
mu = convolve(img, gaussian_window, mode='nearest')
sigma = np.sqrt(
np.abs(
convolve(np.square(img), gaussian_window, mode='nearest') -
np.square(mu)))
# normalize, as in Eq. 1 in the paper
img_nomalized = (img - mu) / (sigma + 1)
feat = []
for idx_w in range(num_block_w):
for idx_h in range(num_block_h):
# process ecah block
block = img_nomalized[idx_h * block_size_h //
scale:(idx_h + 1) * block_size_h //
scale, idx_w * block_size_w //
scale:(idx_w + 1) * block_size_w //
scale]
feat.append(compute_feature(block))
distparam.append(np.array(feat))
# TODO: matlab bicubic downsample with anti-aliasing
# for simplicity, now we use opencv instead, which will result in
# a slight difference.
if scale == 1:
h, w = img.shape
img = cv2.resize(
img / 255., (w // 2, h // 2), interpolation=cv2.INTER_LINEAR)
img = img * 255.
distparam = np.concatenate(distparam, axis=1)
# fit a MVG (multivariate Gaussian) model to distorted patch features
mu_distparam = np.nanmean(distparam, axis=0)
# use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
cov_distparam = np.cov(distparam_no_nan, rowvar=False)
# compute niqe quality, Eq. 10 in the paper
invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
quality = np.matmul(
np.matmul((mu_pris_param - mu_distparam), invcov_param),
np.transpose((mu_pris_param - mu_distparam)))
quality = np.sqrt(quality)
return quality
def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y'):
"""Calculate NIQE (Natural Image Quality Evaluator) metric.
Ref: Making a "Completely Blind" Image Quality Analyzer.
This implementation could produce almost the same results as the official
MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
We use the official params estimated from the pristine dataset.
We use the recommended block size (96, 96) without overlaps.
Args:
img (ndarray): Input image whose quality needs to be computed.
The input image must be in range [0, 255] with float/int type.
The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order)
If the input order is 'HWC' or 'CHW', it will be converted to gray
or Y (of YCbCr) image according to the ``convert_to`` argument.
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the metric calculation.