|
- import torch
- from torch.autograd import Variable
- from torch.autograd import Function
- import torch.nn as nn
- from typing import Tuple
-
- import pointnet2_cuda as pointnet2
-
-
- class FurthestPointSampling(Function):
- @staticmethod
- def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
- """
- Uses iterative furthest point sampling to select a set of npoint features that have the largest
- minimum distance
- :param ctx:
- :param xyz: (B, N, 3) where N > npoint
- :param npoint: int, number of features in the sampled set
- :return:
- output: (B, npoint) tensor containing the set
- """
- assert xyz.is_contiguous()
-
- B, N, _ = xyz.size()
- output = torch.cuda.IntTensor(B, npoint)
- temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
-
- pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
- return output
-
- @staticmethod
- def backward(xyz, a=None):
- return None, None
-
-
- furthest_point_sample = FurthestPointSampling.apply
-
-
- class GatherOperation(Function):
-
- @staticmethod
- def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
- """
- :param ctx:
- :param features: (B, C, N)
- :param idx: (B, npoint) index tensor of the features to gather
- :return:
- output: (B, C, npoint)
- """
- assert features.is_contiguous()
- assert idx.is_contiguous()
-
- B, npoint = idx.size()
- _, C, N = features.size()
- output = torch.cuda.FloatTensor(B, C, npoint)
-
- pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
-
- ctx.for_backwards = (idx, C, N)
- return output
-
- @staticmethod
- def backward(ctx, grad_out):
- idx, C, N = ctx.for_backwards
- B, npoint = idx.size()
-
- grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
- grad_out_data = grad_out.data.contiguous()
- pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
- return grad_features, None
-
-
- gather_operation = GatherOperation.apply
-
-
- class ThreeNN(Function):
-
- @staticmethod
- def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Find the three nearest neighbors of unknown in known
- :param ctx:
- :param unknown: (B, N, 3)
- :param known: (B, M, 3)
- :return:
- dist: (B, N, 3) l2 distance to the three nearest neighbors
- idx: (B, N, 3) index of 3 nearest neighbors
- """
- assert unknown.is_contiguous()
- assert known.is_contiguous()
-
- B, N, _ = unknown.size()
- m = known.size(1)
- dist2 = torch.cuda.FloatTensor(B, N, 3)
- idx = torch.cuda.IntTensor(B, N, 3)
-
- pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
- return torch.sqrt(dist2), idx
-
- @staticmethod
- def backward(ctx, a=None, b=None):
- return None, None
-
-
- three_nn = ThreeNN.apply
-
-
- class ThreeInterpolate(Function):
-
- @staticmethod
- def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
- """
- Performs weight linear interpolation on 3 features
- :param ctx:
- :param features: (B, C, M) Features descriptors to be interpolated from
- :param idx: (B, n, 3) three nearest neighbors of the target features in features
- :param weight: (B, n, 3) weights
- :return:
- output: (B, C, N) tensor of the interpolated features
- """
- assert features.is_contiguous()
- assert idx.is_contiguous()
- assert weight.is_contiguous()
-
- B, c, m = features.size()
- n = idx.size(1)
- ctx.three_interpolate_for_backward = (idx, weight, m)
- output = torch.cuda.FloatTensor(B, c, n)
-
- pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
- return output
-
- @staticmethod
- def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """
- :param ctx:
- :param grad_out: (B, C, N) tensor with gradients of outputs
- :return:
- grad_features: (B, C, M) tensor with gradients of features
- None:
- None:
- """
- idx, weight, m = ctx.three_interpolate_for_backward
- B, c, n = grad_out.size()
-
- grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
- grad_out_data = grad_out.data.contiguous()
-
- pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
- return grad_features, None, None
-
-
- three_interpolate = ThreeInterpolate.apply
-
-
- class GroupingOperation(Function):
-
- @staticmethod
- def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
- """
- :param ctx:
- :param features: (B, C, N) tensor of features to group
- :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
- :return:
- output: (B, C, npoint, nsample) tensor
- """
- assert features.is_contiguous()
- assert idx.is_contiguous()
-
- B, nfeatures, nsample = idx.size()
- _, C, N = features.size()
- output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
-
- pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
-
- ctx.for_backwards = (idx, N)
- return output
-
- @staticmethod
- def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- :param ctx:
- :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
- :return:
- grad_features: (B, C, N) gradient of the features
- """
- idx, N = ctx.for_backwards
-
- B, C, npoint, nsample = grad_out.size()
- grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
-
- grad_out_data = grad_out.data.contiguous()
- pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
- return grad_features, None
-
-
- grouping_operation = GroupingOperation.apply
-
-
- class BallQuery(Function):
-
- @staticmethod
- def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
- """
- :param ctx:
- :param radius: float, radius of the balls
- :param nsample: int, maximum number of features in the balls
- :param xyz: (B, N, 3) xyz coordinates of the features
- :param new_xyz: (B, npoint, 3) centers of the ball query
- :return:
- idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
- """
- assert new_xyz.is_contiguous()
- assert xyz.is_contiguous()
-
- B, N, _ = xyz.size()
- npoint = new_xyz.size(1)
- idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
-
- pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
- return idx
-
- @staticmethod
- def backward(ctx, a=None):
- return None, None, None, None
-
-
- ball_query = BallQuery.apply
-
-
- class QueryAndGroup(nn.Module):
- def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
- """
- :param radius: float, radius of ball
- :param nsample: int, maximum number of features to gather in the ball
- :param use_xyz:
- """
- super().__init__()
- self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
-
- def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
- """
- :param xyz: (B, N, 3) xyz coordinates of the features
- :param new_xyz: (B, npoint, 3) centroids
- :param features: (B, C, N) descriptors of the features
- :return:
- new_features: (B, 3 + C, npoint, nsample)
- """
- idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
- xyz_trans = xyz.transpose(1, 2).contiguous()
- grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
- grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
-
- if features is not None:
- grouped_features = grouping_operation(features, idx)
- if self.use_xyz:
- new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample)
- else:
- new_features = grouped_features
- else:
- assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
- new_features = grouped_xyz
-
- return new_features
-
-
- class GroupAll(nn.Module):
- def __init__(self, use_xyz: bool = True):
- super().__init__()
- self.use_xyz = use_xyz
-
- def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
- """
- :param xyz: (B, N, 3) xyz coordinates of the features
- :param new_xyz: ignored
- :param features: (B, C, N) descriptors of the features
- :return:
- new_features: (B, C + 3, 1, N)
- """
- grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
- if features is not None:
- grouped_features = features.unsqueeze(2)
- if self.use_xyz:
- new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N)
- else:
- new_features = grouped_features
- else:
- new_features = grouped_xyz
-
- return new_features
|