|
- import os
- os.environ["CUDA_VISIBLE_DEVICES"] = "4"
- import sys
- import torch
- import numpy as np
-
- import datetime
- import logging
- import provider
- import importlib
- import shutil
- import argparse
-
- from pathlib import Path
- from tqdm import tqdm
- from data_utils.CloudPointsDataLoader import CloudPointsDataLoader
- import torch.nn.functional as F
-
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
- ROOT_DIR = BASE_DIR
- sys.path.append(os.path.join(ROOT_DIR, 'models_yizx'))
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- # print(torch.cuda.device_count())
- # exit()
-
- def parse_args():
- '''PARAMETERS'''
- parser = argparse.ArgumentParser('training')
- parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
- parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
- parser.add_argument('--batch_size', type=int, default=8, help='batch size in training')
- parser.add_argument('--model', default='pointclouds_diffusion_model', help='model name [default: pointnet_cls]')
- parser.add_argument('--num_category', default=17, type=int, choices=[2, 7, 17], help='training on ModelNet10/40')
- parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training')
- parser.add_argument('--learning_rate', default=2e-4, type=float, help='learning rate in training') # 0.001
- parser.add_argument('--num_point', type=int, default=8192, help='Point Number')
- parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training')
- parser.add_argument('--log_dir', type=str, default='outputs', help='experiment root')
- parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate')
- parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
- parser.add_argument('--process_data', action='store_true', default=True, help='save data offline')
- parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
-
- parser.add_argument('--do_sample', type=int, default=0, help='Whether DoSample using centroid algorithm')
- parser.add_argument('--data_concat', type=int, default=1, help='Whether DoDataConcat/'
- ' DO:multi-label, no: single label')
-
- parser.add_argument('--latent_dim', type=int, default=1024, help='default 1024 for encoder_latent')
-
-
- # Model Arguments
- parser.add_argument('--num_steps', type=int, default=200)
- parser.add_argument('--beta_1', type=float, default=1e-4)
- parser.add_argument('--beta_T', type=float, default=0.05)
- parser.add_argument('--sched_mode', type=str, default='linear')
- parser.add_argument('--flexibility', type=float, default=0.0)
- parser.add_argument('--residual', type=eval, default=True, choices=[True, False])
- parser.add_argument('--resume', type=str, default=None)
-
- # [0 for usingSampleFront num_points,
- # 1 for CentroidSample num_points,
- # 2 for uniformSample
- # ]
-
- return parser.parse_args()
-
- def tricky_rule1_for_AVSD(in_tensor):
- assert in_tensor.shape[-1] == 17
- bs, _ = in_tensor.shape
- for each_bs in range(bs):
- if in_tensor[each_bs, 2] == 1:
- in_tensor[each_bs, 0:2] == 1
- return in_tensor
-
- def inplace_relu(m):
- classname = m.__class__.__name__
- if classname.find('ReLU') != -1:
- m.inplace=True
-
-
- def test(model, loader, args):
- # mean_correct = []
- class_acc = [] #np.zeros((num_class, 3))
- classifier = model.eval()
-
-
- for item in tqdm(loader):
- # print(item)
- points, region_target, classify_target = item
- points = points.transpose(-1, -2)
- points = points.view(-1, 3, args.num_point).float()
-
- if args.num_category == 2:
- classify_target = classify_target[:, :, 0]
-
- classify_target = classify_target.view(-1, 1).long()
- target = F.one_hot(classify_target, num_classes=2)
- target = target.view(-1, 2)
-
- if not args.use_cpu:
- points, target = points.to(device), target.to(device)
- pred, trans_feat, diffusion_loss = classifier.get_code_with_diffusion_loss(points)
-
- pred_choice = pred.data.max(1)[1]
- acc_target = classify_target.view(-1)
- pred_choice = pred_choice.cpu()
- correct = pred_choice.eq(acc_target.long().data).sum()
- class_acc.append(correct.item() / float(acc_target.size()[0]))
-
- elif args.num_category == 17:
- classify_target = classify_target.view(-1, args.num_category).long()
- target = classify_target
- if not args.use_cpu:
- points, target = points.to(device), target.to(device)
- pred, trans_feat, diffusion_loss = classifier.get_code_with_diffusion_loss(points)
-
- acc_target = classify_target
- acc_target = acc_target.to(device)
- bs, class_num = acc_target.shape
-
- probs = torch.sigmoid(pred)
- pred_choice = (probs > 0.5).long()
- pred_choice = tricky_rule1_for_AVSD(pred_choice) ###############
- correct = pred_choice.eq(acc_target.long().data).sum()
- class_acc.append(correct.item() / float(bs * class_num))
-
- mean_class_acc = np.mean(class_acc)
- return mean_class_acc
-
-
- def main(args):
- def log_string(str):
- logger.info(str)
- print(str)
-
- # '''HYPER PARAMETER'''
- # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
-
- '''CREATE DIR'''
- timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
- exp_dir = Path('./log/')
- exp_dir.mkdir(exist_ok=True)
- exp_dir = exp_dir.joinpath('classification')
- exp_dir.mkdir(exist_ok=True)
- if args.log_dir is None:
- exp_dir = exp_dir.joinpath(timestr)
- else:
- exp_dir = exp_dir.joinpath(args.log_dir)
- exp_dir.mkdir(exist_ok=True)
- checkpoints_dir = exp_dir.joinpath('diffusionCheckpoints-num_category_{}-SampleMethod_{}/'.format(args.num_category, args.do_sample))
- checkpoints_dir.mkdir(exist_ok=True)
- log_dir = exp_dir.joinpath('logs/')
- log_dir.mkdir(exist_ok=True)
-
- '''LOG'''
- args = parse_args()
- logger = logging.getLogger("Model")
- logger.setLevel(logging.INFO)
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
- file_handler.setLevel(logging.INFO)
- file_handler.setFormatter(formatter)
- logger.addHandler(file_handler)
- log_string('PARAMETER ...')
- log_string(args)
-
- '''DATA LOADING'''
- log_string('Load dataset ...')
-
- if args.data_concat == 1:
- data_concat = True
- else:
- data_concat = False
- train_dataset = CloudPointsDataLoader(args, split='train', process_data=args.process_data, data_concat=data_concat)
- val_dataset = CloudPointsDataLoader(args, split='valid', process_data=args.process_data, data_concat=data_concat)
- test_dataset = CloudPointsDataLoader(args, split='test', process_data=args.process_data, data_concat=data_concat)
-
- # for item in test_dataset:
- # print(item)
- # exit()
-
- trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True)
- valDataLoader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2, drop_last=True)
- testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=2, drop_last=False)
-
- '''MODEL LOADING'''
- num_class = args.num_category
- model = importlib.import_module(args.model)
- from models_yizx.pointclouds_diffusion_loss import get_classify_loss
-
- classifier = model.pointclouds_diffusion_model(args)
- criterion = get_classify_loss(num_category=args.num_category)
- classifier.apply(inplace_relu)
-
- if not args.use_cpu:
- classifier = classifier.to(device)
- criterion = criterion.to(device)
-
- try:
- checkpoint = torch.load(str(exp_dir) + '/checkpoints_yizx/best_model.pth')
- start_epoch = checkpoint['epoch']
- classifier.load_state_dict(checkpoint['model_state_dict'])
- log_string('Use pretrain model')
- except:
- log_string('No existing model, starting training from scratch...')
- start_epoch = 0
-
- if args.optimizer == 'Adam':
- optimizer = torch.optim.Adam(
- classifier.parameters(),
- lr=args.learning_rate,
- betas=(0.9, 0.999),
- eps=1e-08,
- weight_decay=args.decay_rate
- )
- else:
- optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
-
- scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
- global_epoch = 0
- global_step = 0
- best_class_acc_OnVal = 0.0
- best_class_acc_Ontest = 0.0
-
- '''TRANING'''
- logger.info('Start training...')
- for epoch in range(start_epoch, args.epoch):
- log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
- mean_correct = []
- mean_cls_loss = []
- mean_diffusion_loss = []
- mean_total_loss = []
- classifier = classifier.train()
-
- scheduler.step()
- for batch_id, (points, region_target, classify_target) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
- optimizer.zero_grad()
-
- # points = points.data.numpy()
- # points = provider.random_point_dropout(points)
- # points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
- # points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
- # points = torch.Tensor(points)
-
- points = points.transpose(-1, -2)
- points = points.view(-1, 3, args.num_point).float()
-
- # NoActivationFunc on [Pred]
-
- if args.num_category == 2:
- classify_target = classify_target[:, :, 0]
-
- classify_target = classify_target.view(-1, 1).long()
- target = F.one_hot(classify_target, num_classes=2)
- target = target.view(-1, 2)
-
- if not args.use_cpu:
- points, target = points.to(device), target.to(device)
- pred, trans_feat, diffusion_loss = classifier.get_code_with_diffusion_loss(points)
- cls_loss = criterion(pred, target.float(), trans_feat)
- total_loss = diffusion_loss + 1.0 * cls_loss
-
- mean_cls_loss.append(cls_loss.item())
- mean_diffusion_loss.append(diffusion_loss.item())
- mean_total_loss.append(total_loss.item())
-
- pred_choice = pred.data.max(1)[1]
- acc_target = classify_target.view(-1)
- pred_choice = pred_choice.cpu()
- correct = pred_choice.eq(acc_target.long().data).sum()
- mean_correct.append(correct.item() / float(acc_target.size()[0]))
-
- total_loss.backward()
- optimizer.step()
- global_step += 1
-
- elif args.num_category == 17:
- classify_target = classify_target.view(-1, args.num_category).long()
- target = classify_target
-
- if not args.use_cpu:
- points, target = points.to(device), target.to(device)
-
- pred, trans_feat, diffusion_loss = classifier.get_code_with_diffusion_loss(points)
-
- # print(pred.shape, target.shape, trans_feat.shape)
- # exit()
-
- cls_loss = criterion(pred, target.float(), trans_feat)
- total_loss = 0.2 * diffusion_loss + cls_loss
-
- mean_cls_loss.append(cls_loss.item())
- mean_diffusion_loss.append(diffusion_loss.item())
- mean_total_loss.append(total_loss.item())
- ################################################################
- acc_target = classify_target
- acc_target = acc_target.to(device)
- bs, class_num = acc_target.shape
-
- probs = torch.sigmoid(pred)
- pred_choice = (probs > 0.5).long()
- pred_choice = tricky_rule1_for_AVSD(pred_choice) #######
- correct = pred_choice.eq(acc_target.long().data).sum()
- mean_correct.append(correct.item() / float(bs * class_num))
- ################################################################
-
- total_loss.backward()
- optimizer.step()
- global_step += 1
-
- train_instance_acc = np.mean(mean_correct)
- train_epoch_loss = np.mean(mean_total_loss)
- train_diffusion_loss = np.mean(mean_diffusion_loss)
- train_cls_loss = np.mean(mean_cls_loss)
- log_string('Train Loss: %f, [CLS] Loss: %f, [Diffusion] Loss: %f, Instance Accuracy: %f' % (train_epoch_loss, train_cls_loss, train_diffusion_loss, train_instance_acc))
-
- with torch.no_grad():
- mean_class_acc = test(classifier.eval(), valDataLoader, args)
- mean_class_acc_test = test(classifier.eval(), testDataLoader, args)
-
- # if (instance_acc >= best_instance_acc):
- # best_instance_acc = instance_acc
- # best_epoch = epoch + 1
-
- # if (mean_class_acc_test >= best_class_acc_Ontest):
- # best_class_acc_Ontest = mean_class_acc_test
-
- if (mean_class_acc >= best_class_acc_OnVal):
- best_class_acc_OnVal = mean_class_acc
- best_class_acc_Ontest = mean_class_acc_test
- best_epoch = epoch + 1
- log_string('[Val] Class Accuracy: %f' % (mean_class_acc))
- log_string('[Test] Class Accuracy: %f' % (mean_class_acc_test))
- log_string('----Best Class Accuracy[Val]: %f----' % (best_class_acc_OnVal))
- log_string('----Best Class Accuracy[Test]: %f----\n' % (best_class_acc_Ontest))
-
- if (mean_class_acc >= best_class_acc_OnVal):
- logger.info('Save model...')
- savepath = str(checkpoints_dir) + '/best_model.pth'
- log_string('Saving at %s' % savepath)
- state = {
- 'epoch': best_epoch,
- # 'instance_acc': instance_acc,
- 'class_acc': mean_class_acc,
- 'model_state_dict': classifier.state_dict(),
- 'optimizer_state_dict': optimizer.state_dict(),
- }
- torch.save(state, savepath)
- global_epoch += 1
-
- logger.info('End of training...')
-
-
- if __name__ == '__main__':
- args = parse_args()
- main(args)
|