|
- import itertools
- import csv
- import fire
- import torch
- import torch.nn as nn
- from torch.utils.data import Dataset, DataLoader
- import tokenization
- import models
- import optim
- import train
- from utils import set_seeds, get_device, truncate_tokens_pair, my_cross_entropy
- from MLMAugment import create_mask_sample
- import numpy as np
- from ADBlur_RoBERTa import VAE_RoBERTa
- from transformers import RobertaTokenizer
-
-
- def f1_loss(y_true: torch.Tensor, y_pred: torch.Tensor, is_training=False) -> torch.Tensor:
- assert y_true.ndim == 1
- assert y_pred.ndim == 1 or y_pred.ndim == 2
-
- if y_pred.ndim == 2:
- y_pred = y_pred.argmax(dim=1)
-
- tp = (y_true * y_pred).sum().to(torch.float32)
- tn = ((1 - y_true) * (1 - y_pred)).sum().to(torch.float32)
- fp = ((1 - y_true) * y_pred).sum().to(torch.float32)
- fn = (y_true * (1 - y_pred)).sum().to(torch.float32)
-
- epsilon = 1e-7
-
- precision = tp / (tp + fp + epsilon)
- recall = tp / (tp + fn + epsilon)
-
- f1 = 2 * (precision * recall) / (precision + recall + epsilon)
- f1.requires_grad = is_training
- return f1
-
-
- class CsvDataset(Dataset):
- """ Dataset Class for CSV file """
- labels = None
-
- def __init__(self, file, pipeline=[]): # cvs file and pipeline object
- Dataset.__init__(self)
- data = []
- with open(file, "r", encoding='utf-8') as f:
- # list of splitted lines : line is also list
- lines = csv.reader(f, delimiter='\t', quotechar=None)
- for instance in self.get_instances(lines): # instance : tuple of fields
- for proc in pipeline: # a bunch of pre-processing
- instance = proc(instance)
- data.append(instance)
-
- # To Tensors
- self.tensors = [torch.tensor(x, dtype=torch.long) for x in zip(*data)]
-
- def __len__(self):
- return self.tensors[0].size(0)
-
- def __getitem__(self, index):
- return tuple(tensor[index] for tensor in self.tensors)
-
- def get_instances(self, lines):
- """ get instance array from (csv-separated) line list """
- raise NotImplementedError
-
-
- class MRPC(CsvDataset):
- """ Dataset class for MRPC """
- labels = ("0", "1") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- for line in itertools.islice(lines, 1, None): # skip header
- yield line[0], line[3], line[4] # label, text_a, text_b
-
-
- class SST2(CsvDataset):
- """ Dataset class for MRPC """
- labels = ("0", "1") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- for line in itertools.islice(lines, 1, None): # skip header
- yield line[-1], line[0] # label, text_a, text_b
-
-
- class MNLI(CsvDataset):
- """ Dataset class for MNLI """
- labels = ("contradiction", "entailment", "neutral", "non-entailment") # label names
- trainable = (0, 1)
- spurious = (0, 1)
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- print_line = 0
- for line in itertools.islice(lines, 1, None): # skip header
- yield line[-1], line[-2], line[-3], line[0], line[1] # trainable, label, text_a, text_b, confounder
- print_line += 1
- if print_line % 10000 == 0:
- print("***********{}***********".format(print_line))
-
-
- class SNLI(CsvDataset):
- """ Dataset class for MNLI """
- labels = (
- "contradiction", "entailment", "neutral", "0", "1", "SUPPORTS", "REFUTES", "NOT ENOUGH INFO") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- print_line = 0
- for line in itertools.islice(lines, 1, None): # skip header
- yield line[-1], line[8], line[9] # label, text_a, text_b
- print_line += 1
- print("***********{}***********".format(print_line))
-
-
- class HANS(CsvDataset):
- """ Dataset class for MNLI """
- labels = ("contradiction", "entailment", "neutral") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- print_line = 0
- for line in itertools.islice(lines, 1, None): # skip header
- if line[0] == 'entailment':
- yield 'entailment', line[5], line[6] # label, text_a, text_b
- print_line += 1
- print("***********{}***********".format(print_line))
-
-
- class FEVER(CsvDataset):
- """ Dataset class for MNLI """
- labels = ("SUPPORTS", "REFUTES", "NOT ENOUGH INFO") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- print_line = 0
- for line in itertools.islice(lines, 1, None): # skip header
- yield line[-1], line[0], line[1] # label, text_a, text_b
- print_line += 1
- if print_line % 10000 == 0:
- print("***********{}***********".format(print_line))
-
-
- class FEVERBiase(CsvDataset):
- """ Dataset class for MNLI """
- labels = ("SUPPORTS", "REFUTES", "NOT ENOUGH INFO") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- print_line = 0
- for line in itertools.islice(lines, 1, None): # skip header
- yield line[-1], line[0] # label, text_a, text_b
- print_line += 1
- if print_line % 10000 == 0:
- print("***********{}***********".format(print_line))
-
-
- class FEVERSYS(CsvDataset):
- """ Dataset class for MNLI """
- labels = ("SUPPORTS", "REFUTES", "NOT ENOUGH INFO") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- print_line = 0
- for line in itertools.islice(lines, 1, None): # skip header
- yield line[-1], line[0], line[1] # label, text_a, text_b
- print_line += 1
- if print_line % 10000 == 0:
- print("***********{}***********".format(print_line))
-
-
- class QQP(CsvDataset):
- """ Dataset class for MNLI """
- labels = ("0", "1") # label names
- trainable = (0, 1)
- spurious = (0, 1)
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- print_line = 0
- for line in itertools.islice(lines, 1, None): # skip header
- yield line[-1], line[-2], line[-3], line[0], line[1] # trainable, label, text_a, text_b, confounder
- print_line += 1
- if print_line % 10000 == 0:
- print('***********{}************'.format(print_line))
-
-
- class QQPPAWS(CsvDataset):
- """ Dataset class for MNLI """
- labels = (
- "contradiction", "entailment", "neutral", "0", "1", "SUPPORTS", "REFUTES", "NOT ENOUGH INFO") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- print_line = 0
- for line in itertools.islice(lines, 1, None): # skip header
- if len(line) == 3:
- yield line[-1], line[1], line[2] # label, text_a, text_b
- print_line += 1
- print('***********{}************'.format(print_line))
-
-
- class IMDB(CsvDataset):
- """ Dataset class for MNLI """
- labels = ("0", "1") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- for line in itertools.islice(lines, 1, None): # skip header
- yield line[0], line[1] # label, text_a, text_b
-
-
- def dataset_class(task):
- """ Mapping from task string to Dataset Class """
- table = {'mrpc': MRPC, 'mnli': MNLI, 'snli': SNLI, 'imdb': IMDB, 'hans': HANS, 'qqp': QQP, 'qqppaws': QQPPAWS,
- 'fever': FEVER, 'feverbias': FEVERBiase,
- 'feversym': FEVERSYS,
- 'sst2': SST2}
- return table[task]
-
-
- class Pipeline():
- """ Preprocess Pipeline Class : callable """
-
- def __init__(self):
- super().__init__()
-
- def __call__(self, instance):
- raise NotImplementedError
-
-
- class Tokenizing(Pipeline):
- """ Tokenizing sentence pair """
-
- def __init__(self, tokenize):
- super().__init__()
- self.tokenize = tokenize # tokenize function
-
- def __call__(self, instance):
- if len(instance) > 4:
- trainable, label, spurious, text_a, text_b = instance
- tokens_a = self.tokenize(text_a)
- tokens_b = self.tokenize(text_b) if text_b else []
- return (trainable, label, spurious, tokens_a, tokens_b)
- else:
- label, text = instance
- tokens = self.tokenize(text)
- masked_tokens, _, _ = create_mask_sample(sequence=tokens, mask_prob=0.2, vocab_words=tokens)
- return (label, tokens)
-
-
- class AddSpecialTokensWithTruncation(Pipeline):
- """ Add special tokens [CLS], [SEP] with truncation """
-
- def __init__(self, max_len=512):
- super().__init__()
- self.max_len = max_len
-
- def __call__(self, instance):
- if len(instance) > 4:
- trainable, label, spurious, tokens_a, tokens_b = instance
-
- # -3 special tokens for [CLS] text_a [SEP] text_b [SEP]
- # -2 special tokens for [CLS] text_a [SEP]
- _max_len = self.max_len - 3 if tokens_b else self.max_len - 2
- truncate_tokens_pair(tokens_a, tokens_b, _max_len)
-
- # Add Special Tokens
- tokens_a = ['<s>'] + tokens_a + ['</s>']
- tokens_b = tokens_b + ['</s>'] if tokens_b else []
-
- return (trainable, label, spurious, tokens_a, tokens_b)
- else:
- label, tokens = instance
-
- # -3 special tokens for [CLS] text_a [SEP] text_b [SEP]
- # -2 special tokens for [CLS] text_a [SEP]
- _max_len = self.max_len - 2
- while True:
- if len(tokens) <= _max_len:
- break
- else:
- tokens.pop()
-
- # Add Special Tokens
- tokens = ['[CLS]'] + tokens
-
- return (label, tokens)
-
-
- class TokenIndexing(Pipeline):
- """ Convert tokens into token indexes and do zero-padding """
-
- def __init__(self, indexer, labels, trainable, spurious, max_len=256):
- super().__init__()
- self.indexer = indexer # function : tokens to indexes
- # map from a label name to a label index
- self.label_map = {name: i for i, name in enumerate(labels)}
- self.trainable_map = {str(name): i for i, name in enumerate(trainable)}
- self.spurious_map = {str(name): i for i, name in enumerate(spurious)}
- self.max_len = max_len
-
- def __call__(self, instance):
- if len(instance) > 4:
- trainable, label, spurious, tokens_a, tokens_b = instance
- input_ids = self.indexer(tokens_a + tokens_b)
- segment_ids = [0] * len(tokens_a) + [1] * len(tokens_b) # token type ids
- input_mask = [1] * (len(tokens_a) + len(tokens_b))
-
- label_id = self.label_map[label]
- trainable_id = self.trainable_map[trainable]
-
- # zero padding
- n_pad = self.max_len - len(input_ids)
- input_ids.extend([0] * n_pad)
- segment_ids.extend([0] * n_pad)
- input_mask.extend([0] * n_pad)
-
- spurious_id = int(label_id * 2 + int(spurious))
-
- return (input_ids, segment_ids, input_mask, spurious_id, trainable_id, label_id)
- else:
- label, tokens = instance
- input_ids = self.indexer(tokens)
- segment_ids = [0] * len(tokens) # token type ids
- input_mask = [1] * (len(tokens))
- label_id = self.label_map[label]
-
- # zero padding
- n_pad = self.max_len - len(input_ids)
- input_ids.extend([0] * n_pad)
- segment_ids.extend([0] * n_pad)
- input_mask.extend([0] * n_pad)
-
- return (input_ids, segment_ids, label_id)
-
-
- # pretrain_file='../uncased_L-12_H-768_A-12/bert_model.ckpt',
- # pretrain_file='../exp/bert/pretrain_100k/model_epoch_3_steps_9732.pt',
-
- def main(task='mnli',
- train_cfg='config/train_mrpc.json',
- model_cfg='config/bert_large.json',
- train_file='glue_data/MNLI/train_aug.tsv',
- dev_file='glue_data/MNLI/hans_train_aug#64.tsv',
- iid_dev_file='glue_data/MNLI/dev_matched_aug.tsv',
- ood_ent_file='glue_data/MNLI/hans_eval_aug.tsv',
- ood_nent_file='glue_data/MNLI/hans_nen_eval_aug.tsv',
- model_file='save_model/model_steps_12272.pt',
- pretrain_file='E:/Neurips2022writing/GAMLMBert/roberta-large',
- data_parallel=True,
- vocab='uncased_L-12_H-768_A-12/vocab.txt',
- save_dir='save_model/roberta-base/mnli/few-shot#64',
- max_len=128,
- mode='eval'):
- cfg = train.Config.from_json(train_cfg)
- model_cfg = models.Config.from_json(model_cfg)
-
- set_seeds(cfg.seed)
-
- # tokenizer = tokenization.FullTokenizer(vocab_file=vocab, do_lower_case=False)
- tokenizer = RobertaTokenizer('E:/Neurips2022writing/GAMLMBert/roberta-large/vocab.json', 'E:/Neurips2022writing/GAMLMBert/roberta-large/merges.txt')
- # mask_id = tokenizer.mask_token_id
- TaskDataset = dataset_class(task) # task dataset class according to the task
- pipeline = [Tokenizing(tokenizer.tokenize), AddSpecialTokensWithTruncation(max_len), TokenIndexing(tokenizer.convert_tokens_to_ids, TaskDataset.labels, TaskDataset.trainable,
- TaskDataset.spurious, max_len)]
- # train_dataset = TaskDataset(train_file, pipeline)
- # train_iter = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True)
- train_iter = None
- # dev_dataset = TaskDataset(dev_file, pipeline)
- # dev_iter = DataLoader(dev_dataset, batch_size=cfg.batch_size, shuffle=True)
- dev_iter = None
- iid_dev_dataset = TaskDataset(iid_dev_file, pipeline)
- iid_dev_iter = DataLoader(iid_dev_dataset, batch_size=cfg.batch_size, shuffle=False)
- ood_ent_dataset = TaskDataset(ood_ent_file, pipeline)
- ood_ent_iter = DataLoader(ood_ent_dataset, batch_size=cfg.batch_size, shuffle=False)
- ood_nent_dataset = TaskDataset(ood_nent_file, pipeline)
- ood_nent_iter = DataLoader(ood_nent_dataset, batch_size=cfg.batch_size, shuffle=False)
- ood_dataset = ood_ent_dataset + ood_nent_dataset
- ood_iter = DataLoader(ood_dataset, batch_size=cfg.batch_size, shuffle=False)
-
- # model = Classifier(model_cfg, len(TaskDataset.labels))
- # model = SiameseNetwork(model_cfg, len(TaskDataset.labels))
- model = VAE_RoBERTa(model_cfg, cfg.vae_hidden, cfg.vae_z_dim, model_cfg.vocab_size, len(TaskDataset.labels))
- criterion = nn.CrossEntropyLoss()
-
- optimizer_D = optim.optim4GPU(cfg, model, discriminator=True)
- optimizer_G = optim.optim4GPU(cfg, model, discriminator=False)
-
- trainer = train.Trainer(cfg, model, tokenizer, train_iter, dev_iter, iid_dev_iter, ood_ent_iter, ood_nent_iter, optimizer_D, optimizer_G, save_dir,
- get_device())
-
- if mode == 'train':
- def get_loss(model, batch, global_step, tokenizer): # make sure loss is a scalar tensor
- input_ids, segment_ids, input_mask, spurious_id, trainable_id, label_id = batch
- task_logits, confounder_logits, OOD_logits, _ = model(input_ids, segment_ids, input_mask)
- predicted_loss = my_cross_entropy(task_logits, label_id)
- spurious_loss = my_cross_entropy(confounder_logits, spurious_id)
- return predicted_loss, spurious_loss, OOD_logits, trainable_id
-
- def evaluate(model, batch):
- input_ids, segment_ids, input_mask, spurious_id, trainable_id, label_id = batch
- task_logits, confounder_logits, OOD_logits, _ = model(input_ids, segment_ids, input_mask)
- _, label_pred = task_logits.max(1)
- result = (label_pred == label_id).float() # .cpu().numpy()
- accuracy = result.mean()
- f1 = f1_loss(label_id, label_pred)
- return accuracy, result, f1
-
- trainer.train(get_loss, evaluate, model_file, pretrain_file, data_parallel)
-
- elif mode == 'eval':
- def evaluate(model, batch):
- input_ids, segment_ids, input_mask, spurious_id, trainable_id, label_id = batch
- task_logits, confounder_logits, OOD_logits, hidden = model(input_ids, segment_ids, input_mask)
- _, label_pred = task_logits.max(1)
- result = (label_pred == label_id).float() # .cpu().numpy()
- accuracy = result.mean()
- f1 = f1_loss(label_id, label_pred)
- return accuracy, result, f1, hidden
-
- results, f1s = trainer.eval(evaluate, model_file, iid_dev_iter, data_parallel)
- total_accuracy = torch.cat(results).mean().item()
- averaged_f1 = np.mean(f1s)
- print('Accuracy:', total_accuracy, 'Averaged F1:', averaged_f1)
-
-
- if __name__ == '__main__':
- fire.Fire(main)
|