|
- # coding=utf-8
- # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """ Finetuning the library models for question-answering on SQuAD (Bert, XLM, XLNet)."""
-
- from __future__ import absolute_import, division, print_function
-
- import argparse
- import logging
- import os
- import random
- import glob
-
- import numpy as np
- import torch
- from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
- TensorDataset)
- from torch.utils.data.distributed import DistributedSampler
- from tqdm import tqdm, trange
-
- from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
- BertForQuestionAnswering, BertTokenizer,
- XLMConfig, XLMForQuestionAnswering,
- XLMTokenizer, XLNetConfig,
- XLNetForQuestionAnswering,
- XLNetTokenizer, RobertaConfig,
- RobertaTokenizer)
- from modeling_hotpotqa import RobertaForHotpotQA
-
- from pytorch_transformers import AdamW, WarmupLinearSchedule
-
- from utils_hotpotqa import (read_train_examples, read_eval_examples, convert_examples_to_features,
- RawResult, write_predictions,
- RawResultExtended, write_predictions_extended)
-
- from hotpotqa_loader import hotpotqa_joint_dataset, MyCollator
- from hotpotqa_utils_joint import *
-
- # The follwing import is the official SQuAD evaluation script (2.0).
- # You can remove it from the dependencies if you are using this script outside of the library
- # We've added it here for automated tests (see examples/test_examples.py file)
-
- logger = logging.getLogger(__name__)
-
- ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig)), ())
-
- MODEL_CLASSES = {
- 'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
- 'roberta': (RobertaConfig, RobertaForHotpotQA, RobertaTokenizer),
- 'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
- 'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
- }
-
- def set_seed(args):
- random.seed(args.seed)
- np.random.seed(args.seed)
- torch.manual_seed(args.seed)
- if args.n_gpu > 0:
- torch.cuda.manual_seed_all(args.seed)
-
- def to_list(tensor):
- return tensor.detach().cpu().tolist()
-
- def train(args, train_dataset, model, tokenizer):
- """ Train the model """
-
- args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
- train_collator = MyCollator(args.wdedge, args.quesedge, args.adedge)
- train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=train_collator)
-
- if args.max_steps > 0:
- t_total = args.max_steps
- args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
- else:
- t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
-
- # Prepare optimizer and schedule (linear warmup and decay)
- no_decay = ['bias', 'LayerNorm.weight']
- optimizer_grouped_parameters = [
- {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
- {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
- ]
- optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
- scheduler = WarmupLinearSchedule(optimizer, warmup_steps=int(args.warmup_steps*t_total), t_total=t_total)
- if args.fp16:
- try:
- from apex import amp
- except ImportError:
- raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
- model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
-
- # multi-gpu training (should be after apex fp16 initialization)
- if args.n_gpu > 1:
- model = torch.nn.DataParallel(model)
-
- # Distributed training (should be after apex fp16 initialization)
- if args.local_rank != -1:
- model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
- output_device=args.local_rank,
- find_unused_parameters=True)
-
- # Train!
- logger.info("***** Running training *****")
- logger.info(" Num examples = %d", len(train_dataset))
- logger.info(" Num Epochs = %d", args.num_train_epochs)
- logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
- logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
- args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
- logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
- logger.info(" Total optimization steps = %d", t_total)
-
- global_step = 0
- tr_loss, logging_loss = 0.0, 0.0
- model.zero_grad()
- train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
- set_seed(args) # Added here for reproductibility (even between python 2 and 3)
- for _ in train_iterator:
- epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
- for step, batch in enumerate(epoch_iterator):
- model.train()
- batch = tuple(t.to(args.device) for t in batch)
- inputs = {'input_ids': batch[0],
- 'input_mask': batch[1],
- 'segment_ids': None if args.model_type == 'xlm' else batch[2],
- 'start_positions': batch[8],
- 'end_positions': batch[9],
- 'adj_matrix': batch[4],
- 'graph_mask': batch[5],
- 'sent_start': batch[6],
- 'sent_end': batch[7],
- 'sp_label': batch[10],
- 'all_answer_type': batch[11],
- 'sent_sum_way': args.sent_sum_way,
- 'span_loss_weight': args.span_loss_weight,}
- if args.model_type in ['xlnet', 'xlm', 'roberta']:
- inputs.update({'p_mask': batch[3].float()})
-
- outputs = model(**inputs)
- loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
-
- if args.n_gpu > 1:
- loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
- if args.gradient_accumulation_steps > 1:
- loss = loss / args.gradient_accumulation_steps
-
- if args.fp16:
- with amp.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward()
- torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
- else:
- loss.backward()
- torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
-
- tr_loss += loss.item()
- if (step + 1) % args.gradient_accumulation_steps == 0:
- optimizer.step()
- scheduler.step() # Update learning rate schedule
- model.zero_grad()
- global_step += 1
-
- if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
- # Save model checkpoint
- output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
- if not os.path.exists(output_dir):
- os.makedirs(output_dir)
- model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
- model_to_save.save_pretrained(output_dir)
- torch.save(args, os.path.join(output_dir, 'training_args.bin'))
- logger.info("Saving model checkpoint to %s", output_dir)
-
- if args.max_steps > 0 and global_step > args.max_steps:
- epoch_iterator.close()
- break
- if args.max_steps > 0 and global_step > args.max_steps:
- train_iterator.close()
- break
-
- return global_step, tr_loss / global_step
-
-
- def evaluate(args, model, tokenizer, prefix=""):
- with open(args.predict_file, "r", encoding='utf-8') as reader:
- orig_data = json.load(reader)
-
- dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
-
- if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
- os.makedirs(args.output_dir)
-
- args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
- # Note that DistributedSampler samples randomly
- # eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
- # eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
- eval_collator = MyCollator(args.wdedge, args.quesedge, args.adedge, is_training=False)
- eval_dataloader = DataLoader(dataset, batch_size=args.eval_batch_size, shuffle=False, collate_fn=eval_collator)
-
- # Eval!
- logger.info("***** Running evaluation {} *****".format(prefix))
- logger.info(" Num examples = %d", len(dataset))
- logger.info(" Batch size = %d", args.eval_batch_size)
- all_results = []
- sp_preds = []
- answer_preds = []
- answer_type = []
- current_sample = 0
- for batch in tqdm(eval_dataloader, desc="Evaluating"):
- model.eval()
- batch = tuple(t.to(args.device) for t in batch)
- with torch.no_grad():
- inputs = {'input_ids': batch[0],
- 'input_mask': batch[1],
- 'segment_ids': None if args.model_type == 'xlm' else batch[2], # XLM don't use segment_ids
- 'adj_matrix': batch[4],
- 'graph_mask': batch[5],
- 'sent_start': batch[6],
- 'sent_end': batch[7],
- 'sent_sum_way': args.sent_sum_way,
- 'span_loss_weight': args.span_loss_weight,
- }
- example_indices = torch.arange(current_sample, current_sample + batch[0].size(0))
- if args.model_type in ['xlnet', 'xlm', 'roberta']:
- inputs.update({'p_mask': batch[3].float()})
- outputs = model(**inputs)
-
- preds = process_logit(example_indices, (outputs[2], outputs[3]), features, examples, args.max_answer_length)
- sp_preds.extend(preds[0])
- answer_preds.extend(preds[1])
- answer_type.extend(preds[2])
-
- for i, example_index in enumerate(example_indices):
- eval_feature = features[example_index.item()]
- unique_id = int(eval_feature.unique_id)
- if args.model_type in ['xlnet', 'xlm']:
- # XLNet uses a more complex post-processing procedure
- result = RawResultExtended(unique_id = unique_id,
- start_top_log_probs = to_list(outputs[0][i]),
- start_top_index = to_list(outputs[1][i]),
- end_top_log_probs = to_list(outputs[2][i]),
- end_top_index = to_list(outputs[3][i]),
- cls_logits = to_list(outputs[4][i]))
- else:
- result = RawResult(unique_id = unique_id,
- start_logits = to_list(outputs[0][i]),
- end_logits = to_list(outputs[1][i]))
- all_results.append(result)
-
-
-
- current_sample += batch[0].size(0)
-
- write_prediction(sp_preds, answer_preds, orig_data, args.predict_file, 'output/predictions_sp.json')
-
- # Compute predictions
- output_prediction_file = "output/predictions_ans.json"
- output_nbest_file = "output/nbest_predictions_ans.json"
- if args.version_2_with_negative:
- output_null_log_odds_file = "output/null_odds_ans.json"
- else:
- output_null_log_odds_file = None
-
- if args.model_type in ['xlnet','xlm']:
- # XLNet uses a more complex post-processing procedure
- write_predictions_extended(examples, features, all_results, args.n_best_size,
- args.max_answer_length, output_prediction_file,
- output_nbest_file, output_null_log_odds_file, args.predict_file,
- model.config.start_n_top, model.config.end_n_top,
- args.version_2_with_negative, tokenizer, args.verbose_logging)
- else:
- write_predictions(examples, features, all_results, args.n_best_size,
- args.max_answer_length, args.do_lower_case, output_prediction_file,
- output_nbest_file, output_null_log_odds_file, args.verbose_logging,
- args.version_2_with_negative, args.null_score_diff_threshold)
-
- # # Evaluate with the official SQuAD script
- # evaluate_options = EVAL_OPTS(data_file=args.predict_file,
- # pred_file=output_prediction_file,
- # na_prob_file=output_null_log_odds_file)
-
- #combine_hotpotqa(output_prediction_file, 'output/predictions_sp.json')
-
- # return results
-
-
- def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False, ):
- if args.local_rank not in [-1, 0] and not evaluate:
- torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
-
- # check model type
- if args.model_type == 'xlnet':
- cls_token_at_end = True
- else:
- cls_token_at_end = False
- if args.model_type == 'roberta':
- sep_token_extra = True
- pad_token = 1
- sequence_b_segment_id = 0
- else:
- sep_token_extra=False
-
- # Load data features from cache or dataset file
- input_file = args.predict_file if evaluate else args.train_file
-
- if evaluate:
- if args.is_gold:
- examples = read_train_examples(input_file=args.predict_file, ner_file=args.predict_ner_file, is_gold=args.is_gold)[0]
- else:
- examples = read_eval_examples(input_file=args.predict_file, ner_file=args.predict_ner_file)[0]
- else:
- print("Only using {} squad examples for data augmentation!!!".format(args.squad_num))
- examples = read_train_examples(input_file=args.train_file, ner_file=args.train_ner_file, is_gold=args.is_gold, squad_num = args.squad_num)[0]
-
- cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
- 'dev_submission'.format(args.squad_num) if evaluate else 'train_submission'.format(args.squad_num),
- list(filter(None, args.model_name_or_path.split('/'))).pop(),
- str(args.max_seq_length)))
- if os.path.exists(cached_features_file) and not args.overwrite_cache:
- logger.info("Loading features from cached file %s", cached_features_file)
- features = torch.load(cached_features_file)
- else:
- logger.info("Creating features from dataset file at %s", input_file)
-
- features = convert_examples_to_features(examples=examples,
- tokenizer=tokenizer,
- max_seq_length=args.max_seq_length,
- cls_token_at_end=cls_token_at_end,
- sep_token_extra=sep_token_extra,
- pad_token = pad_token,
- sequence_b_segment_id=sequence_b_segment_id,
- doc_stride=args.doc_stride,
- max_query_length=args.max_query_length,
- is_training=True if not evaluate else False)
-
- if args.local_rank == 0 and not evaluate:
- torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
-
- dataset = hotpotqa_joint_dataset(features)
-
- if output_examples:
- return dataset, examples, features
- return dataset
-
-
- def main():
- parser = argparse.ArgumentParser()
-
- ## Required parameters
- parser.add_argument("--train_file", default=None, type=str, required=True,
- help="SQuAD json for training. E.g., train-v1.1.json")
- parser.add_argument("--train_ner_file", default=None, type=str, help="files give the ner and np in the original train file")
- parser.add_argument("--predict_file", default=None, type=str, required=True,
- help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
- parser.add_argument("--predict_ner_file", default=None, type=str,
- help="files give the ner and np in the original dev/test file")
- parser.add_argument("--model_type", default=None, type=str, required=True,
- help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
- parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
- help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
- parser.add_argument("--output_dir", default=None, type=str, required=True,
- help="The output directory where the model checkpoints and predictions will be written.")
-
- ## Other parameters
- parser.add_argument("--squad_model", action='store_true', help="Whether to use pre-finetuned squad model.")
- parser.add_argument("--squad_model_path", default="", type=str, help="pretrained squad model path")
- parser.add_argument("--is_gold", action='store_true', help="Whether to use gold documents for dev set.")
- parser.add_argument("--config_name", default="", type=str,
- help="Pretrained config name or path if not the same as model_name")
- parser.add_argument("--tokenizer_name", default="", type=str,
- help="Pretrained tokenizer name or path if not the same as model_name")
- parser.add_argument("--cache_dir", default="", type=str,
- help="Where do you want to store the pre-trained models downloaded from s3")
-
- parser.add_argument('--version_2_with_negative', action='store_true',
- help='If true, the SQuAD examples contain some that do not have an answer.')
- parser.add_argument('--null_score_diff_threshold', type=float, default=0.0,
- help="If null_score - best_non_null is greater than the threshold predict null.")
-
- parser.add_argument("--max_seq_length", default=384, type=int,
- help="The maximum total input sequence length after WordPiece tokenization. Sequences "
- "longer than this will be truncated, and sequences shorter than this will be padded.")
- parser.add_argument("--doc_stride", default=128, type=int,
- help="When splitting up a long document into chunks, how much stride to take between chunks.")
- parser.add_argument("--max_query_length", default=64, type=int,
- help="The maximum number of tokens for the question. Questions longer than this will "
- "be truncated to this length.")
- parser.add_argument("--do_train", action='store_true',
- help="Whether to run training.")
- parser.add_argument("--do_eval", action='store_true',
- help="Whether to run eval on the dev set.")
- parser.add_argument("--evaluate_during_training", action='store_true',
- help="Rul evaluation during training at each logging step.")
- parser.add_argument("--do_lower_case", action='store_true',
- help="Set this flag if you are using an uncased model.")
-
- parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
- help="Batch size per GPU/CPU for training.")
- parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
- help="Batch size per GPU/CPU for evaluation.")
- parser.add_argument("--learning_rate", default=5e-5, type=float,
- help="The initial learning rate for Adam.")
- parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
- help="Number of updates steps to accumulate before performing a backward/update pass.")
- parser.add_argument("--weight_decay", default=0.0, type=float,
- help="Weight deay if we apply some.")
- parser.add_argument("--adam_epsilon", default=1e-8, type=float,
- help="Epsilon for Adam optimizer.")
- parser.add_argument("--max_grad_norm", default=1.0, type=float,
- help="Max gradient norm.")
- parser.add_argument("--num_train_epochs", default=3.0, type=float,
- help="Total number of training epochs to perform.")
- parser.add_argument("--max_steps", default=-1, type=int,
- help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
- parser.add_argument("--warmup_steps", default=0, type=float,
- help="Linear warmup over warmup_steps. Actually warmup ratio")
- parser.add_argument("--n_best_size", default=20, type=int,
- help="The total number of n-best predictions to generate in the nbest_predictions.json output file.")
- parser.add_argument("--max_answer_length", default=30, type=int,
- help="The maximum length of an answer that can be generated. This is needed because the start "
- "and end predictions are not conditioned on one another.")
- parser.add_argument("--verbose_logging", action='store_true',
- help="If true, all of the warnings related to data processing will be printed. "
- "A number of warnings are expected for a normal SQuAD evaluation.")
-
- parser.add_argument('--logging_steps', type=int, default=50,
- help="Log every X updates steps.")
- parser.add_argument('--save_steps', type=int, default=100000,
- help="Save checkpoint every X updates steps.")
- parser.add_argument("--eval_all_checkpoints", action='store_true',
- help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
- parser.add_argument("--no_cuda", action='store_true',
- help="Whether not to use CUDA when available")
- parser.add_argument('--overwrite_output_dir', action='store_true',
- help="Overwrite the content of the output directory")
- parser.add_argument('--overwrite_cache', action='store_true',
- help="Overwrite the cached training and evaluation sets")
- parser.add_argument('--seed', type=int, default=42,
- help="random seed for initialization")
-
- parser.add_argument("--local_rank", type=int, default=-1,
- help="local_rank for distributed training on gpus")
- parser.add_argument('--fp16', action='store_true',
- help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
- parser.add_argument('--fp16_opt_level', type=str, default='O1',
- help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
- "See details at https://nvidia.github.io/apex/amp.html")
- parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
- parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
- parser.add_argument('--no_gnn',
- action='store_true',
- help='If true, no GNN will be used.')
- parser.add_argument('--adjnorm',
- action='store_true',
- help='If true, apply adj normalization.')
- parser.add_argument('--hop',
- type=int, default=1,
- help="number of GNN hops.")
- parser.add_argument('--wdedge',
- action='store_true',
- help="whether to use within-doc edge")
- parser.add_argument('--adedge',
- action='store_true',
- help="whether to use across-doc edge")
- parser.add_argument('--quesedge',
- action='store_true',
- help="whether to use question edge")
- parser.add_argument('--sent_sum_way',
- type=str, default="avg",
- help="how to get sentence embedding from bert output")
- parser.add_argument('--span_from_sp',
- action='store_true',
- help="whether to get span logtis from sp logits")
- parser.add_argument('--sp_from_span',
- action='store_true',
- help="whether to get sp logits from span logits")
- parser.add_argument('--span_loss_weight',
- type=float, default=1.0,
- help="weight on span prediction loss")
- parser.add_argument('--gsn', action='store_true', help='whether to use GSN')
- parser.add_argument('--sent_with_cls', action='store_true', help='whether to append cls output to sent')
- parser.add_argument('--squad_num', type=int, default=0, help="how many suqad samples to use")
- args = parser.parse_args()
-
- print(args)
-
- if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
- raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
-
- # Setup distant debugging if needed
- if args.server_ip and args.server_port:
- # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
- import ptvsd
- print("Waiting for debugger attach")
- ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
- ptvsd.wait_for_attach()
-
- # Setup CUDA, GPU & distributed training
- if args.local_rank == -1 or args.no_cuda:
- device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
- args.n_gpu = torch.cuda.device_count()
- else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
- torch.cuda.set_device(args.local_rank)
- device = torch.device("cuda", args.local_rank)
- torch.distributed.init_process_group(backend='nccl')
- args.n_gpu = 1
- args.device = device
-
- # Setup logging
- logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
- datefmt = '%m/%d/%Y %H:%M:%S',
- level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
- logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
- args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
-
- # Set seed
- set_seed(args)
-
- # Load pretrained model and tokenizer
- if args.local_rank not in [-1, 0]:
- torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
-
- args.model_type = args.model_type.lower()
- config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
- config = config_class.from_pretrained('models/qa_model/')
- tokenizer = tokenizer_class.from_pretrained('models/qa_model/', do_lower_case=args.do_lower_case)
- model = model_class.from_pretrained("models/qa_model/", from_tf=bool('.ckpt' in args.model_name_or_path), config=config,
- num_hop = args.hop, no_gnn=args.no_gnn, num_rel = int(args.wdedge) + int(args.adedge) + int(args.quesedge), span_from_sp = args.span_from_sp,
- sp_from_span = args.sp_from_span, gsn=args.gsn, sent_with_cls=args.sent_with_cls)
-
- model.to(args.device)
-
- # Evaluate
- evaluate(args, model, tokenizer, prefix="")
-
-
- if __name__ == "__main__":
- main()
|