|
- import mindspore.nn as nn
- import numpy as np
- import mindspore
- from mindspore import Tensor
- from mindspore import ms_function
- import numba as nb
- import mindspore as ms
- import mindspore.ops as P
- import mindspore.numpy as mnp
- from mindspore.ops.primitive import constexpr
- import math
- from mindspore.common.initializer import initializer, HeUniform, Uniform
- mindspore_expand_dims = mindspore.ops.ExpandDims()
- mindspore_transpose = mindspore.ops.Transpose()
- mindspore_concat = mindspore.ops.Concat(-1)
- mindspore_concat1 = mindspore.ops.Concat(1)
- mindspore_sort = mindspore.ops.Sort()
- mindspore_reshape = mindspore.ops.Reshape()
- mindspore_tile = mindspore.ops.Tile()
- mindspore_zeros = mindspore.ops.Zeros()
- mindspore_ones = mindspore.ops.Ones()
- mindspore_shape=mindspore.ops.Shape()
- class BatchNorm1d_cus(nn.Cell):
- def __init__(self,num_features):
- super(BatchNorm1d_cus, self).__init__()
- self.bn = nn.BatchNorm2d(num_features)
- self.squeeze = mindspore.ops.Squeeze(3)
- @ms_function
- def construct(self,x):
- x = self.squeeze(self.bn(mindspore_expand_dims(x,3)))
- return x
-
- def sample_and_group(npoint, radius, nsample, xyz, points):
- """
- Input:
- xyz: input points position data, [B, N, 3]
- points: input points data, [B, N, D]
- Return:
- new_xyz: sampled points position data, [B, npoint, nsample, 3]
- new_points: sampled points data, [B, npoint, nsample, 3+D]
- """
- B, _, C = xyz.shape
- S = npoint
- fps_idx = farthest_point_sample(xyz, S) # [B, S]
- new_xyz = index_points(xyz, fps_idx) # [B, S, C]
- idx = query_ball_point(radius, nsample, xyz, new_xyz) # [B, S, nsample]
- grouped_xyz = index_points(xyz, idx) # [B, S, nsample, C]
- grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
-
- if points is not None:
- grouped_points = index_points(points, idx)
- new_points = P.Concat(-1)((grouped_xyz_norm, grouped_points)) # [B, S, nsample, C+D]
- else:
- new_points = grouped_xyz_norm
-
- return new_xyz, new_points
-
- def sample_and_group_all(xyz, points):
- """
- Input:
- xyz: input points position data, [B, N, 3]
- points: input points data, [B, N, D]
- Return:
- new_xyz: sampled points position data, [B, 1, 3]
- new_points: sampled points data, [B, 1, N, 3+D]
- """
- B, N, C = mindspore_shape(xyz)
- new_xyz = mindspore_zeros((B, 1, C),mindspore.float32)
- grouped_xyz = mindspore_reshape(xyz,(B, 1, N, C))
- if points is not None:
- new_points = mindspore_concat((grouped_xyz, mindspore_reshape(points,(B, 1, N, -1))))
- else:
- new_points = grouped_xyz
- return new_xyz, new_points
- def square_distance(src, dst):
- """
- Calculate Euclid distance between each two points.
-
- src^T * dst = xn * xm + yn * ym + zn * zm;
- sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
- sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
- dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
- = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
-
- Input:
- src: source points, [B, N, C]
- dst: target points, [B, M, C]
- Output:
- dist: per-point square distance, [B, N, M]
- """
- B, N, _ = src.shape
- _, M, _ = dst.shape
- dist = -2 * P.BatchMatMul()(src, P.Transpose()(dst, (0, 2, 1)))
- dist += P.Reshape()(P.ReduceSum()(src ** 2, -1), (B, N, 1))
- dist += P.Reshape()(P.ReduceSum()(dst ** 2, -1), (B, 1, M))
- return dist
-
- @constexpr
- def generate_tensor_batch_indices(B):
- """generate tensor"""
- return Tensor(np.arange(B), ms.int32)
- def index_points(points, idx):
- """
- Input:
- points: input points data, [B, N, C]
- idx: sample index data, [B, S] or [B, S, nsample]
- Return:
- new_points:, indexed points data, [B, S, C] or [B, S, nsample, C]
- """
- shape = idx.shape
- batch_indices = generate_tensor_batch_indices(shape[0])
- if len(shape) == 2:
- batch_indices = batch_indices.view(shape[0], 1)
- else:
- batch_indices = batch_indices.view(shape[0], 1, 1)
- batch_indices = batch_indices.expand_as(idx)
- index = P.Concat(-1)((batch_indices.reshape(idx.shape + (1,)), idx.reshape(idx.shape + (1,))))
- new_points = P.GatherNd()(points, index)
- return new_points
-
- def query_ball_point(radius, nsample, xyz, new_xyz):
- """
- Input:
- radius: local region radius
- nsample: max sample number in local region
- xyz: all points, [B, N, 3]
- new_xyz: query points, [B, S, 3]
- Return:
- group_idx: grouped points index, [B, S, nsample]
- """
- B, N, _ = xyz.shape
- _, S, _ = new_xyz.shape
- group_idx = mnp.arange(0, N, 1, ms.int32).view(1, 1, N)
- group_idx = P.Tile()(group_idx, (B, S, 1))
- sqrdists = square_distance(new_xyz, xyz)
-
- idx = sqrdists > radius ** 2
- group_idx = P.Select()(idx, -1 * P.OnesLike()(group_idx), group_idx)
- group_idx = P.Cast()(group_idx, ms.float32)
- group_idx, _ = P.TopK()(group_idx, nsample)
- group_idx = P.Cast()(group_idx, ms.int32)
-
- group_first = group_idx[:, :, 0].view(B, S, 1)
- group_first = P.Tile()(group_first, (1, 1, nsample)) # [B, S, nsample]
-
- index = group_idx != -1
- group_first = P.Select()(index, -1 * P.OnesLike()(group_first), group_first)
- group_idx = P.Maximum()(group_idx, group_first)
-
- return group_idx
- @constexpr
- def generate_tensor_fps(B, N):
- """generate tensor"""
- farthest = Tensor(np.random.randint(N, size=(B,)), ms.int32)
- return farthest
-
- def farthest_point_sample(xyz, npoint):
- """
- Input:
- xyz: pointcloud data, [B, N, 3]
- npoint: number of samples
- Return:
- centroids: sampled pointcloud index, [B, npoint]
- """
- B, N, _ = xyz.shape
- centroids = mnp.zeros((npoint, B), ms.int32)
- distance = mnp.ones((B, N), ms.int32) * 1e9
- farthest = generate_tensor_fps(B, N)
- batch_indices = generate_tensor_batch_indices(B)
- for i in range(npoint):
- centroids = P.Cast()(centroids, ms.float32)
- farthest = P.Cast()(farthest, ms.float32)
- centroids[i] = farthest
- centroids = P.Cast()(centroids, ms.int32)
- farthest = P.Cast()(farthest, ms.int32)
- index = P.Concat(-1)((batch_indices.reshape(batch_indices.shape + (1,)),
- farthest.reshape(farthest.shape + (1,))))
- centroid = P.GatherNd()(xyz, index).reshape((B, 1, 3))
- dist = P.ReduceSum()((xyz - centroid) ** 2, -1)
- distance = P.Minimum()(distance, dist)
- farthest = P.Argmax()(distance)
- return P.Transpose()(centroids, (1, 0))
-
- class Conv2d(nn.Conv2d):
- """Conv2d"""
-
- def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_mode='same', padding=0, dilation=1,
- group=1, has_bias=True):
- super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation, group,
- has_bias, weight_init='normal', bias_init='zeros')
- self.reset_parameters()
-
- def reset_parameters(self):
- """reset parameters"""
- self.weight.set_data(initializer(HeUniform(math.sqrt(5)), self.weight.shape))
- if self.has_bias:
- fan_in, _ = calculate_fan_in_and_fan_out(self.weight.shape)
- bound = 1 / math.sqrt(fan_in)
- self.bias.set_data(initializer(Uniform(bound), [self.out_channels]))
-
- class PointNetSetAbstractionMsg(nn.Cell):
- def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
- super(PointNetSetAbstractionMsg, self).__init__()
- self.npoint = npoint
- self.radius_list = radius_list
- self.nsample_list = nsample_list
- self.conv_blocks = nn.CellList()
- self.bn_blocks = nn.CellList()
- for i in range(len(mlp_list)):
- convs = nn.CellList()
- bns = nn.CellList()
- last_channel = in_channel
- for out_channel in mlp_list[i]:
- convs.append(nn.Conv2d(last_channel, out_channel, 1))
- bns.append(nn.BatchNorm2d(out_channel))
- last_channel = out_channel
- self.conv_blocks.append(convs)
- self.bn_blocks.append(bns)
- self.relu = nn.ReLU()
-
- def construct(self, xyz, points):
- """
- Input:
- xyz: input points position data, [B, C, N]
- points: input points data, [B, D, N]
- Return:
- new_xyz: sampled points position data, [B, C, S]
- new_points_concat: sample points feature data, [B, D', S]
- """
-
- xyz=mindspore_transpose(xyz,(0,2,1))
- if points is not None:
- points = mindspore_transpose(points,(0,2,1))
-
- B, N, C = mindspore_shape(xyz)
- S = self.npoint
- idx=farthest_point_sample(xyz, S)
- #print(idx.shape,type(idx))
- new_xyz = index_points(xyz, idx)
- new_points_list = []
- for i, radius in enumerate(self.radius_list):
- K = self.nsample_list[i]
-
- group_idx = query_ball_point(radius, K, xyz, new_xyz)
- grouped_xyz = index_points(xyz, group_idx)
- grouped_xyz -= mindspore_reshape(new_xyz,(B, S, 1, C))
- if points is not None:
- grouped_points = index_points(points, group_idx)
- grouped_points = mindspore_concat((grouped_points, grouped_xyz))
- else:
- grouped_points = grouped_xyz
- grouped_points = mindspore_transpose(grouped_points,(0, 3, 2, 1))# [B, D, K, S]
-
- for j in range(len(self.conv_blocks[i])):
- conv = self.conv_blocks[i][j]
- bn = self.bn_blocks[i][j]
- grouped_points = self.relu(bn(conv(grouped_points)))
- new_points = grouped_points.max(axis=2) # [B, D', S]
- new_points_list.append(new_points)
-
- new_xyz = mindspore_transpose(new_xyz,(0,2,1))
- new_points_concat = mindspore_concat1(new_points_list)
- return new_xyz, new_points_concat
- def calculate_fan_in_and_fan_out(shape):
- """
- calculate fan_in and fan_out
-
- Args:
- shape (tuple): input shape.
-
- Returns:
- Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
- """
- dimensions = len(shape)
- if dimensions < 2:
- raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
- if dimensions == 2: # Linear
- fan_in = shape[1]
- fan_out = shape[0]
- else:
- num_input_fmaps = shape[1]
- num_output_fmaps = shape[0]
- receptive_field_size = 1
- if dimensions > 2:
- receptive_field_size = shape[2] * shape[3]
- fan_in = num_input_fmaps * receptive_field_size
- fan_out = num_output_fmaps * receptive_field_size
- return fan_in, fan_out
- class PointNetSetAbstraction(nn.Cell):
- """PointNetSetAbstraction"""
-
- def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
- super(PointNetSetAbstraction, self).__init__()
- self.npoint = npoint
- self.radius = radius
- self.nsample = nsample
- self.group_all = group_all
-
- self.conv1 = Conv2d(in_channel, mlp[0], 1)
- self.bn1 = nn.BatchNorm2d(mlp[0])
- self.conv2 = Conv2d(mlp[0], mlp[1], 1)
- self.bn2 = nn.BatchNorm2d(mlp[1])
- self.conv3 = Conv2d(mlp[1], mlp[2], 1)
- self.bn3 = nn.BatchNorm2d(mlp[2])
-
- self.relu = P.ReLU()
- self.transpose = P.Transpose()
- self.reduce_max = P.ReduceMax()
-
- def construct(self, xyz, points):
- """
- Input:
- xyz: input points position data, [B, C, N]
- points: input points data, [B, D, N]
- Return:
- new_xyz: sampled points position data, [B, C, S]
- new_points_concat: sample points feature data, [B, D', S]
- """
- xyz=mindspore_transpose(xyz,(0,2,1))
- if points is not None:
- points = self.transpose(points, (0, 2, 1))
-
- if self.group_all:
- new_xyz, new_points = sample_and_group_all(xyz, points)
- else:
- new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
- # new_xyz: sampled points position data, [B, npoint, C]
- # new_points: sampled points data, [B, npoint, nsample, C+D]
- d1, d2, d3, d4 = new_points.shape
- new_points = self.transpose(new_points.reshape((d1, d2 * d3, d4)), (0, 2, 1))
- new_points = self.transpose(new_points.reshape((d1 * d4, d2, d3)), (0, 2, 1)).reshape((d1, d4, d3, d2))
-
- new_points = self.relu(self.bn1(self.conv1(new_points)))
- new_points = self.relu(self.bn2(self.conv2(new_points)))
- new_points = self.relu(self.bn3(self.conv3(new_points)))
-
- new_points = self.reduce_max(new_points, 2)
- new_xyz=mindspore_transpose(new_xyz,(0,2,1))
- return new_xyz, new_points
- class PointNetFeaturePropagation(nn.Cell):
- def __init__(self, in_channel, mlp):
- super(PointNetFeaturePropagation, self).__init__()
- self.mlp_convs = nn.CellList()
- self.mlp_bns = nn.CellList()
- last_channel = in_channel
- for out_channel in mlp:
- self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
- self.mlp_bns.append(BatchNorm1d_cus(out_channel))
- last_channel = out_channel
- self.relu = nn.ReLU()
- @ms_function
- def layer(new_points):
- new_points=mindspore_transpose(new_points,(0,2,1))
- for i, conv in enumerate(self.mlp_convs):
- bn = self.mlp_bns[i]
- new_points = self.relu(bn(conv(new_points)))
- return new_points
- def construct(self, xyz1, xyz2, points1, points2):
- """
- Input:
- xyz1: input points position data, [B, C, N]
- xyz2: sampled input points position data, [B, C, S]
- points1: input points data, [B, D, N]
- points2: input points data, [B, D, S]
- Return:
- new_points: upsampled points data, [B, D', N]
- """
-
-
-
- xyz1=mindspore_transpose(xyz1,(0,2,1))
- xyz2=mindspore_transpose(xyz2,(0,2,1))
- if points2.ndim==2:
- points2=mindspore_expand_dims(points2,2)
- points2=mindspore_transpose(points2,(0,2,1))
- B, N, C = mindspore_shape(xyz1)
- _, S, _ = mindspore_shape(xyz2)
- #print(points1.shape,points2.shape)
-
- if S == 1:
- interpolated_points = mindspore_tile(points2,(1, N, 1))
- else:
- dists = square_distance(xyz1, xyz2)
- dists, idx = mindspore_sort(dists)
- dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
-
- dist_recip = 1.0 / (dists + 1e-8)
- norm = dist_recip.max(axis=2)
- weight = dist_recip / mindspore_expand_dims(norm,-1)
- interpolated_points = (index_points(points2, idx) * mindspore_reshape(weight,(B, N, 3, 1))).sum(axis=2)
- #print(interpolated_points.shape)
- if points1 is not None:
- points1 = mindspore_transpose(points1,(0, 2, 1))
- #print(points1.shape, interpolated_points.shape)
- new_points = mindspore_concat((points1, interpolated_points))
- else:
- new_points = interpolated_points
-
-
- new_points=self.layer(new_points)
- return new_points
|