|
- import itertools
- import csv
- import fire
- import torch
- import torch.nn as nn
- from torch.utils.data import Dataset, DataLoader
- import tokenization
- import models
- import optim
- import train
- from utils import set_seeds, get_device, truncate_tokens_pair
- from transformers import BertModel, BertTokenizer
- from SiameseBert import SiameseNetwork
- from MLMAugment import create_mask_sample
- import numpy as np
- from classification import Classifier
- from TripMLMBert import TripMLMBert
- import os
- import jsonlines
-
-
- def f1_loss(y_true: torch.Tensor, y_pred: torch.Tensor, is_training=False) -> torch.Tensor:
- assert y_true.ndim == 1
- assert y_pred.ndim == 1 or y_pred.ndim == 2
-
- if y_pred.ndim == 2:
- y_pred = y_pred.argmax(dim=1)
-
- tp = (y_true * y_pred).sum().to(torch.float32)
- tn = ((1 - y_true) * (1 - y_pred)).sum().to(torch.float32)
- fp = ((1 - y_true) * y_pred).sum().to(torch.float32)
- fn = (y_true * (1 - y_pred)).sum().to(torch.float32)
-
- epsilon = 1e-7
-
- precision = tp / (tp + fp + epsilon)
- recall = tp / (tp + fn + epsilon)
-
- f1 = 2 * (precision * recall) / (precision + recall + epsilon)
- f1.requires_grad = is_training
- return f1
-
-
- class CsvDataset(Dataset):
- """ Dataset Class for CSV file """
- labels = None
-
- def __init__(self, file, pipeline=[]): # cvs file and pipeline object
- Dataset.__init__(self)
- data = []
- with open(file, "r", encoding='utf-8') as f:
- # list of splitted lines : line is also list
- lines = csv.reader(f, delimiter='\t', quotechar=None)
- for instance in self.get_instances(lines): # instance : tuple of fields
- for proc in pipeline: # a bunch of pre-processing
- instance = proc(instance)
- data.append(instance)
-
- # To Tensors
- self.tensors = [torch.tensor(x, dtype=torch.long) for x in zip(*data)]
-
- def __len__(self):
- return self.tensors[0].size(0)
-
- def __getitem__(self, index):
- return tuple(tensor[index] for tensor in self.tensors)
-
- def get_instances(self, lines):
- """ get instance array from (csv-separated) line list """
- raise NotImplementedError
-
-
- class JsonDataset(Dataset):
- """ Dataset Class for CSV file """
- labels = None
-
- def __init__(self, file, pipeline=[]): # cvs file and pipeline object
- Dataset.__init__(self)
- data = []
- with open(file, "r+", encoding='utf-8') as f:
- # list of splitted lines : line is also list
- lines = jsonlines.Reader(f)
- # lines = csv.reader(f, delimiter='\t', quotechar=None)
- for instance in self.get_instances(lines): # instance : tuple of fields
- for proc in pipeline: # a bunch of pre-processing
- instance = proc(instance)
- data.append(instance)
-
- # To Tensors
- self.tensors = [torch.tensor(x, dtype=torch.long) for x in zip(*data)]
-
- def __len__(self):
- return self.tensors[0].size(0)
-
- def __getitem__(self, index):
- return tuple(tensor[index] for tensor in self.tensors)
-
- def get_instances(self, lines):
- """ get instance array from (csv-separated) line list """
- raise NotImplementedError
-
-
- class MRPC(CsvDataset):
- """ Dataset class for MRPC """
- labels = ("0", "1") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- for line in itertools.islice(lines, 1, None): # skip header
- yield line[0], line[3], line[4] # label, text_a, text_b
-
-
- class SST2(CsvDataset):
- """ Dataset class for MRPC """
- labels = ("0", "1") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- for line in itertools.islice(lines, 1, None): # skip header
- yield line[-1], line[0] # label, text_a, text_b
-
-
- class MNLI(CsvDataset):
- """ Dataset class for MNLI """
- labels = ("contradiction", "entailment", "neutral") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- print_line = 0
- for line in itertools.islice(lines, 1, None): # skip header
- yield line[-1], line[8], line[9] # label, text_a, text_b
- print_line += 1
- if print_line % 10000 == 0:
- print("***********{}***********".format(print_line))
-
-
- class SNLI(JsonDataset):
- """ Dataset class for MNLI """
- labels = ("contradiction", "entailment", "neutral") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- print_line = 0
- for line in itertools.islice(lines, 1, None): # skip header
- if line['annotator_labels'][0] != '-':
- yield line['annotator_labels'][0], line['sentence1'], line['sentence2'] # label, text_a, text_b
- print_line += 1
- print("***********{}***********".format(print_line))
-
-
- class HANS(CsvDataset):
- """ Dataset class for MNLI """
- labels = ("contradiction", "entailment", "neutral") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- print_line = 0
- for line in itertools.islice(lines, 1, None): # skip header
- yield 'entailment', line[5], line[6] # label, text_a, text_b
- print_line += 1
- print("***********{}***********".format(print_line))
-
-
- class FEVER(CsvDataset):
- """ Dataset class for MNLI """
- labels = ("SUPPORTS", "REFUTES", "NOT ENOUGH INFO") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- print_line = 0
- for line in itertools.islice(lines, 1, None): # skip header
- yield line[-1], line[0], line[1] # label, text_a, text_b
- print_line += 1
- if print_line % 10000 == 0:
- print("***********{}***********".format(print_line))
-
-
- class FEVERBiase(CsvDataset):
- """ Dataset class for MNLI """
- labels = ("SUPPORTS", "REFUTES", "NOT ENOUGH INFO") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- print_line = 0
- for line in itertools.islice(lines, 1, None): # skip header
- yield line[-1], line[0] # label, text_a, text_b
- print_line += 1
- if print_line % 10000 == 0:
- print("***********{}***********".format(print_line))
-
-
- class FEVERSYS(CsvDataset):
- """ Dataset class for MNLI """
- labels = ("SUPPORTS", "REFUTES", "NOT ENOUGH INFO") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- print_line = 0
- for line in itertools.islice(lines, 1, None): # skip header
- yield line[-1], line[0], line[1] # label, text_a, text_b
- print_line += 1
- if print_line % 10000 == 0:
- print("***********{}***********".format(print_line))
-
-
- class QQP(CsvDataset):
- """ Dataset class for MNLI """
- labels = ("contradiction", "entailment", "neutral", "0", "1", "SUPPORTS", "REFUTES", "NOT ENOUGH INFO") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- print_line = 0
- for line in itertools.islice(lines, 1, None): # skip header
- if len(line) == 6:
- yield line[-1], line[3], line[4] # label, text_a, text_b
- print_line += 1
- print('***********{}************'.format(print_line))
-
-
- class QQPPAWS(CsvDataset):
- """ Dataset class for MNLI """
- labels = ("contradiction", "entailment", "neutral", "0", "1", "SUPPORTS", "REFUTES", "NOT ENOUGH INFO") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- print_line = 0
- for line in itertools.islice(lines, 1, None): # skip header
- if len(line) == 3:
- yield line[-1], line[1], line[2] # label, text_a, text_b
- print_line += 1
- print('***********{}************'.format(print_line))
-
-
- class IMDB(CsvDataset):
- """ Dataset class for MNLI """
- labels = ("0", "1") # label names
-
- def __init__(self, file, pipeline=[]):
- super().__init__(file, pipeline)
-
- def get_instances(self, lines):
- for line in itertools.islice(lines, 1, None): # skip header
- yield line[0], line[1] # label, text_a, text_b
-
-
- def dataset_class(task):
- """ Mapping from task string to Dataset Class """
- table = {'mrpc': MRPC, 'mnli': MNLI, 'snli': SNLI, 'imdb': IMDB, 'hans': HANS, 'qqp': QQP, 'qqppaws': QQPPAWS, 'fever': FEVER, 'feverbias': FEVERBiase, 'feversym': FEVERSYS,
- 'sst2': SST2}
- return table[task]
-
-
- class Pipeline():
- """ Preprocess Pipeline Class : callable """
-
- def __init__(self):
- super().__init__()
-
- def __call__(self, instance):
- raise NotImplementedError
-
-
- class Tokenizing(Pipeline):
- """ Tokenizing sentence pair """
-
- def __init__(self, preprocessor, tokenize):
- super().__init__()
- self.preprocessor = preprocessor # e.g. text normalization
- self.tokenize = tokenize # tokenize function
-
- def __call__(self, instance):
- if len(instance) > 2:
- label, text_a, text_b = instance
- label = self.preprocessor(label)
- tokens_a = self.tokenize(self.preprocessor(text_a))
- tokens_b = self.tokenize(self.preprocessor(text_b)) if text_b else []
- return (label, tokens_a, tokens_b)
- else:
- label, text = instance
- label = self.preprocessor(label)
- tokens = self.tokenize(self.preprocessor(text))
- masked_tokens, _, _ = create_mask_sample(sequence=tokens, mask_prob=0.2, vocab_words=tokens)
- return (label, tokens)
-
-
- class AddSpecialTokensWithTruncation(Pipeline):
- """ Add special tokens [CLS], [SEP] with truncation """
-
- def __init__(self, max_len=512):
- super().__init__()
- self.max_len = max_len
-
- def __call__(self, instance):
- if len(instance) > 2:
- label, tokens_a, tokens_b = instance
-
- # -3 special tokens for [CLS] text_a [SEP] text_b [SEP]
- # -2 special tokens for [CLS] text_a [SEP]
- _max_len = self.max_len - 3 if tokens_b else self.max_len - 2
- truncate_tokens_pair(tokens_a, tokens_b, _max_len)
-
- # Add Special Tokens
- tokens_a = ['[CLS]'] + tokens_a + ['[SEP]']
- tokens_b = tokens_b + ['[SEP]'] if tokens_b else []
-
- return (label, tokens_a, tokens_b)
- else:
- label, tokens = instance
-
- # -3 special tokens for [CLS] text_a [SEP] text_b [SEP]
- # -2 special tokens for [CLS] text_a [SEP]
- _max_len = self.max_len - 2
- while True:
- if len(tokens) <= _max_len:
- break
- else:
- tokens.pop()
-
- # Add Special Tokens
- tokens = ['[CLS]'] + tokens
-
- return (label, tokens)
-
-
- class TokenIndexing(Pipeline):
- """ Convert tokens into token indexes and do zero-padding """
-
- def __init__(self, indexer, labels, max_len=256):
- super().__init__()
- self.indexer = indexer # function : tokens to indexes
- # map from a label name to a label index
- self.label_map = {name: i for i, name in enumerate(labels)}
- self.max_len = max_len
-
- def __call__(self, instance):
- if len(instance) > 2:
- label, tokens_a, tokens_b = instance
- input_ids = self.indexer(tokens_a + tokens_b)
- segment_ids = [0] * len(tokens_a) + [1] * len(tokens_b) # token type ids
- input_mask = [1] * (len(tokens_a) + len(tokens_b))
- label_id = self.label_map[label]
-
- # zero padding
- n_pad = self.max_len - len(input_ids)
- input_ids.extend([0] * n_pad)
- segment_ids.extend([0] * n_pad)
- input_mask.extend([0] * n_pad)
-
- return (input_ids, segment_ids, input_mask, label_id)
- else:
- label, tokens = instance
- input_ids = self.indexer(tokens)
- segment_ids = [0] * len(tokens) # token type ids
- input_mask = [1] * (len(tokens))
- label_id = self.label_map[label]
-
- # zero padding
- n_pad = self.max_len - len(input_ids)
- input_ids.extend([0] * n_pad)
- segment_ids.extend([0] * n_pad)
- input_mask.extend([0] * n_pad)
-
- return (input_ids, segment_ids, label_id)
-
-
- # pretrain_file='../uncased_L-12_H-768_A-12/bert_model.ckpt',
- # pretrain_file='../exp/bert/pretrain_100k/model_epoch_3_steps_9732.pt',
-
- def main(task='mnli',
- train_cfg='config/train_mnli.json',
- model_cfg='config/bert_base.json',
- data_file='glue_data/MNLI/train.tsv',
- model_file=None,
- pretrain_file='uncased_L-12_H-768_A-12/bert_model.ckpt',
- data_parallel=True,
- vocab='uncased_L-12_H-768_A-12/vocab.txt',
- save_dir='save_model/mnli/',
- max_len=128,
- mode='train'):
- cfg = train.Config.from_json(train_cfg)
- model_cfg = models.Config.from_json(model_cfg)
-
- set_seeds(cfg.seed)
-
- tokenizer = tokenization.FullTokenizer(vocab_file=vocab, do_lower_case=True)
- mask_id = tokenizer.vocab['[MASK]']
- # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
- TaskDataset = dataset_class(task) # task dataset class according to the task
- pipeline = [Tokenizing(tokenizer.convert_to_unicode, tokenizer.tokenize), AddSpecialTokensWithTruncation(max_len),
- TokenIndexing(tokenizer.convert_tokens_to_ids, TaskDataset.labels, max_len)]
- dataset = TaskDataset(data_file, pipeline)
- data_iter = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True)
-
- # model = Classifier(model_cfg, len(TaskDataset.labels))
- model = TripMLMBert(model_cfg, cfg.batch_size, len(TaskDataset.labels), mask_id)
- criterion = nn.CrossEntropyLoss()
-
- trainer = train.Trainer(cfg, model, tokenizer, data_iter, optim.optim4GPU(cfg, model), save_dir, get_device())
-
- if mode == 'train':
- def get_loss(model, batch, global_step, tokenizer): # make sure loss is a scalar tensor
- input_ids, segment_ids, input_mask, label_id = batch
- logits, masked_logits, cos_simi, mlm_energy = model(input_ids, input_mask, segment_ids, global_step)
- loss = criterion(logits, label_id)
- masked_loss = criterion(masked_logits, label_id)
- _, label_pred = logits.max(1)
- f1 = f1_loss(label_id, label_pred)
- return loss, masked_loss, f1, cos_simi, mlm_energy
- # return loss
-
- trainer.train(get_loss, model_file, pretrain_file, data_parallel)
-
- elif mode == 'eval':
- def evaluate(model, batch):
- input_ids, segment_ids, input_mask, label_id = batch
- logits, _, _, _ = model(input_ids, input_mask, segment_ids)
- _, label_pred = logits.max(1)
- result = (label_pred == label_id).float() # .cpu().numpy()
- accuracy = result.mean()
- f1 = f1_loss(label_id, label_pred)
- return accuracy, result, f1
-
- results, f1s = trainer.eval(evaluate, model_file, data_parallel)
- total_accuracy = torch.cat(results).mean().item()
- averaged_f1 = np.mean(f1s)
- print('Accuracy:', total_accuracy, 'Averaged F1:', averaged_f1)
-
-
- if __name__ == '__main__':
- fire.Fire(main)
|