|
- # coding: UTF-8
- import time
- import torch
- import numpy as np
- from train_eval import train, init_network
- from importlib import import_module
- import argparse
-
- parser = argparse.ArgumentParser(description='Chinese Text Classification')
- parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
- parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
- parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
- args = parser.parse_args()
-
-
- if __name__ == '__main__':
- dataset = 'THUCNews' # 数据集
-
- # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
- embedding = 'embedding_SougouNews.npz'
- if args.embedding == 'random':
- embedding = 'random'
- model_name = args.model # 'TextRCNN' # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer
- if model_name == 'FastText':
- from utils_fasttext import build_dataset, build_iterator, get_time_dif
- embedding = 'random'
- else:
- from utils import build_dataset, build_iterator, get_time_dif
-
- x = import_module('models.' + model_name)
- config = x.Config(dataset, embedding)
- np.random.seed(1)
- torch.manual_seed(1)
- torch.cuda.manual_seed_all(1)
- torch.backends.cudnn.deterministic = True # 保证每次结果一样
-
- start_time = time.time()
- print("Loading data...")
- vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
- train_iter = build_iterator(train_data, config)
- dev_iter = build_iterator(dev_data, config)
- test_iter = build_iterator(test_data, config)
- time_dif = get_time_dif(start_time)
- print("Time usage:", time_dif)
-
- # train
- config.n_vocab = len(vocab)
- model = x.Model(config).to(config.device)
- if model_name != 'Transformer':
- init_network(model)
- print(model.parameters)
- train(config, model, train_iter, dev_iter, test_iter)
|