|
- # import torch
- # import torch.nn.functional as F
- import mindspore as ms
- import mindspore.nn as nn
- import mindspore.ops as ops
-
-
- def _fspecial_gauss_1d(size, sigma):
- r"""Create 1-D gauss kernel
- Args:
- size (int): the size of gauss kernel
- sigma (float): sigma of normal distribution
- Returns:
- torch.Tensor: 1D kernel
- """
- coords = ops.arange(size).astype(ms.float32)
- coords -= size//2
-
- g = ops.exp(-(coords**2) / (2*sigma**2))
- g /= g.sum()
- return g.unsqueeze(0).unsqueeze(0)
-
-
- def gaussian_filter(input, win):
- r""" Blur input with 1-D kernel
- Args:
- input (torch.Tensor): a batch of tensors to be blured
- window (torch.Tensor): 1-D gauss kernel
- Returns:
- torch.Tensor: blured tensors
- """
-
- N, C, H, W = input.shape
- out = ops.conv2d(input, win, stride=1, padding=0, group=C)
- out = ops.conv2d(out, win.transpose((0, 1, 3, 2)), stride=1, padding=0, group=C) ## tranpose跟pytorch有区别
- return out
-
-
- def _ssim(X, Y, win, data_range=255, size_average=True, full=False):
- r""" Calculate ssim index for X and Y
- Args:
- X (torch.Tensor): images
- Y (torch.Tensor): images
- win (torch.Tensor): 1-D gauss kernel
- data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
- size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
- full (bool, optional): return sc or not
- Returns:
- torch.Tensor: ssim results
- """
-
- K1 = 0.01
- K2 = 0.03
- batch, channel, height, width = X.shape
- compensation = 1.0
-
- C1 = (K1 * data_range)**2
- C2 = (K2 * data_range)**2
-
- #win = win.to(X.device, dtype=X.dtype)
-
- mu1 = gaussian_filter(X, win)
- mu2 = gaussian_filter(Y, win)
-
- mu1_sq = mu1.pow(2)
- mu2_sq = mu2.pow(2)
- mu1_mu2 = mu1 * mu2
-
- sigma1_sq = compensation * ( gaussian_filter(X * X, win) - mu1_sq )
- sigma2_sq = compensation * ( gaussian_filter(Y * Y, win) - mu2_sq )
- sigma12 = compensation * ( gaussian_filter(X * Y, win) - mu1_mu2 )
-
- cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
- ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
-
- if size_average:
- ssim_val = ssim_map.mean()
- cs = cs_map.mean()
- else:
- ssim_val = ssim_map.mean(-1).mean(-1).mean(-1) # reduce along CHW
- cs = cs_map.mean(-1).mean(-1).mean(-1)
-
- if full:
- return ssim_val, cs
- else:
- return ssim_val
-
-
- def ssim(X, Y, win_size=11, win_sigma=1.5, win=None, data_range=255, size_average=True, full=False):
- r""" interface of ssim
- Args:
- X (torch.Tensor): a batch of images, (N,C,H,W)
- Y (torch.Tensor): a batch of images, (N,C,H,W)
- win_size: (int, optional): the size of gauss kernel
- win_sigma: (float, optional): sigma of normal distribution
- win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
- data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
- size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
- full (bool, optional): return sc or not
- Returns:
- torch.Tensor: ssim results
- """
-
- if len(X.shape) != 4:
- raise ValueError('Input images must 4-d tensor.')
-
- if not type(X) == type(Y):
- raise ValueError('Input images must have the same dtype.')
-
- if not X.shape == Y.shape:
- raise ValueError('Input images must have the same dimensions.')
-
- if not (win_size % 2 == 1):
- raise ValueError('Window size must be odd.')
-
- win_sigma = win_sigma
- if win is None:
- win = _fspecial_gauss_1d(win_size, win_sigma)
- win = ms.numpy.tile(win, (X.shape[1], 1, 1, 1))
- else:
- win_size = win.shape[-1]
-
- ssim_val, cs = _ssim(X, Y,
- win=win,
- data_range=data_range,
- size_average=False,
- full=True)
- if size_average:
- ssim_val = ssim_val.mean()
- cs = cs.mean()
-
- if full:
- return ssim_val, cs
- else:
- return ssim_val
-
-
- def ms_ssim(X, Y, win=None, win_size=11, win_sigma=1.5, data_range=255, size_average=True, full=False, weights=None):
- r""" interface of ms-ssim
- Args:
- X (torch.Tensor): a batch of images, (N,C,H,W)
- Y (torch.Tensor): a batch of images, (N,C,H,W)
- win_size: (int, optional): the size of gauss kernel
- win_sigma: (float, optional): sigma of normal distribution
- win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
- data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
- size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
- full (bool, optional): return sc or not
- weights (list, optional): weights for different levels
- Returns:
- torch.Tensor: ms-ssim results
- """
- if len(X.shape) != 4:
- raise ValueError('Input images must 4-d tensor.')
-
- if not type(X) == type(Y):
- raise ValueError('Input images must have the same dtype.')
-
- if not X.shape == Y.shape:
- raise ValueError('Input images must have the same dimensions.')
-
- if not (win_size % 2 == 1):
- raise ValueError('Window size must be odd.')
-
- if weights is None:
- weights = ms.Tensor(
- [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
-
- win_sigma = win_sigma
- if win is None:
- win = _fspecial_gauss_1d(win_size, win_sigma)
- win = ms.numpy.tile(win, (X.shape[1], 1, 1, 1))
- else:
- win_size = win.shape[-1]
-
- levels = weights.shape[0]
- mcs = []
- for _ in range(levels):
- ssim_val, cs = _ssim(X, Y,
- win=win,
- data_range=data_range,
- size_average=False,
- full=True)
- mcs.append(cs)
-
- padding = (X.shape[2] % 2, X.shape[3] % 2)
- X = ops.pad(X, padding)
- X = ops.AvgPool(kernel_size=2)(X)
- Y = ops.pad(Y, padding)
- Y = ops.AvgPool(kernel_size=2)(Y)
-
- mcs = ops.stack(mcs, axis=0) # mcs, (level, batch)
- # weights, (level)
- msssim_val = ops.prod((mcs[:-1] ** weights[:-1].unsqueeze(1))
- * (ssim_val ** weights[-1]), axis=0) # (batch, )
-
- if size_average:
- msssim_val = msssim_val.mean()
- return msssim_val
-
-
- # Classes to re-use window
- class SSIM(nn.Cell):
- def __init__(self, win_size=11, win_sigma=1.5, data_range=None, size_average=True, channel=3):
- r""" class for ssim
- Args:
- win_size: (int, optional): the size of gauss kernel
- win_sigma: (float, optional): sigma of normal distribution
- data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
- size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
- channel (int, optional): input channels (default: 3)
- """
-
- super(SSIM, self).__init__()
- self.win = ms.numoy.tile(_fspecial_gauss_1d(
- win_size, win_sigma), (channel, 1, 1, 1))
- self.size_average = size_average
- self.data_range = data_range
-
- def construct(self, X, Y):
- return ssim(X, Y, win=self.win, data_range=self.data_range, size_average=self.size_average)
-
-
- class MS_SSIM(nn.Cell):
- def __init__(self, win_size=11, win_sigma=1.5, data_range=None, size_average=True, channel=3, weights=None):
- r""" class for ms-ssim
- Args:
- win_size: (int, optional): the size of gauss kernel
- win_sigma: (float, optional): sigma of normal distribution
- data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
- size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
- channel (int, optional): input channels (default: 3)
- weights (list, optional): weights for different levels
- """
-
- super(MS_SSIM, self).__init__()
- self.win = ms.numpy.tile(_fspecial_gauss_1d(
- win_size, win_sigma), (channel, 1, 1, 1))
- self.size_average = size_average
- self.data_range = data_range
- self.weights = weights
-
- def construct(self, X, Y):
- return ms_ssim(X, Y, win=self.win, size_average=self.size_average, data_range=self.data_range, weights=self.weights)
-
-
- if __name__ == '__main__':
-
- X = ops.zeros(shape=(4, 3, 256, 256))
- Y = ops.ones(shape=(4, 3, 256, 256))
-
- net = MS_SSIM(data_range=255)
- out = net(X, Y)
- print('out:', out)
|