|
- # -*- coding: UTF-8 -*-
- """
- -----------------------------------
- @Author : Encore
- @Date : 2022/7/18
- -----------------------------------
- """
- import torch
- import torch.nn as nn
- from transformers import BertModel, BertTokenizerFast
-
- from config import args
-
-
- class Classification(nn.Module):
- def __init__(self, pretrain_dir, label_nums):
- super().__init__()
- self.bert = BertModel.from_pretrained(pretrain_dir)
- self.dropout = nn.Dropout(p=0.1)
- self.fc_prediction = nn.Linear(in_features=768, out_features=label_nums)
-
- def forward(self, input_ids, token_type_ids=None, attention_mask=None):
- cls_embedding = self.bert(input_ids=input_ids,
- token_type_ids=token_type_ids,
- attention_mask=attention_mask).pooler_output
- cls_embedding = self.dropout(cls_embedding)
- prediction = self.fc_prediction(cls_embedding)
-
- return prediction
-
-
- class Hook(object):
- def __init__(self):
- self.fw = {}
- self.bw = {}
- self.ratio = {}
-
- def save_fw(self, name):
- def hook_fn_forward(module, fw_input, fw_output):
- self.fw[f"{name}_in"] = fw_input[0].detach()
- self.fw[f"{name}_out"] = fw_output.detach()
-
- return hook_fn_forward
-
- def sparse_fw(self, name):
- def hook_fn_forward(module, fw_input, fw_output):
- after_zero = SetZero.apply(fw_output)
- with torch.no_grad():
- zero_num = (after_zero == 0).sum().item()
-
- self.fw[f"{name}_in"] = fw_input[0].detach()
- self.fw[f"{name}_out"] = fw_output.detach()
- self.fw[f"{name}_zero"] = after_zero.detach()
- self.ratio[f"{name}_fw"] = zero_num / torch.numel(after_zero)
-
- return after_zero
-
- return hook_fn_forward
-
- def save_bw(self, name):
- def hook_fn_backward(module, grad_input, grad_output):
- self.bw[f"{name}_in"] = grad_input[0].detach()
- self.bw[f"{name}_out"] = grad_output[0].detach()
-
- return hook_fn_backward
-
- def sparse_bw(self, name):
- def hook_fn_backward(module, grad_input, grad_output):
- mod = grad_input[0]
- with torch.no_grad():
- zero_tensor = torch.tensor(0., dtype=torch.float32, device=mod.device)
- after_zero = torch.where(torch.abs(mod) >= args.bw_threshold, mod, zero_tensor)
- zero_num = (after_zero == 0).sum().item()
-
- self.bw[f"{name}_in"] = grad_input[0].detach()
- self.bw[f"{name}_out"] = grad_output[0].detach()
- self.bw[f"{name}_zero"] = after_zero.detach()
- self.ratio[f"{name}_bw"] = zero_num / torch.numel(after_zero)
-
- return after_zero,
-
- return hook_fn_backward
-
- def fw_compensation_cum(self, name):
- def hook_fn_forward(module, fw_input, fw_output):
- compensate_v = torch.zeros_like(fw_output)
- if f"{name}_compensate" in self.fw:
- last_before_zero = self.fw[f"{name}_compensate"]
- last_after_zero = self.fw[f"{name}_zero"]
- if last_before_zero.shape == fw_output.shape:
- zero_tensor = torch.tensor(0., dtype=torch.float32, device=fw_output.device)
- compensate_v = torch.where(last_after_zero == 0, last_before_zero, zero_tensor)
-
- fw_compensate = fw_output + compensate_v
- after_zero = SetZero.apply(fw_output, fw_compensate)
- with torch.no_grad():
- zero_num = (after_zero == 0).sum().item()
-
- self.fw[f"{name}_in"] = fw_input[0].detach()
- self.fw[f"{name}_out"] = fw_output.detach()
- self.fw[f"{name}_compensate"] = fw_compensate.detach()
- self.fw[f"{name}_zero"] = after_zero.detach()
- self.ratio[f"{name}_fw"] = zero_num / torch.numel(after_zero)
-
- return after_zero
-
- return hook_fn_forward
-
- def fw_compensation(self, name):
- def hook_fn_forward(module, fw_input, fw_output):
- compensate_v = torch.zeros_like(fw_output)
- if f"{name}_compensate" in self.fw:
- last_before_zero = self.fw[f"{name}_out"]
- last_after_zero = self.fw[f"{name}_zero"]
- if last_before_zero.shape == fw_output.shape:
- zero_tensor = torch.tensor(0., dtype=torch.float32, device=fw_output.device)
- compensate_v = torch.where(last_after_zero == 0, last_before_zero, zero_tensor)
-
- fw_compensate = fw_output + compensate_v
- after_zero = SetZero.apply(fw_output, fw_compensate)
- with torch.no_grad():
- zero_num = (after_zero == 0).sum().item()
-
- self.fw[f"{name}_in"] = fw_input[0].detach()
- self.fw[f"{name}_out"] = fw_output.detach()
- self.fw[f"{name}_compensate"] = fw_compensate.detach()
- self.fw[f"{name}_zero"] = after_zero.detach()
- self.ratio[f"{name}_fw"] = zero_num / torch.numel(after_zero)
-
- return after_zero
-
- return hook_fn_forward
-
- def mod_pre(self, name):
- def hook_fn_pre(module, fw_input):
- mod = fw_input[0]
- before_mod = mod.detach().cpu().numpy()
- zero = torch.tensor(0., dtype=torch.float32, device=mod.device)
- mod = torch.where(mod.abs() >= 0.1, mod, zero)
- after_mod = mod.detach().cpu().numpy()
-
- before_num = (before_mod == 0).sum()
- after_num = (after_mod == 0).sum()
- self.ratio[f"{name}_in"] = (after_num - before_num) / before_mod.size
-
- return mod
-
- return hook_fn_pre
-
-
- class SetZero(torch.autograd.Function):
- @staticmethod
- def forward(ctx, i, alter=None):
- if alter is not None:
- zero = torch.tensor(0., dtype=torch.float32, device=i.device)
- mod = torch.where(torch.abs(i) >= args.fw_threshold, i, zero)
- else:
- zero = torch.tensor(0., dtype=torch.float32, device=i.device)
- mod = torch.where(torch.abs(alter) >= args.fw_threshold, i, zero)
- return mod
-
- @staticmethod
- def backward(ctx, grad_output):
- return grad_output, None
-
-
- if __name__ == '__main__':
- classifier = Classification(args.pretrain_dir, 1)
- # for layer in classifier.bert.encoder.layer:
- # # print(dir(layer.output.dropout))
- # layer.output.dropout.register_forward_hook(hook_fn_forward)
- # layer.output.dropout.register_full_backward_hook(hook_fn_backward)
- # break
- #
- # text = "哈哈哈哈"
- # tokenizer = BertTokenizerFast.from_pretrained(args.pretrain_dir)
- # encode_text = tokenizer(text, return_tensors="pt")
- # print(encode_text)
- #
- # classifier.train()
- # out = classifier(**encode_text)
- # out.backward()
- for n, p in classifier.named_parameters():
- print(n)
|