|
- # -*- coding: UTF-8 -*-
- """
- -----------------------------------
- @Author : Encore
- @Date : 2022/7/18
- -----------------------------------
- """
- import os
- import argparse
- import logging
- import sys
-
- # 系统路径设置
- root = os.path.dirname(os.path.dirname(__file__))
- sys.path.append(root)
- # print(sys.path)
-
- parser = argparse.ArgumentParser()
- # 路径配置
- # parser.add_argument("--pretrain_dir", type=str, default="D:/pretrain/pt_bert")
- parser.add_argument("--pretrain_dir", type=str, default="/userhome/pretrain/bert-base-uncased")
- parser.add_argument("--data_dir", type=str, default="/userhome/data/SST-2")
- # parser.add_argument("--data_dir", type=str, default="D:/data/SST-2")
- parser.add_argument("--save_path", type=str, default="/finetune.pt")
- parser.add_argument("--log_path", type=str, default="./train.log")
-
- # 超参数配置
- parser.add_argument("--epoch", type=int, default=3)
- parser.add_argument("--train_batch_size", type=int, default=32)
- parser.add_argument("--dev_batch_size", type=int, default=32)
- parser.add_argument("--test_batch_size", type=int, default=32)
- parser.add_argument("--max_length", type=int, default=66)
- parser.add_argument("--learning_rate", type=float, default=2e-5)
- parser.add_argument("--warm_up_pct", type=float, default=0.1)
-
- parser.add_argument("--fw_threshold", type=float, default=0.5)
- # parser.add_argument("--bw_threshold", type=float, default=0.00003)
- parser.add_argument("--bw_threshold", type=float, default=0.)
- parser.add_argument("--sparse_pct", type=float, default=0.9)
- parser.add_argument("--rank", type=float, default=1)
-
- args = parser.parse_args()
-
- args.vocab_path = os.path.join(args.pretrain_dir, "vocab.txt")
- args.weight_path = os.path.join(args.pretrain_dir, "pytorch_model.bin")
- args.config_path = os.path.join(args.pretrain_dir, "config.json")
-
- args.train_path = os.path.join(args.data_dir, "train.tsv")
- args.dev_path = os.path.join(args.data_dir, "dev.tsv")
- args.test_path = os.path.join(args.data_dir, "test.tsv")
- args.label_path = os.path.join(args.data_dir, "labels.tsv")
-
- # print(dir(args))
-
-
- # log设置
- with open(args.log_path, 'a', encoding="utf-8") as f:
- f.write('\n')
-
- logger = logging.getLogger()
- logger.setLevel(logging.DEBUG)
-
- fh = logging.FileHandler(args.log_path, 'a', encoding='utf-8')
- fh.setLevel(logging.INFO)
- formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s:%(message)s")
- fh.setFormatter(formatter)
- logger.addHandler(fh)
-
- sh = logging.StreamHandler(stream=sys.stdout)
- sh.setLevel(logging.INFO)
- formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s:%(message)s")
- sh.setFormatter(formatter)
- logger.addHandler(sh)
-
- # 系统路径设置
- # root = os.path.dirname(os.path.dirname(__file__))
- # sys.path.append(root)
- # print(sys.path)
|