|
- # -*- coding: utf-8 -*-
- """
- @Author : zhwzhong
- @License : (C) Copyright 2013-2018, hit
- @Contact : zhwzhong@hit.edu.cn
- @Software: PyCharm
- @File : utils.py
- @Time : 2022/2/20 14:07
- @Desc :
- """
- import os
- # import cv2
- import time
- import yaml
- import torch
- import shutil
- import datetime
- import itertools
- import numpy as np
- from skimage import measure
- from collections import Iterable
- import torch.distributed as dist
- from image_resize import imresize
- from skimage.color import rgb2ycbcr
- from collections import defaultdict, deque
- import torchvision.transforms.functional as TF
-
- class SmoothedValue(object):
- """Track a series of values and provide access to smoothed values over a
- window or the global series average.
- """
-
- def __init__(self, window_size=20, fmt=None):
- if fmt is None:
- fmt = "{median:.4f} ({global_avg:.4f})"
- self.deque = deque(maxlen=window_size)
- self.total = 0.0
- self.count = 0
- self.fmt = fmt
-
- def update(self, value, n=1):
- self.deque.append(value)
- self.count += n
- self.total += value * n
-
- def synchronize_between_processes(self):
- """
- Warning: does not synchronize the deque!
- """
- if not is_dist_avail_and_initialized():
- return
- t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
- dist.barrier()
- dist.all_reduce(t)
- t = t.tolist()
- self.count = int(t[0])
- self.total = t[1]
-
- @property
- def median(self):
- d = torch.tensor(list(self.deque))
- return d.median().item()
-
- @property
- def avg(self):
- d = torch.tensor(list(self.deque), dtype=torch.float32)
- return d.mean().item()
-
- @property
- def global_avg(self):
- return self.total / self.count
-
- @property
- def max(self):
- return max(self.deque)
-
- @property
- def value(self):
- return self.deque[-1]
-
- def __str__(self):
- return self.fmt.format(
- median=self.median,
- avg=self.avg,
- global_avg=self.global_avg,
- max=self.max,
- value=self.value)
-
-
- class MetricLogger:
- def __init__(self, delimiter="\t"):
- self.meters = defaultdict(SmoothedValue)
- self.delimiter = delimiter
-
- def update(self, **kwargs):
- for k, v in kwargs.items():
- if isinstance(v, torch.Tensor):
- v = v.item()
- assert isinstance(v, (float, int))
- self.meters[k].update(v)
-
- def __getattr__(self, attr):
- if attr in self.meters:
- return self.meters[attr]
- if attr in self.__dict__:
- return self.__dict__[attr]
- raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
-
- def __str__(self):
- loss_str = []
- for name, meter in self.meters.items():
- loss_str.append(f"{name}: {str(meter)}")
- return self.delimiter.join(loss_str)
-
- def synchronize_between_processes(self):
- for meter in self.meters.values():
- meter.synchronize_between_processes()
-
- def add_meter(self, name, meter):
- self.meters[name] = meter
-
- def log_every(self, iterable, print_freq, header=None):
- i = 0
- if not header:
- header = ""
- start_time = time.time()
- end = time.time()
- iter_time = SmoothedValue(fmt="{avg:.4f}")
- data_time = SmoothedValue(fmt="{avg:.4f}")
- space_fmt = ":" + str(len(str(len(iterable)))) + "d"
- if torch.cuda.is_available():
- log_msg = self.delimiter.join(
- [
- header,
- "[{0" + space_fmt + "}/{1}]",
- "eta: {eta}",
- "{meters}",
- "time: {time}",
- "data: {data}",
- "max mem: {memory:.0f}",
- ]
- )
- else:
- log_msg = self.delimiter.join(
- [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
- )
- MB = 1024.0 * 1024.0
- for obj in iterable:
- data_time.update(time.time() - end)
- yield obj
- iter_time.update(time.time() - end)
- if i % print_freq == 0:
- eta_seconds = iter_time.global_avg * (len(iterable) - i) # 剩余时间
- eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
- if torch.cuda.is_available():
- print(
- log_msg.format(
- i,
- len(iterable),
- eta=eta_string,
- meters=str(self),
- time=str(iter_time), # 当前batch 时间
- data=str(data_time), # 当前batch读数据 时间
- memory=torch.cuda.max_memory_allocated() / MB,
- )
- )
- else:
- print(
- log_msg.format(
- i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
- )
- )
- i += 1
- end = time.time()
- # total_time = time.time() - start_time
- # total_time_str = str(datetime.timedelta(seconds=int(total_time)))
- # print(f"{header} Total time: {total_time_str}")
-
-
- def setup_for_distributed(is_master):
- """
- This function disables printing when not in master process
- """
- import builtins as __builtin__
- builtin_print = __builtin__.print
-
- def print(*args, **kwargs):
- force = kwargs.pop('force', False)
- if is_master or force:
- builtin_print(*args, **kwargs)
-
- __builtin__.print = print
-
-
- def is_dist_avail_and_initialized():
- if not dist.is_available():
- return False
- if not dist.is_initialized():
- return False
- return True
-
-
- def get_world_size():
- if not is_dist_avail_and_initialized():
- return 1
- return dist.get_world_size()
-
-
- def get_rank():
- if not is_dist_avail_and_initialized():
- return 0
- return dist.get_rank()
-
-
- def is_main_process():
- return get_rank() == 0
-
-
- def save_on_master(*args, **kwargs):
- if is_main_process():
- torch.save(*args, **kwargs)
-
-
- def init_distributed_mode(args):
- if 'OMPI_COMM_WORLD_RANK' in os.environ:
- args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK'))
- args.world_size = int(os.environ.get('OMPI_COMM_WORLD_SIZE'))
- args.gpu = args.rank % torch.cuda.device_count()
- elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
- args.rank = int(os.environ["RANK"])
- args.world_size = int(os.environ['WORLD_SIZE'])
- args.gpu = int(os.environ['LOCAL_RANK'])
- elif 'SLURM_PROCID' in os.environ:
- args.rank = int(os.environ['SLURM_PROCID'])
- args.gpu = args.rank % torch.cuda.device_count()
- else:
- print('Not using distributed mode')
- args.distributed = False
- return
-
- args.distributed = True
- if 'WORLD_SIZE' in os.environ:
- args.num_gpus = int(os.environ['WORLD_SIZE'])
- torch.cuda.set_device(args.gpu)
- args.dist_backend = 'nccl'
- print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), flush=True)
- torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
- world_size=args.world_size, rank=args.rank)
- torch.distributed.barrier()
- setup_for_distributed(args.rank == 0)
-
- def set_seed(seed=42):
- np.random.seed(seed)
- torch.manual_seed(seed)
- if torch.cuda.is_available(): # GPU operation have separate seed
- torch.cuda.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
-
- def update_config_from_file(args):
- opt = vars(args)
- args = yaml.load(open(args.config), Loader=yaml.FullLoader)
- opt.update(args)
- args = opt
- return args
-
-
- def to_device(sample, device):
- for key, value in sample.items():
- if key != 'img_name' and key != 'pad':
- sample[key] = value.to(device, non_blocking=True)
- return sample
-
-
-
-
-
- def set_checkpoint_dir(args):
- if args.test_only or args.resume:
- return False
- print('Removing Previous Checkpoints and Get New Checkpoints Dir')
- create_dir('./logs/{}/{}'.format(args.dataset, args.file_name))
-
- create_dir('./checkpoints/{}/{}'.format(args.dataset, args.file_name))
-
- def clever_format(nums, format="%.2f"):
- if not isinstance(nums, Iterable):
- nums = [nums]
- clever_nums = []
-
- for num in nums:
- if num > 1e12:
- clever_nums.append(format % (num / 1e12) + "T")
- elif num > 1e9:
- clever_nums.append(format % (num / 1e9) + "G")
- elif num > 1e6:
- clever_nums.append(format % (num / 1e6) + "M")
- elif num > 1e3:
- clever_nums.append(format % (num / 1e3) + "K")
- else:
- clever_nums.append(format % num + "B")
-
- clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums, )
-
- return clever_nums
-
- def get_parameter_number(net):
- total_num = sum(p.numel() for p in net.parameters())
- trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
- total_num, trainable_num = clever_format([total_num, trainable_num])
- return {'Total': total_num, 'Trainable': trainable_num}
-
- # @torch.no_grad()
- # def rgb2ycbcr(rgb):
- #
- # r, g, b = rgb[:, 0: 1, :, :], rgb[:, 1: 2, :, :], rgb[:, 2: 3, :, :]
- # gray = 0.257 * r + 0.504 * g + 0.098 * b + 16
- #
- # return gray
-
- @torch.no_grad()
- def torch_psnr(img1, img2, border, data_range):
- if border != 0:
- img1 = img1[:, :, border: -border, border: -border]
- img2 = img2[:, :, border: -border, border: -border]
-
- img1, img2 = img1* (255 / data_range), img2* (255 / data_range)
-
- mse = torch.mean((img1 - img2) ** 2)
- if mse == 0:
- return float('inf')
- else:
- return 20 * torch.log10_(255.0 / torch.sqrt(mse)).detach().cpu().numpy()
-
- # def down_sample(img, scale, method='cv2'):
- # noisy_image = img + np.random.normal(0, 10 ** 0.5, img.shape)
- # cv2.normalize(noisy_image, noisy_image, 0, 255, cv2.NORM_MINMAX, dtype=-1)
- # noisy_image = noisy_image.astype(np.uint8)
- # return cv2.resize(noisy_image, fx=scale, fy=scale, dsize=None) if method == 'cv2' else imresize(noisy_image, scalar_scale=scale)
-
-
- def tensor2uint(*args, data_range):
- def _tensor2uint(img):
- img = img.data.squeeze().float().clamp_(0, data_range).cpu().numpy()
- if img.ndim == 3:
- img = np.transpose(img, (1, 2, 0))
- return np.uint8((img * (255.0 / data_range)).round())
-
- out = [_tensor2uint(a) for a in args]
- return out if len(out) > 1 else out[0]
-
-
- def calc_psnr(img1, img2):
- return measure.compare_psnr(np.array(img1), np.array(img2), data_range=255)
-
- def calc_psnr_y(img1, img2):
- img1_np = np.array(rgb2ycbcr(img1))[:, :, 0]
- img2_np = np.array(rgb2ycbcr(img2))[:, :, 0]
- return measure.compare_psnr(img1_np, img2_np, data_range=255)
-
- def calc_metrics(img_out, img_gt, args):
- # img_out, img_gt = tensor2uint(img_out, img_gt, data_range=args.data_range)
- # print(img_out.shape, img_gt.shape)
- metrics = 0
- if args.dataset == 'PBVS':
- metrics = {'PSNR': torch_psnr(img_out, img_gt, data_range=args.data_range, border=0)}
- elif args.dataset == 'SSR':
- psnr_y, psnr_rgb = [], []
- for bi in range(img_out.size(0)):
- if img_gt.size(1) == 6:
- left_out_img, left_gt_img = tensor2uint(img_out[bi, 0: 3], img_gt[bi, 0: 3], data_range=args.data_range)
- right_out_img, right_gt_img = tensor2uint(img_out[bi, 3: 6], img_gt[bi, 3: 6], data_range=args.data_range)
- psnr_rgb.append((calc_psnr(left_out_img, left_gt_img) + calc_psnr(right_out_img, right_gt_img)) / 2)
- psnr_y.append((calc_psnr_y(left_out_img, left_gt_img) + calc_psnr_y(right_out_img, right_gt_img)) / 2)
- metrics = {'PSNR': np.mean(psnr_y), 'PSNR_RGB': np.mean(psnr_rgb)}
- else:
- out_img, gt_img = tensor2uint(img_out[bi], img_gt[bi], data_range=args.data_range)
- psnr_y.append(calc_psnr(out_img[:, :, 0], gt_img[:, :, 0]) + calc_psnr(out_img[:, :, 1], gt_img[:, :, 1]))
- metrics = {'PSNR': np.mean(psnr_y) / 2}
- else:
- ssim = measure.compare_ssim(img_out, img_gt, data_range=255, multichannel=(img_out.ndim == 3))
- metrics = {'PSNR': calc_psnr(img_out, img_gt), 'SSIM': ssim}
- return metrics
-
-
- def mix_up(samples, alpha, prob=0.7):
- gt_img = samples['gt_img']
-
- if np.random.rand(1) < prob and alpha > 0:
- lam = np.random.beta(alpha, alpha)
- batch_size = gt_img.size(0)
- index = torch.randperm(batch_size).to(gt_img.device)
-
- for key, value in samples.items():
- if key != 'img_name' and key != 'pad':
- samples[key] = lam * value + (1 - lam) * value[index]
-
- return samples
-
-
- def transform(*args, xflip, yflip, transpose, reverse=False):
- def _transform(img):
- if not reverse: # forward transform
- if xflip: img = torch.flip(img, [3])
- if yflip: img = torch.flip(img, [2])
- if transpose: img = torch.transpose(img, 2, 3)
- else: # reverse transform
- if transpose: img = torch.transpose(img, 2, 3)
- if yflip: img = torch.flip(img, [2])
- if xflip: img = torch.flip(img, [3])
- return img
- out = [_transform(a) for a in args]
- return out if len(out) > 1 else out[0]
-
-
- def self_ensemble(samples, model, ensemble_mode='mean'):
- outputs = []
- tmp_lr_up = samples['lr_up'].clone()
- tmp_input = samples['lr_img'].clone()
- opts = itertools.product((False, True), (False, True), (False, True))
- for x_flip, y_flip, transpose in opts:
- samples['lr_img'], samples['lr_up'] = transform(tmp_input.clone(), tmp_lr_up.clone(), xflip=x_flip, yflip=y_flip, transpose=transpose)
- out_img = model(samples)['out_img']
- outputs.append(transform(out_img, xflip=x_flip, yflip=y_flip, transpose=transpose, reverse=True))
-
- if ensemble_mode == 'mean':
- out_img = torch.stack(outputs, 0).mean(0)
- elif ensemble_mode == 'median':
- out_img = torch.stack(outputs, 0).median(0)[0]
- else:
- raise ValueError("Unknown ensemble mode %s." % ensemble_mode)
- return {'out_img': out_img}
-
- def freeze_tail(model, defrost: bool = False):
- if defrost:
- print("Defrost all frozen layers.")
- for param in model.parameters():
- param.requires_grad = True
- return
-
- for name, param in model.named_parameters():
- prefix = ".".join(name.split(".")[:3])
- if prefix.find('up_x') == -1:
- param.requires_grad = False
- else:
- print("{} is trainable.".format(name))
|