|
- import os
- from typing import Callable, Tuple, Dict, Optional
- from pathlib import Path
-
- import torch
- import torchaudio
- from torch.utils.data import Dataset
- from torch import Tensor
- from torchvision.datasets.utils import (
- download_url,
- extract_archive
- )
- from torchvision.datasets.utils import verify_str_arg
- import numpy as np
- from random import choice
-
- FOLDER_IN_ARCHIVE = "SpeechCommands"
- URL = "speech_commands_v0.02"
- HASH_DIVIDER = "_nohash_"
- EXCEPT_FOLDER = "_background_noise_"
- _CHECKSUMS = {
- "https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz":
- "3cd23799cb2bbdec517f1cc028f8d43c",
- "https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz":
- "6b74f3901214cb2c2934e98196829835",
- }
- VAL_RECORD = "validation_list.txt"
- TEST_RECORD = "testing_list.txt"
- TRAIN_RECORD = "training_list.txt"
-
-
- def load_speechcommands_item(relpath: str, path: str) -> Tuple[Tensor, int, str, str, int]:
- filepath = os.path.join(path, relpath)
- label, filename = os.path.split(relpath)
- speaker, _ = os.path.splitext(filename)
-
- speaker_id, utterance_number = speaker.split(HASH_DIVIDER)
- utterance_number = int(utterance_number)
-
- # Load audio
- waveform, sample_rate = torchaudio.load(filepath)
- return waveform, sample_rate, label, speaker_id, utterance_number
-
-
- class SPEECHCOMMANDS(Dataset):
- def __init__(self,
- label_dict: Dict,
- root: str,
- silence_cnt: Optional[int] = 0,
- silence_size: Optional[int] = 16000,
- transform: Optional[Callable] = None,
- url: Optional[str] = URL,
- split: Optional[str] = "train",
- folder_in_archive: Optional[str] = FOLDER_IN_ARCHIVE,
- download: Optional[bool] = False) -> None:
- '''
- :param label_dict: 标签与类别的对应字典
- :type label_dict: Dict
- :param root: 数据集的根目录
- :type root: str
- :param silence_cnt: Silence数据的数量
- :type silence_cnt: int, optional
- :param silence_size: Silence数据的尺寸
- :type silence_size: int, optional
- :param transform: A function/transform that takes in a raw audio
- :type transform: Callable, optional
- :param url: 数据集版本,默认为v0.02
- :type url: str, optional
- :param split: 数据集划分,可以是 ``"train", "test", "val"``,默认为 ``"train"``
- :type split: str, optional
- :param folder_in_archive: 解压后的目录名称,默认为 ``"SpeechCommands"``
- :type folder_in_archive: str, optional
- :param download: 是否下载数据,默认为False
- :type download: bool, optional
-
- SpeechCommands语音数据集,出自 `Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition <https://arxiv.org/abs/1804.03209>`_,根据给出的测试集与验证集列表进行了划分,包含v0.01与v0.02两个版本。
-
- 数据集包含三大类单词的音频:
-
- #. 指令单词,共10个,"Yes", "No", "Up", "Down", "Left", "Right", "On", "Off", "Stop", "Go". 对于v0.02,还额外增加了5个:"Forward", "Backward", "Follow", "Learn", "Visual".
-
- #. 0~9的数字,共10个:"One", "Two", "Three", "Four", "Five", "Six", "Seven", "Eight", "Nine".
-
- #. 辅助词,可以视为干扰词,共10个:"Bed", "Bird", "Cat", "Dog", "Happy", "House", "Marvin", "Sheila", "Tree", "Wow".
-
- v0.01版本包含共计30类,64,727个音频片段,v0.02版本包含共计35类,105,829个音频片段。更详细的介绍参见前述论文,以及数据集的README。
-
- 代码实现基于torchaudio并扩充了功能,同时也参考了 `原论文的实现 <https://github.com/romainzimmer/s2net/blob/b073f755e70966ef133bbcd4a8f0343354f5edcd/data.py>`_。
- '''
-
- self.split = verify_str_arg(split, "split", ("train", "val", "test"))
- self.label_dict = label_dict
- self.transform = transform
- self.silence_cnt = silence_cnt
- self.silence_size = silence_size
-
- if silence_cnt < 0:
- raise ValueError(f"Invalid silence_cnt parameter: {silence_cnt}")
- if silence_size <= 0:
- raise ValueError(f"Invalid silence_size parameter: {silence_size}")
-
- if url in [
- "speech_commands_v0.01",
- "speech_commands_v0.02",
- ]:
- base_url = "https://storage.googleapis.com/download.tensorflow.org/data/"
- ext_archive = ".tar.gz"
-
- url = os.path.join(base_url, url + ext_archive)
-
- basename = os.path.basename(url)
- archive = os.path.join(root, basename)
-
- basename = basename.rsplit(".", 2)[0]
- folder_in_archive = os.path.join(folder_in_archive, basename)
-
- self._path = os.path.join(root, folder_in_archive)
-
- self.noise_list = sorted(str(p) for p in Path(self._path).glob('_background_noise_/*.wav'))
-
- if download:
- if not os.path.isdir(self._path):
- if not os.path.isfile(archive):
- checksum = _CHECKSUMS.get(url, None)
- download_url(url, root, md5=checksum)
- extract_archive(archive, self._path)
- elif not os.path.isdir(self._path):
- raise FileNotFoundError("Audio data not found. Please specify \"download=True\" and try again.")
-
-
- if self.split == "train":
- record = os.path.join(self._path, TRAIN_RECORD)
- if os.path.exists(record):
- with open(record, 'r') as f:
- self._walker = list([line.rstrip('\n') for line in f])
- else:
- print("No training list, generating...")
- walker = sorted(str(p) for p in Path(self._path).glob('*/*.wav'))
- walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker)
- walker = map(lambda w: os.path.relpath(w, self._path), walker)
-
- walker = set(walker)
-
- val_record = os.path.join(self._path, VAL_RECORD)
- with open(val_record, 'r') as f:
- val_walker = set([line.rstrip('\n') for line in f])
-
- test_record = os.path.join(self._path, TEST_RECORD)
- with open(test_record, 'r') as f:
- test_walker = set([line.rstrip('\n') for line in f])
-
- walker = walker - val_walker - test_walker
- self._walker = list(walker)
-
- with open(record, 'w') as f:
- f.write('\n'.join(self._walker))
-
- print("Training list generated!")
-
- labels = [self.label_dict.get(os.path.split(relpath)[0]) for relpath in self._walker]
- label_weights = 1. / np.unique(labels, return_counts=True)[1]
- if self.silence_cnt == 0:
- label_weights /= np.sum(label_weights)
- self.weights = torch.DoubleTensor([label_weights[label] for label in labels])
- else:
- silence_weight = 1. / self.silence_cnt
- total_weight = np.sum(label_weights) + silence_weight
- label_weights /= total_weight
- self.weights = torch.DoubleTensor([label_weights[label] for label in labels] + [silence_weight / total_weight] * self.silence_cnt)
-
- else:
- if self.split == "val":
- record = os.path.join(self._path, VAL_RECORD)
- else:
- record = os.path.join(self._path, TEST_RECORD)
- with open(record, 'r') as f:
- self._walker = list([line.rstrip('\n') for line in f])
-
- def __getitem__(self, n: int) -> Tuple[Tensor, int]:
- if n < len(self._walker):
- fileid = self._walker[n]
- waveform, sample_rate, label, speaker_id, utterance_number = load_speechcommands_item(fileid, self._path)
- else:
- # Silence data are randomly and dynamically generated from noise data
-
- # Load random noise
- noisepath = choice(self.noise_list)
- waveform, sample_rate = torchaudio.load(noisepath)
-
- # Random crop
- offset = np.random.randint(waveform.shape[1] - self.silence_size)
- waveform = waveform[:, offset:offset + self.silence_size]
- label = "_silence_"
-
- m = waveform.abs().max()
- if m > 0:
- waveform /= m
- if self.transform is not None:
- waveform = self.transform(waveform)
-
- label = self.label_dict.get(label)
- return waveform, label
-
- def __len__(self) -> int:
- return len(self._walker) + self.silence_cnt
|