|
- import torch
- import torch.nn as nn
- from utils.sampling import fps
- from utils.grouping import ball_query
- from utils.common import gather_points, get_dists
-
- def sample_and_group(xyz, points, M, radius, K, use_xyz=True):
- '''
- :param xyz: shape=(B, N, 3)
- :param points: shape=(B, N, C)
- :param M: int
- :param radius:float
- :param K: int
- :param use_xyz: bool, if True concat XYZ with local point features, otherwise just use point features
- :return: new_xyz, shape=(B, M, 3); new_points, shape=(B, M, K, C+3);
- group_inds, shape=(B, M, K); grouped_xyz, shape=(B, M, K, 3)
- '''
- new_xyz = gather_points(xyz, fps(xyz, M))
- grouped_inds = ball_query(xyz, new_xyz, radius, K)
- grouped_xyz = gather_points(xyz, grouped_inds)
- grouped_xyz -= torch.unsqueeze(new_xyz, 2).repeat(1, 1, K, 1)
- if points is not None:
- grouped_points = gather_points(points, grouped_inds)
- if use_xyz:
- new_points = torch.cat((grouped_xyz.float(), grouped_points.float()), dim=-1)
- else:
- new_points = grouped_points
- else:
- new_points = grouped_xyz
- return new_xyz, new_points, grouped_inds, grouped_xyz
-
-
- def sample_and_group_all(xyz, points, use_xyz=True):
- '''
- :param xyz: shape=(B, M, 3)
- :param points: shape=(B, M, C)
- :param use_xyz:
- :return: new_xyz, shape=(B, 1, 3); new_points, shape=(B, 1, M, C+3);
- group_inds, shape=(B, 1, M); grouped_xyz, shape=(B, 1, M, 3)
- '''
- B, M, C = xyz.shape
- new_xyz = torch.zeros(B, 1, C)
- grouped_inds = torch.arange(0, M).long().view(1, 1, M).repeat(B, 1, 1)
- grouped_xyz = xyz.view(B, 1, M, C)
- if points is not None:
- if use_xyz:
- new_points = torch.cat([xyz.float(), points.float()], dim=2)
- else:
- new_points = points
- new_points = torch.unsqueeze(new_points, dim=1)
- else:
- new_points = grouped_xyz
- return new_xyz, new_points, grouped_inds, grouped_xyz
-
-
- class PointNet_SA_Module(nn.Module):
- def __init__(self, M, radius, K, in_channels, mlp, group_all, bn=True, pooling='max', use_xyz=True):
- super(PointNet_SA_Module, self).__init__()
- self.M = M
- self.radius = radius
- self.K = K
- self.in_channels = in_channels
- self.mlp = mlp
- self.group_all = group_all
- self.bn = bn
- self.pooling = pooling
- self.use_xyz = use_xyz
- self.backbone = nn.Sequential()
- for i, out_channels in enumerate(mlp):
- self.backbone.add_module('Conv{}'.format(i),
- nn.Conv2d(in_channels, out_channels, 1,
- stride=1, padding=0, bias=False))
- if bn:
- self.backbone.add_module('Bn{}'.format(i),
- nn.BatchNorm2d(out_channels))
- self.backbone.add_module('Relu{}'.format(i), nn.ReLU())
- in_channels = out_channels
- def forward(self, xyz, points):
- if self.group_all:
- new_xyz, new_points, grouped_inds, grouped_xyz = sample_and_group_all(xyz, points, self.use_xyz)
- else:
- new_xyz, new_points, grouped_inds, grouped_xyz = sample_and_group(xyz=xyz,
- points=points,
- M=self.M,
- radius=self.radius,
- K=self.K,
- use_xyz=self.use_xyz)
- new_points = self.backbone(new_points.permute(0, 3, 2, 1).contiguous())
- if self.pooling == 'avg':
- new_points = torch.mean(new_points, dim=2)
- else:
- new_points = torch.max(new_points, dim=2)[0]
- new_points = new_points.permute(0, 2, 1).contiguous()
- return new_xyz, new_points
-
-
- def three_nn(xyz1, xyz2):
- '''
- :param xyz1: shape=(B, N1, 3)
- :param xyz2: shape=(B, N2, 3)
- :return: dists: shape=(B, N1, 3), inds: shape=(B, N1, 3)
- '''
- dists = get_dists(xyz1, xyz2)
- dists, inds = torch.sort(dists, dim=-1)
- dists, inds = dists[:, :, :3], inds[:, :, :3]
- return dists, inds
-
-
- def three_interpolate(xyz1, xyz2, points2):
- '''
- :param xyz1: shape=(B, N1, 3)
- :param xyz2: shape=(B, N2, 3)
- :param points2: shape=(B, N2, C2)
- :return: interpolated_points: shape=(B, N1, C2)
- '''
- _, _, C2 = points2.shape
- dists, inds = three_nn(xyz1, xyz2)
- inversed_dists = 1.0 / (dists + 1e-8)
- weight = inversed_dists / torch.sum(inversed_dists, dim=-1, keepdim=True) # shape=(B, N1, 3)
- weight = torch.unsqueeze(weight, -1).repeat(1, 1, 1, C2)
- interpolated_points = gather_points(points2, inds) # shape=(B, N1, 3, C2)
- interpolated_points = torch.sum(weight * interpolated_points, dim=2)
- return interpolated_points
-
-
- class PointNet_FP_Module(nn.Module):
- def __init__(self, in_channels, mlp, bn=True):
- super(PointNet_FP_Module, self).__init__()
- self.backbone = nn.Sequential()
- bias = False if bn else True
- for i, out_channels in enumerate(mlp):
- self.backbone.add_module('Conv_{}'.format(i), nn.Conv2d(in_channels,
- out_channels,
- 1,
- stride=1,
- padding=0,
- bias=bias))
- if bn:
- self.backbone.add_module('Bn_{}'.format(i), nn.BatchNorm2d(out_channels))
- self.backbone.add_module('Relu_{}'.format(i), nn.ReLU())
- in_channels = out_channels
- def forward(self, xyz1, xyz2, points1, points2):
- '''
- :param xyz1: shape=(B, N1, 3)
- :param xyz2: shape=(B, N2, 3) (N1 >= N2)
- :param points1: shape=(B, N1, C1)
- :param points2: shape=(B, N2, C2)
- :return: new_points2: shape = (B, N1, mlp[-1])
- '''
- B, N1, C1 = points1.shape
- _, N2, C2 = points2.shape
- if N2 == 1:
- interpolated_points = points2.repeat(1, N1, 1)
- else:
- interpolated_points = three_interpolate(xyz1, xyz2, points2)
- cat_interpolated_points = torch.cat([interpolated_points, points1], dim=-1).permute(0, 2, 1).contiguous()
- new_points = torch.squeeze(self.backbone(torch.unsqueeze(cat_interpolated_points, -1)), dim=-1)
- return new_points.permute(0, 2, 1).contiguous()
|