#1 master

Merged
junan merged 167 commits from OpenI/spikingjelly:master into master 1 year ago
  1. +62
    -13
      README.md
  2. +58
    -8
      README_cn.md
  3. +2
    -0
      bugs.md
  4. BIN
      docs/source/_static/API/clock_driven/lava_exchange/step_quantize.pdf
  5. +2016
    -0
      docs/source/_static/API/clock_driven/lava_exchange/step_quantize.svg
  6. BIN
      docs/source/_static/API/clock_driven/surrogate/QPseudoSpike.pdf
  7. +2194
    -0
      docs/source/_static/API/clock_driven/surrogate/QPseudoSpike.svg
  8. BIN
      docs/source/_static/API/clock_driven/surrogate/S2NN.pdf
  9. +2379
    -0
      docs/source/_static/API/clock_driven/surrogate/S2NN.svg
  10. +50
    -5
      docs/source/clock_driven/13_neuromorphic_datasets.rst
  11. +387
    -0
      docs/source/clock_driven/17_loihi_sim.rst
  12. +157
    -179
      docs/source/clock_driven/5_ann2snn.rst
  13. +2
    -2
      docs/source/clock_driven_en/0_neuron.rst
  14. +51
    -3
      docs/source/clock_driven_en/13_neuromorphic_datasets.rst
  15. +305
    -352
      docs/source/clock_driven_en/5_ann2snn.rst
  16. +1
    -1
      docs/source/conf.py
  17. +13
    -7
      docs/source/index.rst
  18. +10
    -0
      docs/source/spikingjelly.clock_driven.lava_exchange.rst
  19. +1
    -0
      docs/source/spikingjelly.clock_driven.rst
  20. +17
    -1
      docs/source/spikingjelly.datasets.rst
  21. +25
    -11
      publications.md
  22. +0
    -1
      requirements.txt
  23. +1
    -1
      setup.py
  24. +2
    -476
      spikingjelly/clock_driven/ann2snn/__init__.py
  25. +107
    -0
      spikingjelly/clock_driven/ann2snn/converter.py
  26. +0
    -175
      spikingjelly/clock_driven/ann2snn/examples/cnn_fashionmnist.py
  27. +132
    -198
      spikingjelly/clock_driven/ann2snn/examples/cnn_mnist.py
  28. +0
    -0
      spikingjelly/clock_driven/ann2snn/examples/model_sample/__init__.py
  29. +0
    -0
      spikingjelly/clock_driven/ann2snn/examples/model_sample/cifar10/__init__.py
  30. +0
    -70
      spikingjelly/clock_driven/ann2snn/examples/model_sample/cifar10/vgg.py
  31. +0
    -0
      spikingjelly/clock_driven/ann2snn/examples/model_sample/imagenet/__init__.py
  32. +0
    -339
      spikingjelly/clock_driven/ann2snn/examples/model_sample/imagenet/resnet.py
  33. +49
    -74
      spikingjelly/clock_driven/ann2snn/examples/resnet18_cifar10.py
  34. +0
    -170
      spikingjelly/clock_driven/ann2snn/examples/utils.py
  35. +0
    -0
      spikingjelly/clock_driven/ann2snn/kernels/__init__.py
  36. +0
    -1215
      spikingjelly/clock_driven/ann2snn/kernels/onnx.py
  37. +0
    -127
      spikingjelly/clock_driven/ann2snn/kernels/pytorch.py
  38. +87
    -126
      spikingjelly/clock_driven/ann2snn/modules.py
  39. +12
    -6
      spikingjelly/clock_driven/ann2snn/sample_models/cifar10_resnet.py
  40. +28
    -0
      spikingjelly/clock_driven/ann2snn/sample_models/mnist_cnn.py
  41. +29
    -0
      spikingjelly/clock_driven/ann2snn/utils.py
  42. +177
    -65
      spikingjelly/clock_driven/cu_kernel_opt.py
  43. +9
    -9
      spikingjelly/clock_driven/encoding.py
  44. +1
    -1
      spikingjelly/clock_driven/examples/DQN_state.py
  45. +9
    -9
      spikingjelly/clock_driven/examples/Spiking_DQN_state.py
  46. +16
    -3
      spikingjelly/clock_driven/examples/lif_fc_mnist.py
  47. +30
    -29
      spikingjelly/clock_driven/functional.py
  48. +303
    -0
      spikingjelly/clock_driven/lava_exchange.py
  49. +245
    -41
      spikingjelly/clock_driven/layer.py
  50. +1
    -3
      spikingjelly/clock_driven/model/parametric_lif_net.py
  51. +5
    -2
      spikingjelly/clock_driven/model/sew_resnet.py
  52. +4
    -1
      spikingjelly/clock_driven/model/spiking_resnet.py
  53. +12
    -10
      spikingjelly/clock_driven/model/spiking_vgg.py
  54. +565
    -98
      spikingjelly/clock_driven/neuron.py
  55. +13750
    -1635
      spikingjelly/clock_driven/neuron_kernel.cu
  56. +29
    -2
      spikingjelly/clock_driven/neuron_kernel.md
  57. +1567
    -1423
      spikingjelly/clock_driven/neuron_kernel.py
  58. +506
    -0
      spikingjelly/clock_driven/spike_op.py
  59. +333
    -11
      spikingjelly/clock_driven/surrogate.py
  60. +212
    -0
      spikingjelly/clock_driven/tensor_cache.py
  61. +28
    -3
      spikingjelly/configure.py
  62. +116
    -119
      spikingjelly/datasets/__init__.py
  63. +9
    -68
      spikingjelly/datasets/asl_dvs.py
  64. +10
    -67
      spikingjelly/datasets/cifar10_dvs.py
  65. +9
    -69
      spikingjelly/datasets/dvs128_gesture.py
  66. +217
    -0
      spikingjelly/datasets/es_imagenet.py
  67. +10
    -69
      spikingjelly/datasets/n_caltech101.py
  68. +9
    -71
      spikingjelly/datasets/n_mnist.py
  69. +331
    -0
      spikingjelly/datasets/nav_gesture.py
  70. +0
    -1
      spikingjelly/datasets/speechcommands.py

+ 62
- 13
README.md View File

@@ -26,7 +26,7 @@ Note that SpikingJelly is based on PyTorch. Please make sure that you have insta

The odd version number is the developing version, which is updated with GitHub/OpenI repository. The even version number is the stable version and available at PyPI.

**Install the last stable version (0.0.0.0.8) from** [**PyPI**](https://pypi.org/project/spikingjelly/):
**Install the last stable version from** [**PyPI**](https://pypi.org/project/spikingjelly/):

```bash
pip install spikingjelly
@@ -40,7 +40,7 @@ git clone https://github.com/fangwei123456/spikingjelly.git
cd spikingjelly
python setup.py install
```
From [OpenI](https://git.openi.org.cn/OpenI/spikingjelly)
From [OpenI](https://git.openi.org.cn/OpenI/spikingjelly):
```bash
git clone https://git.openi.org.cn/OpenI/spikingjelly.git
cd spikingjelly
@@ -80,7 +80,7 @@ Read [spikingjelly.clock_driven.examples](https://spikingjelly.readthedocs.io/zh

## Fast And Handy ANN-SNN Conversion

SpikingJelly implements a relatively general ANN-SNN Conversion interface. Users can realize the conversion through PyTorch or ONNX packages. What's more, users can customize the conversion module to add to the conversion.
SpikingJelly implements a relatively general ANN-SNN Conversion interface. Users can realize the conversion through PyTorch. What's more, users can customize the conversion mode.

```python
class ANN(nn.Module):
@@ -103,8 +103,7 @@ class ANN(nn.Module):
nn.AvgPool2d(2, 2),

nn.Flatten(),
nn.Linear(32, 10),
nn.ReLU()
nn.Linear(32, 10)
)

def forward(self,x):
@@ -112,7 +111,7 @@ class ANN(nn.Module):
return x
```

This simple network with analog encoding can achieve 98.51% accuracy after converiosn on MNIST test dataset. Read [the tutorial of ann2snn](https://spikingjelly.readthedocs.io/zh_CN/latest/clock_driven/5_ann2snn.html) for more details. You can also run this code in Python terminal for training on classifying MNIST using converted model:
This simple network with analog encoding can achieve 98.44% accuracy after converiosn on MNIST test dataset. Read [the tutorial of ann2snn](https://spikingjelly.readthedocs.io/zh_CN/latest/clock_driven/5_ann2snn.html) for more details. You can also run this code in Python terminal for training on classifying MNIST using converted model:

```python
>>> import spikingjelly.clock_driven.ann2snn.examples.cnn_mnist as cnn_mnist
@@ -146,21 +145,71 @@ As simple as using PyTorch.
## Neuromorphic Datasets Supports
SpikingJelly includes the following neuromorphic datasets:

| Dataset | Source |
| Dataset | Source |
| -------------- | ------------------------------------------------------------ |
| ASL-DVS | Graph-based Object Classification for Neuromorphic Vision Sensing |
| CIFAR10-DVS | CIFAR10-DVS: An Event-Stream Dataset for Object Classification |
| DVS128 Gesture | A Low Power, Fully Event-Based Gesture Recognition System |
| N-Caltech101 | Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades |
| N-MNIST | Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades |
| ASL-DVS | [Graph-based Object Classification for Neuromorphic Vision Sensing](https://openaccess.thecvf.com/content_ICCV_2019/html/Bi_Graph-Based_Object_Classification_for_Neuromorphic_Vision_Sensing_ICCV_2019_paper.html) |
| CIFAR10-DVS | [CIFAR10-DVS: An Event-Stream Dataset for Object Classification](https://internal-journal.frontiersin.org/articles/10.3389/fnins.2017.00309/full) |
| DVS128 Gesture | [A Low Power, Fully Event-Based Gesture Recognition System](https://openaccess.thecvf.com/content_cvpr_2017/html/Amir_A_Low_Power_CVPR_2017_paper.html) |
| ES-ImageNet | [ES-ImageNet: A Million Event-Stream Classification Dataset for Spiking Neural Networks](https://www.frontiersin.org/articles/10.3389/fnins.2021.726582/full) |
| N-Caltech101 | [Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades](https://www.frontiersin.org/articles/10.3389/fnins.2015.00437/full) |
| N-MNIST | [Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades](https://www.frontiersin.org/articles/10.3389/fnins.2015.00437/full) |
| Nav Gesture | [Event-Based Gesture Recognition With Dynamic Background Suppression Using Smartphone Computational Capabilities](https://www.frontiersin.org/articles/10.3389/fnins.2020.00275/full) |

Users can use both the origin events data and frames data integrated by SpikingJelly:

```python
import torch
from torch.utils.data import DataLoader
from spikingjelly.datasets import pad_sequence_collate, padded_sequence_mask
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
root_dir = 'D:/datasets/DVS128Gesture'
event_set = DVS128Gesture(root_dir, train=True, data_type='event')
frame_set = DVS128Gesture(root_dir, train=True, data_type='frame', frames_number=20, split_by='number')
event, label = event_set[0]
for k in event.keys():
print(k, event[k])

# t [80048267 80048277 80048278 ... 85092406 85092538 85092700]
# x [49 55 55 ... 60 85 45]
# y [82 92 92 ... 96 86 90]
# p [1 0 0 ... 1 0 0]
# label 0

fixed_frames_number_set = DVS128Gesture(root_dir, train=True, data_type='frame', frames_number=20, split_by='number')
rand_index = torch.randint(low=0, high=fixed_frames_number_set.__len__(), size=[2])
for i in rand_index:
frame, label = fixed_frames_number_set[i]
print(f'frame[{i}].shape=[T, C, H, W]={frame.shape}')

# frame[308].shape=[T, C, H, W]=(20, 2, 128, 128)
# frame[453].shape=[T, C, H, W]=(20, 2, 128, 128)

fixed_duration_frame_set = DVS128Gesture(root_dir, data_type='frame', duration=1000000, train=True)
for i in range(5):
x, y = fixed_duration_frame_set[i]
print(f'x[{i}].shape=[T, C, H, W]={x.shape}')

# x[0].shape=[T, C, H, W]=(6, 2, 128, 128)
# x[1].shape=[T, C, H, W]=(6, 2, 128, 128)
# x[2].shape=[T, C, H, W]=(5, 2, 128, 128)
# x[3].shape=[T, C, H, W]=(5, 2, 128, 128)
# x[4].shape=[T, C, H, W]=(7, 2, 128, 128)

train_data_loader = DataLoader(fixed_duration_frame_set, collate_fn=pad_sequence_collate, batch_size=5)
for x, y, x_len in train_data_loader:
print(f'x.shape=[N, T, C, H, W]={tuple(x.shape)}')
print(f'x_len={x_len}')
mask = padded_sequence_mask(x_len) # mask.shape = [T, N]
print(f'mask=\n{mask.t().int()}')
break

# x.shape=[N, T, C, H, W]=(5, 7, 2, 128, 128)
# x_len=tensor([6, 6, 5, 5, 7])
# mask=
# tensor([[1, 1, 1, 1, 1, 1, 0],
# [1, 1, 1, 1, 1, 1, 0],
# [1, 1, 1, 1, 1, 0, 0],
# [1, 1, 1, 1, 1, 0, 0],
# [1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)
```
More datasets will be included in the future.



+ 58
- 8
README_cn.md View File

@@ -26,7 +26,7 @@ SpikingJelly的文档使用中英双语编写: https://spikingjelly.readthedoc

奇数版本是开发版,随着GitHub/OpenI不断更新。偶数版本是稳定版,可以从PyPI获取。

**从** [**PyPI**](https://pypi.org/project/spikingjelly/) **安装最新的稳定版本(0.0.0.0.8)**:
**从** [**PyPI**](https://pypi.org/project/spikingjelly/) **安装最新的稳定版本**:

```bash
pip install spikingjelly
@@ -40,7 +40,7 @@ git clone https://github.com/fangwei123456/spikingjelly.git
cd spikingjelly
python setup.py install
```
通过[OpenI](https://git.openi.org.cn/OpenI/spikingjelly)
通过[OpenI](https://git.openi.org.cn/OpenI/spikingjelly):
```bash
git clone https://git.openi.org.cn/OpenI/spikingjelly.git
cd spikingjelly
@@ -148,19 +148,69 @@ SpikingJelly 已经将下列数据集纳入:

| 数据集 | 来源 |
| -------------- | ------------------------------------------------------------ |
| ASL-DVS | Graph-based Object Classification for Neuromorphic Vision Sensing |
| CIFAR10-DVS | CIFAR10-DVS: An Event-Stream Dataset for Object Classification |
| DVS128 Gesture | A Low Power, Fully Event-Based Gesture Recognition System |
| N-Caltech101 | Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades |
| N-MNIST | Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades |
| ASL-DVS | [Graph-based Object Classification for Neuromorphic Vision Sensing](https://openaccess.thecvf.com/content_ICCV_2019/html/Bi_Graph-Based_Object_Classification_for_Neuromorphic_Vision_Sensing_ICCV_2019_paper.html) |
| CIFAR10-DVS | [CIFAR10-DVS: An Event-Stream Dataset for Object Classification](https://internal-journal.frontiersin.org/articles/10.3389/fnins.2017.00309/full) |
| DVS128 Gesture | [A Low Power, Fully Event-Based Gesture Recognition System](https://openaccess.thecvf.com/content_cvpr_2017/html/Amir_A_Low_Power_CVPR_2017_paper.html) |
| ES-ImageNet | [ES-ImageNet: A Million Event-Stream Classification Dataset for Spiking Neural Networks](https://www.frontiersin.org/articles/10.3389/fnins.2021.726582/full) |
| N-Caltech101 | [Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades](https://www.frontiersin.org/articles/10.3389/fnins.2015.00437/full) |
| N-MNIST | [Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades](https://www.frontiersin.org/articles/10.3389/fnins.2015.00437/full) |
| Nav Gesture | [Event-Based Gesture Recognition With Dynamic Background Suppression Using Smartphone Computational Capabilities](https://www.frontiersin.org/articles/10.3389/fnins.2020.00275/full) |

用户可以轻松使用事件数据,或由SpikingJelly积分生成的帧数据:

```python
import torch
from torch.utils.data import DataLoader
from spikingjelly.datasets import pad_sequence_collate, padded_sequence_mask
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
root_dir = 'D:/datasets/DVS128Gesture'
event_set = DVS128Gesture(root_dir, train=True, data_type='event')
frame_set = DVS128Gesture(root_dir, train=True, data_type='frame', frames_number=20, split_by='number')
event, label = event_set[0]
for k in event.keys():
print(k, event[k])

# t [80048267 80048277 80048278 ... 85092406 85092538 85092700]
# x [49 55 55 ... 60 85 45]
# y [82 92 92 ... 96 86 90]
# p [1 0 0 ... 1 0 0]
# label 0

fixed_frames_number_set = DVS128Gesture(root_dir, train=True, data_type='frame', frames_number=20, split_by='number')
rand_index = torch.randint(low=0, high=fixed_frames_number_set.__len__(), size=[2])
for i in rand_index:
frame, label = fixed_frames_number_set[i]
print(f'frame[{i}].shape=[T, C, H, W]={frame.shape}')

# frame[308].shape=[T, C, H, W]=(20, 2, 128, 128)
# frame[453].shape=[T, C, H, W]=(20, 2, 128, 128)

fixed_duration_frame_set = DVS128Gesture(root_dir, data_type='frame', duration=1000000, train=True)
for i in range(5):
x, y = fixed_duration_frame_set[i]
print(f'x[{i}].shape=[T, C, H, W]={x.shape}')

# x[0].shape=[T, C, H, W]=(6, 2, 128, 128)
# x[1].shape=[T, C, H, W]=(6, 2, 128, 128)
# x[2].shape=[T, C, H, W]=(5, 2, 128, 128)
# x[3].shape=[T, C, H, W]=(5, 2, 128, 128)
# x[4].shape=[T, C, H, W]=(7, 2, 128, 128)

train_data_loader = DataLoader(fixed_duration_frame_set, collate_fn=pad_sequence_collate, batch_size=5)
for x, y, x_len in train_data_loader:
print(f'x.shape=[N, T, C, H, W]={tuple(x.shape)}')
print(f'x_len={x_len}')
mask = padded_sequence_mask(x_len) # mask.shape = [T, N]
print(f'mask=\n{mask.t().int()}')
break

# x.shape=[N, T, C, H, W]=(5, 7, 2, 128, 128)
# x_len=tensor([6, 6, 5, 5, 7])
# mask=
# tensor([[1, 1, 1, 1, 1, 1, 0],
# [1, 1, 1, 1, 1, 1, 0],
# [1, 1, 1, 1, 1, 0, 0],
# [1, 1, 1, 1, 1, 0, 0],
# [1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)
```

未来将会纳入更多数据集。


+ 2
- 0
bugs.md View File

@@ -11,4 +11,6 @@ Some fatal bugs and when the bug is fixed are shown in this table. Note that the
| Bug: Cupy backend for spiking neurons, https://github.com/fangwei123456/spikingjelly/issues/106. This bug makes spiking neurons with cupy backend output wrong spikes and voltages. This bug has no influence on release 0.0.0.0.4, which does not use cupy. | 2021-09-16 |
| **Release: 0.0.0.0.8** | 2021-11-21 |
| Bug: MultiStepParametricLIFNode, https://github.com/fangwei123456/spikingjelly/issues/151. This bug makes the gradient of the learnable parameter in MultiStepParametricLIFNode incomplete when backend is cupy. | 2021-12-10 |
| **Release: 0.0.0.0.10** | |
| Bug: When using CuPy with `version >= 10`, CuPy will change `torch.cuda.current_device()` to 0, https://github.com/cupy/cupy/issues/6569. This bug will break training when using Distributed Data Parallel (DDP). | 2022-03-22 |


BIN
docs/source/_static/API/clock_driven/lava_exchange/step_quantize.pdf View File


+ 2016
- 0
docs/source/_static/API/clock_driven/lava_exchange/step_quantize.svg
File diff suppressed because it is too large
View File


BIN
docs/source/_static/API/clock_driven/surrogate/QPseudoSpike.pdf View File


+ 2194
- 0
docs/source/_static/API/clock_driven/surrogate/QPseudoSpike.svg
File diff suppressed because it is too large
View File


BIN
docs/source/_static/API/clock_driven/surrogate/S2NN.pdf View File


+ 2379
- 0
docs/source/_static/API/clock_driven/surrogate/S2NN.svg
File diff suppressed because it is too large
View File


+ 50
- 5
docs/source/clock_driven/13_neuromorphic_datasets.rst View File

@@ -58,9 +58,7 @@ DVS128 Gesture数据集不支持自动下载,但它的 ``resource_url_md5()``
运行这段代码,惊蜇框架将会完成以下工作:

#. 检测数据集是否存在,如果存在,则进行MD5校验,确认数据集无误后,开始进行解压。将原始数据解压到同级目录下的 ``extract`` 文件夹
#. DVS128 Gesture中的每个样本,是在不同光照环境下,对不同表演者进行录制的手势视频。一个AER文件中包含了多个手势,对应的会有一个csv文件来标注

整个视频内各个时间段内都是哪种手势。因此,单个的视频文件并不是一个类别,而是多个类别的集合。惊蜇框架会启动多线程进行划分,将每个视频中的每个手势类别文件单独提取出来
#. DVS128 Gesture中的每个样本,是在不同光照环境下,对不同表演者进行录制的手势视频。一个AER文件中包含了多个手势,对应的会有一个csv文件来标注整个视频内各个时间段内都是哪种手势。因此,单个的视频文件并不是一个类别,而是多个类别的集合。惊蜇框架会启动多线程进行划分,将每个视频中的每个手势类别文件单独提取出来

下面是运行过程中的命令行输出:

@@ -202,6 +200,52 @@ DVS128 Gesture数据集不支持自动下载,但它的 ``resource_url_md5()``
.. image:: ../_static/tutorials/clock_driven/13_neuromorphic_datasets/dvsg.*
:width: 100%

固定时间间隔积分
----------------------------
使用固定时间间隔积分,更符合实际物理系统。例如每 ``10 ms`` 积分一次,则长度为 ``L ms`` 的数据,可以得到 ``math.floor(L / 10)`` 帧。但
神经形态数据集中每个样本的长度往往不相同,因此会得到不同长度的帧数据。使用惊蜇框架提供的 :class:`spikingjelly.datasets.pad_sequence_collate`
和 :class:`spikingjelly.datasets.padded_sequence_mask` 可以很方便的对不等长数据进行对齐和还原。

示例代码:

.. code:: python

import torch
from torch.utils.data import DataLoader
from spikingjelly.datasets import pad_sequence_collate, padded_sequence_mask, dvs128_gesture
root='D:/datasets/DVS128Gesture'
train_set = dvs128_gesture.DVS128Gesture(root, data_type='frame', duration=1000000, train=True)
for i in range(5):
x, y = train_set[i]
print(f'x[{i}].shape=[T, C, H, W]={x.shape}')
train_data_loader = DataLoader(train_set, collate_fn=pad_sequence_collate, batch_size=5)
for x, y, x_len in train_data_loader:
print(f'x.shape=[N, T, C, H, W]={tuple(x.shape)}')
print(f'x_len={x_len}')
mask = padded_sequence_mask(x_len) # mask.shape = [T, N]
print(f'mask=\n{mask.t().int()}')
break

输出为:

.. code:: bash

The directory [D:/datasets/DVS128Gesture\duration_1000000] already exists.
x[0].shape=[T, C, H, W]=(6, 2, 128, 128)
x[1].shape=[T, C, H, W]=(6, 2, 128, 128)
x[2].shape=[T, C, H, W]=(5, 2, 128, 128)
x[3].shape=[T, C, H, W]=(5, 2, 128, 128)
x[4].shape=[T, C, H, W]=(7, 2, 128, 128)
x.shape=[N, T, C, H, W]=(5, 7, 2, 128, 128)
x_len=tensor([6, 6, 5, 5, 7])
mask=
tensor([[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)


自定义积分方法
-----------------------
惊蜇框架支持用户自定义积分方法。用户只需要提供积分函数 ``custom_integrate_function`` 以及保存frames的文件夹名 ``custom_integrated_frames_dir_name``。
@@ -220,8 +264,9 @@ DVS128 Gesture数据集不支持自动下载,但它的 ``resource_url_md5()``
def integrate_events_to_2_frames_randomly(events: Dict, H: int, W: int):
index_split = np.random.randint(low=0, high=events['t'].__len__())
frames = np.zeros([2, 2, H, W])
frames[0] = sjds.integrate_events_segment_to_frame(events, H, W, 0, index_split)
frames[1] = sjds.integrate_events_segment_to_frame(events, H, W, index_split, events['t'].__len__())
t, x, y, p = (events[key] for key in ('t', 'x', 'y', 'p'))
frames[0] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, 0, index_split)
frames[1] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, index_split, events['t'].__len__())
return frames

接下来创建数据集:


+ 387
- 0
docs/source/clock_driven/17_loihi_sim.rst View File

@@ -0,0 +1,387 @@
Loihi仿真
======================================

本教程作者: `fangwei123456 <https://github.com/fangwei123456>`_

LAVA-DL框架中Block的行为
-----------------------------------------------------

`lava.lib.dl.slayer.block` 封装突触和神经元到单个Block,可以通过如下流程使用Block来进行Loihi仿真:

1.使用Block导出hdf5定义的网络
2.加载网络,转换为LAVA中的Process
3.使用LAVA提供的Loihi仿真器仿真Process

Block是为Loihi仿真而生,它并不是像 `nn.Sequential` 这样简单的把两个模块包装一下,而是有更复杂的行为。

根据对源代码的分析,我们的结论是:

在 `slayer.block` 中:

- `p_scale = 1 << 12`

- `w_scale = scale`

- `s_scale = scale * (1 << 6)`

- 若不指定 `pre_hook_fx = None` 或其他特定的函数,则 `self.synapse.weight` 会被量化,然后限幅,最终取值范围是 `2k / w_scale, k = -128, -127, ..., 127`,共有256种取值

- `p_scale = 1 << 12, self.neuron.current_decay = int(p_scale * current_decay), self.neuron.voltage_decay = int(p_scale * voltage_decay)`,
但在计算衰减时,衰减后的值会通过 `right_shift_to_zero(x, bits=12)` 还原

- `self.threshold = int(threshold * w_scale) / w_scale`

- 计算神经动态时, `x, self.current_state, self.voltage_state, self.threshold` 都会先乘上 `s_scale` 进行计算,最后的输出再除以 `s_scale` 进行还原


下面的内容是源代码的分析过程,不感兴趣的读者可以跳过。

以 `slayer.block.Dense` 为例,对其行为进行介绍。


`slayer.block.Dense` 的参数说明如下:

- neuron_params (dict, optional) –- a dictionary of CUBA LIF neuron parameter. Defaults to None.

- in_neurons (int) –- number of input neurons.

- out_neurons (int) –- number of output neurons.

- weight_scale (int, optional) –- weight initialization scaling. Defaults to 1.

- weight_norm (bool, optional) –- flag to enable weight normalization. Defaults to False.

- pre_hook_fx (optional) –- a function pointer or lambda that is applied to synaptic weights before synaptic operation. None means no transformation. Defaults to None.

- delay (bool, optional) -– flag to enable axonal delay. Defaults to False.

- delay_shift (bool, optional) –- flag to simulate spike propagation delay from one layer to next. Defaults to True.

- mask (bool array, optional) -– boolean synapse mask that only enables relevant synapses. None means no masking is applied. Defaults to None.

- count_log (bool, optional) -– flag to return event count log. If True, an additional value of average event rate is returned. Defaults to False.

`slayer.block.Dense` 前向传播的流程为:

`x` -> `synapse` -> `neuron`

突触的量化
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

在 `synapse` 的前向传播中,在进行计算前,会对自身的权重做一次变换:

.. code-block:: python

# lava\lib\dl\slayer\synapse\layer.py
class Dense(torch.torch.nn.Conv3d, GenericLayer):
def forward(self, input):
# ...
if self._pre_hook_fx is None:
weight = self.weight
else:
weight = self._pre_hook_fx(self.weight)
# ...

根据 `slayer.block.Dense` 的构造函数:

.. code-block:: python

# lava\lib\dl\slayer\block\cuba.py
class Dense(AbstractCuba, base.AbstractDense):
def __init__(self, *args, **kwargs):
super(Dense, self).__init__(*args, **kwargs)
self.synapse = synapse.Dense(**self.synapse_params)
if 'pre_hook_fx' not in kwargs.keys():
self.synapse.pre_hook_fx = self.neuron.quantize_8bit
del self.synapse_params

可以发现,在不专门指定 'pre_hook_fx' 的情况下,`self.synapse.pre_hook_fx = self.neuron.quantize_8bit`。
因此,`slayer.block.Dense` 中的突触,默认是进行了量化。

我们查看量化函数的具体做法:

.. code-block:: python

# lava\lib\dl\slayer\neuron\base.py
class Neuron(torch.nn.Module):
def quantize_8bit(self, weight, descale=False):
if descale is False:
return quantize(
weight, step=2 / self.w_scale
).clamp(-256 / self.w_scale, 255 / self.w_scale)
else:
return quantize(
weight, step=2 / self.w_scale
).clamp(-256 / self.w_scale, 255 / self.w_scale) * self.w_scale

# lava\lib\dl\slayer\utils\quantize.py
class _quantize(torch.autograd.Function):
@staticmethod
def forward(ctx, input, step=1):
return torch.round(input / step) * step

@staticmethod
def backward(ctx, gradOutput):
return gradOutput, None
def quantize(input, step=1):
return _quantize.apply(input, step)


在 `spikingjelly.clock_driven.lava_exchange.step_quantize <https://spikingjelly.readthedocs.io/zh_CN/latest/spikingjelly.clock_driven.lava_exchange.html#spikingjelly.clock_driven.lava_exchange.step_quantize>`_
中提供了一个量化函数的示意图:

.. image:: ../_static/API/clock_driven/lava_exchange/step_quantize.*
:width: 100%

可以看出,`self.synapse.weight` 被进行 `step = 2 / self.neuron.w_scale` 的量化,然后再被限幅到 `[-256 / self.neuron.w_scale, 255 / self.neuron.w_scale]`。
因此,`self.synapse.weight` 量化后的取值范围为 `2k / self.neuron.w_scale, k = -128, -127, ..., 127`,共有256个取值,因而是8比特量化。


神经动态的量化
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

在 `neuron` 的前向传播中,首先进行神经动态(LAVA的重置过程被融合进了神经动态),然后进行放电:

.. code-block:: python

# lava\lib\dl\slayer\neuron\cuba.py
class Neuron(base.Neuron):
def forward(self, input):
_, voltage = self.dynamics(input)
return self.spike(voltage)

神经动态主要包括电流和电压的计算。电流和电压的衰减系数分别是 `self.current_decay` 和 `self.voltage_decay`,它们在初始化时被缩放了一次:

.. code-block:: python

# lava\lib\dl\slayer\neuron\cuba.py
class Neuron(base.Neuron):
def __init__(
self, threshold, current_decay, voltage_decay,
tau_grad=1, scale_grad=1, scale=1 << 6,
norm=None, dropout=None,
shared_param=True, persistent_state=False, requires_grad=False,
graded_spike=False
):
super(Neuron, self).__init__(
threshold=threshold,
tau_grad=tau_grad,
scale_grad=scale_grad,
p_scale=1 << 12,
w_scale=scale,
s_scale=scale * (1 << 6),
norm=norm,
dropout=dropout,
persistent_state=persistent_state,
shared_param=shared_param,
requires_grad=requires_grad
)
# ...
self.register_parameter(
'current_decay',
torch.nn.Parameter(
torch.FloatTensor([self.p_scale * current_decay]),
requires_grad=self.requires_grad,
)
)
self.register_parameter(
'voltage_decay',
torch.nn.Parameter(
torch.FloatTensor([self.p_scale * voltage_decay]),
requires_grad=self.requires_grad,
)
)
# ...

因此,它们实际的值并不是在构造时给定的 `current_decay` 和 `voltage_decay`,而是乘上了 `self.p_scale`,也就是 `1 << 12`。

它们在神经动态中进行计算时,又被 `quantize` 函数量化了一次:

.. code-block:: python

# lava\lib\dl\slayer\neuron\cuba.py
class Neuron(base.Neuron):
def dynamics(self, input):
# ...
# clamp the values only when learning is enabled
# This means we don't need to clamp the values after gradient update.
# It is done in runtime now. Might be slow, but overhead is negligible.
if self.requires_grad is True:
self.clamp()

current = leaky_integrator.dynamics(
input,
quantize(self.current_decay),
self.current_state.contiguous(),
self.s_scale,
debug=self.debug
)

voltage = leaky_integrator.dynamics(
current, # bias can be enabled by adding it here
quantize(self.voltage_decay),
self.voltage_state.contiguous(),
self.s_scale,
self.threshold,
debug=self.debug
)
# ...

在训练时,每次前向传播前都会调用 `self.clamp()` 进行限幅:

.. code-block:: python

# lava\lib\dl\slayer\neuron\cuba.py
def clamp(self):
"""A function to clamp the sin decay and cosine decay parameters to be
within valid range. The user will generally not need to call this
function.
"""
with torch.no_grad():
self.current_decay.data.clamp_(0, self.p_scale)
self.voltage_decay.data.clamp_(0, self.p_scale)



结合限幅和量化过程,我们可以得知,在进行神经动态计算电流和电压衰减时:

-- 真正的衰减系数是 `quantize(self.current_decay)` 和 `quantize(self.voltage_decay)`

-- 衰减系数的取值是量化的,取值范围为 `0, 1, 2, ..., self.p_scale`


接下来我们关注状态和阈值的量化。

收件根据构造函数,我们回顾一下几个系数之间的关系:

.. code-block:: python

# lava\lib\dl\slayer\neuron\cuba.py
class Neuron(base.Neuron):
def __init__(
self, threshold, current_decay, voltage_decay,
tau_grad=1, scale_grad=1, scale=1 << 6,
norm=None, dropout=None,
shared_param=True, persistent_state=False, requires_grad=False,
graded_spike=False
):
super(Neuron, self).__init__(
# ...
p_scale=1 << 12,
w_scale=scale,
s_scale=scale * (1 << 6),
# ...

根据 `base.Neuron` 的构造函数:

.. code-block:: python

# lava\lib\dl\slayer\neuron\base.py
class Neuron(torch.nn.Module):
def __init__(
self, threshold,
tau_grad=1, scale_grad=1,
p_scale=1, w_scale=1, s_scale=1,
norm=None, dropout=None,
persistent_state=False, shared_param=True,
requires_grad=True,
complex=False
):
# ...
self.p_scale = p_scale
self.w_scale = int(w_scale)
self.s_scale = int(s_scale)
# quantize to proper value
self._threshold = int(threshold * self.w_scale) / self.w_scale
# ...

可以发现阈值实际上是做了一个 `step = self.w_scale` 的量化。

最后,我们看一下 `self.s_scale` 在 `leaky_integrator.dynamics` 中的作用。查看源码:

.. code-block:: python

# lava\lib\dl\slayer\neuron\cuba.py
class Neuron(base.Neuron):
def dynamics(self, input):
# ...
current = leaky_integrator.dynamics(
input,
quantize(self.current_decay),
self.current_state.contiguous(),
self.s_scale,
debug=self.debug
)

voltage = leaky_integrator.dynamics(
current, # bias can be enabled by adding it here
quantize(self.voltage_decay),
self.voltage_state.contiguous(),
self.s_scale,
self.threshold,
debug=self.debug
)
# ...

# lava\lib\dl\slayer\neuron\dynamics\leaky_integrator.py
def _li_dynamics_fwd(
input, decay, state, threshold, w_scale, dtype=torch.int32
):
output_old = (state * w_scale).clone().detach().to(dtype).to(input.device)
decay_int = (1 << 12) - decay.clone().detach().to(dtype).to(input.device)
output = torch.zeros_like(input)

threshold *= w_scale

for n in range(input.shape[-1]):
output_new = right_shift_to_zero(output_old * decay_int, 12) + \
(w_scale * input[..., n]).to(dtype)
if threshold > 0:
spike_new = (output_new >= threshold)
output_old = output_new * (spike_new < 0.5)
else:
output_old = output_new

output[..., n] = output_new / w_scale

return output

# lava\lib\dl\slayer\utils\int_utils.py
def right_shift_to_zero(x, bits):
"""Right shift with quantization towards zero implementation.

Parameters
----------
x : torch.int32 or torch.int64
input tensor.
bits : int
number of bits to shift.

Returns
-------
torch.int32 or torch.int64
right shift to zero result.

"""
# ...


可以发现,`input, state, threshold` 都会先乘上 `w_scale` 进行计算,最后再除以 `w_scale` 进行还原。`p_scale = 1 << 12`,因而 `right_shift_to_zero(x, bits=12)`。

最后的结论是,在 `slayer.block` 中:

- `p_scale = 1 << 12`

- `w_scale = scale`

- `s_scale = scale * (1 << 6)`

- 若不指定 `pre_hook_fx = None` 或其他特定的函数,则 `self.synapse.weight` 会被量化,然后限幅,最终取值范围是 `2k / w_scale, k = -128, -127, ..., 127`,共有256种取值

- `p_scale = 1 << 12, self.neuron.current_decay = int(p_scale * current_decay), self.neuron.voltage_decay = int(p_scale * voltage_decay)`,
但在计算衰减时,最终的输出会通过 `right_shift_to_zero(x, bits=12)` 还原

- `self.threshold = int(threshold * w_scale) / w_scale`

- 计算神经动态时, `x, self.current_state, self.voltage_state, self.threshold` 都会先乘上 `s_scale` 进行计算,最后的输出再除以 `s_scale` 进行还原

+ 157
- 179
docs/source/clock_driven/5_ann2snn.rst View File

@@ -4,10 +4,7 @@ ANN转换SNN

本节教程主要关注 ``spikingjelly.clock_driven.ann2snn``,介绍如何将训练好的ANN转换SNN,并且在SpikingJelly框架上进行仿真。

目前实现了两套实现:基于ONNX 和 基于PyTorch,在框架中被称为 ONNX kernel 和 PyTorch kernel。
但是这两套实现各有特点,ONNX kernel的实现更加通用,支持更加复杂的拓扑结构(例如ResNet);
PyTorch kernel主要是为了简单测试,支持的模块比较有限且在现有配置下可能有很多bug。
更多模块可以通过ONNX拓展,用户可自行实现...
较早的实现方案中有两套实现:基于ONNX 和 基于PyTorch。由于ONNX不稳定,本版本为PyTorch增强版,原生支持复杂拓扑(例如ResNet)。一起来看看吧!

ANN转换SNN的理论基础
--------------------
@@ -107,7 +104,7 @@ SNN相比于ANN,产生的脉冲是离散的,这有利于高效的通信。
.. math::
\frac{V_T-V_0}{T} = z - V_{threshold} \frac{\sum_{t=1}^{T}\theta_t}{T} = z- V_{threshold} \frac{N}{T}

其中 :math:`N` 为 :math:`T` 时间步内脉冲数, :math:`\frac{N}{T}` 就是发放率 :math:`r`。利用 :math:`z= V_{threshold} a`
其中 :math:`N` 为 :math:`T` 时间步内脉冲数, :math:`\frac{N}{T}` 就是发放率 :math:`r`。利用 :math:`z= V_{threshold} a`
即:

.. math::
@@ -123,19 +120,14 @@ SNN相比于ANN,产生的脉冲是离散的,这有利于高效的通信。
.. math::
r^l = W^l r^{l-1}+b^l- \frac{V^l_T}{T V_{threshold}}

详细的说明见文献 [#f1]_ 。ann2snn中的方法也主要来自文献 [#f1]_
详细的说明见文献 [#f1]_ 。ann2snn中的方法也主要来自文献 [#f1]_

转换和仿真
----------
转换到脉冲神经网络
^^^^^^^^^^^^^^^^

具体地,进行前馈ANN转SNN主要有两个步骤:即模型分析(英文:parse,直译:句法分析)和仿真模拟。
转换主要解决两个问题:

模型分析
^^^^^^^^

模型分析主要解决两个问题:

1. ANN为了快速训练和收敛提出了批归一化(Batch Normalization)。批归一化旨在将ANN输出归一化到0均值,这与SNN的特性相违背。因此,需要将BN的参数吸收到前面的参数层中(Linear、Conv2d)
1. ANN为了快速训练和收敛提出了批归一化(Batch Normalization)。批归一化旨在将ANN输出归一化到0均值,这与SNN的特性相违背。因此,可以将BN的参数吸收到前面的参数层中(Linear、Conv2d)

2. 根据转换理论,ANN的每层输入输出需要被限制在[0,1]范围内,这就需要对参数进行缩放(模型归一化)

@@ -155,7 +147,7 @@ SNN相比于ANN,产生的脉冲是离散的,这有利于高效的通信。

◆ 模型归一化

对于某个参数模块,假定得到了其输入张量和输出张量,其输入张量的最大值为 :math:`\lambda_{pre}` ,输出张量的最大值为 :math:`\lambda`
对于某个参数模块,假定得到了其输入张量和输出张量,其输入张量的最大值为 :math:`\lambda_{pre}` ,输出张量的最大值为 :math:`\lambda`
那么,归一化后的权重 :math:`\hat{W}` 为:

.. math::
@@ -171,65 +163,28 @@ ANN每层输出的分布虽然服从某个特定分布,但是数据中常常

到现在为止,我们对神经网络做的操作,在数值上是完全等价的。当前的模型表现应该与原模型相同。

模型仿真
^^^^^^^^

仿真前,我们需要将原模型中的ReLU激活函数变为IF神经元。
转换中,我们需要将原模型中的ReLU激活函数变为IF神经元。
对于ANN中的平均池化,我们需要将其转化为空间下采样。由于IF神经元可以等效ReLU激活函数。空间下采样后增加IF神经元与否对结果的影响极小。
对于ANN中的最大池化,目前没有非常理想的方案。目前的最佳方案为使用基于动量累计脉冲的门控函数控制脉冲通道 [#f1]_ 。当然在ONNX kernel中没有用,不过我们在``ann2snn.modules``依然有实现。还有文献提出使用空间下采样替代Maxpool2d。此处我们依然推荐使用avgpool2d。

仿真时,依照转换理论,SNN需要输入恒定的模拟输入。使用Poisson编码器将会带来准确率的降低。Poisson编码和恒定输入方式均已实现,感兴趣可通过配置进行不同实验。
对于ANN中的最大池化,目前没有非常理想的方案。目前的最佳方案为使用基于动量累计脉冲的门控函数控制脉冲通道 [#f1]_ 。此处我们依然推荐使用avgpool2d。
仿真时,依照转换理论,SNN需要输入恒定的模拟输入。使用Poisson编码器将会带来准确率的降低。

实现与可选配置
^^^^^^^^^^^^^^^^^^^^^^^^

ann2snn框架在2020年12月进行一次较大更新。最大改动就是将参数配置回归到了模块参数,并且尽可能考虑到了用户对灵活度和渐变操作的需求。这里我们将简单介绍一下这些类和方法。
针对理论中提到的分析和仿真两大中心,设计了parser和simulator两大类。类的定义在``spikingjelly.ann2snn.__init__``中。

◆ parser类
1. 类初始化函数
- kernel:转换的kernel。可选范围为'onnx'、'pytorch',这将决定您使用的是ONNX kernel还是PyTorch kernel
- name:模型的名字,通常您可以取一个和任务、模型相关的名字,之后的文件夹生成将可能用到这个字符串
- z_norm:许多深度学习模型会存在数据标准化(Z normalization)。如果您ANN模型有这个操作,这个参数的数据格式为:(mean, std),例如对于CIFAR10,z_norm可以为((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
- log_dir:保存临时文件的文件夹,如没有此参数则会根据参数name和当前时间自动生成
- json:历史配置文件名。当您运行过一次parser后,程序会自动保存json文件到log_dir,您可以使用json文件进行parser快速初始化

2. parse函数
- channelwise: 如果为``True``,则控制激活幅值的统计是channelwise的;否则,控制激活幅值的统计是layerwise的
- robust: 如果为``True``,则控制激活幅值的统计是激活的99.9百分位;否则,控制激活幅值的统计是激活的最值
- user_methods:默认使用``spikingjelly.ann2snn.kernel.onnx._o2p_converter``;当发现ONNX kernel遇到ONNX转换PyTorch的方法缺乏的时候,可以通过用户自定义函数的形式进行转换。函数接口可见``spikingjelly.ann2snn.kernel.onnx._o2p_converter``的staticmethods

◆ simulator类
1. 类初始化参数
- snn:待仿真的转换后的SNN
- device:仿真的设备,支持单设备(输入为字符串)和多设备(输入为list,set,tuple类型)
- name:模型的名字,通常您可以取一个和任务、模型相关的名字,之后的文件夹生成将可能用到这个字符串
- log_dir:保存临时文件的文件夹,如没有此参数则会根据参数name和当前时间自动生成
- encoder:编码器,可选范围为'constant'、'poisson'

2. simulate函数
- data_loader:仿真的数据集的dataloader
- T:仿真时间
- canvas:plt.fig类型,用于对仿真模型标量性能(例如准确率)的绘图
- online_drawer:如果为``True``,则在线绘图;否则,仿真结束后绘图
- func_dict:用户可以通过自己定义标量性能函数实现绘图

除此之外,用户可以通过继承simulate类进行仿真器的功能细化。
比如``spikingjelly.ann2snn.__init__``实现了仿真分类任务的``classify_simulator``

3. classify_simulator.simulate函数
除去继承的参数外,
- ann_acc:ANN转换前的分类准确率(0-1间的小数)
- fig_name: 仿真图像的名字
- step_max: 如果为``True``,则图像中标明推理过程中的最大准确率
ann2snn框架在2022年4月又迎来一次较大更新。取消了parser和simulator两大类。使用converter类替代了之前的方案。目前的方案更加简洁,并且具有更多转换设置空间。

◆ Converter类
该类用于将ReLU的ANN转换为SNN。这里实现了常见的三种模式。
最常见的是最大电流转换模式,它利用前后层的激活上限,使发放率最高的情况能够对应激活取得最大值的情况。使用这种模式需要将参数mode设置为``max``[#f2]_。
99.9%电流转换模式利用99.9%的激活分位点限制了激活上限。使用这种模式需要将参数mode设置为``99.9%``[#f1]_。
缩放转换模式下,用户需要给定缩放参数到模式中,即可利用缩放后的激活最大值对电流进行限制。使用这种模式需要将参数mode设置为0-1的浮点数。

识别MNIST
---------

现在我们使用 ``ann2snn`` ,搭建一个简单卷积网络,对MNIST数据集进行分类。

首先定义我们的网络结构:
首先定义我们的网络结构 (见``ann2snn.sample_models.mnist_cnn``):

.. code-block:: python

@@ -267,156 +222,179 @@ ann2snn框架在2020年12月进行一次较大更新。最大改动就是将参

.. code-block:: python

device = input('输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ')
dataset_dir = input('输入保存MNIST数据集的位置,例如“./”\n input root directory for saving MNIST dataset, e.g., "./": ')
batch_size = int(input('输入batch_size,例如“64”\n input batch_size, e.g., "64": '))
learning_rate = float(input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": '))
T = int(input('输入仿真时长,例如“100”\n input simulating steps, e.g., "100": '))
train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“10”\n input training epochs, e.g., "10": '))
model_name = input('输入模型名字,例如“mnist”\n input model name, for log_dir generating , e.g., "mnist": ')
torch.random.manual_seed(0)
torch.cuda.manual_seed(0)
device = 'cuda'
dataset_dir = 'G:/Dataset/mnist'
batch_size = 100
T = 50

之后的所有临时文件都会储存到文件夹中
这里的T就是一会儿推理时使用的推理时间步

初始化数据加载器、网络、优化器、损失函数:
如果您想训练的话,还需要初始化数据加载器、优化器、损失函数,例如

.. code-block:: python

# 初始化网络
ann = ANN().to(device)
lr = 1e-3
epochs = 10
# 定义损失函数
loss_function = nn.CrossEntropyLoss()
# 使用Adam优化器
optimizer = torch.optim.Adam(ann.parameters(), lr=learning_rate, weight_decay=5e-4)
optimizer = torch.optim.Adam(ann.parameters(), lr=lr, weight_decay=5e-4)

训练ANN,并定期测试。训练时也可以使用utils中预先写好的训练程序
训练ANN。示例中,我们的模型训练了10个epoch。训练时测试集准确率变化情况如下

.. code-block:: python

for epoch in range(train_epoch):
# 使用utils中预先写好的训练程序训练网络
# 训练程序的写法和经典ANN中的训练也是一样的
# Train the network using a pre-prepared code in ''utils''
utils.train_ann(net=ann,
device=device,
data_loader=train_data_loader,
optimizer=optimizer,
loss_function=loss_function,
epoch=epoch
)
# 使用utils中预先写好的验证程序验证网络输出
# Validate the network using a pre-prepared code in ''utils''
acc = utils.val_ann(net=ann,
device=device,
data_loader=test_data_loader,
epoch=epoch
)
if best_acc <= acc:
utils.save_model(ann, log_dir, model_name+'.pkl')
完整的代码位于 ``ann2snn.examples.cnn_mnist.py`` ,在代码中我们还使用了Tensorboard来保存训练日志。可以直接在Python命令行运行它
Epoch: 0 100%|██████████| 600/600 [00:05<00:00, 112.04it/s]
Validating Accuracy: 0.972
Epoch: 1 100%|██████████| 600/600 [00:05<00:00, 105.43it/s]
Validating Accuracy: 0.986
Epoch: 2 100%|██████████| 600/600 [00:05<00:00, 107.49it/s]
Validating Accuracy: 0.987
Epoch: 3 100%|██████████| 600/600 [00:05<00:00, 109.26it/s]
Validating Accuracy: 0.990
Epoch: 4 100%|██████████| 600/600 [00:05<00:00, 103.98it/s]
Validating Accuracy: 0.984
Epoch: 5 100%|██████████| 600/600 [00:05<00:00, 100.42it/s]
Validating Accuracy: 0.989
Epoch: 6 100%|██████████| 600/600 [00:06<00:00, 96.24it/s]
Validating Accuracy: 0.991
Epoch: 7 100%|██████████| 600/600 [00:05<00:00, 104.97it/s]
Validating Accuracy: 0.992
Epoch: 8 100%|██████████| 600/600 [00:05<00:00, 106.45it/s]
Validating Accuracy: 0.991
Epoch: 9 100%|██████████| 600/600 [00:05<00:00, 111.93it/s]
Validating Accuracy: 0.991
训练好模型后,我们快速加载一下模型测试一下保存好的模型性能

.. code-block:: python

>>> import spikingjelly.clock_driven.ann2snn.examples.cnn_mnist as cnn_mnist
>>> cnn_mnist.main()
输入运行的设备,例如“cpu”或“cuda:0”
input device, e.g., "cpu" or "cuda:0": cuda:15
输入保存MNIST数据集的位置,例如“./”
input root directory for saving MNIST dataset, e.g., "./": ./mnist
输入batch_size,例如“64”
input batch_size, e.g., "64": 128
输入学习率,例如“1e-3”
input learning rate, e.g., "1e-3": 1e-3
输入仿真时长,例如“100”
input simulating steps, e.g., "100": 100
输入训练轮数,即遍历训练集的次数,例如“10”
input training epochs, e.g., "10": 10
输入模型名字,用于自动生成日志文档,例如“cnn_mnist”
input model name, for log_dir generating , e.g., "cnn_mnist"

Epoch 0 [1/937] ANN Training Loss:2.252 Accuracy:0.078
Epoch 0 [101/937] ANN Training Loss:1.423 Accuracy:0.669
Epoch 0 [201/937] ANN Training Loss:1.117 Accuracy:0.773
Epoch 0 [301/937] ANN Training Loss:0.953 Accuracy:0.795
Epoch 0 [401/937] ANN Training Loss:0.865 Accuracy:0.788
Epoch 0 [501/937] ANN Training Loss:0.807 Accuracy:0.792
Epoch 0 [601/937] ANN Training Loss:0.764 Accuracy:0.795
Epoch 0 [701/937] ANN Training Loss:0.726 Accuracy:0.835
Epoch 0 [801/937] ANN Training Loss:0.681 Accuracy:0.880
Epoch 0 [901/937] ANN Training Loss:0.641 Accuracy:0.889
100%|██████████| 100/100 [00:00<00:00, 116.12it/s]
Epoch 0 [100/100] ANN Validating Loss:0.327 Accuracy:0.881
Save model to: cnn_mnist-XXXXX\cnn_mnist.pkl
......

示例中,这个模型训练10个epoch。训练时测试集准确率变化情况如下:

.. image:: ../_static/tutorials/clock_driven/5_ann2snn/accuracy_curve.png

最终达到98.8%的测试集准确率。

从训练集中,取出一部分数据,用于模型的归一化步骤。这里我们取192张图片。
model.load_state_dict(torch.load('SJ-mnist-cnn_model-sample.pth'))
acc = val(model, device, test_data_loader)
print('ANN Validating Accuracy: %.4f' % (acc))

.. code-block:: python
输出结果如下:

# 加载用于归一化模型的数据
# Load the data to normalize the model
percentage = 0.004 # load 0.004 of the data
norm_data_list = []
for idx, (imgs, targets) in enumerate(train_data_loader):
norm_data_list.append(imgs)
if idx == int(len(train_data_loader) * percentage) - 1:
break
norm_data = torch.cat(norm_data_list)
print('use %d imgs to parse' % (norm_data.size(0)))
.. code-block:: python

100%|██████████| 200/200 [00:02<00:00, 89.44it/s]
ANN Validating Accuracy: 0.9870

调用\ ``ann2snn``\ 中的类parser,并使用ONNX kernel。
使用Converter进行转换非常简单,只需要参数中设置希望使用的模式即可。例如使用MaxNorm,需要先定义一个``ann2snn.Converter``,并且把模型forward给这个对象:

.. code-block:: python

onnxparser = parser(name=model_name,
log_dir=log_dir + '/parser',
kernel='onnx')
snn = onnxparser.parse(ann, norm_data.to(parser_device))
model_converter = ann2snn.Converter(mode='max', dataloader=train_data_loader)
snn_model = model_converter(model)

我们可以保存好我们转换好的snn模型,并且定义一个plt.figure用于绘图
snn_model就是输出来的SNN模型。

.. code-block:: python
按照这个例子,我们分别定义模式为``max``,``99.9%``,``1.0/2``,``1.0/3``,``1.0/4``,``1.0/5``情况下的SNN转换并分别推理T步得到准确率。

torch.save(snn, os.path.join(log_dir,'snn-'+model_name+'.pkl'))
fig = plt.figure('simulator')
.. code-block:: python

现在,我们定义用于SNN的仿真器。由于我们的任务是分类,选择类``classify_simulator``
print('---------------------------------------------')
print('Converting using MaxNorm')
model_converter = ann2snn.Converter(mode='max', dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_max_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_max_accs[-1]))

print('---------------------------------------------')
print('Converting using RobustNorm')
model_converter = ann2snn.Converter(mode='99.9%', dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_robust_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_robust_accs[-1]))

print('---------------------------------------------')
print('Converting using 1/2 max(activation) as scales...')
model_converter = ann2snn.Converter(mode=1.0 / 2, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_two_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_two_accs[-1]))

print('---------------------------------------------')
print('Converting using 1/3 max(activation) as scales')
model_converter = ann2snn.Converter(mode=1.0 / 3, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_three_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_three_accs[-1]))

print('---------------------------------------------')
print('Converting using 1/4 max(activation) as scales')
model_converter = ann2snn.Converter(mode=1.0 / 4, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_four_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_four_accs[-1]))

print('---------------------------------------------')
print('Converting using 1/5 max(activation) as scales')
model_converter = ann2snn.Converter(mode=1.0 / 5, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_five_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_five_accs[-1]))

观察控制栏输出:

.. code-block:: python

sim = classify_simulator(snn,
log_dir=log_dir + '/simulator',
device=simulator_device,
canvas=fig
)
sim.simulate(test_data_loader,
T=T,
online_drawer=True,
ann_acc=ann_acc,
fig_name=model_name,
step_max=True
)

模型仿真由于时间较长,我们设计了tqdm的进度条用于预估仿真时间。仿真结束时会有仿真器的summary
---------------------------------------------
Converting using MaxNorm
100%|██████████| 600/600 [00:04<00:00, 128.25it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.44it/s] SNN accuracy (simulation 50 time-steps): 0.9777
---------------------------------------------
Converting using RobustNorm
100%|██████████| 600/600 [00:19<00:00, 31.06it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.75it/s] SNN accuracy (simulation 50 time-steps): 0.9841
---------------------------------------------
Converting using 1/2 max(activation) as scales...
100%|██████████| 600/600 [00:04<00:00, 126.64it/s] ]Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.90it/s] SNN accuracy (simulation 50 time-steps): 0.9844
---------------------------------------------
Converting using 1/3 max(activation) as scales
100%|██████████| 600/600 [00:04<00:00, 126.27it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.73it/s] SNN accuracy (simulation 50 time-steps): 0.9828
---------------------------------------------
Converting using 1/4 max(activation) as scales
100%|██████████| 600/600 [00:04<00:00, 128.94it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.47it/s] SNN accuracy (simulation 50 time-steps): 0.9747
---------------------------------------------
Converting using 1/5 max(activation) as scales
100%|██████████| 600/600 [00:04<00:00, 121.18it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.42it/s] SNN accuracy (simulation 50 time-steps): 0.9487
---------------------------------------------

模型转换的速度可以看到是非常快的。模型推理速度200步仅需11s完成(GTX 2080ti)。
根据模型输出的随时间变化的准确率,我们可以绘制不同设置下的准确率图像。

.. code-block:: python

simulator is working on the normal mode, device: cuda:0
100%|██████████| 100/100 [00:46<00:00, 2.15it/s]
--------------------simulator summary--------------------
time elapsed: 46.55072790000008 (sec)
---------------------------------------------------------
fig = plt.figure()
plt.plot(np.arange(0, T), mode_max_accs, label='mode: max')
plt.plot(np.arange(0, T), mode_robust_accs, label='mode: 99.9%')
plt.plot(np.arange(0, T), mode_two_accs, label='mode: 1.0/2')
plt.plot(np.arange(0, T), mode_three_accs, label='mode: 1.0/3')
plt.plot(np.arange(0, T), mode_four_accs, label='mode: 1.0/4')
plt.plot(np.arange(0, T), mode_five_accs, label='mode: 1.0/5')
plt.legend()
plt.xlabel('t')
plt.ylabel('Acc')
plt.show()

.. image:: ../_static/tutorials/clock_driven/5_ann2snn/accuracy_mode.png

通过最后的输出,可以知道,仿真器使用了46.6s。转换后的SNN准确率可以从simulator文件夹中plot.pdf看到,最高的转换准确率为98.51%。转换带来了0.37%的性能下降。通过增加推理时间可以减少转换损失。
不同的设置可以得到不同的结果,有的推理速度快,但是最终精度低,有的推理慢,但是精度高。用户可以根据自己的需求选择模型设置

.. [#f1] Rueckauer B, Lungu I-A, Hu Y, Pfeiffer M and Liu S-C (2017) Conversion of Continuous-Valued Deep Networks to Efficient Event-Driven Networks for Image Classification. Front. Neurosci. 11:682.
.. [#f2] Diehl, Peter U. , et al. Fast classifying, high-accuracy spiking deep networks through weight and threshold balancing. Neural Networks (IJCNN), 2015 International Joint Conference on IEEE, 2015.
.. [#f3] Rueckauer, B., Lungu, I. A., Hu, Y., & Pfeiffer, M. (2016). Theory and tools for the conversion of analog to spiking convolutional neural networks. arXiv preprint arXiv:1612.04052.
.. [#f4] Sengupta, A., Ye, Y., Wang, R., Liu, C., & Roy, K. (2019). Going deeper in spiking neural networks: Vgg and residual architectures. Frontiers in neuroscience, 13, 95.
.. [#f4] Sengupta, A., Ye, Y., Wang, R., Liu, C., & Roy, K. (2019). Going deeper in spiking neural networks: Vgg and residual architectures. Frontiers in neuroscience, 13, 95.

+ 2
- 2
docs/source/clock_driven_en/0_neuron.rst View File

@@ -7,7 +7,7 @@ Translator: `YeYumin <https://github.com/YEYUMIN>`_
This tutorial focuses on :class:`spikingjelly.clock_driven.neuron` and introduces spiking neurons and clock-driven
simulation methods.
Spiking Nneuron Model
Spiking Neuron Model
-----------------------------------------------
In ``spikingjelly``, we define the neuron which can only output spikes, i.e. 0 or 1, as a "spiking neuron".
Networks that use spiking neurons are called Spiking Neural Networks (SNNs).
@@ -235,4 +235,4 @@ The results are as follows:
:width: 100%
.. image:: ../_static/tutorials/clock_driven/0_neuron/2.*
:width: 100%
:width: 100%

+ 51
- 3
docs/source/clock_driven_en/13_neuromorphic_datasets.rst View File

@@ -203,8 +203,55 @@ We will get the images like:
.. image:: ../_static/tutorials/clock_driven/13_neuromorphic_datasets/dvsg.*
:width: 100%

Fixed Duration Integrating
--------------------------------------
Integrating by fixed duration is more compatible with the practical application. For example, if we set duration as ``10 ms``,
then a sample with length ``L ms`` can be integrated to frames with frame number ``math.floor(L / 10)``. However, the lengthes
of samples in neuromorphic datasets are not identical, and we will get frames with different frame numbers when integrating
with fixed duration. Fortunately, we can use :class:`spikingjelly.datasets.pad_sequence_collate` and
:class:`spikingjelly.datasets.padded_sequence_mask` to pad/unpad frames.

Example codes:

.. code:: python

import torch
from torch.utils.data import DataLoader
from spikingjelly.datasets import pad_sequence_collate, padded_sequence_mask, dvs128_gesture
root='D:/datasets/DVS128Gesture'
train_set = dvs128_gesture.DVS128Gesture(root, data_type='frame', duration=1000000, train=True)
for i in range(5):
x, y = train_set[i]
print(f'x[{i}].shape=[T, C, H, W]={x.shape}')
train_data_loader = DataLoader(train_set, collate_fn=pad_sequence_collate, batch_size=5)
for x, y, x_len in train_data_loader:
print(f'x.shape=[N, T, C, H, W]={tuple(x.shape)}')
print(f'x_len={x_len}')
mask = padded_sequence_mask(x_len) # mask.shape = [T, N]
print(f'mask=\n{mask.t().int()}')
break

The outputs are:

.. code:: bash

The directory [D:/datasets/DVS128Gesture\duration_1000000] already exists.
x[0].shape=[T, C, H, W]=(6, 2, 128, 128)
x[1].shape=[T, C, H, W]=(6, 2, 128, 128)
x[2].shape=[T, C, H, W]=(5, 2, 128, 128)
x[3].shape=[T, C, H, W]=(5, 2, 128, 128)
x[4].shape=[T, C, H, W]=(7, 2, 128, 128)
x.shape=[N, T, C, H, W]=(5, 7, 2, 128, 128)
x_len=tensor([6, 6, 5, 5, 7])
mask=
tensor([[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)

Custom Integrating Method
-----------------------
----------------------------
SpikingJelly provides user-defined integrating method. The user should provide a function ``custom_integrate_function`` and
the name of directory ``custom_integrated_frames_dir_name`` for saving frames.

@@ -224,8 +271,9 @@ a function:
def integrate_events_to_2_frames_randomly(events: Dict, H: int, W: int):
index_split = np.random.randint(low=0, high=events['t'].__len__())
frames = np.zeros([2, 2, H, W])
frames[0] = sjds.integrate_events_segment_to_frame(events, H, W, 0, index_split)
frames[1] = sjds.integrate_events_segment_to_frame(events, H, W, index_split, events['t'].__len__())
t, x, y, p = (events[key] for key in ('t', 'x', 'y', 'p'))
frames[0] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, 0, index_split)
frames[1] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, index_split, events['t'].__len__())
return frames

Now let us use this function to create frames dataset:


+ 305
- 352
docs/source/clock_driven_en/5_ann2snn.rst View File

@@ -2,22 +2,89 @@ spikingjelly.clock_driven.ann2snn
=======================================
Author: `DingJianhao <https://github.com/DingJianhao>`_, `fangwei123456 <https://github.com/fangwei123456>`_

This tutorial focuses on ``spikingjelly.clock_driven.ann2snn``introduce how to convert the trained feedforward ANN to SNN and simulate it on the SpikingJelly framework.
This tutorial focuses on ``spikingjelly.clock_driven.ann2snn``, introduce how to convert the trained feedforward ANN to SNN and simulate it on the SpikingJelly framework.

Currently support conversion of Pytorch modules including ``nn.Conv2d`` , ``nn.Linear`` , ``nn.MaxPool2d`` , ``nn.AvgPool2d`` , ``nn.BatchNorm1d`` , ``nn.BatchNorm2d`` , ``nn.Flatten`` , ``nn.ReLU`` ,other module solutions are under development...
There are two sets of implementations in earlier implementations: ONNX-based and PyTorch-based. Due to the instability of ONNX, this version is an enhanced version of PyTorch, which natively supports complex topologies (such as ResNet). Let's have a look!

Theoretical basis of ANN2SNN
----------------------------

Compared with ANN, SNN generates discrete spikes, which is conducive to efficient communication. Today, ANN is popular, while direct training of SNN requires far more resources. Naturally, people will think of using very mature ANN to switch to SNN, and hope that SNN can have similar performance. This leads to the question of how to build a bridge between ANN and SNN. The current SNN mainstream method is to use frequency coding. So for the output layer, we will use the number of neuron output spikes to determine the category. Is the firing rate related to ANN?
Compared with ANN, the generated pulses of SNN are discrete, which is conducive to efficient communication. Today, with the popularity of ANN, the direct training of SNN requires more resources. Naturally, we will think of using the now very mature ANN to convert to SNN, and hope that SNN can have similar performance. This involves the problem of how to build a bridge between ANN and SNN. Now the mainstream way of SNN is to use frequency encoding, so for the output layer, we will use the number of neuron output pulses to judge the category. Is there a relationship between the release rate and ANN?

Fortunately, there is a strong correlation between the non-linear activation of ReLU neurons in ANN and the firing rate of IF neurons in SNN (reset by subtracting the threshold :math:`V_{threshold}` ). We can use this feature for conversion. The neuron update method mentioned here is the Soft method mentioned in the `Clock Driven Tutorial <https://spikingjelly.readthedocs.io/zh_CN/latest/clock_driven_en/0_neuron.html>`_.
Fortunately, there is a strong correlation between the nonlinear activation of ReLU neurons in ANN and the firing rate of IF neurons in SNN (reset by subtracting the threshold: math:`V_{threshold}`). this feature to convert. The neuron update method mentioned here is the Soft method mentioned in `Time-driven tutorial <https://spikingjelly.readthedocs.io/zh_CN/latest/clock_driven/0_neuron.html>`_.

The following figure shows this correspondence: the left figure is a curve obtained by giving a constant input to an IF neuron and observing its firing over a period of time. The right one is the ReLU activation curve, which satisfies :math:`activation = max(input,0)`.
Experiment: Relationship between IF neuron spiking frequency and input
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We gave constant input to the IF neuron and observed its output spikes and spike firing frequency. First import the relevant modules, create a new IF neuron layer, determine the input and draw the input of each IF neuron :math:`x_{i}`:

.. code-block:: python

import torch
from spikingjelly.clock_driven import neuron
from spikingjelly import visualizing
from matplotlib import pyplot as plt
import numpy as np

plt.rcParams['figure.dpi'] = 200
if_node = neuron.IFNode(v_reset=None)
T = 128
x = torch.arange(-0.2, 1.2, 0.04)
plt.scatter(torch.arange(x.shape[0]), x)
plt.title('Input $x_{i}$ to IF neurons')
plt.xlabel('Neuron index $i$')
plt.ylabel('Input $x_{i}$')
plt.grid(linestyle='-.')
plt.show()

.. image:: ../_static/tutorials/clock_driven/5_ann2snn/0.*
:width: 100%

Next, send the input to the IF neuron layer, and run the ``T=128`` step to observe the pulses and pulse firing frequency of each neuron:

.. code-block:: python

s_list = []
for t in range(T):
s_list.append(if_node(x).unsqueeze(0))

out_spikes = np.asarray(torch.cat(s_list))
visualizing.plot_1d_spikes(out_spikes, 'IF neurons\' spikes and firing rates', 't', 'Neuron index $i$')
plt.show()

.. image:: ../_static/tutorials/clock_driven/5_ann2snn/1.*
:width: 100%

It can be found that the frequency of the pulse firing is within a certain range, which is proportional to the size of the input :math:`x_{i}`.

Next, let's plot the firing frequency of the IF neuron against the input :math:`x_{i}` and compare it with :math:`\mathrm{ReLU}(x_{i})`:

.. code-block:: python

plt.subplot(1, 2, 1)
firing_rate = np.mean(out_spikes, axis=1)
plt.plot(x, firing_rate)
plt.title('Input $x_{i}$ and firing rate')
plt.xlabel('Input $x_{i}$')
plt.ylabel('Firing rate')
plt.grid(linestyle='-.')

plt.subplot(1, 2, 2)
plt.plot(x, x.relu())
plt.title('Input $x_{i}$ and ReLU($x_{i}$)')
plt.xlabel('Input $x_{i}$')
plt.ylabel('ReLU($x_{i}$)')
plt.grid(linestyle='-.')
plt.show()

.. image:: ../_static/tutorials/clock_driven/5_ann2snn/2.*
:width: 100%

It can be found that the two curves are almost the same. It should be noted that the pulse frequency cannot be higher than 1, so the IF neuron cannot fit the input of the ReLU in the ANN is larger than 1.

Theoretical basis of ANN2SNN
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The literature [#f1]_ provides a theoretical basis for analyzing the conversion of ANN to SNN. The theory shows that the IF neuron in SNN is an unbiased estimator of ReLU activation function over time.

For the first layer of the neural network, the input layer, discuss the relationship between the firing rate of SNN neurons :math:`r` and the activation in the corresponding ANN. Assume that the input is constant as :math:`z \in [0,1]`.
@@ -55,393 +122,279 @@ Similarly, for the higher layers of the neural network, literature [#f1]_ furthe

For details, please refer to [#f1]_. The methods in ann2snn also mainly come from [#f1]_ .

Conversion and simulation
-------------------------

Specifically, there are two main steps for converting feedforward ANN to SNN: model parsing and model simulation.

model parsing
^^^^^^^^^^^^^
Converting to spiking neural network
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Model parsing mainly solves two problems:
Conversion mainly solves two problems:

1. Researchers propose Batch Normalization for fast training and convergence. Batch normalization aims to normalize the output of ANN to 0 mean, which is contrary to the characteristics of SNN. Therefore, the parameters of BN need to be absorbed into the previous parameter layer (Linear, Conv2d)
1. ANN proposes Batch Normalization for fast training and convergence. Batch normalization aims to normalize the ANN output to 0 mean, which is contrary to the properties of SNNs. Therefore, the parameters of BN can be absorbed into the previous parameter layers (Linear, Conv2d)

2. According to the conversion theory, the input and output of each layer of ANN need to be limited to the range of [0,1], which requires scaling of the parameters (model normalization)
2. According to the transformation theory, the input and output of each layer of ANN need to be limited to the range of [0,1], which requires scaling the parameters (model normalization)

Absorbing BatchNorm parameters
◆ BatchNorm parameter absorption

Assume that the parameters of BatchNorm are :math:`\gamma` (BatchNorm.weight), :math:`\beta` (BatchNorm.bias), :math:`\mu`(BatchNorm.running_mean), :math:`\sigma`(BatchNorm.running_std, square root of running_var).For specific parameter definitions, see ``torch.nn.batchnorm``.
Parameter modules (such as Linear) have parameters :math:`W` and :math:`b`. Absorbing BatchNorm parameters is transfering the parameters of BatchNorm to :math:`W` and :math:`b` of the parameter module through calculation,, so that the output of the data in new module is the same as when there is BatchNorm.
In this regard, the new model's :math:`\bar{W}` and :math:`\bar{b}` formulas are expressed as:
Assume that the parameters of BatchNorm are: math:`\gamma` (``BatchNorm.weight``), :math:`\beta` (``BatchNorm.bias``), :math:`\mu` (``BatchNorm. .running_mean``) ,
:math:`\sigma` (``BatchNorm.running_var``, :math:`\sigma = \sqrt{\mathrm{running\_var}}`). For specific parameter definitions, see
`torch.nn.BatchNorm1d <https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html#torch.nn.BatchNorm1d>`_ .
Parameter modules (eg Linear) have parameters :math:`W` and :math:`b` . BatchNorm parameter absorption is to transfer the parameters of BatchNorm to :math:`W` and :math:`b` of the parameter module by operation, so that the output of the new module of data input is the same as when there is BatchNorm.
For this, the :math:`\bar{W}` and :math:`\bar{b}` formulas for the new model are expressed as:

.. math::
\bar{W} = \frac{\gamma}{\sigma} W
\bar{W} = \frac{\gamma}{\sigma} W

.. math::
\bar{b} = \frac{\gamma}{\sigma} (b - \mu) + \beta
\bar{b} = \frac{\gamma}{\sigma} (b - \mu) + \beta

◆ Model normalization
◆ Model Normalization

For a parameter module, assuming that the input tensor and output tensor are obtained, the maximum value of the input tensor is :math:`\lambda_{pre}`, and the maximum value of the output tensor is :math:`\lambda`
For a parameter module, it is assumed that its input tensor and output tensor are obtained, the maximum value of its input tensor is: math:`\lambda_{pre}`, and the maximum value of its output tensor is: math:`\lambda `
Then, the normalized weight :math:`\hat{W}` is:

.. math::
\hat{W} = W * \frac{\lambda_{pre}}{\lambda}
\hat{W} = W * \frac{\lambda_{pre}}{\lambda}

The normalized bias :math:`\hat{b}` is:

.. math::
\hat{b} = b / \lambda

Although the output distribution of each layer of ANN obeys a certain distribution, there are often large outliers in the data, which will reduce the overall neuron firing rate.
To solve this problem, robust normalization adjusts the scaling factor from the maximum value of the tensor to the p-percentile of the tensor. The recommended percentile value in the literature is 99.9

So far, the operations we have done on neural networks are completely equivalent. The performance of the current model should be the same as the original model.

Model simulation
^^^^^^^^^^^^^^^^

Before simulation, we need to change the ReLU activation function in the original model into an IF neuron.
For the average pooling in ANN, we need to transform it into spatial subsampling. Because IF neuron can be equivalent to ReLU activation function. Adding IF neurons after spatial downsampling has little effect on the results.
There is currently no ideal solution for maximum pooling in ANN. The best solution at present is to control the spike channel [#f1]_ with a gated function based on the momentum accumulation spike. This is also the default method in ann2snn. There are also literatures proposing to use spatial subsampling to replace Maxpool2d.

In simulation, according to the conversion theory, SNN needs to input a constant analog input. Using a Poisson encoder will bring about a decrease in accuracy. Both Poisson coding and constant input have been implemented, and one can perform different experiments if interested.

Optional configuration
^^^^^^^^^^^^^^^^^^^^^^

In view of the various optional configurations in the conversion, the ``Config`` class implemented in ``ann2snn.utils`` is used to load the default configuration and save the configuration. By loading the default configuration in Config and modifying it, one can set the parameters required when running.

Below are the introductions of the configuration corresponding to different parameters, the feasible input range, and why this configuration is needed.

(1) conf['parser']['robust_norm']

Available value:``bool``

Note:when ``True``, use robust normalization

(2) conf['simulation']['reset_to_zero']

Available value: ``None``, floating point

Note: When floating point, voltage of neurons that just fired spikes will be set to :math:``V_{reset}``; when ``None``, voltage of neurons that just fired spikes will subtract :math:``V_{threshold}``. For model that need normalization, setting to ``None`` is default, which has theoretical guaratee.

(3) conf['simulation']['encoder']['possion']

Available value:``bool``
\hat{b} = \frac{b}{\lambda}

Note: When ``True``, use Possion encoder; otherwise, use constant input over T steps.
Although the distribution of the output of each layer of ANN obeys a certain distribution, there are often large outliers in the data, which will lead to a decrease in the overall neuron firing rate.
To address this, robust normalization adjusts the scaling factor from the maximum value of the tensor to the p-quantile of the tensor. The recommended quantile value in the literature is 99.9.

(4) conf['simulation']['avg_pool']['has_neuron']
So far, what we have done with neural networks is numerically equivalent. The current model should perform the same as the original model.

Available value:``bool``
In the conversion, we need to change the ReLU activation function in the original model into IF neurons.
For average pooling in ANN, we need to convert it to spatial downsampling. Since IF neurons can be equivalent to the ReLU activation function. Adding IF neurons or not after spatial downsampling has minimal effect on the results.
There is currently no very ideal solution for max pooling in ANNs. The best solution so far is to control the pulse channel [#f1]_ with a gating function based on momentum accumulated pulses. Here we still recommend using avgpool2d.
When simulating, according to the transformation theory, the SNN needs to input a constant analog input. Using a Poisson encoder will bring about a reduction in accuracy.

Note: When ``True``, avgpool2d is converted to spatial subsampling with a layer of IF neurons; otherwise, it is only converted to spatial subsampling.
Implementation and optional configuration
^^^^^^^^^^^^^^^^^^^^^^^^^^

(5) conf['simulation']['max_pool']['if_spatial_avg']
The ann2snn framework will receive another major update in April 2022. The two categories of parser and simulator have been cancelled. Using the converter class replaces the previous solution. The current scheme is more compact and has more room for transformation settings.

Available value:``bool``
◆ Converter class
This class is used to convert ReLU's ANN to SNN. Three common patterns are implemented here.
The most common is the maximum current switching mode, which utilizes the upper and lower activation limits of the front and rear layers so that the case with the highest firing rate corresponds to the case where the activation achieves the maximum value. Using this mode requires setting the parameter mode to ``max``[#f2]_.
The 99.9% current switching mode utilizes the 99.9% activation quantile to limit the upper activation limit. Using this mode requires setting the parameter mode to ``99.9%``[#f1]_.
In the scaling conversion mode, the user needs to specify the scaling parameters into the mode, and the current can be limited by the activated maximum value after scaling. Using this mode requires setting the parameter mode to a float of 0-1.

Note: When ``True``,maxpool2d is converted to avgpool2d. As referred in many literatures, this method will cause accuracy degrading.
Classify MNIST
--------------

(6) conf['simulation']['max_pool']['if_wta']
Now we use ``ann2snn`` to build a simple convolutional network to classify the MNIST dataset.

Available value:``bool``

Note: When ``True``, maxpool2d in SNN is identical with maxpool2d in ANN. Using maxpool2d in ANN means that when a spike is available in the Receptive Field, output a spike.

(7) conf['simulation']['max_pool']['momentum']

Available value: ``None``, floating point [0,1]

Note: By default, maxpool2d layer is converted into a gated function controled channel based on momentum cumulative spikes. When set to ``None``, the spike is accumulated directly. If set to floating point in the range of [0,1], spike momentum is accumulated.

The default configuration is:

.. code-block:: python

default_config =
{
'simulation':
{
'reset_to_zero': False,
'encoder':
{
'possion': False
},
'avg_pool':
{
'has_neuron': True
},
'max_pool':
{
'if_spatial_avg': False,
'if_wta': False,
'momentum': None
}
},
'parser':
{
'robust_norm': True
}
}



MNIST classification
--------------------

Now, use ``ann2snn`` to build a simple convolutional network to classify the MNIST dataset.

First define our network structure:

.. code-block:: python
class ANN(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(1, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),

nn.Conv2d(32, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),

nn.Conv2d(32, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),

nn.Flatten(),
nn.Linear(32, 10),
nn.ReLU()
)

def forward(self,x):
x = self.network(x)
return x

Note: In the defined network, the order of module definition must be consistent with the forward order, otherwise it will affect the automatic analysis of the network.It is best to use ``nn.Sequence(·)`` to completely define the network. After each Conv2d and Linear layer, a ReLU layer must be placed, which can be separated by a BatchNorm layer. No ReLU is added after the pooling layer. If you encounter a situation where you need to expand the tensor, define a ``nn.Flatten`` module in the network. In the forward function, you need to use the defined Flatten instead of the view function.

Define our hyperparameters:

.. code-block:: python

device = input('输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ')
dataset_dir = input('输入保存MNIST数据集的位置,例如“./”\n input root directory for saving MNIST dataset, e.g., "./": ')
batch_size = int(input('输入batch_size,例如“64”\n input batch_size, e.g., "64": '))
learning_rate = float(input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": '))
T = int(input('输入仿真时长,例如“100”\n input simulating steps, e.g., "100": '))
train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“10”\n input training epochs, e.g., "10": '))
model_name = input('输入模型名字,例如“mnist”\n input model name, for log_dir generating , e.g., "mnist": ')

The program searches for the trained model archive (a file with the same name as `model_name`) according to the specified folder, and all subsequent temporary files will be stored in that folder.

Load the default conversion configuration and save

.. code-block:: python

config = utils.Config.default_config
print('ann2snn config:\n\t', config)
utils.Config.store_config(os.path.join(log_dir,'default_config.json'),config)


Initialize data loader, network, optimizer, loss function

.. code-block:: python

# Initialize the network
ann = ANN().to(device)
# Define loss function
loss_function = nn.CrossEntropyLoss()
# Use Adam optimizer
optimizer = torch.optim.Adam(ann.parameters(), lr=learning_rate, weight_decay=5e-4)

Train ANN and test it regularly. You can also use the pre-written training program in utils during training.
First define our network structure (see ``ann2snn.sample_models.mnist_cnn``):

.. code-block:: python

for epoch in range(train_epoch):
# Train the network using a pre-prepared code in ''utils''
utils.train_ann(net=ann,
device=device,
data_loader=train_data_loader,
optimizer=optimizer,
loss_function=loss_function,
epoch=epoch
)
# Validate the network using a pre-prepared code in ''utils''
acc = utils.val_ann(net=ann,
device=device,
data_loader=test_data_loader,
epoch=epoch
)
if best_acc <= acc:
utils.save_model(ann, log_dir, model_name+'.pkl')

The complete code is located in ``ann2snn.examples.if_cnn_mnist.py``, in the code we also use Tensorboard to save training logs. You can run it directly on the Python command line:
class ANN(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(1, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),

.. code-block:: python
nn.Conv2d(32, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),

>>> import spikingjelly.clock_driven.ann2snn.examples.if_cnn_mnist as if_cnn_mnist
>>> if_cnn_mnist.main()
输入运行的设备,例如“cpu”或“cuda:0”
input device, e.g., "cpu" or "cuda:0": cuda:15
输入保存MNIST数据集的位置,例如“./”
input root directory for saving MNIST dataset, e.g., "./": ./mnist
输入batch_size,例如“64”
input batch_size, e.g., "64": 128
输入学习率,例如“1e-3”
input learning rate, e.g., "1e-3": 1e-3
输入仿真时长,例如“100”
input simulating steps, e.g., "100": 100
输入训练轮数,即遍历训练集的次数,例如“10”
input training epochs, e.g., "10": 10
输入模型名字,用于自动生成日志文档,例如“mnist”
input model name, for log_dir generating , e.g., "mnist"

If the input of the main function is not a folder with valid files, an automatic log file folder is automatically generated.
Terminal outputs root directory for saving logs, e.g., "./": ./log-mnist1596804385.476601

Epoch 0 [1/937] ANN Training Loss:2.252 Accuracy:0.078
Epoch 0 [101/937] ANN Training Loss:1.424 Accuracy:0.669
Epoch 0 [201/937] ANN Training Loss:1.117 Accuracy:0.773
Epoch 0 [301/937] ANN Training Loss:0.953 Accuracy:0.795
Epoch 0 [401/937] ANN Training Loss:0.865 Accuracy:0.788
Epoch 0 [501/937] ANN Training Loss:0.807 Accuracy:0.792
Epoch 0 [601/937] ANN Training Loss:0.764 Accuracy:0.795
Epoch 0 [701/937] ANN Training Loss:0.726 Accuracy:0.834
Epoch 0 [801/937] ANN Training Loss:0.681 Accuracy:0.880
Epoch 0 [901/937] ANN Training Loss:0.641 Accuracy:0.888
Epoch 0 [100/100] ANN Validating Loss:0.328 Accuracy:0.881
Save model to: ./log-mnist1596804385.476601\mnist.pkl
...
Epoch 9 [901/937] ANN Training Loss:0.036 Accuracy:0.990
Epoch 9 [100/100] ANN Validating Loss:0.042 Accuracy:0.988
Save model to: ./log-mnist1596804957.0179427\mnist.pkl

In the example, this model is trained for 10 epochs. The changes in the accuracy of the test set during training are as follows:

.. image:: ../_static/tutorials/clock_driven/5_ann2snn/accuracy_curve.png

In the end, the accuracy on test dataset is 98.8%.

Take a part of the data from the training set and use it for the normalization step of the model. Here we take 1/500 of the training data, which is 100 pictures. But it should be noted that the range of the data tensor taken from the dataset is [0, 255], and it needs to be divided by 255 to become a floating point tensor in the range of [0.0, 1.0] to match the feasible range of firing rate.
nn.Conv2d(32, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),

.. code-block:: python
nn.Flatten(),
nn.Linear(32, 10),
nn.ReLU()
)

norm_set_len = int(train_data_dataset.data.shape[0] / 500)
print('Using %d pictures as norm set'%(norm_set_len))
norm_set = train_data_dataset.data[:norm_set_len, :, :].float() / 255
norm_tensor = torch.FloatTensor(norm_set).view(-1,1,28,28)
def forward(self,x):
x = self.network(x)
return x

Call the standard conversion function ``standard_conversion`` implemented in ``ann2snn.utils`` to realize ANN conversion and SNN simulation.
Note: If you need to expand the tensor, define a ``nn.Flatten`` module in the network, and use the defined Flatten instead of the view function in the forward function.

.. code-block:: python

utils.standard_conversion(model_name=model_name,
norm_data=norm_tensor,
test_data_loader=test_data_loader,
device=device,
T=T,
log_dir=log_dir,
config=config
)

In the process, the normalized model structure is output:

.. code-block:: python

ModelParser(
(network): Sequential(
(0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
(2): AvgPool2d(kernel_size=2, stride=2, padding=0)
(3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(4): ReLU()
(5): AvgPool2d(kernel_size=2, stride=2, padding=0)
(6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(7): ReLU()
(8): AvgPool2d(kernel_size=2, stride=2, padding=0)
(9): Flatten()
(10): Linear(in_features=32, out_features=10, bias=True)
(11): ReLU()
)
)

At the same time, one can also observe the structure of SNN:

.. code-block:: python

SNN(
(network): Sequential(
(0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(1): IFNode(
v_threshold=1.0, v_reset=None
(surrogate_function): Sigmoid()
)
(2): AvgPool2d(kernel_size=2, stride=2, padding=0)
(3): IFNode(
v_threshold=1.0, v_reset=None
(surrogate_function): Sigmoid()
)
(4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(5): IFNode(
v_threshold=1.0, v_reset=None
(surrogate_function): Sigmoid()
)
(6): AvgPool2d(kernel_size=2, stride=2, padding=0)
(7): IFNode(
v_threshold=1.0, v_reset=None
(surrogate_function): Sigmoid()
)
(8): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(9): IFNode(
v_threshold=1.0, v_reset=None
(surrogate_function): Sigmoid()
)
(10): AvgPool2d(kernel_size=2, stride=2, padding=0)
(11): IFNode(
v_threshold=1.0, v_reset=None
(surrogate_function): Sigmoid()
)
(12): Flatten()
(13): Linear(in_features=32, out_features=10, bias=True)
(14): IFNode(
v_threshold=1.0, v_reset=None
(surrogate_function): Sigmoid()
)
)
)

It can be seen that the activation of ReLU in the ANN model is replaced by the IFNode of SNN. Each layer of AvgPool2d is followed by a layer of IFNode.

Due to the long time of model simulation, the current accuracy and simulation progress are continuously output:
Define our hyperparameters:

.. code-block:: python

[SNN Simulating... 1.00%] Acc:0.990
[SNN Simulating... 2.00%] Acc:0.990
[SNN Simulating... 3.00%] Acc:0.990
[SNN Simulating... 4.00%] Acc:0.988
[SNN Simulating... 5.00%] Acc:0.990
……
[SNN Simulating... 95.00%] Acc:0.986
[SNN Simulating... 96.00%] Acc:0.986
[SNN Simulating... 97.00%] Acc:0.986
[SNN Simulating... 98.00%] Acc:0.986
[SNN Simulating... 99.00%] Acc:0.987
SNN Simulating Accuracy:0.987
Summary: ANN Accuracy:98.7900% SNN Accuracy:98.6500% [Decreased 0.1400%]

Through the final output, we can know that the accuracy of ANN's MNIST classification is 98.79%. The accuracy of the converted SNN is 98.65%. The conversion resulted in a 0.14% performance degradation.
torch.random.manual_seed(0)
torch.cuda.manual_seed(0)
device = 'cuda'
dataset_dir = 'G:/Dataset/mnist'
batch_size = 100
T = 50

Here T is the inference time step used in inference for a while.

If you want to train, you also need to initialize the data loader, optimizer, loss function, for example:

.. code-block::python

lr = 1e-3
epochs = 10
# define the loss function
loss_function = nn.CrossEntropyLoss()
# Use Adam optimizer
optimizer = torch.optim.Adam(ann.parameters(), lr=lr, weight_decay=5e-4)

Train the ANN. In the example, our model is trained for 10 epochs. The test set accuracy changes during training are as follows:

.. code-block::python

Epoch: 0 100%|██████████| 600/600 [00:05<00:00, 112.04it/s]
Validating Accuracy: 0.972
Epoch: 1 100%|██████████| 600/600 [00:05<00:00, 105.43it/s]
Validating Accuracy: 0.986
Epoch: 2 100%|██████████| 600/600 [00:05<00:00, 107.49it/s]
Validating Accuracy: 0.987
Epoch: 3 100%|██████████| 600/600 [00:05<00:00, 109.26it/s]
Validating Accuracy: 0.990
Epoch: 4 100%|██████████| 600/600 [00:05<00:00, 103.98it/s]
Validating Accuracy: 0.984
Epoch: 5 100%|██████████| 600/600 [00:05<00:00, 100.42it/s]
Validating Accuracy: 0.989
Epoch: 6 100%|██████████| 600/600 [00:06<00:00, 96.24it/s]
Validating Accuracy: 0.991
Epoch: 7 100%|██████████| 600/600 [00:05<00:00, 104.97it/s]
Validating Accuracy: 0.992
Epoch: 8 100%|██████████| 600/600 [00:05<00:00, 106.45it/s]
Validating Accuracy: 0.991
Epoch: 9 100%|██████████| 600/600 [00:05<00:00, 111.93it/s]
Validating Accuracy: 0.991

After training the model, we quickly load the model to test the performance of the saved model:

.. code-block::python

model.load_state_dict(torch.load('SJ-mnist-cnn_model-sample.pth'))
acc = val(model, device, test_data_loader)
print('ANN Validating Accuracy: %.4f' % (acc))

The output is as follows:

.. code-block::python

100%|██████████| 200/200 [00:02<00:00, 89.44it/s]
ANN Validating Accuracy: 0.9870

Converting with Converter is very simple, you only need to set the mode you want to use in the parameters. For example, to use MaxNorm, you need to define an ``ann2snn.Converter`` first, and forward the model to this object:

.. code-block::python

model_converter = ann2snn.Converter(mode='max', dataloader=train_data_loader)
snn_model = model_converter(model)

snn_model is the output SNN model.

Following this example, we define the modes as ``max``, ``99.9%``, ``1.0/2``, ``1.0/3``, ``1.0/4``, ``1.0/ 5`` case SNN transformation and separate inference T steps to get the accuracy.

.. code-block::python

print('---------------------------------------------')
print('Converting using MaxNorm')
model_converter = ann2snn.Converter(mode='max', dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_max_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_max_accs[-1]))

print('---------------------------------------------')
print('Converting using RobustNorm')
model_converter = ann2snn.Converter(mode='99.9%', dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_robust_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_robust_accs[-1]))

print('---------------------------------------------')
print('Converting using 1/2 max(activation) as scales...')
model_converter = ann2snn.Converter(mode=1.0 / 2, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_two_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_two_accs[-1]))

print('---------------------------------------------')
print('Converting using 1/3 max(activation) as scales')
model_converter = ann2snn.Converter(mode=1.0 / 3, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_three_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_three_accs[-1]))

print('---------------------------------------------')
print('Converting using 1/4 max(activation) as scales')
model_converter = ann2snn.Converter(mode=1.0 / 4, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_four_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_four_accs[-1]))

print('---------------------------------------------')
print('Converting using 1/5 max(activation) as scales')
model_converter = ann2snn.Converter(mode=1.0 / 5, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_five_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_five_accs[-1]))

Observe the control bar output:

.. code-block::python

---------------------------------------------
Converting using MaxNorm
100%|██████████| 600/600 [00:04<00:00, 128.25it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.44it/s] SNN accuracy (simulation 50 time-steps): 0.9777
---------------------------------------------
Converting using RobustNorm
100%|██████████| 600/600 [00:19<00:00, 31.06it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.75it/s] SNN accuracy (simulation 50 time-steps): 0.9841
---------------------------------------------
Converting using 1/2 max(activation) as scales...
100%|██████████| 600/600 [00:04<00:00, 126.64it/s] ]Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.90it/s] SNN accuracy (simulation 50 time-steps): 0.9844
---------------------------------------------
Converting using 1/3 max(activation) as scales
100%|██████████| 600/600 [00:04<00:00, 126.27it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.73it/s] SNN accuracy (simulation 50 time-steps): 0.9828
---------------------------------------------
Converting using 1/4 max(activation) as scales
100%|██████████| 600/600 [00:04<00:00, 128.94it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.47it/s] SNN accuracy (simulation 50 time-steps): 0.9747
---------------------------------------------
Converting using 1/5 max(activation) as scales
100%|██████████| 600/600 [00:04<00:00, 121.18it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.42it/s] SNN accuracy (simulation 50 time-steps): 0.9487
---------------------------------------------

The speed of model conversion can be seen to be very fast. Model inference speed of 200 steps takes only 11s to complete (GTX 2080ti).
Based on the time-varying accuracy of the model output, we can plot the accuracy for different settings.

.. code-block::python

fig = plt.figure()
plt.plot(np.arange(0, T), mode_max_accs, label='mode: max')
plt.plot(np.arange(0, T), mode_robust_accs, label='mode: 99.9%')
plt.plot(np.arange(0, T), mode_two_accs, label='mode: 1.0/2')
plt.plot(np.arange(0, T), mode_three_accs, label='mode: 1.0/3')
plt.plot(np.arange(0, T), mode_four_accs, label='mode: 1.0/4')
plt.plot(np.arange(0, T), mode_five_accs, label='mode: 1.0/5')
plt.legend()
plt.xlabel('t')
plt.ylabel('Acc')
plt.show()

.. image:: ../_static/tutorials/clock_driven/5_ann2snn/accuracy_mode.png

Different settings can get different results, some inference speed is fast, but the final accuracy is low, and some inference is slow, but the accuracy is high. Users can choose model settings according to their needs.

.. [#f1] Rueckauer B, Lungu I-A, Hu Y, Pfeiffer M and Liu S-C (2017) Conversion of Continuous-Valued Deep Networks to Efficient Event-Driven Networks for Image Classification. Front. Neurosci. 11:682.
.. [#f2] Diehl, Peter U. , et al. Fast classifying, high-accuracy spiking deep networks through weight and threshold balancing. Neural Networks (IJCNN), 2015 International Joint Conference on IEEE, 2015.
.. [#f3] Rueckauer, B., Lungu, I. A., Hu, Y., & Pfeiffer, M. (2016). Theory and tools for the conversion of analog to spiking convolutional neural networks. arXiv preprint arXiv:1612.04052.
.. [#f4] Sengupta, A., Ye, Y., Wang, R., Liu, C., & Roy, K. (2019). Going deeper in spiking neural networks: Vgg and residual architectures. Frontiers in neuroscience, 13, 95.
.. [#f4] Sengupta, A., Ye, Y., Wang, R., Liu, C., & Roy, K. (2019). Going deeper in spiking neural networks: Vgg and residual architectures. Frontiers in neuroscience, 13, 95.

+ 1
- 1
docs/source/conf.py View File

@@ -80,7 +80,7 @@ napoleon_use_ivar = True
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']

autodoc_mock_imports = ['loris', 'readline', '_C_gemm', '_C_neuron', 'torchaudio', 'onnx', 'onnxruntime', 'gym', 'cloudpickle']
autodoc_mock_imports = ['loris', 'readline', '_C_gemm', '_C_neuron', 'torchaudio', 'onnx', 'onnxruntime', 'gym', 'cloudpickle', 'rarfile']
autoclass_content = 'both'
autodoc_member_order = 'bysource'
autodoc_inherit_docstrings = False


+ 13
- 7
docs/source/index.rst View File

@@ -15,7 +15,7 @@

奇数版本是开发版,随着GitHub/OpenI不断更新。偶数版本是稳定版,可以从PyPI获取。

从 `PyPI <https://pypi.org/project/spikingjelly/>`_ 安装最新的稳定版本(0.0.0.0.8)
从 `PyPI <https://pypi.org/project/spikingjelly/>`_ 安装最新的稳定版本:

.. code-block:: bash

@@ -76,7 +76,7 @@
* :ref:`search`


引用
引用和出版物
-------------------------
如果您在自己的工作中用到了惊蜇(SpikingJelly),您可以按照下列格式进行引用:

@@ -86,11 +86,14 @@
title = {SpikingJelly},
author = {Fang, Wei and Chen, Yanqi and Ding, Jianhao and Chen, Ding and Yu, Zhaofei and Zhou, Huihui and Tian, Yonghong and other contributors},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/fangwei123456/spikingjelly}},
note = {Accessed: YYYY-MM-DD},
}

其中的 `YYYY-MM-DD` 需要更改为您的工作使用的惊蜇(SpikingJelly)版本对应的最后一次代码修改日期。

使用惊蜇(SpikingJelly)的出版物可见于 `Publications using SpikingJelly <https://github.com/fangwei123456/spikingjelly/blob/master/publications.md>`_。

项目信息
-------------------------
北京大学信息科学技术学院数字媒体所媒体学习组 `Multimedia Learning Group <https://pkuml.org/>`_ 和 `鹏城实验室 <https://www.pcl.ac.cn/>`_ 是SpikingJelly的主要开发者。
@@ -124,7 +127,7 @@ Note that SpikingJelly is based on PyTorch. Please make sure that you have insta

The odd version number is the developing version, which is updated with GitHub/OpenI repository. The even version number is the stable version and available at PyPI.

Install the last stable version (0.0.0.0.8) from `PyPI <https://pypi.org/project/spikingjelly/>`_:
Install the last stable version from `PyPI <https://pypi.org/project/spikingjelly/>`_:

.. code-block:: bash

@@ -195,11 +198,14 @@ If you use SpikingJelly in your work, please cite it as follows:
title = {SpikingJelly},
author = {Fang, Wei and Chen, Yanqi and Ding, Jianhao and Chen, Ding and Yu, Zhaofei and Zhou, Huihui and Tian, Yonghong and other contributors},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/fangwei123456/spikingjelly}},
note = {Accessed: YYYY-MM-DD},
}

Note: To specify the version of framework you are using, the default value YYYY-MM-DD in the note field should be replaced with the date of the last change of the framework you are using, i.e. the date of the latest commit.

Publications using SpikingJelly are recorded in `Publications using SpikingJelly <https://github.com/fangwei123456/spikingjelly/blob/master/publications.md>`_. If you use SpikingJelly in your paper, you can also add it to this table by pull request.

About
-------------------------
`Multimedia Learning Group, Institute of Digital Media (NELVT), Peking University <https://pkuml.org/>`_ and `Peng Cheng Laboratory <http://www.szpclab.com/>`_ are the main developers of SpikingJelly.


+ 10
- 0
docs/source/spikingjelly.clock_driven.lava_exchange.rst View File

@@ -0,0 +1,10 @@
spikingjelly.clock_driven.lava_exchange package
======================================

Module contents
---------------

.. automodule:: spikingjelly.clock_driven.lava_exchange
:members:
:undoc-members:
:show-inheritance:

+ 1
- 0
docs/source/spikingjelly.clock_driven.rst View File

@@ -14,6 +14,7 @@ spikingjelly.clock_driven package
spikingjelly.clock_driven.rnn
spikingjelly.clock_driven.surrogate
spikingjelly.clock_driven.ann2snn
spikingjelly.clock_driven.lava_exchange

Module contents
---------------


+ 17
- 1
docs/source/spikingjelly.datasets.rst View File

@@ -28,6 +28,14 @@ spikingjelly.datasets.dvs128\_gesture module
:undoc-members:
:show-inheritance:

spikingjelly.datasets.es\_imagenet module
------------------------------------------

.. automodule:: spikingjelly.datasets.es_imagenet
:members:
:undoc-members:
:show-inheritance:

spikingjelly.datasets.n\_caltech101 module
------------------------------------------

@@ -44,6 +52,14 @@ spikingjelly.datasets.n\_mnist module
:undoc-members:
:show-inheritance:

spikingjelly.datasets.nav\_gesture module
-------------------------------------

.. automodule:: spikingjelly.datasets.nav_gesture
:members:
:undoc-members:
:show-inheritance:

spikingjelly.datasets.speechcommands module
-------------------------------------------

@@ -58,4 +74,4 @@ Module contents
.. automodule:: spikingjelly.datasets
:members:
:undoc-members:
:show-inheritance:
:show-inheritance:

+ 25
- 11
publications.md View File

@@ -1,15 +1,29 @@
## Publications using SpikingJelly

| Papers | Codes |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------- |
| [Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks](https://arxiv.org/abs/2007.05785) | https://github.com/fangwei123456/Parametric-Leaky-Integrate-and-Fire-Spiking-Neuron |
| [Pruning of Deep Spiking Neural Networks through Gradient Rewiring](https://arxiv.org/abs/2105.04916) | https://github.com/Yanqi-Chen/Gradient-Rewiring |
| [Optimal ANN-SNN Conversion for Fast and Accurate Inference in Deep Spiking Neural Networks](https://arxiv.org/abs/2105.11654) | https://github.com/DingJianhao/OptSNNConvertion-RNL-RIL |
| [Deep Residual Learning in Spiking Neural Networks](https://arxiv.org/abs/2102.04159) | https://github.com/fangwei123456/Spike-Element-Wise-ResNet |
| [Spiking Neural Networks Trained via Proxy](https://arxiv.org/abs/2109.13208) | https://github.com/SRKH/ProxyLearning |
| [StereoSpike: Depth Learning with a Spiking Neural Network](https://arxiv.org/abs/2109.13751) | https://github.com/urancon/StereoSpike |
| [An Odor Recognition Algorithm of Electronic Noses Based on Convolutional Spiking Neural Network for Spoiled Food Identification](https://iopscience.iop.org/article/10.1149/1945-7111/ac1699/meta) | |
| [Cascade Spiking Neuron Network For Event-based Image Classification In Noisy Environment](https://www.techrxiv.org/articles/preprint/Cascade_Spiking_Neuron_Network_For_Event-based_Image_Classification_In_Noisy_Environment/16571043) | |
| [Keys to Accurate Feature Extraction Using Residual Spiking Neural Networks](https://arxiv.org/abs/2111.05955) | https://github.com/VicenteAlex/Spiking_ResNet |
| Papers | Codes |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
| [Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks](https://arxiv.org/abs/2007.05785) | https://github.com/fangwei123456/Parametric-Leaky-Integrate-and-Fire-Spiking-Neuron |
| [Pruning of Deep Spiking Neural Networks through Gradient Rewiring](https://arxiv.org/abs/2105.04916) | https://github.com/Yanqi-Chen/Gradient-Rewiring |
| [Optimal ANN-SNN Conversion for Fast and Accurate Inference in Deep Spiking Neural Networks](https://arxiv.org/abs/2105.11654) | https://github.com/DingJianhao/OptSNNConvertion-RNL-RIL |
| [Deep Residual Learning in Spiking Neural Networks](https://arxiv.org/abs/2102.04159) | https://github.com/fangwei123456/Spike-Element-Wise-ResNet |
| [Spiking Neural Networks Trained via Proxy](https://arxiv.org/abs/2109.13208) | https://github.com/SRKH/ProxyLearning |
| [StereoSpike: Depth Learning with a Spiking Neural Network](https://arxiv.org/abs/2109.13751) | https://github.com/urancon/StereoSpike |
| [An Odor Recognition Algorithm of Electronic Noses Based on Convolutional Spiking Neural Network for Spoiled Food Identification](https://iopscience.iop.org/article/10.1149/1945-7111/ac1699/meta) | |
| [Cascade Spiking Neuron Network For Event-based Image Classification In Noisy Environment](https://www.techrxiv.org/articles/preprint/Cascade_Spiking_Neuron_Network_For_Event-based_Image_Classification_In_Noisy_Environment/16571043) | |
| [Keys to Accurate Feature Extraction Using Residual Spiking Neural Networks](https://arxiv.org/abs/2111.05955) | https://github.com/VicenteAlex/Spiking_ResNet |
| [Human-Level Control through Directly-Trained Deep Spiking Q-Networks](https://arxiv.org/abs/2201.07211) | https://github.com/AptX395/Deep-Spiking-Q-Networks |
| [Deep Reinforcement Learning with Spiking Q-learning](https://arxiv.org/abs/2201.09754) | |
| [Event-based Video Reconstruction via Potential-assisted Spiking Neural Network](https://arxiv.org/abs/2201.10943) | https://github.com/LinZhu111/EVSNN |
| [Optimal ANN-SNN Conversion for High-accuracy and Ultra-low-latency Spiking Neural Networks](https://openreview.net/forum?id=7B3IJMM1k_M) | https://github.com/putshua/SNN-conversion-QCFS |
| [Optimized Potential Initialization for Low-latency Spiking Neural Networks](https://arxiv.org/abs/2202.01440) | |
| [AutoSNN: Towards Energy-Efficient Spiking Neural Networks](https://arxiv.org/abs/2201.12738) | |
| [Neural Architecture Search for Spiking Neural Networks](https://arxiv.org/abs/2201.10355) | https://github.com/Intelligent-Computing-Lab-Yale/Neural-Architecture-Search-for-Spiking-Neural-Networks |
| [FEAS: A Faster Event-driven Accelerator Supporting Inhibitory Spiking Neural Network](https://ieeexplore.ieee.org/document/9720483/) | |
| [Neuromorphic Data Augmentation for Training Spiking Neural Networks](https://arxiv.org/abs/2203.06145) | |
| [SIT: A Bionic and Non-Linear Neuron for Spiking Neural Network](https://arxiv.org/abs/2203.16117) | |
| [Building and training a deep spiking neural network for ECG classification](https://www.sciencedirect.com/science/article/pii/S1746809422002713) | |
| [DynSNN: A Dynamic Approach to Reduce Redundancy in Spiking Neural Networks](https://ieeexplore.ieee.org/abstract/document/9746566) | |
| [Object Detection with Spiking Neural Networks on Automotive Event Data](https://arxiv.org/abs/2205.04339) | |

If you use SpikingJelly in your paper, you can also add it to this table by pull request.


+ 0
- 1
requirements.txt View File

@@ -4,4 +4,3 @@ numpy
tqdm
torchvision
scipy
onnx==1.8.0

+ 1
- 1
setup.py View File

@@ -19,7 +19,7 @@ with open("./README.md", "r", encoding="utf-8") as fh:
setup(
install_requires=install_requires,
name="spikingjelly",
version="0.0.0.0.9",
version="0.0.0.0.13",
author="PKU MLG, PCL, and other contributors",
author_email="fwei@pku.edu.cn, chyq@pku.edu.cn",
description="A deep learning framework for SNNs built on PyTorch.",


+ 2
- 476
spikingjelly/clock_driven/ann2snn/__init__.py View File

@@ -1,476 +1,2 @@
import numpy as np
import torch
import torch.nn as nn
import os
from tqdm import tqdm
import json
from spikingjelly.clock_driven import neuron,encoding,functional
from collections import defaultdict
import copy
import time
import inspect
import matplotlib.pyplot as plt
import warnings

from spikingjelly.clock_driven.ann2snn.kernels.onnx import _o2p_converter as onnx2pytorch

class parser:
def __init__(self, name='', kernel='onnx', **kargs):
try:
with open(kargs['json'], 'r') as f:
self.config = json.load(f)
except KeyError:
try:
self.log_dir = kargs['log_dir']
except KeyError:
from datetime import datetime
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = os.path.join(
self.__class__.__name__ + '-' + current_time +
('' if len(name) == 0 else '_' + name))
self.log_dir = log_dir
self.config = kargs
print('parser log_dir:', self.log_dir)
self.config['log_dir'] = self.log_dir
self.kernel = kernel
assert(self.kernel.lower() in ('onnx','pytorch'))
if not os.path.isdir(self.log_dir):
os.makedirs(self.log_dir)
with open(os.path.join(self.log_dir,'parser_args.json'), 'w') as fw:
json.dump(self.config, fw)

def parse(self, model: nn.Module, data: torch.Tensor, **kargs) -> nn.Module:
model_name = model.__class__.__name__
model.eval()

for m in model.modules():
if hasattr(m,'weight'):
assert(data.get_device() == m.weight.get_device())

try:
model = z_norm_integration(model=model, z_norm=self.config['z_norm'])
except KeyError:
pass
layer_reduc = False
for m in model.modules():
if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d, nn.BatchNorm3d)):
layer_reduc = True
break
if self.kernel.lower() == 'onnx':
try:
import onnx
import onnxruntime as ort
except ImportError:
print(Warning("Package onnx or onnxruntime not found: launch pytorch convert engine,"
" only support very simple arctitecture"))
self.kernel = 'pytorch'
else:
pass

if self.kernel.lower() == 'onnx':
# use onnx engine

data = data.cpu()
model = model.cpu()

import spikingjelly.clock_driven.ann2snn.kernels.onnx as onnx_kernel

onnx_model = onnx_kernel.pytorch2onnx_model(model=model, data=data, log_dir=self.config['log_dir'])
# onnx_kernel.print_onnx_model(onnx_model.graph)
onnx.checker.check_model(onnx_model)
if layer_reduc:
onnx_model = onnx_kernel.layer_reduction(onnx_model)
# onnx.checker.check_model(onnx_model)
onnx_model = onnx_kernel.rate_normalization(onnx_model, data.numpy(), **kargs) #**self.config['normalization']
onnx_kernel.save_model(onnx_model,os.path.join(self.config['log_dir'],model_name+".onnx"))

convert_methods = onnx2pytorch
try:
user_defined = kargs['user_methods']
assert (user_defined is dict)
for k in user_defined:
convert_methods.add_method(op_name=k, func=user_defined[k])
except KeyError:
print('no user-defined conversion method found, use default')
except AssertionError:
print('user-defined conversion method should be organized into a dict!')
model = onnx_kernel.onnx2pytorch_model(onnx_model, convert_methods)
else:
# use pytorch engine

import spikingjelly.clock_driven.ann2snn.kernels.pytorch as pytorch_kernel

if layer_reduc:
model = pytorch_kernel.layer_reduction(model)
model = pytorch_kernel.rate_normalization(model, data)#, **self.config['normalization']

self.ann_filename = os.path.join(self.config['log_dir'], model_name + ".pth")
torch.save(model, os.path.join(self.config['log_dir'], "debug.pth"))
torch.save(model, self.ann_filename)
model = self.to_snn(model)
return model

def to_snn(self, model: nn.Module, **kargs) -> nn.Module:
for name, module in model._modules.items():
if hasattr(module, "_modules"):
model._modules[name] = self.to_snn(module, **kargs)
if module.__class__.__name__ == "AvgPool2d":
new_module = nn.Sequential(module, neuron.IFNode(v_reset=None))
model._modules[name] = new_module
if "BatchNorm" in module.__class__.__name__:
try:
# NSIFNode是能够产生正负脉冲的模型,现在版本被删除
new_module = nn.Sequential(module, neuron.NSIFNode(v_threshold=(-1.0, 1.0), v_reset=None))
except AttributeError:
new_module = module
model._modules[name] = new_module
if module.__class__.__name__ == "ReLU":
new_module = neuron.IFNode(v_reset=None)
model._modules[name] = new_module
try:
if module.__class__.__name__ == 'PReLU':
p = module.weight
assert (p.size(0) == 1 and p != 0)
if -1 / p.item() > 0:
model._modules[name] = neuron.NSIFNode(v_threshold=(1.0 / p.item(), 1.0),
bipolar=(1.0, 1.0), v_reset=None)
else:
model._modules[name] = neuron.NSIFNode(v_threshold=(-1 / p.item(), 1.0),
bipolar=(-1.0, 1.0), v_reset=None)
except AttributeError:
assert False, 'NSIFNode has been removed.'
if module.__class__.__name__ == "MaxPool2d":
new_module = nn.AvgPool2d(
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding)
model._modules[name] = new_module
return model

def z_norm_integration(model: nn.Module, z_norm=None) -> nn.Module:
if z_norm is not None:
(z_norm_mean, z_norm_std) = z_norm
z_norm_mean = torch.from_numpy(np.array(z_norm_mean).astype(np.float32))
z_norm_std = torch.from_numpy(np.array(z_norm_std).astype(np.float32))
bn = nn.BatchNorm2d(num_features=len(z_norm_std))
bn.weight.data = torch.ones_like(bn.weight.data)
bn.bias.data = torch.zeros_like(bn.bias.data)
bn.running_var.data = torch.pow(z_norm_std, exponent=2) - bn.eps
bn.running_mean.data = z_norm_mean
return nn.Sequential(bn, model)
else:
return model

import threading
mutex_schedule = threading.Lock()
mutex_shared = threading.Lock()
global_shared = {}

class simulator:
def __init__(self, snn, device, name='', **kargs):
snn.eval()
try:
self.log_dir = kargs['log_dir']
except KeyError:
from datetime import datetime
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = os.path.join(
self.__class__.__name__ + '-' + current_time +
('' if len(name)==0 else '_' + name))
self.log_dir = log_dir
print('simulator log_dir:',self.log_dir)
if not os.path.isdir(self.log_dir):
os.makedirs(self.log_dir)

try:
self.fig = kargs['canvas']
self.ax = self.fig.add_subplot(1, 1, 1)
plt.ion()
except KeyError:
self.fig = None

try:
encoder = kargs['encoder']
except KeyError:
encoder = 'constant'
if encoder == 'poisson':
self.encoder = encoding.PoissonEncoder()
else:
self.encoder = lambda x: x

if isinstance(device,(list,set,tuple)):
if len(device)==1:
device = device[0]
self.pi = False
else:
self.pi = True # parallel inference
else:
self.pi = False
if self.pi:
print('simulator is working on the parallel mode, device(s):', device)
else:
print('simulator is working on the normal mode, device:', device)
self.device = device

global global_shared, mutex_schedule, mutex_shared
self.mutex_shared = mutex_shared
self.mutex_schedule = mutex_schedule
self.global_shared = global_shared
if self.pi:
self.global_shared['device_used'] = defaultdict(int)
self.global_shared['device_stat'] = defaultdict(int)
self.global_shared['distri_model'] = {}
self.global_shared['batch'] = 0
self.global_shared['batch_sum'] = 0
self.global_shared['T'] = None
for dev in self.device:
self.global_shared['distri_model'][dev] = copy.deepcopy(snn).to(dev)
else:
self.global_shared['distri_model'] = {}
self.global_shared['distri_model'][self.device] = copy.deepcopy(snn).to(self.device)
self.config = dict()
self.config['device'] = self.device
self.config['name'] = name
self.config['log_dir'] = self.log_dir
self.config = {**self.config, **kargs}


def simulate(self, data_loader, T, **kargs):
self.config['T'] = T
self.config = {**self.config, **kargs}
with open(os.path.join(self.log_dir,'simulator_args.json'), 'w') as fw:
json.dump({k: self.config[k] for k in self.config.keys() if _if_json_serializable(self.config[k])}
, fw)
try:
if kargs['online_drawer']:
if isinstance(self.device, (list, set, tuple)):
warnings.warn("online drawer deprecated because package Matplotlib is not thread safe!")
except KeyError:
pass
try:
func_dict = kargs['func_dict']
except KeyError:
func_dict = {}
for n in self._get_user_defined_static_methods():
func_dict[n] = getattr(self,n)
try:
assert(len(func_dict.keys())>0)
except AssertionError:
raise KeyError("Please add valid func_dict for simulator, or use pre-defined subclass of ``simulator``!")
if self.pi:
threads = []
start = time.perf_counter()
global global_shared
self.global_shared['T'] = T
for value_name in func_dict:
self.global_shared[value_name] = []
self.global_shared['batch_sum'] = len(data_loader)
for batch, (data, targets) in enumerate(tqdm(data_loader)):
self.global_shared['batch'] = batch
if self.pi:
distributed = False
while not distributed:
time.sleep(0.001) # time delay
for device in self.device:
if self.global_shared['device_used'][device] == 0:
t = threading.Thread(target=self.get_values,
kwargs=dict(data=data,
targets=targets,
device=device,
T=T,
func_dict=func_dict,
**kargs)
)
t.start()
threads.append(t)
distributed = True
self.global_shared['device_stat'][device] += 1
break
else:
self.get_values(data=data,
targets=targets,
device=self.device,
T=T,
func_dict=func_dict,
**kargs)
if self.pi:
for t in threads:
t.join()
elapsed = time.perf_counter() - start
print('--------------------simulator summary--------------------')
print('time elapsed:', elapsed, '(sec)')
if self.pi:
print('load stat:',self.global_shared['device_stat'])
print('---------------------------------------------------------')

try:
if kargs['canvas'] is not None:
plt.ioff()
plt.close()
except KeyError:
pass

ret_dict = {}

for value_name in func_dict:
ret_dict[value_name] = self.global_shared[value_name]
return ret_dict

def get_values(self, data, targets, device, T, func_dict, **kargs):
if self.pi:
if mutex_shared.acquire():
getattr(self, '_pre_batch_sim')(**kargs)
mutex_shared.release()
else:
getattr(self, '_pre_batch_sim')(**kargs)
global global_shared
data = data.to(device)
targets = targets.to(device)
values_list = defaultdict(list)

if self.pi:
if mutex_schedule.acquire():
self.global_shared['device_used'][device] = 1
mutex_schedule.release()

snn = self.global_shared['distri_model'][device]
functional.reset_net(snn)
with torch.no_grad():
for t in range(T):
enc = self.encoder(data).float().to(device)
out = snn(enc)
if t == 0:
counter = out
else:
counter += out
for value_name in func_dict.keys():
value = func_dict[value_name](data=data,
targets=targets,
out_spike=out,
out_spike_cnt=counter,
device=device,
**kargs)
values_list[value_name].append(value)

for value_name in func_dict.keys():
values_list[value_name] = np.array(values_list[value_name]).astype(np.float32)

if self.pi:
if mutex_shared.acquire():
for value_name in func_dict.keys():
self.global_shared[value_name].append(values_list[value_name])
getattr(self, '_after_batch_sim')(**kargs)
mutex_shared.release()
else:
for value_name in func_dict.keys():
self.global_shared[value_name].append(values_list[value_name])
getattr(self, '_after_batch_sim')(**kargs)

if self.pi:
if mutex_schedule.acquire():
self.global_shared['device_used'][device] = 0
mutex_schedule.release()

def _get_user_defined_static_methods(self):
method = []
attrs = dir(self)
for attr in attrs:
if attr[0] != '_':
user_defined = inspect.isroutine(getattr(self, attr))
static_method = False
for cls in inspect.getmro(type(self)):
if attr in cls.__dict__:
v = cls.__dict__[attr]
if isinstance(v, staticmethod):
static_method = True
if user_defined and static_method:
method.append(attr)
return method

def _pre_batch_sim(self, **kargs):
pass

def _after_batch_sim(self, **kargs):
pass



class classify_simulator(simulator): # 一个分类任务的实例
def __init__(self, snn, device, **kargs):
super().__init__(snn, device, **kargs)
self.global_shared['accu_correct'] = 0.0
self.global_shared['accu_total'] = 0.0
self.global_shared['acc'] = 0.0
# try:
# self.fig = kargs['canvas']
# self.ax = self.fig.add_subplot(1, 1, 1)
# plt.ion()
# except KeyError:
# self.fig = None

@staticmethod
def correct_num(targets, out_spike_cnt, **kargs) -> float:
n = (out_spike_cnt.max(1)[1] == targets).float().sum().item()
return n

@staticmethod
def total_num(targets, **kargs) -> float:
n = targets.size(0)
return n

def _after_batch_sim(self, **kargs):
import matplotlib.pyplot as plt
T = self.global_shared['T']
self.global_shared['accu_correct'] += self.global_shared['correct_num'][-1]
self.global_shared['accu_total'] += self.global_shared['total_num'][-1]
self.global_shared['acc'] = self.global_shared['accu_correct'] \
/ self.global_shared['accu_total']
np.savetxt(os.path.join(self.log_dir, 'acc.csv'), self.global_shared['acc'], delimiter=",")

if self.fig is not None:
self.ax.cla()
x = np.arange(self.global_shared['acc'].shape[0])
self.ax.plot(x,self.global_shared['acc'] * 100,label='SNN Acc')

try:
ann_acc = kargs['ann_acc'] * 100
self.ax.plot(x, np.ones_like(x) * ann_acc, label='ANN', c='g', linestyle=':')
self.ax.text(0, ann_acc + 1, "%.3f%%" % (ann_acc), fontdict={'size': '8', 'color': 'g'})
except KeyError:
pass
try:
self.ax.set_title("%s\n[%.1f%% dataset]" % (
kargs['fig_name'],
100.0 * (self.global_shared['batch']+1) / self.global_shared['batch_sum']
))
except KeyError:
pass
try:
if kargs['step_max']:
y = self.global_shared['acc'] * 100
argmax = np.argmax(y)
disp_bias = 0.3 * float(T) if x[argmax] / T > 0.7 else 0
self.ax.text(x[argmax] - 0.8 - disp_bias, y[argmax] + 0.8, "MAX:%.3f%% T=%d" % (y[argmax], x[argmax]),
fontdict={'size': '12', 'color': 'r'})
self.ax.scatter([x[argmax]], [y[argmax]], c='r')
except KeyError:
pass

self.ax.set_xlabel("T")
self.ax.set_ylabel("Percentage(%)")
self.ax.legend()
plt.savefig(os.path.join(self.log_dir,'plot.pdf'))

try:
if kargs['online_drawer']:
if not isinstance(self.device, (list, set, tuple)):
plt.pause(0.001)
except KeyError:
pass

def _if_json_serializable(x):
try:
json.dumps(x)
return True
except:
return False
from spikingjelly.clock_driven.ann2snn.converter import Converter
from spikingjelly.clock_driven.ann2snn.utils import download_url

+ 107
- 0
spikingjelly/clock_driven/ann2snn/converter.py View File

@@ -0,0 +1,107 @@
from spikingjelly.clock_driven.ann2snn.modules import *
from tqdm import tqdm
from spikingjelly.clock_driven import neuron
import copy


class Converter(nn.Module):

def __init__(self, dataloader, mode='Max'):
"""
* :ref:`API in English <Converter.__init__-en>`

.. _Converter.__init__-cn:

:param dataloader: 数据加载器
:type dataloader: Dataloader
:param mode: 转换模式。目前支持三种模式,最大电流转换模式,99.9%电流转换模式,以及缩放转换模式
:type mode: str, float

``Converter`` 用于将ReLU的ANN转换为SNN。这里实现了常见的三种模式。
最常见的是最大电流转换模式,它利用前后层的激活上限,使发放率最高的情况能够对应激活取得最大值的情况。
99.9%电流转换模式利用99.9%的激活分位点限制了激活上限。
缩放转换模式下,用户需要给定缩放参数到模式中,即可利用缩放后的激活最大值对电流进行限制。

* :ref:`中文API <VoltageScaler.__init__-cn>`

.. _Converter.__init__-en:

:param dataloader: Dataloader for converting
:type dataloader: Dataloader
:param mode: Conversion mode. Now support three mode, MaxNorm, RobustNorm(99.9%), and scaling mode
:type mode: str, float

``Converter`` is used to convert ReLU's ANN to SNN. Three common methods are implemented here.
The most common is the maximum mode, which utilizes the upper activation limits of
the front and rear layers so that the case with the highest firing rate corresponds to the case where the
activation achieves the maximum value.
The 99.9% mode utilizes the 99.9% activation quantile to limit the upper activation limit.
In the scaling conversion mode, the user needs to specify the scaling parameters into the mode, and the current
can be limited by the activated maximum value after scaling.

"""
super().__init__()
self.mode = mode
self.dataloader = dataloader
self._check_mode()
self.device = None
def forward(self, relu_model):
relu_model = copy.deepcopy(relu_model)
if self.device is None:
self.device = next(relu_model.parameters()).device
relu_model.eval()
model = self.set_voltagehook(relu_model, mode=self.mode).to(self.device)
for _, (imgs, _) in enumerate(tqdm(self.dataloader)):
model(imgs.to(self.device))
model = self.replace_by_ifnode(model)
return model

def _check_mode(self):
err_msg = 'You have used a non-defined VoltageScale Method.'
if isinstance(self.mode, str):
if self.mode[-1] == '%':
try:
float(self.mode[:-1])
except ValueError:
raise NotImplemented(err_msg)
elif self.mode.lower() in ['max']:
pass
else:
raise NotImplemented(err_msg)
elif isinstance(self.mode, float):
try:
assert(self.mode <= 1 and self.mode > 0)
except AssertionError:
raise NotImplemented(err_msg)
else:
raise NotImplemented(err_msg)


@staticmethod
def set_voltagehook(model, mode='MaxNorm'):
for name, module in model._modules.items():
if hasattr(module, "_modules"):
model._modules[name] = Converter.set_voltagehook(module, mode=mode)
if module.__class__.__name__ == 'ReLU':
model._modules[name] = nn.Sequential(
nn.ReLU(),
VoltageHook(mode=mode)
)
return model

@staticmethod
def replace_by_ifnode(model):
for name,module in model._modules.items():
if hasattr(module, "_modules"):
model._modules[name] = Converter.replace_by_ifnode(module)
if module.__class__.__name__ == 'Sequential' and len(module) == 2 and \
module[0].__class__.__name__ == 'ReLU' and \
module[1].__class__.__name__ == 'VoltageHook':
s = module[1].scale.item()
model._modules[name] = nn.Sequential(
VoltageScaler(1.0 / s),
neuron.IFNode(v_threshold=1., v_reset=None),
VoltageScaler(s)
)
return model

+ 0
- 175
spikingjelly/clock_driven/ann2snn/examples/cnn_fashionmnist.py View File

@@ -1,175 +0,0 @@
import torch
import torch.nn as nn
import torchvision
import os
from torch.utils.tensorboard import SummaryWriter
import spikingjelly.clock_driven.ann2snn.examples.utils as utils
from spikingjelly.clock_driven.ann2snn import parser, classify_simulator
import matplotlib.pyplot as plt

class ANN(nn.Module):
def __init__(self):
super().__init__()
# 网络结构:类似AlexNet的结构
# Network structure: AlexNet-like structure
self.network = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Dropout2d(0.2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Dropout2d(0.2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Dropout2d(0.2),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(256 * 3 * 3, 1024),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 10)
)

def forward(self,x):
x = self.network(x)
return x

def main(log_dir=None):
torch.random.manual_seed(0)
torch.cuda.manual_seed(0)

train_device = input('输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ')
parser_device = input('输入分析模型的设备,例如“cpu”或“cuda:0”\n input parsing device, e.g., "cpu" or "cuda:0": ')
simulator_device = parser_device
# simulator_device = input(
# '输入SNN仿真的设备(支持多线程),例如“cpu,cuda:0”或“cuda:0,cuda:1”\n input SNN simulating device (support multithread), e.g., "cpu,cuda:0" or "cuda:0,cuda:1": ').split(
# ',')
dataset_dir = input('输入保存MNIST数据集的位置,例如“./”\n input root directory for saving FashionMNIST dataset, e.g., "./": ')
batch_size = int(input('输入batch_size,例如“128”\n input batch_size, e.g., "128": '))
learning_rate = float(input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": '))
T = int(input('输入仿真时长,例如“400”\n input simulating steps, e.g., "400": '))
train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“100”\n input training epochs, e.g., "100": '))
model_name = input('输入模型名字,例如“cnn_fashionmnist”\n input model name, for log_dir generating , e.g., "cnn_fashionmnist": ')

load = False
if log_dir == None:
from datetime import datetime
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = model_name + '-' + current_time
if not os.path.exists(log_dir):
os.makedirs(log_dir)
else:
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.exists(os.path.join(log_dir, model_name + '.pkl')):
print('%s has no model to load.' % (log_dir))
load = False
else:
load = True

if not load:
writer = SummaryWriter(log_dir)

# 初始化数据加载器
# initialize data loader
train_data_dataset = torchvision.datasets.FashionMNIST(
root=dataset_dir,
train=True,
transform=torchvision.transforms.ToTensor(),
download=True)
train_data_loader = torch.utils.data.DataLoader(
train_data_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True)
test_data_loader = torch.utils.data.DataLoader(
dataset=torchvision.datasets.FashionMNIST(
root=dataset_dir,
train=False,
transform=torchvision.transforms.ToTensor(),
download=True),
batch_size=100,
shuffle=True,
drop_last=False)

ann = ANN().to(train_device)
loss_function = nn.CrossEntropyLoss()
if not load:
optimizer = torch.optim.Adam(ann.parameters(), lr=learning_rate, weight_decay=5e-4)
best_acc = 0.0
for epoch in range(train_epoch):
# 使用utils中预先写好的训练程序训练网络
# 训练程序的写法和经典ANN中的训练也是一样的
# Train the network using a pre-prepared code in ''utils''
utils.train_ann(net=ann,
device=train_device,
data_loader=train_data_loader,
optimizer=optimizer,
loss_function=loss_function,
epoch=epoch
)
# 使用utils中预先写好的验证程序验证网络输出
# Validate the network using a pre-prepared code in ''utils''
acc = utils.val_ann(net=ann,
device=train_device,
data_loader=test_data_loader,
loss_function=loss_function,
epoch=epoch
)
if best_acc <= acc:
utils.save_model(ann, log_dir, model_name + '.pkl')
writer.add_scalar('val_accuracy', acc, epoch)
ann = torch.load(os.path.join(log_dir, model_name + '.pkl'))
print('validating best model...')
ann_acc = utils.val_ann(net=ann,
device=train_device,
data_loader=test_data_loader,
loss_function=loss_function
)

# 加载用于归一化模型的数据
# Load the data to normalize the model
percentage = 0.004 # load 0.004 of the data
norm_data_list = []
for idx, (imgs, targets) in enumerate(train_data_loader):
norm_data_list.append(imgs)
if idx == int(len(train_data_loader) * percentage) - 1:
break
norm_data = torch.cat(norm_data_list)
print('use %d imgs to parse' % (norm_data.size(0)))

onnxparser = parser(name=model_name,
log_dir=log_dir + '/parser',
kernel='onnx')
snn = onnxparser.parse(ann, norm_data.to(parser_device))

torch.save(snn, os.path.join(log_dir, 'snn-' + model_name + '.pkl'))
fig = plt.figure('simulator')
sim = classify_simulator(snn,
log_dir=log_dir + '/simulator',
device=simulator_device,
canvas=fig
)
sim.simulate(test_data_loader,
T=T,
online_drawer=True,
ann_acc=ann_acc,
fig_name=model_name,
step_max=True
)

if __name__ == '__main__':
main('./cnn_fashionmnist')

+ 132
- 198
spikingjelly/clock_driven/ann2snn/examples/cnn_mnist.py View File

@@ -1,219 +1,153 @@
import torch
import torch.nn as nn
import torchvision
import os
from torch.utils.tensorboard import SummaryWriter
import spikingjelly.clock_driven.ann2snn.examples.utils as utils
from spikingjelly.clock_driven.ann2snn import parser, classify_simulator
import torch.nn as nn
import spikingjelly
from spikingjelly.clock_driven import ann2snn
from tqdm import tqdm
from spikingjelly.clock_driven.ann2snn.sample_models import mnist_cnn
import numpy as np
import matplotlib.pyplot as plt

class ANN(nn.Module):
def __init__(self):
super().__init__()
# 网络结构:三层卷积块串联一个全连接层,每个卷积块由一个卷积层、一个批正则化、一个ReLU激活和一个平均池化层组成
# Network structure: Three convolution blocks connected with a full-connection layer, each convolution
# block consists of a convolution layer, a batch normalization, a ReLU activation and an average pool
# layer.
self.network = nn.Sequential(
nn.Conv2d(1, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),

nn.Conv2d(32, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),

nn.Conv2d(32, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),

nn.Flatten(),
nn.Linear(32, 10),
nn.ReLU()
)

def forward(self,x):
x = self.network(x)
return x


def main(log_dir=None):
'''
:return: None

使用Conv-ReLU-[Conv-ReLU]-全连接-ReLU的网络结构训练并转换为SNN,进行MNIST识别。运行示例:

.. code-block:: python

>>> import spikingjelly.clock_driven.ann2snn.examples.cnn_mnist as cnn_mnist
>>> cnn_mnist.main()
输入运行的设备,例如“cpu”或“cuda:0”
input device, e.g., "cpu" or "cuda:0": cuda:15
输入保存MNIST数据集的位置,例如“./”
input root directory for saving MNIST dataset, e.g., "./": ./mnist
输入batch_size,例如“64”
input batch_size, e.g., "64": 128
输入学习率,例如“1e-3”
input learning rate, e.g., "1e-3": 1e-3
输入仿真时长,例如“100”
input simulating steps, e.g., "100": 100
输入训练轮数,即遍历训练集的次数,例如“10”
input training epochs, e.g., "10": 10
输入模型名字,用于自动生成日志文档,例如“cnn_mnist”
input model name, for log_dir generating , e.g., "cnn_mnist"

Epoch 0 [1/937] ANN Training Loss:2.252 Accuracy:0.078
Epoch 0 [101/937] ANN Training Loss:1.423 Accuracy:0.669
Epoch 0 [201/937] ANN Training Loss:1.117 Accuracy:0.773
Epoch 0 [301/937] ANN Training Loss:0.953 Accuracy:0.795
Epoch 0 [401/937] ANN Training Loss:0.865 Accuracy:0.788
Epoch 0 [501/937] ANN Training Loss:0.807 Accuracy:0.792
Epoch 0 [601/937] ANN Training Loss:0.764 Accuracy:0.795
Epoch 0 [701/937] ANN Training Loss:0.726 Accuracy:0.835
Epoch 0 [801/937] ANN Training Loss:0.681 Accuracy:0.880
Epoch 0 [901/937] ANN Training Loss:0.641 Accuracy:0.889
100%|██████████| 100/100 [00:00<00:00, 116.12it/s]
Epoch 0 [100/100] ANN Validating Loss:0.327 Accuracy:0.881
Save model to: cnn_mnist-XXXXX\cnn_mnist.pkl
......
--------------------simulator summary--------------------
time elapsed: 46.55072790000008 (sec)
---------------------------------------------------------
'''
def val(net, device, data_loader, T=None):
net.eval().to(device)
correct = 0.0
total = 0.0
if T is not None:
corrects = np.zeros(T)
with torch.no_grad():
for batch, (img, label) in enumerate(tqdm(data_loader)):
img = img.to(device)
if T is None:
out = net(img)
correct += (out.argmax(dim=1) == label.to(device)).float().sum().item()
else:
for m in net.modules():
if hasattr(m, 'reset'):
m.reset()
for t in range(T):
if t == 0:
out = net(img)
else:
out += net(img)
corrects[t] += (out.argmax(dim=1) == label.to(device)).float().sum().item()
total += out.shape[0]
return correct / total if T is None else corrects / total

def main():
torch.random.manual_seed(0)
torch.cuda.manual_seed(0)

train_device = input('输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ')
parser_device = input('输入分析模型的设备,例如“cpu”或“cuda:0”\n input parsing device, e.g., "cpu" or "cuda:0": ')
simulator_device = parser_device
# simulator_device = input(
# '输入SNN仿真的设备(支持多线程),例如“cpu,cuda:0”或“cuda:0,cuda:1”\n input SNN simulating device (support multithread), e.g., "cpu,cuda:0" or "cuda:0,cuda:1": ').split(
# ',')
dataset_dir = input('输入保存MNIST数据集的位置,例如“./”\n input root directory for saving MNIST dataset, e.g., "./": ')
batch_size = int(input('输入batch_size,例如“64”\n input batch_size, e.g., "64": '))
learning_rate = float(input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": '))
T = int(input('输入仿真时长,例如“100”\n input simulating steps, e.g., "100": '))
train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“10”\n input training epochs, e.g., "10": '))
model_name = input('输入模型名字,例如“cnn_mnist”\n input model name, for log_dir generating , e.g., "cnn_mnist": ')

load = False
if log_dir == None:
from datetime import datetime
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = model_name+'-'+current_time
if not os.path.exists(log_dir):
os.makedirs(log_dir)
else:
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.exists(os.path.join(log_dir,model_name+'.pkl')):
print('%s has no model to load.'%(log_dir))
load = False
else:
load = True

if not load:
writer = SummaryWriter(log_dir)

# 初始化数据加载器
# initialize data loader
device = 'cuda'
dataset_dir = 'G:/Dataset/mnist'
batch_size = 100
T = 50
# 训练参数
lr = 1e-3
epochs = 10

model = mnist_cnn.CNN().to(device)
train_data_dataset = torchvision.datasets.MNIST(
root=dataset_dir,
train=True,
transform=torchvision.transforms.ToTensor(),
download=True)
train_data_loader = torch.utils.data.DataLoader(
train_data_dataset,
dataset=train_data_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True)
drop_last=False)
test_data_dataset = torchvision.datasets.MNIST(
root=dataset_dir,
train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
test_data_loader = torch.utils.data.DataLoader(
dataset=torchvision.datasets.MNIST(
root=dataset_dir,
train=False,
transform=torchvision.transforms.ToTensor(),
download=True),
batch_size=100,
dataset=test_data_dataset,
batch_size=50,
shuffle=True,
drop_last=False)

ann = ANN().to(train_device)
loss_function = nn.CrossEntropyLoss()
if not load:
optimizer = torch.optim.Adam(ann.parameters(), lr=learning_rate, weight_decay=5e-4)
best_acc = 0.0
for epoch in range(train_epoch):
# 使用utils中预先写好的训练程序训练网络
# 训练程序的写法和经典ANN中的训练也是一样的
# Train the network using a pre-prepared code in ''utils''
utils.train_ann(net=ann,
device=train_device,
data_loader=train_data_loader,
optimizer=optimizer,
loss_function=loss_function,
epoch=epoch
)
# 使用utils中预先写好的验证程序验证网络输出
# Validate the network using a pre-prepared code in ''utils''
acc = utils.val_ann(net=ann,
device=train_device,
data_loader=test_data_loader,
loss_function=loss_function,
epoch=epoch
)
if best_acc <= acc:
utils.save_model(ann, log_dir, model_name + '.pkl')
writer.add_scalar('val_accuracy', acc, epoch)
ann = torch.load(os.path.join(log_dir, model_name + '.pkl'))
print('validating best model...')
ann_acc = utils.val_ann(net=ann,
device=train_device,
data_loader=test_data_loader,
loss_function=loss_function
)

# 加载用于归一化模型的数据
# Load the data to normalize the model
percentage = 0.004 # load 0.004 of the data
norm_data_list = []
for idx, (imgs, targets) in enumerate(train_data_loader):
norm_data_list.append(imgs)
if idx == int(len(train_data_loader) * percentage) - 1:
break
norm_data = torch.cat(norm_data_list)
print('use %d imgs to parse' % (norm_data.size(0)))

# 调用parser,使用kernel为onnx
# Call parser, use onnx kernel
onnxparser = parser(name=model_name,
log_dir=log_dir + '/parser',
kernel='onnx')
snn = onnxparser.parse(ann, norm_data.to(parser_device))

# 保存转换好的SNN模型
# Save SNN model
torch.save(snn, os.path.join(log_dir,'snn-'+model_name+'.pkl'))
fig = plt.figure('simulator')
# loss_function = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
# for epoch in range(epochs):
# model.train()
# for (img, label) in train_data_loader:
# optimizer.zero_grad()
# out = model(img.to(device))
# loss = loss_function(out, label.to(device))
# loss.backward()
# optimizer.step()
# torch.save(model.state_dict(), 'SJ-mnist-cnn_model-sample.pth')
# print('Epoch: %d' % epoch)
# acc = val(model, device, train_data_loader)
# print('Validating Accuracy: %.3f' % (acc))
# print()

model.load_state_dict(torch.load('SJ-mnist-cnn_model-sample.pth'))
acc = val(model, device, test_data_loader)
print('ANN Validating Accuracy: %.4f' % (acc))

print('---------------------------------------------')
print('Converting using MaxNorm')
model_converter = ann2snn.Converter(mode='max', dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_max_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_max_accs[-1]))

print('---------------------------------------------')
print('Converting using RobustNorm')
model_converter = ann2snn.Converter(mode='99.9%', dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_robust_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_robust_accs[-1]))

print('---------------------------------------------')
print('Converting using 1/2 max(activation) as scales...')
model_converter = ann2snn.Converter(mode=1.0 / 2, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_two_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_two_accs[-1]))

print('---------------------------------------------')
print('Converting using 1/3 max(activation) as scales')
model_converter = ann2snn.Converter(mode=1.0 / 3, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_three_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_three_accs[-1]))

print('---------------------------------------------')
print('Converting using 1/4 max(activation) as scales')
model_converter = ann2snn.Converter(mode=1.0 / 4, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_four_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_four_accs[-1]))

print('---------------------------------------------')
print('Converting using 1/5 max(activation) as scales')
model_converter = ann2snn.Converter(mode=1.0 / 5, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_five_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_five_accs[-1]))

fig = plt.figure()
plt.plot(np.arange(0, T), mode_max_accs, label='mode: max')
plt.plot(np.arange(0, T), mode_robust_accs, label='mode: 99.9%')
plt.plot(np.arange(0, T), mode_two_accs, label='mode: 1.0/2')
plt.plot(np.arange(0, T), mode_three_accs, label='mode: 1.0/3')
plt.plot(np.arange(0, T), mode_four_accs, label='mode: 1.0/4')
plt.plot(np.arange(0, T), mode_five_accs, label='mode: 1.0/5')
plt.legend()
plt.xlabel('t')
plt.ylabel('Acc')
plt.show()

# 定义用于分类的SNN仿真器
# define simulator for classification task
sim = classify_simulator(snn,
log_dir=log_dir + '/simulator',
device=simulator_device,
canvas=fig
)
# 仿真SNN
# Simulate SNN
sim.simulate(test_data_loader,
T=T,
online_drawer=True,
ann_acc=ann_acc,
fig_name=model_name,
step_max=True
)

if __name__ == '__main__':
main('./cnn_mnist')
print('Downloading SJ-mnist-cnn_model-sample.pth...')
ann2snn.download_url("https://ndownloader.figshare.com/files/34960191", './SJ-mnist-cnn_model-sample.pth')
main()

+ 0
- 0
spikingjelly/clock_driven/ann2snn/examples/model_sample/__init__.py View File


+ 0
- 0
spikingjelly/clock_driven/ann2snn/examples/model_sample/cifar10/__init__.py View File


+ 0
- 70
spikingjelly/clock_driven/ann2snn/examples/model_sample/cifar10/vgg.py View File

@@ -1,70 +0,0 @@
'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn


cfg = {
'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
def __init__(self, vgg_name):
super(VGG, self).__init__()
self.features = self._make_layers(cfg[vgg_name])
self.classifier = nn.Linear(512, 10)

def forward(self, x):
out = self.features(x)
out = out.view(out.size(0), -1)
out = self.classifier(out)
return out

def _make_layers(self, cfg):
layers = []
in_channels = 3
for x in cfg:
if x == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
nn.BatchNorm2d(x),
nn.ReLU(inplace=True)]
in_channels = x
layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
return nn.Sequential(*layers)


class VGG_no_bias_bn(nn.Module):
def __init__(self, vgg_name):
super(VGG_no_bias_bn, self).__init__()
self.features = self._make_layers(cfg[vgg_name])
self.classifier = nn.Linear(512, 10,bias=False)

def forward(self, x):
out = self.features(x)
out = out.view(out.size(0), -1)
out = self.classifier(out)
return out

def _make_layers(self, cfg):
layers = []
in_channels = 3
for x in cfg:
if x == 'M':
layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
else:
layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1,bias=False),
nn.ReLU(inplace=True)]
in_channels = x
layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
return nn.Sequential(*layers)

def test():
net = VGG('VGG11')
x = torch.randn(2,3,32,32)
y = net(x)
print(y.size())

+ 0
- 0
spikingjelly/clock_driven/ann2snn/examples/model_sample/imagenet/__init__.py View File


+ 0
- 339
spikingjelly/clock_driven/ann2snn/examples/model_sample/imagenet/resnet.py View File

@@ -1,339 +0,0 @@
import torch
import torch.nn as nn


__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']



def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out


class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out


class ResNet(nn.Module):

def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer

self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.avgpool1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)

def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))

return nn.Sequential(*layers)

def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.avgpool1(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

x = self.avgpool2(x)
x = torch.flatten(x, 1)
x = self.fc(x)

return x

def forward(self, x):
return self._forward_impl(x)


def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
raise NotImplementedError

return model


def resnet18(pretrained=False, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)


def resnet34(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)


def resnet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)


def resnet101(pretrained=False, progress=True, **kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)


def resnet152(pretrained=False, progress=True, **kwargs):
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)


def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)


def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)


def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_

The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)


def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_

The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)

+ 49
- 74
spikingjelly/clock_driven/ann2snn/examples/resnet18_cifar10.py View File

@@ -1,47 +1,50 @@
import torch
import torch.nn as nn
import torchvision
import os
from torch.utils.tensorboard import SummaryWriter
import spikingjelly.clock_driven.ann2snn.examples.utils as utils
from spikingjelly.clock_driven.ann2snn import parser, classify_simulator
from spikingjelly.clock_driven.ann2snn.examples.model_sample.cifar10 import resnet
import matplotlib.pyplot as plt
from tqdm import tqdm
import spikingjelly.clock_driven.ann2snn as ann2snn
from spikingjelly.clock_driven.ann2snn.sample_models import cifar10_resnet

def main(log_dir=None):
torch.random.manual_seed(0)
torch.cuda.manual_seed(0)

train_device = input('输入运行的设备,例如“cpu”或“cuda:0”\n input training device, e.g., "cpu" or "cuda:0": ')
parser_device = input('输入分析模型的设备,例如“cpu”或“cuda:0”\n input parsing device, e.g., "cpu" or "cuda:0": ')
simulator_device = parser_device
# simulator_device = input('输入SNN仿真的设备(支持多线程),例如“cpu,cuda:0”或“cuda:0,cuda:1”\n input SNN simulating device (support multithread), e.g., "cpu,cuda:0" or "cuda:0,cuda:1": ').split(',')
dataset_dir = input('输入保存cifar10数据集的位置,例如“./”\n input root directory for saving cifar10 dataset, e.g., "./": ')
batch_size = int(input('输入batch_size,例如“128”\n input batch_size, e.g., "128": '))
T = int(input('输入仿真时长,例如“400”\n input simulating steps, e.g., "400": '))
model_name = input('输入模型名字,例如“resnet18_cifar10”\n input model name, for log_dir generating , e.g., "resnet18_cifar10": ')

z_norm_mean = (0.4914, 0.4822, 0.4465)
z_norm_std = (0.2023, 0.1994, 0.2010)
def val(net, device, data_loader, T=None):
net.eval().to(device)
correct = 0.0
total = 0.0
with torch.no_grad():
for batch, (img, label) in enumerate(tqdm(data_loader)):
img = img.to(device)
if T is None:
out = net(img)
else:
for m in net.modules():
if hasattr(m, 'reset'):
m.reset()
for t in range(T):
if t == 0:
out = net(img)
else:
out += net(img)
correct += (out.argmax(dim=1) == label.to(device)).float().sum().item()
total += out.shape[0]
acc = correct / total
print('Validating Accuracy: %.3f' % (acc))
return acc

load = False
if log_dir == None:
from datetime import datetime
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = model_name + '-' + current_time
if not os.path.exists(log_dir):
os.makedirs(log_dir)
else:
if not os.path.exists(log_dir):
os.makedirs(log_dir)

if not load:
writer = SummaryWriter(log_dir)
def main():
torch.random.manual_seed(0)
torch.cuda.manual_seed(0)
device = 'cuda:9'
dataset_dir = '~/dataset/cifar10'
batch_size = 100
T = 400

transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

model = cifar10_resnet.ResNet18()
model.load_state_dict(torch.load('SJ-cifar10-resnet18_model-sample.pth'))

train_data_dataset = torchvision.datasets.CIFAR10(
root=dataset_dir,
train=True,
@@ -63,44 +66,16 @@ def main(log_dir=None):
shuffle=True,
drop_last=False)

ann = resnet.ResNet18().to(train_device)
loss_function = nn.CrossEntropyLoss()
checkpoint_state_dict = torch.load('./SJ-cifar10-resnet18_model-sample.pth')
ann.load_state_dict(checkpoint_state_dict)

# 加载用于归一化模型的数据
# Load the data to normalize the model
percentage = 0.004 # load 0.004 of the data
norm_data_list = []
for idx, (imgs, targets) in enumerate(train_data_loader):
norm_data_list.append(imgs)
if idx == int(len(train_data_loader) * percentage) - 1:
break
norm_data = torch.cat(norm_data_list)
print('use %d imgs to parse' % (norm_data.size(0)))

onnxparser = parser(name=model_name,
log_dir=log_dir + '/parser',
kernel='onnx',
z_norm=(z_norm_mean, z_norm_std))

snn = onnxparser.parse(ann, norm_data.to(parser_device))
ann_acc = utils.val_ann(torch.load(onnxparser.ann_filename).to(train_device),train_device,test_data_loader,loss_function)
torch.save(snn, os.path.join(log_dir, 'snn-' + model_name + '.pkl'))
fig = plt.figure('simulator')
sim = classify_simulator(snn,
log_dir=log_dir + '/simulator',
device=simulator_device,
canvas=fig
)
sim.simulate(test_data_loader,
T=T,
online_drawer=True,
ann_acc=ann_acc,
fig_name=model_name,
step_max=True
)
print('ANN accuracy:')
val(model, device, test_data_loader)
print('Converting...')
model_converter = ann2snn.Converter(device=device,mode='Max', dataloader=train_data_loader)
snn_model = model_converter(model)
print('SNN accuracy:')
val(snn_model, device, test_data_loader, T=T)

if __name__ == '__main__':
utils.download_sample_pth("https://ndownloader.figshare.com/files/26676110",'./SJ-cifar10-resnet18_model-sample.pth')
main('./resnet18_cifar10')
print('Downloading SJ-cifar10-resnet18_model-sample.pth')
ann2snn.download_url("https://ndownloader.figshare.com/files/26676110",'./SJ-cifar10-resnet18_model-sample.pth')
main()


+ 0
- 170
spikingjelly/clock_driven/ann2snn/examples/utils.py View File

@@ -1,170 +0,0 @@
import torch
import os
import numpy as np
from tqdm import tqdm
import requests

def train_ann(net, device, data_loader, optimizer, loss_function, epoch=None):
'''
* :ref:`API in English <train_ann-en>`

.. _train_ann-cn:

:param net: 训练的模型
:param device: 运行的设备
:param data_loader: 训练集
:param optimizer: 神经网络优化器
:param loss_function: 损失函数
:param epoch: 当前训练期数
:return: ``None``

经典的神经网络训练程序预设,便于直接调用训练网络

* :ref:`中文API <train_ann-cn>`

.. _train_ann-en:

:param net: network to train
:param device: running device
:param data_loader: training data loader
:param optimizer: neural network optimizer
:param loss_function: neural network loss function
:param epoch: current training epoch
:return: ``None``

Preset classic neural network training program
'''
net.train()
losses = []
correct = 0.0
total = 0.0
for batch, (img, label) in enumerate(data_loader):
img = img.to(device)
optimizer.zero_grad()
out = net(img)
loss = loss_function(out, label.to(device))
loss.backward()
optimizer.step()
losses.append(loss.item())
correct += (out.max(dim=1)[1] == label.to(device)).float().sum().item()
total += out.shape[0]
if batch % 100 == 0:
acc = correct / total
print('Epoch %d [%d/%d] ANN Training Loss:%.3f Accuracy:%.3f' % (epoch,
batch + 1,
len(data_loader),
np.array(losses).mean(),
acc))
correct = 0.0
total = 0.0


def val_ann(net, device, data_loader, loss_function, epoch=None):
'''
* :ref:`API in English <val_ann-en>`

.. _val_ann-cn:

:param net: 待验证的模型
:param device: 运行的设备
:param data_loader: 测试集
:param epoch: 当前训练期数
:return: 验证准确率

经典的神经网络训练程序预设,便于直接调用训练网络

* :ref:`中文API <val_ann-cn>`

.. _val_ann-en:

:param net: network to test
:param device: running device
:param data_loader: testing data loader
:param epoch: current training epoch
:return: testing accuracy

Preset classic neural network training program
'''
net.eval()
correct = 0.0
total = 0.0
losses = []
with torch.no_grad():
for batch, (img, label) in enumerate(tqdm(data_loader)):
img = img.to(device)
out = net(img)
loss = loss_function(out, label.to(device))
correct += (out.argmax(dim=1) == label.to(device)).float().sum().item()
total += out.shape[0]
losses.append(loss.item())
acc = correct / total
if epoch == None:
print('ANN Validating Accuracy:%.3f' % (acc))
else:
print('Epoch %d [%d/%d] ANN Validating Loss:%.3f Accuracy:%.3f' % (epoch,
batch + 1,
len(data_loader),
np.array(losses).mean(),
acc))
return acc


def save_model(net, log_dir, file_name):
'''
* :ref:`API in English <save_model-en>`

.. _save_model-cn:

:param net: 要保存的模型
:param log_dir: 日志文件夹
:param file_name: 文件名
:return: ``None``

保存模型的参数,以两种形式保存,分别为Pytorch保存的完整模型(适用于网络模型中只用了Pytorch预设模块的)
以及模型参数(适用于网络模型中有自己定义的非参数模块无法保存完整模型)

* :ref:`中文API <save_model-cn>`

.. _save_model-en:

:param net: network model to save
:param log_dir: log file folder
:param file_name: file name
:return: ``None``

Save the model, which is saved in two forms, the full model saved by Pytorch (for the network model only possessing
the Pytorch preset module) and model parameters only (for network models that have their own defined nonparametric
modules. In that case, Pytorch cannot save the full model)
'''
if not os.path.exists(log_dir):
os.makedirs(log_dir)
torch.save(net, os.path.join(log_dir, file_name))
torch.save(net.state_dict(), os.path.join(log_dir, 'param_' + file_name))
print('Save model to:', os.path.join(log_dir, file_name))


def download_sample_pth(url, filename):
'''
* :ref:`API in English <download_sample_pth-en>`

.. _download_sample_pth-cn:

:param url: 链接
:param filename: 文件名
:return: ``None``

下载例子的模型文件

* :ref:`中文API <download_sample_pth-cn>`

.. _download_sample_pth-en:

:param url: links
:param filename: file name
:return: ``None``

Download model state dict for examples
'''
print('Downloading %s from %s, please wait...'%(filename,url))
r = requests.get(url, allow_redirects=True)
open(filename, 'wb').write(r.content)

+ 0
- 0
spikingjelly/clock_driven/ann2snn/kernels/__init__.py View File


+ 0
- 1215
spikingjelly/clock_driven/ann2snn/kernels/onnx.py View File

@@ -1,1215 +0,0 @@
import onnx
import onnx.helper as helper
import onnx.numpy_helper as numpy_helper
import collections
import numpy as np
import torch
import torch.nn as nn
import os
import tqdm
import onnxruntime as ort
from collections import defaultdict
import json


class Mul(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input1, input2):
return input1 * input2


class Add(nn.Module):
def __init__(self):
super().__init__()
def forward(self,input1,input2):
return input1 + input2


class Reshape(nn.Module):
def __init__(self):
super().__init__()

def forward(self, input1, input2):
return torch.reshape(input1,shape=list(input2))


class Concat(nn.Module):
def __init__(self, dim=[1]):
super().__init__()
self.dim = dim
if not isinstance(self.dim, list):
self.dim = [self.dim]
for i, d in enumerate(self.dim):
if not isinstance(d, int):
self.dim[i] = int(d)

def forward(self, *args):
args = list(args)
for i,a in enumerate(args):
args[i] = a.type_as(args[0])
return torch.cat(args,dim=self.dim[0])

class Shape(nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.IntTensor([input.size(i) for i in range(len(input.size()))])

class Gather(nn.Module):
def __init__(self,dim=1):
super().__init__()
self.dim= int(dim)

def forward(self, input1, input2):
return torch.gather(input1,dim=self.dim,index=input2.cpu())

class Unsqueeze(nn.Module):
def __init__(self, dim=[1]):
super().__init__()
self.dim = dim
if not isinstance(self.dim, list):
self.dim = [self.dim]
for i,d in enumerate(self.dim):
if not isinstance(d,int):
self.dim[i] = int(d)

def forward(self, input):
x = input
for i in self.dim:
x = torch.unsqueeze(x,dim=i)
return x

class TopologyAnalyser:
def __init__(self):
'''
* :ref:`API in English <TopologyAnalyser.__init__-en>`

.. _TopologyAnalyser.__init__-cn:

这个类通过onnx分析模型的拓扑结构,方便后续处理
此处还有更多更好的实现方法,欢迎开发者不断优化

* :ref:`API in English <TopologyAnalyser.__init__-cn>`

.. _TopologyAnalyser.__init__-en:

This class analyzes the topological structure of the model through onnx to facilitate subsequent processing
There are better implementation methods here, developers are welcome to continue to optimize
'''
self.data_nodes = []
self.module_output = collections.OrderedDict()
self.module_op = collections.OrderedDict()
self.module_idx = collections.OrderedDict()
self.param_idx = collections.OrderedDict()
self.edge = collections.OrderedDict()
self.reverse_edge = collections.OrderedDict() # 快速计算前驱结点

def add_data_node(self, a):
if not a in self.data_nodes:
self.data_nodes.append(a)

def insert(self, a, b, info=None):
self.add_data_node(a)
self.add_data_node(b)
if a not in self.edge.keys():
self.edge[a] = [(b, info)]
else:
self.edge[a].append((b, info))
if b not in self.reverse_edge.keys():
self.reverse_edge[b] = [a]
else:
self.reverse_edge[b].append(a)

def findNext(self, id):
if isinstance(id, str):
if id in self.edge.keys():
return self.edge[id]
else:
return []
elif isinstance(id, list):
l = []
for i in id:
l += self.findNext(i)
return l

def findPre(self, id):
l = []
if isinstance(id, str):
for pre_id in self.reverse_edge[id]:
if pre_id in self.reverse_edge.keys():
for pre_pre_id in self.reverse_edge[pre_id]:
if pre_pre_id in self.edge.keys():
for item in self.edge[pre_pre_id]:
if item[0] == pre_id:
l.append(item)
elif isinstance(id, list):
for i in id:
l += self.findPre(i)
return l

def find_pre_module(self, module_name):
if module_name in self.module_output.keys():
ids = self.module_output[module_name]
return set(['%s:%s' % (k[1]['op'], k[1]['param_module_name']) for k in self.findPre(ids)])
else:
return set()

def find_next_module(self, module_name):
if module_name in self.module_output.keys():
ids = self.module_output[module_name]
return set(['%s:%s' % (k[1]['op'], k[1]['param_module_name']) for k in self.findNext(ids)])
else:
return set()

def update_module_idx(self, onnx_graph):
for idx, n in enumerate(onnx_graph.node):
trainable_input = n.input[1:]
op = n.op_type
k = set()
for i in trainable_input:
n = self._get_module_name_from_value_name(i)
if n is not None:
k.add(n)
if len(k) > 1:
# TODO: sanity check, raise error
pass
if len(k) == 1:
param_module_name = list(k)[0]
self.module_op[param_module_name] = op
self.module_idx[param_module_name] = idx

def analyse(self, onnx_graph): # 输入的onnx graph需要保证所以常量在已经在initializer中
# 先把该往initializer放下面的参数,保证下面只有运算没有常量
for idx, constant in enumerate(onnx_graph.initializer):
self.param_idx[constant.name] = idx
for idx, n in enumerate(onnx_graph.node):
param_module_name = None
op = n.op_type
inputs = n.input
outputs = n.output
# print(inputs,outputs)
k = set()
trainable_input = inputs[1:]
for i in trainable_input:
n = self._get_module_name_from_value_name(i)
if n is not None:
k.add(n)
if len(k) > 1:
# TODO: sanity check, raise error
pass
if len(k) == 1:
param_module_name = list(k)[0]
self.module_op[param_module_name] = op
self.module_idx[param_module_name] = idx
if op is not None:
for o in outputs:
for i in inputs:
self.insert(i, o, {'op': op, 'param_module_name': param_module_name})
if param_module_name is not None:
if param_module_name not in self.module_output.keys():
self.module_output[param_module_name] = [o]
else:
self.module_output[param_module_name].append(o)
return self

@staticmethod
def _get_module_name_from_value_name(value_name):
module_name = None
if len(value_name.split('.')) > 1:
l = value_name.split('.')[:-1]
l = '.'.join(l)
module_name = l # [1:]
module_name.replace(' ', '')
return module_name

def pytorch2onnx_model(model: nn.Module, data, **kargs) -> onnx.ModelProto:
'''

* :ref:`API in English <pytorch2onnx_model-en>`

.. _pytorch2onnx_model-cn:

:param model: 待转换的PyTorch模型

:param data: 用于转换的数据(用来确定输入维度)

:param log_dir: 输出文件夹

转换PyTorch模型到onnx模型

* :ref:`API in English <pytorch2onnx_model-cn>`

.. _pytorch2onnx_model-en:

:param model: the PyTorch model to be converted

:param data: The data used for conversion (used to determine the input dimension)

:param log_dir: output folder

Convert PyTorch model to onnx model

'''
try:
log_dir = kargs['log_dir']
except KeyError:
print('pytorch2onnx_model need argument log_dir!')
dump_input_size = [data.size(i) for i in range(len(data.size()))]
dump_input_size[0] = 1
fname = os.path.join(log_dir,'onnxmodel')
try:
dynamic_axes = {'input': {0: 'batch_size'},
'output': {0: 'batch_size'}}
torch.onnx.export(model, torch.ones(dump_input_size), fname,
input_names=['input'],
output_names=['output'],
dynamic_axes=dynamic_axes)
except BaseException:
raise NotImplementedError("Models with multiple inputs are not supported yet!")
return onnx.load(fname)

def onnx2pytorch_model(model: onnx.ModelProto, _converter) -> nn.Module:
model = _pt_model(model, _converter)
model = model.reduce()
return model

def layer_reduction(model: onnx.ModelProto) -> onnx.ModelProto:
graph = model.graph
topo_analyser = TopologyAnalyser()
graph = move_constant_to_initializer(graph)
topo_analyser.analyse(graph)

absorb_bn(graph, topo_analyser)
remove_unreferenced_initializer(graph)
update_topology(graph)
print("Finish layer reduction!")
return model

def rate_normalization(model: onnx.ModelProto, data: torch.Tensor, **kargs) -> onnx.ModelProto:
'''

* :ref:`API in English <rate_normalization-en>`

.. _rate_normalization-cn:

:param model: ANN模型,类型为onnx.ModelProto

:param data: 用于转换的数据,类型为torch.Tensor

:param channelwise: 如果为``True``,则控制激活幅值的统计是channelwise的;否则,控制激活幅值的统计是layerwise的

:param robust: 如果为``True``,则控制激活幅值的统计是激活的99.9百分位;否则,控制激活幅值的统计是激活的最值

:param eps: epsilon;未设置值时默认1e-5

发放率归一化

* :ref:`API in English <rate_normalization-cn>`

.. _rate_normalization-en:

:param model: ANN model, the type is onnx.ModelProto

:param data: the data used for conversion, the type is torch.Tensor

:param channelwise: If ``True`` , the statistics that control the activation amplitude are channelwise; otherwise, the statistics that control the activation amplitude are layerwise

:param robust: If ``True``, the statistic of the control activation amplitude is the 99.9th percentile of activation; otherwise, the statistic of the activation amplitude is the maximum value of activation

:param eps: epsilon; if no value is set, the default is 1e-5

normalize the firing rate

'''

try:
channelwise = kargs['channelwise']
except KeyError:
channelwise = False
try:
robust_norm = kargs['robust']
except KeyError:
robust_norm = False
try:
eps = kargs['eps']
except KeyError:
eps = 1e-5
topo_analyser = update_topology(model.graph)
output_debug = {}
output_statistics = get_intermediate_output_statistics(model, data,
channelwise=channelwise) # if want debug, debug=output_debug
model = normalize_model(model, output_statistics, topo_analyser, robust_norm=robust_norm,
channelwise=channelwise, eps=eps)
return model

def save_model(model: onnx.ModelProto, f=None):
fb = model.SerializeToString()
if f is not None:
if hasattr(f, 'write'):
f.write(fb)
else:
with open(f, "wb") as f:
f.write(fb)
return fb

def move_constant_to_initializer(graph):
constant_idx = []
for idx, n in enumerate(graph.node):
op = n.op_type
if op == 'Constant':
constant_idx.append(idx)
if len(constant_idx):
for idx in reversed(constant_idx):
n = graph.node[idx]
graph.initializer.append(
numpy_helper.from_array(numpy_helper.to_array(n.attribute[0].t), n.output[0]))
graph.node.remove(n)
return graph

def print_onnx_model(graph):
print(onnx.helper.printable_graph(graph))

def absorb_bn(graph, topo_analyser):
print("\nAbsorbing BatchNorm Parameters...\n")
for mn in tqdm.tqdm(reversed(topo_analyser.module_output.keys())):
if topo_analyser.module_op[mn] == 'BatchNormalization':
pre_m = topo_analyser.find_pre_module(mn)
next_m = topo_analyser.find_next_module(mn)
bn_weight_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[mn]].input[1]]
bn_weight = np.array(numpy_helper.to_array(graph.initializer[bn_weight_idx]))
bn_bias_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[mn]].input[2]]
bn_bias = np.array(numpy_helper.to_array(graph.initializer[bn_bias_idx]))
bn_mean_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[mn]].input[3]]
bn_mean = np.array(numpy_helper.to_array(graph.initializer[bn_mean_idx]))
bn_var_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[mn]].input[4]]
bn_var = np.array(numpy_helper.to_array(graph.initializer[bn_var_idx]))
bn_eps = graph.node[topo_analyser.module_idx[mn]].attribute[0].f
bn_std = np.sqrt(bn_var + bn_eps)
if len(pre_m) == 1 and list(pre_m)[0].split(':')[0] in ['Conv', 'Gemm']:
pre_mn = list(pre_m)[0].split(':')[1]
weight_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[pre_mn]].input[1]]
weight = np.array(numpy_helper.to_array(graph.initializer[weight_idx]))
if len(graph.node[topo_analyser.module_idx[pre_mn]].input) == 2:
bias = None
else:
bias_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[pre_mn]].input[2]]
bias = np.array(numpy_helper.to_array(graph.initializer[bias_idx]))
wrsp_args = (-1, 1) if len(weight.shape) == 2 else (-1, 1, 1, 1)

weight_ = weight * bn_weight.reshape(*wrsp_args) / bn_std.reshape(*wrsp_args)
bias_ = ((bias if bias is not None else 0) - bn_mean.reshape(-1)) * bn_weight.reshape(
-1) / bn_std.reshape(-1) \
+ bn_bias.reshape(-1)
assert (list(pre_m)[0].split(':')[0] in ['Conv', 'Gemm'])
args = {}
for attr in graph.node[topo_analyser.module_idx[pre_mn]].attribute:
args[attr.name] = helper.get_attribute_value(attr)
new_node = onnx.helper.make_node(
list(pre_m)[0].split(':')[0],
inputs=[graph.node[topo_analyser.module_idx[pre_mn]].input[0], pre_mn + ".new.weight", pre_mn + ".new.bias"],
outputs=[graph.node[topo_analyser.module_idx[mn]].output[0]],
**args
)
graph.initializer.append(numpy_helper.from_array(weight_.astype(np.float32), pre_mn + ".new.weight"))
graph.initializer.append(numpy_helper.from_array(bias_.astype(np.float32), pre_mn + ".new.bias"))
graph.node.remove(graph.node[topo_analyser.module_idx[pre_mn]])
graph.node.insert(topo_analyser.module_idx[pre_mn], new_node)
graph.node.remove(graph.node[topo_analyser.module_idx[mn]])
else:
weight_ = bn_weight / bn_std
bias_ = bn_bias - bn_weight * bn_mean / bn_std
name = graph.initializer[bn_weight_idx].name
graph.initializer.remove(graph.initializer[bn_weight_idx])
graph.initializer.insert(bn_weight_idx, numpy_helper.from_array(weight_.astype(np.float32), name))
name = graph.initializer[bn_bias_idx].name
graph.initializer.remove(graph.initializer[bn_bias_idx])
graph.initializer.insert(bn_bias_idx, numpy_helper.from_array(bias_.astype(np.float32), name))
name = graph.initializer[bn_mean_idx].name
graph.initializer.remove(graph.initializer[bn_mean_idx])
graph.initializer.insert(bn_mean_idx,
numpy_helper.from_array(np.zeros_like(bn_mean).astype(np.float32), name))
name = graph.initializer[bn_var_idx].name
graph.initializer.remove(graph.initializer[bn_var_idx])
graph.initializer.insert(bn_var_idx,
numpy_helper.from_array(np.ones_like(bn_var).astype(np.float32), name))

def remove_unreferenced_initializer(graph):
in_graph = set()
in_initializer = set()
for node in graph.node:
in_graph.update(node.input)
in_graph.update(node.output)
for init in graph.initializer:
in_initializer.add(init.name)
not_in_graph = in_initializer - in_graph
l = len(graph.initializer)
for i in range(l - 1, -1, -1):
if graph.initializer[i].name in not_in_graph:
graph.initializer.remove(graph.initializer[i])

def update_topology(graph):
topo_analyser = TopologyAnalyser()
move_constant_to_initializer(graph)
topo_analyser.analyse(graph)
return topo_analyser

def find_node_by_output(output_name, graph):
flag = False
idx, node = None, None
for idx, node in enumerate(graph.node):
if output_name in node.output:
flag = True
break
if not flag:
idx, node = None, None
return idx, node

def scale_node_weight_bias(topo_analyser, graph, node_idx, scale):
initializer = graph.initializer
node = graph.node[node_idx]
if len(node.input) < 2:
return
weight_idx = topo_analyser.param_idx[node.input[1]]
bias_idx = topo_analyser.param_idx[node.input[2]] if len(node.input) >= 3 else None
weight = np.array(numpy_helper.to_array(initializer[weight_idx]))
bias = np.array(numpy_helper.to_array(initializer[bias_idx])) if bias_idx is not None else None

w_scale = scale.reshape([*scale.shape] + [1 for _ in range(len(weight.shape) - 1)]) \
if len(scale.shape) == 1 else scale
b_scale = scale

weight_ = weight * w_scale
name = initializer[weight_idx].name
initializer.remove(initializer[weight_idx])
initializer.insert(weight_idx, numpy_helper.from_array(weight_.astype(np.float32), name))
if bias is not None:
bias_ = bias * b_scale
name = initializer[bias_idx].name
initializer.remove(initializer[bias_idx])
initializer.insert(bias_idx, numpy_helper.from_array(bias_.astype(np.float32), name))

def get_onnx_output(model, numpy_tensor):
ort_session = ort.InferenceSession(model.SerializeToString())
outputs = ort_session.run(None, {'input': numpy_tensor})
return outputs

def get_intermediate_output_statistics(model, numpy_tensor, channelwise=False, debug=None):
graph = model.graph
output_needed_module = {}
output_needed_all_input = {}
for idx, node in enumerate(graph.node):
output = node.output
input = node.input
if 'input' in node.input:
for out in output:
output_needed_module[out] = set([idx])
output_needed_all_input[out] = set(input)
else:
s = set()
s_i = set()
for in_ in input:
s |= (output_needed_module[in_] if in_ in output_needed_module.keys() else set())
s_i |= (output_needed_all_input[in_] if in_ in output_needed_all_input.keys() else set())
for out in output:
output_needed_module[out] = s | set([idx])
output_needed_all_input[out] = s_i | set(input)

output_statistics = {}
if not channelwise:
statistic = {'shape': numpy_tensor.shape,
'min': np.min(numpy_tensor),
'max': np.max(numpy_tensor) if np.max(numpy_tensor) > 0 else np.abs(np.min(numpy_tensor)),
'99.9': np.percentile(numpy_tensor, 99.9)
}
else:
axis_args = (0, 2, 3) if len(numpy_tensor.shape) == 4 else (0)
statistic = {'shape': numpy_tensor.shape,
'min': np.min(numpy_tensor, axis=axis_args),
'max': np.max(numpy_tensor, axis=axis_args),
'99.9': np.percentile(numpy_tensor, 99.9, axis=axis_args)
}
output_statistics['input'] = statistic
print("\nGetting intermediate output statistics...\n")
for out in tqdm.tqdm(output_needed_module.keys()):
keep_nodes = [graph.node[i] for i in list(output_needed_module[out])]
keep_initializer = [init for init in graph.initializer
if init.name in list(output_needed_all_input[out])]
var_out = []
value_info = onnx.ValueInfoProto()
value_info.name = out
var_out.append(value_info)
new_graph = onnx.helper.make_graph(keep_nodes, graph.name, graph.input,
var_out, keep_initializer)
tmp_model = onnx.helper.make_model(new_graph)
tmp_model.ir_version = model.ir_version
tmp_model.producer_name = model.producer_name
tmp_model.producer_version = model.producer_version
tmp_model.domain = model.domain
tmp_model.model_version = model.model_version
tmp_model.doc_string = model.doc_string
if len(tmp_model.metadata_props) > 0:
values = {p.key: p.value for p in model.metadata_props}
onnx.helper.set_model_props(tmp_model, values)
# fix opset import
for oimp in model.opset_import:
op_set = tmp_model.opset_import.add()
op_set.domain = oimp.domain
op_set.version = oimp.version

ort_session = ort.InferenceSession(tmp_model.SerializeToString())
outputs = ort_session.run(None, {'input': numpy_tensor})
if debug is not None:
# print(out,outputs[0].reshape(1,-1)[0,10:20])
debug[out] = outputs[0]
if not channelwise:
statistic = {'shape': outputs[0].shape,
'min': np.min(outputs[0]),
'max': np.max(outputs[0]) if np.max(outputs[0]) > 0 else np.abs(np.min(outputs[0])),
'99.9': np.percentile(outputs[0], 99.9) if np.percentile(outputs[0], 99.9) > 0 else np.abs(np.min(outputs[0]))
}
else:
axis_args = (0, 2, 3) if len(outputs[0].shape) == 4 else (0)
statistic = {'shape': outputs[0].shape,
'min': np.min(outputs[0], axis=axis_args),
'max': np.max(outputs[0], axis=axis_args),
'99.9': np.percentile(outputs[0], 99.9, axis=axis_args)
}
# print(np.max(statistic['max']),np.max(outputs[0]))
output_statistics[out] = statistic
print("Finished getting intermediate output statistics!")
if debug is not None:
return output_statistics,debug
else:
return output_statistics

def normalize_model(model, output_statistics, topo_analyser, robust_norm=True, channelwise=False, eps=1e-5):
nodes = model.graph.node
graph = model.graph
initializer = model.graph.initializer
if robust_norm:
statistic_key = '99.9'
else:
statistic_key = 'max'
node_scaled_range = {}
seperate_scale = collections.OrderedDict()
print("\nNormalizing model...\n")
for node_idx, node in enumerate(tqdm.tqdm(nodes)):
output = node.output
input = node.input
op = node.op_type
if input[0] == 'input': # single input model
l = output_statistics[input[0]]['shape'][1]
node_scaled_range[input[0]] = np.ones(l) if channelwise else 1.0 \
* output_statistics[input[0]][
statistic_key]

if op in ['Conv', 'Gemm']:
weight_idx = topo_analyser.param_idx[input[1]]
bias_idx = topo_analyser.param_idx[input[2]] if len(input) == 3 else None
weight = np.array(numpy_helper.to_array(initializer[weight_idx]))
bias = np.array(numpy_helper.to_array(initializer[bias_idx])) if bias_idx is not None else None

l = output_statistics[output[0]]['shape'][1]
input_real_range = node_scaled_range[input[0]]
input_range = output_statistics[input[0]][statistic_key]
output_range = output_statistics[output[0]][statistic_key]
demand = np.ones(l) if channelwise else 1.0
w_scale = demand / (output_range + eps) * (
input_range / (input_real_range + eps)) if not channelwise else \
(demand / (output_range + eps)).reshape(-1, 1).dot(
(input_range / (input_real_range + eps)).reshape(1, -1))
w_scale = w_scale.reshape([*w_scale.shape, 1, 1]) if len(weight.shape) == 4 else w_scale
b_scale = 1 / (output_range + eps)
node_scaled_range[output[0]] = demand

weight_ = weight * w_scale

name = initializer[weight_idx].name
initializer.remove(initializer[weight_idx])
initializer.insert(weight_idx, numpy_helper.from_array(weight_.astype(np.float32), name))
if bias is not None:
bias_ = bias * b_scale
name = initializer[bias_idx].name
initializer.remove(initializer[bias_idx])
initializer.insert(bias_idx, numpy_helper.from_array(bias_.astype(np.float32), name))

elif op == 'BatchNormalization': # var=1 mean=0
weight_idx = topo_analyser.param_idx[input[1]]
bias_idx = topo_analyser.param_idx[input[2]]
weight = np.array(numpy_helper.to_array(initializer[weight_idx]))
bias = np.array(numpy_helper.to_array(initializer[bias_idx]))

# node_scaled_range[output[0]] = node_scaled_range[input[0]] * self.output_statistics[input[0]][statistic_key] / self.output_statistics[output[0]][statistic_key]
# lamda_last = self.output_statistics[input[0]][statistic_key]
# lamda = self.output_statistics[output[0]][statistic_key]
# weight_ = weight * node_scaled_range[output[0]]
# bias_ = bias / lamda

# print(output_statistics[output[0]])
input_real_range = node_scaled_range[input[0]]
input_range = output_statistics[input[0]][statistic_key]
output_range = output_statistics[output[0]][statistic_key]
demand = 1.0
w_scale = demand / (output_range + eps) * (input_range / (input_real_range + eps))
b_scale = 1 / (output_range + eps)
node_scaled_range[output[0]] = demand
weight_ = weight * w_scale
bias_ = bias * b_scale
# print(output[0],op,input[0], input_range, output_range, demand, input_real_range, w_scale)

name = initializer[weight_idx].name
initializer.remove(initializer[weight_idx])
initializer.insert(weight_idx, numpy_helper.from_array(weight_.astype(np.float32), name))
name = initializer[bias_idx].name
initializer.remove(initializer[bias_idx])
initializer.insert(bias_idx, numpy_helper.from_array(bias_.astype(np.float32), name))

elif op == 'Add':
l = output_statistics[output[0]]['shape'][1]
demand = np.ones(l) if channelwise else 1.0
node_scaled_range[output[0]] = demand
output_range = output_statistics[output[0]][statistic_key]

# node_scaled_range[output[0]] = 1.0
# lamda = self.output_statistics[output[0]][statistic_key]
# lamda_lasts = {}
for i in input:
if i in output_statistics.keys():
# lamda_lasts[i] = self.output_statistics[i][statistic_key]
# scale = lamda_lasts[i] / lamda
input_real_range = node_scaled_range[i]
input_range = output_statistics[i][statistic_key]
scale = demand / (output_range + eps) * (input_range / (input_real_range + eps))

# print(output[0], op, i, input_range, output_range, demand, input_real_range, scale)

idx, _ = find_node_by_output(i, graph)
if idx is not None and nodes[idx].op_type in ['Conv', 'Gemm', 'BatchNormalization']:
scale_node_weight_bias(topo_analyser, graph, idx, scale)
else:
scale = scale.reshape(
[1, *scale.shape] + [1 for _ in range(len(output_statistics[i]['shape']) - 2)]) \
if len(scale.shape) == 1 else scale
initializer.append(numpy_helper.from_array(scale.astype(np.float32), "scale_" + i))
if idx not in seperate_scale.keys():
# seperate_scale[node_idx] = [(i,"scale_"+i,"scaled_"+i)]
seperate_scale[node_idx] = {i: ("scale_" + i, "scaled_" + i)}
else:
# seperate_scale[node_idx].append((i,"scale_"+i,"scaled_"+i))
seperate_scale[node_idx][i] = ("scale_" + i, "scaled_" + i)
pass
elif op in ['Gather', 'Unsqueeze', 'Shape', 'Concat']:
continue
# elif op == "Concat":
# raise NotImplementedError("Not supported %s yet!"%(op))
elif op == "Softmax":
raise NotImplementedError("Not supported %s yet!" % (op))
else: # single input single output module
# print(op,self.output_statistics[output[0]]['shape'])
input_range = output_statistics[input[0]][statistic_key]
output_range = output_statistics[output[0]][statistic_key]
input_scaled_range = node_scaled_range[input[0]]
output_scaled_range = input_scaled_range / (input_range + eps) * output_range
node_scaled_range[output[0]] = output_scaled_range

# print(output[0], op, input[0], input_range, output_range, output_scaled_range)
# print(op, node_scaled_range[output[0]],'=',input_scaled_range,'/',input_range,'*',output_range)
# else:
# raise NotImplementedError("Not supported yet! %s"%(op))

if len(seperate_scale.keys()) != 0:
print("Making new scale node...")

for node_idx in reversed(seperate_scale.keys()):
args = {}
for attr in nodes[node_idx].attribute:
args[attr.name] = helper.get_attribute_value(attr)
input = [str(i) if i not in seperate_scale[node_idx].keys() else seperate_scale[node_idx][i][1] \
for i in nodes[node_idx].input]

output = [str(i) for i in nodes[node_idx].output]

new_node = onnx.helper.make_node(
nodes[node_idx].op_type,
inputs=input,
outputs=output,
**args
)
nodes.remove(nodes[node_idx])
nodes.insert(node_idx, new_node)

for i in seperate_scale[node_idx].keys():
new_node = onnx.helper.make_node(
'Mul',
inputs=[seperate_scale[node_idx][i][0], i],
outputs=[seperate_scale[node_idx][i][1]]
)
nodes.insert(node_idx, new_node)
print("Finished normalizing model!")
return model

def _pre_onnx_shape_inference(model:onnx.ModelProto):
'''
为了对模型进行shape inference,需要先对onnxmodel运行此函数进行准备

To perform shape inference for model, need to run this function on onnxmodel to prepare

This function has referenced code in https://github.com/onnx/onnx/issues/2660#issuecomment-605874784
'''
if model.ir_version < 4:
return

def add_some_graph_info(graph:onnx.GraphProto):
inputs = {i.name for i in graph.input}
vi_dict = {vi.name: vi for vi in graph.value_info}
for init in graph.initializer:
if init.name in inputs:
continue
vi = vi_dict.get(init.name)
if vi is None:
vi = graph.value_info.add()
vi.name = init.name

tensor_type = vi.type.tensor_type
if tensor_type.elem_type == onnx.TensorProto.UNDEFINED:
tensor_type.elem_type = init.data_type
if not tensor_type.HasField("shape"):
tensor_type.shape.dim.extend([])
for dim in init.dims:
tensor_type.shape.dim.add().dim_value = dim

for node in graph.node:
for attr in node.attribute:
if attr.ref_attr_name != "":
continue

if attr.type == onnx.AttributeProto.GRAPH:
add_some_graph_info(attr.g)
if attr.type == onnx.AttributeProto.GRAPHS:
for g in attr.graphs:
add_some_graph_info(g)
return add_some_graph_info(model.graph)

class _pt_model(nn.Module):
def __init__(self, path_or_model, _converter=None):
super(_pt_model, self).__init__()
if path_or_model is not None:
if isinstance(path_or_model, str):
onnx_model = onnx.load(path_or_model)
else:
onnx_model = path_or_model
self.onnx_model = onnx_model

self.loaded_weights = load_parameters(self, onnx_model.graph.initializer)
self.module_list = nn.ModuleList([])
self.op_tree = {}

_pre_onnx_shape_inference(onnx_model)
inferred_model = onnx.shape_inference.infer_shapes(onnx_model)
self.value_info = inferred_model.graph.value_info
self.dim_info = {}
for idx, v in enumerate(self.value_info):
self.dim_info[v.name] = len(v.type.tensor_type.shape.dim)

self.graph = defaultdict(list)
# self.V = set()
for idx, node in enumerate(onnx_model.graph.node):
op = node.op_type
# if op=='MatMul': # TODO temporary
# op = 'Gemm'
(op_idx, inputs, outputs) = getattr(_converter, 'convert_' + op.lower())(node, self)
for out_seq, output in enumerate(outputs):
self.op_tree[str(output)] = (int(op_idx), [str(i) for i in inputs], out_seq)
for output in outputs:
for input in inputs:
self.graph[input].append(output)
# self.V.update(inputs)
# self.V.update(outputs)
# print(self.V)
self.op_tree = json.dumps(self.op_tree)

self.input_name = [i.name for i in onnx_model.graph.input]
self.output_name = [i.name for i in onnx_model.graph.output]

for out in onnx_model.graph.output:
self.graph[out.name] = []

def TopologicalSort(G):
in_degrees = dict((u, 0) for u in G)
for u in G:
for v in G[u]:
in_degrees[v] += 1
Q = [u for u in G if in_degrees[u] == 0]
res = []
while Q:
u = Q.pop()
res.append(u)
for v in G[u]:
in_degrees[v] -= 1
if in_degrees[v] == 0:
Q.append(v)
return res

self.compute_seq = TopologicalSort(self.graph)
# print(self.compute_seq)
# self.tensors = {}

for k in self.loaded_weights.keys():
if isinstance(self.loaded_weights[k], torch.FloatTensor):
setattr(self, 'P' + k.replace('.', '@'), torch.nn.Parameter(self.loaded_weights[k]))
else:
# print(self.loaded_weights[k])
self.register_buffer('P' + k.replace('.', '@'), self.loaded_weights[k])

self.reserved_tensors_name = list(self.loaded_weights.keys())
# print('reserve', self.reserved_tensors_name)
# print(self.tensors)

def refresh_running_tensor(self):
self.tensors = {}
for k in set(self.tensors.keys()) | set(self.reserved_tensors_name):
if k not in self.reserved_tensors_name:
del self.tensors[k]
else:
self.tensors[k] = getattr(self, 'P' + k.replace('.', '@'))


def forward(self, input):
op_tree = json.loads(self.op_tree)

tensors = {}
for k in set(tensors.keys()) | set(self.reserved_tensors_name):
if k not in self.reserved_tensors_name:
del tensors[k]
else:
tensors[k] = getattr(self, 'P' + k.replace('.', '@'))
# self.refresh_running_tensor()
if not isinstance(input, list) or not isinstance(input, tuple):
input = [input]
for i, n in enumerate(self.input_name):
tensors[n] = input[i]
for name in self.compute_seq:
if name in op_tree.keys():
op_idx, inputs, out_seq = op_tree[name]
# print(name,op_idx, inputs,out_seq)
args = []
for input in inputs:
args.append(tensors[input])
# print(len(args))
# print(type(args[0]))
result = self.module_list[op_idx](*args)
if not isinstance(result, tuple):
tensors[name] = result
# print(' %s = self.module_list[%d] (%s)'%(name,op_idx,inputs))
else:
tensors[name] = result[out_seq]
# print(' %s = self.module_list[%d] (%s)[%d]'%(name,op_idx,inputs,out_seq) )

if len(self.output_name) == 1:
return tensors[self.output_name[0]]
else:
ret = []
for output in self.output_name:
ret.append(tensors[output])
return ret

def reduce(self):
import copy
net = _pt_model(None)
for k in self.reserved_tensors_name:
if isinstance(self.loaded_weights[k], torch.FloatTensor):
setattr(net, 'P' + k.replace('.', '@'),
torch.nn.Parameter(getattr(self,'P' + k.replace('.', '@')).data.detach().clone()) )
else:
net.register_buffer('P' + k.replace('.', '@'),
getattr(self,'P' + k.replace('.', '@')).data.clone() )

net.compute_seq = copy.deepcopy(self.compute_seq)

net.input_name = copy.deepcopy(self.input_name)
net.output_name = copy.deepcopy(self.output_name)
net.module_list = copy.deepcopy(self.module_list)
net.op_tree = copy.deepcopy(self.op_tree)
net.reserved_tensors_name = copy.deepcopy(self.reserved_tensors_name)
return net

def load_parameters(model:_pt_model, initializer):
param_dict = {}
for init in initializer:
param_dict[init.name] = torch.from_numpy(numpy_helper.to_array(init).copy())
return param_dict

class _o2p_converter:
def __init__(self):
'''
* :ref:`API in English <ONNX_Converter.__init__-en>`

.. _ONNX_Converter.__init__-cn:

该类主要将onnx模型转换为Pytorch的ANN模型,从而转换为SpikingJelly的SNN模型
链接中 [#f1]_ 提供了一个onnx-pytorch转换的主要版本。更复杂的版本可以在这里找到。
大多数使用过的onnx运算符已在此处定义,但仍然有一些未被覆盖,或没有被完美实现
用户可以通过添加如下面例子所示的静态方法来定义您的例外情况

* :ref:`API in English <ONNX_Converter.__init__-cn>`

.. _ONNX_Converter.__init__-en:

This class mainly convert an onnx model to Pytorch ANN model, and thus to SpikingJelly SNN model
The link [#f1]_ has provided a primary version of onnx-pytorch conversion. More complex version can be found here.
Most used onnx operators has covered here, yet still there are some left, or not being defined perfectly
User can define your exceptions by adding static method like below

.. [#f1] https://gist.github.com/qinjian623/6aa777037534c1c1dccbb66f832e93b8
'''
pass

def add_method(self, op_name, func):
setattr(self, 'convert_'+op_name, staticmethod(func))

@staticmethod
def convert_conv(node, model:_pt_model):
attr_map = {
"pads": "padding",
"strides": "stride",
"kernel_shape": "kernel_size",
"group": "groups",
"dilations": "dilation"
}
assert len(node.output) == 1
with_bias = False
if len(node.input) == 3:
with_bias = True
bias = model.loaded_weights[node.input[2]]
del model.loaded_weights[node.input[2]]
weight = model.loaded_weights[node.input[1]]
del model.loaded_weights[node.input[1]]
in_channels = weight.shape[1]
out_channels = weight.shape[0]
kwargs = {}
for att in node.attribute:
kwargs[attr_map[att.name]] = list(att.ints) if att.name != 'group' else att.i
if 'padding' in kwargs:
assert(kwargs["padding"][0]==kwargs["padding"][2] and kwargs["padding"][1]==kwargs["padding"][3])
kwargs["padding"] = kwargs["padding"][0],kwargs["padding"][1]
groups = 1 if 'groups' not in kwargs else kwargs['groups']
in_channels *= groups
conv = nn.Conv2d(in_channels, out_channels, **kwargs, bias=with_bias)
conv.weight.data = weight
if with_bias:
conv.bias.data = bias
model.module_list.append(conv)
return len(model.module_list)-1, node.input[:1], node.output

@staticmethod
def convert_relu(node, model:_pt_model):
relu = nn.ReLU()
model.module_list.append(relu)
return len(model.module_list)-1, node.input, node.output

@staticmethod
def convert_prelu(node, model:_pt_model):
weight = model.loaded_weights[node.input[1]]
del model.loaded_weights[node.input[1]]
prelu = nn.PReLU()
prelu.weight.data = weight
model.module_list.append(prelu)
return len(model.module_list) - 1, node.input[:-1], node.output

@staticmethod
def convert_shape(node, model:_pt_model):
shape = Shape()
model.module_list.append(shape)
return len(model.module_list) - 1, node.input, node.output

@staticmethod
def convert_gather(node, model:_pt_model):
attr_map = {
"axis": "dim"
}
kwargs = {}
for att in node.attribute:
if att.name in attr_map:
kwargs[attr_map[att.name]] = att.f
gather = Gather(**kwargs)
model.module_list.append(gather)
return len(model.module_list) - 1, node.input, node.output

@staticmethod
def convert_unsqueeze(node, model:_pt_model):
attr_map = {
"axes": "dim"
}
kwargs = {}
for att in node.attribute:
if att.name in attr_map:
kwargs[attr_map[att.name]] = att.f
unsqueeze = Unsqueeze(**kwargs)
model.module_list.append(unsqueeze)
return len(model.module_list) - 1, node.input, node.output

@staticmethod
def convert_concat(node, model:_pt_model):
attr_map = {
"axis": "dim"
}
kwargs = {}
for att in node.attribute:
if att.name in attr_map:
kwargs[attr_map[att.name]] = att.f

concat = Concat(**kwargs)
model.module_list.append(concat)
return len(model.module_list) - 1, node.input, node.output

@staticmethod
def convert_reshape(node, model:_pt_model):
reshape = Reshape()
model.module_list.append(reshape)
return len(model.module_list) - 1, node.input, node.output

@staticmethod
def convert_matmul(node, model:_pt_model):
class MatMul(nn.Module):
def __init__(self):
super().__init__()
def forward(self,input1,input2):
return input1 @ input2
mul = MatMul()
model.module_list.append(mul)
return len(model.module_list)-1, node.input, node.output


@staticmethod
def convert_batchnormalization(node, model:_pt_model):
attr_map = {
"epsilon": "eps",
"momentum": "momentum"
}
assert len(node.input) == 5
assert len(node.output) == 1
weight = model.loaded_weights[node.input[1]]
bias = model.loaded_weights[node.input[2]]
running_mean = model.loaded_weights[node.input[3]]
running_var = model.loaded_weights[node.input[4]]
del model.loaded_weights[node.input[1]]
del model.loaded_weights[node.input[2]]
del model.loaded_weights[node.input[3]]
del model.loaded_weights[node.input[4]]
dim = weight.shape[0]
kwargs = {}
# _check_attr(node.attribute, rebuild_batchnormalization.bn_attr_map)
for att in node.attribute:
if att.name in attr_map:
kwargs[attr_map[att.name]] = att.f
bn = None
if model.dim_info[node.output[0]] == 5:
bn = nn.BatchNorm3d(num_features=dim)
elif model.dim_info[node.output[0]] == 4:
bn = nn.BatchNorm2d(num_features=dim)
elif model.dim_info[node.output[0]] == 2 or model.dim_info[node.output[0]] == 3:
bn = nn.BatchNorm1d(num_features=dim)
bn.weight.data = weight
bn.bias.data = bias
bn.running_mean.data = running_mean
bn.running_var.data = running_var
model.module_list.append(bn)
return len(model.module_list)-1, node.input[:1], node.output

@staticmethod
def convert_add(node, model:_pt_model):
add = Add()
model.module_list.append(add)
return len(model.module_list)-1, node.input, node.output

@staticmethod
def convert_mul(node, model:_pt_model):
mul = Mul()
model.module_list.append(mul)
return len(model.module_list)-1, node.input, node.output

@staticmethod
def convert_averagepool(node, model:_pt_model):
attr_map = {
"pads": "padding",
"strides": "stride",
"kernel_shape": "kernel_size",
}
kwargs = {}
for att in node.attribute:
kwargs[attr_map[att.name]] = list(att.ints)
if 'padding' in kwargs:
assert (kwargs["padding"][0] == kwargs["padding"][2] and kwargs["padding"][1] == kwargs["padding"][3])
kwargs["padding"] = kwargs["padding"][0], kwargs["padding"][1]
ap = nn.AvgPool2d(**kwargs)
model.module_list.append(ap)
return len(model.module_list)-1, node.input, node.output

@staticmethod
def convert_globalaveragepool(node, model:_pt_model):
gap = nn.AdaptiveAvgPool2d((1, 1))
model.module_list.append(gap)
model.module_list.append(gap)
return len(model.module_list) - 1, node.input, node.output

@staticmethod
def convert_maxpool(node, model:_pt_model):
attr_map = {
"pads": "padding",
"strides": "stride",
"kernel_shape": "kernel_size",
}
kwargs = {}
for att in node.attribute:
kwargs[attr_map[att.name]] = list(att.ints)
if 'padding' in kwargs:
assert (kwargs["padding"][0] == kwargs["padding"][2] and kwargs["padding"][1] == kwargs["padding"][3])
kwargs["padding"] = kwargs["padding"][0], kwargs["padding"][1]
ap = nn.MaxPool2d(**kwargs)
model.module_list.append(ap)
return len(model.module_list) - 1, node.input, node.output

@staticmethod
def convert_flatten(node, model:_pt_model):
if len(node.attribute) == 0:
axis = 1
else:
axis = node.attribute[0].i
if axis==1:
flatten = nn.Flatten()
model.module_list.append(flatten)
return len(model.module_list)-1, node.input, node.output
else:
raise NotImplementedError("Not Implemented yet!")

@staticmethod
def convert_gemm(node, model:_pt_model):
weight = model.loaded_weights[node.input[1]]
bias = model.loaded_weights[node.input[2]]
del model.loaded_weights[node.input[2]]
del model.loaded_weights[node.input[1]]
in_features = weight.shape[1]
out_features = weight.shape[0]
linear = nn.Linear(in_features=in_features, out_features=out_features)
linear.weight.data = weight
linear.bias.data = bias
model.module_list.append(linear)
return len(model.module_list)-1, node.input[:1], node.output

@staticmethod
def convert_pad(node, model:_pt_model):
mode = node.attribute[0].s
pads = list(node.attribute[1].ints)
value = node.attribute[2].f
try:
assert(mode == b'constant')
assert(sum(pads[:4]) == 0)
except AssertionError:
print("Now only support converting to nn.ConstantPad2d")
pad = nn.ConstantPad2d([*pads[2:4],*pads[3:5]],value)
model.module_list.append(pad)
return len(model.module_list)-1, node.input, node.output

+ 0
- 127
spikingjelly/clock_driven/ann2snn/kernels/pytorch.py View File

@@ -1,127 +0,0 @@
import numpy as np
import torch
import torch.nn as nn
import copy
from collections import defaultdict

def layer_reduction(model: nn.Module) -> nn.Module:
relu_linker = {} # 字典类型,用于通过relu层在network中的序号确定relu前参数化模块的序号
param_module_relu_linker = {} # 字典类型,用于通过relu前在network中的参数化模块的序号确定relu层序号
activation_range = defaultdict(float) # 字典类型,保存在network中的序号对应层的激活最大值(或某分位点值)

module_len = 0
module_list = nn.ModuleList([])
last_parammodule_idx = 0
for n, m in model.named_modules():
Name = m.__class__.__name__
# 加载激活层
if isinstance(m,nn.Softmax):
Name = 'ReLU'
print(UserWarning("Replacing Softmax by ReLU."))
if isinstance(m,nn.ReLU) or Name == "ReLU":
module_list.append(m)
relu_linker[module_len] = last_parammodule_idx
param_module_relu_linker[last_parammodule_idx] = module_len
module_len += 1
activation_range[module_len] = -1e5
# 加载BatchNorm层
if isinstance(m,(nn.BatchNorm2d,nn.BatchNorm1d)):
if isinstance(module_list[last_parammodule_idx], (nn.Conv2d,nn.Linear)):
absorb(module_list[last_parammodule_idx], m)
else:
module_list.append(copy.deepcopy(m))
# 加载有参数的层
if isinstance(m,(nn.Conv2d,nn.Linear)):
module_list.append(m)
last_parammodule_idx = module_len
module_len += 1
# 加载无参数层
if isinstance(m,nn.MaxPool2d):
module_list.append(m)
module_len += 1
if isinstance(m,nn.AvgPool2d):
module_list.append(nn.AvgPool2d(kernel_size=m.kernel_size, stride=m.stride, padding=m.padding))
module_len += 1
# if isinstance(m,nn.Flatten):
if m.__class__.__name__ == "Flatten":
module_list.append(m)
module_len += 1
network = torch.nn.Sequential(*module_list)
setattr(network,'param_module_relu_linker',param_module_relu_linker)
setattr(network, 'activation_range', activation_range)
return network

def rate_normalization(model: nn.Module, data: torch.Tensor, **kargs) -> nn.Module:
if not hasattr(model,"activation_range") or not hasattr(model,"param_module_relu_linker"):
raise(AttributeError("run layer_reduction first!"))
try:
robust_norm = kargs['robust']
except KeyError:
robust_norm = False
x = data
i = 0
with torch.no_grad():
for n, m in model.named_modules():
Name = m.__class__.__name__
if Name in ['Conv2d', 'ReLU', 'MaxPool2d', 'AvgPool2d', 'Flatten', 'Linear']:
x = m.forward(x)
a = x.cpu().numpy().reshape(-1)
if robust_norm:
model.activation_range[i] = np.percentile(a[np.nonzero(a)], 99.9)
else:
model.activation_range[i] = np.max(a)
i += 1
i = 0
last_lambda = 1.0
for n, m in model.named_modules():
Name = m.__class__.__name__
if Name in ['Conv2d', 'ReLU', 'MaxPool2d', 'AvgPool2d', 'Flatten', 'Linear']:
if Name in ['Conv2d', 'Linear']:
relu_output_layer = model.param_module_relu_linker[i]
if hasattr(m, 'weight') and m.weight is not None:
m.weight.data = m.weight.data * last_lambda / model.activation_range[relu_output_layer]
if hasattr(m, 'bias') and m.bias is not None:
m.bias.data = m.bias.data / model.activation_range[relu_output_layer]
last_lambda = model.activation_range[relu_output_layer]
i += 1
return model

def save_model(model: nn.Module, f):
if isinstance(f,str):
torch.save(model,f)
return

def absorb(param_module, bn_module):
if_2d = len(param_module.weight.size()) == 4 # 判断是否为BatchNorm2d
bn_std = torch.sqrt(bn_module.running_var.data + bn_module.eps)
if not if_2d:
if param_module.bias is not None:
param_module.weight.data = param_module.weight.data * bn_module.weight.data.view(-1, 1) / bn_std.view(
-1,
1)
param_module.bias.data = (param_module.bias.data - bn_module.running_mean.data.view(
-1)) * bn_module.weight.data.view(-1) / bn_std.view(
-1) + bn_module.bias.data.view(-1)
else:
param_module.weight.data = param_module.weight.data * bn_module.weight.data.view(-1, 1) / bn_std.view(
-1,
1)
param_module.bias.data = (torch.zeros_like(
bn_module.running_mean.data.view(-1)) - bn_module.running_mean.data.view(
-1)) * bn_module.weight.data.view(-1) / bn_std.view(-1) + bn_module.bias.data.view(-1)
else:
if param_module.bias is not None:
param_module.weight.data = param_module.weight.data * bn_module.weight.data.view(-1, 1, 1,
1) / bn_std.view(-1, 1,
1, 1)
param_module.bias.data = (param_module.bias.data - bn_module.running_mean.data.view(
-1)) * bn_module.weight.data.view(-1) / bn_std.view(
-1) + bn_module.bias.data.view(-1)
else:
param_module.weight.data = param_module.weight.data * bn_module.weight.data.view(-1, 1, 1,
1) / bn_std.view(-1, 1,
1, 1)
param_module.bias.data = (torch.zeros_like(
bn_module.running_mean.data.view(-1)) - bn_module.running_mean.data.view(
-1)) * bn_module.weight.data.view(-1) / bn_std.view(-1) + bn_module.bias.data.view(-1)
return param_module

+ 87
- 126
spikingjelly/clock_driven/ann2snn/modules.py View File

@@ -1,133 +1,94 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class MaxPool2d(nn.Module):
def __init__(self, kernel_size, stride=None, padding=0, dilation=1,
return_indices=False, ceil_mode=False, momentum=None):
'''
* :ref:`API in English <MaxPool2d.__init__-en>`

.. _MaxPool2d.__init__-cn:

:param kernel_size: 窗口取最大的大小
:param stride: 窗口的步长. 默认值为 :attr:`kernel_size`
:param padding: 隐式两侧填充零的大小
:param dilation: 控制窗口中元素的步幅的参数
:param return_indices: 当 ``True`` ,将返回最大序号并输出
:param ceil_mode: 当 ``True`` ,将使用 `ceil` 而不是 `floor` 来计算输出形状
:param momentum: 当在[0,1]中,将在门控函数中使用在线动量统计;
当为 ``None`` 时,将在门控函数中使用累计脉冲数
:return: ``None``

基于文献 [#f1]_ 中2.2.6章节设计MaxPool2d模块。为了兼容Pytorch的MaxPool2d模块,众多参数设定和Pytorch相同。详情请见 ``torch.nn.MaxPool2d`` 。
基本想法是对输入脉冲进行统计,统计量可以控制门控函数确定以哪一路输入信号作为输出。
根据 `momentum` 参数类型不同可以有不同的统计功能。 `momentum` 参数支持None值和[0,1]区间的浮点数数值作为输出。
假定在t时刻,脉冲输入张量为 :math:`s_t` ,脉冲统计量为 :math:`p_t`
当 `momentum` 参数为 ``None`` 时,统计量为累计脉冲数

.. math::
p_t = p_{t-1} + s_t

当 `momentum` 参数为[0,1]区间的浮点数时,统计量为在线的动量累积

.. math::
p_t = momentum * p_{t-1} + (1-momentum) * s_t

* :ref:`中文API <MaxPool2d.__init__-cn>`

.. _MaxPool2d.__init__-en:

:param kernel_size: the size of the window to take a max over
:param stride: the stride of the window. Default value is :attr:`kernel_size`
:param padding: implicit zero padding to be added on both sides
:param dilation: a parameter that controls the stride of elements in the window
:param return_indices: if ``True``, will return the max indices along with the outputs.
Useful for :class:`torch.nn.MaxUnpool2d` later
:param ceil_mode: when ``True``, will use `ceil` instead of `floor` to compute the output shape
:param momentum: when in [0,1], will use online momentum statistics in gate functions;
when ``None``, will use accumulated spike in gate functions
:return: ``None``

Design the MaxPool2d module based on section 2.2.6 in [#f1]_ . In order to be compatible with Pytorch's MaxPool2d module, many parameter settings are the same as Pytorch. See ``torch.nn.MaxPool2d`` for details.
The basic idea is to accumulate the input spikes, which can control the gating function to determine which input spike is used as output.
Depending on the type of `momentum` parameter, different statistical functions can be used.
`momentum` supports the floating-point value in [0,1] or value ``None``
Assume at time t, the spike input is :math:`s_t` and the spike statistic is :math:`p_t`.
When `momentum` is ``None``, the statistic is sum of spikes over time.

.. math::
p_t = p_{t-1} + s_t

When `momentum` is a floating point in [0,1], the statistic is online momentum of spikes.

.. math::
p_t = momentum * p_{t-1} + (1-momentum) * s_t

.. [#f1] Rueckauer B, Lungu I-A, Hu Y, Pfeiffer M and Liu S-C (2017) Conversion of Continuous-Valued Deep Networks to
Efficient Event-Driven Networks for Image Classification. Front. Neurosci. 11:682.
'''

super(MaxPool2d, self).__init__()
self.kernel_size = kernel_size
self.stride = stride or kernel_size
self.padding = padding
self.dilation = dilation
self.return_indices = return_indices
self.ceil_mode = ceil_mode

assert (momentum is None or momentum <= 1)
self.momentum = momentum
import torch
import numpy as np

self.v = 0
class VoltageHook(nn.Module):
def __init__(self, scale=1.0, momentum=0.1, mode='Max'):
"""
* :ref:`API in English <VoltageHook.__init__-en>`

def forward(self, dv: torch.Tensor):
if self.momentum is not None:
self.v = self.v * self.momentum + (1 - self.momentum) * dv
else:
self.v += dv
(dv_out, ind) = F.max_pool2d(self.v, self.kernel_size, self.stride,
self.padding, self.dilation, self.ceil_mode, True)
unpool_dv_out = F.max_unpool2d(dv_out, ind, self.kernel_size, self.stride, self.padding, self.v.size())
max_gate = (unpool_dv_out != 0.0).float()
gated_spk = dv * max_gate
spk = F.max_pool2d(gated_spk, self.kernel_size, self.stride,
self.padding)
return spk

def reset(self):
'''
:return: None

重置神经元为初始状态
'''
self.v = 0


class AccuLayer(nn.Module):
def __init__(self, momentum=None):
super(AccuLayer, self).__init__()

assert (momentum is None or momentum <= 1)
.. _voltageHook.__init__-cn:

:param scale: 缩放初始值
:type scale: float
:param momentum: 动量值
:type momentum: float
:param mode: 模式。输入“Max”表示记录ANN激活最大值,“99.9%”表示记录ANN激活的99.9%分位点,输入0-1的float型浮点数表示记录激活最大值的对应倍数。
:type mode: str, float

``VoltageHook`` 用于在ANN推理中确定激活的范围。

* :ref:`中文API <VoltageHook.__init__-cn>`

.. _voltageHook.__init__-en:

:param scale: initial scaling value
:type scale: float
:param momentum: momentum value
:type momentum: float
:param mode: The mode. Value "Max" means recording the maximum value of ANN activation, "99.9%" means recording the 99.9% precentile of ANN activation, and a float of 0-1 means recording the corresponding multiple of the maximum activation value.
:type mode: str, float

``VoltageHook`` is used to determine the range of activations in ANN inference.

"""
super().__init__()
self.register_buffer('scale', torch.tensor(scale))
self.mode = mode
self.num_batches_tracked = 0
self.momentum = momentum
self.v = 0
self.t = 0.0

def forward(self, spk: torch.Tensor):
self.t += 1.0
if self.momentum is not None:
self.v = self.v * self.momentum + (1 - self.momentum) * spk
return self.v

def forward(self, x):
err_msg = 'You have used a non-defined VoltageScale Method.'
if isinstance(self.mode, str):
if self.mode[-1] == '%':
try:
s_t = torch.tensor(np.percentile(x.detach().cpu(), float(self.mode[:-1])))
except ValueError:
raise NotImplemented(err_msg)
elif self.mode.lower() in ['max']:
s_t = x.max().detach()
else:
raise NotImplemented(err_msg)
elif isinstance(self.mode, float) and self.mode <= 1 and self.mode > 0:
s_t = x.max().detach() * self.mode
else:
raise NotImplemented(err_msg)
if self.num_batches_tracked == 0:
self.scale = s_t
else:
self.v += spk
return self.v / self.t
self.scale = (1 - self.momentum) * self.scale + self.momentum * s_t
self.num_batches_tracked += x.shape[0]
return x

class VoltageScaler(nn.Module):
def __init__(self, scale=1.0):
"""
* :ref:`API in English <VoltageScaler.__init__-en>`

.. _voltageScaler.__init__-cn:

:param scale: 缩放值
:type scale: float

``VoltageScaler`` 用于SNN推理中缩放电流。

* :ref:`中文API <VoltageScaler.__init__-cn>`

.. _voltageScaler.__init__-en:

:param scale: scaling value
:type scale: float

``VoltageScaler`` is used for scaling current in SNN inference.

"""
super().__init__()
self.register_buffer('scale', torch.tensor(scale))

def reset(self):
'''
:return: None
def forward(self, x):
return x * self.scale

重置神经元为初始状态
'''
self.t = 0.0
self.v = 0.0
def extra_repr(self):
return '%f' % self.scale.item()

spikingjelly/clock_driven/ann2snn/examples/model_sample/cifar10/resnet.py → spikingjelly/clock_driven/ann2snn/sample_models/cifar10_resnet.py View File

@@ -30,12 +30,14 @@ class BasicBlock(nn.Module):
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.relu1(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
out = self.relu2(out)
return out


@@ -60,13 +62,16 @@ class Bottleneck(nn.Module):
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()
self.relu3 = nn.ReLU()

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
out = self.relu3(out)
return out


@@ -84,6 +89,7 @@ class ResNet(nn.Module):
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512*block.expansion, num_classes)
self.flatten = nn.Flatten()
self.relu = nn.ReLU()

def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
@@ -94,7 +100,7 @@ class ResNet(nn.Module):
return nn.Sequential(*layers)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)

+ 28
- 0
spikingjelly/clock_driven/ann2snn/sample_models/mnist_cnn.py View File

@@ -0,0 +1,28 @@
import torch.nn as nn

class CNN(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(1, 32, 3, 1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.AvgPool2d(2, 2),

nn.Conv2d(32, 32, 3, 1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.AvgPool2d(2, 2),

nn.Conv2d(32, 32, 3, 1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.AvgPool2d(2, 2),

nn.Flatten(),
nn.Linear(32, 10)
)

def forward(self,x):
x = self.network(x)
return x

+ 29
- 0
spikingjelly/clock_driven/ann2snn/utils.py View File

@@ -0,0 +1,29 @@
import requests
import os
from tqdm import tqdm

def download_url(url, dst):
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:67.0) Gecko/20100101 Firefox/67.0'
}

response = requests.get(url, headers=headers, stream=True) # (1)
file_size = int(response.headers['content-length']) # (2)
if os.path.exists(dst):
first_byte = os.path.getsize(dst) # (3)
else:
first_byte = 0
if first_byte >= file_size: # (4)
return file_size

header = {"Range": f"bytes={first_byte}-{file_size}"}

pbar = tqdm(total=file_size, initial=first_byte, unit='B', unit_scale=True, desc=dst)
req = requests.get(url, headers=header, stream=True) # (5)
with open(dst, 'ab') as f:
for chunk in req.iter_content(chunk_size=1024): # (6)
if chunk:
f.write(chunk)
pbar.update(1024)
pbar.close()
return file_size

+ 177
- 65
spikingjelly/clock_driven/cu_kernel_opt.py View File

@@ -1,67 +1,179 @@
import logging
import torch
import time
import numpy as np
from .. import configure
import os
import threading
import datetime
from torch.utils.tensorboard import SummaryWriter
import re
try:
import cupy
import torch
import time
import numpy as np
from ..configure import cuda_threads, cuda_compiler_options


def cal_fun_t(n, device, f, *args, **kwargs):
if n <= 2:
torch.cuda.synchronize(device)
t_start = time.perf_counter()
f(*args, **kwargs)
torch.cuda.synchronize(device)
return (time.perf_counter() - t_start)
# warm up
f(*args, **kwargs)
torch.cuda.synchronize(device)

t_list = []
for _ in range(n * 2):
torch.cuda.synchronize(device)
t_start = time.perf_counter()
f(*args, **kwargs)
torch.cuda.synchronize(device)
t_list.append(time.perf_counter() - t_start)
t_list = np.asarray(t_list)
return t_list[n:].mean()

def cal_blocks(numel: int):
return (numel + cuda_threads - 1) // cuda_threads

def get_contiguous(*args):
ret_list = []
for item in args:
if isinstance(item, torch.Tensor):
ret_list.append(item.contiguous())

elif isinstance(item, cupy.ndarray):
ret_list.append(cupy.ascontiguousarray(item))

else:
raise TypeError
return ret_list

def wrap_args_to_raw_kernel(device: int, *args):
# note that the input must be contiguous
# check device and get data_ptr from tensor
ret_list = []
for item in args:
if isinstance(item, torch.Tensor):
assert item.get_device() == device
assert item.is_contiguous()
ret_list.append(item.data_ptr())

elif isinstance(item, cupy.ndarray):
assert item.device.id == device
assert item.flags['C_CONTIGUOUS']
ret_list.append(item)

else:
raise TypeError

return tuple(ret_list)

except ImportError:
pass
except BaseException as e:
logging.info(f'spikingjelly.clock_driven.cu_kernel_opt: {e}')
pass

def cuda_timer(device, f, *args, **kwargs):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
f(*args, **kwargs)
end.record()
torch.cuda.synchronize(device)
return start.elapsed_time(end)

def cal_fun_t(n, device, f, *args, **kwargs):
assert n > 2
# warm up
cuda_timer(device, f, *args, **kwargs)

t_list = []
for _ in range(n * 2):
t_list.append(cuda_timer(device, f, *args, **kwargs))
t_list = np.asarray(t_list)
return t_list[n:].mean()

def cal_blocks(numel: int):
return (numel + configure.cuda_threads - 1) // configure.cuda_threads

def get_contiguous(*args):
ret_list = []
for item in args:
if isinstance(item, torch.Tensor):
ret_list.append(item.contiguous())

elif isinstance(item, cupy.ndarray):
ret_list.append(cupy.ascontiguousarray(item))

else:
raise TypeError
return ret_list

def wrap_args_to_raw_kernel(device: int, *args):
# note that the input must be contiguous
# check device and get data_ptr from tensor
ret_list = []
for item in args:
if isinstance(item, torch.Tensor):
assert item.get_device() == device
assert item.is_contiguous()
ret_list.append(item.data_ptr())

elif isinstance(item, cupy.ndarray):
assert item.device.id == device
assert item.flags['C_CONTIGUOUS']
ret_list.append(item)

else:
raise TypeError

return tuple(ret_list)

class GPUMonitor(threading.Thread):
def __init__(self, log_dir: str = None, gpu_ids: tuple = (0, ), interval: float = 60., start_now=True):
"""
:param log_dir: the directory for saving logs with tensorboard. If it is None, this module will print logs
:type log_dir: str
:param gpu_ids: the id of GPUs to be monitored, e.g., `(0, 1, 2, 3)`. The default value is `(0, )`
:type gpu_ids: tuple
:param interval: the recording interval (in seconds)
:type interval: float
:param start_now: if true, the monitor will start to record now. Otherwise, it will start after the user call `start()` manually
:type start_now:

The GPU monitor, which starts a new thread to record the utilization and memory used of `gpu_ids` every `interval` seconds.

.. admonition:: Warning
:class: warning

Do not forget to call `stop()` after the main thread finishes its job, otherwise the main thread will never stop!

Codes example:

.. code-block:: python

import time

gm = GPUMonitor(interval=1)
time.sleep(2) # make the main thread sleep
gm.stop()

# The outputs are:

# 2022-04-28 10:52:25
# utilization.gpu [%], memory.used [MiB]
# 0 %, 376 MiB
"""
super().__init__()
self.gpu_ids = gpu_ids
self.interval = interval
self.stopped = False
self.cmds = 'nvidia-smi --query-gpu=utilization.gpu,memory.used --format=csv'
self.cmds += ' -i '
id_str = []
for gpu_id in self.gpu_ids:
id_str.append(str(gpu_id))
self.cmds += ','.join(id_str)
self.step = 0

if log_dir is None:
self.writer = None
else:
self.writer = []
for i in range(self.gpu_ids.__len__()):
self.writer.append(SummaryWriter(os.path.join(log_dir, f'gpu_{id_str[i]}')))

if start_now:
self.start()

def stop(self):
self.stopped = True

def run(self):
while not self.stopped:
with os.popen(self.cmds) as fp:
outputs = fp.read()
if self.writer is not None:
outputs = outputs.split('\n')[1:-1]
# skip the first row 'utilization.gpu [%], memory.used [MiB]' and the last row ('\n')
for i in range(outputs.__len__()):
utilization_memory = re.findall(r'\d+', outputs[i])
utilization = int(utilization_memory[0])
memory_used = int(utilization_memory[1])
self.writer[i].add_scalar('utilization', utilization, self.step)
self.writer[i].add_scalar('memory_used', memory_used, self.step)
else:
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
print(outputs)
'''
2022-04-20 18:14:26
utilization.gpu [%], memory.used [MiB]
4 %, 1816 MiB
0 %, 1840 MiB
0 %, 1840 MiB
0 %, 1720 MiB
'''
time.sleep(self.interval)
self.step += 1


class DeviceEnvironment:
def __init__(self, device: int):
"""
This module is used as a context to make CuPy use the specific device, and avoids `torch.cuda.current_device()` is changed by CuPy.
Refer to https://github.com/cupy/cupy/issues/6569 for more details.
"""
self.device = device
self.previous_device = None

def __enter__(self):
current_device = torch.cuda.current_device()
if current_device != self.device:
torch.cuda.set_device(self.device)
self.previous_device = current_device

def __exit__(self, exc_type, exc_val, exc_tb):
if self.previous_device is not None:
torch.cuda.set_device(self.previous_device)


+ 9
- 9
spikingjelly/clock_driven/encoding.py View File

@@ -312,14 +312,14 @@ class PoissonEncoder(StatelessEncoder):
return out_spike

class WeightedPhaseEncoder(StatefulEncoder):
def __init__(self, T: int):
def __init__(self, K: int):
"""
* :ref:`API in English <WeightedPhaseEncoder.__init__-en>`

.. _WeightedPhaseEncoder.__init__-cn:

:param T: 编码周期。通常情况下,与SNN的仿真周期(总步长一致)
:type T: int
:param K: 编码周期。通常情况下,与SNN的仿真周期(总步长一致)
:type K: int

Kim J, Kim H, Huh S, et al. Deep neural networks with weighted spikes[J]. Neurocomputing, 2018, 311: 373-386.

@@ -346,8 +346,8 @@ class WeightedPhaseEncoder(StatefulEncoder):

.. _WeightedPhaseEncoder.__init__-en:

:param T: the encoding period. It is usually same with the total simulation time-steps of SNN
:type T: int
:param K: the encoding period. It is usually same with the total simulation time-steps of SNN
:type K: int

The weighted phase encoder, which is based on binary system. It will flatten ``x`` as a binary number. When
``T=k``, it can encode :math:`x \in [0, 1-2^{-K}]` to different spikes. Here is the example from the origin paper:
@@ -368,14 +368,14 @@ class WeightedPhaseEncoder(StatefulEncoder):


"""
super().__init__(T)
super().__init__(K)

def encode(self, x: torch.Tensor):
assert (x >= 0).all() and (x <= 1 - 2 ** (-self.phase)).all()
assert (x >= 0).all() and (x <= 1 - 2 ** (-self.T)).all()
inputs = x.clone()
self.spike = torch.empty((self.phase,) + x.shape, device=x.device) # 编码为[phase, batch_size, *]
self.spike = torch.empty((self.T,) + x.shape, device=x.device) # Encoding to [T, batch_size, *]
w = 0.5
for i in range(self.phase):
for i in range(self.T):
self.spike[i] = inputs >= w
inputs -= w * self.spike[i]
w *= 0.5

+ 1
- 1
spikingjelly/clock_driven/examples/DQN_state.py View File

@@ -69,7 +69,7 @@ if __name__ == '__main__':

device = torch.device("cuda" if args.use_cuda else "cpu")

writer = SummaryWriter(logdir='./log')
writer = SummaryWriter(log_dir='./log')

env = gym.make(env_name).unwrapped
env.seed(args.seed)


+ 9
- 9
spikingjelly/clock_driven/examples/Spiking_DQN_state.py View File

@@ -38,15 +38,15 @@ class ReplayMemory(object):
return len(self.memory)


class NonSpikingLIFNode(neuron.LIFNode):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, dv: torch.Tensor):
self.neuronal_charge(dv)
# self.neuronal_fire()
# self.neuronal_reset()
return self.v
class NonSpikingLIFNode(neuron.LIFNode):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, dv: torch.Tensor):
self.neuronal_charge(dv)
# self.neuronal_fire()
# self.neuronal_reset()
return self.v


# Spiking DQN algorithm


+ 16
- 3
spikingjelly/clock_driven/examples/lif_fc_mnist.py View File

@@ -179,6 +179,17 @@ def main():

# 保存绘图用数据
net.eval()
# 注册钩子
output_layer = net[-1] # 输出层
output_layer.v_seq = []
output_layer.s_seq = []
def save_hook(m, x, y):
m.v_seq.append(m.v.unsqueeze(0))
m.s_seq.append(y.unsqueeze(0))

output_layer.register_forward_hook(save_hook)


with torch.no_grad():
img, label = test_dataset[0]
img = img.to(device)
@@ -189,10 +200,12 @@ def main():
out_spikes_counter += net(encoder(img).float())
out_spikes_counter_frequency = (out_spikes_counter / T).cpu().numpy()
print(f'Firing rate: {out_spikes_counter_frequency}')
output_layer = net[-1] # 输出层
v_t_array = output_layer.v.cpu().numpy().squeeze().T # v_t_array[i][j]表示神经元i在j时刻的电压值

output_layer.v_seq = torch.cat(output_layer.v_seq)
output_layer.s_seq = torch.cat(output_layer.s_seq)
v_t_array = output_layer.v_seq.cpu().numpy().squeeze().T # v_t_array[i][j]表示神经元i在j时刻的电压值
np.save("v_t_array.npy",v_t_array)
s_t_array = output_layer.spike.cpu().numpy().squeeze().T # s_t_array[i][j]表示神经元i在j时刻释放的脉冲,为0或1
s_t_array = output_layer.s_seq.cpu().numpy().squeeze().T # s_t_array[i][j]表示神经元i在j时刻释放的脉冲,为0或1
np.save("s_t_array.npy",s_t_array)

train_accs = np.array(train_accs)


+ 30
- 29
spikingjelly/clock_driven/functional.py View File

@@ -4,6 +4,8 @@ import torch.nn.functional as F
import math
from . import neuron

from torch import Tensor

def reset_net(net: nn.Module):
'''
* :ref:`API in English <reset_net-en>`
@@ -30,7 +32,7 @@ def reset_net(net: nn.Module):
if hasattr(m, 'reset'):
m.reset()

def spike_cluster(v: torch.Tensor, v_threshold, T_in: int):
def spike_cluster(v: Tensor, v_threshold, T_in: int):
'''
* :ref:`API in English <spike_cluster-en>`

@@ -180,7 +182,7 @@ def spike_cluster(v: torch.Tensor, v_threshold, T_in: int):

return N_o, k_positive, k_negative

def spike_similar_loss(spikes:torch.Tensor, labels:torch.Tensor, kernel_type='linear', loss_type='mse', *args):
def spike_similar_loss(spikes:Tensor, labels:Tensor, kernel_type='linear', loss_type='mse', *args):
'''
* :ref:`API in English <spike_similar_loss-en>`

@@ -285,7 +287,7 @@ def spike_similar_loss(spikes:torch.Tensor, labels:torch.Tensor, kernel_type='li
else:
raise NotImplementedError

def kernel_dot_product(x:torch.Tensor, y:torch.Tensor, kernel='linear', *args):
def kernel_dot_product(x:Tensor, y:Tensor, kernel='linear', *args):

'''
* :ref:`API in English <kernel_dot_product-en>`
@@ -349,7 +351,7 @@ def kernel_dot_product(x:torch.Tensor, y:torch.Tensor, kernel='linear', *args):
else:
raise NotImplementedError

def set_threshold_margin(output_layer:neuron.BaseNode, label_one_hot:torch.Tensor,
def set_threshold_margin(output_layer:neuron.BaseNode, label_one_hot:Tensor,
eval_threshold=1.0, threshold0=0.9, threshold1=1.1):
'''
* :ref:`API in English <set_threshold_margin-en>`
@@ -391,7 +393,7 @@ def set_threshold_margin(output_layer:neuron.BaseNode, label_one_hot:torch.Tenso
else:
output_layer.v_threshold = eval_threshold

def redundant_one_hot(labels:torch.Tensor, num_classes:int, n:int):
def redundant_one_hot(labels:Tensor, num_classes:int, n:int):
'''
* :ref:`API in English <redundant_one_hot-en>`

@@ -453,7 +455,7 @@ def redundant_one_hot(labels:torch.Tensor, num_classes:int, n:int):
codes += F.one_hot(labels * n + i, redundant_classes)
return codes

def first_spike_index(spikes: torch.Tensor):
def first_spike_index(spikes: Tensor):
'''
* :ref:`API in English <first_spike_index-en>`

@@ -522,47 +524,46 @@ def first_spike_index(spikes: torch.Tensor):
# 在时间维度上,2次cumsum后,元素为1的位置,即为首次发放脉冲的位置
return spikes.cumsum(dim=-1).cumsum(dim=-1) == 1

def multi_step_forward(x_seq: torch.Tensor, multi_step_module: nn.Module or list or tuple):
def multi_step_forward(x_seq: Tensor, single_step_module: nn.Module or list or tuple or nn.Sequential):
"""
:param x_seq: shape=[T, batch_size, ...]
:type x_seq: torch.Tensor
:param multi_step_module: a multi-step module, or a list/tuple that contains multi-step modules
:type multi_step_module: torch.nn.Module or list or tuple
:type x_seq: Tensor
:param single_step_module: a single-step module, or a list/tuple that contains single-step modules
:type single_step_module: torch.nn.Module or list or tuple or torch.nn.Sequential
:return: y_seq, shape=[T, batch_size, ...]
:rtype: torch.Tensor
:rtype: Tensor

See :class:`spikingjelly.clock_driven.layer.MultiStepContainer` for more details.
"""
y_seq = []
if isinstance(multi_step_module, (list, tuple)):
if isinstance(single_step_module, (list, tuple, nn.Sequential)):
for t in range(x_seq.shape[0]):
x_seq_t = x_seq[t]
for m in multi_step_module:
for m in single_step_module:
x_seq_t = m(x_seq_t)
y_seq.append(x_seq_t)
else:
for t in range(x_seq.shape[0]):
y_seq.append(multi_step_module(x_seq[t]))
y_seq.append(single_step_module(x_seq[t]))

for t in range(y_seq.__len__()):
# y_seq[t].unsqueeze_(0)
y_seq[t] = y_seq[t].unsqueeze(0)
return torch.cat(y_seq, 0)

def seq_to_ann_forward(x_seq: torch.Tensor, stateless_module: nn.Module or list or tuple):
def seq_to_ann_forward(x_seq: Tensor, stateless_module: nn.Module or list or tuple or nn.Sequential):
"""
:param x_seq: shape=[T, batch_size, ...]
:type x_seq: torch.Tensor
:param multi_step_module: a stateless module, e.g., 'torch.nn.Conv2d' or a list contains stateless modules, e.g., '[torch.nn.Conv2d, torch.nn.BatchNorm2d]
:type multi_step_module: torch.nn.Module or list or tuple
:type x_seq: Tensor
:param stateless_module: a stateless module, e.g., 'torch.nn.Conv2d' or a list contains stateless modules, e.g., '[torch.nn.Conv2d, torch.nn.BatchNorm2d]
:type stateless_module: torch.nn.Module or list or tuple or torch.nn.Sequential
:return: y_seq, shape=[T, batch_size, ...]
:rtype: torch.Tensor
:rtype: Tensor

See :class:`spikingjelly.clock_driven.layer.SeqToANNContainer` for more details.
"""
y_shape = [x_seq.shape[0], x_seq.shape[1]]
y = x_seq.flatten(0, 1)
if isinstance(stateless_module, (list, tuple)):
if isinstance(stateless_module, (list, tuple, nn.Sequential)):
for m in stateless_module:
y = m(y)
else:
@@ -577,7 +578,7 @@ def fused_conv2d_weight_of_convbn2d(conv2d: nn.Conv2d, bn2d: nn.BatchNorm2d):
:param bn2d: a BatchNorm2d layer
:type bn2d: torch.nn.BatchNorm2d
:return: the weight of this fused module
:rtype: torch.Tensor
:rtype: Tensor

A {Conv2d-BatchNorm2d} can be fused to a {Conv2d} module with BatchNorm2d's parameters being absorbed into Conv2d.
This function returns the weight of this fused module.
@@ -600,7 +601,7 @@ def fused_conv2d_bias_of_convbn2d(conv2d: nn.Conv2d, bn2d: nn.BatchNorm2d):
:param bn2d: a BatchNorm2d layer
:type bn2d: torch.nn.BatchNorm2d
:return: the bias of this fused module
:rtype: torch.Tensor
:rtype: Tensor

A {Conv2d-BatchNorm2d} can be fused to a {Conv2d} module with BatchNorm2d's parameters being absorbed into Conv2d.
This function returns the bias of this fused module.
@@ -690,14 +691,14 @@ def fuse_convbn2d(conv2d: nn.Conv2d, bn2d: nn.BatchNorm2d, k=None, b=None):
fused_conv.bias.data = fused_conv2d_bias_of_convbn2d(conv2d, bn2d)
return fused_conv

def temporal_efficient_training_cross_entropy(x_seq: torch.Tensor, target: torch.LongTensor):
def temporal_efficient_training_cross_entropy(x_seq: Tensor, target: torch.LongTensor):
"""
:param x_seq: ``shape=[T, N, C, *]``, where ``C`` is the number of classes
:type x_seq: torch.Tensor
:type x_seq: Tensor
:param target: ``shape=[N]``, where ``0 <= target[i] <= C-1``
:type target: torch.LongTensor
:return: the temporal efficient training cross entropy
:rtype: torch.Tensor
:rtype: Tensor

The temporal efficient training (TET) cross entropy, which is the mean of cross entropy of each time-step.

@@ -705,7 +706,7 @@ def temporal_efficient_training_cross_entropy(x_seq: torch.Tensor, target: torch

.. code-block:: python

def tet_ce_for_loop_version(x_seq: torch.Tensor, target: torch.LongTensor):
def tet_ce_for_loop_version(x_seq: Tensor, target: torch.LongTensor):
loss = 0.
for t in range(x_seq.shape[0]):
loss += F.cross_entropy(x_seq[t], target)
@@ -761,9 +762,9 @@ def kaiming_normal_conv_linear_weight(net: nn.Module):

:return: None

initialize all weights (not including bias) of :class:`torch.nn._ConvNd` and `:class:`torch.nn.Linear` in `net` by the kaiming normal. See :class:`torch.nn.init.kaiming_normal_`
initialize all weights (not including bias) of :class:`torch.nn._ConvNd` and :class:`torch.nn.Linear` in `net` by the kaiming normal. See :class:`torch.nn.init.kaiming_normal_`
for more details.
'''
for m in net.modules():
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
nn.init.kaiming_normal_(m.weight, a=math.sqrt(5))
nn.init.kaiming_normal_(m.weight, a=math.sqrt(5))

+ 303
- 0
spikingjelly/clock_driven/lava_exchange.py View File

@@ -0,0 +1,303 @@
import torch
import torch.nn as nn
import logging
from . import neuron

try:
import lava.lib.dl.slayer as slayer

except BaseException as e:
logging.info(f'spikingjelly.clock_driven.lava_exchange: {e}')
slayer = None

# ----------------------------------------
# data reshape function

def TNX_to_NXT(x_seq: torch.Tensor):
# x_seq.shape = [T, N, *]
permute_args = list(range(1, x_seq.dim()))
permute_args.append(0)
return x_seq.permute(permute_args)

def NXT_to_TNX(x_seq: torch.Tensor):
# x_seq.shape = [N, *, T]
permute_args = list(range(x_seq.dim() - 1))
permute_args.insert(0, x_seq.dim() - 1)
return x_seq.permute(permute_args)


def lava_neuron_forward(lava_neuron: nn.Module, x_seq: torch.Tensor, v: torch.Tensor or float):
# x_seq.shape = [T, N, *]
# lave uses shape = [*, T], while SJ uses shape = [T, *]
unsqueeze_flag = False
if x_seq.dim() == 2:
x_seq = x_seq.unsqueeze(1)
# lave needs input with shape [N, ... ,T]
unsqueeze_flag = True

if isinstance(v, float):
v_init = v
v = torch.zeros_like(x_seq[0])
if v_init != 0.:
torch.fill_(v, v_init)

x_seq_shape = x_seq.shape
x_seq = x_seq.flatten(2).permute(1, 2, 0)
# [T, N, *] -> [N, *, T]

lava_neuron.voltage_state = v
spike = lava_neuron(x_seq).permute(2, 0, 1)

v = lava_neuron.voltage_state.reshape(x_seq_shape[1:])
spike = spike.reshape(x_seq_shape)
if unsqueeze_flag:
v = v.squeeze(1)
spike = spike.squeeze(1)

return spike, v

# ----------------------------------------
# quantize function

class _step_quantize(torch.autograd.Function):
@staticmethod
def forward(ctx, x, step):
return torch.round(x / step) * step

@staticmethod
def backward(ctx, grad_output):
return grad_output, None

def step_quantize(x: torch.Tensor, step: float = 1.):
"""
:param x: the input tensor
:type x: torch.Tensor
:param step: the quantize step
:type step: float
:return: quantized tensor
:rtype: torch.Tensor

The step quantize function. Here is an example:

.. code-block:: python

# plt.style.use(['science', 'muted', 'grid'])
fig = plt.figure(dpi=200, figsize=(6, 4))
x = torch.arange(-4, 4, 0.001)
plt.plot(x, lava_exchange.step_quantize(x, 2.), label='quantize(x, step=2)')
plt.plot(x, x, label='y=x', ls='-.')
plt.legend()
plt.grid(ls='--')
plt.title('step quantize')
plt.xlabel('Input')
plt.ylabel('Output')
plt.savefig('./docs/source/_static/API/clock_driven/lava_exchange/step_quantize.svg')
plt.savefig('./docs/source/_static/API/clock_driven/lava_exchange/step_quantize.pdf')

.. image:: ./_static/API/clock_driven/lava_exchange/step_quantize.*
:width: 100%

"""
return _step_quantize.apply(x, step)


def quantize_8bit(x: torch.Tensor, scale, descale=False):
if descale:
return step_quantize(x, 2. / scale).clamp(-256. / scale, 255. / scale) * scale
else:
return step_quantize(x, 2. / scale).clamp(-256. / scale, 255. / scale)

# ----------------------------------------
# convert function
def check_conv2d(conv2d_nn: nn.Conv2d):
if not isinstance(conv2d_nn, nn.Conv2d):
raise ValueError(f'expected conv2d_nn with type torch.nn.Conv2d, but got conv2d_nn with type {type(conv2d_nn)}!')

if conv2d_nn.bias is not None:
raise ValueError('lava does not support for convolutional synapse with bias!')

def check_fc(fc: nn.Linear):
if not isinstance(fc, nn.Linear):
raise ValueError(f'expected fc with type torch.nn.Linear, but got fc with type {type(fc)}!')

if fc.bias is not None:
raise ValueError('lava does not support for dense synapse with bias!')

def to_lava_neuron_param_dict(sj_ms_neuron: nn.Module):
if isinstance(sj_ms_neuron, neuron.MultiStepIFNode):
if sj_ms_neuron.v_reset != 0.:
raise ValueError('lava only supports for v_reset == 0!')
return {
'threshold': sj_ms_neuron.v_threshold,
'current_decay': 1.,
'voltage_decay': 0.,
'tau_grad': 1, 'scale_grad': 1, 'scale': sj_ms_neuron.lava_s_cale,
'norm': None, 'dropout': None,
'shared_param': True, 'persistent_state': True, 'requires_grad': False,
'graded_spike': False
}

elif isinstance(sj_ms_neuron, neuron.MultiStepLIFNode):
if sj_ms_neuron.v_reset != 0.:
raise ValueError('lava only supports for v_reset == 0!')
if sj_ms_neuron.decay_input:
raise ValueError('lava only supports for decay_input == False!')
return {
'threshold': sj_ms_neuron.v_threshold,
'current_decay': 1.,
'voltage_decay': 1. / sj_ms_neuron.tau,
'tau_grad': 1, 'scale_grad': 1, 'scale': sj_ms_neuron.lava_s_cale,
'norm': None, 'dropout': None,
'shared_param': True, 'persistent_state': True, 'requires_grad': False,
'graded_spike': False
}
else:
raise NotImplementedError(sj_ms_neuron)


def to_lava_neuron(sj_ms_neuron: nn.Module):
if isinstance(sj_ms_neuron, (neuron.MultiStepIFNode, neuron.MultiStepLIFNode)):
return slayer.neuron.cuba.Neuron(
**to_lava_neuron_param_dict(sj_ms_neuron)
)
else:
raise NotImplementedError(sj_ms_neuron)

def linear_to_lava_synapse_dense(fc: nn.Linear):
"""
:param fc: a pytorch linear layer without bias
:type fc: nn.Linear
:return: a lava slayer dense synapse
:rtype: slayer.synapse.Dense

Codes example:

.. code-block:: python

T = 4
N = 2
layer_nn = nn.Linear(8, 4, bias=False)
layer_sl = lava_exchange.linear_to_lava_synapse_dense(layer_nn)
x_seq = torch.rand([T, N, 8])
with torch.no_grad():
y_nn = functional.seq_to_ann_forward(x_seq, layer_nn)
y_sl = lava_exchange.NXT_to_TNX(layer_sl(lava_exchange.TNX_to_NXT(x_seq)))
print('max error:', (y_nn - y_sl).abs().max())
"""
check_fc(fc)

dense_slayer = slayer.synapse.Dense(fc.in_features, fc.out_features)

# `dense_slayer` is a `torch.torch.nn.Conv3d`. Its weight has shape [out_features, in_features, 1, 1, 1]
dense_slayer.weight.data[:, :, 0, 0, 0] = fc.weight.data.clone()

return dense_slayer

def conv2d_to_lava_synapse_conv(conv2d_nn: nn.Conv2d):
"""
:param conv2d_nn: a pytorch conv2d layer without bias
:type conv2d_nn: nn.Conv2d
:return: a lava slayer conv synapse
:rtype: slayer.synapse.Conv

Codes example:

.. code-block:: python

T = 4
N = 2
layer_nn = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1, bias=False)
layer_sl = lava_exchange.conv2d_to_lava_synapse_conv(layer_nn)
x_seq = torch.rand([T, N, 3, 28, 28])
with torch.no_grad():
y_nn = functional.seq_to_ann_forward(x_seq, layer_nn)
y_sl = lava_exchange.NXT_to_TNX(layer_sl(lava_exchange.TNX_to_NXT(x_seq)))
print('max error:', (y_nn - y_sl).abs().max())
"""
check_conv2d(conv2d_nn)

conv_slayer = slayer.synapse.Conv(in_features=conv2d_nn.in_channels, out_features=conv2d_nn.out_channels, kernel_size=conv2d_nn.kernel_size, stride=conv2d_nn.stride, padding=conv2d_nn.padding, dilation=conv2d_nn.dilation, groups=conv2d_nn.groups)
# `conv_slayer` is a `torch.torch.nn.Conv3d`.
conv_slayer.weight.data[:, :, :, :, 0] = conv2d_nn.weight.data.clone()

return conv_slayer

def avgpool2d_to_lava_synapse_pool(pool2d_nn: nn.AvgPool2d):
"""
:param pool2d_nn: a pytorch AvgPool2d layer
:type pool2d_nn: nn.AvgPool2d
:return: a lava slayer pool layer
:rtype: slayer.synapse.Pool

.. admonition:: Warning
:class: warning

The lava slayer pool layer applies sum pooling, rather than average pooling.

.. code-block:: python

T = 4
N = 2
layer_nn = nn.AvgPool2d(kernel_size=2, stride=2)
layer_sl = lava_exchange.avgpool2d_to_lava_synapse_pool(layer_nn)
x_seq = torch.rand([T, N, 3, 28, 28])
with torch.no_grad():
y_nn = functional.seq_to_ann_forward(x_seq, layer_nn)
y_sl = lava_exchange.NXT_to_TNX(layer_sl(lava_exchange.TNX_to_NXT(x_seq))) / 4.
print('max error:', (y_nn - y_sl).abs().max())
"""
if not isinstance(pool2d_nn, nn.AvgPool2d):
raise ValueError(f'expected pool2d_nn with type torch.nn.Conv2d, but got pool2d_nn with type {type(pool2d_nn)}!')

return slayer.synapse.Pool(pool2d_nn.kernel_size, pool2d_nn.stride, pool2d_nn.padding)

def to_lava_block_dense(fc: nn.Linear, sj_ms_neuron: nn.Module, quantize_to_8bit: bool = True):

check_fc(fc)

neuron_params = to_lava_neuron_param_dict(sj_ms_neuron)
if isinstance(sj_ms_neuron, (neuron.MultiStepIFNode, neuron.MultiStepLIFNode)):
block_init = slayer.block.cuba.Dense
else:
raise NotImplementedError(sj_ms_neuron)


if quantize_to_8bit:
# if 'pre_hook_fx' not in kwargs.keys(), then `pre_hook_fx` will be set to `quantize_8bit` by default
block_lava = block_init(neuron_params, fc.in_features, fc.out_features, delay_shift=False)
else:
block_lava = block_init(neuron_params, fc.in_features, fc.out_features, delay_shift=False, pre_hook_fx=None)

block_lava.synapse.weight.data[:, :, 0, 0, 0] = fc.weight.data.clone()

return block_lava


def to_lava_block_conv(conv2d_nn: nn.Conv2d, sj_ms_neuron: nn.Module, quantize_to_8bit: bool = True):

check_conv2d(conv2d_nn)

neuron_params = to_lava_neuron_param_dict(sj_ms_neuron)
if isinstance(sj_ms_neuron, (neuron.MultiStepIFNode, neuron.MultiStepLIFNode)):
block_init = slayer.block.cuba.Conv
else:
raise NotImplementedError(sj_ms_neuron)

if quantize_to_8bit:
# if 'pre_hook_fx' not in kwargs.keys(), then `pre_hook_fx` will be set to `quantize_8bit` by default
block_lava = block_init(neuron_params, in_features=conv2d_nn.in_channels, out_features=conv2d_nn.out_channels, kernel_size=conv2d_nn.kernel_size, stride=conv2d_nn.stride, padding=conv2d_nn.padding, dilation=conv2d_nn.dilation, groups=conv2d_nn.groups, delay_shift=False)
else:
block_lava = block_init(neuron_params, in_features=conv2d_nn.in_channels, out_features=conv2d_nn.out_channels, kernel_size=conv2d_nn.kernel_size, stride=conv2d_nn.stride, padding=conv2d_nn.padding, dilation=conv2d_nn.dilation, groups=conv2d_nn.groups, delay_shift=False, pre_hook_fx=None)

block_lava.synapse.weight.data[:, :, :, :, 0] = conv2d_nn.weight.data.clone()

return block_lava


def to_lava_block_flatten(flatten_nn: nn.Flatten):
if flatten_nn.start_dim != 1:
raise ValueError('lava only supports for flatten_nn.start_dim == 1!')
return slayer.block.cuba.Flatten()




+ 245
- 41
spikingjelly/clock_driven/layer.py View File

@@ -3,8 +3,11 @@ import torch.nn as nn
import torch.nn.functional as F
import math
from . import base, functional
from torch import Tensor
from torch.nn.common_types import _size_2_t
from typing import Callable
from torch.nn.modules.batchnorm import _BatchNorm


class NeuNorm(base.MemoryModule):
def __init__(self, in_channels, height, width, k=0.9, shared_across_channels=False):
@@ -71,12 +74,12 @@ class NeuNorm(base.MemoryModule):
self.k0 = k
self.k1 = (1. - self.k0) / in_channels ** 2
if shared_across_channels:
self.w = nn.Parameter(torch.Tensor(1, height, width))
self.w = nn.Parameter(Tensor(1, height, width))
else:
self.w = nn.Parameter(torch.Tensor(in_channels, height, width))
self.w = nn.Parameter(Tensor(in_channels, height, width))
nn.init.kaiming_uniform_(self.w, a=math.sqrt(5))

def forward(self, in_spikes: torch.Tensor):
def forward(self, in_spikes: Tensor):
self.x = self.k0 * self.x + self.k1 * in_spikes.sum(dim=1,
keepdim=True) # x.shape = [batch_size, 1, height, width]
return in_spikes - self.w * self.x
@@ -119,7 +122,7 @@ class DCT(nn.Module):
else:
self.kernel[i][j] = math.sqrt(2 / kernel_size) * math.cos((j + 0.5) * math.pi * i / kernel_size)

def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
if self.kernel.device != x.device:
self.kernel = self.kernel.to(x.device)
x_shape = x.shape
@@ -160,10 +163,10 @@ class AXAT(nn.Module):
The input will be regarded as a batch of tensors with ``shape = [in_features, in_features]``.
"""
super().__init__()
self.A = nn.Parameter(torch.Tensor(out_features, in_features))
self.A = nn.Parameter(Tensor(out_features, in_features))
nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))

def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
x_shape = list(x.shape)
x = x.view(-1, x_shape[-2], x_shape[-1])
x = self.A.matmul(x).matmul(self.A.t())
@@ -241,10 +244,10 @@ class Dropout(base.MemoryModule):
def extra_repr(self):
return f'p={self.p}'

def create_mask(self, x: torch.Tensor):
def create_mask(self, x: Tensor):
self.mask = F.dropout(torch.ones_like(x.data), self.p, training=True)

def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
if self.training:
if self.mask is None:
self.create_mask(x)
@@ -284,7 +287,7 @@ class Dropout2d(Dropout):
"""
super().__init__(p)

def create_mask(self, x: torch.Tensor):
def create_mask(self, x: Tensor):
self.mask = F.dropout2d(torch.ones_like(x.data), self.p, training=True)


@@ -321,7 +324,7 @@ class MultiStepDropout(Dropout):
"""
super().__init__(p)

def forward(self, x_seq: torch.Tensor):
def forward(self, x_seq: Tensor):
if self.training:
if self.mask is None:
self.create_mask(x_seq[0])
@@ -364,7 +367,7 @@ class MultiStepDropout2d(Dropout2d):
"""
super().__init__(p)

def forward(self, x_seq: torch.Tensor):
def forward(self, x_seq: Tensor):
if self.training:
if self.mask is None:
self.create_mask(x_seq[0])
@@ -526,7 +529,7 @@ class SynapseFilter(base.MemoryModule):

return f'tau={tau}, learnable={self.learnable}'

def forward(self, in_spikes: torch.Tensor):
def forward(self, in_spikes: Tensor):
if self.learnable:
inv_tau = self.w.sigmoid()
else:
@@ -536,6 +539,7 @@ class SynapseFilter(base.MemoryModule):

return self.out_i


class ChannelsPool(nn.Module):
def __init__(self, pool: nn.MaxPool1d or nn.AvgPool1d):
"""
@@ -578,7 +582,7 @@ class ChannelsPool(nn.Module):
super().__init__()
self.pool = pool

def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
x_shape = x.shape
return self.pool(x.flatten(2).permute(0, 2, 1)).permute(0, 2, 1).view((x_shape[0], -1) + x_shape[2:])

@@ -660,9 +664,9 @@ class DropConnectLinear(base.MemoryModule):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
self.weight = nn.Parameter(Tensor(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_features))
self.bias = nn.Parameter(Tensor(out_features))
else:
self.register_parameter('bias', None)

@@ -738,7 +742,7 @@ class DropConnectLinear(base.MemoryModule):
# self.dropped_b = mask_b.to(self.bias) * self.bias
self.dropped_b = self.bias * mask_b

def forward(self, input: torch.Tensor) -> torch.Tensor:
def forward(self, input: Tensor) -> Tensor:
if self.training:
if self.invariant:
if self.dropped_w is None:
@@ -809,21 +813,15 @@ class MultiStepContainer(nn.Sequential):
"""
super().__init__(*args)

def forward(self, x_seq: torch.Tensor):
def forward(self, x_seq: Tensor):
"""
:param x_seq: shape=[T, batch_size, ...]
:type x_seq: torch.Tensor
:type x_seq: Tensor
:return: y_seq, shape=[T, batch_size, ...]
:rtype: torch.Tensor
:rtype: Tensor
"""
y_seq = []
for t in range(x_seq.shape[0]):
y_seq.append(super().forward(x_seq[t]))

for t in range(y_seq.__len__()):
# y_seq[t].unsqueeze_(0)
y_seq[t] = y_seq[t].unsqueeze(0)
return torch.cat(y_seq, 0)
return functional.multi_step_forward(x_seq, self)


class SeqToANNContainer(nn.Sequential):
@@ -869,17 +867,14 @@ class SeqToANNContainer(nn.Sequential):
"""
super().__init__(*args)

def forward(self, x_seq: torch.Tensor):
def forward(self, x_seq: Tensor):
"""
:param x_seq: shape=[T, batch_size, ...]
:type x_seq: torch.Tensor
:type x_seq: Tensor
:return: y_seq, shape=[T, batch_size, ...]
:rtype: torch.Tensor
:rtype: Tensor
"""
y_shape = [x_seq.shape[0], x_seq.shape[1]]
y_seq = super().forward(x_seq.flatten(0, 1))
y_shape.extend(y_seq.shape[1:])
return y_seq.view(y_shape)
return functional.seq_to_ann_forward(x_seq, self)


class STDPLearner(base.MemoryModule):
@@ -970,7 +965,7 @@ class STDPLearner(base.MemoryModule):
self.f_post = f_post

@torch.no_grad()
def stdp(self, s_pre: torch.Tensor, s_post: torch.Tensor, module: nn.Module, learning_rate: float):
def stdp(self, s_pre: Tensor, s_post: Tensor, module: nn.Module, learning_rate: float):
if isinstance(module, nn.Linear):
# update trace
self.trace_pre += - self.trace_pre / self.tau_pre + s_pre
@@ -1009,7 +1004,7 @@ class PrintShapeModule(nn.Module):
super().__init__()
self.ext_str = ext_str

def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
print(self.ext_str, x.shape)
return x

@@ -1064,20 +1059,20 @@ class ConvBatchNorm2d(nn.Module):
self.bn = nn.BatchNorm2d(num_features=out_channels, eps=eps, momentum=momentum, affine=affine,
track_running_stats=track_running_stats)

def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
return self.bn(self.conv(x))

def get_fused_weight(self):
"""
:return: the weight of this fused module
:rtype: torch.Tensor
:rtype: Tensor
"""
return functional.fused_conv2d_weight_of_convbn2d(self.conv, self.bn)

def get_fused_bias(self):
"""
:return: the bias of this fused module
:rtype: torch.Tensor
:rtype: Tensor
"""
return functional.fused_conv2d_bias_of_convbn2d(self.conv, self.bn)

@@ -1108,6 +1103,7 @@ class ConvBatchNorm2d(nn.Module):
def get_fused_conv(self):
return functional.fuse_convbn2d(self.conv, self.bn)


class ElementWiseRecurrentContainer(base.MemoryModule):
def __init__(self, sub_module: nn.Module, element_wise_function: Callable):
"""
@@ -1150,7 +1146,7 @@ class ElementWiseRecurrentContainer(base.MemoryModule):
self.element_wise_function = element_wise_function
self.register_memory('y', None)

def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
if self.y is None:
self.y = torch.zeros_like(x.data)
self.y = self.sub_module(self.element_wise_function(self.y, x))
@@ -1159,6 +1155,7 @@ class ElementWiseRecurrentContainer(base.MemoryModule):
def extra_repr(self) -> str:
return f'element-wise function={self.element_wise_function}'


class LinearRecurrentContainer(base.MemoryModule):
def __init__(self, sub_module: nn.Module, in_features: int, out_features: int, bias: bool = True) -> None:
"""
@@ -1214,7 +1211,7 @@ class LinearRecurrentContainer(base.MemoryModule):
self.sub_module = sub_module
self.register_memory('y', None)

def forward(self, x: torch.Tensor):
def forward(self, x: Tensor):
if self.y is None:
if x.ndim == 2:
self.y = torch.zeros([x.shape[0], self.sub_module_out_features]).to(x)
@@ -1225,4 +1222,211 @@ class LinearRecurrentContainer(base.MemoryModule):
self.y = torch.zeros(out_shape).to(x)
x = torch.cat((x, self.y), dim=-1)
self.y = self.sub_module(self.rc(x))
return self.y
return self.y
class _MultiStepThresholdDependentBatchNormBase(_BatchNorm):
def __init__(self, alpha: float, v_th: float, *args, **kwargs):
super().__init__(*args, **kwargs)
self.alpha = alpha
self.v_th = v_th
assert self.affine, "ThresholdDependentBatchNorm needs to set `affine = True`!"
torch.nn.init.constant_(self.weight, alpha * v_th)

def forward(self, x_seq):
y_shape = [x_seq.shape[0], x_seq.shape[1]]
y = x_seq.flatten(0, 1)
y = super().forward(y)
y_shape.extend(y.shape[1:])
return y.view(y_shape)


class MultiStepThresholdDependentBatchNorm1d(_MultiStepThresholdDependentBatchNormBase):
def __init__(self, alpha: float, v_th: float, *args, **kwargs):
"""
* :ref:`API in English <MultiStepThresholdDependentBatchNorm1d.__init__-en>`

.. _MultiStepThresholdDependentBatchNorm1d.__init__-cn:

:param alpha: 由网络结构决定的超参数
:type alpha: float
:param v_th: 下一个脉冲神经元层的阈值
:type v_th: float

``*args, **kwargs`` 中的参数与 :class:`torch.nn.BatchNorm1d` 的参数相同。

`Going Deeper With Directly-Trained Larger Spiking Neural Networks <https://arxiv.org/abs/2011.05280>`_ 一文提出
的Threshold-Dependent Batch Normalization (tdBN)。

* :ref:`中文API <MultiStepThresholdDependentBatchNorm1d.__init__-cn>`

.. _MultiStepThresholdDependentBatchNorm1d.__init__-en:

:param alpha: the hyper-parameter depending on network structure
:type alpha: float
:param v_th: the threshold of next spiking neurons layer
:type v_th: float

Other parameters in ``*args, **kwargs`` are same with those of :class:`torch.nn.BatchNorm1d`.

The Threshold-Dependent Batch Normalization (tdBN) proposed in `Going Deeper With Directly-Trained Larger Spiking Neural Networks <https://arxiv.org/abs/2011.05280>`_.
"""
super().__init__(alpha, v_th, *args, **kwargs)

def _check_input_dim(self, x):
if x.dim() != 2 and x.dim() != 3:
raise ValueError(
f'expected 3D or 4D input with shape [T, N, C] or [T, N, C, M], but got input with shape {x.shape}')


class MultiStepThresholdDependentBatchNorm2d(_MultiStepThresholdDependentBatchNormBase):
def __init__(self, alpha: float, v_th: float, *args, **kwargs):
"""
* :ref:`API in English <MultiStepThresholdDependentBatchNorm2d.__init__-en>`

.. _MultiStepThresholdDependentBatchNorm2d.__init__-cn:

:param alpha: 由网络结构决定的超参数
:type alpha: float
:param v_th: 下一个脉冲神经元层的阈值
:type v_th: float

``*args, **kwargs`` 中的参数与 :class:`torch.nn.BatchNorm2d` 的参数相同。

`Going Deeper With Directly-Trained Larger Spiking Neural Networks <https://arxiv.org/abs/2011.05280>`_ 一文提出
的Threshold-Dependent Batch Normalization (tdBN)。

* :ref:`中文API <MultiStepThresholdDependentBatchNorm2d.__init__-cn>`

.. _MultiStepThresholdDependentBatchNorm2d.__init__-en:

:param alpha: the hyper-parameter depending on network structure
:type alpha: float
:param v_th: the threshold of next spiking neurons layer
:type v_th: float

Other parameters in ``*args, **kwargs`` are same with those of :class:`torch.nn.BatchNorm2d`.

The Threshold-Dependent Batch Normalization (tdBN) proposed in `Going Deeper With Directly-Trained Larger Spiking Neural Networks <https://arxiv.org/abs/2011.05280>`_.
"""
super().__init__(alpha, v_th, *args, **kwargs)

def _check_input_dim(self, x):
if x.dim() != 4:
raise ValueError(f'expected 5D input with shape [T, N, C, H, W], but got input with shape {x.shape}')


class MultiStepThresholdDependentBatchNorm3d(_MultiStepThresholdDependentBatchNormBase):
def __init__(self, alpha: float, v_th: float, *args, **kwargs):
"""
* :ref:`API in English <MultiStepThresholdDependentBatchNorm3d.__init__-en>`

.. _MultiStepThresholdDependentBatchNorm3d.__init__-cn:

:param alpha: 由网络结构决定的超参数
:type alpha: float
:param v_th: 下一个脉冲神经元层的阈值
:type v_th: float

``*args, **kwargs`` 中的参数与 :class:`torch.nn.BatchNorm3d` 的参数相同。

`Going Deeper With Directly-Trained Larger Spiking Neural Networks <https://arxiv.org/abs/2011.05280>`_ 一文提出
的Threshold-Dependent Batch Normalization (tdBN)。

* :ref:`中文API <MultiStepThresholdDependentBatchNorm3d.__init__-cn>`

.. _MultiStepThresholdDependentBatchNorm3d.__init__-en:

:param alpha: the hyper-parameter depending on network structure
:type alpha: float
:param v_th: the threshold of next spiking neurons layer
:type v_th: float

Other parameters in ``*args, **kwargs`` are same with those of :class:`torch.nn.BatchNorm3d`.

The Threshold-Dependent Batch Normalization (tdBN) proposed in `Going Deeper With Directly-Trained Larger Spiking Neural Networks <https://arxiv.org/abs/2011.05280>`_.
"""
super().__init__(alpha, v_th, *args, **kwargs)

def _check_input_dim(self, x):
if x.dim() != 5:
raise ValueError(f'expected 6D input with shape [T, N, C, D, H, W], but got input with shape {x.shape}')


class MultiStepTemporalWiseAttention(nn.Module):
def __init__(self, T: int, reduction: int = 16, dimension: int = 4):
"""
* :ref:`API in English <MultiStepTemporalWiseAttention.__init__-en>`

.. _MultiStepTemporalWiseAttention.__init__-cn:

:param T: 输入数据的时间步长

:param reduction: 压缩比

:param dimension: 输入数据的维度。当输入数据为[T, N, C, H, W]时, dimension = 4;输入数据维度为[T, N, L]时,dimension = 2。

`Temporal-Wise Attention Spiking Neural Networks for Event Streams Classification <https://openaccess.thecvf.com/content/ICCV2021/html/Yao_Temporal-Wise_Attention_Spiking_Neural_Networks_for_Event_Streams_Classification_ICCV_2021_paper.html>`_ 中提出
的MultiStepTemporalWiseAttention层。MultiStepTemporalWiseAttention层必须放在二维卷积层之后脉冲神经元之前,例如:

``Conv2d -> MultiStepTemporalWiseAttention -> LIF``

输入的尺寸是 ``[T, N, C, H, W]`` 或者 ``[T, N, L]`` ,经过MultiStepTemporalWiseAttention层,输出为 ``[T, N, C, H, W]`` 或者 ``[T, N, L]`` 。

``reduction`` 是压缩比,相当于论文中的 :math:`r`。

* :ref:`中文API <MultiStepTemporalWiseAttention.__init__-cn>`

.. _MultiStepTemporalWiseAttention.__init__-en:

:param T: timewindows of input

:param reduction: reduction ratio

:param dimension: Dimensions of input. If the input dimension is [T, N, C, H, W], dimension = 4; when the input dimension is [T, N, L], dimension = 2.

The MultiStepTemporalWiseAttention layer is proposed in `Temporal-Wise Attention Spiking Neural Networks for Event Streams Classification <https://openaccess.thecvf.com/content/ICCV2021/html/Yao_Temporal-Wise_Attention_Spiking_Neural_Networks_for_Event_Streams_Classification_ICCV_2021_paper.html>`_.

It should be placed after the convolution layer and before the spiking neurons, e.g.,

``Conv2d -> MultiStepTemporalWiseAttention -> LIF``

The dimension of the input is ``[T, N, C, H, W]`` or ``[T, N, L]`` , after the MultiStepTemporalWiseAttention layer, the output dimension is ``[T, N, C, H, W]`` or ``[T, N, L]`` .

``reduction`` is the reduction ratio,which is :math:`r` in the paper.

"""
super().__init__()
assert dimension == 4 or dimension == 2, 'dimension must be 4 or 2'

self.dimension = dimension

# Sequence
if self.dimension == 2:
self.avg_pool = nn.AdaptiveAvgPool1d(1)
self.max_pool = nn.AdaptiveMaxPool1d(1)
elif self.dimension == 4:
self.avg_pool = nn.AdaptiveAvgPool3d(1)
self.max_pool = nn.AdaptiveMaxPool3d(1)

assert T >= reduction, 'reduction cannot be greater than T'

# Excitation
self.sharedMLP = nn.Sequential(
nn.Linear(T, T // reduction, bias=False),
nn.ReLU(),
nn.Linear(T // reduction, T, bias=False)
)

self.sigmoid = nn.Sigmoid()

def forward(self, x_seq: torch.Tensor):
assert x_seq.dim() == 3 or x_seq.dim() == 5, ValueError(f'expected 3D or 5D input with shape [T, N, M] or [T, N, C, H, W], but got input with shape {x_seq.shape}')
x_seq = x_seq.transpose(0, 1)
avgout = self.sharedMLP(self.avg_pool(x_seq).view([x_seq.shape[0], x_seq.shape[1]]))
maxout = self.sharedMLP(self.max_pool(x_seq).view([x_seq.shape[0], x_seq.shape[1]]))
scores = self.sigmoid(avgout + maxout)
if self.dimension == 2:
y_seq = x_seq * scores[:, :, None]
elif self.dimension == 4:
y_seq = x_seq * scores[:, :, None, None, None]
y_seq = y_seq.transpose(0, 1)
return y_seq

+ 1
- 3
spikingjelly/clock_driven/model/parametric_lif_net.py View File

@@ -15,9 +15,7 @@ class VotingLayer(nn.Module):
self.voting_size = voting_size

def forward(self, x: torch.Tensor):
x.unsqueeze_(1) # [N, C] -> [N, 1, C]
y = F.avg_pool1d(x, self.voting_size, self.voting_size)
y.squeeze_(1)
y = F.avg_pool1d(x.unsqueeze(1), self.voting_size, self.voting_size).squeeze(1)
return y




+ 5
- 2
spikingjelly/clock_driven/model/sew_resnet.py View File

@@ -1,8 +1,11 @@
import torch
import torch.nn as nn
from .. import functional
from torchvision.models.utils import load_state_dict_from_url

try:
from torchvision.models.utils import load_state_dict_from_url
except ImportError:
from torchvision._internally_replaced_utils import load_state_dict_from_url
__all__ = ['SEWResNet', 'sew_resnet18', 'sew_resnet34', 'sew_resnet50', 'sew_resnet101',
'sew_resnet152', 'sew_resnext50_32x4d', 'sew_resnext101_32x8d',
'sew_wide_resnet50_2', 'sew_wide_resnet101_2',


+ 4
- 1
spikingjelly/clock_driven/model/spiking_resnet.py View File

@@ -1,7 +1,10 @@
import torch
import torch.nn as nn
from .. import functional
from torchvision.models.utils import load_state_dict_from_url
try:
from torchvision.models.utils import load_state_dict_from_url
except ImportError:
from torchvision._internally_replaced_utils import load_state_dict_from_url

__all__ = ['SpikingResNet', 'spiking_resnet18', 'spiking_resnet34', 'spiking_resnet50', 'spiking_resnet101',
'spiking_resnet152', 'spiking_resnext50_32x4d', 'spiking_resnext101_32x8d',


+ 12
- 10
spikingjelly/clock_driven/model/spiking_vgg.py View File

@@ -1,8 +1,10 @@
import torch
import torch.nn as nn
from spikingjelly.clock_driven import functional, neuron
from torchvision.models.utils import load_state_dict_from_url

try:
from torchvision.models.utils import load_state_dict_from_url
except ImportError:
from torchvision._internally_replaced_utils import load_state_dict_from_url

__all__ = [
'SpikingVGG', 'MultiStepSpikingVGG',
@@ -356,7 +358,7 @@ def spiking_vgg16(pretrained=False, progress=True, single_step_neuron: callable
A spiking version of VGG-16 model from `"Very Deep Convolutional Networks for Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
"""

return _spiking_vgg('vgg16', 'C', False, pretrained, progress, None, single_step_neuron, **kwargs)
return _spiking_vgg('vgg16', 'D', False, pretrained, progress, None, single_step_neuron, **kwargs)


def multi_step_spiking_vgg16(pretrained=False, progress=True, T: int = None, multi_step_neuron: callable = None, **kwargs):
@@ -377,7 +379,7 @@ def multi_step_spiking_vgg16(pretrained=False, progress=True, T: int = None, mul
A multi-step spiking version of VGG-16 model from `"Very Deep Convolutional Networks for Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
"""

return _multi_step_spiking_vgg('vgg16', 'C', False, pretrained, progress, None, T, multi_step_neuron, **kwargs)
return _multi_step_spiking_vgg('vgg16', 'D', False, pretrained, progress, None, T, multi_step_neuron, **kwargs)


def spiking_vgg16_bn(pretrained=False, progress=True, norm_layer: callable = None, single_step_neuron: callable = None, **kwargs):
@@ -398,7 +400,7 @@ def spiking_vgg16_bn(pretrained=False, progress=True, norm_layer: callable = Non
A spiking version of VGG-16-BN model from `"Very Deep Convolutional Networks for Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
"""

return _spiking_vgg('vgg16', 'C', True, pretrained, progress, norm_layer, single_step_neuron, **kwargs)
return _spiking_vgg('vgg16', 'D', True, pretrained, progress, norm_layer, single_step_neuron, **kwargs)


def multi_step_spiking_vgg16_bn(pretrained=False, progress=True, norm_layer: callable = None, T: int = None, multi_step_neuron: callable = None, **kwargs):
@@ -421,7 +423,7 @@ def multi_step_spiking_vgg16_bn(pretrained=False, progress=True, norm_layer: cal
A multi-step spiking version of VGG-16-BN model from `"Very Deep Convolutional Networks for Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
"""

return _multi_step_spiking_vgg('vgg16', 'C', True, pretrained, progress, norm_layer, T, multi_step_neuron, **kwargs)
return _multi_step_spiking_vgg('vgg16', 'D', True, pretrained, progress, norm_layer, T, multi_step_neuron, **kwargs)


def spiking_vgg19(pretrained=False, progress=True, single_step_neuron: callable = None, **kwargs):
@@ -440,7 +442,7 @@ def spiking_vgg19(pretrained=False, progress=True, single_step_neuron: callable
A spiking version of VGG-19 model from `"Very Deep Convolutional Networks for Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
"""

return _spiking_vgg('vgg19', 'D', False, pretrained, progress, None, single_step_neuron, **kwargs)
return _spiking_vgg('vgg19', 'E', False, pretrained, progress, None, single_step_neuron, **kwargs)


def multi_step_spiking_vgg19(pretrained=False, progress=True, T: int = None, multi_step_neuron: callable = None, **kwargs):
@@ -461,7 +463,7 @@ def multi_step_spiking_vgg19(pretrained=False, progress=True, T: int = None, mul
A multi-step spiking version of VGG-19 model from `"Very Deep Convolutional Networks for Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
"""

return _multi_step_spiking_vgg('vgg19', 'D', False, pretrained, progress, None, T, multi_step_neuron, **kwargs)
return _multi_step_spiking_vgg('vgg19', 'E', False, pretrained, progress, None, T, multi_step_neuron, **kwargs)


def spiking_vgg19_bn(pretrained=False, progress=True, norm_layer: callable = None, single_step_neuron: callable = None, **kwargs):
@@ -482,7 +484,7 @@ def spiking_vgg19_bn(pretrained=False, progress=True, norm_layer: callable = Non
A spiking version of VGG-19-BN model from `"Very Deep Convolutional Networks for Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
"""

return _spiking_vgg('vgg19', 'D', True, pretrained, progress, norm_layer, single_step_neuron, **kwargs)
return _spiking_vgg('vgg19', 'E', True, pretrained, progress, norm_layer, single_step_neuron, **kwargs)


def multi_step_spiking_vgg19_bn(pretrained=False, progress=True, norm_layer: callable = None, T: int = None, multi_step_neuron: callable = None, **kwargs):
@@ -505,5 +507,5 @@ def multi_step_spiking_vgg19_bn(pretrained=False, progress=True, norm_layer: cal
A multi-step spiking version of VGG-19-BN model from `"Very Deep Convolutional Networks for Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
"""

return _multi_step_spiking_vgg('vgg19', 'D', True, pretrained, progress, norm_layer, T, multi_step_neuron, **kwargs)
return _multi_step_spiking_vgg('vgg19', 'E', True, pretrained, progress, norm_layer, T, multi_step_neuron, **kwargs)


+ 565
- 98
spikingjelly/clock_driven/neuron.py View File

@@ -1,15 +1,37 @@
from abc import abstractmethod
from typing import Callable
from typing import Callable, overload
import torch
import torch.nn as nn
from . import surrogate, base
from . import surrogate, base, lava_exchange
from .. import configure
import math
import numpy as np
import logging
try:
import cupy
from . import neuron_kernel, cu_kernel_opt
except ImportError:
except BaseException as e:
logging.info(f'spikingjelly.clock_driven.neuron: {e}')
cupy = None
neuron_kernel = None
cu_kernel_opt = None

try:
import lava.lib.dl.slayer as slayer

except BaseException as e:
logging.info(f'spikingjelly.clock_driven.neuron: {e}')
slayer = None

def check_backend(backend: str):
if backend == 'torch':
return
elif backend == 'cupy':
assert cupy is not None, 'CuPy is not installed! You can install it from "https://github.com/cupy/cupy".'
elif backend == 'lava':
assert slayer is not None, 'Lava-DL is not installed! You can install it from "https://github.com/lava-nc/lava-dl".'
else:
raise NotImplementedError(backend)

class BaseNode(base.MemoryModule):
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
@@ -60,13 +82,11 @@ class BaseNode(base.MemoryModule):

if v_reset is None:
self.register_memory('v', 0.)
self.register_memory('spike', 0.)
else:
self.register_memory('v', v_reset)
self.register_memory('spike', 0.)

self.v_threshold = v_threshold
self.v_reset = v_reset
self.register_memory('v_threshold', v_threshold)
self.register_memory('v_reset', v_reset)

self.detach_reset = detach_reset
self.surrogate_function = surrogate_function
@@ -105,9 +125,9 @@ class BaseNode(base.MemoryModule):
Calculate out spikes of neurons by their current membrane potential and threshold voltage.
"""

self.spike = self.surrogate_function(self.v - self.v_threshold)
return self.surrogate_function(self.v - self.v_threshold)

def neuronal_reset(self):
def neuronal_reset(self, spike):
"""
* :ref:`API in English <BaseNode.neuronal_reset-en>`

@@ -123,17 +143,17 @@ class BaseNode(base.MemoryModule):
Reset the membrane potential according to neurons' output spikes.
"""
if self.detach_reset:
spike = self.spike.detach()
spike_d = spike.detach()
else:
spike = self.spike
spike_d = spike

if self.v_reset is None:
# soft reset
self.v = self.v - spike * self.v_threshold
self.v = self.v - spike_d * self.v_threshold

else:
# hard reset
self.v = (1. - spike) * self.v + spike * self.v_reset
self.v = (1. - spike_d) * self.v + spike_d * self.v_reset

def extra_repr(self):
return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}'
@@ -167,14 +187,50 @@ class BaseNode(base.MemoryModule):

"""
self.neuronal_charge(x)
self.neuronal_fire()
self.neuronal_reset()
return self.spike
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return spike

class AdaptiveBaseNode(BaseNode):
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
v_rest: float = 0., w_rest: float = 0, tau_w: float = 2., a: float = 0., b: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
# b: jump amplitudes
# a: subthreshold coupling
assert isinstance(w_rest, float)
assert isinstance(v_rest, float)
assert isinstance(tau_w, float)
assert isinstance(a, float)
assert isinstance(b, float)

super.__init__(v_threshold, v_reset, surrogate_function, detach_reset)

self.register_memory('w', w_rest)

self.w_rest = w_rest
self.v_rest = v_rest
self.tau_w = tau_w
self.a = a
self.b = b


def neuronal_adaptation(self, spike):
self.w = self.w + 1. / self.tau_w * (self.a * (self.v - self.v_rest) - self.w) + self.b * spike

def extra_repr(self):
return super().extra_repr() + f', v_rest={self.v_rest}, w_rest={self.w_rest}, tau_w={self.tau_w}, a={self.a}, b={self.b}'

@overload
def forward(self, x: torch.Tensor):
self.neuronal_charge(x)
spike = self.neuronal_fire()
self.neuronal_adaptation(spike)
self.neuronal_reset(spike)
return spike

class IFNode(BaseNode):
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, cupy_fp32_inference=False):
"""
* :ref:`API in English <IFNode.__init__-en>`

@@ -193,6 +249,9 @@ class IFNode(BaseNode):
:param detach_reset: 是否将reset过程的计算图分离
:type detach_reset: bool

:param cupy_fp32_inference: 若为 `True`,在 `eval` 模式下,使用float32,却在GPU上运行,并且 `cupy` 已经安装,则会自动使用 `cupy` 进行加速
:type cupy_fp32_inference: bool

Integrate-and-Fire 神经元模型,可以看作理想积分器,无输入时电压保持恒定,不会像LIF神经元那样衰减。其阈下神经动力学方程为:

.. math::
@@ -215,21 +274,108 @@ class IFNode(BaseNode):
:param detach_reset: whether detach the computation graph of reset
:type detach_reset: bool

:param cupy_fp32_inference: If `True`, if this module is in `eval` mode, using float32, running on GPU, and `cupy` is installed, then this
module will use `cupy` to accelerate
:type cupy_fp32_inference: bool

The Integrate-and-Fire neuron, which can be seen as a ideal integrator. The voltage of the IF neuron will not decay
as that of the LIF neuron. The subthreshold neural dynamics of it is as followed:

.. math::
V[t] = V[t-1] + X[t]

"""
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)

if cupy_fp32_inference:
check_backend('cupy')
self.cupy_fp32_inference = cupy_fp32_inference

def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + x

def forward(self, x: torch.Tensor):
if self.cupy_fp32_inference and cupy is not None and not self.training and x.dtype == torch.float32:
# cupy is installed && eval mode && fp32
device_id = x.get_device()
if device_id < 0:
return super().forward(x)

# use cupy to accelerate
if isinstance(self.v, float):
v = torch.zeros_like(x)
if self.v != 0.:
torch.fill_(v, self.v)
self.v = v

if self.v_reset is None:
hard_reset = False
else:
hard_reset = True

code = rf'''
extern "C" __global__
void IFNode_{'hard' if hard_reset else 'soft'}_reset_inference_forward(
const float * x, const float & v_threshold, {'const float & v_reset,' if hard_reset else ''}
float * spike, float * v,
const int & numel)
'''

code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < numel)
{
v[index] += x[index];
spike[index] = (float) (v[index] >= v_threshold);
'''

code += rf'''
{'v[index] = (1.0f - spike[index]) * v[index] + spike[index] * v_reset;' if hard_reset else 'v[index] -= spike[index] * v_threshold;'}
'''

code += r'''
}
}
'''
if hasattr(self, 'cp_kernel'):
if self.cp_kernel.code != code:
# replace codes
del self.cp_kernel
self.cp_kernel = cupy.RawKernel(code, f"IFNode_{'hard' if hard_reset else 'soft'}_reset_inference_forward", options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend)
else:
self.cp_kernel = cupy.RawKernel(code, f"IFNode_{'hard' if hard_reset else 'soft'}_reset_inference_forward", options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend)

with cu_kernel_opt.DeviceEnvironment(device_id):
numel = x.numel()
threads = configure.cuda_threads
blocks = cu_kernel_opt.cal_blocks(numel)
cp_numel = cupy.asarray(numel)
cp_v_threshold = cupy.asarray(self.v_threshold, dtype=np.float32)
if hard_reset:
cp_v_reset = cupy.asarray(self.v_reset, dtype=np.float32)

spike = torch.zeros_like(x)
if hard_reset:
x, cp_v_threshold, cp_v_reset, spike, self.v, cp_numel = cu_kernel_opt.get_contiguous(x, cp_v_threshold, cp_v_reset, spike, self.v, cp_numel)
kernel_args = [x, cp_v_threshold, cp_v_reset, spike, self.v, cp_numel]
else:
x, cp_v_threshold, spike, self.v, cp_numel = cu_kernel_opt.get_contiguous(x, cp_v_threshold, spike, self.v, cp_numel)
kernel_args = [x, cp_v_threshold, spike, self.v, cp_numel]
self.cp_kernel(
(blocks,), (threads,),
cu_kernel_opt.wrap_args_to_raw_kernel(
device_id,
*kernel_args
)
)
return spike
else:
return super().forward(x)

class MultiStepIFNode(IFNode):
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, backend='torch'):
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, backend='torch', lava_s_cale=1 << 6):
"""
* :ref:`API in English <MultiStepIFNode.__init__-en>`

@@ -301,24 +447,32 @@ class MultiStepIFNode(IFNode):
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)

self.register_memory('v_seq', None)
self.register_memory('spike_seq', None)

assert backend == 'torch' or backend == 'cupy'
assert not (backend == 'cupy' and neuron_kernel is None), 'cupy is not installed'
check_backend(backend)

self.backend = backend

self.lava_s_cale = lava_s_cale

if backend == 'lava':
self.lava_neuron = self.to_lava()
else:
self.lava_neuron = None


def forward(self, x_seq: torch.Tensor):
assert x_seq.dim() > 1
# x_seq.shape = [T, *]
self.v_seq = torch.zeros_like(x_seq.data)
self.spike_seq = torch.zeros_like(x_seq.data)

if self.backend == 'torch':
spike_seq = []
self.v_seq = []
for t in range(x_seq.shape[0]):
self.spike_seq[t] = super().forward(x_seq[t])
self.v_seq[t] = self.v
return self.spike_seq
spike_seq.append(super().forward(x_seq[t]).unsqueeze(0))
self.v_seq.append(self.v.unsqueeze(0))
spike_seq = torch.cat(spike_seq, 0)
self.v_seq = torch.cat(self.v_seq, 0)
return spike_seq

elif self.backend == 'cupy':
if isinstance(self.v, float):
@@ -327,27 +481,43 @@ class MultiStepIFNode(IFNode):
if v_init != 0.:
torch.fill_(self.v, v_init)

self.spike_seq, self.v_seq = neuron_kernel.MultiStepIFNodePTT.apply(
spike_seq, self.v_seq = neuron_kernel.MultiStepIFNodePTT.apply(
x_seq.flatten(1), self.v.flatten(0), self.v_threshold, self.v_reset, self.detach_reset, self.surrogate_function.cuda_code)

self.spike_seq = self.spike_seq.reshape(x_seq.shape)
spike_seq = spike_seq.reshape(x_seq.shape)
self.v_seq = self.v_seq.reshape(x_seq.shape)


self.spike = self.spike_seq[-1].clone()
self.v = self.v_seq[-1].clone()

return self.spike_seq
return spike_seq

elif self.backend == 'lava':
if self.lava_neuron is None:
self.lava_neuron = self.to_lava()

spike, self.v = lava_exchange.lava_neuron_forward(self.lava_neuron, x_seq, self.v)

return spike

else:
raise NotImplementedError
raise NotImplementedError(self.backend)

def extra_repr(self):
return super().extra_repr() + f', backend={self.backend}'

def to_lava(self):
return lava_exchange.to_lava_neuron(self)

def reset(self):
super().reset()
if self.lava_neuron is not None:
self.lava_neuron.current_state.zero_()
self.lava_neuron.voltage_state.zero_()

class LIFNode(BaseNode):
def __init__(self, tau: float = 2., v_threshold: float = 1.,
def __init__(self, tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
detach_reset: bool = False):
detach_reset: bool = False, cupy_fp32_inference=False):
"""
* :ref:`API in English <LIFNode.__init__-en>`

@@ -356,6 +526,9 @@ class LIFNode(BaseNode):
:param tau: 膜电位时间常数
:type tau: float

:param decay_input: 输入是否会衰减
:type decay_input: bool

:param v_threshold: 神经元的阈值电压
:type v_threshold: float

@@ -369,11 +542,24 @@ class LIFNode(BaseNode):
:param detach_reset: 是否将reset过程的计算图分离
:type detach_reset: bool

:param cupy_fp32_inference: 若为 `True`,在 `eval` 模式下,使用float32,却在GPU上运行,并且 `cupy` 已经安装,则会自动使用 `cupy` 进行加速
:type cupy_fp32_inference: bool

Leaky Integrate-and-Fire 神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:

.. math::
V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})
若 ``decay_input == True``:

.. math::
V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))

若 ``decay_input == False``:

.. math::
V[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]

.. tip::

在 `eval` 模式下,使用float32,却在GPU上运行,并且 `cupy` 已经安装,则会自动使用 `cupy` 进行加速。

* :ref:`中文API <LIFNode.__init__-cn>`

@@ -382,6 +568,9 @@ class LIFNode(BaseNode):
:param tau: membrane time constant
:type tau: float

:param decay_input: whether the input will decay
:type decay_input: bool

:param v_threshold: threshold voltage of neurons
:type v_threshold: float

@@ -395,34 +584,159 @@ class LIFNode(BaseNode):
:param detach_reset: whether detach the computation graph of reset
:type detach_reset: bool

:param cupy_fp32_inference: If `True`, if this module is in `eval` mode, using float32, running on GPU, and `cupy` is installed, then this
module will use `cupy` to accelerate
:type cupy_fp32_inference: bool

The Leaky Integrate-and-Fire neuron, which can be seen as a leaky integrator.
The subthreshold neural dynamics of it is as followed:

.. math::
V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})
IF ``decay_input == True``:

.. math::
V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))

IF ``decay_input == False``:

.. math::
V[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]

.. admonition:: Tip
:class: tip

If this module is in `eval` mode, using float32, running on GPU, and `cupy` is installed, then this
module will use `cupy` to accelerate.

"""
assert isinstance(tau, float) and tau > 1.

super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
self.tau = tau
self.decay_input = decay_input

if cupy_fp32_inference:
check_backend('cupy')
self.cupy_fp32_inference = cupy_fp32_inference

def extra_repr(self):
return super().extra_repr() + f', tau={self.tau}'

def neuronal_charge(self, x: torch.Tensor):
if self.v_reset is None:
self.v = self.v + (x - self.v) / self.tau

else:
if isinstance(self.v_reset, float) and self.v_reset == 0.:
if self.decay_input:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.v + (x - self.v) / self.tau
else:
self.v = self.v + (x - (self.v - self.v_reset)) / self.tau

else:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.v * (1. - 1. / self.tau) + x
else:
self.v = self.v - (self.v - self.v_reset) / self.tau + x

def forward(self, x: torch.Tensor):
if self.cupy_fp32_inference and cupy is not None and not self.training and x.dtype == torch.float32:
# cupy is installed && eval mode && fp32
device_id = x.get_device()
if device_id < 0:
return super().forward(x)

# use cupy to accelerate
if isinstance(self.v, float):
v = torch.zeros_like(x)
if self.v != 0.:
torch.fill_(v, self.v)
self.v = v

if self.v_reset is None:
hard_reset = False
else:
hard_reset = True

code = rf'''
extern "C" __global__
void LIFNode_{'hard' if hard_reset else 'soft'}_reset_decayInput_{self.decay_input}_inference_forward(
const float * x, const float & v_threshold, {'const float & v_reset,' if hard_reset else ''} const float & tau,
float * spike, float * v,
const int & numel)
'''

code += r'''
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < numel)
{
'''

if self.decay_input:
if hard_reset:
code += r'''
v[index] += (x[index] - (v[index] - v_reset)) / tau;
'''
else:
code += r'''
v[index] += (x[index] - v[index]) / tau;
'''
else:
if hard_reset:
code += r'''
v[index] = x[index] + v[index] - (v[index] - v_reset) / tau;
'''
else:
code += r'''
v[index] = x[index] + v[index] * (1.0f - 1.0f / tau);
'''

code += rf'''
spike[index] = (float) (v[index] >= v_threshold);
{'v[index] = (1.0f - spike[index]) * v[index] + spike[index] * v_reset;' if hard_reset else 'v[index] -= spike[index] * v_threshold;'}
'''

code += r'''
}
}
'''
if hasattr(self, 'cp_kernel'):
if self.cp_kernel.code != code:
# replace codes
del self.cp_kernel
self.cp_kernel = cupy.RawKernel(code, f"LIFNode_{'hard' if hard_reset else 'soft'}_reset_decayInput_{self.decay_input}_inference_forward", options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend)
else:
self.cp_kernel = cupy.RawKernel(code, f"LIFNode_{'hard' if hard_reset else 'soft'}_reset_decayInput_{self.decay_input}_inference_forward", options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend)

with cu_kernel_opt.DeviceEnvironment(device_id):
numel = x.numel()
threads = configure.cuda_threads
blocks = cu_kernel_opt.cal_blocks(numel)
cp_numel = cupy.asarray(numel)
cp_v_threshold = cupy.asarray(self.v_threshold, dtype=np.float32)
if hard_reset:
cp_v_reset = cupy.asarray(self.v_reset, dtype=np.float32)
cp_tau = cupy.asarray(self.tau, dtype=np.float32)
spike = torch.zeros_like(x)
if hard_reset:
x, cp_v_threshold, cp_v_reset, cp_tau, spike, self.v, cp_numel = cu_kernel_opt.get_contiguous(x, cp_v_threshold, cp_v_reset, cp_tau, spike, self.v, cp_numel)
kernel_args = [x, cp_v_threshold, cp_v_reset, cp_tau, spike, self.v, cp_numel]
else:
x, cp_v_threshold, cp_tau, spike, self.v, cp_numel = cu_kernel_opt.get_contiguous(x, cp_v_threshold, cp_tau, spike, self.v, cp_numel)
kernel_args = [x, cp_v_threshold, cp_tau, spike, self.v, cp_numel]

self.cp_kernel(
(blocks,), (threads,),
cu_kernel_opt.wrap_args_to_raw_kernel(
device_id,
*kernel_args
)
)
return spike
else:
return super().forward(x)

class MultiStepLIFNode(LIFNode):
def __init__(self, tau: float = 2., v_threshold: float = 1.,
def __init__(self, tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
detach_reset: bool = False, backend='torch'):
detach_reset: bool = False, backend='torch', lava_s_cale=1 << 6):
"""
* :ref:`API in English <MultiStepLIFNode.__init__-en>`

@@ -431,6 +745,9 @@ class MultiStepLIFNode(LIFNode):
:param tau: 膜电位时间常数
:type tau: float

:param decay_input: 输入是否会衰减
:type decay_input: bool

:param v_threshold: 神经元的阈值电压
:type v_threshold: float

@@ -465,6 +782,9 @@ class MultiStepLIFNode(LIFNode):
:param tau: membrane time constant
:type tau: float

:param decay_input: whether the input will decay
:type decay_input: bool

:param v_threshold: threshold voltage of neurons
:type v_threshold: float

@@ -497,25 +817,33 @@ class MultiStepLIFNode(LIFNode):
and multi-step propagation.

"""
super().__init__(tau, v_threshold, v_reset, surrogate_function, detach_reset)
super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset)
self.register_memory('v_seq', None)
self.register_memory('spike_seq', None)

assert backend == 'torch' or backend == 'cupy'
assert not (backend == 'cupy' and neuron_kernel is None), 'cupy is not installed'
check_backend(backend)
self.backend = backend

self.lava_s_cale = lava_s_cale

if backend == 'lava':
self.lava_neuron = self.to_lava()
else:
self.lava_neuron = None

def forward(self, x_seq: torch.Tensor):
assert x_seq.dim() > 1
# x_seq.shape = [T, *]
self.v_seq = torch.zeros_like(x_seq.data)
self.spike_seq = torch.zeros_like(x_seq.data)

if self.backend == 'torch':
spike_seq = []
self.v_seq = []
for t in range(x_seq.shape[0]):
self.spike_seq[t] = super().forward(x_seq[t])
self.v_seq[t] = self.v
return self.spike_seq
spike_seq.append(super().forward(x_seq[t]).unsqueeze(0))
self.v_seq.append(self.v.unsqueeze(0))
spike_seq = torch.cat(spike_seq, 0)
self.v_seq = torch.cat(self.v_seq, 0)
return spike_seq

elif self.backend == 'cupy':
if isinstance(self.v, float):
@@ -524,25 +852,41 @@ class MultiStepLIFNode(LIFNode):
if v_init != 0.:
torch.fill_(self.v, v_init)

spike_seq, self.v_seq = neuron_kernel.MultiStepLIFNodePTT.apply(
x_seq.flatten(1), self.v.flatten(0), self.decay_input, self.tau, self.v_threshold, self.v_reset, self.detach_reset, self.surrogate_function.cuda_code)

self.spike_seq, self.v_seq = neuron_kernel.MultiStepLIFNodePTT.apply(
x_seq.flatten(1), self.v.flatten(0), self.tau, self.v_threshold, self.v_reset, self.detach_reset, self.surrogate_function.cuda_code)

self.spike_seq = self.spike_seq.reshape(x_seq.shape)
spike_seq = spike_seq.reshape(x_seq.shape)
self.v_seq = self.v_seq.reshape(x_seq.shape)

self.spike = self.spike_seq[-1].clone()
self.v = self.v_seq[-1].clone()

return self.spike_seq
return spike_seq

elif self.backend == 'lava':
if self.lava_neuron is None:
self.lava_neuron = self.to_lava()

spike, self.v = lava_exchange.lava_neuron_forward(self.lava_neuron, x_seq, self.v)

return spike

else:
raise NotImplementedError
raise NotImplementedError(self.backend)

def extra_repr(self):
return super().extra_repr() + f', backend={self.backend}'

def to_lava(self):
return lava_exchange.to_lava_neuron(self)

def reset(self):
super().reset()
if self.lava_neuron is not None:
self.lava_neuron.current_state.zero_()
self.lava_neuron.voltage_state.zero_()

class ParametricLIFNode(BaseNode):
def __init__(self, init_tau: float = 2.0, v_threshold: float = 1.,
def __init__(self, init_tau: float = 2.0, decay_input: bool = True, v_threshold: float = 1.,
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
detach_reset: bool = False):
"""
@@ -553,6 +897,9 @@ class ParametricLIFNode(BaseNode):
:param init_tau: 膜电位时间常数的初始值
:type init_tau: float

:param decay_input: 输入是否会衰减
:type decay_input: bool

:param v_threshold: 神经元的阈值电压
:type v_threshold: float

@@ -569,8 +916,15 @@ class ParametricLIFNode(BaseNode):
`Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_
提出的 Parametric Leaky Integrate-and-Fire (PLIF)神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:

.. math::
V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})
若 ``decay_input == True``:

.. math::
V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))

若 ``decay_input == False``:

.. math::
V[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]

其中 :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`,:math:`w` 是可学习的参数。

@@ -581,6 +935,9 @@ class ParametricLIFNode(BaseNode):
:param init_tau: the initial value of membrane time constant
:type init_tau: float

:param decay_input: whether the input will decay
:type decay_input: bool

:param v_threshold: threshold voltage of neurons
:type v_threshold: float

@@ -597,14 +954,22 @@ class ParametricLIFNode(BaseNode):
The Parametric Leaky Integrate-and-Fire (PLIF) neuron, which is proposed by `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_ and can be seen as a leaky integrator.
The subthreshold neural dynamics of it is as followed:

.. math::
V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})
IF ``decay_input == True``:

.. math::
V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))

IF ``decay_input == False``:

.. math::
V[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]

where :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`, :math:`w` is a learnable parameter.
"""

assert isinstance(init_tau, float) and init_tau > 1.
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
self.decay_input = decay_input
init_w = - math.log(init_tau - 1.)
self.w = nn.Parameter(torch.as_tensor(init_w))

@@ -614,17 +979,19 @@ class ParametricLIFNode(BaseNode):
return super().extra_repr() + f', tau={tau}'

def neuronal_charge(self, x: torch.Tensor):
if self.v_reset is None:
self.v = self.v + (x - self.v) * self.w.sigmoid()
else:
if self.v_reset == 0.:
if self.decay_input:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.v + (x - self.v) * self.w.sigmoid()
else:
self.v = self.v + (x - (self.v - self.v_reset)) * self.w.sigmoid()

else:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.v * (1. - self.w.sigmoid()) + x
else:
self.v = self.v - (self.v - self.v_reset) * self.w.sigmoid() + x

class MultiStepParametricLIFNode(ParametricLIFNode):
def __init__(self, init_tau: float = 2., v_threshold: float = 1.,
def __init__(self, init_tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
detach_reset: bool = False, backend='torch'):
"""
@@ -635,6 +1002,9 @@ class MultiStepParametricLIFNode(ParametricLIFNode):
:param init_tau: 膜电位时间常数的初始值
:type init_tau: float

:param decay_input: 输入是否会衰减
:type decay_input: bool

:param v_threshold: 神经元的阈值电压
:type v_threshold: float

@@ -672,6 +1042,9 @@ class MultiStepParametricLIFNode(ParametricLIFNode):
:param init_tau: the initial value of membrane time constant
:type init_tau: float

:param decay_input: whether the input will decay
:type decay_input: bool

:param v_threshold: threshold voltage of neurons
:type v_threshold: float

@@ -709,25 +1082,26 @@ class MultiStepParametricLIFNode(ParametricLIFNode):
Read :doc:`Propagation Pattern <./clock_driven_en/10_propagation_pattern>` for more details about single-step
and multi-step propagation.
"""
super().__init__(init_tau, v_threshold, v_reset, surrogate_function, detach_reset)
super().__init__(init_tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset)
self.register_memory('v_seq', None)
self.register_memory('spike_seq', None)

assert backend == 'torch' or backend == 'cupy'
assert not (backend == 'cupy' and neuron_kernel is None), 'cupy is not installed'
check_backend(backend)
self.backend = backend

def forward(self, x_seq: torch.Tensor):
assert x_seq.dim() > 1
# x_seq.shape = [T, *]
self.v_seq = torch.zeros_like(x_seq.data)
self.spike_seq = torch.zeros_like(x_seq.data)

if self.backend == 'torch':
spike_seq = []
self.v_seq = []
for t in range(x_seq.shape[0]):
self.spike_seq[t] = super().forward(x_seq[t])
self.v_seq[t] = self.v
return self.spike_seq
spike_seq.append(super().forward(x_seq[t]).unsqueeze(0))
self.v_seq.append(self.v.unsqueeze(0))
spike_seq = torch.cat(spike_seq, 0)
self.v_seq = torch.cat(self.v_seq, 0)
return spike_seq

elif self.backend == 'cupy':
if isinstance(self.v, float):
@@ -737,16 +1111,15 @@ class MultiStepParametricLIFNode(ParametricLIFNode):
torch.fill_(self.v, v_init)


self.spike_seq, self.v_seq = neuron_kernel.MultiStepParametricLIFNodePTT.apply(
x_seq.flatten(1), self.v.flatten(0), self.w.sigmoid(), self.v_threshold, self.v_reset, self.detach_reset, self.surrogate_function.cuda_code)
spike_seq, self.v_seq = neuron_kernel.MultiStepParametricLIFNodePTT.apply(
x_seq.flatten(1), self.v.flatten(0), self.w.sigmoid(), self.decay_input, self.v_threshold, self.v_reset, self.detach_reset, self.surrogate_function.cuda_code)

self.spike_seq = self.spike_seq.reshape(x_seq.shape)
spike_seq = spike_seq.reshape(x_seq.shape)
self.v_seq = self.v_seq.reshape(x_seq.shape)

self.spike = self.spike_seq[-1].clone()
self.v = self.v_seq[-1].clone()

return self.spike_seq
return spike_seq
else:
raise NotImplementedError

@@ -1037,23 +1410,24 @@ class MultiStepEIFNode(EIFNode):
super().__init__(tau, delta_T, theta_rh, v_threshold, v_rest, v_reset,
surrogate_function, detach_reset)
self.register_memory('v_seq', None)
self.register_memory('spike_seq', None)

assert backend == 'torch' or backend == 'cupy'
assert not (backend == 'cupy' and neuron_kernel is None), 'cupy is not installed'
check_backend(backend)
self.backend = backend

def forward(self, x_seq: torch.Tensor):
assert x_seq.dim() > 1
# x_seq.shape = [T, *]
self.v_seq = torch.zeros_like(x_seq.data)
self.spike_seq = torch.zeros_like(x_seq.data)

if self.backend == 'torch':
spike_seq = []
self.v_seq = []
for t in range(x_seq.shape[0]):
self.spike_seq[t] = super().forward(x_seq[t])
self.v_seq[t] = self.v
return self.spike_seq
spike_seq.append(super().forward(x_seq[t]).unsqueeze(0))
self.v_seq.append(self.v.unsqueeze(0))
spike_seq = torch.cat(spike_seq, 0)
self.v_seq = torch.cat(self.v_seq, 0)
return spike_seq

elif self.backend == 'cupy':
if isinstance(self.v, float):
@@ -1063,18 +1437,111 @@ class MultiStepEIFNode(EIFNode):
torch.fill_(self.v, v_init)


self.spike_seq, self.v_seq = neuron_kernel.MultiStepEIFNodePTT.apply(
spike_seq, self.v_seq = neuron_kernel.MultiStepEIFNodePTT.apply(
x_seq.flatten(1), self.v.flatten(0), self.tau, self.v_threshold, self.v_reset, self.v_rest, self.theta_rh, self.delta_T, self.detach_reset, self.surrogate_function.cuda_code)

self.spike_seq = self.spike_seq.reshape(x_seq.shape)
spike_seq = spike_seq.reshape(x_seq.shape)
self.v_seq = self.v_seq.reshape(x_seq.shape)

self.v = self.v_seq[-1].clone()

return spike_seq
else:
raise NotImplementedError

def extra_repr(self):
return super().extra_repr() + f', backend={self.backend}'

class GeneralNode(BaseNode):
def __init__(self, a: float or torch.Tensor, b: float or torch.Tensor, c: float or torch.Tensor = 0., v_threshold: float = 1., v_reset: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
self.a = self.register_buffer('a', torch.as_tensor(a))
self.b = self.register_buffer('b', torch.as_tensor(b))
self.c = self.register_buffer('c', torch.as_tensor(c))

def neuronal_charge(self, x: torch.Tensor):
self.v = self.a * self.v + self.b * x + self.c

class MultiStepGeneralNode(GeneralNode):
def __init__(self, a: float, b: float, c: float, v_threshold: float = 1., v_reset: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, backend='torch'):

super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)

self.register_memory('v_seq', None)

check_backend(backend)

self.backend = backend

def forward(self, x_seq: torch.Tensor):
assert x_seq.dim() > 1
# x_seq.shape = [T, *]

if self.backend == 'torch':
spike_seq = []
self.v_seq = []
for t in range(x_seq.shape[0]):
spike_seq.append(super().forward(x_seq[t]).unsqueeze(0))
self.v_seq.append(self.v.unsqueeze(0))
spike_seq = torch.cat(spike_seq, 0)
self.v_seq = torch.cat(self.v_seq, 0)
return spike_seq

elif self.backend == 'cupy':
if isinstance(self.v, float):
v_init = self.v
self.v = torch.zeros_like(x_seq[0].data)
if v_init != 0.:
torch.fill_(self.v, v_init)

raise NotImplementedError

spike_seq = spike_seq.reshape(x_seq.shape)
self.v_seq = self.v_seq.reshape(x_seq.shape)

self.spike = self.spike_seq[-1].clone()
self.v = self.v_seq[-1].clone()

return self.spike_seq
return spike_seq
else:
raise NotImplementedError

def extra_repr(self):
return super().extra_repr() + f', backend={self.backend}'


class LIAFNode(LIFNode):
def __init__(self, act: Callable, threshold_related: bool, *args, **kwargs):
"""
:param act: the activation function
:type act: Callable
:param threshold_related: whether the neuron uses threshold related (TR mode). If true, `y = act(h - v_th)`,
otherwise `y = act(h)`
:type threshold_related: bool

Other parameters in `*args, **kwargs` are same with :class:`LIFNode`.

The LIAF neuron proposed in `LIAF-Net: Leaky Integrate and Analog Fire Network for Lightweight and Efficient Spatiotemporal Information Processing <https://arxiv.org/abs/2011.06176>`_.

.. admonition:: Warning
:class: warning

The outputs of this neuron are not binary spikes.

"""
super().__init__(*args, **kwargs)
self.act = act
self.threshold_related = threshold_related

def forward(self, x: torch.Tensor):
self.neuronal_charge(x)
if self.threshold_related:
y = self.act(self.v - self.v_threshold)
else:
y = self.act(self.v)
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return y



+ 13750
- 1635
spikingjelly/clock_driven/neuron_kernel.cu
File diff suppressed because it is too large
View File


+ 29
- 2
spikingjelly/clock_driven/neuron_kernel.md View File

@@ -85,7 +85,7 @@ $$
$$
## Leaky-Integrate-and-Fire Neuron (LIF Neuron)

For the LIF neuron, the charge function is
For the LIF neuron with decay input, the charge function is
$$
H[t] = V[t - 1] + \frac{1}{\tau}(X[t] - (V[t - 1] - V_{reset}))
$$
@@ -98,9 +98,22 @@ $$
\end{align}
$$

For the LIF neuron without decay input, the charge function is
$$
H[t] = V[t - 1] - \frac{1}{\tau}(V[t - 1] - V_{reset}) + X[t]
$$
Then the gradients are
$$
\begin{align}
\frac{\mathrm{d} L}{\mathrm{d} H[t]} &=\frac{\partial L}{\partial S[t]}\frac{\mathrm{d} S[t]}{\mathrm{d} H[t]} + (\frac{\partial L}{\partial V[t]}+\frac{\mathrm{d} L}{\mathrm{d} H[t+1]}(1 - \frac{1}{\tau}))\frac{\mathrm{d} V[t]}{\mathrm{d} H[t]}\\
\frac{\mathrm{d} L}{\mathrm{d} X[t]} &= \frac{\mathrm{d} L}{\mathrm{d} H[t]}\\
\frac{\mathrm{d} L}{\mathrm{d} V[0]} &= \frac{\mathrm{d} L}{\mathrm{d} H[1]} (1 - \frac{1}{\tau})
\end{align}
$$

## Parametric Leaky-Integrate-and-Fire Neuron (PLIF Neuron)

For the PLIF neuron, the charge function is
For the PLIF neuron with decay input, the charge function is
$$
H[t] = V[t - 1] + \frac{1}{\tau}(X[t] - (V[t - 1] - V_{reset}))
$$
@@ -114,6 +127,20 @@ $$
\end{align}
$$

For the PLIF neuron without decay input, the charge function is
$$
H[t] = V[t - 1] - \frac{1}{\tau}(V[t - 1] - V_{reset}) + X[t]
$$
Then the gradients are
$$
\begin{align}
\frac{\mathrm{d} L}{\mathrm{d} H[t]} &=\frac{\partial L}{\partial S[t]}\frac{\mathrm{d} S[t]}{\mathrm{d} H[t]} + (\frac{\partial L}{\partial V[t]}+\frac{\mathrm{d} L}{\mathrm{d} H[t+1]}(1 - \frac{1}{\tau}))\frac{\mathrm{d} V[t]}{\mathrm{d} H[t]}\\
\frac{\mathrm{d} L}{\mathrm{d} X[t]} &= \frac{\mathrm{d} L}{\mathrm{d} H[t]}\\
\frac{\mathrm{d} L}{\mathrm{d} \frac{1}{\tau}} &= \sum_{t} \frac{\mathrm{d} L}{\mathrm{d} H[t]} (V_{reset} - V[t - 1])\\
\frac{\mathrm{d} L}{\mathrm{d} V[0]} &= \frac{\mathrm{d} L}{\mathrm{d} H[1]} (1 - \frac{1}{\tau})
\end{align}
$$

## Exponential Integrate-and-Fire Neuron (EIF Neuron)

For the EIF neuron, the charge function is


+ 1567
- 1423
spikingjelly/clock_driven/neuron_kernel.py
File diff suppressed because it is too large
View File


+ 506
- 0
spikingjelly/clock_driven/spike_op.py View File

@@ -0,0 +1,506 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.cpp_extension import load_inline
from torch.cuda.amp import custom_fwd, custom_bwd
import logging
from . import tensor_cache

from torch import Tensor
from typing import Optional, Union
from torch.types import _int, _size
from torch.nn.modules.utils import _single, _pair, _triple

try:
import cupy
except BaseException as e:
logging.info(f'spikingjelly.clock_driven.spike_op: {e}')
cupy = None


try:
logging.info('spikingjelly.clock_driven.spike_op: try to use `torch.utils.cpp_extension.load_inline` to load cudnn functions.')
logging.info(f'If it is hanging, pleast try to delete torch_extensions cache directory. (In most cases, the directory is {torch.utils.cpp_extension._get_build_directory("", False)}.)')
cpp_wrapper = load_inline(
name='cpp_wrapper',
cpp_sources='using namespace at;',
functions=[
'cudnn_convolution_backward',
'cudnn_convolution_backward_input',
'cudnn_convolution_backward_weight'
],
with_cuda=True
)
except BaseException as e:
logging.info(f'spikingjelly.clock_driven.spike_op: {e}')
cpp_wrapper = None

'''
aten/src/ATen/native/cudnn/ConvPlaceholders.cpp

at::Tensor cudnn_convolution(
const at::Tensor& input, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic, bool allow_tf32)

There are two overloaded C++ methods `cudnn_convolution`. So, we need to use an alternative syntax to cast the overloaded function.
Refer to https://pybind11.readthedocs.io/en/stable/classes.html#overloaded-methods and https://github.com/pytorch/pytorch/issues/39518 for more details.
aten/src/ATen/native/cudnn/ConvShared.cpp

Tensor cudnn_convolution_forward(
CheckedFrom c,
const TensorArg& input, const TensorArg& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32)

aten/src/ATen/native/cudnn/ConvPlaceholders.cpp

std::tuple<at::Tensor,at::Tensor> cudnn_convolution_backward(
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32, std::array<bool,2> output_mask)
aten/src/ATen/native/cudnn/ConvShared.cpp

at::Tensor cudnn_convolution_backward_input(
IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32)
aten/src/ATen/native/cudnn/ConvShared.cpp

at::Tensor cudnn_convolution_backward_weight(
IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32)
'''

class spikeConvolution(torch.autograd.Function):
# Pytorch only provides cudnn_convolution without bias.
# Refer to https://github.com/pytorch/pytorch/issues/3823 for more details.
@staticmethod
@custom_fwd
def forward(ctx, spike, weight, bias, stride, padding, dilation, groups):
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
if ctx.needs_input_grad[1]:
ctx.s_shape = spike.shape
ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike)

if ctx.needs_input_grad[0]:
ctx.save_for_backward(weight)

ctx.padding = padding
ctx.stride = stride
ctx.dilation = dilation
ctx.groups = groups
ctx.weight_shape = weight.shape

if spike.dim() == 3:
return F.conv1d(input=spike, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
elif spike.dim() == 4:
return F.conv2d(input=spike, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
elif spike.dim() == 5:
return F.conv3d(input=spike, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)



@staticmethod
@custom_bwd
def backward(ctx, grad_output):
grad_spike = None
grad_weight = None
grad_bias = None
if ctx.needs_input_grad[0] and ctx.needs_input_grad[1]:
weight = ctx.saved_tensors[0]
spike = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
weight = weight.to(grad_output.dtype)
grad_spike, grad_weight = cpp_wrapper.cudnn_convolution_backward(spike, grad_output, weight, ctx.padding,
ctx.stride, ctx.dilation, ctx.groups,
torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic,
torch.backends.cudnn.allow_tf32, (
True,
True))

elif not ctx.needs_input_grad[0] and ctx.needs_input_grad[1]:
spike = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
grad_weight = cpp_wrapper.cudnn_convolution_backward_weight(ctx.weight_shape, grad_output, spike, ctx.padding,
ctx.stride, ctx.dilation, ctx.groups,
torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic,
torch.backends.cudnn.allow_tf32)

elif ctx.needs_input_grad[0] and not ctx.needs_input_grad[1]:
weight = ctx.saved_tensors[0]
weight = weight.to(grad_output.dtype)
grad_spike = cpp_wrapper.cudnn_convolution_backward_input(ctx.spike_shape, grad_output, weight, ctx.padding,
ctx.stride, ctx.dilation, ctx.groups,
torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic,
torch.backends.cudnn.allow_tf32)

if ctx.needs_input_grad[2]:
# grad_output.shape = [N, C, *]
out_channels = grad_output.shape[1]
grad_bias = grad_output.transpose(0, 1).reshape(out_channels, -1).sum(1)
return grad_spike, grad_weight, grad_bias, None, None, None, None

class spikeLinear(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, spike, weight, bias=None):
# spike.shape = [N, *, in_features]
# weight.shape = [out_features, in_features]
# bias.shape = [out_features]
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
if ctx.needs_input_grad[1]:
ctx.s_shape = spike.shape
ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(weight)
return F.linear(spike, weight, bias)

@staticmethod
@custom_bwd
def backward(ctx, grad_output):
# grad_output.shape = [N, *, out_features]
if ctx.needs_input_grad[1]:
weight = ctx.saved_tensors[0]
if ctx.needs_input_grad[0]:
spike = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)

grad_spike = grad_weight = grad_bias = None

if ctx.needs_input_grad[0]:
grad_spike = F.linear(grad_output, weight.t(), bias=None)
if ctx.needs_input_grad[1]:
in_features = spike.shape[-1]
out_features = grad_output.shape[-1]
# grad_output.reshape(-1, out_features).t().shape = [out_features, N*]
# spike.reshape(-1, in_features).shape = [N*, in_features]
grad_weight = torch.mm(grad_output.reshape(-1, out_features).t(), spike.reshape(-1, in_features).to(grad_output.dtype))
if ctx.needs_input_grad[2]:
out_features = grad_output.shape[-1]
grad_bias = grad_output.reshape(-1, out_features).sum(0)
return grad_spike, grad_weight, grad_bias

def spike_linear(spike: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
"""
* :ref:`API in English <spike_linear-en>`

.. _spike_linear-cn:

:class:`torch.nn.functional.linear` 在输入为脉冲时的特例。

.. note::

在CUDA设备上训练时拥有比 :class:`torch.nn.functional.linear` 更低的显存消耗。

.. warning::

`spike` 中的任何元素都必须为0或1。

* :ref:`中文API <spike_linear-cn>`

.. _spike_linear-en:

A specific case of :class:`torch.nn.functional.linear` with inputs are spikes.

.. admonition:: Note
:class: note

This function has less memory consumption than :class:`torch.nn.functional.linear` when training on CUDA devices.

.. admonition:: Warning
:class: warning

Any element in `spike` must be 0 or 1.
"""
if spike.get_device() < 0:
return F.linear(spike, weight, bias)
else:
return spikeLinear.apply(spike, weight, bias)

def spike_conv1d(spike: Tensor, weight: Tensor, bias: Tensor=None, stride: Union[_int, _size]=1, padding: str="valid", dilation: Union[_int, _size]=1, groups: _int=1) -> Tensor:
"""
* :ref:`API in English <spike_conv1d-en>`

.. _spike_conv1d-cn:

:class:`torch.nn.functional.conv1d` 在输入为脉冲时的特例。

.. note::

在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv1d` 更低的显存消耗。

.. warning::

`spike` 中的任何元素都必须为0或1。

* :ref:`中文API <spike_conv1d-cn>`

.. _spike_conv1d-en:

A specific case of :class:`torch.nn.functional.conv1d` with inputs are spikes.

.. admonition:: Note
:class: note

This function has less memory consumption than :class:`torch.nn.functional.conv1d` when training on CUDA devices.

.. admonition:: Warning
:class: warning

Any element in `spike` must be 0 or 1.
"""
if spike.get_device() < 0:
return F.conv1d(spike, weight, bias, stride, padding, dilation, groups)
else:
return spikeConvolution.apply(spike, weight, bias, stride, padding, dilation, groups)

def spike_conv2d(spike: Tensor, weight: Tensor, bias: Optional[Tensor]=None, stride: Union[_int, _size]=1, padding: str="valid", dilation: Union[_int, _size]=1, groups: _int=1) -> Tensor:
"""
* :ref:`API in English <spike_conv2d-en>`

.. _spike_conv2d-cn:

:class:`torch.nn.functional.conv2d` 在输入为脉冲时的特例。

.. note::

在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv2d` 更低的显存消耗。

.. warning::

`spike` 中的任何元素都必须为0或1。

* :ref:`中文API <spike_conv2d-cn>`

.. _spike_conv2d-en:

A specific case of :class:`torch.nn.functional.conv2d` with inputs are spikes.

.. admonition:: Note
:class: note

This function has less memory consumption than :class:`torch.nn.functional.conv2d` when training on CUDA devices.

.. admonition:: Warning
:class: warning

Any element in `spike` must be 0 or 1.
"""
if spike.get_device() < 0:
return F.conv2d(spike, weight, bias, stride, padding, dilation, groups)
else:
return spikeConvolution.apply(spike, weight, bias, stride, padding, dilation, groups)

def spike_conv3d(spike: Tensor, weight: Tensor, bias: Optional[Tensor]=None, stride: Union[_int, _size]=1, padding: str="valid", dilation: Union[_int, _size]=1, groups: _int=1) -> Tensor:
"""
* :ref:`API in English <spike_conv3d-en>`

.. _spike_conv3d-cn:

:class:`torch.nn.functional.conv3d` 在输入为脉冲时的特例。

.. note::

在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv3d` 更低的显存消耗。

.. warning::

`spike` 中的任何元素都必须为0或1。

* :ref:`中文API <spike_conv3d-cn>`

.. _spike_conv3d-en:

A specific case of :class:`torch.nn.functional.conv3d` with inputs are spikes.

.. admonition:: Note
:class: note

This function has less memory consumption than :class:`torch.nn.functional.conv3d` when training on CUDA devices.

.. admonition:: Warning
:class: warning

Any element in `spike` must be 0 or 1.
"""
if spike.get_device() < 0:
return F.conv3d(spike, weight, bias, stride, padding, dilation, groups)
else:
return spikeConvolution.apply(spike, weight, bias, stride, padding, dilation, groups)


class SpikeLinear(nn.Linear):
"""
* :ref:`API in English <SpikeLinear-en>`

.. _SpikeLinear-cn:

:class:`torch.nn.Linear` 在输入为脉冲时的特例。

.. note::

在CUDA设备上运行时拥有比 :class:`torch.nn.Linear` 更低的显存消耗。

.. warning::

`spike` 中的任何元素都必须为0或1。

* :ref:`中文API <SpikeLinear-cn>`

.. _SpikeLinear-en:

A specific case of :class:`torch.nn.Linear` with inputs are spikes.

.. admonition:: Note
:class: note

This function has less memory consumption than :class:`torch.nn.Linear` when training on CUDA devices.

.. admonition:: Warning
:class: warning

Any element in `spike` must be 0 or 1.
"""

def forward(self, spike: Tensor) -> Tensor:
return spike_linear(spike, self.weight, self.bias)


class SpikeConv1d(nn.Conv1d):
"""
* :ref:`API in English <SpikeConv1d-en>`

.. _SpikeConv1d-cn:

:class:`torch.nn.Conv1d` 在输入为脉冲时的特例。

.. note::

在CUDA设备上运行时拥有比 :class:`torch.nn.Conv1d` 更低的显存消耗。

.. warning::

`spike` 中的任何元素都必须为0或1。

* :ref:`中文API <SpikeConv1d-cn>`

.. _SpikeConv1d-en:

A specific case of :class:`torch.nn.Conv1d` with inputs are spikes.

.. admonition:: Note
:class: note

This function has less memory consumption than :class:`torch.nn.Conv1d` when training on CUDA devices.

.. admonition:: Warning
:class: warning

Any element in `spike` must be 0 or 1.
"""

def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != 'zeros':
return spike_conv1d(F.pad(spike, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, bias, self.stride,
_single(0), self.dilation, self.groups)
return spike_conv1d(spike, weight, bias, self.stride,
self.padding, self.dilation, self.groups)


class SpikeConv2d(nn.Conv2d):
"""
* :ref:`API in English <SpikeConv2d-en>`

.. _SpikeConv2d-cn:

:class:`torch.nn.Conv2d` 在输入为脉冲时的特例。

.. note::

在CUDA设备上运行时拥有比 :class:`torch.nn.Conv2d` 更低的显存消耗。

.. warning::

`spike` 中的任何元素都必须为0或1。

* :ref:`中文API <SpikeConv2d-cn>`

.. _SpikeConv2d-en:

A specific case of :class:`torch.nn.Conv2d` with inputs are spikes.

.. admonition:: Note
:class: note

This function has less memory consumption than :class:`torch.nn.Conv2d` when training on CUDA devices.

.. admonition:: Warning
:class: warning

Any element in `spike` must be 0 or 1.
"""

def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != 'zeros':
return spike_conv2d(F.pad(spike, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, bias, self.stride,
_pair(0), self.dilation, self.groups)
return spike_conv2d(spike, weight, bias, self.stride,
self.padding, self.dilation, self.groups)


class SpikeConv3d(nn.Conv3d):
"""
* :ref:`API in English <SpikeConv3d-en>`

.. _SpikeConv3d-cn:

:class:`torch.nn.Conv3d` 在输入为脉冲时的特例。

.. note::

在CUDA设备上运行时拥有比 :class:`torch.nn.Conv3d` 更低的显存消耗。

.. warning::

`spike` 中的任何元素都必须为0或1。

* :ref:`中文API <SpikeConv3d-cn>`

.. _SpikeConv3d-en:

A specific case of :class:`torch.nn.Conv3d` with inputs are spikes.

.. admonition:: Note
:class: note

This function has less memory consumption than :class:`torch.nn.Conv3d` when training on CUDA devices.

.. admonition:: Warning
:class: warning

Any element in `spike` must be 0 or 1.
"""

def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != "zeros":
return spike_conv3d(
F.pad(
spike, self._reversed_padding_repeated_twice, mode=self.padding_mode
),
weight,
bias,
self.stride,
_triple(0),
self.dilation,
self.groups,
)
return spike_conv3d(
spike, weight, bias, self.stride, self.padding, self.dilation, self.groups
)

+ 333
- 11
spikingjelly/clock_driven/surrogate.py View File

@@ -45,14 +45,12 @@ def heaviside(x: torch.Tensor):
'''
return (x >= 0).to(x)

def check_manual_grad(primitive_function, spiking_function, eps=1e-5):
def check_manual_grad(primitive_function, spiking_function, *args, **kwargs):
'''
:param primitive_function: 梯度替代函数的原函数
:type primitive_function: callable
:param spiking_function: 梯度替代函数
:type spiking_function: callable
:param eps: 最大误差
:type eps: float

梯度替代函数的反向传播一般是手写的,可以用此函数去检查手写梯度是否正确。

@@ -62,18 +60,54 @@ def check_manual_grad(primitive_function, spiking_function, eps=1e-5):

.. code-block:: python

surrogate.check_manual_grad(surrogate.ATan.primitive_function, surrogate.atan.apply)
def s2nn_apply(x, alpha, beta):
return surrogate.s2nn.apply(x, alpha, beta)

surrogate.check_manual_grad(surrogate.S2NN.primitive_function, s2nn_apply, alpha=4., beta=1.)
'''
alpha = torch.tensor(1.0, dtype=torch.float)
x = torch.arange(-16, 16, 32 / 8192)
x = torch.arange(-2, 2, 32 / 8192)
# x = torch.as_tensor([-1., 0., 1.])
x.requires_grad_(True)
primitive_function(x, alpha).sum().backward()
primitive_function(x, *args, **kwargs).sum().backward()
x_grad_auto = x.grad.clone()
x.grad.zero_()
spiking_function(x, alpha).sum().backward()
spiking_function(x, *args, **kwargs).sum().backward()
x_grad_manual = x.grad.clone()
assert (x_grad_manual - x_grad_auto).abs().max().item() <= eps, 'x.grad is wrong!'
print('grad check pass')
print('auto grad', x_grad_auto)
print('manual grad', x_grad_manual)
abs_error = (x_grad_manual - x_grad_auto).abs()
idx = abs_error.argmax()
print('max error', abs_error[idx], 'occurs at')
print(f'x[{idx}] = {x[idx]}')
print('auto grad', x_grad_auto[idx])
print('manual grad', x_grad_manual[idx])

def check_cuda_grad(neu: nn.Module, surrogate_function, device, *args, **kwargs):
# check_cuda_grad(neuron.MultiStepIFNode, surrogate.S2NN, device='cuda:1', alpha=4., beta=1.)
for dtype in [torch.float, torch.half]:
print(dtype)
net = neu(surrogate_function=surrogate_function(*args, **kwargs))
net.to(device)
x = torch.arange(-2, 2, 32 / 8192, device=device, dtype=dtype)
x = x.unsqueeze(-1)
x.requires_grad_(True)
net.backend = 'torch'
net(x).sum().backward()
x_grad_py = x.grad.clone()
x.grad.zero_()
net.reset()
net.backend = 'cupy'
net(x).sum().backward()
x_grad_cp = x.grad.clone()
# print('python grad', x_grad_py)
# print('cupy grad', x_grad_cp)
abs_error = (x_grad_cp - x_grad_py).abs()
idx = abs_error.argmax()
print('max error', abs_error[idx], 'occurs at')
print(f'x[{idx}] = {x[idx]}')
print('python grad', x_grad_py[idx])
print('cupy grad', x_grad_cp[idx])


class SurrogateFunctionBase(nn.Module):
def __init__(self, alpha, spiking=True):
@@ -133,7 +167,8 @@ class piecewise_quadratic(torch.autograd.Function):
@staticmethod
def forward(ctx, x, alpha):
if x.requires_grad:
ctx.save_for_backward(x, alpha)
ctx.save_for_backward(x)
ctx.alpha = alpha
return heaviside(x)

@staticmethod
@@ -1221,5 +1256,292 @@ class SquarewaveFourierSeries(MultiArgsSurrogateFunctionBase):
# plt.savefig('./docs/source/_static/API/clock_driven/surrogate/SquarewaveFourierSeries2.pdf')
# plt.savefig('./docs/source/_static/API/clock_driven/surrogate/SquarewaveFourierSeries2.svg')

class s2nn(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, alpha: float, beta: float):
if x.requires_grad:
ctx.save_for_backward(x)
ctx.alpha = alpha
ctx.beta = beta
return heaviside(x)

@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
sgax = torch.sigmoid(ctx.alpha * x)
grad_x = torch.where(x < 0., ctx.alpha * sgax * (1. - sgax), ctx.beta / (x + 1.))
return grad_x * grad_output, None, None

class S2NN(MultiArgsSurrogateFunctionBase):
def __init__(self, alpha=4., beta=1., spiking=True):
"""
* :ref:`API in English <S2NN.__init__-en>`
.. _S2NN.__init__-cn:

:param alpha: 控制 ``x < 0`` 时梯度的参数
:param beta: 控制 ``x >= 0`` 时梯度的参数
:param spiking: 是否输出脉冲,默认为 ``True``,在前向传播时使用 ``heaviside`` 而在反向传播使用替代梯度。若为 ``False``
则不使用替代梯度,前向传播时,使用反向传播时的梯度替代函数对应的原函数

`S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks <https://arxiv.org/abs/2201.10879>`_ 提出的S2NN替代函数。反向传播为

.. math::
g'(x) = \\begin{cases}
\\alpha * (1 - \\mathrm{sigmoid} (\\alpha x)) \\mathrm{sigmoid} (\\alpha x), x < 0 \\\\
\\beta (x + 1), x \ge 0
\\end{cases}

对应的原函数为

.. math::
g(x) = \\begin{cases}
\\mathrm{sigmoid} (\\alpha x), x < 0 \\\\
\\beta \\mathrm{ln}(x + 1) + 1, x \ge 0
\\end{cases}

.. image:: ./_static/API/clock_driven/surrogate/S2NN.*
:width: 100%


* :ref:`中文API <S2NN.__init__-cn>`
.. _S2NN.__init__-en:

:param alpha: the param that controls the gradient when ``x < 0``
:param beta: the param that controls the gradient when ``x >= 0``
:param spiking: whether output spikes. The default is ``True`` which means that using ``heaviside`` in forward
propagation and using surrogate gradient in backward propagation. If ``False``, in forward propagation,
using the primitive function of the surrogate gradient function used in backward propagation

The S2NN surrogate spiking function, which is proposed by `S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks <https://arxiv.org/abs/2201.10879>`_. The gradient is defined by

.. math::
g'(x) = \\begin{cases}
\\alpha * (1 - \\mathrm{sigmoid} (\\alpha x)) \\mathrm{sigmoid} (\\alpha x), x < 0 \\\\
\\beta (x + 1), x \ge 0
\\end{cases}

The primitive function is defined by

.. math::
g(x) = \\begin{cases}
\\mathrm{sigmoid} (\\alpha x), x < 0 \\\\
\\beta \\mathrm{ln}(x + 1) + 1, x \ge 0
\\end{cases}

.. image:: ./_static/API/clock_driven/surrogate/S2NN.*
:width: 100%
"""
super().__init__(spiking)
self.alpha = alpha
self.beta = beta
self.spiking = spiking
if spiking:
self.f = self.spiking_function
else:
self.f = self.primitive_function

def forward(self, x):
return self.f(x, self.alpha, self.beta)

@staticmethod
def spiking_function(x: torch.Tensor, alpha, beta):
return s2nn.apply(x, alpha, beta)

@staticmethod
def primitive_function(x: torch.Tensor, alpha, beta):
return torch.where(x < 0., torch.sigmoid(x * alpha), beta * torch.log((x + 1.).abs_() + 1e-5) + 0.5)
# abs and 1e-5 are used to avoid nan

def cuda_code(self, x: str, y: str, dtype='fp32'):
sg_name = 'sg_' + self._get_name()
alpha = str(self.alpha) + 'f'
beta = str(self.beta) + 'f'
code = f'''
{tab4_str}{self.cuda_code_start_comments()}
'''

if dtype == 'fp32':
code += f'''
{tab4_str}const float {sg_name}_sigmoid_ax = 1.0f / (1.0f + expf(- {alpha} * {x}));
{tab4_str}const float {sg_name}_mask_l = (float)({x} < 0.0f);
{tab4_str}const float {y} = (1.0f - {sg_name}_sigmoid_ax) * {sg_name}_sigmoid_ax * {alpha} * {sg_name}_mask_l + {beta} / ({x} + 1.0f) * (1.0f - {sg_name}_mask_l);
'''
elif dtype == 'fp16':
code += f'''
{tab4_str}const half2 {sg_name}_alpha = __float2half2_rn({alpha});
{tab4_str}const half2 {sg_name}_sigmoid_ax = __h2div(__float2half2_rn(1.0f), __hadd2(h2exp(__hneg2(__hmul2({sg_name}_alpha, {x}))), __float2half2_rn(1.0f)));
{tab4_str}const half2 {sg_name}_mask_l = __hlt2({x}, __float2half2_rn(0.0f));
{tab4_str}const half2 {y} = __hadd2(__hmul2(__hmul2(__hmul2(__hsub2(__float2half2_rn(1.0f), {sg_name}_sigmoid_ax), {sg_name}_sigmoid_ax), {sg_name}_alpha), {sg_name}_mask_l), __hmul2(__h2div(__float2half2_rn({beta}), __hadd2({x}, __float2half2_rn(1.0f))), __hsub2(__float2half2_rn(1.0f), {sg_name}_mask_l)));
'''
else:
raise NotImplementedError
code += f'''
{tab4_str}{self.cuda_code_end_comments()}
'''
return code

# plt.style.use(['science', 'muted', 'grid'])
# fig = plt.figure(dpi=200, figsize=(6, 4))
# x = torch.arange(-2.5, 2.5, 0.001)
# plt.plot(x.data, surrogate.heaviside(x), label='Heaviside', linestyle='-.')
# surrogate_function = surrogate.S2NN(alpha=4., beta=1., spiking=False)
# y = surrogate_function(x)
# plt.plot(x.data, y.data, label='Primitive, $\\alpha=4, \\beta=1$')
#
# surrogate_function = surrogate.S2NN(alpha=4, beta=1., spiking=True)
# x.requires_grad_(True)
# y = surrogate_function(x)
# z = y.sum()
# z.backward()
# plt.plot(x.data, x.grad, label='Gradient, $\\alpha=4, \\beta=1$')
# plt.xlim(-2, 2)
# plt.legend()
# plt.title('S2NN surrogate function')
# plt.xlabel('Input')
# plt.ylabel('Output')
# plt.grid(linestyle='--')
# # plt.show()
# plt.savefig('./S2NN.svg')
# plt.savefig('./S2NN.pdf')

class q_pseudo_spike(torch.autograd.Function):
@staticmethod
def forward(ctx, x, alpha):
if x.requires_grad:
ctx.save_for_backward(x)
ctx.alpha = alpha
return heaviside(x)

@staticmethod
def backward(ctx, grad_output):
grad_x = None
x = ctx.saved_tensors[0]
if ctx.needs_input_grad[0]:
grad_x = ((1 + 2 / (ctx.alpha - 1) * x.abs()).pow_(-ctx.alpha)) * grad_output
return grad_x, None

class QPseudoSpike(SurrogateFunctionBase):
def __init__(self, alpha=2.0, spiking=True):
'''
* :ref:`API in English <QPseudoSpike.__init__-en>`
.. _QPseudoSpike.__init__-cn:

:param alpha: 控制反向传播时梯度函数尾部厚度的参数
:param spiking: 是否输出脉冲,默认为 ``True``,在前向传播时使用 ``heaviside`` 而在反向传播使用替代梯度。若为 ``False``
则不使用替代梯度,前向传播时,使用反向传播时的梯度替代函数对应的原函数

`Surrogate Gradients Design <https://arxiv.org/abs/2202.00282>`_ 提出的 :math:`q`-PseudoSpike替代函数。反向传播为

.. math::
g'(x) = (1+\\frac{2|x|}{\\alpha-1})^{-\\alpha}

其中 :math:`\\alpha>1` 对应原文中的 :math:`q`。

对应的原函数为

.. math::
g(x) =
\\begin{cases}
\\frac{1}{2}(1-\\frac{2x}{\\alpha-1})^{1-\\alpha}, & x < 0 \\\\
1 - \\frac{1}{2}(1+\\frac{2x}{\\alpha-1})^{1-\\alpha}, & x \\geq 0.
\\end{cases}

.. image:: ./_static/API/clock_driven/surrogate/QPseudoSpike.*
:width: 100%

* :ref:`中文API <QPseudoSpike.__init__-cn>`
.. _QPseudoSpike.__init__-en:

:param alpha: parameter to control tail fatness of gradient
:param spiking: whether output spikes. The default is ``True`` which means that using ``heaviside`` in forward
propagation and using surrogate gradient in backward propagation. If ``False``, in forward propagation,
using the primitive function of the surrogate gradient function used in backward propagation

The :math:`q`-PseudoSpike surrogate spiking function, which is first proposed in `Surrogate Gradients Design <https://arxiv.org/abs/2202.00282>`_. The gradient is defined by

.. math::
g'(x) = (1+\\frac{2|x|}{\\alpha-1})^{-\\alpha}

where :math:`\\alpha>1` corresponds to :math:`q` in paper.

The primitive function is defined by

.. math::
g(x) =
\\begin{cases}
\\frac{1}{2}(1-\\frac{2x}{\\alpha-1})^{1-\\alpha}, & x < 0 \\\\
1 - \\frac{1}{2}(1+\\frac{2x}{\\alpha-1})^{1-\\alpha}, & x \\geq 0.
\\end{cases}

.. image:: ./_static/API/clock_driven/surrogate/QPseudoSpike.*
:width: 100%
'''
super().__init__(alpha, spiking)


@staticmethod
def spiking_function(x, alpha):
return q_pseudo_spike.apply(x, alpha)

@staticmethod
def primitive_function(x: torch.Tensor, alpha):
mask_nonnegative = heaviside(x)
mask_sign = mask_nonnegative * 2. - 1.

return mask_nonnegative - mask_sign * (0.5 * ((1. + 2. / (alpha - 1.) * x * mask_sign).pow_(1. - alpha)))

def cuda_code(self, x: str, y: str, dtype='fp32'):
sg_name = 'sg_' + self._get_name()
alpha = str(self.alpha) + 'f'
code = f'''
{tab4_str}{self.cuda_code_start_comments()}
'''

if dtype == 'fp32':
code += f'''
{tab4_str}const float {sg_name}_base = 1.0f + 2.0f / ({alpha} - 1.0f) * fabsf({x});
{tab4_str}const float {y} = powf({sg_name}_base, -{alpha});
'''
elif dtype == 'fp16':
code += f'''
{tab4_str}const half2 {sg_name}_alpha = __float2half2_rn({alpha});
{tab4_str}const half2 {sg_name}_base = __hadd2(__float2half2_rn(1.0f), __h2div(__hmul2(__float2half2_rn(2.0f), __habs2({x})), __hsub2({sg_name}_alpha, __float2half2_rn(1.0f))));
{tab4_str}const half2 {y} = h2exp2(__hmul2(h2log2({sg_name}_base), __hneg2({sg_name}_alpha))); // Replace power with combination of log and exp, since CUDA has no power function for FP16.
'''
else:
raise NotImplementedError
code += f'''
{tab4_str}{self.cuda_code_end_comments()}
'''
return code

# plt.style.use(['science', 'muted', 'grid'])
# fig = plt.figure(dpi=200, figsize=(6, 4))
# x = torch.arange(-2.5, 2.5, 0.001)
# plt.plot(x.data, surrogate.heaviside(x), label='Heaviside', linestyle='-.')
# surrogate_function = surrogate.QPseudoSpike(alpha=2, spiking=False)
# y = surrogate_function(x)
# plt.plot(x.data, y.data, label='Primitive, $\\alpha=2$')

# surrogate_function = surrogate.QPseudoSpike(alpha=2, spiking=True)
# x.requires_grad_(True)
# y = surrogate_function(x)
# z = y.sum()
# z.backward()
# plt.plot(x.data, x.grad, label='Gradient, $\\alpha=2$')
# plt.xlim(-2, 2)
# plt.legend()
# plt.title('QPseudoSpike surrogate function')
# plt.xlabel('Input')
# plt.ylabel('Output')
# plt.grid(linestyle='--')
# # plt.savefig('QPseudoSpike.svg')
# # plt.savefig('QPseudoSpike.pdf')

_has_cuda_ = [
ATan,
Sigmoid,
PiecewiseLeakyReLU,
S2NN,
QPseudoSpike
]

+ 212
- 0
spikingjelly/clock_driven/tensor_cache.py View File

@@ -0,0 +1,212 @@
import torch
import torch.nn.functional as F
import threading
from .. import configure
from . import cu_kernel_opt
import logging
try:
import cupy
except BaseException as e:
logging.info(f'spikingjelly.clock_driven.tensor_cache: {e}')
cupy = None

class DataTypeConvertCUDACode:
float2bool = r'''
extern "C" __global__
void float2bool(const float* fs, unsigned char* bs, const int &N)
{
// assert N == numel / 8 and numel % 8 == 0
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
bs[index] = 0;
const int mem_offset = (index << 3);
#pragma unroll
for(int i = 0; i < 8; i++)
{
bs[index] += ( ((unsigned char) fs[mem_offset + i]) << i);
}
}
}
'''

half2bool = r'''
#include <cuda_fp16.h>
extern "C" __global__
void half2bool(const half* fs, unsigned char* bs, const int &N)
{
// assert N == numel / 8 and numel % 8 == 0
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
bs[index] = 0;
const int mem_offset = (index << 3);
#pragma unroll
for(int i = 0; i < 8; i++)
{
bs[index] += ( ((unsigned char) __half2float(fs[mem_offset + i])) << i);
}
}
}
'''

bool2float = r'''
extern "C" __global__
void bool2float(const unsigned char* bs, float* fs, const int &N)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
const int mem_offset = (index << 3);
unsigned char compressed_v = bs[index];
#pragma unroll
for(int i = 0; i < 8; i++)
{
fs[mem_offset + i] = (float) (compressed_v % 2);
compressed_v = (compressed_v >> 1);
}
}
}
'''

bool2half = r'''
#include <cuda_fp16.h>
extern "C" __global__
void bool2half(const unsigned char* bs, half* fs, const int &N)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
const int mem_offset = (index << 3);
unsigned char compressed_v = bs[index];
#pragma unroll
for(int i = 0; i < 8; i++)
{
fs[mem_offset + i] = __float2half((float) (compressed_v % 2));
compressed_v = (compressed_v >> 1);
}
}
}
'''
def float_spike_to_bool(spike: torch.Tensor):
s_dtype = spike.dtype
if s_dtype == torch.float:
kernel_codes = DataTypeConvertCUDACode.float2bool
kernel_name = 'float2bool'
elif s_dtype == torch.half:
kernel_codes = DataTypeConvertCUDACode.half2bool
kernel_name = 'half2bool'
else:
raise NotImplementedError

s_shape = spike.shape

spike = spike.flatten()
s_padding = 8 - spike.numel() % 8
if s_padding != 0:
spike = F.pad(spike, (0, s_padding))
device_id = spike.get_device()
spike_b = torch.zeros([spike.numel() // 8], device=spike.device, dtype=torch.uint8)
with cu_kernel_opt.DeviceEnvironment(device_id):
numel = spike_b.numel()
blocks = cu_kernel_opt.cal_blocks(numel)
numel = cupy.asarray(numel)
spike, spike_b, numel = cu_kernel_opt.get_contiguous(spike, spike_b, numel)
kernel_args = [spike, spike_b, numel]
kernel = cupy.RawKernel(
kernel_codes,
kernel_name,
options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend
)
kernel(
(blocks,), (configure.cuda_threads,),
cu_kernel_opt.wrap_args_to_raw_kernel(
device_id,
*kernel_args
)
)
return spike_b, s_dtype, s_shape, s_padding

def bool_spike_to_float(spike_b: torch.Tensor, s_dtype: torch.dtype, s_shape: torch.Size, s_padding: int = 0):
device_id = spike_b.get_device()
spike = torch.zeros(spike_b.numel() * 8, device=spike_b.device, dtype=s_dtype)
if s_dtype == torch.float:
kernel_codes = DataTypeConvertCUDACode.bool2float
kernel_name = 'bool2float'
elif s_dtype == torch.half:
kernel_codes = DataTypeConvertCUDACode.bool2half
kernel_name = 'bool2half'
else:
raise NotImplementedError
with cu_kernel_opt.DeviceEnvironment(device_id):
numel = spike_b.numel()
blocks = cu_kernel_opt.cal_blocks(numel)
numel = cupy.asarray(numel)
spike_b, spike, numel = cu_kernel_opt.get_contiguous(spike_b, spike, numel)
kernel_args = [spike_b, spike, numel]
kernel = cupy.RawKernel(
kernel_codes,
kernel_name,
options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend
)
kernel(
(blocks,), (configure.cuda_threads,),
cu_kernel_opt.wrap_args_to_raw_kernel(
device_id,
*kernel_args
)
)
if s_padding is not None and s_padding != 0:
spike = spike[0: spike.numel() - s_padding]
return spike.reshape(s_shape)


def tensor_key(x: torch.Tensor):
x = x.flatten()
return x.data_ptr(), x[-1].data_ptr(), x.numel()

class BoolTensorCache:
def __init__(self):
super().__init__()
self.cache_dict = {}
self.cache_refcount_dict = {}
self.lock = threading.Lock()

def store_bool(self, spike: torch.FloatTensor or torch.HalfTensor):
tk = tensor_key(spike)

self.lock.acquire()
if tk not in self.cache_dict:
if configure.save_bool_spike_level == 0:
self.cache_dict[tk] = (spike.bool(), spike.dtype)
elif configure.save_bool_spike_level == 1:
self.cache_dict[tk] = float_spike_to_bool(spike)
else:
raise NotImplementedError
self.cache_refcount_dict[tk] = 1
else:
self.cache_refcount_dict[tk] += 1
self.lock.release()

return tk

def get_float(self, tk, spike_shape: torch.Size):
if configure.save_bool_spike_level == 0:
spike, s_dtype = self.cache_dict[tk]
spike = spike.to(s_dtype)
elif configure.save_bool_spike_level == 1:
spike = bool_spike_to_float(*self.cache_dict[tk])
else:
raise NotImplementedError

self.lock.acquire()
self.cache_refcount_dict[tk] -= 1
if self.cache_refcount_dict[tk] == 0:
del self.cache_refcount_dict[tk]
del self.cache_dict[tk]
self.lock.release()

return spike.view(spike_shape)


BOOL_TENSOR_CACHE = BoolTensorCache()

+ 28
- 3
spikingjelly/configure.py View File

@@ -1,7 +1,17 @@
# This py file defines some variables used in SpikingJelly.
# The user can change them and install SpikingJelly manually.
'''
This py file defines some variables used in SpikingJelly.
Here is an example of how you can change them to make effect in your codes:

import spikingjelly
spikingjelly.configure.cuda_threads = 512

Do not change them in this way, which will not make effect:

from spikingjelly.configure import cuda_threads
cuda_threads = 512

max_threads_number_for_datasets_preprocess = 4
'''
max_threads_number_for_datasets_preprocess = 16
'''
`max_threads_number_for_datasets_preprocess` defines the maximum threads for datasets preprocessing, which is
1. reading binary events and saving them to numpy format
@@ -41,4 +51,19 @@ If `save_datasets_compressed == True`, events and frames in spikingjelly.dataset
The compressed npz file consumes less memory in disk but more time in reading.
'''

save_spike_as_bool_in_neuron_kernel = False
'''
If `save_spike_as_bool_in_neuron_kernel == True`, the neuron kernel used in the neuron's cupy backend will save the spike as a bool, rather than float/half tensor for backward, which can reduce the memory consumption.
'''

save_bool_spike_level = 0
'''
`save_bool_spike_level` take effects on SpikeConv/SpikeLinear, and on neuron's cupy kernel when `save_spike_as_bool_in_neuron_kernel == True`.

If `save_bool_spike_level == 0`, spikes will be saved in bool. Note that bool uses 8-bit, rather than 1-bit.

If `save_bool_spike_level == 1`, spikes will be saved in uint8 with each 8-bit storing 8 spikes.

A larger `save_bool_spike_level` means less memory consumption but slower speed.
'''


+ 116
- 119
spikingjelly/datasets/__init__.py View File

@@ -1,6 +1,5 @@
import torchvision.transforms
from torchvision.datasets import DatasetFolder
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from typing import Callable, Dict, Optional, Tuple
from abc import abstractmethod
import scipy.io
import struct
@@ -10,18 +9,19 @@ import torch.utils.data
import os
from concurrent.futures import ThreadPoolExecutor
import time
import multiprocessing
from torchvision import transforms
import torch
from matplotlib import pyplot as plt
import math
import tqdm
from ..configure import max_threads_number_for_datasets_preprocess, cuda_threads, cuda_compiler_options, cuda_compiler_backend, save_datasets_compressed
np_savez = np.savez_compressed if save_datasets_compressed else np.savez
import shutil
from .. import configure
import logging
np_savez = np.savez_compressed if configure.save_datasets_compressed else np.savez

try:
import cupy
from spikingjelly.clock_driven import cu_kernel_opt
from ..clock_driven import cu_kernel_opt

padded_sequence_mask_kernel_code = r'''
extern "C" __global__
@@ -37,7 +37,8 @@ try:
}
}
'''
except ImportError:
except BaseException as e:
logging.info(f'spikingjelly.dataset.__init__: {e}')
cupy = None
pass

@@ -91,7 +92,6 @@ def load_aedat_v3(file_name: str) -> Dict:
:type file_name: str
:return: a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``
:rtype: Dict

This function is written by referring to https://gitlab.com/inivation/dv/dv-python . It can be used for DVS128 Gesture.
'''
with open(file_name, 'rb') as bin_f:
@@ -156,19 +156,12 @@ def load_ATIS_bin(file_name: str) -> Dict:
:type file_name: str
:return: a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``
:rtype: Dict

This function is written by referring to https://github.com/jackd/events-tfds .

Each ATIS binary example is a separate binary file consisting of a list of events. Each event occupies 40 bits as described below:

bit 39 - 32: Xaddress (in pixels)

bit 31 - 24: Yaddress (in pixels)

bit 23: Polarity (0 for OFF, 1 for ON)

bit 22 - 0: Timestamp (in microseconds)

'''
with open(file_name, 'rb') as bin_f:
# `& 128` 是取一个8位二进制数的最高位
@@ -191,10 +184,14 @@ def load_npz_frames(file_name: str) -> np.ndarray:
'''
return np.load(file_name, allow_pickle=True)['frames']

def integrate_events_segment_to_frame(events: Dict, H: int, W: int, j_l: int = 0, j_r: int = -1) -> np.ndarray:
def integrate_events_segment_to_frame(x: np.ndarray, y: np.ndarray, p: np.ndarray, H: int, W: int, j_l: int = 0, j_r: int = -1) -> np.ndarray:
'''
:param events: a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``
:type events: Dict
:param x: x-coordinate of events
:type x: numpy.ndarray
:param y: y-coordinate of events
:type y: numpy.ndarray
:param p: polarity of events
:type p: numpy.ndarray
:param H: height of the frame
:type H: int
:param W: weight of the frame
@@ -205,13 +202,9 @@ def integrate_events_segment_to_frame(events: Dict, H: int, W: int, j_l: int = 0
:type j_r:
:return: frames
:rtype: np.ndarray

Denote a two channels frame as :math:`F` and a pixel at :math:`(p, x, y)` as :math:`F(p, x, y)`, the pixel value is integrated from the events data whose indices are in :math:`[j_{l}, j_{r})`:

.. math::

F(p, x, y) = \sum_{i = j_{l}}^{j_{r} - 1} \mathcal{I}_{p, x, y}(p_{i}, x_{i}, y_{i})

where :math:`\lfloor \cdot \rfloor` is the floor operation, :math:`\mathcal{I}_{p, x, y}(p_{i}, x_{i}, y_{i})` is an indicator function and it equals 1 only when :math:`(p, x, y) = (p_{i}, x_{i}, y_{i})`.
'''
# 累计脉冲需要用bitcount而不能直接相加,原因可参考下面的示例代码,以及
@@ -253,9 +246,9 @@ def integrate_events_segment_to_frame(events: Dict, H: int, W: int, j_l: int = 0
# print('correct accumulation by bincount\n', frames)

frame = np.zeros(shape=[2, H * W])
x = events['x'][j_l: j_r].astype(int) # avoid overflow
y = events['y'][j_l: j_r].astype(int)
p = events['p'][j_l: j_r]
x = x[j_l: j_r].astype(int) # avoid overflow
y = y[j_l: j_r].astype(int)
p = p[j_l: j_r]
mask = []
mask.append(p == 0)
mask.append(np.logical_not(mask[0]))
@@ -275,17 +268,12 @@ def cal_fixed_frames_number_segment_index(events_t: np.ndarray, split_by: str, f
:type frames_num: int
:return: a tuple ``(j_l, j_r)``
:rtype: tuple

Denote ``frames_num`` as :math:`M`, if ``split_by`` is ``'time'``, then

.. math::

\\Delta T & = [\\frac{t_{N-1} - t_{0}}{M}] \\\\
j_{l} & = \\mathop{\\arg\\min}\\limits_{k} \\{t_{k} | t_{k} \\geq t_{0} + \\Delta T \\cdot j\\} \\\\
j_{r} & = \\begin{cases} \\mathop{\\arg\\max}\\limits_{k} \\{t_{k} | t_{k} < t_{0} + \\Delta T \\cdot (j + 1)\\} + 1, & j < M - 1 \\cr N, & j = M - 1 \\end{cases}

If ``split_by`` is ``'number'``, then

.. math::
j_{l} & = [\\frac{N}{M}] \\cdot j \\\\
j_{r} & = \\begin{cases} [\\frac{N}{M}] \\cdot (j + 1), & j < M - 1 \\cr N, & j = M - 1 \\end{cases}
@@ -332,17 +320,19 @@ def integrate_events_by_fixed_frames_number(events: Dict, split_by: str, frames_
:type W: int
:return: frames
:rtype: np.ndarray

Integrate events to frames by fixed frames number. See :class:`cal_fixed_frames_number_segment_index` and :class:`integrate_events_segment_to_frame` for more details.
'''
j_l, j_r = cal_fixed_frames_number_segment_index(events['t'], split_by, frames_num)
t, x, y, p = (events[key] for key in ('t', 'x', 'y', 'p'))
j_l, j_r = cal_fixed_frames_number_segment_index(t, split_by, frames_num)
frames = np.zeros([frames_num, 2, H, W])
for i in range(frames_num):
frames[i] = integrate_events_segment_to_frame(events, H, W, j_l[i], j_r[i])
frames[i] = integrate_events_segment_to_frame(x, y, p, H, W, j_l[i], j_r[i])
return frames

def integrate_events_file_to_frames_file_by_fixed_frames_number(events_np_file: str, output_dir: str, split_by: str, frames_num: int, H: int, W: int, print_save: bool = False) -> None:
def integrate_events_file_to_frames_file_by_fixed_frames_number(loader: Callable, events_np_file: str, output_dir: str, split_by: str, frames_num: int, H: int, W: int, print_save: bool = False) -> None:
'''
:param loader: a function that can load events from `events_np_file`
:type loader: Callable
:param events_np_file: path of the events np file
:type events_np_file: str
:param output_dir: output directory for saving the frames
@@ -358,11 +348,10 @@ def integrate_events_file_to_frames_file_by_fixed_frames_number(events_np_file:
:param print_save: If ``True``, this function will print saved files' paths.
:type print_save: bool
:return: None

Integrate a events file to frames by fixed frames number and save it. See :class:`cal_fixed_frames_number_segment_index` and :class:`integrate_events_segment_to_frame` for more details.
'''
fname = os.path.join(output_dir, os.path.basename(events_np_file))
np_savez(fname, frames=integrate_events_by_fixed_frames_number(np.load(events_np_file), split_by, frames_num, H, W))
np_savez(fname, frames=integrate_events_by_fixed_frames_number(loader(events_np_file), split_by, frames_num, H, W))
if print_save:
print(f'Frames [{fname}] saved.')

@@ -380,10 +369,12 @@ def integrate_events_by_fixed_duration(events: Dict, duration: int, H: int, W: i
:type W: int
:return: frames
:rtype: np.ndarray

Integrate events to frames by fixed time duration of each frame.
'''
x = events['x']
y = events['y']
t = events['t']
p = events['p']
N = t.size

frames = []
@@ -397,15 +388,17 @@ def integrate_events_by_fixed_duration(events: Dict, duration: int, H: int, W: i
else:
right += 1
# integrate from index [left, right)
frames.append(np.expand_dims(integrate_events_segment_to_frame(events, H, W, left, right), 0))
frames.append(np.expand_dims(integrate_events_segment_to_frame(x, y, p, H, W, left, right), 0))

left = right

if right == N:
return np.concatenate(frames)

def integrate_events_file_to_frames_file_by_fixed_duration(events_np_file: str, output_dir: str, duration: int, H: int, W: int, print_save: bool = False) -> None:
def integrate_events_file_to_frames_file_by_fixed_duration(loader: Callable, events_np_file: str, output_dir: str, duration: int, H: int, W: int, print_save: bool = False) -> None:
'''
:param loader: a function that can load events from `events_np_file`
:type loader: Callable
:param events_np_file: path of the events np file
:type events_np_file: str
:param output_dir: output directory for saving the frames
@@ -419,10 +412,9 @@ def integrate_events_file_to_frames_file_by_fixed_duration(events_np_file: str,
:param print_save: If ``True``, this function will print saved files' paths.
:type print_save: bool
:return: None

Integrate events to frames by fixed time duration of each frame.
'''
frames = integrate_events_by_fixed_duration(np.load(events_np_file), duration, H, W)
frames = integrate_events_by_fixed_duration(loader(events_np_file), duration, H, W)
fname, _ = os.path.splitext(os.path.basename(events_np_file))
fname = os.path.join(output_dir, f'{fname}_{frames.shape[0]}.npz')
np_savez(fname, frames=frames)
@@ -441,7 +433,6 @@ def create_same_directory_structure(source_dir: str, target_dir: str) -> None:
:param target_dir: Path of the directory that be copied to
:type target_dir: str
:return: None

Create the same directory structure in ``target_dir`` with that of ``source_dir``.
'''
for sub_dir_name in os.listdir(source_dir):
@@ -492,40 +483,32 @@ def split_to_train_test_set(train_ratio: float, origin_dataset: torch.utils.data

def pad_sequence_collate(batch: list):
'''
:param batch: a list of samples that contains ``(x, y)``, where ``x.shape=[T, *]`` and ``y`` is the label
:param batch: a list of samples that contains ``(x, y)``, where ``x`` is a list containing sequences with different length and ``y`` is the label
:type batch: list
:return: batched samples, where ``x`` is padded with the same length
:return: batched samples ``(x_p, y, x_len), where ``x_p`` is padded ``x`` with the same length, `y`` is the label, and ``x_len`` is the length of the ``x``
:rtype: tuple

This function can be use as the ``collate_fn`` for ``DataLoader`` to process the dataset with variable length, e.g., a ``NeuromorphicDatasetFolder`` with fixed duration to integrate events to frames.

Here is an example:

.. code-block:: python

class RandomLengthDataset(torch.utils.data.Dataset):
def __init__(self, n=1000):
super().__init__()
self.n = n

def __getitem__(self, i):
return torch.rand([random.randint(1, 10), 28, 28]), random.randint(0, 10)

def __len__(self):
return self.n

loader = torch.utils.data.DataLoader(RandomLengthDataset(n=32), batch_size=16, collate_fn=pad_sequence_collate)

for x, y, z in loader:
print(x.shape, y.shape, z)

class VariableLengthDataset(torch.utils.data.Dataset):
def __init__(self, n=1000):
super().__init__()
self.n = n
def __getitem__(self, i):
return torch.rand([i + 1, 2]), self.n - i - 1
def __len__(self):
return self.n
loader = torch.utils.data.DataLoader(VariableLengthDataset(n=32), batch_size=2, collate_fn=pad_sequence_collate,
shuffle=True)
for i, (x_p, label, x_len) in enumerate(loader):
print(f'x_p.shape={x_p.shape}, label={label}, x_len={x_len}')
if i == 2:
break
And the outputs are:

.. code-block:: bash

torch.Size([10, 16, 28, 28]) torch.Size([16]) tensor([ 1, 9, 3, 4, 1, 2, 9, 7, 2, 1, 5, 7, 4, 10, 9, 5])
torch.Size([10, 16, 28, 28]) torch.Size([16]) tensor([ 1, 8, 7, 10, 3, 10, 6, 7, 5, 9, 10, 5, 9, 6, 7, 6])

x_p.shape=torch.Size([2, 18, 2]), label=tensor([14, 30]), x_len=tensor([18, 2])
x_p.shape=torch.Size([2, 29, 2]), label=tensor([3, 6]), x_len=tensor([29, 26])
x_p.shape=torch.Size([2, 23, 2]), label=tensor([ 9, 23]), x_len=tensor([23, 9])
'''
x_list = []
x_len_list = []
@@ -545,11 +528,8 @@ def padded_sequence_mask(sequence_len: torch.Tensor, T=None):
:type T: int
:return: a bool mask with shape = [T, N], where the padded position is ``False``
:rtype: torch.Tensor

Here is an example:

.. code-block:: python

x1 = torch.rand([2, 6])
x2 = torch.rand([3, 6])
x3 = torch.rand([4, 6])
@@ -559,11 +539,8 @@ def padded_sequence_mask(sequence_len: torch.Tensor, T=None):
mask = padded_sequence_mask(x_len)
print('mask.shape=', mask.shape)
print('mask=\\n', mask)

And the outputs are:

.. code-block:: bash

x.shape= torch.Size([4, 3, 6])
mask.shape= torch.Size([4, 3])
mask=
@@ -571,7 +548,6 @@ def padded_sequence_mask(sequence_len: torch.Tensor, T=None):
[ True, True, True],
[False, True, True],
[False, False, True]])

'''
if T is None:
T = sequence_len.max().item()
@@ -580,15 +556,15 @@ def padded_sequence_mask(sequence_len: torch.Tensor, T=None):

if device_id >= 0 and cupy is not None:
mask = torch.zeros([T, N], dtype=bool, device=sequence_len.device)
with cupy.cuda.Device(device_id):
with cu_kernel_opt.DeviceEnvironment(device_id):
T = cupy.asarray(T)
N = cupy.asarray(N)
sequence_len, mask, T, N = cu_kernel_opt.get_contiguous(sequence_len.to(torch.int), mask, T, N)
kernel_args = [sequence_len, mask, T, N]
kernel = cupy.RawKernel(padded_sequence_mask_kernel_code, 'padded_sequence_mask_kernel', options=cuda_compiler_options, backend=cuda_compiler_backend)
kernel = cupy.RawKernel(padded_sequence_mask_kernel_code, 'padded_sequence_mask_kernel', options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend)
blocks = cu_kernel_opt.cal_blocks(N)
kernel(
(blocks,), (cuda_threads,),
(blocks,), (configure.cuda_threads,),
cu_kernel_opt.wrap_args_to_raw_kernel(
device_id,
*kernel_args
@@ -645,27 +621,20 @@ class NeuromorphicDatasetFolder(DatasetFolder):
:param target_transform: a function/transform that takes
in the target and transforms it.
:type target_transform: callable

The base class for neuromorphic dataset. Users can define a new dataset by inheriting this class and implementing
all abstract methods. Users can refer to :class:`spikingjelly.datasets.dvs128_gesture.DVS128Gesture`.

If ``data_type == 'event'``
the sample in this dataset is a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``.

If ``data_type == 'frame'`` and ``frames_number`` is not ``None``
events will be integrated to frames with fixed frames number. ``split_by`` will define how to split events.
See :class:`cal_fixed_frames_number_segment_index` for
more details.

If ``data_type == 'frame'``, ``frames_number`` is ``None``, and ``duration`` is not ``None``
events will be integrated to frames with fixed time duration.

If ``data_type == 'frame'``, ``frames_number`` is ``None``, ``duration`` is ``None``, and ``custom_integrate_function`` is not ``None``:
events will be integrated by the user-defined function and saved to the ``custom_integrated_frames_dir_name`` directory in ``root`` directory.
Here is an example from SpikingJelly's tutorials:

.. code-block:: python

from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
from typing import Dict
import numpy as np
@@ -673,13 +642,12 @@ class NeuromorphicDatasetFolder(DatasetFolder):
def integrate_events_to_2_frames_randomly(events: Dict, H: int, W: int):
index_split = np.random.randint(low=0, high=events['t'].__len__())
frames = np.zeros([2, 2, H, W])
frames[0] = sjds.integrate_events_segment_to_frame(events, H, W, 0, index_split)
frames[1] = sjds.integrate_events_segment_to_frame(events, H, W, index_split, events['t'].__len__())
t, x, y, p = (events[key] for key in ('t', 'x', 'y', 'p'))
frames[0] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, 0, index_split)
frames[1] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, index_split, events['t'].__len__())
return frames

root_dir = 'D:/datasets/DVS128Gesture'
train_set = DVS128Gesture(root_dir, train=True, data_type='frame', custom_integrate_function=integrate_events_to_2_frames_randomly)

from spikingjelly.datasets import play_frame
frame, label = train_set[500]
play_frame(frame)
@@ -773,7 +741,7 @@ class NeuromorphicDatasetFolder(DatasetFolder):

# use multi-thread to accelerate
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), max_threads_number_for_datasets_preprocess)) as tpe:
with ThreadPoolExecutor(max_workers=configure.max_threads_number_for_datasets_preprocess) as tpe:
print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].')
for e_root, e_dirs, e_files in os.walk(events_np_root):
if e_files.__len__() > 0:
@@ -781,7 +749,7 @@ class NeuromorphicDatasetFolder(DatasetFolder):
for e_file in e_files:
events_np_file = os.path.join(e_root, e_file)
print(f'Start to integrate [{events_np_file}] to frames and save to [{output_dir}].')
tpe.submit(integrate_events_file_to_frames_file_by_fixed_frames_number, events_np_file, output_dir, split_by, frames_number, H, W, True)
tpe.submit(integrate_events_file_to_frames_file_by_fixed_frames_number, self.load_events_np, events_np_file, output_dir, split_by, frames_number, H, W, True)

print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')

@@ -803,7 +771,7 @@ class NeuromorphicDatasetFolder(DatasetFolder):
create_same_directory_structure(events_np_root, frames_np_root)
# use multi-thread to accelerate
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), max_threads_number_for_datasets_preprocess)) as tpe:
with ThreadPoolExecutor(max_workers=configure.max_threads_number_for_datasets_preprocess) as tpe:
print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].')
for e_root, e_dirs, e_files in os.walk(events_np_root):
if e_files.__len__() > 0:
@@ -811,7 +779,7 @@ class NeuromorphicDatasetFolder(DatasetFolder):
for e_file in e_files:
events_np_file = os.path.join(e_root, e_file)
print(f'Start to integrate [{events_np_file}] to frames and save to [{output_dir}].')
tpe.submit(integrate_events_file_to_frames_file_by_fixed_duration, events_np_file, output_dir, duration, H, W, True)
tpe.submit(integrate_events_file_to_frames_file_by_fixed_duration, self.load_events_np, events_np_file, output_dir, duration, H, W, True)

print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')

@@ -834,7 +802,7 @@ class NeuromorphicDatasetFolder(DatasetFolder):
create_same_directory_structure(events_np_root, frames_np_root)
# use multi-thread to accelerate
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), max_threads_number_for_datasets_preprocess)) as tpe:
with ThreadPoolExecutor(max_workers=configure.max_threads_number_for_datasets_preprocess) as tpe:
print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].')
for e_root, e_dirs, e_files in os.walk(events_np_root):
if e_files.__len__() > 0:
@@ -865,19 +833,6 @@ class NeuromorphicDatasetFolder(DatasetFolder):
super().__init__(root=_root, loader=_loader, extensions=('.npz', ), transform=_transform,
target_transform=_target_transform)

@staticmethod
@abstractmethod
def load_origin_data(file_name: str) -> Dict:
'''
:param file_name: path of the events file
:type file_name: str
:return: a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``
:rtype: Dict

This function defines how to read the origin binary data.
'''
pass

@staticmethod
@abstractmethod
def resource_url_md5() -> list:
@@ -905,7 +860,6 @@ class NeuromorphicDatasetFolder(DatasetFolder):
:param extract_root: Root directory path which saves extracted files from downloaded files
:type extract_root: str
:return: None

This function defines how to extract download files.
'''
pass
@@ -919,7 +873,6 @@ class NeuromorphicDatasetFolder(DatasetFolder):
:param events_np_root: Root directory path which saves events files in the ``npz`` format
:type events_np_root:
:return: None

This function defines how to convert the origin binary data in ``extract_root`` to ``npz`` format and save converted files in ``events_np_root``.
'''
pass
@@ -934,6 +887,16 @@ class NeuromorphicDatasetFolder(DatasetFolder):
'''
pass

@staticmethod
def load_events_np(fname: str):
'''
:param fname: file name
:return: a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``
This function defines how to load a sample from `events_np`. In most cases, this function is `np.load`.
But for some datasets, e.g., ES-ImageNet, it can be different.
'''
return np.load(fname)



def random_temporal_delete(x_seq: torch.Tensor or np.ndarray, T_remain: int, batch_first):
@@ -946,13 +909,9 @@ def random_temporal_delete(x_seq: torch.Tensor or np.ndarray, T_remain: int, bat
:type batch_first: bool
:return: the sequence with length `T_remain`, which is obtained by randomly removing `T - T_remain` slices
:rtype: torch.Tensor or np.ndarray

The random temporal delete data augmentation used in `Deep Residual Learning in Spiking Neural Networks <https://arxiv.org/abs/2102.04159>`_.

Codes example:

.. code-block:: python

import torch
from spikingjelly.datasets import random_temporal_delete
T = 8
@@ -961,11 +920,8 @@ def random_temporal_delete(x_seq: torch.Tensor or np.ndarray, T_remain: int, bat
x_seq = torch.arange(0, N*T).view([N, T])
print('x_seq=\\n', x_seq)
print('random_temporal_delete(x_seq)=\\n', random_temporal_delete(x_seq, T_remain, batch_first=True))

Outputs:

.. code-block:: shell

x_seq=
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15],
@@ -994,9 +950,7 @@ class RandomTemporalDelete(torch.nn.Module):
:type T_remain: int
:type T_remain: int
:param batch_first: if `True`, `x_seq` will be regarded as `shape = [N, T, *]`

The random temporal delete data augmentation used in `Deep Residual Learning in Spiking Neural Networks <https://arxiv.org/abs/2102.04159>`_.

Refer to :class:`random_temporal_delete` for more details.
"""
super().__init__()
@@ -1007,3 +961,46 @@ class RandomTemporalDelete(torch.nn.Module):
return random_temporal_delete(x_seq, self.T_remain, self.batch_first)


def create_sub_dataset(source_dir: str, target_dir:str, ratio: float, use_soft_link=True, randomly=False):
"""
:param source_dir: the directory path of the origin dataset
:type source_dir: str
:param target_dir: the directory path of the sub dataset
:type target_dir: str
:param ratio: the ratio of samples sub dataset will copy from the origin dataset
:type ratio: float
:param use_soft_link: if ``True``, the sub dataset will use soft link to copy; else, the sub dataset will copy files
:type use_soft_link: bool
:param randomly: if ``True``, the files copy from the origin dataset will be picked up randomly. The randomness is controlled by
``numpy.random.seed``
:type randomly: bool
Create a sub dataset with copy ``ratio`` of samples from the origin dataset.
"""
if not os.path.exists(target_dir):
os.makedirs(target_dir)
print(f'Mkdir [{target_dir}].')
create_same_directory_structure(source_dir, target_dir)
warnings_info = []
for e_root, e_dirs, e_files in os.walk(source_dir, followlinks=True):
if e_files.__len__() > 0:
output_dir = os.path.join(target_dir, os.path.relpath(e_root, source_dir))
samples_number = int(ratio * e_files.__len__())
if samples_number == 0:
warnings_info.append(f'Warning: the samples number is 0 in [{output_dir}].')
if randomly:
np.random.shuffle(e_files)
for i, e_file in enumerate(e_files):
if i >= samples_number:
break
source_file = os.path.join(e_root, e_file)
target_file = os.path.join(output_dir, os.path.basename(source_file))
if use_soft_link:
os.symlink(source_file, target_file)
# print(f'symlink {source_file} -> {target_file}')
else:
shutil.copyfile(source_file, target_file)
# print(f'copyfile {source_file} -> {target_file}')
print(f'[{samples_number}] files in [{e_root}] have been copied to [{output_dir}].')

for i in range(warnings_info.__len__()):
print(warnings_info[i])

+ 9
- 68
spikingjelly/datasets/asl_dvs.py View File

@@ -1,5 +1,4 @@
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
import numpy as np
from typing import Callable, Dict, Optional, Tuple
import spikingjelly.datasets as sjds
from torchvision.datasets.utils import extract_archive
import os
@@ -7,8 +6,8 @@ import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import time
import shutil
from ..configure import max_threads_number_for_datasets_preprocess
from spikingjelly.datasets import np_savez
from .. import configure
from ..datasets import np_savez

class ASLDVS(sjds.NeuromorphicDatasetFolder):
def __init__(
@@ -23,69 +22,11 @@ class ASLDVS(sjds.NeuromorphicDatasetFolder):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
'''
:param root: root path of the dataset
:type root: str
:param data_type: `event` or `frame`
:type data_type: str
:param frames_number: the integrated frame number
:type frames_number: int
:param split_by: `time` or `number`
:type split_by: str
:param duration: the time duration of each frame
:type duration: int
:param custom_integrate_function: a user-defined function that inputs are ``events, H, W``.
``events`` is a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``
``H`` is the height of the data and ``W`` is the weight of the data.
For example, H=128 and W=128 for the DVS128 Gesture dataset.
The user should define how to integrate events to frames, and return frames.
:type custom_integrate_function: Callable
:param custom_integrated_frames_dir_name: The name of directory for saving frames integrating by ``custom_integrate_function``.
If ``custom_integrated_frames_dir_name`` is ``None``, it will be set to ``custom_integrate_function.__name__``
:type custom_integrated_frames_dir_name: str or None
:param transform: a function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
:type transform: callable
:param target_transform: a function/transform that takes
in the target and transforms it.
:type target_transform: callable

If ``data_type == 'event'``
the sample in this dataset is a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``.

If ``data_type == 'frame'`` and ``frames_number`` is not ``None``
events will be integrated to frames with fixed frames number. ``split_by`` will define how to split events.
See :class:`spikingjelly.datasets.cal_fixed_frames_number_segment_index` for
more details.

If ``data_type == 'frame'``, ``frames_number`` is ``None``, and ``duration`` is not ``None``
events will be integrated to frames with fixed time duration.

If ``data_type == 'frame'``, ``frames_number`` is ``None``, ``duration`` is ``None``, and ``custom_integrate_function`` is not ``None``:
events will be integrated by the user-defined function and saved to the ``custom_integrated_frames_dir_name`` directory in ``root`` directory.
Here is an example from SpikingJelly's tutorials:

.. code-block:: python

from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
from typing import Dict
import numpy as np
import spikingjelly.datasets as sjds
def integrate_events_to_2_frames_randomly(events: Dict, H: int, W: int):
index_split = np.random.randint(low=0, high=events['t'].__len__())
frames = np.zeros([2, 2, H, W])
frames[0] = sjds.integrate_events_segment_to_frame(events, H, W, 0, index_split)
frames[1] = sjds.integrate_events_segment_to_frame(events, H, W, index_split, events['t'].__len__())
return frames

root_dir = 'D:/datasets/DVS128Gesture'
train_set = DVS128Gesture(root_dir, train=True, data_type='frame', custom_integrate_function=integrate_events_to_2_frames_randomly)

from spikingjelly.datasets import play_frame
frame, label = train_set[500]
play_frame(frame)
'''
"""
The ASL-DVS dataset, which is proposed by `Graph-based Object Classification for Neuromorphic Vision Sensing <https://openaccess.thecvf.com/content_ICCV_2019/html/Bi_Graph-Based_Object_Classification_for_Neuromorphic_Vision_Sensing_ICCV_2019_paper.html>`_.

Refer to :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` for more details about params information.
"""
super().__init__(root, None, data_type, frames_number, split_by, duration, custom_integrate_function, custom_integrated_frames_dir_name, transform,
target_transform)
@staticmethod
@@ -179,7 +120,7 @@ class ASLDVS(sjds.NeuromorphicDatasetFolder):
This function defines how to convert the origin binary data in ``extract_root`` to ``npz`` format and save converted files in ``events_np_root``.
'''
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), max_threads_number_for_datasets_preprocess)) as tpe:
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), configure.max_threads_number_for_datasets_preprocess)) as tpe:
for class_name in os.listdir(extract_root):
mat_dir = os.path.join(extract_root, class_name)
np_dir = os.path.join(events_np_root, class_name)


+ 10
- 67
spikingjelly/datasets/cifar10_dvs.py View File

@@ -1,13 +1,13 @@
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from typing import Callable, Dict, Optional, Tuple
import numpy as np
import spikingjelly.datasets as sjds
from .. import datasets as sjds
from torchvision.datasets.utils import extract_archive
import os
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import time
from ..configure import max_threads_number_for_datasets_preprocess
from spikingjelly.datasets import np_savez
from .. import configure
from ..datasets import np_savez
# https://github.com/jackd/events-tfds/blob/master/events_tfds/data_io/aedat.py


@@ -119,69 +119,12 @@ class CIFAR10DVS(sjds.NeuromorphicDatasetFolder):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
'''
:param root: root path of the dataset
:type root: str
:param data_type: `event` or `frame`
:type data_type: str
:param frames_number: the integrated frame number
:type frames_number: int
:param split_by: `time` or `number`
:type split_by: str
:param duration: the time duration of each frame
:type duration: int
:param custom_integrate_function: a user-defined function that inputs are ``events, H, W``.
``events`` is a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``
``H`` is the height of the data and ``W`` is the weight of the data.
For example, H=128 and W=128 for the DVS128 Gesture dataset.
The user should define how to integrate events to frames, and return frames.
:type custom_integrate_function: Callable
:param custom_integrated_frames_dir_name: The name of directory for saving frames integrating by ``custom_integrate_function``.
If ``custom_integrated_frames_dir_name`` is ``None``, it will be set to ``custom_integrate_function.__name__``
:type custom_integrated_frames_dir_name: str or None
:param transform: a function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
:type transform: callable
:param target_transform: a function/transform that takes
in the target and transforms it.
:type target_transform: callable

If ``data_type == 'event'``
the sample in this dataset is a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``.

If ``data_type == 'frame'`` and ``frames_number`` is not ``None``
events will be integrated to frames with fixed frames number. ``split_by`` will define how to split events.
See :class:`spikingjelly.datasets.cal_fixed_frames_number_segment_index` for
more details.

If ``data_type == 'frame'``, ``frames_number`` is ``None``, and ``duration`` is not ``None``
events will be integrated to frames with fixed time duration.

If ``data_type == 'frame'``, ``frames_number`` is ``None``, ``duration`` is ``None``, and ``custom_integrate_function`` is not ``None``:
events will be integrated by the user-defined function and saved to the ``custom_integrated_frames_dir_name`` directory in ``root`` directory.
Here is an example from SpikingJelly's tutorials:
"""
The CIFAR10-DVS dataset, which is proposed by `CIFAR10-DVS: An Event-Stream Dataset for Object Classification
<https://internal-journal.frontiersin.org/articles/10.3389/fnins.2017.00309/full>`_.

.. code-block:: python

from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
from typing import Dict
import numpy as np
import spikingjelly.datasets as sjds
def integrate_events_to_2_frames_randomly(events: Dict, H: int, W: int):
index_split = np.random.randint(low=0, high=events['t'].__len__())
frames = np.zeros([2, 2, H, W])
frames[0] = sjds.integrate_events_segment_to_frame(events, H, W, 0, index_split)
frames[1] = sjds.integrate_events_segment_to_frame(events, H, W, index_split, events['t'].__len__())
return frames

root_dir = 'D:/datasets/DVS128Gesture'
train_set = DVS128Gesture(root_dir, train=True, data_type='frame', custom_integrate_function=integrate_events_to_2_frames_randomly)

from spikingjelly.datasets import play_frame
frame, label = train_set[500]
play_frame(frame)
'''
Refer to :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` for more details about params information.
"""
super().__init__(root, None, data_type, frames_number, split_by, duration, custom_integrate_function, custom_integrated_frames_dir_name, transform,
target_transform)
@staticmethod
@@ -283,7 +226,7 @@ class CIFAR10DVS(sjds.NeuromorphicDatasetFolder):
This function defines how to convert the origin binary data in ``extract_root`` to ``npz`` format and save converted files in ``events_np_root``.
'''
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), max_threads_number_for_datasets_preprocess)) as tpe:
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), configure.max_threads_number_for_datasets_preprocess)) as tpe:
for class_name in os.listdir(extract_root):
aedat_dir = os.path.join(extract_root, class_name)
np_dir = os.path.join(events_np_root, class_name)


+ 9
- 69
spikingjelly/datasets/dvs128_gesture.py View File

@@ -1,13 +1,13 @@
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from typing import Callable, Dict, Optional, Tuple
import numpy as np
import spikingjelly.datasets as sjds
from .. import datasets as sjds
from torchvision.datasets.utils import extract_archive
import os
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import time
from ..configure import max_threads_number_for_datasets_preprocess
from spikingjelly.datasets import np_savez
from .. import configure
from ..datasets import np_savez

class DVS128Gesture(sjds.NeuromorphicDatasetFolder):
def __init__(
@@ -23,71 +23,11 @@ class DVS128Gesture(sjds.NeuromorphicDatasetFolder):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
'''
:param root: root path of the dataset
:type root: str
:param train: whether use the train set
:type train: bool
:param data_type: `event` or `frame`
:type data_type: str
:param frames_number: the integrated frame number
:type frames_number: int
:param split_by: `time` or `number`
:type split_by: str
:param duration: the time duration of each frame
:type duration: int
:param custom_integrate_function: a user-defined function that inputs are ``events, H, W``.
``events`` is a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``
``H`` is the height of the data and ``W`` is the weight of the data.
For example, H=128 and W=128 for the DVS128 Gesture dataset.
The user should define how to integrate events to frames, and return frames.
:type custom_integrate_function: Callable
:param custom_integrated_frames_dir_name: The name of directory for saving frames integrating by ``custom_integrate_function``.
If ``custom_integrated_frames_dir_name`` is ``None``, it will be set to ``custom_integrate_function.__name__``
:type custom_integrated_frames_dir_name: str or None
:param transform: a function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
:type transform: callable
:param target_transform: a function/transform that takes
in the target and transforms it.
:type target_transform: callable

If ``data_type == 'event'``
the sample in this dataset is a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``.

If ``data_type == 'frame'`` and ``frames_number`` is not ``None``
events will be integrated to frames with fixed frames number. ``split_by`` will define how to split events.
See :class:`spikingjelly.datasets.cal_fixed_frames_number_segment_index` for
more details.

If ``data_type == 'frame'``, ``frames_number`` is ``None``, and ``duration`` is not ``None``
events will be integrated to frames with fixed time duration.

If ``data_type == 'frame'``, ``frames_number`` is ``None``, ``duration`` is ``None``, and ``custom_integrate_function`` is not ``None``:
events will be integrated by the user-defined function and saved to the ``custom_integrated_frames_dir_name`` directory in ``root`` directory.
Here is an example from SpikingJelly's tutorials:
"""
The DVS128 Gesture dataset, which is proposed by `A Low Power, Fully Event-Based Gesture Recognition System <https://openaccess.thecvf.com/content_cvpr_2017/html/Amir_A_Low_Power_CVPR_2017_paper.html>`_.

.. code-block:: python

from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
from typing import Dict
import numpy as np
import spikingjelly.datasets as sjds
def integrate_events_to_2_frames_randomly(events: Dict, H: int, W: int):
index_split = np.random.randint(low=0, high=events['t'].__len__())
frames = np.zeros([2, 2, H, W])
frames[0] = sjds.integrate_events_segment_to_frame(events, H, W, 0, index_split)
frames[1] = sjds.integrate_events_segment_to_frame(events, H, W, index_split, events['t'].__len__())
return frames

root_dir = 'D:/datasets/DVS128Gesture'
train_set = DVS128Gesture(root_dir, train=True, data_type='frame', custom_integrate_function=integrate_events_to_2_frames_randomly)

from spikingjelly.datasets import play_frame
frame, label = train_set[500]
play_frame(frame)
'''
Refer to :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` for more details about params information.
"""
assert train is not None
super().__init__(root, train, data_type, frames_number, split_by, duration, custom_integrate_function, custom_integrated_frames_dir_name, transform, target_transform)
@staticmethod
@@ -234,7 +174,7 @@ class DVS128Gesture(sjds.NeuromorphicDatasetFolder):
os.path.join(aedat_dir, 'trials_to_test.txt')) as trials_to_test_txt:
# use multi-thread to accelerate
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), max_threads_number_for_datasets_preprocess)) as tpe:
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), configure.max_threads_number_for_datasets_preprocess)) as tpe:
print(f'Start the ThreadPoolExecutor with max workers = [{tpe._max_workers}].')

for fname in trials_to_train_txt.readlines():


+ 217
- 0
spikingjelly/datasets/es_imagenet.py View File

@@ -0,0 +1,217 @@
from typing import Callable, Dict, Optional, Tuple
import numpy as np
from .. import datasets as sjds
import os
import rarfile
import time


def load_events(fname: str):
events = np.load(fname)
e_pos = events['pos']
e_neg = events['neg']
e_pos = np.hstack((e_pos, np.ones((e_pos.shape[0], 1))))
e_neg = np.hstack((e_neg, np.zeros((e_neg.shape[0], 1))))
events = np.vstack((e_pos, e_neg)) # shape = [N, 4], N * (x, y, t, p)
idx = np.argsort(events[:, 2])
events = events[idx]
return {
'x': events[:, 1],
'y': events[:, 0],
't': events[:, 2],
'p': events[:, 3]
}


class ESImageNet(sjds.NeuromorphicDatasetFolder):
def __init__(
self,
root: str,
train: bool = None,
data_type: str = 'event',
frames_number: int = None,
split_by: str = None,
duration: int = None,
custom_integrate_function: Callable = None,
custom_integrated_frames_dir_name: str = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
"""
The ES-ImageNet dataset, which is proposed by `ES-ImageNet: A Million Event-Stream Classification Dataset for Spiking Neural Networks <https://www.frontiersin.org/articles/10.3389/fnins.2021.726582/full>`_.

Refer to :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` for more details about params information.
"""
assert train is not None
super().__init__(root, train, data_type, frames_number, split_by, duration, custom_integrate_function, custom_integrated_frames_dir_name, transform, target_transform)

if data_type == 'event':
self.loader = load_events

@staticmethod
def load_events_np(fname: str):
return load_events(fname)

@staticmethod
def resource_url_md5() -> list:
'''
:return: A list ``url`` that ``url[i]`` is a tuple, which contains the i-th file's name, download link, and MD5
:rtype: list
'''
urls = [
('ES-imagenet-0.18.part01.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part01.rar&dl=1',
'900bdd57b5641f7d81cd4620283fef76'),
('ES-imagenet-0.18.part02.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part02.rar&dl=1',
'5982532009e863a8f4e18e793314c54b'),
('ES-imagenet-0.18.part03.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part03.rar&dl=1',
'8f408c1f5a1d4604e48d0d062a8289a0'),
('ES-imagenet-0.18.part04.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part04.rar&dl=1',
'5c5b5cf0a55954eb639964e3da510097'),
('ES-imagenet-0.18.part05.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part05.rar&dl=1',
'51feb661b4c9fa87860b63e76b914673'),
('ES-imagenet-0.18.part06.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part06.rar&dl=1',
'fcd007a2b17b7c13f338734c53f6db31'),
('ES-imagenet-0.18.part07.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part07.rar&dl=1',
'd3e74b96d9c5df15714bbc3abcd329fc'),
('ES-imagenet-0.18.part08.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part08.rar&dl=1',
'65b9cf7fa63e18d2e7d92ff45a42a5e5'),
('ES-imagenet-0.18.part09.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part09.rar&dl=1',
'241c9a37a83ff9efd305fe46d012211e'),
('ES-imagenet-0.18.part10.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part10.rar&dl=1',
'ceee96971008e30d0cdc34086c49fd75'),
('ES-imagenet-0.18.part11.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part11.rar&dl=1',
'4fbfefbe6e48758fbb72427c81f119cf'),
('ES-imagenet-0.18.part12.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part12.rar&dl=1',
'c8cc163be4e5f6451201dccbded4ec24'),
('ES-imagenet-0.18.part13.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part13.rar&dl=1',
'08c9dff32f6b42c49ef7cd78e37c728e'),
('ES-imagenet-0.18.part14.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part14.rar&dl=1',
'43aa157dc5bd5fcea81315a46e0322cf'),
('ES-imagenet-0.18.part15.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part15.rar&dl=1',
'480a69b050f465ef01efcc44ae29f7df'),
('ES-imagenet-0.18.part16.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part16.rar&dl=1',
'11abd24d92b93e7f85acd63abd4a18ab'),
('ES-imagenet-0.18.part17.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part17.rar&dl=1',
'3891486a6862c63a325c5f16cd01fdd1'),
('ES-imagenet-0.18.part18.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part18.rar&dl=1',
'cf8bb0525b514f411bca9d7c2d681f7c'),
('ES-imagenet-0.18.part19.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part19.rar&dl=1',
'3766bc35572ccacc03f0f293c571d0ae'),
('ES-imagenet-0.18.part20.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part20.rar&dl=1',
'bf73a5e338644122220e41da7b5630e6'),
('ES-imagenet-0.18.part21.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part21.rar&dl=1',
'564de4a2609cbb0bb67ffa1bc51f2487'),
('ES-imagenet-0.18.part22.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part22.rar&dl=1',
'60a9e52db1acadfccc9a9809073f0b04'),
('ES-imagenet-0.18.part23.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part23.rar&dl=1',
'373b5484826d40d7ec35f0e1605cb6ea'),
('ES-imagenet-0.18.part24.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part24.rar&dl=1',
'a50612e889b20f99cc7b2725dfd72e9e'),
('ES-imagenet-0.18.part25.rar',
'https://cloud.tsinghua.edu.cn/d/94873ab4ec2a4eb497b3/files/?p=%2FES-imagenet-0.18.part25.rar&dl=1',
'0802ccdeb0cff29237faf55164524101')
]

return urls


@staticmethod
def downloadable() -> bool:
'''
:return: Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
:rtype: bool
'''
return True

@staticmethod
def extract_downloaded_files(download_root: str, extract_root: str):
'''
:param download_root: Root directory path which saves downloaded dataset files
:type download_root: str
:param extract_root: Root directory path which saves extracted files from downloaded files
:type extract_root: str
:return: None

This function defines how to extract download files.
'''
rar_file = os.path.join(download_root, 'ES-imagenet-0.18.part01.rar')
print(f'Extract [{rar_file}] to [{extract_root}].')
rar_file = rarfile.RarFile(rar_file)
rar_file.extractall(extract_root)
rar_file.close()



@staticmethod
def create_events_np_files(extract_root: str, events_np_root: str):
'''
:param extract_root: Root directory path which saves extracted files from downloaded files
:type extract_root: str
:param events_np_root: Root directory path which saves events files in the ``npz`` format
:type events_np_root:
:return: None

This function defines how to convert the origin binary data in ``extract_root`` to ``npz`` format and save converted files in ``events_np_root``.
'''
t_ckp = time.time()
train_dir = os.path.join(events_np_root, 'train')
os.mkdir(train_dir)
print(f'Mkdir [{train_dir}].')
sjds.create_same_directory_structure(os.path.join(extract_root, 'ES-imagenet-0.18/train'), train_dir)
for class_dir in os.listdir(os.path.join(extract_root, 'ES-imagenet-0.18/train')):
source_dir = os.path.join(extract_root, 'ES-imagenet-0.18/train', class_dir)
target_dir = os.path.join(train_dir, class_dir)
print(f'Create soft links from [{source_dir}] to [{target_dir}].')
for class_sample in os.listdir(source_dir):
os.symlink(os.path.join(source_dir, class_sample),
os.path.join(target_dir, class_sample))




val_label = np.loadtxt(os.path.join(extract_root, 'ES-imagenet-0.18/vallabel.txt'), delimiter=' ', usecols=(1, ), dtype=int)
val_fname = np.loadtxt(os.path.join(extract_root, 'ES-imagenet-0.18/vallabel.txt'), delimiter=' ', usecols=(0, ), dtype=str)
source_dir = os.path.join(extract_root, 'ES-imagenet-0.18/val')
target_dir = os.path.join(events_np_root, 'test')
os.mkdir(target_dir)
print(f'Mkdir [{target_dir}].')
sjds.create_same_directory_structure(train_dir, target_dir)

for i in range(val_fname.__len__()):
os.symlink(os.path.join(source_dir, val_fname[i]), os.path.join(target_dir, f'class{val_label[i]}/{val_fname[i]}'))

print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')
print(f'Note that files in [{events_np_root}] are soft links whose source files are in [{extract_root}]. If you want to use events, do not delete [{extract_root}].')

@staticmethod
def get_H_W() -> Tuple:
'''
:return: A tuple ``(H, W)``, where ``H`` is the height of the data and ``W` is the weight of the data.
For example, this function returns ``(128, 128)`` for the DVS128 Gesture dataset.
:rtype: tuple
'''
return 256, 256

+ 10
- 69
spikingjelly/datasets/n_caltech101.py View File

@@ -1,13 +1,12 @@
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
import numpy as np
import spikingjelly.datasets as sjds
from typing import Callable, Dict, Optional, Tuple
from .. import datasets as sjds
from torchvision.datasets.utils import extract_archive
import os
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import time
from ..configure import max_threads_number_for_datasets_preprocess
from spikingjelly.datasets import np_savez
from .. import configure
from ..datasets import np_savez

class NCaltech101(sjds.NeuromorphicDatasetFolder):
def __init__(
@@ -22,69 +21,11 @@ class NCaltech101(sjds.NeuromorphicDatasetFolder):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
'''
:param root: root path of the dataset
:type root: str
:param data_type: `event` or `frame`
:type data_type: str
:param frames_number: the integrated frame number
:type frames_number: int
:param split_by: `time` or `number`
:type split_by: str
:param duration: the time duration of each frame
:type duration: int
:param custom_integrate_function: a user-defined function that inputs are ``events, H, W``.
``events`` is a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``
``H`` is the height of the data and ``W`` is the weight of the data.
For example, H=128 and W=128 for the DVS128 Gesture dataset.
The user should define how to integrate events to frames, and return frames.
:type custom_integrate_function: Callable
:param custom_integrated_frames_dir_name: The name of directory for saving frames integrating by ``custom_integrate_function``.
If ``custom_integrated_frames_dir_name`` is ``None``, it will be set to ``custom_integrate_function.__name__``
:type custom_integrated_frames_dir_name: str or None
:param transform: a function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
:type transform: callable
:param target_transform: a function/transform that takes
in the target and transforms it.
:type target_transform: callable

If ``data_type == 'event'``
the sample in this dataset is a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``.

If ``data_type == 'frame'`` and ``frames_number`` is not ``None``
events will be integrated to frames with fixed frames number. ``split_by`` will define how to split events.
See :class:`spikingjelly.datasets.cal_fixed_frames_number_segment_index` for
more details.

If ``data_type == 'frame'``, ``frames_number`` is ``None``, and ``duration`` is not ``None``
events will be integrated to frames with fixed time duration.

If ``data_type == 'frame'``, ``frames_number`` is ``None``, ``duration`` is ``None``, and ``custom_integrate_function`` is not ``None``:
events will be integrated by the user-defined function and saved to the ``custom_integrated_frames_dir_name`` directory in ``root`` directory.
Here is an example from SpikingJelly's tutorials:

.. code-block:: python

from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
from typing import Dict
import numpy as np
import spikingjelly.datasets as sjds
def integrate_events_to_2_frames_randomly(events: Dict, H: int, W: int):
index_split = np.random.randint(low=0, high=events['t'].__len__())
frames = np.zeros([2, 2, H, W])
frames[0] = sjds.integrate_events_segment_to_frame(events, H, W, 0, index_split)
frames[1] = sjds.integrate_events_segment_to_frame(events, H, W, index_split, events['t'].__len__())
return frames

root_dir = 'D:/datasets/DVS128Gesture'
train_set = DVS128Gesture(root_dir, train=True, data_type='frame', custom_integrate_function=integrate_events_to_2_frames_randomly)

from spikingjelly.datasets import play_frame
frame, label = train_set[500]
play_frame(frame)
'''
"""
The N-Caltech101 dataset, which is proposed by `Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades <https://www.frontiersin.org/articles/10.3389/fnins.2015.00437/full>`_.

Refer to :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` for more details about params information.
"""
super().__init__(root, None, data_type, frames_number, split_by, duration, custom_integrate_function, custom_integrated_frames_dir_name, transform, target_transform)
@staticmethod
def resource_url_md5() -> list:
@@ -171,7 +112,7 @@ class NCaltech101(sjds.NeuromorphicDatasetFolder):
'''
t_ckp = time.time()
extract_root = os.path.join(extract_root, 'Caltech101')
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), max_threads_number_for_datasets_preprocess)) as tpe:
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), configure.max_threads_number_for_datasets_preprocess)) as tpe:
# too many threads will make the disk overload
for class_name in os.listdir(extract_root):
bin_dir = os.path.join(extract_root, class_name)


+ 9
- 71
spikingjelly/datasets/n_mnist.py View File

@@ -1,13 +1,12 @@
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
import numpy as np
import spikingjelly.datasets as sjds
from typing import Callable, Dict, Optional, Tuple
from .. import datasets as sjds
from torchvision.datasets.utils import extract_archive
import os
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import time
from ..configure import max_threads_number_for_datasets_preprocess
from spikingjelly.datasets import np_savez
from .. import configure
from ..datasets import np_savez

class NMNIST(sjds.NeuromorphicDatasetFolder):
def __init__(
@@ -23,72 +22,11 @@ class NMNIST(sjds.NeuromorphicDatasetFolder):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
'''
:param root: root path of the dataset
:type root: str
:param train: whether use the train set
:type train: bool
:param data_type: `event` or `frame`
:type data_type: str
:param frames_number: the integrated frame number
:type frames_number: int
:param split_by: `time` or `number`
:type split_by: str
:param duration: the time duration of each frame
:type duration: int
:param custom_integrate_function: a user-defined function that inputs are ``events, H, W``.
``events`` is a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``
``H`` is the height of the data and ``W`` is the weight of the data.
For example, H=128 and W=128 for the DVS128 Gesture dataset.
The user should define how to integrate events to frames, and return frames.
:type custom_integrate_function: Callable
:param custom_integrated_frames_dir_name: The name of directory for saving frames integrating by ``custom_integrate_function``.
If ``custom_integrated_frames_dir_name`` is ``None``, it will be set to ``custom_integrate_function.__name__``
:type custom_integrated_frames_dir_name: str or None
:param transform: a function/transform that takes in
a sample and returns a transformed version.
E.g, ``transforms.RandomCrop`` for images.
:type transform: callable
:param target_transform: a function/transform that takes
in the target and transforms it.
:type target_transform: callable

If ``data_type == 'event'``
the sample in this dataset is a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``.

If ``data_type == 'frame'`` and ``frames_number`` is not ``None``
events will be integrated to frames with fixed frames number. ``split_by`` will define how to split events.
See :class:`spikingjelly.datasets.cal_fixed_frames_number_segment_index` for
more details.

If ``data_type == 'frame'``, ``frames_number`` is ``None``, and ``duration`` is not ``None``
events will be integrated to frames with fixed time duration.

If ``data_type == 'frame'``, ``frames_number`` is ``None``, ``duration`` is ``None``, and ``custom_integrate_function`` is not ``None``:
events will be integrated by the user-defined function and saved to the ``custom_integrated_frames_dir_name`` directory in ``root`` directory.
Here is an example from SpikingJelly's tutorials:

.. code-block:: python

from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
from typing import Dict
import numpy as np
import spikingjelly.datasets as sjds
def integrate_events_to_2_frames_randomly(events: Dict, H: int, W: int):
index_split = np.random.randint(low=0, high=events['t'].__len__())
frames = np.zeros([2, 2, H, W])
frames[0] = sjds.integrate_events_segment_to_frame(events, H, W, 0, index_split)
frames[1] = sjds.integrate_events_segment_to_frame(events, H, W, index_split, events['t'].__len__())
return frames

root_dir = 'D:/datasets/DVS128Gesture'
train_set = DVS128Gesture(root_dir, train=True, data_type='frame', custom_integrate_function=integrate_events_to_2_frames_randomly)

from spikingjelly.datasets import play_frame
frame, label = train_set[500]
play_frame(frame)
"""
The N-MNIST dataset, which is proposed by `Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades <https://www.frontiersin.org/articles/10.3389/fnins.2015.00437/full>`_.

'''
Refer to :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` for more details about params information.
"""
assert train is not None
super().__init__(root, train, data_type, frames_number, split_by, duration, custom_integrate_function, custom_integrated_frames_dir_name, transform, target_transform)
@staticmethod
@@ -175,7 +113,7 @@ class NMNIST(sjds.NeuromorphicDatasetFolder):
This function defines how to convert the origin binary data in ``extract_root`` to ``npz`` format and save converted files in ``events_np_root``.
'''
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), max_threads_number_for_datasets_preprocess)) as tpe:
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), configure.max_threads_number_for_datasets_preprocess)) as tpe:
# too many threads will make the disk overload
for train_test_dir in ['Train', 'Test']:
source_dir = os.path.join(extract_root, train_test_dir)


+ 331
- 0
spikingjelly/datasets/nav_gesture.py View File

@@ -0,0 +1,331 @@

# Codes from the source dataset:
# ---------------------------------------------------------------------------------------------
#!/usr/bin/python
# -*- coding: utf8 -*
#####################
# read_td_events.py #
#####################
# Feb 2017 - Jean-Matthieu Maro
# Email: jean-matthieu dot maro, hosted at inserm, which is located in FRance.
# Thanks to Germain Haessig and Laurent Dardelet.

from struct import unpack, pack
import numpy as np
import sys


def peek(f, length=1):
pos = f.tell()
data = f.read(length)
f.seek(pos)
return data

def readATIS_tddat(file_name, orig_at_zero = True, drop_negative_dt = True, verbose = True, events_restriction = [0, np.inf]):

"""
reads ATIS td events in .dat format

input:
filename: string, path to the .dat file
orig_at_zero: bool, if True, timestamps will start at 0
drop_negative_dt: bool, if True, events with a timestamp greater than the previous event are dismissed
verbose: bool, if True, verbose mode.
events_restriction: list [min ts, max ts], will return only events with ts in the defined boundaries

output:
timestamps: numpy array of length (number of events), timestamps
coords: numpy array of size (number of events, 2), spatial coordinates: col 0 is x, col 1 is y.
polarities: numpy array of length (number of events), polarities
removed_events: integer, number of removed events (negative delta-ts)

"""

polmask = 0x0002000000000000
xmask = 0x000001FF00000000
ymask = 0x0001FE0000000000
polpadding = 49
ypadding = 41
xpadding = 32

# This one read _td.dat files generated by kAER
if verbose:
print('Reading _td dat file... (' + file_name + ')')
file = open(file_name,'rb')

header = False
while peek(file) == b'%':
file.readline()
header = True
if header:
ev_type = unpack('B',file.read(1))[0]
ev_size = unpack('B',file.read(1))[0]
if verbose:
print('> Header exists. Event type is ' + str(ev_type) + ', event size is ' + str(ev_size))
if ev_size != 8:
print('Wrong event size. Aborting.')
return -1, -1, -1, -1
else: # set default ev type and size
if verbose:
print('> No header. Setting default event type and size.')
ev_size = 8
ev_type = 0

# Compute number of events in the file
start = file.tell()
file.seek(0,2)
stop = file.tell()
file.seek(start)

Nevents = int( (stop-start)/ev_size )
dNEvents = Nevents/100
if verbose:
print("> The file contains %d events." %Nevents)

# store read data
timestamps = np.zeros(Nevents, dtype = int)
polarities = np.zeros(Nevents, dtype = int)
coords = np.zeros((Nevents, 2), dtype = int)

ActualEvents = 0
for i in np.arange(0, int(Nevents)):

event = unpack('Q',file.read(8))
ts = event[0] & 0x00000000FFFFFFFF
# padding = event[0] & 0xFFFC000000000000
pol = (event[0] & polmask) >> polpadding
y = (event[0] & ymask) >> ypadding
x = (event[0] & xmask) >> xpadding
if i >= events_restriction[0] and ts>=timestamps[max(0,i-1)]:
ActualEvents += 1
timestamps[i] = ts
polarities[i] = pol
coords[i, 0] = x
coords[i, 1] = y

if verbose and i%dNEvents == 0:
sys.stdout.write("> "+str(i/dNEvents)+"% \r")
sys.stdout.flush()
if i > events_restriction[1]:
break
file.close()
if verbose:
print ("> After loading events, actually found {0} events.".format(ActualEvents))

timestamps = timestamps[:ActualEvents]
coords = coords[:ActualEvents, :]
polarities = polarities[:ActualEvents]

#check for negative timestamps
for ts in timestamps:
if ts < 0:
print('Found a negative timestamp.')

if orig_at_zero:
timestamps = timestamps - timestamps[0]

drop_sum = 0
if drop_negative_dt:
if verbose:
print('> Looking for negative dts...')
# first check if negative TS differences
just_dropped = True
nPasses = 0
while just_dropped:
nPasses += 1
index_neg = []
just_dropped = False
ii = 0
while ii < (timestamps.size - 1):
dt = timestamps[ii+1] - timestamps[ii]
if dt < 0: # alors ts en ii+1 plus petit que ii
index_neg += [ii+1]
ii += 1
just_dropped = True
if verbose and ii%dNEvents == 0:
sys.stdout.write("> "+str(ii/dNEvents)+"% (pass "+str(nPasses)+") \r")
sys.stdout.flush()
ii += 1
if len(index_neg) > 0:
drop_sum += len(index_neg)
index_neg = np.array(index_neg)
timestamps = np.delete(timestamps, index_neg)
polarities = np.delete(polarities, index_neg)
coords = np.delete(coords, index_neg, axis = 0)
if verbose:
print('> Removed {0} events in {1} passes.'.format(drop_sum, nPasses))
removed_events = drop_sum
else:
removed_events = -1
if verbose:
print("> Sequence duration: {0:.2f}s, ts[0] = {1}, ts[{2}] = {3}.".format(float(timestamps[-1] - timestamps[0]) / 1e6, timestamps[0], len(timestamps)-1, timestamps[-1]))


return timestamps, coords, polarities, removed_events
# ---------------------------------------------------------------------------------------------

from typing import Callable, Dict, Optional, Tuple
from .. import datasets as sjds
from torchvision.datasets.utils import extract_archive
import os
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import shutil
import time
from .. import configure
from ..datasets import np_savez



class NAVGestureWalk(sjds.NeuromorphicDatasetFolder):
# 6 gestures: left, right, up, down, home, select.
# 10 subjects, holding the phone in one hand (selfie mode) while walking indoor and outdoor
def __init__(
self,
root: str,
data_type: str = 'event',
frames_number: int = None,
split_by: str = None,
duration: int = None,
custom_integrate_function: Callable = None,
custom_integrated_frames_dir_name: str = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
"""
The Nav Gesture dataset, which is proposed by `Event-Based Gesture Recognition With Dynamic Background Suppression Using Smartphone Computational Capabilities <https://www.frontiersin.org/articles/10.3389/fnins.2020.00275/full>`_.

Refer to :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` for more details about params information.
"""
super().__init__(root, None, data_type, frames_number, split_by, duration, custom_integrate_function, custom_integrated_frames_dir_name, transform, target_transform)

@staticmethod
def resource_url_md5() -> list:
'''
:return: A list ``url`` that ``url[i]`` is a tuple, which contains the i-th file's name, download link, and MD5
:rtype: list
'''
return [('navgesture-walk.zip', 'https://www.neuromorphic-vision.com/public/downloads/navgesture/navgesture-walk.zip', '5d305266f13005401959e819abe206f0')]

@staticmethod
def downloadable() -> bool:
'''
:return: Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
:rtype: bool
'''
return True

@staticmethod
def extract_downloaded_files(download_root: str, extract_root: str):
'''
:param download_root: Root directory path which saves downloaded dataset files
:type download_root: str
:param extract_root: Root directory path which saves extracted files from downloaded files
:type extract_root: str
:return: None

This function defines how to extract download files.
'''
temp_ext_dir = os.path.join(download_root, 'temp_ext')
os.mkdir(temp_ext_dir)
print(f'Mkdir [{temp_ext_dir}].')
extract_archive(os.path.join(download_root, 'navgesture-walk.zip'), temp_ext_dir)
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 4)) as tpe:
for zip_file in os.listdir(temp_ext_dir):
if os.path.splitext(zip_file)[1] == '.zip':
zip_file = os.path.join(temp_ext_dir, zip_file)
print(f'Extract [{zip_file}] to [{extract_root}].')
tpe.submit(extract_archive, zip_file, extract_root)

shutil.rmtree(temp_ext_dir)
print(f'Rmtree [{temp_ext_dir}].')

@staticmethod
def get_H_W() -> Tuple:
'''
:return: A tuple ``(H, W)``, where ``H`` is the height of the data and ``W` is the weight of the data.
For example, this function returns ``(128, 128)`` for the DVS128 Gesture dataset.
:rtype: tuple
'''
return 240, 304 # this camera is 240*320, but x.max() = 303. So, I set W = 304.

@staticmethod
def read_aedat_save_to_np(bin_file: str, np_file: str):
t, xy, p, _ = readATIS_tddat(bin_file, verbose=False)
x = xy[:, 0]
y = 239 - xy[:, 1]
np_savez(np_file,
t=t,
x=x,
y=y,
p=p
)
print(f'Save [{bin_file}] to [{np_file}].')

@staticmethod
def create_events_np_files(extract_root: str, events_np_root: str):
'''
:param extract_root: Root directory path which saves extracted files from downloaded files
:type extract_root: str
:param events_np_root: Root directory path which saves events files in the ``npz`` format
:type events_np_root:
:return: None

This function defines how to convert the origin binary data in ``extract_root`` to ``npz`` format and save converted files in ``events_np_root``.
'''
t_ckp = time.time()
np_dir_dict = {}
for label in ['le', 'ri', 'up', 'do', 'ho', 'se']:
np_dir = os.path.join(events_np_root, label)
os.mkdir(np_dir)
print(f'Mkdir [{np_dir}].')
np_dir_dict[label] = np_dir

with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(),
configure.max_threads_number_for_datasets_preprocess)) as tpe:
for user_name in os.listdir(extract_root):
aedat_dir = os.path.join(extract_root, user_name)
for bin_file in os.listdir(aedat_dir):
base_name = os.path.splitext(bin_file)[0]
label = base_name.split('_')[1]
source_file = os.path.join(aedat_dir, bin_file)
target_file = os.path.join(np_dir_dict[label], base_name + '.npz')
print(f'Start to convert [{source_file}] to [{target_file}].')
tpe.submit(NAVGestureWalk.read_aedat_save_to_np, source_file,
target_file)
print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')


class NAVGestureSit(NAVGestureWalk):
@staticmethod
def resource_url_md5() -> list:
'''
:return: A list ``url`` that ``url[i]`` is a tuple, which contains the i-th file's name, download link, and MD5
:rtype: list
'''
return [('navgesture-sit.zip', 'https://www.neuromorphic-vision.com/public/downloads/navgesture/navgesture-sit.zip', '1571753ace4d9e0946e6503313712c22')]

@staticmethod
def extract_downloaded_files(download_root: str, extract_root: str):
'''
:param download_root: Root directory path which saves downloaded dataset files
:type download_root: str
:param extract_root: Root directory path which saves extracted files from downloaded files
:type extract_root: str
:return: None

This function defines how to extract download files.
'''
temp_ext_dir = os.path.join(download_root, 'temp_ext')
os.mkdir(temp_ext_dir)
print(f'Mkdir [{temp_ext_dir}].')
extract_archive(os.path.join(download_root, 'navgesture-sit.zip'), temp_ext_dir)
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 4)) as tpe:
for zip_file in os.listdir(temp_ext_dir):
if os.path.splitext(zip_file)[1] == '.zip':
zip_file = os.path.join(temp_ext_dir, zip_file)
print(f'Extract [{zip_file}] to [{extract_root}].')
tpe.submit(extract_archive, zip_file, extract_root)

shutil.rmtree(temp_ext_dir)
print(f'Rmtree [{temp_ext_dir}].')

+ 0
- 1
spikingjelly/datasets/speechcommands.py View File

@@ -10,7 +10,6 @@ from torchaudio.datasets.utils import (
download_url,
extract_archive
)
from torchvision import transforms
from torchvision.datasets.utils import verify_str_arg
import numpy as np
from random import choice


Loading…
Cancel
Save