|
- import os
- import argparse
- import logging
- import math
- import random
- import shutil
- import sys
- from collections import defaultdict
- from typing import List
- from pathlib import Path
- import numpy as np
-
- import mindspore.nn as nn
- import mindspore.dataset as ds
- from mindspore import context
- import mindspore.ops as ops
- import mindspore as ms
-
- from PIL import Image
-
- from models.ssf2020 import ScaleSpaceFlow
- from models.dcvc import DCVC
- from dataset.vimeo90k import VideoFolder
-
- def psnr(mse: ms.Tensor) -> float:
- return -10 * math.log10(mse)
-
-
- def collect_likelihoods_list(likelihoods_list, num_pixels: int):
- bpp_loss = 0
-
- for likelihood in likelihoods_list:
- bpp = ops.log(likelihood).sum(axis=(1, 2, 3)) / (-math.log(2) * num_pixels)
- bpp_loss += bpp
-
- return bpp_loss
-
- class RateDistortionLoss(nn.Cell):
- """Custom rate distortion loss with a Lagrangian parameter."""
-
- def __init__(self, lmbda=1e-2, return_details: bool = False, bitdepth: int = 8):
- super().__init__()
- self.mse = nn.MSELoss(reduction="none")
- self.lmbda = lmbda
- self._scaling_functions = lambda x: (2**bitdepth - 1) ** 2 * x
- self.return_details = bool(return_details)
-
- @staticmethod
- def _get_rate(likelihoods_list, num_pixels):
- return sum(
- (ops.log(likelihoods).sum() / (-math.log(2) * num_pixels))
- for frame_likelihoods in likelihoods_list
- for likelihoods in frame_likelihoods
- )
-
- def _get_scaled_distortion(self, x, target):
- if not len(x) == len(target):
- raise RuntimeError(f"len(x)={len(x)} != len(target)={len(target)})")
-
- nC = ms.ops.shape(x)[1]
- if not nC == ms.ops.shape(target)[1]:
- raise RuntimeError(
- "number of channels mismatches while computing distortion"
- )
-
- if isinstance(x, ms.Tensor):
- x = ops.Split(1, ms.ops.shape(x)[1])(x)
-
- if isinstance(target, ms.Tensor):
- target = ops.Split(1, ms.ops.shape(target)[1])(target)
-
- # compute metric over each component (eg: y, u and v)
- metric_values = []
- for (x0, x1) in zip(x, target):
- v = self.mse(x0.float(), x1.float())
- if v.ndimension() == 4:
- v = v.mean(axis=(1, 2, 3))
- metric_values.append(v)
- metric_values = ops.stack(tuple(metric_values))
-
- # sum value over the components dimension
- metric_value = ops.sum(metric_values.transpose(1, 0), dim=1) / nC
- scaled_metric = self._scaling_functions(metric_value)
-
- return scaled_metric, metric_value
-
- @staticmethod
- def _check_tensor(x) -> bool:
- return (isinstance(x, ms.Tensor) and x.ndimension() == 4) or (
- isinstance(x, (tuple, list)) and isinstance(x[0], ms.Tensor)
- )
-
- @classmethod
- def _check_tensors_list(cls, lst):
- if (
- not isinstance(lst, (tuple, list))
- or len(lst) < 1
- or any(not cls._check_tensor(x) for x in lst)
- ):
- raise ValueError(
- "Expected a list of 4D torch.Tensor (or tuples of) as input"
- )
-
- def construct(self, reconstructions, frames_likelihoods, targets):
- assert isinstance(targets, type(reconstructions))
- assert len(reconstructions) == len(targets)
-
- self._check_tensors_list(targets)
- self._check_tensors_list(reconstructions)
-
- _, _, H, W = ms.ops.shape(targets[0])
- num_frames = len(targets)
-
- num_pixels = H * W * num_frames
-
- # Get scaled and raw loss distortions for each frame
- scaled_distortions = []
- distortions = []
- for i, (x_hat, x) in enumerate(zip(reconstructions, targets)):
- scaled_distortion, distortion = self._get_scaled_distortion(x_hat, x)
-
- distortions.append(distortion)
- scaled_distortions.append(scaled_distortion)
-
- # aggregate (over batch and frame dimensions).
- mse_loss = ops.mean(ops.stack(distortions))
-
- # average scaled_distortions accros the frames
- scaled_distortions = sum(scaled_distortions) / num_frames
-
- assert isinstance(frames_likelihoods, list)
- likelihoods_list = frames_likelihoods
-
- # collect bpp info on noisy tensors (estimated differentiable entropy)
- bpp_loss = collect_likelihoods_list(likelihoods_list, num_pixels)
-
- lambdas = ms.numpy.full_like(bpp_loss, self.lmbda)
- bpp_loss = bpp_loss.mean()
- loss = (lambdas * scaled_distortions).mean() + bpp_loss
-
- distortion_loss = scaled_distortions.mean()
-
- return distortion_loss, mse_loss, bpp_loss, loss
-
- class AverageMeter:
- """Compute running average."""
-
- def __init__(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
-
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
-
- def compute_aux_loss(aux_list: List):
- aux_loss_sum = 0
- for aux_loss in aux_list:
- aux_loss_sum += aux_loss
-
- return aux_loss_sum
-
- def inference_entropy_estimation(test_dataloader, model, criterion):
- model.set_train(False)
-
- loss_avg = AverageMeter()
- bpp_loss_avg = AverageMeter()
- mse_loss_avg = AverageMeter()
- psnr_avg = AverageMeter()
-
- for i, input in enumerate(test_dataloader):
- d = [ input['frame0'], input['frame1'], input['frame2'] ]
-
- reconstructions, frames_likelihoods = model(d)
- distortion_loss, mse_loss, bpp_loss, loss = criterion(reconstructions, frames_likelihoods, d)
-
- psnr_performance = psnr(mse=mse_loss)
- print(
- f"inference: {i}"
- f"\tLoss: {float(loss):.3f} |"
- f"\tMSE loss: {float(mse_loss):.3f} |"
- f"\tPSNR: {float(psnr_performance):.3f} |"
- f"\tBpp loss: {float(bpp_loss):.2f}"
- )
-
- bpp_loss_avg.update(float(bpp_loss))
- loss_avg.update(float(loss))
- mse_loss_avg.update(float(mse_loss))
- psnr_avg.update(float(psnr_performance))
-
- print(
- f"inference: Average losses:"
- f"\tLoss: {loss_avg.avg:.3f} |"
- f"\tMSE loss: {mse_loss_avg.avg:.3f} |"
- f"\tPSNR: {psnr_avg.avg:.3f} |"
- f"\tBpp loss: {bpp_loss_avg.avg:.2f}\n"
- )
-
- return loss_avg.avg
-
- def parse_args(argv):
- parser = argparse.ArgumentParser(description="Example training script.")
- parser.add_argument(
- "-m",
- "--model",
- default="ssf2020",
- help="Model architecture (default: %(default)s)",
- )
- parser.add_argument(
- "-d", "--dataset", type=str, required=True, help="Training dataset"
- )
- parser.add_argument(
- "-n",
- "--num-workers",
- type=int,
- default=4,
- help="Dataloaders threads (default: %(default)s)",
- )
- parser.add_argument(
- "--lambda",
- dest="lmbda",
- type=float,
- default=1e-2,
- help="Bit-rate distortion parameter (default: %(default)s)",
- )
- parser.add_argument(
- "--test-batch-size",
- type=int,
- default=1,
- help="Test batch size (default: %(default)s)",
- )
- parser.add_argument(
- "--patch-size",
- type=int,
- nargs=2,
- default=(256, 256),
- help="Size of the patches to be cropped (default: %(default)s)",
- )
- parser.add_argument("--load_checkpoint", type=str, help="Path to a checkpoint")
-
- args = parser.parse_args(argv)
- return args
-
- def main(argv):
- context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") #GRAPH_MODE(静态图模式) PYNATIVE_MODE(动态图模式)
- context.set_context(save_graphs=False)
- context.set_context(device_id=int(os.getenv('DEVICE_ID', '0')))
- print("int(os.getenv('DEVICE_ID', '0')): ",int(os.getenv('DEVICE_ID', '0'))) #0
- if ms.get_context("device_target") == "GPU":
- context.set_context(enable_graph_kernel=False) #如果开启的话,会在本地生成额外的过程文件,开启图算融合以优化网络执行性能,常用于GPU,动静态模式应该都可以
- ms.reset_auto_parallel_context()
- ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.STAND_ALONE, gradients_mean=True, device_num=1)
- ds.config.set_enable_shared_mem(False) #多进程可以使用共享内存
-
-
- args = parse_args(argv)
-
- test_dataset = VideoFolder(
- args.dataset,
- rnd_interval=False,
- rnd_temp_order=False,
- split="test",
- patch_size=args.patch_size
- )
-
- test_dataset = ds.GeneratorDataset(test_dataset, column_names=["frame0", "frame1", "frame2"],
- num_parallel_workers=args.num_workers, shuffle=False, python_multiprocessing=False) #使用多线程
- test_loader = test_dataset.batch(batch_size=args.test_batch_size, drop_remainder=False)
- test_dataloader = test_loader.create_dict_iterator()
-
-
- if args.model=='ssf2020':
- net = ScaleSpaceFlow()
- if args.model =='dcvc':
- net = DCVC()
-
- criterion = RateDistortionLoss(lmbda=args.lmbda, return_details=True)
-
- if args.load_checkpoint: # load from previous checkpoint
- print("Loading", args.load_checkpoint)
- state_dict = ms.load_checkpoint(args.load_checkpoint)
- ms.load_param_into_net(net, state_dict)
-
- inference_entropy_estimation( test_dataloader, net, criterion)
-
-
- if __name__ == "__main__":
- main(sys.argv[1:])
|