|
|
@@ -0,0 +1,212 @@ |
|
|
|
# Copyright 2022 Huawei Technologies Co., Ltd |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
# You may obtain a copy of the License at |
|
|
|
# |
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
# |
|
|
|
# Unless required by applicable law or agreed to in writing, software |
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================ |
|
|
|
|
|
|
|
import argparse |
|
|
|
import numpy as np |
|
|
|
from model import ConvTasNet |
|
|
|
from mir_eval.separation import bss_eval_sources |
|
|
|
from data_test import DatasetGenerator |
|
|
|
from Loss_final1 import loss |
|
|
|
#from function import get_input_with_list |
|
|
|
import mindspore |
|
|
|
import mindspore.dataset as ds |
|
|
|
import mindspore.ops as ops |
|
|
|
from mindspore import context |
|
|
|
from mindspore import load_checkpoint, load_param_into_net |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser('Evaluate separation performance using TasNet') |
|
|
|
parser.add_argument('--model_path', type=str, |
|
|
|
default=r"/home/heu_MEDAI/RenQQ/6.5relu/src/exp/temp/ConvTasnet_3-10_303.ckpt", |
|
|
|
help='Path to model file created by training') |
|
|
|
parser.add_argument('--data_dir', type=str, |
|
|
|
default=r"/mass_data/dataset/LS-2mix/Libri2Mix/tt", |
|
|
|
help='directory including mix.json, s1.json and s2.json') |
|
|
|
parser.add_argument('--cal_sdr', type=int, default=0, |
|
|
|
help='Whether calculate SDR, add this option because calculation of SDR is very slow') |
|
|
|
parser.add_argument('--use_cuda', type=int, default=0, |
|
|
|
help='Whether use GPU') |
|
|
|
parser.add_argument('--sample_rate', default=8000, type=int, |
|
|
|
help='Sample rate') |
|
|
|
parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同 |
|
|
|
help='Segment length (seconds)') |
|
|
|
parser.add_argument('--batch_size', default=4, type=int, # 需要抛弃的音频长度 |
|
|
|
help='Batch size') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Network architecture |
|
|
|
parser.add_argument('--N', default=512, type=int,#256 |
|
|
|
help='Number of filters in autoencoder') |
|
|
|
parser.add_argument('--L', default=20, type=int, |
|
|
|
help='Length of the filters in samples (40=5ms at 8kHZ)') |
|
|
|
parser.add_argument('--B', default=256, type=int,#256 |
|
|
|
help='Number of channels in bottleneck 1 × 1-conv block') #1 × 1-conv的通道数 |
|
|
|
parser.add_argument('--H', default=512, type=int,#512 |
|
|
|
help='Number of channels in convolutional blocks') #卷几块的通道数 |
|
|
|
parser.add_argument('--P', default=3, type=int, |
|
|
|
help='Kernel size in convolutional blocks') #卷积核的大小 |
|
|
|
parser.add_argument('--X', default=8, type=int, |
|
|
|
help='Number of convolutional blocks in each repeat') #每次8个卷几块 |
|
|
|
parser.add_argument('--R', default=4, type=int, |
|
|
|
help='Number of repeats') # 重复4次 |
|
|
|
parser.add_argument('--C', default=2, type=int, |
|
|
|
help='Number of speakers') #说话者数量 |
|
|
|
parser.add_argument('--norm_type', default='gLN', type=str, |
|
|
|
choices=['gLN', 'cLN', 'BN'], help='Layer norm type') #归一化的方法 |
|
|
|
parser.add_argument('--causal', type=int, default=0, |
|
|
|
help='Causal (1) or noncausal(0) training') #因果设定 |
|
|
|
parser.add_argument('--mask_nonlinear', default='relu', type=str, |
|
|
|
choices=['relu', 'softmax'], help='non-linear to generate mask') #产生mask的非线性层 |
|
|
|
|
|
|
|
|
|
|
|
def evaluate(args): |
|
|
|
total_SISNRi = 0 |
|
|
|
total_SDRi = 0 |
|
|
|
total_cnt = 0 |
|
|
|
|
|
|
|
# Load model |
|
|
|
model = ConvTasNet(args.N, args.L, args.B, args.H, args.P, args.X, args.R, |
|
|
|
args.C, norm_type=args.norm_type, causal=args.causal, mask_nonlinear=args.mask_nonlinear) |
|
|
|
model.set_train(mode=False) |
|
|
|
param_dict = load_checkpoint(args.model_path) |
|
|
|
load_param_into_net(model, param_dict) |
|
|
|
print(model) |
|
|
|
|
|
|
|
# Load data |
|
|
|
tt_dataset = DatasetGenerator(args.data_dir, args.batch_size, |
|
|
|
sample_rate=args.sample_rate, segment=args.segment) |
|
|
|
tt_loader = ds.GeneratorDataset(tt_dataset, ["mixture", "lens", "sources"], shuffle=False) |
|
|
|
tt_loader = tt_loader.batch(batch_size=8) |
|
|
|
|
|
|
|
for data in tt_loader.create_dict_iterator(): |
|
|
|
padded_mixture = data["mixture"] |
|
|
|
mixture_lengths = data["lens"] |
|
|
|
padded_source = data["sources"] |
|
|
|
padded_mixture = ops.Cast()(padded_mixture, mindspore.float32) |
|
|
|
padded_source = ops.Cast()(padded_source, mindspore.float32) |
|
|
|
# mixture_lengths_with_list = get_input_with_list(args.data_dir) |
|
|
|
mixture_lengths_with_list = mixture_lengths.asnumpy().tolist() |
|
|
|
estimate_source = model(padded_mixture) # [B, C, T] |
|
|
|
|
|
|
|
my_loss = loss() |
|
|
|
estimate_source, reorder_estimate_source = \ |
|
|
|
my_loss(padded_source, estimate_source, mixture_lengths) |
|
|
|
# Remove padding and flat |
|
|
|
mixture = remove_pad(padded_mixture, mixture_lengths_with_list) |
|
|
|
source = remove_pad(padded_source, mixture_lengths_with_list) |
|
|
|
# NOTE: use reorder estimate source |
|
|
|
estimate_source = remove_pad(reorder_estimate_source, |
|
|
|
mixture_lengths_with_list) |
|
|
|
# for each utterance |
|
|
|
for mix, src_ref, src_est in zip(mixture, source, estimate_source): |
|
|
|
print("Utt", total_cnt + 1) |
|
|
|
# Compute SDRi |
|
|
|
if args.cal_sdr: |
|
|
|
avg_SDRi = cal_SDRi(src_ref, src_est, mix) |
|
|
|
total_SDRi += avg_SDRi |
|
|
|
print("\tSDRi={0:.2f}".format(-avg_SDRi)) |
|
|
|
# Compute SI-SNRi |
|
|
|
avg_SISNRi = cal_SISNRi(src_ref, src_est, mix) |
|
|
|
print("\tSI-SNRi={0:.2f}".format(avg_SISNRi)) |
|
|
|
total_SISNRi += avg_SISNRi |
|
|
|
total_cnt += 1 |
|
|
|
if args.cal_sdr: |
|
|
|
print("Average SDR improvement: {0:.2f}".format(-total_SDRi / total_cnt)) |
|
|
|
print("Average SISNR improvement: {0:.2f}".format(-total_SISNRi / total_cnt)) |
|
|
|
|
|
|
|
|
|
|
|
def cal_SDRi(src_ref, src_est, mix): |
|
|
|
"""Calculate Source-to-Distortion Ratio improvement (SDRi). |
|
|
|
NOTE: bss_eval_sources is very very slow. |
|
|
|
Args: |
|
|
|
src_ref: numpy.ndarray, [C, T] |
|
|
|
src_est: numpy.ndarray, [C, T], reordered by best PIT permutation |
|
|
|
mix: numpy.ndarray, [T] |
|
|
|
Returns: |
|
|
|
average_SDRi |
|
|
|
""" |
|
|
|
src_anchor = np.stack([mix, mix], axis=0) |
|
|
|
sdr = bss_eval_sources(src_ref, src_est) |
|
|
|
sdr0 = bss_eval_sources(src_ref, src_anchor) |
|
|
|
avg_SDRi = ((sdr[0]-sdr0[0]) + (sdr[1]-sdr0[1])) / 2 |
|
|
|
# print("SDRi1: {0:.2f}, SDRi2: {1:.2f}".format(sdr[0]-sdr0[0], sdr[1]-sdr0[1])) |
|
|
|
return avg_SDRi |
|
|
|
|
|
|
|
|
|
|
|
def cal_SISNRi(src_ref, src_est, mix): |
|
|
|
"""Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi) |
|
|
|
Args: |
|
|
|
src_ref: numpy.ndarray, [C, T] |
|
|
|
src_est: numpy.ndarray, [C, T], reordered by best PIT permutation |
|
|
|
mix: numpy.ndarray, [T] |
|
|
|
Returns: |
|
|
|
average_SISNRi |
|
|
|
""" |
|
|
|
sisnr1 = cal_SISNR(src_ref[0], src_est[0]) |
|
|
|
sisnr2 = cal_SISNR(src_ref[1], src_est[1]) |
|
|
|
sisnr1b = cal_SISNR(src_ref[0], mix) |
|
|
|
sisnr2b = cal_SISNR(src_ref[1], mix) |
|
|
|
avg_SISNRi = ((sisnr1 - sisnr1b) + (sisnr2 - sisnr2b)) / 2 |
|
|
|
return avg_SISNRi |
|
|
|
|
|
|
|
|
|
|
|
def cal_SISNR(ref_sig, out_sig, eps=1e-8): |
|
|
|
"""Calculate Scale-Invariant Source-to-Noise Ratio (SI-SNR) |
|
|
|
Args: |
|
|
|
ref_sig: numpy.ndarray, [T] |
|
|
|
out_sig: numpy.ndarray, [T] |
|
|
|
Returns: |
|
|
|
SISNR |
|
|
|
""" |
|
|
|
assert len(ref_sig) == len(out_sig) |
|
|
|
ref_sig = ref_sig - np.mean(ref_sig) |
|
|
|
out_sig = out_sig - np.mean(out_sig) |
|
|
|
ref_energy = np.sum(ref_sig ** 2) + eps |
|
|
|
proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy |
|
|
|
noise = out_sig - proj |
|
|
|
ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps) |
|
|
|
sisnr = 10 * np.log(ratio + eps) / np.log(10.0) |
|
|
|
return sisnr |
|
|
|
|
|
|
|
def remove_pad(inputs, inputs_lengths): |
|
|
|
""" |
|
|
|
Args: |
|
|
|
inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size |
|
|
|
inputs_lengths: torch.Tensor, [B] |
|
|
|
Returns: |
|
|
|
results: a list containing B items, each item is [C, T], T varies |
|
|
|
""" |
|
|
|
results = [] |
|
|
|
dim = inputs.ndim |
|
|
|
if dim == 3: |
|
|
|
C = inputs.shape[1] |
|
|
|
# for input, length in zip(inputs, inputs_lengths): |
|
|
|
# if dim == 3: # [B, C, T] |
|
|
|
# results.append(input[:, :length].view(C, -1).asnumpy()) |
|
|
|
# elif dim == 2: # [B, T] |
|
|
|
# results.append(input[:length].view(-1).asnumpy()) |
|
|
|
for i, data in enumerate(inputs): |
|
|
|
if dim == 3: # [B, C, T] |
|
|
|
results.append(data[:, :inputs_lengths[i]].view(C, -1).asnumpy()) |
|
|
|
elif dim == 2: # [B, T] |
|
|
|
results.append(data[:inputs_lengths[i]].view(-1).asnumpy()) |
|
|
|
return results |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", device_id=4) |
|
|
|
arg = parser.parse_args() |
|
|
|
print(arg) |
|
|
|
evaluate(arg) |
|
|
|
|