|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ===========================================================================
- """
- network config setting, will be used in train.py and eval.py
- """
- import copy
- import inspect
- import os
- import typing
- import yaml
- from easydict import EasyDict as ed
-
- from src.model import JasperBlock, JasperDecoderForCTC, JasperEncoder
-
- src_file_root = os.path.dirname(os.path.realpath(__file__))
- labels_json_path = os.path.join(src_file_root, 'labels.json')
-
- train_config = ed({
- "TrainingConfig": {
- # "epochs": 440,
- "epochs": 2,
- },
- "DataConfig": {
- # "Data_dir":
- # '/home/work/user-job-dir/data/processed_data',
- # "train_manifest": [
- # '/home/work/user-job-dir/data/processed_data/librispeech-train-clean-100-wav.json',
- # '/home/work/user-job-dir/data/processed_data/librispeech-train-clean-360-wav.json',
- # '/home/work/user-job-dir/data/processed_data/librispeech-train-other-500-wav.json'
- # ],
- "Data_dir": '/home/work/user-job-dir/data/test_data',
- "train_manifest":
- ['/home/work/user-job-dir/data/test_data/librispeech-test-clean-wav.json'],
- "batch_size": 30,
- "labels_path": labels_json_path,
- # "labels.json",
- "SpectConfig": {
- "sample_rate": 16000,
- "window_size": 0.02,
- "window_stride": 0.01,
- "window": "hamming"
- },
- "AugmentationConfig": {
- "speed_volume_perturb": False,
- "spec_augment": False,
- "noise_dir": '',
- "noise_prob": 0.4,
- "noise_min": 0.0,
- "noise_max": 0.5,
- }
- },
-
- # "ModelConfig": {
- # "rnn_type": "LSTM",
- # "hidden_size": 1024,
- # "hidden_layers": 5,
- # "lookahead_context": 20,
- # },
- "OptimConfig": {
- "learning_rate": 0.01,
- "learning_anneal": 1.1,
- "weight_decay": 1e-5,
- "momentum": 0.9,
- "eps": 1e-8,
- "betas": (0.9, 0.999),
- "loss_scale": 1024,
- "epsilon": 0.00001
- },
- "CheckpointConfig": {
- "ckpt_file_name_prefix": 'Jasper',
- "ckpt_path": './checkpoint',
- "keep_checkpoint_max": 10
- },
- })
-
- eval_config = ed({
- "save_output": 'librispeech_val_output',
- "verbose": True,
- "DataConfig": {
- "Data_dir": '/disk2/wx/dataset/LibriSpeech',
- "test_manifest": ['/disk2/wx/dataset/LibriSpeech/librispeech-dev-clean-wav.json'],
- #"test_manifest": './data/libri_test_clean_manifest.csv',
-
- # "test_manifest": 'data/libri_test_other_manifest.csv',
- # "test_manifest": 'data/libri_val_manifest.csv',
- "batch_size": 32,
- "labels_path": labels_json_path,
- "SpectConfig": {
- "sample_rate": 16000,
- "window_size": 0.02,
- "window_stride": 0.01,
- "window": "hanning"
- },
- },
-
- # "ModelConfig": {
- # "rnn_type": "LSTM",
- # "hidden_size": 1024,
- # "hidden_layers": 5,
- # "lookahead_context": 20,
- # },
- "LMConfig": {
- "decoder_type": "greedy",
- "lm_path": './3-gram.pruned.3e-7.arpa',
- "top_paths": 1,
- "alpha": 1.818182,
- "beta": 0,
- "cutoff_top_n": 40,
- "cutoff_prob": 1.0,
- "beam_width": 1024,
- "lm_workers": 4
- },
- })
-
-
- def default_args(klass):
- sig = inspect.signature(klass.__init__)
- return {k: v.default for k, v in sig.parameters.items() if k != 'self'}
-
-
- def load(fpath):
- if fpath.endswith('.toml'):
- raise ValueError('.toml config format has been changed to .yaml')
-
- cfg = yaml.safe_load(open(fpath, 'r'))
-
- # Reload to deep copy shallow copies, which were made with yaml anchors
- yaml.Dumper.ignore_aliases = lambda *args: True
- cfg = yaml.dump(cfg)
- cfg = yaml.safe_load(cfg)
- return cfg
-
-
- def validate_and_fill(klass, user_conf, ignore_unk=[], optional=[]):
- conf = default_args(klass)
-
- for k, v in user_conf.items():
- assert k in conf or k in ignore_unk, f'Unknown parameter {k} for {klass}'
- conf[k] = v
-
- # Keep only mandatory or optional-nonempty
- conf = {k: v for k, v in conf.items() if k not in optional or v is not inspect.Parameter.empty}
-
- # Validate
- for k, v in conf.items():
- assert v is not inspect.Parameter.empty, \
- f'Value for {k} not specified for {klass}'
- return conf
-
-
- def encoder(conf):
- """Validate config for JasperEncoder and subsequent JasperBlocks"""
-
- # Validate, but don't overwrite with defaults
- for blk in conf['jasper']['encoder']['blocks']:
- validate_and_fill(JasperBlock, blk, optional=['infilters'], ignore_unk=['residual_dense'])
-
- return validate_and_fill(JasperEncoder, conf['jasper']['encoder'])
-
-
- def decoder(conf, n_classes):
- decoder_kw = {'n_classes': n_classes, **conf['jasper']['decoder']}
- return validate_and_fill(JasperDecoderForCTC, decoder_kw)
-
-
- def add_ctc_blank(symbols):
- return symbols + ['_']
-
-
- # './src/jasper10x5dr_speca.yaml'
-
- jasper_yaml_path = os.path.join(src_file_root, 'jasper10x5dr_speca.yaml')
- print(jasper_yaml_path)
- cfg = load(jasper_yaml_path)
-
- symbols = add_ctc_blank(cfg['labels'])
- encoder_kw = encoder(cfg)
- decoder_kw = decoder(cfg, n_classes=len(symbols))
|