|
- # -*- coding: UTF-8 -*-
- """
- -----------------------------------
- @Author : Encore
- @Date : 2022/8/18
- -----------------------------------
- """
- import torch
- import torch.nn as nn
- from transformers import ViTModel, ViTImageProcessor
- # from transformers.models.vit.modeling_vit import ViTLayer
-
- from config import args
-
- torch.set_printoptions(precision=8)
-
-
- class Classification(nn.Module):
- def __init__(self, pretrain_dir, label_nums):
- super().__init__()
- self.vit = ViTModel.from_pretrained(pretrain_dir)
- self.fc_prediction = nn.Linear(in_features=768, out_features=label_nums)
-
- def forward(self, pixel_values):
- cls_embedding = self.vit(pixel_values).pooler_output
- # cls_embedding = self.dropout(cls_embedding)
- prediction = self.fc_prediction(cls_embedding)
-
- return prediction
-
-
- class SvdLinear(nn.Module):
- def __init__(self, layer, rank):
- super().__init__()
- weight = layer.weight
- bias = layer.bias
- u, s, v = torch.linalg.svd(weight)
-
- self.dense_u = nn.Linear(in_features=1, out_features=1, bias=True)
- self.dense_s = nn.Linear(in_features=1, out_features=1, bias=False)
- self.dense_v = nn.Linear(in_features=1, out_features=1, bias=False)
-
- self.dense_u.weight = nn.Parameter(u[:, :rank].clone())
- self.dense_u.bias = bias
- self.dense_s.weight = nn.Parameter(torch.diag(s[:rank]))
- self.dense_v.weight = nn.Parameter(v[:rank].clone())
-
- def forward(self, inputs):
- out = self.dense_v(inputs)
- out = self.dense_s(out)
- out = self.dense_u(out)
-
- return out
-
-
- class NoResViTLayer(nn.Module):
- """This corresponds to the Block class in the timm implementation."""
-
- def __init__(self, layer, rank):
- super().__init__()
- self.attention = layer.attention
- self.intermediate = layer.intermediate
- self.output = layer.output
- self.layernorm_before = layer.layernorm_before
- self.layernorm_after = layer.layernorm_after
-
- self.intermediate.dense = SvdLinear(self.intermediate.dense, rank=rank)
-
- def forward(
- self,
- hidden_states,
- head_mask=None,
- output_attentions=False,
- ):
- self_attention_outputs = self.attention(
- self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
- head_mask,
- output_attentions=output_attentions,
- )
- attention_output = self_attention_outputs[0]
- outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
-
- # first residual connection
- hidden_states = attention_output + hidden_states
-
- # in ViT, layernorm is also applied after self-attention
- layer_output = self.layernorm_after(hidden_states)
- layer_output = self.intermediate(layer_output)
-
- # second residual connection is done here
- layer_output = self.output.dense(layer_output)
- layer_output = self.output.dropout(layer_output)
-
- outputs = (layer_output,) + outputs
-
- return outputs
|