#7 pass eval

Merged
xuezhongxuan merged 1 commits from wangjq/KTNET:master into master 3 years ago
  1. +3
    -1
      .gitignore
  2. +9
    -8
      run_KTNET_squad.py
  3. +3
    -3
      scripts/run_squad_twomemory.sh

+ 3
- 1
.gitignore View File

@@ -142,4 +142,6 @@ analyze_fail.dat
src/__pycache__/
src/cased_L-24_H-1024_A-16/
data/
kernel_meta/
kernel_meta/
output/
log/

+ 9
- 8
run_KTNET_squad.py View File

@@ -15,6 +15,7 @@ from utils.util import LossCallBack, make_directory, LoadNewestCkpt
from src.reader.squad_twomemory import DataProcessor, write_predictions
from src.dataset import create_squad_train_dataset, create_squad_dev_dataset

import mindspore
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore import log as logger
@@ -215,13 +216,13 @@ def do_eval(processor, eval_concept_settings, eval_output_name='eval_result.json
# input_data.append(data[i])
# input_mask, src_ids, pos_ids, sent_ids, wn_concept_ids, nell_concept_ids, unique_id = input_data
src_ids = np.squeeze(data[0])
pos_ids = np.squeeze(data[1])
sent_ids = np.squeeze(data[2])
wn_concept_ids = data[3]
nell_concept_ids = data[4]
input_mask = np.squeeze(data[5])
unique_id = data[6]
src_ids = Tensor(np.squeeze(data[0]), mindspore.int32)
pos_ids = Tensor(np.squeeze(data[1]), mindspore.int32)
sent_ids = Tensor(np.squeeze(data[2]), mindspore.int32)
wn_concept_ids = Tensor(data[3], mindspore.int32)
nell_concept_ids = Tensor(data[4], mindspore.int32)
input_mask = Tensor(np.squeeze(data[5]), mindspore.float32)
unique_id = Tensor(data[6], mindspore.int32)

pad = ops.Pad(((0, 0), (0, 0), (0, 3), (0, 0)))
nell_concept_ids = pad(nell_concept_ids)
@@ -244,7 +245,7 @@ def do_eval(processor, eval_concept_settings, eval_output_name='eval_result.json
start_logits=start_logits,
end_logits=end_logits))
logger.info("unique_id: %d" % unique_id)
# logger.info("unique_id: %d" % unique_id)

# callback.update(logits, unique_id)
if not os.path.exists(args.checkpoints):


+ 3
- 3
scripts/run_squad_twomemory.sh View File

@@ -42,10 +42,10 @@ NELL_CPT_EMBEDDING_PATH=data/KB_embeddings/nell_concept2vec.txt

python3 run_KTNET_squad.py \
--device_target "Ascend" \
--device_id 4 \
--device_id 5 \
--batch_size 6 \
--do_train true \
--do_predict false \
--do_train false \
--do_predict true \
--do_lower_case false \
--init_pretraining_params $BERT_DIR/params \
--load_pretrain_checkpoint_path $BERT_DIR/roberta.ckpt \


Loading…
Cancel
Save