|
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import os
- import time
- from tracemalloc import is_tracing
- import numpy as np
- import cv2
- from src.mask_rcnn_r50 import MaskTextSpotter_Resnet50
- from src.config import config
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore import context, Tensor, export
- from src.dataset import data_to_mindrecord_byte_image, create_maskrcnn_dataset
- import argparse
-
- parser = argparse.ArgumentParser(description='Masktextspotter')
- parser.add_argument('--icdar_root', help='your dataset path', default='')
-
- def create_mindrecord_dir(model_prefix="", model_mindrecord_dir=None, is_training=True):
- if not os.path.isdir(model_mindrecord_dir):
- os.makedirs(model_mindrecord_dir)
- if config.dataset == "coco":
- if os.path.isdir(config.coco_root):
- print("Create Mindrecord.")
- data_to_mindrecord_byte_image("coco", True, model_prefix)
- print("Create Mindrecord Done, at {}".format(model_mindrecord_dir))
- elif config.dataset == "icdar":
- if os.path.isdir(config.icdar_root):
- print("Create Mindrecord.")
- data_to_mindrecord_byte_image("icdar", True, model_prefix)
- print("Create Mindrecord Done, at {}".format(model_mindrecord_dir))
- else:
- if os.path.isdir(config.IMAGE_DIR) and os.path.exists(config.ANNO_PATH):
- print("Create Mindrecord.")
- data_to_mindrecord_byte_image("other", True, model_prefix)
- print("Create Mindrecord Done, at {}".format(model_mindrecord_dir))
- else:
- raise Exception("IMAGE_DIR or ANNO_PATH not exits.")
-
- args = parser.parse_args()
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=1)
-
- net = MaskTextSpotter_Resnet50(config=config)
- net.set_train(False)
-
- prefix = "MaskRcnn.mindrecord"
- mindrecord_dir = os.path.join(config.icdar_root, config.mindrecord_test_dir)
- mindrecord_file = os.path.join( mindrecord_dir, prefix)
- if not os.path.exists(mindrecord_file):
- create_mindrecord_dir(model_prefix=prefix, model_mindrecord_dir=mindrecord_dir, is_training=True)
- ds = create_maskrcnn_dataset(mindrecord_file,
- batch_size=config.test_batch_size, is_training=False)
- dataset_size = ds.get_dataset_size()
-
- print("\n========================================\n")
- print("total images num: ", dataset_size)
- print("Processing, please wait a moment.")
- max_num = 128
- eval_iter = 0
- results = []
- for index, data in enumerate(ds.create_dict_iterator(output_numpy=True, num_epochs=1)):
- if index>2:
- break
- eval_iter += 1
- img_data = data["image"]
- img_metas = data["image_shape"]
- gt_bboxes = data["box"]
- gt_labels = data["label"]
- gt_num = data["valid_num"]
- mask_gt = data["mask_gt"]
- mask_char = data["mask_char"]
-
- start = time.time()
- # run net
- output = net(
- Tensor(img_data),
- Tensor(img_metas),
- Tensor(gt_bboxes),
- Tensor(gt_labels),
- Tensor(gt_num),
- Tensor(mask_gt),
- Tensor(mask_char),
- )
- input_data = [Tensor(img_data), Tensor(img_metas), Tensor(gt_bboxes), Tensor(gt_labels), Tensor(gt_num), Tensor(mask_gt), Tensor(mask_char)]
- export(net, *input_data, file_name=config.file_name, file_format=config.file_format)
- end = time.time()
- print("Iter {} cost time {}".format(eval_iter, end - start))
|