|
- """
- Author: Yonglong Tian (yonglong@mit.edu)
- Date: May 07, 2020
- """
- from __future__ import print_function
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- def normalization(data):
-
- for i in range(len(data)):
-
- _range = torch.max(data[i]) - torch.min(data[i])
- data[i] = (data[i] - torch.min(data[i])) / _range
- return data
-
-
- # 有监督对比学习方法
- class SupConLoss(nn.Module):
-
- def __init__(self, temperature=0.07, contrast_mode='one',
- base_temperature=0.07):
-
- super(SupConLoss, self).__init__()
- # 温度系数
- self.temperature = temperature
- # 对比锚点
- self.contrast_mode = contrast_mode
- # 温度系数
- self.base_temperature = base_temperature
-
- def forward(self, features, labels=None, mask=None):
- """计算模型对比学习loss
- Args:
- features: hidden vector of shape [bsz, n_views, ...].
- labels: ground truth of shape [bsz].
- mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
- has the same class as sample i. Can be asymmetric.
- Returns:
- A loss scalar.
- """
- device = (torch.device('cuda')
- if features.is_cuda
- else torch.device('cpu'))
-
- if len(features.shape) < 2:
- raise ValueError('`features` needs to be [bsz, n_views, ...],'
- 'at least 3 dimensions are required')
- if len(features.shape) > 2:
- features = features.view(features.shape[0], features.shape[1], -1)
-
- # get batch_size
- batch_size = features.shape[0]
-
- # print(features.shape) #16x768
-
- if labels is not None and mask is not None:
- raise ValueError('Cannot define both `labels` and `mask`')
- elif labels is None and mask is None:
- mask = torch.eye(batch_size, dtype=torch.float32).to(device)
-
- # 有监督对比学习,标签相关的关系矩阵
- elif labels is not None:
- labels = labels.contiguous().view(-1, 1) # 16*1
- if labels.shape[0] != batch_size:
- raise ValueError('Num of labels does not match num of features')
-
- # 转成矩阵16*16,找到与label相同的标签,相同为1,不同为0
- mask = torch.eq(labels, labels.T).float().to(device) # 16*16
- else:
- mask = mask.float().to(device)
-
- #--------------------------------------------------------------------------------
-
- # 16 768 =》 16 1 768
- features = features.unsqueeze(dim=1)
- # 归一化第二维度
- features = F.normalize(features, dim=2)
- # 1个锚点
- contrast_count = features.shape[1]
- # 锚点特征表示就会[16x768]
- contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
-
- if self.contrast_mode == 'one':
- anchor_feature = features[:, 0]
- anchor_count = 1
- elif self.contrast_mode == 'all':
- anchor_feature = contrast_feature
- anchor_count = contrast_count
- else:
- raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
-
- # compute logits
- anchor_dot_contrast = torch.div(
- torch.matmul(anchor_feature, contrast_feature.T),
- self.temperature)
-
- # for numerical stability
- # 数值稳定性
- logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
- # 获得logits最大值
- logits = anchor_dot_contrast - logits_max.detach()
- # 线性归一化
- logits_min, _ = torch.min(logits, dim=1, keepdim=True)
- logits_max, _ = torch.max(logits, dim=1, keepdim=True)
- _range = logits_max - logits_min
- logits = torch.div(logits-logits_min,_range)
-
- # tile mask
- mask = mask.repeat(anchor_count, contrast_count)
-
- # mask-out self-contrast cases,除了对角矩阵外都为1
- logits_mask = torch.scatter(
- torch.ones_like(mask), #返回一个填充了标量值1的张量
- 1,
- torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
- 0
- )
-
- # print("logits_mask",logits_mask) #16*16 对角线
- mask = mask * logits_mask
-
- exp_logits = torch.exp(logits) * logits_mask
-
- log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
-
- # compute mean of log-likelihood over positive
- mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1)+1)
-
- # loss
- loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
- loss = loss.view(anchor_count, batch_size).mean()
-
-
- return loss
|