|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- PanGu predict run
- """
- import json
- import os
- import requests
- import datetime
- import glob
-
- import numpy as np
- from tqdm import tqdm
-
- import mindspore.common.dtype as mstype
- import mindspore.communication.management as D
- import mindspore as ms
- from mindspore import context, Tensor
- from mindspore import export
- from mindspore.context import ParallelMode
- from mindspore.parallel import set_algo_parameters
- from mindspore.parallel._cost_model_context import _set_multi_subgraphs
- from mindspore.train.model import Model
- from mindspore.train.serialization import load_distributed_checkpoint, load_checkpoint
- from mindspore.nn.transformer.transformer import TransformerOpParallelConfig
-
- from src.generate import get_scores
- from src.pangu_alpha import EvalNet, PanguAlphaModel, EvalNet_200B
- from src.pangu_alpha_config import set_parse, PanguAlphaConfig
- from src.utils import get_args
-
- from mindspore.common import Parameter
- from mindspore.common.tensor import Tensor
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from src.utils import download_ckpt_from_obs
-
- def restore_checkpoint(args_param, network, cache_url='/cache/Ckpt/'):
- r"""
- Load checkpoint process.
- """
- restore_ranks = D.get_rank()
- print("======start single checkpoint", flush=True)
- ckpt_name = os.path.join(cache_url, f"rank_{restore_ranks}.ckpt")
-
- if not ckpt_name:
- print(f"There is no ckpt file in {ckpt_name}, "
- f"current ckpt_files found is {ckpt_name} "
- f"with pattern {ckpt_name}, so skip the loading.")
-
- time_stamp = datetime.datetime.now()
- print(f"time stamp {time_stamp.strftime('%Y.%m.%d-%H:%M:%S')} pre trained ckpt model {ckpt_name} loading",
- flush=True)
- # Load checkpoint files latest file
- print(f'Start to load from {ckpt_name}')
- param_dict = load_checkpoint(ckpt_name)
- # for k, v in param_dict.items():
- # print("rank: ", restore_ranks, k)
- load_param_into_net(network, param_dict, strict_load=False)
-
- def set_auto_parallel_context(args_opt):
- """Set the auto parallel context"""
- rank = 0
- device_num = 1
- context.reset_auto_parallel_context()
- # context.set_auto_parallel_context(
- # strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path)
- if args_opt.distribute == "true":
- D.init()
- device_num = D.get_group_size()
- rank = D.get_rank()
- print("rank_id is {}, device_num is {}".format(rank, device_num))
- context.set_auto_parallel_context(
- parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
- gradients_mean=False,
- full_batch=True,
- loss_repeated_mean=True,
- enable_parallel_optimizer=False,
- strategy_ckpt_save_file=f'/cache/strategy_{rank}.ckpt',
- pipeline_stages=args_opt.stage_num)
- set_algo_parameters(elementwise_op_strategy_follow=True)
- _set_multi_subgraphs()
-
- return rank, device_num
-
- def load_model(args_opt):
- r"""
- The main function for load model
- """
- context.set_context(mode=context.GRAPH_MODE)
- # Set parallel context
- rank, device_num = set_auto_parallel_context(args_opt)
-
- context.set_context(variable_memory_max_size="31GB")
- context.set_context(save_graphs=False,
- save_graphs_path="/cache/graphs_of_device_id_" + str(rank),
- device_target=args_opt.device_target)
-
- strategy_local_file = f"/cache/inference_strategy_100b_d8_mp8_dp1-{rank}.ckpt"
- ms.set_auto_parallel_context(strategy_ckpt_save_file=strategy_local_file)
-
- # if args_opt.eval_task:
- # use_past = False
- # else:
- # use_past = True if args_opt.export else (args_opt.use_past == "true")
- use_past = False
- print('local_rank:{}, start to run...'.format(rank), flush=True)
-
- # Set model property, rewrite the model parallel
- if device_num < args_opt.op_level_model_parallel_num:
- print(f"The op_level_model_parallel_num {args_opt.op_level_model_parallel_num} is smaller than the device num,"
- f"so change it to the {device_num}", flush=True)
- args_opt.op_level_model_parallel_num = device_num
- model_parallel_num = args_opt.op_level_model_parallel_num
- data_parallel_num = int(device_num / (model_parallel_num*args_opt.stage_num))
-
- parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
- model_parallel=model_parallel_num,
- pipeline_stage=args_opt.stage_num,
- micro_batch_num=args_opt.micro_size,
- vocab_emb_dp=False,
- recompute=False)
- # add sequence_parallel
- parallel_config.sequence_parallel = args_opt.sequence_parallel
- # add select_recompute
- parallel_config.select_recompute = args_opt.select_recompute
-
- per_batch_size = args_opt.per_batch_size
- batch_size = per_batch_size * data_parallel_num
- # Now only support single batch_size for predict
- if args_opt.run_type == "predict":
- batch_size = 1
-
- # download ckpt to local
- D.init()
- device_num = D.get_group_size()
- rank_id = D.get_rank()
-
-
- softmax_compute_type = mstype.float16
- top_query_softmax = mstype.float16
- config = PanguAlphaConfig(
- batch_size=batch_size,
- seq_length=args_opt.seq_length,
- vocab_size=args_opt.vocab_size,
- hidden_size=args_opt.embedding_size,
- num_layers=args_opt.num_layers,
- num_heads=args_opt.num_heads,
- post_layernorm_residual=False,
- dropout_rate=0.0,
- ffn_hidden_size=args_opt.embedding_size * 4,
- use_past=use_past,
- eod_reset=False,
- parallel_config=parallel_config,
- load_ckpt_path=None,
- run_type=args_opt.run_type,
- param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
- use_rope=args_opt.use_rope,)
-
- config.softmax_compute_fp32 = softmax_compute_type
- config.top_query_softmax_fp32 = top_query_softmax
- print("===config is: ", config, flush=True)
- print("=====args_opt is: ", args_opt, flush=True)
-
- # Define network
- pangu_alpha = PanguAlphaModel(config)
-
- from src.pangu_alpha import PanGUAlphaLossWithPrompt
- from mindspore.nn.transformer import CrossEntropyLoss
- # loss = CrossEntropyLoss()
- # eval_net = PanGUAlphaLossWithPrompt(config, pangu_alpha, loss)
- eval_net = EvalNet_200B(pangu_alpha, pad_token=args_opt.padding_id)
- eval_net.set_train(False)
-
- # # 完整模型加载,要在构图之前
- # import time
- # time.sleep((rank % 8)*20)
- # load_checkpoint(local_ckpt_path, net=eval_net)
-
- model_predict = Model(eval_net)
- # Compile network and obtain tensor layout for loading ckpt
- inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
-
- if args_opt.distribute == "false":
- predict_layout = None
- else:
- # Compiling only needs the shape
- current_index = Tensor(np.array([0]), mstype.int32)
- predict_layout = model_predict.infer_predict_layout(inputs_np, current_index)
-
- cache_url = '/cache/Ckpt/'
- download_ckpt_from_obs(args_opt, cache_url, rank=rank_id)
- restore_checkpoint(args_opt, eval_net, cache_url=cache_url)
- print("================load param ok=================", flush=True)
-
- return model_predict, config
-
- def load_model_probs(args_opt):
- r"""
- The main function for load model
- """
- context.set_context(mode=context.GRAPH_MODE)
- # Set parallel context
- rank, device_num = set_auto_parallel_context(args_opt)
-
- context.set_context(variable_memory_max_size="31GB")
- context.set_context(save_graphs=False,
- save_graphs_path="/cache/graphs_of_device_id_" + str(rank),
- device_target=args_opt.device_target)
-
- # strategy_local_file = f"/cache/inference_strategy_100b_d8_mp8_dp1-{rank}.ckpt"
- # ms.set_auto_parallel_context(strategy_ckpt_save_file=strategy_local_file)
-
- # if args_opt.eval_task:
- # use_past = False
- # else:
- # use_past = True if args_opt.export else (args_opt.use_past == "true")
- use_past = False
- print('local_rank:{}, start to run...'.format(rank), flush=True)
-
- # Set model property, rewrite the model parallel
- if device_num < args_opt.op_level_model_parallel_num:
- print(f"The op_level_model_parallel_num {args_opt.op_level_model_parallel_num} is smaller than the device num,"
- f"so change it to the {device_num}", flush=True)
- args_opt.op_level_model_parallel_num = device_num
- model_parallel_num = args_opt.op_level_model_parallel_num
- data_parallel_num = int(device_num / (model_parallel_num*args_opt.stage_num))
-
- parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
- model_parallel=model_parallel_num,
- pipeline_stage=args_opt.stage_num,
- micro_batch_num=args_opt.micro_size,
- vocab_emb_dp=False,
- recompute=False)
- # add sequence_parallel
- parallel_config.sequence_parallel = args_opt.sequence_parallel
- # add select_recompute
- parallel_config.select_recompute = args_opt.select_recompute
-
- per_batch_size = args_opt.per_batch_size
- batch_size = per_batch_size * data_parallel_num
- # Now only support single batch_size for predict
- if args_opt.run_type == "predict":
- batch_size = 1
-
- # download ckpt to local
- D.init()
- rank_id = D.get_rank()
-
- softmax_compute_type = mstype.float16
- top_query_softmax = mstype.float16
- config = PanguAlphaConfig(
- batch_size=batch_size,
- seq_length=args_opt.seq_length,
- vocab_size=args_opt.vocab_size,
- hidden_size=args_opt.embedding_size,
- num_layers=args_opt.num_layers,
- num_heads=args_opt.num_heads,
- post_layernorm_residual=False,
- dropout_rate=0.0,
- ffn_hidden_size=args_opt.embedding_size * 4,
- use_past=use_past,
- eod_reset=False,
- parallel_config=parallel_config,
- load_ckpt_path=None,
- run_type=args_opt.run_type,
- param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
- use_rope=args_opt.use_rope,)
-
- config.softmax_compute_fp32 = softmax_compute_type
- config.top_query_softmax_fp32 = top_query_softmax
- print("===config is: ", config, flush=True)
- print("=====args_opt is: ", args_opt, flush=True)
-
- # Define network
- pangu_alpha = PanguAlphaModel(config)
-
- # # 完整模型加载,要在构图之前
- from src.pangu_alpha import PanGUAlphaLossWith_notPrompt
- # from mindspore.nn.transformer import CrossEntropyLoss
- from src.loss import CrossEntropyLoss_eval
- loss = CrossEntropyLoss_eval()
- eval_net = PanGUAlphaLossWith_notPrompt(config,
- pangu_alpha,
- loss,
- pad_token=args_opt.padding_id,
- seq_length=args_opt.seq_length)
- eval_net.set_train(False)
- model_predict = Model(eval_net)
- # Compile network and obtain tensor layout for loading ckpt
- inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
- input_mask = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.float32)
- predict_layout = model_predict.infer_predict_layout(inputs_np, input_mask)
-
- cache_url = '/cache/Ckpt/'
- download_ckpt_from_obs(args_opt, cache_url, rank=rank_id)
- restore_checkpoint(args_opt, eval_net, cache_url=cache_url)
- print("================load param ok=================", flush=True)
-
- return model_predict, config
-
-
- def get_local_tokenizer():
- from transformers import LlamaTokenizer
- MAIN_DIR=os.path.dirname(os.path.abspath(__file__))
- token_version = os.getenv('TOKEN_VERSION')
- if token_version == 'v1':
- vocab_file = f'{MAIN_DIR}/tokenizer/llama_vocab/llama_zh_hf/tokenizer.model'
- else:
- assert token_version == 'v2'
- vocab_file = f'{MAIN_DIR}/tokenizer/llama_vocab/llama_zh_hf/tokenizer_2.model'
- tokenizer = LlamaTokenizer.from_pretrained(vocab_file)
-
- return tokenizer
-
-
- def get_model():
-
- opt = get_args(True)
- set_parse(opt)
- model_predict, config = load_model(opt)
-
- return model_predict, config, opt
-
- def get_model_probs():
-
- opt = get_args(True)
- set_parse(opt)
- model_predict, config = load_model_probs(opt)
-
- return model_predict, config, opt
-
- class ModelManager:
- def __init__(self, init=False):
- self.init = init
- self.model = None
- self.config = None
- self.opt = None
- self.tokenizer = None
-
-
- def get_my_tokenizer(self):
- if not self.init:
- self.tokenizer = get_local_tokenizer()
- return self.tokenizer
-
- def get_my_model(self):
- if not self.init:
- MODEL_PROBS = os.getenv('MODEL_PROBS')
- if MODEL_PROBS == 'TRUE':
- model, config, opt = get_model_probs()
- else:
- assert MODEL_PROBS == 'FALSE'
- model, config, opt = get_model()
- self.model, self.config, self.opt = model, config, opt
- return self.model, self.config, self.opt
-
- def init_finish(self):
- self.init = True
-
- model_manager = ModelManager(init=False)
-
- def get_local_model_resp_one_item(input_str, tokens_to_generate, top_k=3, logprobs=False):
-
- model, config, opt = model_manager.get_my_model()
- tokenizer = model_manager.get_my_tokenizer()
- model_manager.init_finish()
-
- rank_id = D.get_rank()
-
- return_item = None
-
- if logprobs:
- input_str, input_length, mask_length = input_str
- # input_str = mask_str + pred_str
- # assert len(tokenizer.encode(mask_str, add_special_tokens=False)) == mask_length
-
- # Tokenize input sentence to ids
- start_sentence = tokenizer.encode(input_str, add_special_tokens=False)
- input_ids = np.array(start_sentence).reshape(1, -1)
- assert len(tokenizer.encode(input_str, add_special_tokens=False)) == input_length
- # assert len(tokenizer.encode(mask_str, add_special_tokens=False)) == mask_length
- if input_ids.shape[-1] > opt.seq_length:
- input_ids = input_ids[:, :opt.seq_length]
- pad_length = opt.seq_length - input_length
- input_ids = np.pad(input_ids, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, opt.padding_id))
- input_tokens = Tensor(input_ids, mstype.int32)
-
- input_mask = np.array([[0] * mask_length +
- [1] * (input_length-mask_length) +
- [0] * (config.seq_length - input_length)]).reshape(1, -1)
- input_mask = Tensor(input_mask, mstype.float32)
- loss = model.predict(input_tokens, input_mask).asnumpy().tolist()
- if rank_id == 0:
- return_item = [0 - loss[0]]
-
-
- else:
- # generate_func = generate_increment if config.use_past else generate
- from src.generate import generate, generate_increment, generate_100b_task
-
- generate_func = generate_100b_task
-
- # Tokenize input sentence to ids
- start_sentence = tokenizer.encode(input_str, add_special_tokens=False)
- input_ids = np.array(start_sentence).reshape(1, -1)
-
-
- # Call inference
- output_ids = generate_func(model, input_ids, opt,
- top_p=0.0,
- top_k_num=top_k,
- max_generate_length=tokens_to_generate,
- duRepeate=True)
- if rank_id == 0:
- # Decode output ids to sentence
- return_item = tokenizer.decode(output_ids.tolist())
-
- return return_item
-
-
-
-
- if __name__ == '__main__':
- input_str = "阅读:" + "是不是cad系统毛病 建议重新下载" + "问:" + "cad捕捉不到点 一直跳来跳去" + "?答:"
- output = get_local_model_resp_one_item(input_str, 100, logprobs=False)
- print(output)
- output = get_local_model_resp_one_item(input_str, 0, logprobs=True)
- print(output)
- pass
-
-
|