|
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- @Author: Yue Wang
- @Contact: yuewangx@mit.edu
- @File: model.py
- @Time: 2018/10/13 6:35 PM
-
- Modified by
- @Author: An Tao
- @Contact: ta19@mails.tsinghua.edu.cn
- @Time: 2020/3/9 9:32 PM
-
- Modified by
- @Author: Dinghao Yang
- @Contact: dinghaoyang@gmail,cin
- @Time: 2020/9/28 7:29 PM
- """
-
-
- import os
- import sys
- import copy
- import math
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.init as init
- import torch.nn.functional as F
- from sklearn import manifold
-
-
- def knn(x, k):
- inner = -2*torch.matmul(x.transpose(2, 1), x)
- xx = torch.sum(x**2, dim=1, keepdim=True)
- pairwise_distance = -xx - inner - xx.transpose(2, 1)
-
- idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
- return idx
-
-
- def adaptive_sample(x, k):
- # input x (b, d, n)
- batch_size = x.size(0)
- feature_dims = x.size(1)
- num_points = x.size(2)
- x = x.permute(0,2,1) # b, n, d
- idxs = torch.LongTensor(0).cuda()
- for i in range(num_points):
- print("processing points%d"%(i))
- points_norm = x - x[:, i, :].unsqueeze(1)
- U, S, V = torch.svd(points_norm.float())
- new_dim = int(feature_dims/2)+1
- Vk = V[:,:,:new_dim]
- points_norm_pca = points_norm.float() @ Vk
- dis = -torch.sum(points_norm_pca**2, dim=2, keepdim=True)
- idx = dis.topk(k=k, dim=1)[1].squeeze(2).unsqueeze(1)
- idxs = torch.cat((idxs, idx),1)
- print("finish sampling")
- return idxs
-
-
- def pca_points(x):
- # input x (b, d, n)
- batch_size = x.size(0)
- feature_dims = x.size(1)
- num_points = x.size(2)
- x = x.permute(0,2,1) # b, n, d
- U, S, V = torch.svd(x.float())
- new_dim = int(feature_dims/2)+1
- # print(new_dim)
- Vk = V[:,:,:new_dim]
- # print(Vk.shape)
- x_pca = x.float() @ Vk
- # print(x_pca.shape)
- # print(-x_pca)
- x_pca = -x_pca.permute(0,2,1) # b, d, n
- return x_pca
-
-
- def get_LLE(x, k, out_dim):
- # x (n, d)
- points, err = manifold.locally_linear_embedding(x, n_neighbors=k, n_components=out_dim)
- return points
-
-
- def LLE(x, k, out_dim):
- # x (b, d, n)
- x = x.permute(0,2,1).cpu() # b, n, d
- points_r = torch.tensor([get_LLE(points, k, out_dim) for points in x])
- points_r = points_r.permute(0,2,1).cuda().float() # b, d, n
- return points_r
-
-
- def get_octant_index(points_norm):
- index_list = [[] for i in range(8)]
- for i,point in enumerate(points_norm):
- if point[0] >= 0:
- if point[1] >= 0:
- if point[2] >=0:
- index_list[0].append(i) # first octant
- else:
- index_list[4].append(i) # fifth octant
- else:
- if point[2] >= 0:
- index_list[3].append(i) # fourth octant
- else:
- index_list[7].append(i) # eighth octant
- else:
- if point[1] >= 0:
- if point[2] >=0:
- index_list[1].append(i) # second octant
- else:
- index_list[5].append(i) # sixth octant
- else:
- if point[2] >= 0:
- index_list[2].append(i) # third octant
- else:
- index_list[6].append(i) # seventh octant
- return index_list
-
-
- def get_octant_nn(points_norm, index_list, k):
- relative_idx = torch.LongTensor(0).cuda()
- # print(k)
- # print(index_list)
- for octant in index_list:
- point_octant = points_norm[octant]
- # print(octant)
- dis = -torch.sum(point_octant**2, dim=1, keepdim=True)
- # print(dis)
- # print('octant len:%d'%len(octant))
- if(k > len(octant)):
- if not len(octant) == 0:
- nn = torch.tensor(octant).cuda()
- zero = torch.zeros(k-len(octant), dtype=int).cuda()
- nn = torch.cat((nn, zero),0)
- else:
- nn = torch.zeros(k, dtype=int).cuda()
- else:
- nn = torch.tensor(octant[:k]).cuda()
- # print('nn:')
- # print(nn)
- relative_idx = torch.cat((relative_idx, nn), 0)
- return relative_idx
-
-
- def get_octants_knn(x, k, rough_idx):
- idx_in_batch = torch.LongTensor(0).cuda()
- # input x (n, d), input rough_idx (n, 8*k)
- for i, point in enumerate(x):
- x_nn = x[rough_idx[i]]
- # print(rough_idx[i])
- x_nn_norm = x_nn - point
- octants_index = get_octant_index(x_nn_norm)
- relative_idx = get_octant_nn(x_nn_norm, octants_index, int(k/8))
- # print(relative_idx)
- absolute_idx = torch.tensor([rough_idx[i][j] for j in relative_idx]).cuda()
- # print(absolute_idx)
- idx_in_batch = torch.cat((idx_in_batch, absolute_idx.unsqueeze(0)),0)
- return idx_in_batch
-
-
- def octants_knn(x, k):
- # input x (b, d, n)
- rough_idx = knn(x, k=8*k) # (b, n, 8k)
- x = x.permute(0, 2, 1) # (b, n, d)
- idx = torch.LongTensor(0).cuda()
- for i, points in enumerate(x):
- print('Processing batch%d'%i)
- per_batch_idx = get_octants_knn(points, k, rough_idx[i])
- idx = torch.cat((idx, per_batch_idx.unsqueeze(0)),0)
- return idx
-
-
- def get_graph_feature(x, k=20, idx=None, dim9=False):
- batch_size = x.size(0)
- num_points = x.size(2)
- x = x.view(batch_size, -1, num_points)
- if idx is None:
- if dim9 == False:
- idx = knn(x, k=k) # (batch_size, num_points, k)
- else:
- idx = knn(x[:, 6:], k=k)
- device = torch.device('cuda')
-
- idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points
-
- idx = idx + idx_base
-
- idx = idx.view(-1)
-
- _, num_dims, _ = x.size()
-
- x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points)
- feature = x.view(batch_size*num_points, -1)[idx, :]
- feature = feature.view(batch_size, num_points, k, num_dims)
- x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
-
- feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
-
- return feature # (batch_size, 2*num_dims, num_points, k)
-
-
- class PointManifold_LLE(nn.Module):
- def __init__(self, args, output_channels=40):
- super(PointManifold_LLE, self).__init__()
- self.args = args
- self.k = args.k
-
- self.bn1 = nn.BatchNorm2d(64)
- self.bn2 = nn.BatchNorm2d(64)
- self.bn3 = nn.BatchNorm2d(128)
- self.bn4 = nn.BatchNorm2d(256)
- self.bn5 = nn.BatchNorm1d(args.emb_dims)
-
- self.conv1 = nn.Sequential(nn.Conv2d(10, 64, kernel_size=1, bias=False),
- self.bn1,
- nn.LeakyReLU(negative_slope=0.2))
- self.conv2 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
- self.bn2,
- nn.LeakyReLU(negative_slope=0.2))
- self.conv3 = nn.Sequential(nn.Conv2d(64*2, 128, kernel_size=1, bias=False),
- self.bn3,
- nn.LeakyReLU(negative_slope=0.2))
- self.conv4 = nn.Sequential(nn.Conv2d(128*2, 256, kernel_size=1, bias=False),
- self.bn4,
- nn.LeakyReLU(negative_slope=0.2))
- self.conv5 = nn.Sequential(nn.Conv1d(512, args.emb_dims, kernel_size=1, bias=False),
- self.bn5,
- nn.LeakyReLU(negative_slope=0.2))
- self.linear1 = nn.Linear(args.emb_dims*2, 512, bias=False)
- self.bn6 = nn.BatchNorm1d(512)
- self.dp1 = nn.Dropout(p=args.dropout)
- self.linear2 = nn.Linear(512, 256)
- self.bn7 = nn.BatchNorm1d(256)
- self.dp2 = nn.Dropout(p=args.dropout)
- self.linear3 = nn.Linear(256, output_channels)
-
- def forward(self, x):
- batch_size = x.size(0)
- x = get_graph_feature(x, k=self.k) # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
- x = self.conv1(x) # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
- x1 = x.max(dim=-1, keepdim=False)[0] # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)
-
- x = get_graph_feature(x1, k=self.k) # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
- x = self.conv2(x) # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
- x2 = x.max(dim=-1, keepdim=False)[0] # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)
-
- x = get_graph_feature(x2, k=self.k) # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
- x = self.conv3(x) # (batch_size, 64*2, num_points, k) -> (batch_size, 128, num_points, k)
- x3 = x.max(dim=-1, keepdim=False)[0] # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points)
-
- x = get_graph_feature(x3, k=self.k) # (batch_size, 128, num_points) -> (batch_size, 128*2, num_points, k)
- x = self.conv4(x) # (batch_size, 128*2, num_points, k) -> (batch_size, 256, num_points, k)
- x4 = x.max(dim=-1, keepdim=False)[0] # (batch_size, 256, num_points, k) -> (batch_size, 256, num_points)
-
- x = torch.cat((x1, x2, x3, x4), dim=1) # (batch_size, 64+64+128+256, num_points)
-
- x = self.conv5(x) # (batch_size, 64+64+128+256, num_points) -> (batch_size, emb_dims, num_points)
- x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
- x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
- x = torch.cat((x1, x2), 1) # (batch_size, emb_dims*2)
-
- x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512)
- x = self.dp1(x)
- x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) # (batch_size, 512) -> (batch_size, 256)
- x = self.dp2(x)
- x = self.linear3(x) # (batch_size, 256) -> (batch_size, output_channels)
-
- return x
-
-
- class PointManifold_NNML(nn.Module):
- def __init__(self, args, output_channels=40):
- super(PointManifold_NNML, self).__init__()
- self.args = args
- self.k = args.k
-
- self.bn0_0 = nn.BatchNorm1d(2)
- self.bn0_1 = nn.BatchNorm1d(2)
- self.bn0_2 = nn.BatchNorm1d(2)
- self.bn1 = nn.BatchNorm2d(128*args.hyper_times)
- self.bn2 = nn.BatchNorm2d(128*args.hyper_times)
- self.bn3 = nn.BatchNorm2d(256*args.hyper_times)
- self.bn4 = nn.BatchNorm2d(512*args.hyper_times)
- self.bn5 = nn.BatchNorm1d(args.emb_dims)
-
- self.conv0_0 = nn.Sequential(nn.Conv1d(2, 2, kernel_size = 1, bias=False),
- self.bn0_0,
- nn.LeakyReLU(negative_slope=0.2))
- self.conv0_1 = nn.Sequential(nn.Conv1d(2, 2, kernel_size = 1, bias=False),
- self.bn0_1,
- nn.LeakyReLU(negative_slope=0.2))
- self.conv0_2 = nn.Sequential(nn.Conv1d(2, 2, kernel_size = 1, bias=False),
- self.bn0_2,
- nn.LeakyReLU(negative_slope=0.2))
- self.conv1 = nn.Sequential(nn.Conv2d(18, 128*args.hyper_times, kernel_size=1, bias=False),
- self.bn1,
- nn.LeakyReLU(negative_slope=0.2))
- self.conv2 = nn.Sequential(nn.Conv2d(128*2*args.hyper_times, 128*args.hyper_times, kernel_size=1, bias=False),
- self.bn2,
- nn.LeakyReLU(negative_slope=0.2))
- self.conv3 = nn.Sequential(nn.Conv2d(128*2*args.hyper_times, 256*args.hyper_times, kernel_size=1, bias=False),
- self.bn3,
- nn.LeakyReLU(negative_slope=0.2))
- self.conv4 = nn.Sequential(nn.Conv2d(256*2*args.hyper_times, 512*args.hyper_times, kernel_size=1, bias=False),
- self.bn4,
- nn.LeakyReLU(negative_slope=0.2))
- self.conv5 = nn.Sequential(nn.Conv1d(1024*args.hyper_times, args.emb_dims, kernel_size=1, bias=False),
- self.bn5,
- nn.LeakyReLU(negative_slope=0.2))
- self.linear1 = nn.Linear(args.emb_dims*2, 512, bias=False)
- self.bn6 = nn.BatchNorm1d(512)
- self.dp1 = nn.Dropout(p=args.dropout)
- self.linear2 = nn.Linear(512, 256)
- self.bn7 = nn.BatchNorm1d(256)
- self.dp2 = nn.Dropout(p=args.dropout)
- self.linear3 = nn.Linear(256, output_channels)
-
- def forward(self, x):
- batch_size = x.size(0) # x (b, 3, n)
- # x = LLE(x, 12, 2)
- x_2d_z = self.conv0_0(x[:,:2,:]) # (b, 2, n)
- x_2d_z = torch.mul(x_2d_z, x[:, 2, :].reshape(batch_size, 1, -1)) # (b, 2, n) x (b, 1, n)
- x_2d_y = self.conv0_1(x[:,[0,2],:])
- x_2d_y = torch.mul(x_2d_y, x[:, 1, :].reshape(batch_size, 1, -1))
- x_2d_x = self.conv0_1(x[:,1:3,:])
- x_2d_x = torch.mul(x_2d_x, x[:, 0, :].reshape(batch_size, 1, -1))
- x = torch.cat((x, x_2d_x, x_2d_y, x_2d_z), 1) # (b, 9, n)
- x = get_graph_feature(x, k=self.k, lle_knn=True) # (batch_size, 7, num_points) -> (batch_size, 7x2, num_points, k)
- x = self.conv1(x) # (batch_size, 7*2, num_points, k) -> (batch_size, 128, num_points, k)
- x1 = x.max(dim=-1, keepdim=False)[0] # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points)
-
- x = get_graph_feature(x1, k=self.k) # (batch_size, 128, num_points) -> (batch_size, 128*2, num_points, k)
- x = self.conv2(x) # (batch_size, 128*2, num_points, k) -> (batch_size, 128, num_points, k)
- x2 = x.max(dim=-1, keepdim=False)[0] # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points)
-
- x = get_graph_feature(x2, k=self.k) # (batch_size, 128, num_points) -> (batch_size, 128*2, num_points, k)
- x = self.conv3(x) # (batch_size, 128*2, num_points, k) -> (batch_size, 256, num_points, k)
- x3 = x.max(dim=-1, keepdim=False)[0] # (batch_size, 256, num_points, k) -> (batch_size, 256, num_points)
-
- x = get_graph_feature(x3, k=self.k) # (batch_size, 256, num_points) -> (batch_size, 256*2, num_points, k)
- x = self.conv4(x) # (batch_size, 256*2, num_points, k) -> (batch_size, 512, num_points, k)
- x4 = x.max(dim=-1, keepdim=False)[0] # (batch_size, 512, num_points, k) -> (batch_size, 512, num_points)
-
- x = torch.cat((x1, x2, x3, x4), dim=1) # (batch_size, 128+128+256+512, num_points)
-
- x = self.conv5(x) # (batch_size, 128+128+256+512, num_points) -> (batch_size, emb_dims, num_points)
- x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
- x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
- x = torch.cat((x1, x2), 1) # (batch_size, emb_dims*2)
-
- x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512)
- x = self.dp1(x)
- x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) # (batch_size, 512) -> (batch_size, 256)
- x = self.dp2(x)
- x = self.linear3(x) # (batch_size, 256) -> (batch_size, output_channels)
-
- return x
|