Browse Source

Add introduction of SpeechCommands example

master
Yanqi-Chen 1 month ago
parent
commit
6e3cf264b1
3 changed files with 56 additions and 23 deletions
  1. +8
    -0
      docs/source/spikingjelly.clock_driven.examples.rst
  2. +47
    -22
      spikingjelly/clock_driven/examples/speechcommands.py
  3. +1
    -1
      spikingjelly/datasets/speechcommands.py

+ 8
- 0
docs/source/spikingjelly.clock_driven.examples.rst View File

@@ -124,6 +124,14 @@ spikingjelly.clock\_driven.examples.spiking\_lstm\_text module
:undoc-members:
:show-inheritance:

spikingjelly.clock\_driven.examples.speechcommands module
--------------------------------------------------------------

.. automodule:: spikingjelly.clock_driven.examples.speechcommands
:members:
:undoc-members:
:show-inheritance:

Module contents
---------------



+ 47
- 22
spikingjelly/clock_driven/examples/speechcommands.py View File

@@ -8,6 +8,38 @@ This code reproduces an audio recognition task using convolutional SNN. It provi
.. note::

To prevent too much dependency like `librosa <https://librosa.org/doc/latest/index.html>`_, we implement MelScale ourselves. We provide two kinds of DCT types: Slaney & HTK. Slaney style is used in the original paper and will be applied by default.

Confusion matrix of TEST set after training (50 epochs):

+------------------------+--------------------------------------------------------------------------------------------------+
| Count | Prediction |
| +-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "Yes" | "Stop" | "No" | "Right" | "Up" | "Left" | "On" | "Down" | "Off" | "Go" | Other | Silence |
+--------------+---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| Ground Truth | "Yes" | 234 | 0 | 2 | 0 | 0 | 3 | 0 | 0 | 0 | 1 | 16 | 0 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "Stop" | 0 | 233 | 0 | 1 | 5 | 0 | 0 | 0 | 0 | 1 | 9 | 0 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "No" | 0 | 1 | 223 | 1 | 0 | 1 | 0 | 5 | 0 | 9 | 12 | 0 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "Right" | 0 | 0 | 0 | 234 | 0 | 0 | 0 | 0 | 0 | 0 | 24 | 1 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "Up" | 0 | 4 | 0 | 0 | 249 | 0 | 0 | 0 | 8 | 0 | 11 | 0 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "Left" | 3 | 1 | 2 | 3 | 1 | 250 | 0 | 0 | 1 | 0 | 6 | 0 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "On" | 0 | 3 | 0 | 0 | 0 | 0 | 231 | 0 | 2 | 1 | 9 | 0 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "Down" | 0 | 0 | 7 | 0 | 0 | 1 | 2 | 230 | 0 | 4 | 8 | 1 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "Off" | 0 | 0 | 2 | 1 | 4 | 2 | 6 | 0 | 237 | 1 | 9 | 0 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "Go" | 0 | 2 | 5 | 0 | 0 | 2 | 0 | 1 | 5 | 220 | 16 | 0 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | Other | 6 | 21 | 12 | 25 | 22 | 19 | 25 | 14 | 11 | 40 | 4072 | 1 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | Silence | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 260 |
+--------------+---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
"""

import torch
@@ -217,12 +249,11 @@ class Pad(object):
class Rescale(object):

def __call__(self, input):
std = torch.std(input, axis=2, keepdims=True, unbiased=False)
std = torch.std(input, axis=2, keepdims=True, unbiased=False) # Numpy std is calculated via the Numpy's biased estimator. https://github.com/romainzimmer/s2net/blob/82c38bf80b55d16d12d0243440e34e52d237a2df/data.py#L201
std.masked_fill_(std == 0, 1)

return input / std


def collate_fn(data):

X_batch = torch.cat([d[0] for d in data])
@@ -295,7 +326,7 @@ if __name__ == '__main__':
parser.add_argument('-sr', '--sample-rate', type=int, default=16000)
parser.add_argument('-lr', '--learning-rate', type=float, default=1e-2)
parser.add_argument('-dir', '--dataset-dir', type=str)
parser.add_argument('-e', '--epoch', type=int, default=15)
parser.add_argument('-e', '--epoch', type=int, default=50)
args = parser.parse_args()

sr = args.sample_rate
@@ -326,17 +357,6 @@ if __name__ == '__main__':
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=16,
sampler=train_sampler, collate_fn=collate_fn)


# print(train_dataset[48310][0])
# print(pad(train_dataset[48310][0]))
# print(spec(pad(train_dataset[48310][0])))
# print(melscale(spec(pad(train_dataset[48310][0]))))
# print(rescale(melscale(spec(pad(train_dataset[48310][0])))))
# exit(0)

# val_dataset = SPEECHCOMMANDS(label_dict, dataset_dir, url="speech_commands_v0.01", split="val", transform=transform, download=True)
# val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=4, collate_fn=collate_fn)

test_dataset = SPEECHCOMMANDS(
label_dict, dataset_dir, silence_cnt=260, url="speech_commands_v0.01", split="test", transform=transform, download=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=16, collate_fn=collate_fn, shuffle=False,
@@ -388,6 +408,8 @@ if __name__ == '__main__':

net.eval()

writer.add_scalar('Train Loss', loss.item(), global_step=net.epochs)

##### TEST #####
with torch.no_grad():
test_sum = 0
@@ -413,18 +435,21 @@ if __name__ == '__main__':
pred = torch.cat(pred).cpu().numpy()
label = torch.cat(label).cpu().numpy()

# Plot confusion matrix
# Confusion matrix
cmatrix = confusion_matrix(label, pred)

plt.clf()
fig = plt.figure()
plt.imshow(cmatrix)
writer.add_figure('Confusion Matrix', figure=fig,
global_step=net.epochs)
print("Confusion Matrix:")
print(cmatrix)

# plt.clf()
# fig = plt.figure()
# plt.imshow(cmatrix)
# writer.add_figure('Confusion Matrix', figure=fig,
# global_step=net.epochs)

test_accuracy = correct_sum / test_sum
writer.add_scalar('Test Acc.', test_accuracy, global_step=net.epochs)

net.epochs += 1
time_end = time.time()
print(
f'Test Acc: {test_accuracy} Loss: {loss} Elapse: {time_end - time_start:.2f}s')
print(f'Test Acc: {test_accuracy} Loss: {loss} Elapse: {time_end - time_start:.2f}s')

+ 1
- 1
spikingjelly/datasets/speechcommands.py View File

@@ -82,7 +82,7 @@ class SPEECHCOMMANDS(Dataset):

#. 0~9的数字,共10个:"One", "Two", "Three", "Four", "Five", "Six", "Seven", "Eight", "Nine".

#. 非关键词,可以视为干扰词,共10个:"Bed", "Bird", "Cat", "Dog", "Happy", "House", "Marvin", "Sheila", "Tree", "Wow".
#. 辅助词,可以视为干扰词,共10个:"Bed", "Bird", "Cat", "Dog", "Happy", "House", "Marvin", "Sheila", "Tree", "Wow".

v0.01版本包含共计30类,64,727个音频片段,v0.02版本包含共计35类,105,829个音频片段。更详细的介绍参见前述论文,以及数据集的README。



Loading…
Cancel
Save