|
-
- from __future__ import absolute_import
-
- import paddle
- import paddle.nn as nn
- import paddle.nn.initializer as init
- import numpy as np
- from . import pretrained_networks as pn
- import paddle.nn
-
- import paddle_lpips as lpips
-
- def spatial_average(in_tens, keepdim=True):
- return in_tens.mean([2,3],keepdim=keepdim)
-
- def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W
- in_H, in_W = in_tens.shape[2], in_tens.shape[3]
- return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)
-
- # Learned perceptual metric
- class LPIPS(nn.Layer):
- def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False,
- pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True):
- # lpips - [True] means with linear calibration on top of base network
- # pretrained - [True] means load linear weights
-
- super(LPIPS, self).__init__()
- if(verbose):
- print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'%
- ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))
-
- self.pnet_type = net
- self.pnet_tune = pnet_tune
- self.pnet_rand = pnet_rand
- self.spatial = spatial
- self.lpips = lpips # false means baseline of just averaging all layers
- self.version = version
- self.scaling_layer = ScalingLayer()
-
- if(self.pnet_type in ['vgg','vgg16']):
- net_type = pn.vgg16
- self.chns = [64,128,256,512,512]
- elif(self.pnet_type=='alex'):
- net_type = pn.alexnet
- self.chns = [64,192,384,256,256]
- elif(self.pnet_type=='squeeze'):
- net_type = pn.squeezenet
- self.chns = [64,128,256,384,384,512,512]
- self.L = len(self.chns)
-
- self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
-
- if(lpips):
- self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
- self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
- self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
- self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
- self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
- self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
- if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
- self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
- self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
- self.lins+=[self.lin5,self.lin6]
- self.lins = nn.LayerList(self.lins)
-
- if(pretrained):
- if(model_path is None):
- import inspect
- import os
- model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pdparams'%(version,net)))
-
- if(verbose):
- print('Loading model from: %s'%model_path)
- import warnings
- with warnings.catch_warnings():
- warnings.simplefilter('ignore')
- self.set_state_dict(paddle.load(model_path))
-
- if(eval_mode):
- self.eval()
-
- def forward(self, in0, in1, retPerLayer=False, normalize=False):
- if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
- in0 = 2 * in0 - 1
- in1 = 2 * in1 - 1
-
- # v0.0 - original release had a bug, where input was not scaled
- in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
- outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
- feats0, feats1, diffs = {}, {}, {}
-
- for kk in range(self.L):
- feats0[kk], feats1[kk] = lpips.normalize_tensor(outs0[kk]), lpips.normalize_tensor(outs1[kk])
- diffs[kk] = (feats0[kk]-feats1[kk])**2
-
- if(self.lpips):
- if(self.spatial):
- res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
- else:
- res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
- else:
- if(self.spatial):
- res = [upsample(diffs[kk].sum(axis=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
- else:
- res = [spatial_average(diffs[kk].sum(axis=1,keepdim=True), keepdim=True) for kk in range(self.L)]
-
- val = res[0]
- for l in range(1,self.L):
- val += res[l]
-
- # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
- # b = paddle.max(self.lins[kk](feats0[kk]**2))
- # for kk in range(self.L):
- # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
- # b = paddle.max(b,paddle.max(self.lins[kk](feats0[kk]**2)))
- # a = a/self.L
- # from IPython import embed
- # embed()
- # return 10*paddle.log10(b/a)
-
- if(retPerLayer):
- return (val, res)
- else:
- return val
-
-
- class ScalingLayer(nn.Layer):
- def __init__(self):
- super(ScalingLayer, self).__init__()
- self.register_buffer('shift', paddle.to_tensor(np.asarray([-.030,-.088,-.188]).astype('float32')[None,:,None,None]))
- self.register_buffer('scale', paddle.to_tensor(np.asarray([.458,.448,.450]).astype('float32')[None,:,None,None]))
-
- def forward(self, inp):
- return (inp - self.shift) / self.scale
-
-
- class NetLinLayer(nn.Layer):
- ''' A single linear layer which does a 1x1 conv '''
- def __init__(self, chn_in, chn_out=1, use_dropout=False):
- super(NetLinLayer, self).__init__()
-
- layers = [nn.Dropout(),] if(use_dropout) else []
- layers += [nn.Conv2D(chn_in, chn_out, 1, stride=1, padding=0, bias_attr=False),]
- self.model = nn.Sequential(*layers)
-
- def forward(self, x):
- return self.model(x)
-
- class Dist2LogitLayer(nn.Layer):
- ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
- def __init__(self, chn_mid=32, use_sigmoid=True):
- super(Dist2LogitLayer, self).__init__()
-
- layers = [nn.Conv2D(5, chn_mid, 1, stride=1, padding=0, bias=True),]
- layers += [nn.LeakyReLU(0.2),]
- layers += [nn.Conv2D(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
- layers += [nn.LeakyReLU(0.2),]
- layers += [nn.Conv2D(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
- if(use_sigmoid):
- layers += [nn.Sigmoid(),]
- self.model = nn.Sequential(*layers)
-
- def forward(self,d0,d1,eps=0.1):
- return self.model.forward(paddle.concat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),axis=1))
-
- class BCERankingLoss(nn.Layer):
- def __init__(self, chn_mid=32):
- super(BCERankingLoss, self).__init__()
- self.net = Dist2LogitLayer(chn_mid=chn_mid)
- # self.parameters = list(self.net.parameters())
- self.loss = paddle.nn.BCELoss()
-
- def forward(self, d0, d1, judge):
- per = (judge+1.)/2.
- self.logit = self.net.forward(d0,d1)
- return self.loss(self.logit, per)
-
- # L2, DSSIM metrics
- class FakeNet(nn.Layer):
- def __init__(self, use_gpu=True, colorspace='Lab'):
- super(FakeNet, self).__init__()
- self.use_gpu = use_gpu
- self.colorspace = colorspace
-
- class L2(FakeNet):
- def forward(self, in0, in1, retPerLayer=None):
- assert(in0.shape[0]==1) # currently only supports batchSize 1
-
- if(self.colorspace=='RGB'):
- (N,C,X,Y) = in0.shape
- value = paddle.mean(paddle.mean(paddle.mean((in0-in1)**2,axis=1).reshape((N,1,X,Y)),axis=2).reshape((N,1,1,Y)),axis=3).reshape((N,))
- return value
- elif(self.colorspace=='Lab'):
- value = lpips.l2(lpips.tensor2np(lpips.tensor2tensorlab(in0,to_norm=False)),
- lpips.tensor2np(lpips.tensor2tensorlab(in1,to_norm=False)), range=100.).astype('float')
- ret_var = paddle.to_tensor((value,) )
- return ret_var
-
- class DSSIM(FakeNet):
-
- def forward(self, in0, in1, retPerLayer=None):
- assert(in0.shape[0]==1) # currently only supports batchSize 1
-
- if(self.colorspace=='RGB'):
- value = lpips.dssim(1.*lpips.tensor2im(in0), 1.*lpips.tensor2im(in1), range=255.).astype('float')
- elif(self.colorspace=='Lab'):
- value = lpips.dssim(lpips.tensor2np(lpips.tensor2tensorlab(in0,to_norm=False)),
- lpips.tensor2np(lpips.tensor2tensorlab(in1,to_norm=False)), range=100.).astype('float')
- ret_var = paddle.to_tensor((value,) )
- return ret_var
-
- def print_network(net):
- num_params = 0
- for param in net.parameters():
- num_params += param.numel()
- print('Network',net)
- print('Total number of parameters: %d' % num_params)
|