|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import os
- import json
- import pickle
- import numpy as np
- import mindspore as ms
- from mindspore import Tensor
- from src.config import eval_config
- from src.greedydecoder import MSGreedyDecoder
-
- def get_target_indices(filePath):
- target_list = []
- with open(filePath, 'r', encoding='utf-8') as f:
- for refer in f:
- target_list.append(refer.strip())
- return target_list
-
- def generate_output():
- post_result_path = "result_Files_20"
- targets_file = "target_20.txt"
- config = eval_config
- file_num = int(len(os.listdir(post_result_path)))
- # batch_size = config.batch_size
- with open(config.DataConfig.labels_path) as label_file:
- labels = json.load(label_file)
- if config.LMConfig.decoder_type == 'greedy':
- decoder = MSGreedyDecoder(labels=labels, blank_index=labels.index('_'))
- else:
- raise NotImplementedError("Only greedy decoder is supported now")
- total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0
- output_data = []
- start = 0
- target_strings = get_target_indices(targets_file)
- for i in range(file_num):
- out_name = "deepspeech2_" + str(
- config.DataConfig.batch_size) + "_" + str(i) + "_0.bin"
- output_sizes_name = "deepspeech2_" + str(
- config.DataConfig.batch_size) + "_" + str(i) + ".bin"
- out = np.fromfile(os.path.join(post_result_path, out_name), np.float32)
- output_sizes = np.fromfile(os.path.join("./preprocess_Result_20/outputlength_data",
- output_sizes_name), np.float32)
- out = out.reshape(-1, config.DataConfig.batch_size, 29)
- out = out.transpose(1, 0, 2)
- out, output_sizes = Tensor(out), Tensor(output_sizes, ms.int32)
- decoded_output, _ = decoder.decode(out, output_sizes)
- target_string = target_strings[start:start + config.DataConfig.batch_size]
- start += config.DataConfig.batch_size
- for j in range(len(target_string)):
- transcript, reference = decoded_output[j][0], target_string[j]
- wer_inst = decoder.wer(transcript, reference)
- cer_inst = decoder.cer(transcript, reference)
- total_wer += wer_inst
- total_cer += cer_inst
- num_tokens += len(reference.split())
- num_chars += len(reference.replace(' ', ''))
- if config.verbose:
- print("Ref:", reference.lower())
- print("Hyp:", transcript.lower())
- print("WER:", float(wer_inst) / len(reference.split()),
- "CER:", float(cer_inst) / len(reference.replace(' ', '')), "\n")
- wer = float(total_wer) / num_tokens
- cer = float(total_cer) / num_chars
-
- print('Test Summary \t'
- 'Average WER {wer:.3f}\t'
- 'Average CER {cer:.3f}\t'.format(wer=wer * 100, cer=cer * 100))
-
- if config.save_output is not None:
- with open(config.save_output + '.bin', 'wb') as output:
- pickle.dump(output_data, output)
-
- if __name__ == "__main__":
- generate_output()
|