|
- # -*- coding: UTF-8 -*-
- """
- -----------------------------------
- @Author : Encore
- @Date : 2022/10/31
- -----------------------------------
- """
- import os
-
- import torch
- import torch.nn as nn
- import torch.distributed.rpc as rpc
- from torch.distributed.rpc import RRef
- from transformers import BertModel, BertTokenizerFast
-
- from config import args
-
-
- class BertFront(nn.Module):
- def __init__(self, pretrain_dir, split_num, rank, device):
- super().__init__()
- self.device = device
- bert = BertModel.from_pretrained(pretrain_dir)
- bert.to(device)
-
- self.embeddings = bert.embeddings
- self.encoder = nn.ModuleList([bert.encoder.layer[i] for i in range(split_num - 1)])
-
- split_layer = bert.encoder.layer[split_num - 1]
- self.attention = split_layer.attention
-
- weight = split_layer.intermediate.dense.weight
- bias = split_layer.intermediate.dense.bias
- u, s, v = torch.linalg.svd(weight)
-
- # self.dense_u = nn.Linear(in_features=1, out_features=1, bias=True)
- self.dense_s = nn.Linear(in_features=1, out_features=1, bias=False)
- self.dense_v = nn.Linear(in_features=1, out_features=1, bias=False)
-
- # self.dense_u.weight = nn.Parameter(u[:, :rank].clone())
- # self.dense_u.bias = bias
- self.dense_s.weight = nn.Parameter(torch.diag(s[:rank]))
- self.dense_v.weight = nn.Parameter(v[:rank].clone())
-
- self.to(device)
-
- def forward(self, input_ids=None, token_type_ids=None, attention_mask=None):
- input_ids = input_ids.to(self.device)
- token_type_ids = token_type_ids.to(self.device)
- attention_mask = attention_mask.to(self.device)
-
- extended_attention_mask = attention_mask[:, None, None, :]
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
-
- hidden_states = self.embeddings(
- input_ids=input_ids,
- position_ids=None,
- token_type_ids=token_type_ids,
- inputs_embeds=None,
- past_key_values_length=0,
- )
-
- for layer in self.encoder:
- layer_output = layer(hidden_states, extended_attention_mask)
- hidden_states = layer_output[0]
-
- hidden_states = self.attention(hidden_states, extended_attention_mask)[0]
- hidden_states = self.dense_v(hidden_states)
- hidden_states = self.dense_s(hidden_states)
-
- return hidden_states.cpu()
-
- def parameter_rrefs(self):
- return [RRef(p) for p in self.parameters()]
-
-
- class BertBack(nn.Module):
- def __init__(self, pretrain_dir, split_num, rank, label_nums, device):
- super().__init__()
- self.device = device
-
- bert = BertModel.from_pretrained(pretrain_dir)
- bert.to(device)
-
- self.pooler = bert.pooler
- self.encoder = nn.ModuleList([bert.encoder.layer[i] for i in range(split_num, 12)])
-
- split_layer = bert.encoder.layer[split_num - 1]
- self.intermediate_act_fn = split_layer.intermediate.intermediate_act_fn
- self.output = split_layer.output
-
- weight = split_layer.intermediate.dense.weight
- bias = split_layer.intermediate.dense.bias
- u, s, v = torch.linalg.svd(weight)
-
- self.dense_u = nn.Linear(in_features=1, out_features=1, bias=True)
- # self.dense_s = nn.Linear(in_features=1, out_features=1, bias=False)
- # self.dense_v = nn.Linear(in_features=1, out_features=1, bias=False)
-
- self.dense_u.weight = nn.Parameter(u[:, :rank].clone())
- self.dense_u.bias = bias
- # self.dense_s.weight = nn.Parameter(torch.diag(s[:rank]))
- # self.dense_v.weight = nn.Parameter(v[:rank].clone())
-
- self.dropout = nn.Dropout(p=0.1)
- self.fc_prediction = nn.Linear(in_features=768, out_features=label_nums)
- self.to(device)
-
- def forward(self, hidden_states=None, attention_mask=None):
- hidden_states = hidden_states.to(self.device)
- attention_mask = attention_mask.to(self.device)
-
- extended_attention_mask = attention_mask[:, None, None, :]
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
-
- hidden_states = self.dense_u(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- hidden_states = self.output.dense(hidden_states)
- hidden_states = self.output.dropout(hidden_states)
- hidden_states = self.output.LayerNorm(hidden_states)
-
- for layer in self.encoder:
- layer_output = layer(hidden_states, extended_attention_mask)
- hidden_states = layer_output[0]
-
- hidden_states = self.pooler(hidden_states)
- cls_embedding = self.dropout(hidden_states)
- prediction = self.fc_prediction(cls_embedding)
-
- return prediction.cpu()
-
- def parameter_rrefs(self):
- return [RRef(p) for p in self.parameters()]
-
-
- class TaskHead(nn.Module):
- def __init__(self, label_nums, device):
- super().__init__()
- self.dropout = nn.Dropout(p=0.1)
- self.fc_prediction = nn.Linear(in_features=768, out_features=label_nums)
- self.device = device
-
- self.to(device)
-
- def forward(self, cls_embedding):
- cls_embedding = cls_embedding.to_here().to(self.device)
- cls_embedding = self.dropout(cls_embedding)
- prediction = self.fc_prediction(cls_embedding)
-
- return prediction.cpu()
-
- def parameter_rrefs(self):
- return [RRef(p) for p in self.parameters()]
-
-
- class DistBert(nn.Module):
- def __init__(self, pretrain_dir, split_num, rank, label_nums, device):
- super().__init__()
- self.front_ref = rpc.remote("worker0", BertFront, args=(pretrain_dir, split_num, rank, device))
- self.back_ref = rpc.remote("worker1", BertBack, args=(pretrain_dir, split_num, rank, label_nums, device))
-
- def forward(self, input_ids, token_type_ids=None, attention_mask=None):
- hs = self.front_ref.rpc_sync().forward(input_ids, token_type_ids, attention_mask)
- prediction = self.back_ref.rpc_sync().forward(hs, attention_mask)
-
- return prediction
-
- def parameter_rrefs(self):
- remote_params = []
- remote_params.extend(self.front_ref.rpc_sync().parameter_rrefs())
- remote_params.extend(self.back_ref.rpc_sync().parameter_rrefs())
-
- return remote_params
-
-
- class BertFrontOri(nn.Module):
- def __init__(self, pretrain_dir, split_num, rank, device):
- super().__init__()
- self.device = device
- bert = BertModel.from_pretrained(pretrain_dir)
- bert.to(device)
-
- self.embeddings = bert.embeddings
- self.encoder = nn.ModuleList([bert.encoder.layer[i] for i in range(split_num)])
-
- self.to(device)
-
- def forward(self, input_ids=None, token_type_ids=None, attention_mask=None):
- input_ids = input_ids.to(self.device)
- token_type_ids = token_type_ids.to(self.device)
- attention_mask = attention_mask.to(self.device)
-
- extended_attention_mask = attention_mask[:, None, None, :]
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
-
- hidden_states = self.embeddings(
- input_ids=input_ids,
- position_ids=None,
- token_type_ids=token_type_ids,
- inputs_embeds=None,
- past_key_values_length=0,
- )
-
- for layer in self.encoder:
- layer_output = layer(hidden_states, extended_attention_mask)
- hidden_states = layer_output[0]
-
- return hidden_states.cpu()
-
- def parameter_rrefs(self):
- return [RRef(p) for p in self.parameters()]
-
-
- class BertBackOri(nn.Module):
- def __init__(self, pretrain_dir, split_num, rank, label_nums, device):
- super().__init__()
- self.device = device
-
- bert = BertModel.from_pretrained(pretrain_dir)
- bert.to(device)
-
- self.pooler = bert.pooler
- self.encoder = nn.ModuleList([bert.encoder.layer[i] for i in range(split_num, 12)])
-
- self.dropout = nn.Dropout(p=0.1)
- self.fc_prediction = nn.Linear(in_features=768, out_features=label_nums)
- self.to(device)
-
- def forward(self, hidden_states=None, attention_mask=None):
- hidden_states = hidden_states.to(self.device)
- attention_mask = attention_mask.to(self.device)
-
- extended_attention_mask = attention_mask[:, None, None, :]
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
-
- for layer in self.encoder:
- layer_output = layer(hidden_states, extended_attention_mask)
- hidden_states = layer_output[0]
-
- hidden_states = self.pooler(hidden_states)
- cls_embedding = self.dropout(hidden_states)
- prediction = self.fc_prediction(cls_embedding)
-
- return prediction.cpu()
-
- def parameter_rrefs(self):
- return [RRef(p) for p in self.parameters()]
-
-
- class DistBertOri(nn.Module):
- def __init__(self, pretrain_dir, split_num, rank, label_nums, device):
- super().__init__()
- self.front_ref = rpc.remote("worker0", BertFrontOri, args=(pretrain_dir, split_num, rank, device))
- self.back_ref = rpc.remote("worker1", BertBackOri, args=(pretrain_dir, split_num, rank, label_nums, device))
-
- def forward(self, input_ids, token_type_ids=None, attention_mask=None):
- hs = self.front_ref.rpc_sync().forward(input_ids, token_type_ids, attention_mask)
- prediction = self.back_ref.rpc_sync().forward(hs, attention_mask)
-
- return prediction
-
- def parameter_rrefs(self):
- remote_params = []
- remote_params.extend(self.front_ref.rpc_sync().parameter_rrefs())
- remote_params.extend(self.back_ref.rpc_sync().parameter_rrefs())
-
- return remote_params
-
-
- if __name__ == '__main__':
- tokenizer = BertTokenizerFast.from_pretrained(args.pretrain_dir)
- # bert = BertModel.from_pretrained(args.pretrain_dir)
-
- text = "it's a sunny day"
- encoded_text = tokenizer(text, return_tensors="pt")
-
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '29500'
- rpc.init_rpc("worker0", rank=0, world_size=2)
- # remote_bert_module = RemoteModule("worker1/cpu", BertShard0, args=(args.pretrain_dir,))
- # remote_bert_module.train()
- # ret = remote_bert_module.forward(**encoded_text)
- # ret = RRef(encoded_text["input_ids"])
-
- model = DistBert(args.pretrain_dir, 2, torch.device("cuda:0"))
- model.eval()
- ret = model(**encoded_text)
-
- print(ret)
- # print(ret.is_owner())
- # print(ret.owner())
- # print(ret.owner_name())
- # print(ret.local_value())
- # print(dir(ret))
- # print(help(ret))
- rpc.shutdown()
|