|
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from transformers import BertModel, BertForSequenceClassification, BertConfig
- import models
- import checkpoint
-
-
- # 自定义ContrastiveLoss
- class ContrastiveLoss(torch.nn.Module):
-
- def __init__(self, margin=2.0):
- super(ContrastiveLoss, self).__init__()
- self.margin = margin
-
- def forward(self, output1, output2, label):
- euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
- loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
- return loss_contrastive
-
-
- class SiameseNetwork(nn.Module):
- def __init__(self, cfg, n_labels):
- super().__init__()
- self.transformer = models.Transformer(cfg)
- # self.transformer = BertModel.from_pretrained('bert-base-uncased')
- self.fc = nn.Linear(cfg.dim, cfg.dim)
- self.activ = nn.Tanh()
- self.drop = nn.Dropout(cfg.p_drop_hidden)
- self.classifier = nn.Linear(cfg.dim, n_labels)
- self.constra_loss = ContrastiveLoss()
-
- def load(self, model_file, pretrain_file):
- """ load saved model or pretrained transformer (a part of model) """
- if model_file:
- print('Loading the model from', model_file)
- self.model.load_state_dict(torch.load(model_file))
-
- elif pretrain_file: # use pretrained transformer
- print('Loading the pretrained model from', pretrain_file)
- if pretrain_file.endswith('.ckpt'): # checkpoint file in tensorflow
- checkpoint.load_model(self.model.transformer, pretrain_file)
- elif pretrain_file.endswith('.pt'): # pretrain model file in pytorch
- self.model.transformer.load_state_dict(
- {key[12:]: value
- for key, value in torch.load(pretrain_file).items()
- if key.startswith('transformer')}
- ) # load only transformer parts
-
- def forward_once(self, input_ids, input_mask, segment_ids):
- h = self.transformer(input_ids, segment_ids, input_mask)
- # only use the first h in the sequence
- pooled_h = self.activ(self.fc(h[:, 0]))
- logits = self.classifier(self.drop(pooled_h))
- return h[:, 0], logits
-
- def forward(self, input_ids, input_mask, segment_ids, masked_input_ids, masked_input_mask, masked_segment_ids):
- pooled_h, logits = self.forward_once(input_ids, input_mask, segment_ids)
- masked_pooled_h, masked_logits = self.forward_once(masked_input_ids, masked_input_mask, masked_segment_ids)
- constra_loss = self.constra_loss(pooled_h, masked_pooled_h, 0)
- return constra_loss, logits, masked_logits
|