|
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. 2022, PCL. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """Pretrain AlphaEnhanced"""
-
- import torch
-
- from megatron import get_args
- from megatron import print_rank_0
- from megatron import get_timers
- from megatron import get_tokenizer
- from megatron import mpu
- from megatron.initialize import initialize_megatron
- from megatron.checkpointing import load_checkpoint
- from megatron.data.gpt2_dataset import build_train_valid_test_datasets
- from megatron.model import EnhancedModel
- from megatron.training import pretrain, get_model
- from megatron.utils import get_ltor_masks_and_position_ids
- from megatron.utils import reduce_losses
- import torch.nn.functional as F
-
-
- MODEL_CONFIG = {
- '350M': {
- 'num_layers': 23,
- 'hidden_size': 1024,
- 'num_attention_heads': 16,
- 'load': '',
- 'save': '/workspace/models/pangu-alpha-evolution_small'
- },
- '2B6': {
- 'num_layers': 31,
- 'hidden_size': 2560,
- 'num_attention_heads': 32,
- 'load': '/workspace/models/pangu-alpha-evolution_2.6b_fp16',
- }
- }
-
- teacher_model = None
- # kl_loss_function_batchmean = torch.nn.KLDivLoss(reduction="batchmean")
- kl_loss_function_none = torch.nn.KLDivLoss(reduction="none")
-
-
- def set_args(model_size):
- assert model_size in MODEL_CONFIG
- args = get_args()
-
- for key, value in MODEL_CONFIG[model_size].items():
- args.__dict__[key] = value
-
-
- def load_teacher_model():
- global teacher_model
-
- print_rank_0('Loading AlphaEnhanced Teacher model ...')
-
- def teacher_model_provider():
- set_args('2B6')
- _model = EnhancedModel(num_tokentypes=0, parallel_output=True)
- return _model
-
- teacher_model = get_model(teacher_model_provider)
- load_checkpoint(teacher_model, None, None) # load教师模型参数
- teacher_model.eval()
- print_rank_0('Load done')
- return teacher_model
-
-
- def student_model_provider():
- print_rank_0('building AlphaEnhanced Student model ...')
-
- set_args('350M') # 随机初始化参数
- model = EnhancedModel(num_tokentypes=0, parallel_output=True)
- return model
-
-
- def get_batch(data_iterator):
- """Generate a batch"""
- args = get_args()
- tokenizer = get_tokenizer()
-
- # Items and their type.
- keys = ['text']
- datatype = torch.int64
-
- # Broadcast data.
- if data_iterator is not None:
- data = next(data_iterator)
- else:
- data = None
- data_b = mpu.broadcast_data(keys, data, datatype)
-
- # Unpack.
- tokens_ = data_b['text'].long()
- labels = tokens_[:, 1:].contiguous()
- tokens = tokens_[:, :-1].contiguous()
-
- # Get the masks and postition ids.
- attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
- tokens,
- tokenizer.eod,
- args.reset_position_ids,
- args.reset_attention_mask,
- args.eod_mask_loss)
-
- return tokens, labels, loss_mask, attention_mask, position_ids
-
-
- def forward_step(data_iterator, student_model): # 已废弃
- global teacher_model
- """Forward step."""
- args = get_args()
- timers = get_timers()
-
- # Get the batch.
- timers('batch generator').start()
- tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
- timers('batch generator').stop()
- # Forward model.
- # losses, _ = model(tokens, position_ids, attention_mask, labels=labels)
-
- timers('forward-teacher').start()
- teacher_logits = teacher_model(tokens, position_ids, attention_mask)
- timers('forward-teacher').stop()
-
- timers('forward-student').start()
- student_logits = student_model(tokens, position_ids, attention_mask)
- timers('forward-student').stop()
-
- # 与Label算交叉熵
- losses1 = mpu.vocab_parallel_cross_entropy(student_logits.float(), labels) # 8x1024
-
- # KL散度
- student_log_softmax = F.log_softmax(student_logits.float(), dim=2)
- teacher_softmax = F.softmax(teacher_logits.float(), dim=2)
- kl_result = kl_loss_function_none(student_log_softmax, teacher_softmax) # batch_size8 x sentence1024 x 40000
- losses2 = torch.sum(kl_result, dim=2) # batch_size8 x sentence1024
-
- losses = args.alpha * losses1 + (1 - args.alpha) * losses2
-
- loss_mask = loss_mask.view(-1)
- loss_mask_sum = loss_mask.sum()
- loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask_sum
- # loss1 = torch.sum(losses1.view(-1) * loss_mask) / loss_mask_sum
- # loss2 = torch.sum(losses2.view(-1) * loss_mask) / loss_mask_sum
- # print(loss.tolist(), loss1.tolist(), loss2.tolist())
-
- # Reduce loss for logging.
- reduced_loss = reduce_losses([loss])
- return loss, {'lm loss': reduced_loss[0]}
-
-
- def forward_step2(data_iterator, model):
- """Forward step."""
- args = get_args()
- timers = get_timers()
-
- # Get the batch.
- timers('batch generator').start()
- tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
- timers('batch generator').stop()
- # Forward model.
-
- timers('forward-teacher').start()
- teacher_logits = teacher_model(tokens, position_ids, attention_mask)
- timers('forward-teacher').stop()
-
- timers('forward-student').start()
- losses, _ = model(tokens, position_ids, attention_mask, labels=labels, teacher_logits=teacher_logits)
- timers('forward-student').stop()
-
- loss_mask = loss_mask.view(-1)
- loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
-
- # Reduce loss for logging.
- reduced_loss = reduce_losses([loss])
-
- return loss, {'lm loss': reduced_loss[0]}
-
-
- def train_valid_test_datasets_provider(train_val_test_num_samples):
- """Build train, valid, and test datasets."""
- args = get_args()
-
- print_rank_0('> building train, validation, and test datasets for AlphaEnhanced ...')
- train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
- data_prefix=args.data_path,
- data_impl=args.data_impl,
- splits_string=args.split,
- train_valid_test_num_samples=train_val_test_num_samples,
- seq_length=args.seq_length,
- seed=args.seed,
- skip_warmup=(not args.mmap_warmup))
- print_rank_0("> finished creating AlphaEnhanced datasets ...")
-
- return train_ds, valid_ds, test_ds
-
-
- if __name__ == "__main__":
- def add_distillation_args(parser):
- group = parser.add_argument_group(title='distillation arguments')
- group.add_argument("--alpha", type=float, default=0.3,
- help='alpha==0, all teacher model - kl div; alpha==1, all label cross entropy, no teacher')
- # 理应在此输入参数对 MODEL_CONFIG 进行调整
- return parser
-
- initialize_megatron(args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
- extra_args_provider=add_distillation_args)
- load_teacher_model()
-
- pretrain(train_valid_test_datasets_provider, student_model_provider, forward_step2)
|