|
- import os
- import datetime
- import torch.optim as O
- import datasets
- from models.Causformer import Causformer
- from models.VanillaTransformer import VanillaTransformer
- from utils import *
- from sklearn import metrics
- from attentionviz import heatmap_attention, bipartite_attention, cross_bipartite_attention
-
-
- def load_model(model_dir):
- checkpoint = torch.load(model_dir, map_location=torch.device('cpu'))
- model_opt = checkpoint['options']
- model = Causformer(model_opt)
- model = torch.nn.DataParallel(model)
- model.load_state_dict(checkpoint['model_dict'])
- model = model.cuda()
- print('[Info] Trained model state loaded.')
- return model
-
-
- def load_causal_model(model_dir):
- checkpoint = torch.load(model_dir, map_location=torch.device('cpu'))
- model_opt = checkpoint['options']
- model = Causformer(model_opt)
- # model = torch.nn.DataParallel(model)
- model.load_state_dict(checkpoint['model_dict'])
- model = torch.nn.DataParallel(model)
- model = model.cuda()
- print('[Info] Trained model state loaded.')
- return model
-
-
- class Train():
- def __init__(self):
- print("program execution start: {}".format(datetime.datetime.now()))
- self.args = parse_args()
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- self.logger = get_logger(self.args, "train")
- self.logger.info("Arguments: {}".format(self.args))
- dataset_options = {'batch_size': self.args.batch_size, 'device': self.device, 'train_csv': '.data/qqp/new_train.tsv', 'test_csv': '.data/qqp/new_dev.tsv',
- 'paws_train_csv': '.data/paws_output/new_train.tsv', 'paws_test_csv': '.data/paws_output/new_test.tsv', 'max_length': 256}
- self.dataset = datasets.__dict__[self.args.dataset](dataset_options)
-
- self.model_options = {
- 'out_dim': self.dataset.out_dim(),
- 'dp_ratio': self.args.dp_ratio,
- 'd_hidden': self.args.d_hidden,
- 'device': self.device,
- 'dataset': self.args.dataset,
- 'pad_idx': self.dataset.pad_idx,
- 'vocab_size': self.dataset.vocab_size,
- 'nclass': self.dataset.nclass
- }
-
- self.model = Causformer(self.model_options).to(self.device)
- self.model = torch.nn.DataParallel(self.model)
- # self.model = load_causal_model("results/Causformer_main/snli/best-Causformer_main-snli-params.pt")
- self.criterion = nn.CrossEntropyLoss(reduction='sum')
- self.opt = O.Adam(self.model.parameters(), lr=self.args.lr)
- self.scheduler = O.lr_scheduler.StepLR(self.opt, step_size=50, gamma=0.5)
- self.best_val_acc = None
- print("resource preparation done: {}".format(datetime.datetime.now()))
-
- def result_checkpoint(self, epoch, train_loss, val_loss, train_acc, val_acc, took):
- if self.best_val_acc is None or val_acc > self.best_val_acc:
- self.best_val_acc = val_acc
- torch.save({
- 'accuracy': self.best_val_acc,
- 'options': self.model_options,
- 'model_dict': self.model.state_dict(),
- }, '{}/{}/{}/best-{}-{}-params.pt'.format(self.args.results_dir, self.args.model, self.args.dataset, self.args.model, self.args.dataset))
- self.logger.info('| Epoch {:3d} | train loss {:5.2f} | train acc {:5.2f} | val loss {:5.2f} | val acc {:5.2f} | time: {:5.2f}s |'
- .format(epoch, train_loss, train_acc, val_loss, val_acc, took))
-
- def train(self):
- self.model.train();
-
- self.dataset.train_iter.init_epoch()
- n_correct, n_total, n_loss, mb_total, notears_total = 0, 0, 0, 0, 0
- num_plot = 0
- for batch_idx, batch in enumerate(self.dataset.train_iter):
- self.opt.zero_grad()
- # premise = self.dataset.train_iter.dataset.examples[batch_idx].premise
- # hypothesis = self.dataset.train_iter.dataset.examples[batch_idx].hypothesis
- # answer, enc_attns_list, dec_slf_attn_list, dec_enc_attn_list = self.model(batch.premise, batch.hypothesis)
- cau_answer, MBLoss, cau_enc_attns_list, cau_dec_slf_attn_list, cau_dec_enc_attn_list = self.model(batch.premise, batch.hypothesis)
- loss = self.criterion(cau_answer, batch.label)
-
- loss_sum = loss.mean() + MBLoss.mean()
- # loss_sum = loss.mean()
- n_correct += (torch.max(cau_answer, 1)[1].view(batch.label.size()) == batch.label).sum().item()
- n_total += batch.batch_size
- n_loss += loss.mean().item()
- mb_total += MBLoss.mean().item()
-
- loss_sum.backward();
- self.opt.step()
- # self.opt.step_and_update_lr()
- # lr = self.opt._optimizer.param_groups[0]['lr']
- lr = self.opt.param_groups[0]['lr']
- progress_bar(batch_idx, len(self.dataset.train_iter),
- 'Lr: %.6f | Loss: %.3f | MBLoss: %.3f | Notears: %.3f | Acc: %.3f%%' % (
- lr, n_loss / (batch_idx + 1), mb_total / (batch_idx + 1), notears_total / (batch_idx + 1), 100. * n_correct / n_total))
- train_loss = n_loss / (batch_idx + 1)
- train_acc = 100. * n_correct / n_total
- return train_loss, train_acc
-
- def validate(self, epoch_i):
- self.model.eval();
- self.dataset.dev_iter.init_epoch()
- n_correct, n_total, n_loss, mbloss_total, notears_total, f1 = 0, 0, 0, 0, 0, 0
- true_labels = []
- pred_labels = []
- with torch.no_grad():
- for batch_idx, batch in enumerate(self.dataset.dev_iter):
- premise = self.dataset.dev_iter.dataset.examples[batch_idx].premise
- hypothesis = self.dataset.dev_iter.dataset.examples[batch_idx].hypothesis
- answer, _, enc_slf_attn_list, dec_slf_attn_list, dec_enc_attn_list = self.model(batch.premise, batch.hypothesis)
-
- # answer, _, _, _, _ = self.model(batch.premise, batch.hypothesis)
- loss = self.criterion(answer, batch.label)
- correct = (torch.max(answer, 1)[1].view(batch.label.size()) == batch.label).sum().item()
- """
- if correct == 0:
- for n_layer in range(len(dec_enc_attn_list)):
- for nhead in range(dec_enc_attn_list[n_layer].size(1)):
- cross_bipartite_attention(dec_enc_attn_list[n_layer][0, nhead, :len(premise), :len(premise)].cpu().detach(), premise, hypothesis, batch_idx, n_layer, nhead, "cross")
-
- for n_layer in range(len(enc_slf_attn_list)):
- for nhead in range(enc_slf_attn_list[n_layer].size(1)):
- bipartite_attention(enc_slf_attn_list[n_layer][0, nhead, :len(premise), :len(premise)].cpu().detach(), premise, batch_idx, n_layer, nhead, "encoder")
-
- for n_layer in range(len(dec_slf_attn_list)):
- for nhead in range(dec_slf_attn_list[n_layer].size(1)):
- bipartite_attention(dec_slf_attn_list[n_layer][0, nhead, :len(premise), :len(premise)].cpu().detach(), hypothesis, batch_idx, n_layer, nhead, "decoder")
- """
- n_correct += correct
- n_total += batch.batch_size
- n_loss += loss.mean().item()
- pred_labels += list(torch.max(answer, 1)[1].view(batch.label.size()).cpu().numpy())
- true_labels += list(batch.label.cpu().numpy())
-
- val_loss = n_loss / n_total
- val_mbloss = mbloss_total / n_total
- val_acc = 100. * n_correct / n_total
- assert len(true_labels) == len(pred_labels)
- f1 = metrics.precision_score(true_labels, pred_labels, average='macro')
- return val_loss, val_mbloss, val_acc, 100. * f1
-
- def hans_en_validate(self, epoch_i):
- self.model.eval();
- self.dataset.hans_train_iter.init_epoch()
- n_correct, n_total, n_loss, mbloss_total, f1 = 0, 0, 0, 0, 0
- true_labels = []
- pred_labels = []
- with torch.no_grad():
- for batch_idx, batch in enumerate(self.dataset.hans_train_iter):
- answer, _, _, _, _ = self.model(batch.claim, batch.evidence)
- loss = self.criterion(answer, batch.label)
- n_correct += (torch.max(answer, 1)[1].view(batch.label.size()) == batch.label).sum().item()
- n_total += batch.batch_size
- n_loss += loss.mean().item()
- pred_labels += list(torch.max(answer, 1)[1].view(batch.label.size()).cpu().numpy())
- true_labels += list(batch.label.cpu().numpy())
-
- val_loss = n_loss / n_total
- val_mbloss = mbloss_total / n_total
- val_acc = 100. * n_correct / n_total
- assert len(true_labels) == len(pred_labels)
- f1 = metrics.precision_score(true_labels, pred_labels, average='macro')
- return val_loss, val_mbloss, val_acc, 100. * f1
-
- def hans_non_validate(self, epoch_i):
- self.model.eval();
- self.dataset.hans_dev_iter.init_epoch()
- n_correct, n_total, n_loss, mbloss_total, f1 = 0, 0, 0, 0, 0
- true_labels = []
- pred_labels = []
- with torch.no_grad():
- for batch_idx, batch in enumerate(self.dataset.hans_dev_iter):
- answer, _, _, _, _ = self.model(batch.claim, batch.evidence)
- loss = self.criterion(answer, batch.label)
- n_correct += (torch.max(answer, 1)[1].view(batch.label.size()) == batch.label).sum().item()
- n_total += batch.batch_size
- n_loss += loss.mean().item()
- pred_labels += list(torch.max(answer, 1)[1].view(batch.label.size()).cpu().numpy())
- true_labels += list(batch.label.cpu().numpy())
-
- val_loss = n_loss / n_total
- val_mbloss = mbloss_total / n_total
- val_acc = 100. * n_correct / n_total
- assert len(true_labels) == len(pred_labels)
- f1 = metrics.precision_score(true_labels, pred_labels, average='macro')
- return val_loss, val_mbloss, val_acc, 100. * f1
-
- def execute(self):
- print(" [*] Training starts!")
- print('-' * 99)
- for epoch in range(1, self.args.epochs + 1):
- start = time.time()
-
- # hans_en_loss, hans_en_mbloss, hans_en_acc, hans_en_f1 = self.hans_en_validate(epoch_i=epoch + 1)
- # hans_non_loss, hans_non_mbloss, hans_non_acc, hans_non_f1 = self.hans_non_validate(epoch_i=epoch + 1)
- took = time.time() - start
- """
- print(
- '| Epoch {:3d} | entailment loss {:5.2f} | entailment acc {:5.2f} | non-entailment loss {:5.2f} | non-entailment mbloss {:5.2f} | non-entailment acc {:5.2f} | non-entailment f1 {:5.2f} | time: {:5.2f}s |' \
- .format(epoch, hans_en_loss, hans_en_acc, hans_non_loss, hans_non_mbloss, hans_non_acc, hans_non_f1, took))
- """
- train_loss, train_acc = self.train()
- # train_loss, train_acc = 0., 0.
- val_loss, val_mbloss, val_acc, val_f1 = self.validate(epoch_i=epoch + 1)
- hans_en_loss, hans_en_mbloss, hans_en_acc, hans_en_f1 = self.hans_en_validate(epoch_i=epoch + 1)
- hans_nen_loss, hans_nen_mbloss, hans_nen_acc, hans_nen_f1 = self.hans_non_validate(epoch_i=epoch + 1)
- self.scheduler.step()
-
- took = time.time() - start
- self.result_checkpoint(epoch, train_loss, val_loss, train_acc, val_acc, took)
-
- print('| Epoch {:3d} | train loss {:5.2f} | train acc {:5.2f} | val loss {:5.2f} | val mbloss {:5.2f} | val acc {:5.2f} | val f1 {:5.2f} | time: {:5.2f}s |'.format(epoch, train_loss, train_acc, val_loss, val_mbloss, val_acc, val_f1, took))
- print('| Epoch {:3d} | hans en loss {:5.2f} | hans en acc {:5.2f} | hans non loss {:5.2f} | hans non acc {:5.2f} |'.format(epoch, hans_en_loss, hans_en_acc, hans_nen_loss, hans_nen_acc))
- self.finish()
-
- def finish(self):
- self.logger.info("[*] Training finished!\n\n")
- print('-' * 99)
- print(" [*] Training finished!")
- print(" [*] Please find the saved model and training log in results_dir")
-
-
- task = Train()
- task.execute()
|