|
- # flake8: noqa: E402
- import platform
- import os
- import torch
- from torch.nn import functional as F
- from torch.utils.data import DataLoader
- from torch.utils.tensorboard import SummaryWriter
- import torch.distributed as dist
- from torch.nn.parallel import DistributedDataParallel as DDP
- from torch.cuda.amp import autocast, GradScaler
- from tqdm import tqdm
- import logging
- from config import config
- import argparse
- import datetime
-
- logging.getLogger("numba").setLevel(logging.WARNING)
- import commons
- import utils
- from data_utils import (
- TextAudioSpeakerLoader,
- TextAudioSpeakerCollate,
- DistributedBucketSampler,
- )
- from models import (
- SynthesizerTrn,
- MultiPeriodDiscriminator,
- DurationDiscriminator,
- WavLMDiscriminator,
- )
- from losses import (
- generator_loss,
- discriminator_loss,
- feature_loss,
- kl_loss,
- WavLMLoss,
- )
- from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
- from text.symbols import symbols
-
- torch.backends.cuda.matmul.allow_tf32 = True
- torch.backends.cudnn.allow_tf32 = (
- True # If encountered training problem,please try to disable TF32.
- )
- torch.set_float32_matmul_precision("medium")
- torch.backends.cuda.sdp_kernel("flash")
- torch.backends.cuda.enable_flash_sdp(True)
- torch.backends.cuda.enable_mem_efficient_sdp(
- True
- ) # Not available if torch version is lower than 2.0
- global_step = 0
-
-
- def run():
- # 环境变量解析
- envs = config.train_ms_config.env
- for env_name, env_value in envs.items():
- if env_name not in os.environ.keys():
- print("加载config中的配置{}".format(str(env_value)))
- os.environ[env_name] = str(env_value)
- print(
- "加载环境变量 \nMASTER_ADDR: {},\nMASTER_PORT: {},\nWORLD_SIZE: {},\nRANK: {},\nLOCAL_RANK: {}".format(
- os.environ["MASTER_ADDR"],
- os.environ["MASTER_PORT"],
- os.environ["WORLD_SIZE"],
- os.environ["RANK"],
- os.environ["LOCAL_RANK"],
- )
- )
-
- backend = "nccl"
- if platform.system() == "Windows":
- backend = "gloo" # If Windows,switch to gloo backend.
- dist.init_process_group(
- backend=backend,
- init_method="env://",
- timeout=datetime.timedelta(seconds=300),
- ) # Use torchrun instead of mp.spawn
- rank = dist.get_rank()
- local_rank = int(os.environ["LOCAL_RANK"])
- n_gpus = dist.get_world_size()
-
- # 命令行/config.yml配置解析
- # hps = utils.get_hparams()
- parser = argparse.ArgumentParser()
- # 非必要不建议使用命令行配置,请使用config.yml文件
- parser.add_argument(
- "-c",
- "--config",
- type=str,
- default=config.train_ms_config.config_path,
- help="JSON file for configuration",
- )
-
- parser.add_argument(
- "-m",
- "--model",
- type=str,
- help="数据集文件夹路径,请注意,数据不再默认放在/logs文件夹下。如果需要用命令行配置,请声明相对于根目录的路径",
- default=config.dataset_path,
- )
- args = parser.parse_known_args()
- model_dir = os.path.join(args.model, config.train_ms_config.model)
- if not os.path.exists(model_dir):
- os.makedirs(model_dir, exist_ok=True)
- hps = utils.get_hparams_from_file(args.config)
- hps.model_dir = model_dir
- # 比较路径是否相同
- if os.path.realpath(args.config) != os.path.realpath(
- config.train_ms_config.config_path
- ):
- with open(args.config, "r", encoding="utf-8") as f:
- data = f.read()
- with open(config.train_ms_config.config_path, "w", encoding="utf-8") as f:
- f.write(data)
-
- torch.manual_seed(hps.train.seed)
- torch.cuda.set_device(local_rank)
-
- global global_step
- if rank == 0:
- logger = utils.get_logger(hps.model_dir)
- logger.info(hps)
- utils.check_git_hash(hps.model_dir)
- writer = SummaryWriter(log_dir=hps.model_dir)
- writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
- train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
- train_sampler = DistributedBucketSampler(
- train_dataset,
- hps.train.batch_size,
- [32, 300, 400, 500, 600, 700, 800, 900, 1000],
- num_replicas=n_gpus,
- rank=rank,
- shuffle=True,
- )
- collate_fn = TextAudioSpeakerCollate()
- train_loader = DataLoader(
- train_dataset,
- num_workers=min(config.train_ms_config.num_workers, os.cpu_count() - 1),
- shuffle=False,
- pin_memory=True,
- collate_fn=collate_fn,
- batch_sampler=train_sampler,
- persistent_workers=True,
- prefetch_factor=4,
- ) # DataLoader config could be adjusted.
- if rank == 0:
- eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)
- eval_loader = DataLoader(
- eval_dataset,
- num_workers=0,
- shuffle=False,
- batch_size=1,
- pin_memory=True,
- drop_last=False,
- collate_fn=collate_fn,
- )
- if (
- "use_noise_scaled_mas" in hps.model.keys()
- and hps.model.use_noise_scaled_mas is True
- ):
- print("Using noise scaled MAS for VITS2")
- mas_noise_scale_initial = 0.01
- noise_scale_delta = 2e-6
- else:
- print("Using normal MAS for VITS1")
- mas_noise_scale_initial = 0.0
- noise_scale_delta = 0.0
- if (
- "use_duration_discriminator" in hps.model.keys()
- and hps.model.use_duration_discriminator is True
- ):
- print("Using duration discriminator for VITS2")
- net_dur_disc = DurationDiscriminator(
- hps.model.hidden_channels,
- hps.model.hidden_channels,
- 3,
- 0.1,
- gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0,
- ).cuda(local_rank)
- else:
- net_dur_disc = None
- if (
- "use_spk_conditioned_encoder" in hps.model.keys()
- and hps.model.use_spk_conditioned_encoder is True
- ):
- if hps.data.n_speakers == 0:
- raise ValueError(
- "n_speakers must be > 0 when using spk conditioned encoder to train multi-speaker model"
- )
- else:
- print("Using normal encoder for VITS1")
-
- net_g = SynthesizerTrn(
- len(symbols),
- hps.data.filter_length // 2 + 1,
- hps.train.segment_size // hps.data.hop_length,
- n_speakers=hps.data.n_speakers,
- mas_noise_scale_initial=mas_noise_scale_initial,
- noise_scale_delta=noise_scale_delta,
- **hps.model,
- ).cuda(local_rank)
-
- if getattr(hps.train, "freeze_ZH_bert", False):
- print("Freezing ZH bert encoder !!!")
- for param in net_g.enc_p.bert_proj.parameters():
- param.requires_grad = False
-
- if getattr(hps.train, "freeze_EN_bert", False):
- print("Freezing EN bert encoder !!!")
- for param in net_g.enc_p.en_bert_proj.parameters():
- param.requires_grad = False
-
- if getattr(hps.train, "freeze_JP_bert", False):
- print("Freezing JP bert encoder !!!")
- for param in net_g.enc_p.ja_bert_proj.parameters():
- param.requires_grad = False
-
- net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(local_rank)
- net_wd = WavLMDiscriminator(
- hps.model.slm.hidden, hps.model.slm.nlayers, hps.model.slm.initial_channel
- ).cuda(local_rank)
- optim_g = torch.optim.AdamW(
- filter(lambda p: p.requires_grad, net_g.parameters()),
- hps.train.learning_rate,
- betas=hps.train.betas,
- eps=hps.train.eps,
- )
- optim_d = torch.optim.AdamW(
- net_d.parameters(),
- hps.train.learning_rate,
- betas=hps.train.betas,
- eps=hps.train.eps,
- )
- optim_wd = torch.optim.AdamW(
- net_wd.parameters(),
- hps.train.learning_rate,
- betas=hps.train.betas,
- eps=hps.train.eps,
- )
- if net_dur_disc is not None:
- optim_dur_disc = torch.optim.AdamW(
- net_dur_disc.parameters(),
- hps.train.learning_rate,
- betas=hps.train.betas,
- eps=hps.train.eps,
- )
- else:
- optim_dur_disc = None
- net_g = DDP(net_g, device_ids=[local_rank], bucket_cap_mb=512)
- net_d = DDP(net_d, device_ids=[local_rank], bucket_cap_mb=512)
- net_wd = DDP(net_wd, device_ids=[local_rank], bucket_cap_mb=512)
- if net_dur_disc is not None:
- net_dur_disc = DDP(
- net_dur_disc,
- device_ids=[local_rank],
- bucket_cap_mb=512,
- )
-
- # 下载底模
- if config.train_ms_config.base["use_base_model"]:
- utils.download_checkpoint(
- hps.model_dir,
- config.train_ms_config.base,
- token=config.openi_token,
- mirror=config.mirror,
- )
- dur_resume_lr = hps.train.learning_rate
- wd_resume_lr = hps.train.learning_rate
- if net_dur_disc is not None:
- try:
- _, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
- utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
- net_dur_disc,
- optim_dur_disc,
- skip_optimizer=(
- hps.train.skip_optimizer if "skip_optimizer" in hps.train else True
- ),
- )
- if not optim_dur_disc.param_groups[0].get("initial_lr"):
- optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
- except:
- print("Initialize dur_disc")
-
- try:
- _, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
- utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
- net_g,
- optim_g,
- skip_optimizer=(
- hps.train.skip_optimizer if "skip_optimizer" in hps.train else True
- ),
- )
- _, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
- utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
- net_d,
- optim_d,
- skip_optimizer=(
- hps.train.skip_optimizer if "skip_optimizer" in hps.train else True
- ),
- )
- if not optim_g.param_groups[0].get("initial_lr"):
- optim_g.param_groups[0]["initial_lr"] = g_resume_lr
- if not optim_d.param_groups[0].get("initial_lr"):
- optim_d.param_groups[0]["initial_lr"] = d_resume_lr
-
- epoch_str = max(epoch_str, 1)
- # global_step = (epoch_str - 1) * len(train_loader)
- global_step = int(
- utils.get_steps(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"))
- )
- print(
- f"******************检测到模型存在,epoch为 {epoch_str},gloabl step为 {global_step}*********************"
- )
- except Exception as e:
- print(e)
- epoch_str = 1
- global_step = 0
-
- try:
- _, optim_wd, wd_resume_lr, epoch_str = utils.load_checkpoint(
- utils.latest_checkpoint_path(hps.model_dir, "WD_*.pth"),
- net_wd,
- optim_wd,
- skip_optimizer=(
- hps.train.skip_optimizer if "skip_optimizer" in hps.train else True
- ),
- )
- if not optim_wd.param_groups[0].get("initial_lr"):
- optim_wd.param_groups[0]["initial_lr"] = wd_resume_lr
- except Exception as e:
- print(e)
-
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
- optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
- )
- scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
- optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
- )
- scheduler_wd = torch.optim.lr_scheduler.ExponentialLR(
- optim_wd, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
- )
- if net_dur_disc is not None:
- scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(
- optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
- )
- else:
- scheduler_dur_disc = None
- scaler = GradScaler(enabled=hps.train.bf16_run)
-
- wl = WavLMLoss(
- hps.model.slm.model,
- net_wd,
- hps.data.sampling_rate,
- hps.model.slm.sr,
- ).to(local_rank)
-
- for epoch in range(epoch_str, hps.train.epochs + 1):
- if rank == 0:
- train_and_evaluate(
- rank,
- local_rank,
- epoch,
- hps,
- [net_g, net_d, net_dur_disc, net_wd, wl],
- [optim_g, optim_d, optim_dur_disc, optim_wd],
- [scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd],
- scaler,
- [train_loader, eval_loader],
- logger,
- [writer, writer_eval],
- )
- else:
- train_and_evaluate(
- rank,
- local_rank,
- epoch,
- hps,
- [net_g, net_d, net_dur_disc, net_wd, wl],
- [optim_g, optim_d, optim_dur_disc, optim_wd],
- [scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd],
- scaler,
- [train_loader, None],
- None,
- None,
- )
- scheduler_g.step()
- scheduler_d.step()
- scheduler_wd.step()
- if net_dur_disc is not None:
- scheduler_dur_disc.step()
-
-
- def train_and_evaluate(
- rank,
- local_rank,
- epoch,
- hps,
- nets,
- optims,
- schedulers,
- scaler,
- loaders,
- logger,
- writers,
- ):
- net_g, net_d, net_dur_disc, net_wd, wl = nets
- optim_g, optim_d, optim_dur_disc, optim_wd = optims
- scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd = schedulers
- train_loader, eval_loader = loaders
- if writers is not None:
- writer, writer_eval = writers
-
- train_loader.batch_sampler.set_epoch(epoch)
- global global_step
-
- net_g.train()
- net_d.train()
- net_wd.train()
- if net_dur_disc is not None:
- net_dur_disc.train()
- for batch_idx, (
- x,
- x_lengths,
- spec,
- spec_lengths,
- y,
- y_lengths,
- speakers,
- tone,
- language,
- bert,
- ja_bert,
- en_bert,
- ) in enumerate(tqdm(train_loader)):
- if net_g.module.use_noise_scaled_mas:
- current_mas_noise_scale = (
- net_g.module.mas_noise_scale_initial
- - net_g.module.noise_scale_delta * global_step
- )
- net_g.module.current_mas_noise_scale = max(current_mas_noise_scale, 0.0)
- x, x_lengths = x.cuda(local_rank, non_blocking=True), x_lengths.cuda(
- local_rank, non_blocking=True
- )
- spec, spec_lengths = spec.cuda(
- local_rank, non_blocking=True
- ), spec_lengths.cuda(local_rank, non_blocking=True)
- y, y_lengths = y.cuda(local_rank, non_blocking=True), y_lengths.cuda(
- local_rank, non_blocking=True
- )
- speakers = speakers.cuda(local_rank, non_blocking=True)
- tone = tone.cuda(local_rank, non_blocking=True)
- language = language.cuda(local_rank, non_blocking=True)
- bert = bert.cuda(local_rank, non_blocking=True)
- ja_bert = ja_bert.cuda(local_rank, non_blocking=True)
- en_bert = en_bert.cuda(local_rank, non_blocking=True)
-
- with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
- (
- y_hat,
- l_length,
- attn,
- ids_slice,
- x_mask,
- z_mask,
- (z, z_p, m_p, logs_p, m_q, logs_q),
- (hidden_x, logw, logw_, logw_sdp),
- g,
- ) = net_g(
- x,
- x_lengths,
- spec,
- spec_lengths,
- speakers,
- tone,
- language,
- bert,
- ja_bert,
- en_bert,
- )
- mel = spec_to_mel_torch(
- spec,
- hps.data.filter_length,
- hps.data.n_mel_channels,
- hps.data.sampling_rate,
- hps.data.mel_fmin,
- hps.data.mel_fmax,
- )
- y_mel = commons.slice_segments(
- mel, ids_slice, hps.train.segment_size // hps.data.hop_length
- )
- y_hat_mel = mel_spectrogram_torch(
- y_hat.squeeze(1).float(),
- hps.data.filter_length,
- hps.data.n_mel_channels,
- hps.data.sampling_rate,
- hps.data.hop_length,
- hps.data.win_length,
- hps.data.mel_fmin,
- hps.data.mel_fmax,
- )
-
- y = commons.slice_segments(
- y, ids_slice * hps.data.hop_length, hps.train.segment_size
- ) # slice
-
- # Discriminator
- y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
- with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
- loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
- y_d_hat_r, y_d_hat_g
- )
- loss_disc_all = loss_disc
- if net_dur_disc is not None:
- y_dur_hat_r, y_dur_hat_g = net_dur_disc(
- hidden_x.detach(),
- x_mask.detach(),
- logw_.detach(),
- logw.detach(),
- g.detach(),
- )
- y_dur_hat_r_sdp, y_dur_hat_g_sdp = net_dur_disc(
- hidden_x.detach(),
- x_mask.detach(),
- logw_.detach(),
- logw_sdp.detach(),
- g.detach(),
- )
- y_dur_hat_r = y_dur_hat_r + y_dur_hat_r_sdp
- y_dur_hat_g = y_dur_hat_g + y_dur_hat_g_sdp
- with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
- # TODO: I think need to mean using the mask, but for now, just mean all
- (
- loss_dur_disc,
- losses_dur_disc_r,
- losses_dur_disc_g,
- ) = discriminator_loss(y_dur_hat_r, y_dur_hat_g)
- loss_dur_disc_all = loss_dur_disc
- optim_dur_disc.zero_grad()
- scaler.scale(loss_dur_disc_all).backward()
- scaler.unscale_(optim_dur_disc)
- # torch.nn.utils.clip_grad_norm_(
- # parameters=net_dur_disc.parameters(), max_norm=100
- # )
- grad_norm_dur = commons.clip_grad_value_(
- net_dur_disc.parameters(), None
- )
- scaler.step(optim_dur_disc)
-
- optim_d.zero_grad()
- scaler.scale(loss_disc_all).backward()
- scaler.unscale_(optim_d)
- if getattr(hps.train, "bf16_run", False):
- torch.nn.utils.clip_grad_norm_(parameters=net_d.parameters(), max_norm=200)
- grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
- scaler.step(optim_d)
-
- with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
- loss_slm = wl.discriminator(
- y.detach().squeeze(), y_hat.detach().squeeze()
- ).mean()
-
- optim_wd.zero_grad()
- scaler.scale(loss_slm).backward()
- scaler.unscale_(optim_wd)
- # torch.nn.utils.clip_grad_norm_(parameters=net_wd.parameters(), max_norm=200)
- grad_norm_wd = commons.clip_grad_value_(net_wd.parameters(), None)
- scaler.step(optim_wd)
-
- with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
- # Generator
- y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
- if net_dur_disc is not None:
- _, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw_, logw, g)
- _, y_dur_hat_g_sdp = net_dur_disc(hidden_x, x_mask, logw_, logw_sdp, g)
- y_dur_hat_g = y_dur_hat_g + y_dur_hat_g_sdp
- with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
- loss_dur = torch.sum(l_length.float())
- loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
- loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
-
- loss_fm = feature_loss(fmap_r, fmap_g)
- loss_gen, losses_gen = generator_loss(y_d_hat_g)
-
- loss_lm = wl(y.detach().squeeze(), y_hat.squeeze()).mean()
- loss_lm_gen = wl.generator(y_hat.squeeze())
-
- loss_gen_all = (
- loss_gen
- + loss_fm
- + loss_mel
- + loss_dur
- + loss_kl
- + loss_lm
- + loss_lm_gen
- )
- if net_dur_disc is not None:
- loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g)
- loss_gen_all += loss_dur_gen
- optim_g.zero_grad()
- scaler.scale(loss_gen_all).backward()
- scaler.unscale_(optim_g)
- if getattr(hps.train, "bf16_run", False):
- torch.nn.utils.clip_grad_norm_(parameters=net_g.parameters(), max_norm=500)
- grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
- scaler.step(optim_g)
- scaler.update()
-
- if rank == 0:
- if global_step % hps.train.log_interval == 0:
- lr = optim_g.param_groups[0]["lr"]
- losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl]
- logger.info(
- "Train Epoch: {} [{:.0f}%]".format(
- epoch, 100.0 * batch_idx / len(train_loader)
- )
- )
- logger.info([x.item() for x in losses] + [global_step, lr])
-
- scalar_dict = {
- "loss/g/total": loss_gen_all,
- "loss/d/total": loss_disc_all,
- "loss/wd/total": loss_slm,
- "learning_rate": lr,
- "grad_norm_d": grad_norm_d,
- "grad_norm_g": grad_norm_g,
- "grad_norm_dur": grad_norm_dur,
- "grad_norm_wd": grad_norm_wd,
- }
- scalar_dict.update(
- {
- "loss/g/fm": loss_fm,
- "loss/g/mel": loss_mel,
- "loss/g/dur": loss_dur,
- "loss/g/kl": loss_kl,
- "loss/g/lm": loss_lm,
- "loss/g/lm_gen": loss_lm_gen,
- }
- )
- scalar_dict.update(
- {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
- )
- scalar_dict.update(
- {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
- )
- scalar_dict.update(
- {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
- )
-
- if net_dur_disc is not None:
- scalar_dict.update({"loss/dur_disc/total": loss_dur_disc_all})
-
- scalar_dict.update(
- {
- "loss/dur_disc_g/{}".format(i): v
- for i, v in enumerate(losses_dur_disc_g)
- }
- )
- scalar_dict.update(
- {
- "loss/dur_disc_r/{}".format(i): v
- for i, v in enumerate(losses_dur_disc_r)
- }
- )
-
- scalar_dict.update({"loss/g/dur_gen": loss_dur_gen})
- scalar_dict.update(
- {
- "loss/g/dur_gen_{}".format(i): v
- for i, v in enumerate(losses_dur_gen)
- }
- )
-
- image_dict = {
- "slice/mel_org": utils.plot_spectrogram_to_numpy(
- y_mel[0].data.cpu().numpy()
- ),
- "slice/mel_gen": utils.plot_spectrogram_to_numpy(
- y_hat_mel[0].data.cpu().numpy()
- ),
- "all/mel": utils.plot_spectrogram_to_numpy(
- mel[0].data.cpu().numpy()
- ),
- "all/attn": utils.plot_alignment_to_numpy(
- attn[0, 0].data.cpu().numpy()
- ),
- }
- utils.summarize(
- writer=writer,
- global_step=global_step,
- images=image_dict,
- scalars=scalar_dict,
- )
-
- if global_step % hps.train.eval_interval == 0:
- evaluate(hps, net_g, eval_loader, writer_eval)
- utils.save_checkpoint(
- net_g,
- optim_g,
- hps.train.learning_rate,
- epoch,
- os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
- )
- utils.save_checkpoint(
- net_d,
- optim_d,
- hps.train.learning_rate,
- epoch,
- os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
- )
- utils.save_checkpoint(
- net_wd,
- optim_wd,
- hps.train.learning_rate,
- epoch,
- os.path.join(hps.model_dir, "WD_{}.pth".format(global_step)),
- )
- if net_dur_disc is not None:
- utils.save_checkpoint(
- net_dur_disc,
- optim_dur_disc,
- hps.train.learning_rate,
- epoch,
- os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step)),
- )
- keep_ckpts = config.train_ms_config.keep_ckpts
- if keep_ckpts > 0:
- utils.clean_checkpoints(
- path_to_models=hps.model_dir,
- n_ckpts_to_keep=keep_ckpts,
- sort_by_time=True,
- )
-
- global_step += 1
-
- # gc.collect()
- # torch.cuda.empty_cache()
- if rank == 0:
- logger.info("====> Epoch: {}".format(epoch))
-
-
- def evaluate(hps, generator, eval_loader, writer_eval):
- generator.eval()
- image_dict = {}
- audio_dict = {}
- print("Evaluating ...")
- with torch.no_grad():
- for batch_idx, (
- x,
- x_lengths,
- spec,
- spec_lengths,
- y,
- y_lengths,
- speakers,
- tone,
- language,
- bert,
- ja_bert,
- en_bert,
- ) in enumerate(eval_loader):
- x, x_lengths = x.cuda(), x_lengths.cuda()
- spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
- y, y_lengths = y.cuda(), y_lengths.cuda()
- speakers = speakers.cuda()
- bert = bert.cuda()
- ja_bert = ja_bert.cuda()
- en_bert = en_bert.cuda()
- tone = tone.cuda()
- language = language.cuda()
- for use_sdp in [True, False]:
- y_hat, attn, mask, *_ = generator.module.infer(
- x,
- x_lengths,
- speakers,
- tone,
- language,
- bert,
- ja_bert,
- en_bert,
- y=spec,
- max_len=1000,
- sdp_ratio=0.0 if not use_sdp else 1.0,
- )
- y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
-
- mel = spec_to_mel_torch(
- spec,
- hps.data.filter_length,
- hps.data.n_mel_channels,
- hps.data.sampling_rate,
- hps.data.mel_fmin,
- hps.data.mel_fmax,
- )
- y_hat_mel = mel_spectrogram_torch(
- y_hat.squeeze(1).float(),
- hps.data.filter_length,
- hps.data.n_mel_channels,
- hps.data.sampling_rate,
- hps.data.hop_length,
- hps.data.win_length,
- hps.data.mel_fmin,
- hps.data.mel_fmax,
- )
- image_dict.update(
- {
- f"gen/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
- y_hat_mel[0].cpu().numpy()
- )
- }
- )
- audio_dict.update(
- {
- f"gen/audio_{batch_idx}_{use_sdp}": y_hat[
- 0, :, : y_hat_lengths[0]
- ]
- }
- )
- image_dict.update(
- {
- f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
- mel[0].cpu().numpy()
- )
- }
- )
- audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
-
- utils.summarize(
- writer=writer_eval,
- global_step=global_step,
- images=image_dict,
- audios=audio_dict,
- audio_sampling_rate=hps.data.sampling_rate,
- )
- generator.train()
-
-
- if __name__ == "__main__":
- run()
|