|
- import argparse
- import glob
- import json
- import logging
- import os
- import re
- import subprocess
- import sys
- import traceback
- from multiprocessing import cpu_count
-
- import faiss
- import librosa
- import numpy as np
- import torch
- from scipy.io.wavfile import read
- from sklearn.cluster import MiniBatchKMeans
- from torch.nn import functional as F
-
- MATPLOTLIB_FLAG = False
-
- logging.basicConfig(stream=sys.stdout, level=logging.WARN)
- logger = logging
-
- f0_bin = 256
- f0_max = 1100.0
- f0_min = 50.0
- f0_mel_min = 1127 * np.log(1 + f0_min / 700)
- f0_mel_max = 1127 * np.log(1 + f0_max / 700)
-
- def normalize_f0(f0, x_mask, uv, random_scale=True):
- # calculate means based on x_mask
- uv_sum = torch.sum(uv, dim=1, keepdim=True)
- uv_sum[uv_sum == 0] = 9999
- means = torch.sum(f0[:, 0, :] * uv, dim=1, keepdim=True) / uv_sum
-
- if random_scale:
- factor = torch.Tensor(f0.shape[0], 1).uniform_(0.8, 1.2).to(f0.device)
- else:
- factor = torch.ones(f0.shape[0], 1).to(f0.device)
- # normalize f0 based on means and factor
- f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1)
- if torch.isnan(f0_norm).any():
- exit(0)
- return f0_norm * x_mask
- def plot_data_to_numpy(x, y):
- global MATPLOTLIB_FLAG
- if not MATPLOTLIB_FLAG:
- import matplotlib
- matplotlib.use("Agg")
- MATPLOTLIB_FLAG = True
- mpl_logger = logging.getLogger('matplotlib')
- mpl_logger.setLevel(logging.WARNING)
- import matplotlib.pylab as plt
- import numpy as np
-
- fig, ax = plt.subplots(figsize=(10, 2))
- plt.plot(x)
- plt.plot(y)
- plt.tight_layout()
-
- fig.canvas.draw()
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
- plt.close()
- return data
-
-
- def f0_to_coarse(f0):
- f0_mel = 1127 * (1 + f0 / 700).log()
- a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
- b = f0_mel_min * a - 1.
- f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
- # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
- f0_coarse = torch.round(f0_mel).long()
- f0_coarse = f0_coarse * (f0_coarse > 0)
- f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
- f0_coarse = f0_coarse * (f0_coarse < f0_bin)
- f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
- return f0_coarse
-
- def get_content(cmodel, y):
- with torch.no_grad():
- c = cmodel.extract_features(y.squeeze(1))[0]
- c = c.transpose(1, 2)
- return c
-
- def get_f0_predictor(f0_predictor,hop_length,sampling_rate,**kargs):
- if f0_predictor == "pm":
- from modules.F0Predictor.PMF0Predictor import PMF0Predictor
- f0_predictor_object = PMF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
- elif f0_predictor == "crepe":
- from modules.F0Predictor.CrepeF0Predictor import CrepeF0Predictor
- f0_predictor_object = CrepeF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,device=kargs["device"],threshold=kargs["threshold"])
- elif f0_predictor == "harvest":
- from modules.F0Predictor.HarvestF0Predictor import HarvestF0Predictor
- f0_predictor_object = HarvestF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
- elif f0_predictor == "dio":
- from modules.F0Predictor.DioF0Predictor import DioF0Predictor
- f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
- elif f0_predictor == "rmvpe":
- from modules.F0Predictor.RMVPEF0Predictor import RMVPEF0Predictor
- f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"])
- elif f0_predictor == "fcpe":
- from modules.F0Predictor.FCPEF0Predictor import FCPEF0Predictor
- f0_predictor_object = FCPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"])
- else:
- raise Exception("Unknown f0 predictor")
- return f0_predictor_object
-
- def get_speech_encoder(speech_encoder,device=None,**kargs):
- if speech_encoder == "vec768l12":
- from vencoder.ContentVec768L12 import ContentVec768L12
- speech_encoder_object = ContentVec768L12(device = device)
- elif speech_encoder == "vec256l9":
- from vencoder.ContentVec256L9 import ContentVec256L9
- speech_encoder_object = ContentVec256L9(device = device)
- elif speech_encoder == "vec256l9-onnx":
- from vencoder.ContentVec256L9_Onnx import ContentVec256L9_Onnx
- speech_encoder_object = ContentVec256L9_Onnx(device = device)
- elif speech_encoder == "vec256l12-onnx":
- from vencoder.ContentVec256L12_Onnx import ContentVec256L12_Onnx
- speech_encoder_object = ContentVec256L12_Onnx(device = device)
- elif speech_encoder == "vec768l9-onnx":
- from vencoder.ContentVec768L9_Onnx import ContentVec768L9_Onnx
- speech_encoder_object = ContentVec768L9_Onnx(device = device)
- elif speech_encoder == "vec768l12-onnx":
- from vencoder.ContentVec768L12_Onnx import ContentVec768L12_Onnx
- speech_encoder_object = ContentVec768L12_Onnx(device = device)
- elif speech_encoder == "hubertsoft-onnx":
- from vencoder.HubertSoft_Onnx import HubertSoft_Onnx
- speech_encoder_object = HubertSoft_Onnx(device = device)
- elif speech_encoder == "hubertsoft":
- from vencoder.HubertSoft import HubertSoft
- speech_encoder_object = HubertSoft(device = device)
- elif speech_encoder == "whisper-ppg":
- from vencoder.WhisperPPG import WhisperPPG
- speech_encoder_object = WhisperPPG(device = device)
- elif speech_encoder == "cnhubertlarge":
- from vencoder.CNHubertLarge import CNHubertLarge
- speech_encoder_object = CNHubertLarge(device = device)
- elif speech_encoder == "dphubert":
- from vencoder.DPHubert import DPHubert
- speech_encoder_object = DPHubert(device = device)
- elif speech_encoder == "whisper-ppg-large":
- from vencoder.WhisperPPGLarge import WhisperPPGLarge
- speech_encoder_object = WhisperPPGLarge(device = device)
- elif speech_encoder == "wavlmbase+":
- from vencoder.WavLMBasePlus import WavLMBasePlus
- speech_encoder_object = WavLMBasePlus(device = device)
- else:
- raise Exception("Unknown speech encoder")
- return speech_encoder_object
-
- def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
- assert os.path.isfile(checkpoint_path)
- checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
- iteration = checkpoint_dict['iteration']
- learning_rate = checkpoint_dict['learning_rate']
- if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
- optimizer.load_state_dict(checkpoint_dict['optimizer'])
- saved_state_dict = checkpoint_dict['model']
- model = model.to(list(saved_state_dict.values())[0].dtype)
- if hasattr(model, 'module'):
- state_dict = model.module.state_dict()
- else:
- state_dict = model.state_dict()
- new_state_dict = {}
- for k, v in state_dict.items():
- try:
- # assert "dec" in k or "disc" in k
- # print("load", k)
- new_state_dict[k] = saved_state_dict[k]
- assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
- except Exception:
- if "enc_q" not in k or "emb_g" not in k:
- print("%s is not in the checkpoint,please check your checkpoint.If you're using pretrain model,just ignore this warning." % k)
- logger.info("%s is not in the checkpoint" % k)
- new_state_dict[k] = v
- if hasattr(model, 'module'):
- model.module.load_state_dict(new_state_dict)
- else:
- model.load_state_dict(new_state_dict)
- print("load ")
- logger.info("Loaded checkpoint '{}' (iteration {})".format(
- checkpoint_path, iteration))
- return model, optimizer, learning_rate, iteration
-
-
- def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
- logger.info("Saving model and optimizer state at iteration {} to {}".format(
- iteration, checkpoint_path))
- if hasattr(model, 'module'):
- state_dict = model.module.state_dict()
- else:
- state_dict = model.state_dict()
- torch.save({'model': state_dict,
- 'iteration': iteration,
- 'optimizer': optimizer.state_dict(),
- 'learning_rate': learning_rate}, checkpoint_path)
-
- def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
- """Freeing up space by deleting saved ckpts
-
- Arguments:
- path_to_models -- Path to the model directory
- n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
- sort_by_time -- True -> chronologically delete ckpts
- False -> lexicographically delete ckpts
- """
- ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
- def name_key(_f):
- return int(re.compile("._(\\d+)\\.pth").match(_f).group(1))
- def time_key(_f):
- return os.path.getmtime(os.path.join(path_to_models, _f))
- sort_key = time_key if sort_by_time else name_key
- def x_sorted(_x):
- return sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")], key=sort_key)
- to_del = [os.path.join(path_to_models, fn) for fn in
- (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
- def del_info(fn):
- return logger.info(f".. Free up space by deleting ckpt {fn}")
- def del_routine(x):
- return [os.remove(x), del_info(x)]
- [del_routine(fn) for fn in to_del]
-
- def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
- for k, v in scalars.items():
- writer.add_scalar(k, v, global_step)
- for k, v in histograms.items():
- writer.add_histogram(k, v, global_step)
- for k, v in images.items():
- writer.add_image(k, v, global_step, dataformats='HWC')
- for k, v in audios.items():
- writer.add_audio(k, v, global_step, audio_sampling_rate)
-
-
- def latest_checkpoint_path(dir_path, regex="G_*.pth"):
- f_list = glob.glob(os.path.join(dir_path, regex))
- f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
- x = f_list[-1]
- print(x)
- return x
-
-
- def plot_spectrogram_to_numpy(spectrogram):
- global MATPLOTLIB_FLAG
- if not MATPLOTLIB_FLAG:
- import matplotlib
- matplotlib.use("Agg")
- MATPLOTLIB_FLAG = True
- mpl_logger = logging.getLogger('matplotlib')
- mpl_logger.setLevel(logging.WARNING)
- import matplotlib.pylab as plt
- import numpy as np
-
- fig, ax = plt.subplots(figsize=(10,2))
- im = ax.imshow(spectrogram, aspect="auto", origin="lower",
- interpolation='none')
- plt.colorbar(im, ax=ax)
- plt.xlabel("Frames")
- plt.ylabel("Channels")
- plt.tight_layout()
-
- fig.canvas.draw()
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
- plt.close()
- return data
-
-
- def plot_alignment_to_numpy(alignment, info=None):
- global MATPLOTLIB_FLAG
- if not MATPLOTLIB_FLAG:
- import matplotlib
- matplotlib.use("Agg")
- MATPLOTLIB_FLAG = True
- mpl_logger = logging.getLogger('matplotlib')
- mpl_logger.setLevel(logging.WARNING)
- import matplotlib.pylab as plt
- import numpy as np
-
- fig, ax = plt.subplots(figsize=(6, 4))
- im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
- interpolation='none')
- fig.colorbar(im, ax=ax)
- xlabel = 'Decoder timestep'
- if info is not None:
- xlabel += '\n\n' + info
- plt.xlabel(xlabel)
- plt.ylabel('Encoder timestep')
- plt.tight_layout()
-
- fig.canvas.draw()
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
- plt.close()
- return data
-
-
- def load_wav_to_torch(full_path):
- sampling_rate, data = read(full_path)
- return torch.FloatTensor(data.astype(np.float32)), sampling_rate
-
-
- def load_filepaths_and_text(filename, split="|"):
- with open(filename, encoding='utf-8') as f:
- filepaths_and_text = [line.strip().split(split) for line in f]
- return filepaths_and_text
-
-
- def get_hparams(init=True):
- parser = argparse.ArgumentParser()
- parser.add_argument('-c', '--config', type=str, default="./configs/config.json",
- help='JSON file for configuration')
- parser.add_argument('-m', '--model', type=str, required=True,
- help='Model name')
-
- args = parser.parse_args()
- model_dir = os.path.join("./logs", args.model)
-
- if not os.path.exists(model_dir):
- os.makedirs(model_dir)
-
- config_path = args.config
- config_save_path = os.path.join(model_dir, "config.json")
- if init:
- with open(config_path, "r") as f:
- data = f.read()
- with open(config_save_path, "w") as f:
- f.write(data)
- else:
- with open(config_save_path, "r") as f:
- data = f.read()
- config = json.loads(data)
-
- hparams = HParams(**config)
- hparams.model_dir = model_dir
- return hparams
-
-
- def get_hparams_from_dir(model_dir):
- config_save_path = os.path.join(model_dir, "config.json")
- with open(config_save_path, "r") as f:
- data = f.read()
- config = json.loads(data)
-
- hparams =HParams(**config)
- hparams.model_dir = model_dir
- return hparams
-
-
- def get_hparams_from_file(config_path, infer_mode = False):
- with open(config_path, "r") as f:
- data = f.read()
- config = json.loads(data)
- hparams =HParams(**config) if not infer_mode else InferHParams(**config)
- return hparams
-
-
- def check_git_hash(model_dir):
- source_dir = os.path.dirname(os.path.realpath(__file__))
- if not os.path.exists(os.path.join(source_dir, ".git")):
- logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
- source_dir
- ))
- return
-
- cur_hash = subprocess.getoutput("git rev-parse HEAD")
-
- path = os.path.join(model_dir, "githash")
- if os.path.exists(path):
- saved_hash = open(path).read()
- if saved_hash != cur_hash:
- logger.warn("git hash values are different. {}(saved) != {}(current)".format(
- saved_hash[:8], cur_hash[:8]))
- else:
- open(path, "w").write(cur_hash)
-
-
- def get_logger(model_dir, filename="train.log"):
- global logger
- logger = logging.getLogger(os.path.basename(model_dir))
- logger.setLevel(logging.DEBUG)
-
- formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
- if not os.path.exists(model_dir):
- os.makedirs(model_dir)
- h = logging.FileHandler(os.path.join(model_dir, filename))
- h.setLevel(logging.DEBUG)
- h.setFormatter(formatter)
- logger.addHandler(h)
- return logger
-
-
- def repeat_expand_2d(content, target_len, mode = 'left'):
- # content : [h, t]
- return repeat_expand_2d_left(content, target_len) if mode == 'left' else repeat_expand_2d_other(content, target_len, mode)
-
-
-
- def repeat_expand_2d_left(content, target_len):
- # content : [h, t]
-
- src_len = content.shape[-1]
- target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(content.device)
- temp = torch.arange(src_len+1) * target_len / src_len
- current_pos = 0
- for i in range(target_len):
- if i < temp[current_pos+1]:
- target[:, i] = content[:, current_pos]
- else:
- current_pos += 1
- target[:, i] = content[:, current_pos]
-
- return target
-
-
- # mode : 'nearest'| 'linear'| 'bilinear'| 'bicubic'| 'trilinear'| 'area'
- def repeat_expand_2d_other(content, target_len, mode = 'nearest'):
- # content : [h, t]
- content = content[None,:,:]
- target = F.interpolate(content,size=target_len,mode=mode)[0]
- return target
-
-
- def mix_model(model_paths,mix_rate,mode):
- mix_rate = torch.FloatTensor(mix_rate)/100
- model_tem = torch.load(model_paths[0])
- models = [torch.load(path)["model"] for path in model_paths]
- if mode == 0:
- mix_rate = F.softmax(mix_rate,dim=0)
- for k in model_tem["model"].keys():
- model_tem["model"][k] = torch.zeros_like(model_tem["model"][k])
- for i,model in enumerate(models):
- model_tem["model"][k] += model[k]*mix_rate[i]
- torch.save(model_tem,os.path.join(os.path.curdir,"output.pth"))
- return os.path.join(os.path.curdir,"output.pth")
-
- def change_rms(data1, sr1, data2, sr2, rate): # 1是输入音频,2是输出音频,rate是2的占比 from RVC
- # print(data1.max(),data2.max())
- rms1 = librosa.feature.rms(
- y=data1, frame_length=sr1 // 2 * 2, hop_length=sr1 // 2
- ) # 每半秒一个点
- rms2 = librosa.feature.rms(y=data2.detach().cpu().numpy(), frame_length=sr2 // 2 * 2, hop_length=sr2 // 2)
- rms1 = torch.from_numpy(rms1).to(data2.device)
- rms1 = F.interpolate(
- rms1.unsqueeze(0), size=data2.shape[0], mode="linear"
- ).squeeze()
- rms2 = torch.from_numpy(rms2).to(data2.device)
- rms2 = F.interpolate(
- rms2.unsqueeze(0), size=data2.shape[0], mode="linear"
- ).squeeze()
- rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-6)
- data2 *= (
- torch.pow(rms1, torch.tensor(1 - rate))
- * torch.pow(rms2, torch.tensor(rate - 1))
- )
- return data2
-
- def train_index(spk_name,root_dir = "dataset/44k/"): #from: RVC https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI
- n_cpu = cpu_count()
- print("The feature index is constructing.")
- exp_dir = os.path.join(root_dir,spk_name)
- listdir_res = []
- for file in os.listdir(exp_dir):
- if ".wav.soft.pt" in file:
- listdir_res.append(os.path.join(exp_dir,file))
- if len(listdir_res) == 0:
- raise Exception("You need to run preprocess_hubert_f0.py!")
- npys = []
- for name in sorted(listdir_res):
- phone = torch.load(name)[0].transpose(-1,-2).numpy()
- npys.append(phone)
- big_npy = np.concatenate(npys, 0)
- big_npy_idx = np.arange(big_npy.shape[0])
- np.random.shuffle(big_npy_idx)
- big_npy = big_npy[big_npy_idx]
- if big_npy.shape[0] > 2e5:
- # if(1):
- info = "Trying doing kmeans %s shape to 10k centers." % big_npy.shape[0]
- print(info)
- try:
- big_npy = (
- MiniBatchKMeans(
- n_clusters=10000,
- verbose=True,
- batch_size=256 * n_cpu,
- compute_labels=False,
- init="random",
- )
- .fit(big_npy)
- .cluster_centers_
- )
- except Exception:
- info = traceback.format_exc()
- print(info)
- n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
- index = faiss.index_factory(big_npy.shape[1] , "IVF%s,Flat" % n_ivf)
- index_ivf = faiss.extract_index_ivf(index) #
- index_ivf.nprobe = 1
- index.train(big_npy)
- batch_size_add = 8192
- for i in range(0, big_npy.shape[0], batch_size_add):
- index.add(big_npy[i : i + batch_size_add])
- # faiss.write_index(
- # index,
- # f"added_{spk_name}.index"
- # )
- print("Successfully build index")
- return index
-
-
- class HParams():
- def __init__(self, **kwargs):
- for k, v in kwargs.items():
- if type(v) == dict:
- v = HParams(**v)
- self[k] = v
-
- def keys(self):
- return self.__dict__.keys()
-
- def items(self):
- return self.__dict__.items()
-
- def values(self):
- return self.__dict__.values()
-
- def __len__(self):
- return len(self.__dict__)
-
- def __getitem__(self, key):
- return getattr(self, key)
-
- def __setitem__(self, key, value):
- return setattr(self, key, value)
-
- def __contains__(self, key):
- return key in self.__dict__
-
- def __repr__(self):
- return self.__dict__.__repr__()
-
- def get(self,index):
- return self.__dict__.get(index)
-
-
- class InferHParams(HParams):
- def __init__(self, **kwargs):
- for k, v in kwargs.items():
- if type(v) == dict:
- v = InferHParams(**v)
- self[k] = v
-
- def __getattr__(self,index):
- return self.get(index)
-
-
- class Volume_Extractor:
- def __init__(self, hop_size = 512):
- self.hop_size = hop_size
-
- def extract(self, audio): # audio: 2d tensor array
- if not isinstance(audio,torch.Tensor):
- audio = torch.Tensor(audio)
- n_frames = int(audio.size(-1) // self.hop_size)
- audio2 = audio ** 2
- audio2 = torch.nn.functional.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode = 'reflect')
- volume = torch.nn.functional.unfold(audio2[:,None,None,:],(1,self.hop_size),stride=self.hop_size)[:,:,:n_frames].mean(dim=1)[0]
- volume = torch.sqrt(volume)
- return volume
|