|
- # install pcl_pangu
- import os
- os.system('python -m pip install pcl_pangu')
- os.system('ls /tmp/dataset')
- os.system('nvidia-smi')
- import argparse
- from pcl_pangu.context import set_context
- from pcl_pangu.model import alpha
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--model', default='350M',
- type=str, choices=['350M', '2B6', '13B'],
- help="setting model size from ['350M', '2B6', '13B']")
- parser.add_argument('--data_url', default='/tmp/dataset/text_document',
- type=str,
- help="setting bin dataset text_document from: '/tmp/dataset'.")
- parser.add_argument('--load', default='/tmp/dataset/',
- type=str,
- help="loading pretrained model ckpt, from: '/tmp/dataset'.")
- parser.add_argument('--train_url', default='/tmp/output/',
- type=str,
- help="save your model to: '/tmp/output'.")
- args = parser.parse_args()
- set_context(backend='pytorch')
- print(args)
-
- model = args.model
- data_path = args.data_url
- load = args.load
- save = args.train_url
-
- config = alpha.model_config_gpu(model=model, load=load)
- alpha.inference(config, input='四川的省会是?')
- #
- # config = alpha.alpha_config_gpu(data_path=data_path)
- # alpha.train(config)
-
- config = alpha.model_config_gpu(data_path=data_path, load=load, save=save)
- alpha.fine_tune(config)
-
- print()
- pass
|