|
- # -*- coding: UTF-8 -*-
- """
- -----------------------------------
- @Author : Encore
- @Date : 2022/9/5
- -----------------------------------
- """
- import os
- import argparse
- import logging
- import sys
-
- # 系统路径设置
- root = os.path.dirname(__file__)
- sys.path.append(root)
- # print(sys.path)
-
- parser = argparse.ArgumentParser()
- # 路径配置
- parser.add_argument("--pretrain_dir", type=str, default="/userhome/pretrain/vit")
- # parser.add_argument("--pretrain_dir", type=str, default="D:/pretrain/vit")
- # parser.add_argument("--data_dir", type=str, default="D:/data/cifar-100-python")
- parser.add_argument("--data_dir", type=str, default="/userhome/data/cifar-100-python")
- parser.add_argument("--save_path", type=str, default="../pic.pt")
- parser.add_argument("--log_path", type=str, default="../pic.log")
-
- # 超参数配置
- parser.add_argument("--epoch", type=int, default=4)
- parser.add_argument("--train_batch_size", type=int, default=32)
- parser.add_argument("--dev_batch_size", type=int, default=32)
- parser.add_argument("--test_batch_size", type=int, default=32)
- parser.add_argument("--learning_rate", type=float, default=1e-5)
- parser.add_argument("--warm_up_pct", type=float, default=0.1)
- parser.add_argument("--cuda", type=int, default=0)
-
- # 用于解析成12个rank的配置,组之间用-分隔,每个组为层数*秩的样式,层数为1可以省略1*,秩为0代表该层不分解
- # 类似于4*32-4*1-3*32-64这种形式
- parser.add_argument("--rank", type=str, default="12*0")
-
- args = parser.parse_args()
-
- # args.vocab_path = os.path.join(args.pretrain_dir, "vocab.txt")
- # args.weight_path = os.path.join(args.pretrain_dir, "pytorch_model.bin")
- # args.config_path = os.path.join(args.pretrain_dir, "config.json")
-
- args.train_path = os.path.join(args.data_dir, "train")
- args.test_path = os.path.join(args.data_dir, "test")
- # args.label_path = os.path.join(args.data_dir, "labels.tsv")
-
- # print(dir(args))
-
-
- # log设置
- with open(args.log_path, 'a', encoding="utf-8") as f:
- f.write('\n')
-
- logger = logging.getLogger()
- logger.setLevel(logging.DEBUG)
-
- fh = logging.FileHandler(args.log_path, 'a', encoding='utf-8')
- fh.setLevel(logging.INFO)
- formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s:%(message)s")
- fh.setFormatter(formatter)
- logger.addHandler(fh)
-
- sh = logging.StreamHandler(stream=sys.stdout)
- sh.setLevel(logging.INFO)
- formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s:%(message)s")
- sh.setFormatter(formatter)
- logger.addHandler(sh)
|