|
- from collections import OrderedDict
-
- import torch
-
- import utils
- from models import SynthesizerTrn
-
-
- def copyStateDict(state_dict):
- if list(state_dict.keys())[0].startswith('module'):
- start_idx = 1
- else:
- start_idx = 0
- new_state_dict = OrderedDict()
- for k, v in state_dict.items():
- name = ','.join(k.split('.')[start_idx:])
- new_state_dict[name] = v
- return new_state_dict
-
-
- def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str):
- hps = utils.get_hparams_from_file(config)
-
- net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
- hps.train.segment_size // hps.data.hop_length,
- **hps.model)
-
- optim_g = torch.optim.AdamW(net_g.parameters(),
- hps.train.learning_rate,
- betas=hps.train.betas,
- eps=hps.train.eps)
-
- state_dict_g = torch.load(input_model, map_location="cpu")
- new_dict_g = copyStateDict(state_dict_g)
- keys = []
- for k, v in new_dict_g['model'].items():
- if "enc_q" in k: continue # noqa: E701
- keys.append(k)
-
- new_dict_g = {k: new_dict_g['model'][k].half() for k in keys} if ishalf else {k: new_dict_g['model'][k] for k in keys}
-
- torch.save(
- {
- 'model': new_dict_g,
- 'iteration': 0,
- 'optimizer': optim_g.state_dict(),
- 'learning_rate': 0.0001
- }, output_model)
-
-
- if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument("-c",
- "--config",
- type=str,
- default='configs/config.json')
- parser.add_argument("-i", "--input", type=str)
- parser.add_argument("-o", "--output", type=str, default=None)
- parser.add_argument('-hf', '--half', action='store_true', default=False, help='Save as FP16')
-
- args = parser.parse_args()
-
- output = args.output
-
- if output is None:
- import os.path
- filename, ext = os.path.splitext(args.input)
- half = "_half" if args.half else ""
- output = filename + "_release" + half + ext
-
- removeOptimizer(args.config, args.input, args.half, output)
|