|
- import sys
- import time
- from importlib import import_module
- import mxnet
- from train import train
- from train_bert import train_bert
- import torch
- import data_got
- import numpy as np
- import argparse
- import random
- import torch.backends.cudnn as cudnn
- import platform
-
-
- # log
- class Logger(object):
- def __init__(self, fileN=None):
- self.terminal = sys.stdout
- self.filename = fileN
-
- def write(self, message):
- with open(self.filename, 'a+') as log:
- self.terminal.write(message)
- log.write(message)
-
- def flush(self):
- pass
-
-
- # 设置随机数
- def setup_seed(seed):
- random.seed(seed)
- np.random.seed(seed)
- mxnet.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- torch.backends.cudnn.deterministic = True # 保证每次结果一样
- cudnn.benchmark = False
-
-
- parser = argparse.ArgumentParser(description='Text Classification')
- parser.add_argument('--model', type=str, default='big_paper_model',
- help='choose a model: big_paper_model,model,model_bert')
- parser.add_argument('--log', type=str, default='False')
- parser.add_argument('--BERT', type=str, default='True')
-
- parser.add_argument('--device', type=str, default=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'))
- parser.add_argument('--path', type=str, default='/')
- args = parser.parse_args()
-
- if platform.system().lower() == 'windows':
- args.path = './'
- print("windows")
- elif platform.system().lower() == 'linux':
- args.path = '/'
- print("linux")
-
- # log
- time_now = time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime(time.time()))
- if args.log == 'True':
- sys.stdout = Logger(args.path + 'result/' + args.model + '+' + time_now + '.txt')
-
- # 设置随机数种子
- setup_seed(1)
-
- x = import_module('models.' + args.model)
- config = x.Config()
- config.device = args.device
- config.path = args.path
-
- print("模型:" + args.model)
- print("路由算法:" + config.Routing)
- print('loading data...\n')
-
- if args.BERT == 'True':
- train_loader, test_loader, embed, label_num, label_embed = data_got.load_aapd_bert_data(config.path,
- batch_size=config.batch_size,
- max_length=config.pad_size)
-
- print("load done")
-
- model = x.Model(config, embed, label_num, label_embed)
-
- train_bert(config, model, train_loader, test_loader, epochs=config.epochs)
-
- else:
- train_loader, test_loader, embed, label_num, label_embed = data_got.load_aapd_data(config.path,
- batch_size=config.batch_size,
- max_length=config.pad_size)
-
- embed = torch.from_numpy(embed).float()
- label_embed = torch.from_numpy(label_embed).float()
-
- print("load done")
-
- model = x.Model(config, embed, label_num, label_embed)
-
- train(config, model, train_loader, test_loader, epochs=config.epochs)
|