#7 上传文件至 ''

Open
dfh123157750 wants to merge 1 commits from dfh123157750-patch-5 into master
  1. +180
    -0
      data_test.py

+ 180
- 0
data_test.py View File

@@ -0,0 +1,180 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import json
import math
import os
import argparse
import librosa
import numpy as np
parser = argparse.ArgumentParser(
"Dual-path transformer"
"with Permutation Invariant Training")
parser.add_argument('--train_dir', type=str, default='/home/heu_MEDAI/RenQQ/6.24/src/out/tr',
help='directory including mix.json, s1.json and s2.json')
parser.add_argument('--valid_dir', type=str, default='/mass_data/dataset/LS-2mix/Libri2Mix/cv',
help='directory including mix.json, s1.json and s2.json')
parser.add_argument('--batch_size', default=4, type=int, #default =3
help='Batch size')
parser.add_argument('--sample_rate', default=8000, type=int,
help='Sample rate')
parser.add_argument('--segment', default=4, type=float,
help='Segment length (seconds)')
def load_mixtures_and_sources(batch):
"""
Each info include wav path and wav duration.
Returns:
mixtures: a list containing B items, each item is T np.ndarray
sources: a list containing B items, each item is T x C np.ndarray
T varies from item to item.
"""
mixtures, sources = [], []
mix_infos, s1_infos, s2_infos, sample_rate, segment_len = batch
# for each utterance
for mix_info, s1_info, s2_info in zip(mix_infos, s1_infos, s2_infos):
mix_path = mix_info[0]
s1_path = s1_info[0]
s2_path = s2_info[0]
assert mix_info[1] == s1_info[1] and s1_info[1] == s2_info[1]
# read wav file
mix, _ = librosa.load(mix_path, sr=sample_rate)
s1, _ = librosa.load(s1_path, sr=sample_rate)
s2, _ = librosa.load(s2_path, sr=sample_rate)
# merge s1 and s2
s = np.dstack((s1, s2))[0] # T x C, C = 2
utt_len = mix.shape[-1]
if segment_len >= 0:
# segment
for i in range(0, utt_len - segment_len + 1, segment_len):
mixtures.append(mix[i:i+segment_len])
sources.append(s[i:i+segment_len])
if utt_len % segment_len != 0:
mixtures.append(mix[-segment_len:])
sources.append(s[-segment_len:])
else: # full utterance
mixtures.append(mix)
sources.append(s)
return mixtures, sources
def pad_list(xs):
n_batch = len(xs)
max_len = max(x.shape for x in xs)
if len(max_len) == 1:
pad = np.zeros((n_batch, max_len[0]), np.float32)
else:
pad = np.zeros((n_batch, max_len[0], max_len[1]), np.float32)
for i in range(n_batch):
temp = xs[i].shape
pad[i, :temp[0]] = xs[i]
return pad
class DatasetGenerator:
def __init__(self, json_dir, batch_size, sample_rate=8000, segment=4.0, cv_maxlen=8.0):
"""
Args:
json_dir: directory including mix.json, s1.json and s2.json
segment: duration of audio segment, when set to -1, use full audio
xxx_infos is a list and each item is a tuple (wav_file, #samples)
"""
super(DatasetGenerator, self).__init__()
mix_json = os.path.join(json_dir, 'mix.json')
s1_json = os.path.join(json_dir, 's1.json')
s2_json = os.path.join(json_dir, 's2.json')
with open(mix_json, 'r') as f:
mix_infos = json.load(f)
with open(s1_json, 'r') as f:
s1_infos = json.load(f)
with open(s2_json, 'r') as f:
s2_infos = json.load(f)
# sort it by #samples (impl bucket)
def sort(infos):
return sorted(infos, key=lambda info: int(info[1]), reverse=True)
sorted_mix_infos = sort(mix_infos)
sorted_s1_infos = sort(s1_infos)
sorted_s2_infos = sort(s2_infos)
# segment length and count dropped utts
segment_len = int(segment * sample_rate) # 4s * 8000/s = 32000 samples
drop_utt, drop_len = 0, 0
for _, sample in sorted_mix_infos:
if sample < segment_len:
drop_utt += 1
drop_len += sample
print("Drop {} utts({:.2f} h) which is short than {} samples".format(
drop_utt, drop_len/sample_rate/36000, segment_len))
# generate minibach infomations
mixture_pad = []
lens = []
source_pad = []
start = 0
while True:
num_segments = 0
end = start
part_mix, part_s1, part_s2 = [], [], []
while num_segments < batch_size and end < len(sorted_mix_infos):
utt_len = int(sorted_mix_infos[end][1])
if utt_len >= segment_len: # skip too short utt
num_segments += math.ceil(utt_len / segment_len)
# Ensure num_segments is less than batch_size
if num_segments > batch_size:
# if num_segments of 1st audio > batch_size, skip it
if start == end: end += 1
break
part_mix.append(sorted_mix_infos[end])
part_s1.append(sorted_s1_infos[end])
part_s2.append(sorted_s2_infos[end])
end += 1
if part_mix:
meta = [part_mix, part_s1, part_s2, sample_rate, segment_len]
mixtures_pad, ilens, sources_pad = self.sort_and_pad(meta)
for i in range(len(mixtures_pad)):
mixture_pad.append(mixtures_pad[i])
lens.append(ilens[i])
source_pad.append(sources_pad[i])
if end == len(sorted_mix_infos):
break
start = end
self.mixture = mixture_pad
self.len = lens
self.sources = source_pad
def __getitem__(self, index):
return (self.mixture[index], self.len[index], self.sources[index])
def __len__(self):
return len(self.mixture)
def sort_and_pad(self, batch):
#assert len(batch) == 1
mixtures, sources = load_mixtures_and_sources(batch)
# get batch of lengths of input sequences
ilens = np.array([mix.shape[0] for mix in mixtures])
mixtures_pad = pad_list([mix for mix in mixtures])
sources_pad = pad_list([s for s in sources])
sources_pad = sources_pad.transpose((0, 2, 1))
return mixtures_pad, ilens, sources_pad

Loading…
Cancel
Save