|
- from collections import OrderedDict
- from text.symbols import symbols
- import torch
-
- from tools.log import logger
- import utils
- from models import SynthesizerTrn
- import os
-
-
- def copyStateDict(state_dict):
- if list(state_dict.keys())[0].startswith("module"):
- start_idx = 1
- else:
- start_idx = 0
- new_state_dict = OrderedDict()
- for k, v in state_dict.items():
- name = ",".join(k.split(".")[start_idx:])
- new_state_dict[name] = v
- return new_state_dict
-
-
- def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str):
- hps = utils.get_hparams_from_file(config)
-
- net_g = SynthesizerTrn(
- len(symbols),
- hps.data.filter_length // 2 + 1,
- hps.train.segment_size // hps.data.hop_length,
- n_speakers=hps.data.n_speakers,
- **hps.model,
- )
-
- optim_g = torch.optim.AdamW(
- net_g.parameters(),
- hps.train.learning_rate,
- betas=hps.train.betas,
- eps=hps.train.eps,
- )
-
- state_dict_g = torch.load(input_model, map_location="cpu")
- new_dict_g = copyStateDict(state_dict_g)
- keys = []
- for k, v in new_dict_g["model"].items():
- if "enc_q" in k:
- continue # noqa: E701
- keys.append(k)
-
- new_dict_g = (
- {k: new_dict_g["model"][k].half() for k in keys}
- if ishalf
- else {k: new_dict_g["model"][k] for k in keys}
- )
-
- torch.save(
- {
- "model": new_dict_g,
- "iteration": 0,
- "optimizer": optim_g.state_dict(),
- "learning_rate": 0.0001,
- },
- output_model,
- )
-
-
- if __name__ == "__main__":
- import argparse
-
- parser = argparse.ArgumentParser()
- parser.add_argument("-c", "--config", type=str, default="configs/config.json")
- parser.add_argument("-i", "--input", type=str)
- parser.add_argument("-o", "--output", type=str, default=None)
- parser.add_argument(
- "-hf", "--half", action="store_true", default=False, help="Save as FP16"
- )
-
- args = parser.parse_args()
-
- output = args.output
-
- if output is None:
- import os.path
-
- filename, ext = os.path.splitext(args.input)
- half = "_half" if args.half else ""
- output = filename + "_release" + half + ext
-
- removeOptimizer(args.config, args.input, args.half, output)
- logger.info(f"压缩模型成功, 输出模型: {os.path.abspath(output)}")
|