|
- import torch
- import itertools
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.autograd import Variable
-
- def grid_sample(input, grid, canvas = None):
- output = F.grid_sample(input, grid)
- if canvas is None:
- return output
- else:
- input_mask = Variable(input.data.new(input.size()).fill_(1))
- output_mask = F.grid_sample(input_mask, grid)
- padded_output = output * output_mask + canvas * (1 - output_mask)
- return padded_output
-
- # phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
- def compute_partial_repr(input_points, control_points):
- N = input_points.size(0)
- M = control_points.size(0)
- pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
- # original implementation, very slow
- # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
- pairwise_diff_square = pairwise_diff * pairwise_diff
- pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1]
- repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
- # fix numerical error for 0 * log(0), substitute all nan with 0
- mask = repr_matrix != repr_matrix
- repr_matrix.masked_fill_(mask, 0)
- return repr_matrix
-
- class TPSGridGen(nn.Module):
-
- def __init__(self, target_shape, target_control_points):
- super(TPSGridGen, self).__init__()
- target_height, target_width = target_shape
- self.target_shape = target_shape
-
- assert target_control_points.ndimension() == 2
- assert target_control_points.size(1) == 2
- N = target_control_points.size(0)
- self.num_points = N
- target_control_points = target_control_points.float()
-
- # self.target_control_points = target_control_points
- self.register_buffer('target_control_points', target_control_points)
-
- # create padded kernel matrix
- forward_kernel = torch.zeros(N + 3, N + 3)
- target_control_partial_repr = compute_partial_repr(target_control_points, target_control_points)
- forward_kernel[:N, :N].copy_(target_control_partial_repr)
- forward_kernel[:N, -3].fill_(1)
- forward_kernel[-3, :N].fill_(1)
- forward_kernel[:N, -2:].copy_(target_control_points)
- forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
- # compute inverse matrix
- inverse_kernel = torch.inverse(forward_kernel)
-
- # create target cordinate matrix
- HW = target_height * target_width
- target_coordinate = list(itertools.product(range(target_height), range(target_width)))
- target_coordinate = torch.Tensor(target_coordinate) # HW x 2
- Y, X = target_coordinate.split(1, dim = 1)
- Y = Y * 2 / (target_height - 1) - 1
- X = X * 2 / (target_width - 1) - 1
- target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y)
- target_coordinate_partial_repr = compute_partial_repr(target_coordinate, target_control_points)
- target_coordinate_repr = torch.cat([
- target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
- ], dim = 1)
-
- # register precomputed matrices
- self.register_buffer('inverse_kernel', inverse_kernel)
- self.register_buffer('padding_matrix', torch.zeros(3, 2))
- self.register_buffer('target_coordinate_repr', target_coordinate_repr)
-
- def forward(self, source_control_points):
- assert source_control_points.ndimension() == 3
- assert source_control_points.size(1) == self.num_points
- assert source_control_points.size(2) == 2
- batch_size = source_control_points.size(0)
-
- Y = torch.cat([source_control_points, Variable(self.padding_matrix.expand(batch_size, 3, 2))], 1)
- mapping_matrix = torch.matmul(Variable(self.inverse_kernel), Y)
- source_coordinate = torch.matmul(Variable(self.target_coordinate_repr), mapping_matrix)
- return source_coordinate
-
- def tps_trans(self, inputs, source_control_points=None, max_range=0.1, canvas=0.5, target_shape=None):
- if target_shape is not None:
- if target_shape != self.target_shape:
- device = self.padding_matrix.device
- self.__init__(target_shape, self.target_control_points.cpu())
- self.to(device)
-
- target_height, target_width = self.target_shape
-
- if source_control_points is None:
- source_control_points = self.target_control_points.unsqueeze(0) + self.target_control_points.new(
- size=(inputs.shape[0], ) + self.target_control_points.shape).uniform_(-max_range, max_range)
- # source_control_points = source_control_points.to(self.padding_matrix.device)
- if isinstance(canvas, float):
- canvas = torch.FloatTensor(inputs.shape[0], 3, target_height, target_width).fill_(canvas).to(self.padding_matrix.device)
- source_coordinate = self.forward(source_control_points)
- grid = source_coordinate.view(inputs.shape[0], target_height, target_width, 2)
- target_image = grid_sample(inputs, grid, canvas)
- return target_image, source_control_points
|