|
- # -*- coding: UTF-8 -*-
- """
- -----------------------------------
- @Author : Encore
- @Date : 2022/9/5
- -----------------------------------
- """
- import os
- import argparse
- import logging
- import sys
-
- # 系统路径设置
- root = os.path.dirname(__file__)
- sys.path.append(root)
- # print(sys.path)
-
- parser = argparse.ArgumentParser()
- # 路径配置
- parser.add_argument("--pretrain_dir", type=str, default="/userhome/pretrain/mbart-large-50-many-to-many-mmt")
- # parser.add_argument("--pretrain_dir", type=str, default="D:/pretrain/mbart-large-50-many-to-many-mmt")
- parser.add_argument("--data_dir", type=str, default="/userhome/data/whitepaper")
- # parser.add_argument("--data_dir", type=str, default="D:/data/whitepaper")
- parser.add_argument("--save_path", type=str, default="model.pt")
- parser.add_argument("--load_path", type=str, default="model.pt")
- parser.add_argument("--log_path", type=str, default="train.log")
-
- # 超参数配置
- parser.add_argument("--epoch", type=int, default=3)
- parser.add_argument("--train_batch_size", type=int, default=8)
- parser.add_argument("--dev_batch_size", type=int, default=32)
- parser.add_argument("--test_batch_size", type=int, default=32)
- # parser.add_argument("--max_length", type=int, default=66)
- parser.add_argument("--learning_rate", type=float, default=2e-5)
- parser.add_argument("--warm_up_pct", type=float, default=0.1)
- parser.add_argument("--cuda", type=int, default=0)
- # parser.add_argument("--pct", type=float, default=0.7)
-
- args = parser.parse_args()
-
- args.vocab_path = os.path.join(args.pretrain_dir, "vocab.txt")
- args.weight_path = os.path.join(args.pretrain_dir, "pytorch_model.bin")
- args.config_path = os.path.join(args.pretrain_dir, "config.json")
-
- args.src_path = os.path.join(args.data_dir, "train.zh")
- args.tgt_path = os.path.join(args.data_dir, "train.en")
- args.dev_src_path = os.path.join(args.data_dir, "dev.zh")
- args.dev_tgt_path = os.path.join(args.data_dir, "dev.en")
-
-
- # print(dir(args))
-
-
- # log设置
- with open(args.log_path, 'a', encoding="utf-8") as f:
- f.write('\n')
-
- logger = logging.getLogger()
- logger.setLevel(logging.DEBUG)
-
- fh = logging.FileHandler(args.log_path, 'a', encoding='utf-8')
- fh.setLevel(logging.INFO)
- formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s:%(message)s")
- fh.setFormatter(formatter)
- logger.addHandler(fh)
-
- sh = logging.StreamHandler(stream=sys.stdout)
- sh.setLevel(logging.INFO)
- formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s:%(message)s")
- sh.setFormatter(formatter)
- logger.addHandler(sh)
|