|
- import torch
- import torch.nn as nn
- import models
- import torch.nn.functional as F
- import math
- import matplotlib.pyplot as plt
-
-
- # 自定义ContrastiveLoss50
- class ContrastiveLoss(torch.nn.Module):
-
- def __init__(self, margin=2.0):
- super(ContrastiveLoss, self).__init__()
- self.margin = margin
-
- def forward(self, output1, output2, label):
- euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
- # cosine_similarity = torch.cosine_similarity(output1, output2)
- loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) + (label) * torch.pow(
- torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
- return loss_contrastive
-
-
- class TripMLMBert(nn.Module):
- """ Classifier with Transformer """
-
- def __init__(self, cfg, batch_size, n_labels, mask_id):
- super().__init__()
- self.mask_id = mask_id
- self.transformer = models.Transformer(cfg)
- # self.transformer = BertModel.from_pretrained('bert-base-uncased')
- self.fc = nn.Linear(cfg.dim, cfg.dim)
- self.mlms = nn.ModuleList(nn.Linear(cfg.dim, 1) for _ in range(15))
- self.activ = nn.Tanh()
- self.relu = nn.ReLU()
- self.drop = nn.Dropout(cfg.p_drop_hidden)
- self.classifier = nn.Linear(cfg.dim, n_labels)
- self.constra_loss = ContrastiveLoss()
- self.lstm = nn.LSTM(input_size=cfg.dim, hidden_size=int(0.5 * cfg.dim), num_layers=4, bias=True, batch_first=True, dropout=0.1, bidirectional=True)
-
- def forward(self, input_ids, input_mask, segment_ids, global_step=0, tokenizer=None):
- h = self.lstm(self.transformer(input_ids, segment_ids, input_mask))[0]
- # h = self.transformer(input_ids, segment_ids, input_mask)
- mlm_energy = 0.
- masked_input_ids = input_ids
- for mlm in self.mlms:
- mlm_logit = self.relu(mlm(h).squeeze(-1)) * input_mask
- mlm_energy += torch.mean((mlm_logit + math.e) / (2 * math.e))
- mlm_samples = F.gumbel_softmax(mlm_logit, hard=True).long()
- masked_input_ids = (1 - mlm_samples) * masked_input_ids + mlm_samples * self.mask_id
- # only use the first h in the sequence
- pooled_h = self.activ(self.fc(h[:, 0]))
- logits = self.classifier(self.drop(pooled_h))
- masked_h = self.lstm(self.transformer(masked_input_ids, segment_ids, input_mask))[0]
- # only use the first h in the sequence
- masked_pooled_h = self.activ(self.fc(masked_h[:, 0]))
- masked_logits = self.classifier(self.drop(masked_pooled_h))
- cos_simi = torch.mean(F.cosine_similarity(h, masked_h, dim=-1))
- return logits, masked_logits, cos_simi, mlm_energy
|