|
- import torch
- import torch.nn as nn
- from pt2_modules import PointnetFPModule, PointnetSAModuleMSG
- import utils
- import torch.nn.functional as F
-
-
- NPOINTS = [4096, 1024, 256, 64]
- RADIUS = [[0.1, 0.5], [0.5, 1.0], [1.0, 2.0], [2.0, 4.0]]
- NSAMPLE = [[16, 32], [16, 32], [16, 32], [16, 32]]
- MLPS = [[[16, 16, 32], [32, 32, 64]], [[64, 64, 128], [64, 96, 128]],
- [[128, 196, 256], [128, 196, 256]], [[256, 256, 512], [256, 384, 512]]]
- FP_MLPS = [[128, 128], [256, 256], [512, 512], [512, 512]]
- CLS_FC = [128]
- DP_RATIO = 0.5
-
-
- class Pointnet2MSG(nn.Module):
- def __init__(self, input_channels=0, k=1):
- super().__init__()
-
- self.k = k
-
- self.SA_modules = nn.ModuleList()
- channel_in = input_channels
-
- skip_channel_list = [input_channels]
- for k in range(NPOINTS.__len__()):
- mlps = MLPS[k].copy()
- channel_out = 0
- for idx in range(mlps.__len__()):
- mlps[idx] = [channel_in] + mlps[idx]
- channel_out += mlps[idx][-1]
-
- self.SA_modules.append(
- PointnetSAModuleMSG(
- npoint=NPOINTS[k],
- radii=RADIUS[k],
- nsamples=NSAMPLE[k],
- mlps=mlps,
- use_xyz=True,
- bn=True
- )
- )
- skip_channel_list.append(channel_out)
- channel_in = channel_out
-
- self.FP_modules = nn.ModuleList()
-
- for k in range(FP_MLPS.__len__()):
- pre_channel = FP_MLPS[k + 1][-1] if k + 1 < len(FP_MLPS) else channel_out
- self.FP_modules.append(
- PointnetFPModule(mlp=[pre_channel + skip_channel_list[k]] + FP_MLPS[k])
- )
-
- cls_layers = []
- pre_channel = FP_MLPS[0][-1]
- for k in range(0, CLS_FC.__len__()):
- cls_layers.append(utils.Conv1d(pre_channel, CLS_FC[k], bn=True))
- pre_channel = CLS_FC[k]
- cls_layers.append(utils.Conv1d(pre_channel, self.k, activation=None))
- cls_layers.insert(1, nn.Dropout(0.5))
- self.cls_layer = nn.Sequential(*cls_layers)
-
-
- def forward(self, pointcloud: torch.cuda.FloatTensor):
- batchsize = pointcloud.size()[0]
- n_pts = pointcloud.size()[1]
-
- xyz = pointcloud[..., 0:3].contiguous()
- features = (pointcloud[..., 3:].transpose(1, 2).contiguous() if pointcloud.size(-1) > 3 else None)
-
- l_xyz, l_features = [xyz], [features]
- for i in range(len(self.SA_modules)):
- li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
- l_xyz.append(li_xyz)
- l_features.append(li_features)
-
- for i in range(-1, -(len(self.FP_modules) + 1), -1):
- l_features[i - 1] = self.FP_modules[i](
- l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i]
- )
-
- pred_cls = self.cls_layer(l_features[0]).transpose(1, 2).contiguous() # (B, N, 4)
- pred_cls = F.log_softmax(pred_cls.view(-1,self.k), dim=-1)
- pred_cls = pred_cls.view(batchsize, n_pts, self.k)
- #print(pred_cls.shape)
- return pred_cls
|