|
- import torch
- import torch.nn as nn
- import models
- import torch.nn.functional as F
-
-
- # 自定义ContrastiveLoss
- class ContrastiveLoss(torch.nn.Module):
-
- def __init__(self, margin=2.0):
- super(ContrastiveLoss, self).__init__()
- self.margin = margin
-
- def forward(self, output1, output2, label):
- euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
- # cosine_similarity = torch.cosine_similarity(output1, output2)
- loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
- return loss_contrastive
-
-
- class Classifier(nn.Module):
- """ Classifier with Transformer """
-
- def __init__(self, cfg, n_labels):
- super().__init__()
- self.transformer = models.Transformer(cfg)
- # self.transformer = BertModel.from_pretrained('bert-base-uncased')
- self.fc = nn.Linear(cfg.dim, cfg.dim)
- self.activ = nn.Tanh()
- self.drop = nn.Dropout(cfg.p_drop_hidden)
- self.classifier = nn.Linear(cfg.dim, n_labels)
- self.constra_loss = ContrastiveLoss()
-
- def forward(self, input_ids, segment_ids, input_mask, masked_input_ids, masked_segment_ids, masked_input_mask):
- h = self.transformer(input_ids, segment_ids, input_mask)
- # only use the first h in the sequence
- pooled_h = self.activ(self.fc(h[:, 0]))
- logits = self.classifier(self.drop(pooled_h))
- masked_h = self.transformer(masked_input_ids, masked_segment_ids, masked_input_mask)
- # only use the first h in the sequence
- masked_pooled_h = self.activ(self.fc(masked_h[:, 0]))
- masked_logits = self.classifier(self.drop(masked_pooled_h))
- cos_simi = torch.mean(F.cosine_similarity(pooled_h, masked_pooled_h, dim=-1))
- return logits, masked_logits, cos_simi
|