|
- import argparse
- import os
- import random
- import torch
- import torch.optim as optim
- import torch.nn.functional as F
- from torch.utils.data import DataLoader
- import sys
- sys.path.append("..")
- from data.dataset_shapenet import ShapeNetDataset
- import pointnet2.pt2_model as ptn2
- from tqdm import tqdm
- import numpy as np
-
- os.environ["CUDA_VISIBLE_DEVICES"] = '1'
-
- # class seg_loss(nn.Module):
- # def __init__(self):
- # super(seg_loss, self).__init__()
- # self.loss = nn.CrossEntropyLoss()
- # def forward(self, pred, label):
- # loss = self.loss(pred, label)
- # return loss
-
- def train(seg_model, opt, dataset, dataloader):
- optimizer = optim.Adam(seg_model.parameters(), lr=0.001, betas=(0.9, 0.999))
- scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
- seg_model.cuda()
-
- num_batch = len(dataset) / opt.batchsize
-
- for epoch in tqdm(range(opt.epochs)):
- scheduler.step()
- for i, data in enumerate(dataloader, 0):
- points, target = data
- #points = points.transpose(2, 1)
- points, target = points.cuda(), target.cuda()
- optimizer.zero_grad()
- seg_model = seg_model.train()
- pred = seg_model(points)
- #print('----------------->', pred.size())
- pred = pred.view(-1, opt.nclasses)
- target = target.view(-1, 1)[:, 0] - 1
- #print(pred.size(), target.size())
- loss = torch.nn.functional.nll_loss(pred, target)
-
- loss.backward()
- optimizer.step()
- pred_choice = pred.data.max(1)[1]
- correct = pred_choice.eq(target.data).cpu().sum()
- print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item()/float(opt.batchsize * 2500)))
-
- if epoch % 10 == 0:
- torch.save(seg_model.state_dict(), '%s/seg_model_%d.pth' % (opt.ckp, epoch))
-
-
- ## Testing...
- ## benchmark mIOU
- def test(seg_model, opt, test_dataloader, checkpoint_path):
- logfile.flush()
- seg_model.cuda()
- seg_model.load_state_dict(torch.load(os.path.join(opt.ckp, checkpoint_path)))
- eg_model = seg_model.eval()
- shape_ious = []
- for i,data in tqdm(enumerate(test_dataloader, 0)):
- points, target = data
- #points = points.transpose(2, 1)
- points, target = points.cuda(), target.cuda()
-
- pred = seg_model(points)
- pred_choice = pred.data.max(2)[1]
-
- pred_np = pred_choice.cpu().data.numpy()
- target_np = target.cpu().data.numpy() - 1
- #average_iou, acc = cal_accuracy_iou(pred_np, target_np, seg_classes=num_classes)
- for shape_idx in range(target_np.shape[0]):
- parts = range(num_classes)#np.unique(target_np[shape_idx])
- part_ious = []
- for part in parts:
- I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part))
- U = np.sum(np.logical_or(pred_np[shape_idx] == part, target_np[shape_idx] == part))
- if U == 0:
- iou = 1
- else:
- iou = I / float(U)
- part_ious.append(iou)
- shape_ious.append(np.mean(part_ious))
- print("mIOU for class {}: {}\n".format(opt.cat_choice, np.mean(shape_ious)))
- logfile.write("mIOU for class {}: {}\n".format(opt.cat_choice, np.mean(shape_ious)))
- logfile.close()
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--batchsize', type=int, default=16, help='input batch size')
- parser.add_argument('--epochs', type=int, default=101, help='number of epochs to train for')
- parser.add_argument('--ckp', type=str, default='checkpoints', help='output folder')
- parser.add_argument('--model', type=str, default='', help='model path')
- parser.add_argument('--dataset', type=str, default='../data/', required=False, help="dataset path")
- parser.add_argument('--cat_choice', type=str, default='Chair', help="class_choice")
- parser.add_argument('--nclasses', type=int, default=50, help='Number of classes')
- parser.add_argument('--mode', type=str, default='train', help='train or test')
-
- logfile = open('./log.txt','a')
- opt = parser.parse_args()
- print(opt)
-
- if not os.path.exists(opt.ckp):
- os.makedirs(opt.ckp)
-
- opt.manualSeed = random.randint(1, 10000) # fix seed
- print("Random Seed: ", opt.manualSeed)
- random.seed(opt.manualSeed)
- torch.manual_seed(opt.manualSeed)
-
- dataset = ShapeNetDataset(data_root=opt.dataset, split='train')
- dataloader = DataLoader(dataset, batch_size=opt.batchsize, shuffle=True, num_workers=4)
-
- test_dataset = ShapeNetDataset(data_root=opt.dataset, cat_choice=opt.cat_choice, split='test')
- test_dataloader = DataLoader(test_dataset, batch_size=opt.batchsize//2, shuffle=False, num_workers=4)
-
-
- print('-----------------------------')
- print(len(dataset), len(test_dataset))
- print('train on all catagery data')
- print('test on a specified categery: %s' % opt.cat_choice)
-
- if opt.cat_choice is not None:
- num_classes = test_dataset.seg_classes[opt.cat_choice]
-
- print('test_classes', num_classes)
- print('-----------------------------')
-
- seg_model = ptn2.Pointnet2MSG(k=opt.nclasses)
-
-
- if opt.mode == 'train':
- train(seg_model, opt, dataset, dataloader)
-
- if opt.mode == 'test':
- checkpoint_path = 'seg_model_90.pth'
- test(seg_model, opt, test_dataloader, checkpoint_path)
|