|
- import os
- from tqdm import tqdm
- from tqdm import trange
- # import torch
- # from torch.nn import functional as F
- # from torch import distributions as dist
- import tensorflow as tf
- import numpy
-
-
- class Trainer():
- ''' Trainer object for the Occupancy Network.
-
- Args:
- model (nn.Module): Occupancy Network model
- optimizer (optimizer): pytorch optimizer object
- device (device): pytorch device
- input_type (str): input type
- vis_dir (str): visualization directory
- threshold (float): threshold value
- eval_sample (bool): whether to evaluate samples
-
- '''
-
- # def cos_sim(self, x1, x2):
- # scores = torch.acos(torch.cosine_similarity(x1, x2, dim=2))/numpy.pi
- # return scores.mean()
-
-
- def __init__(self, model, optimizer, device=None, input_type='img',
- vis_dir=None, threshold=0.5, eval_sample=False):
- self.model = model
- self.optimizer = optimizer
- self.device = device
- self.input_type = input_type
- self.vis_dir = vis_dir
- self.threshold = threshold
- self.eval_sample = eval_sample
-
- if vis_dir is not None and not os.path.exists(vis_dir):
- os.makedirs(vis_dir)
-
- def train_step(self, points, output):
- ''' Performs a training step.
-
- Args:
- data (dict): data dictionary
- '''
-
- with tf.GradientTape() as tape:
- loss= self.compute_loss(points, output)
- gradients = tape.gradient(loss, self.model.trainable_variables)
- self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
-
- return loss
-
- def evaluate(self, val_loader):
- ''' Performs an evaluation.
- Args:
- val_loader (dataloader): pytorch dataloader
- '''
-
- val=0.0
- num=0
- for points, output in tqdm(val_loader):
- eval_step_dict = self.eval_step(points, output)
- val=val+eval_step_dict
- num=num+1
-
- return val/num
-
- def eval_step(self, points, output):
- ''' Performs an evaluation step.
-
- Args:
- data (dict): data dictionary
- '''
- # Compute elbo
-
- # points = data[0]
- # output = data[1]
-
- kwargs = {}
-
-
- mae = self.model.compute_loss(points, output)
-
- return mae
-
- def compute_loss(self, points, output):
- ''' Computes the loss.
-
- Args:
- data (dict): data dictionary
- '''
-
- # points = data[0]
- # output = data[1]
-
- c = self.model.encode_inputs(points)
- n = self.model.decode(c)
- n=tf.reshape(n, -1)
- output=tf.reshape(output, -1)
-
- #print(n.shape, output.shape)
-
- loss = tf.keras.losses.MSE(n, output)
- #print(loss)
-
- return loss
|