@@ -1,10 +0,0 @@ | |||
--- | |||
name: 'Feature Request' | |||
about: 'Suggest a new idea or improvement for Brainpy' | |||
labels: 'enhancement' | |||
--- | |||
Please: | |||
- [ ] Check for duplicate requests. | |||
- [ ] Describe your goal, and if possible provide a code snippet with a motivating example. |
@@ -1,5 +1,5 @@ | |||
--- | |||
name: 'Bug report' | |||
name: 'Bug Report' | |||
about: 'Report a bug to help improve the package' | |||
labels: 'bug' | |||
--- | |||
@@ -1,5 +1,5 @@ | |||
blank_issues_enabled: false | |||
contact_links: | |||
- name: Question | |||
url: https://github.com/google/jax/discussions | |||
url: https://github.com/PKU-NIP-Lab/BrainPy/discussions | |||
about: Please ask questions on the Discussions tab |
@@ -5,9 +5,9 @@ name: Linux CI | |||
on: | |||
push: | |||
branches: [ master, brainpy-2.x, V2.1.0 ] | |||
branches: [ master ] | |||
pull_request: | |||
branches: [ master, brainpy-2.x, V2.1.0 ] | |||
branches: [ master ] | |||
jobs: | |||
@@ -16,7 +16,7 @@ jobs: | |||
strategy: | |||
fail-fast: false | |||
matrix: | |||
python-version: ["3.7", "3.8", "3.9"] | |||
python-version: ["3.7", "3.8", "3.9", "3.10"] | |||
steps: | |||
- uses: actions/checkout@v2 | |||
@@ -5,19 +5,18 @@ name: MacOS CI | |||
on: | |||
push: | |||
branches: [ master, brainpy-2.x, V2.1.0 ] | |||
branches: [ master ] | |||
pull_request: | |||
branches: [ master, brainpy-2.x, V2.1.0 ] | |||
branches: [ master ] | |||
jobs: | |||
build: | |||
runs-on: ${{ matrix.os }} | |||
runs-on: macos-latest | |||
strategy: | |||
fail-fast: false | |||
matrix: | |||
os: [macos-10.15, macos-11, macos-latest] | |||
python-version: ["3.7", "3.8", "3.9"] | |||
python-version: ["3.7", "3.8", "3.9", "3.10"] | |||
steps: | |||
- uses: actions/checkout@v2 | |||
@@ -39,4 +38,4 @@ jobs: | |||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics | |||
- name: Test with pytest | |||
run: | | |||
pytest | |||
pytest brainpy/ |
@@ -0,0 +1,18 @@ | |||
name: Sync multiple branches | |||
on: | |||
pull_request: | |||
branches: | |||
- master | |||
jobs: | |||
sync-branch: | |||
runs-on: ubuntu-latest | |||
steps: | |||
- uses: actions/checkout@master | |||
- name: Merge master -> brainpy-2.x | |||
uses: devmasx/merge-branch@v1.3.1 | |||
with: | |||
type: now | |||
from_branch: master | |||
target_branch: brainpy-2.x | |||
github_token: ${{ github.token }} |
@@ -5,9 +5,9 @@ name: Windows CI | |||
on: | |||
push: | |||
branches: [ master, brainpy-2.x, V2.1.0 ] | |||
branches: [ master ] | |||
pull_request: | |||
branches: [ master, brainpy-2.x, V2.1.0 ] | |||
branches: [ master ] | |||
jobs: | |||
@@ -16,7 +16,7 @@ jobs: | |||
strategy: | |||
fail-fast: false | |||
matrix: | |||
python-version: ["3.7", "3.8", "3.9"] | |||
python-version: ["3.7", "3.8", "3.9", "3.10"] | |||
steps: | |||
- uses: actions/checkout@v2 | |||
@@ -2,6 +2,8 @@ name: Add contributors | |||
on: | |||
schedule: | |||
- cron: '20 20 * * *' | |||
push: | |||
branches: [ master ] | |||
jobs: | |||
add-contributors: | |||
@@ -17,6 +17,7 @@ BrainModels/ | |||
book/ | |||
docs/examples | |||
docs/apis/jaxsetting.rst | |||
docs/quickstart/data | |||
examples/recurrent_neural_network/neurogym | |||
develop/iconip_paper | |||
develop/benchmark/COBA/results | |||
@@ -28,9 +28,9 @@ BrainPy is a flexible, efficient, and extensible framework for computational neu | |||
## Install | |||
## Installation | |||
BrainPy is based on Python (>=3.6) and can be installed on Linux (Ubuntu 16.04 or later), macOS (10.12 or later), and Windows platforms. Install the latest version of BrainPy: | |||
BrainPy is based on Python (>=3.7) and can be installed on Linux (Ubuntu 16.04 or later), macOS (10.12 or later), and Windows platforms. Install the latest version of BrainPy: | |||
```bash | |||
$ pip install brain-py -U | |||
@@ -54,7 +54,121 @@ import brainpy as bp | |||
**1\. E-I balance network** | |||
### 1. Operator level | |||
Mathematical operators in BrainPy are the same as those in NumPy. | |||
```python | |||
>>> import numpy as np | |||
>>> import brainpy.math as bm | |||
# array creation | |||
>>> np_arr = np.zeros((2, 4)); np_arr | |||
array([[0., 0., 0., 0.], | |||
[0., 0., 0., 0.]]) | |||
>>> bm_arr = bm.zeros((2, 4)); bm_arr | |||
JaxArray([[0., 0., 0., 0.], | |||
[0., 0., 0., 0.]], dtype=float32) | |||
# in-place updating | |||
>>> np_arr[0] += 1.; np_arr | |||
array([[1., 1., 1., 1.], | |||
[0., 0., 0., 0.]]) | |||
>>> bm_arr[0] += 1.; bm_arr | |||
JaxArray([[1., 1., 1., 1.], | |||
[0., 0., 0., 0.]], dtype=float32) | |||
# mathematical functions | |||
>>> np.sin(np_arr) | |||
array([[0.84147098, 0.84147098, 0.84147098, 0.84147098], | |||
[0. , 0. , 0. , 0. ]]) | |||
>>> bm.sin(bm_arr) | |||
JaxArray([[0.84147096, 0.84147096, 0.84147096, 0.84147096], | |||
[0. , 0. , 0. , 0. ]], dtype=float32) | |||
# linear algebra | |||
>>> np.dot(np_arr, np.ones((4, 2))) | |||
array([[4., 4.], | |||
[0., 0.]]) | |||
>>> bm.dot(bm_arr, bm.ones((4, 2))) | |||
JaxArray([[4., 4.], | |||
[0., 0.]], dtype=float32) | |||
# random number generation | |||
>>> np.random.uniform(-0.1, 0.1, (2, 3)) | |||
array([[-0.02773637, 0.03766689, -0.01363128], | |||
[-0.01946991, -0.06669802, 0.09426067]]) | |||
>>> bm.random.uniform(-0.1, 0.1, (2, 3)) | |||
JaxArray([[-0.03044081, -0.07787752, 0.04346445], | |||
[-0.01366713, -0.0522548 , 0.04372055]], dtype=float32) | |||
``` | |||
### 2. Integrator level | |||
Numerical methods for ordinary differential equations (ODEs). | |||
```python | |||
sigma = 10; beta = 8/3; rho = 28 | |||
@bp.odeint(method='rk4') | |||
def lorenz_system(x, y, z, t): | |||
dx = sigma * (y - x) | |||
dy = x * (rho - z) - y | |||
dz = x * y - beta * z | |||
return dx, dy, dz | |||
runner = bp.integrators.IntegratorRunner(lorenz_system, dt=0.01) | |||
runner.run(100.) | |||
``` | |||
Numerical methods for stochastic differential equations (SDEs). | |||
```python | |||
sigma = 10; beta = 8/3; rho = 28 | |||
p=0.1 | |||
def lorenz_noise(x, y, z, t): | |||
return p*x, p*y, p*z | |||
@bp.odeint(method='milstein', g=lorenz_noise) | |||
def lorenz_system(x, y, z, t): | |||
dx = sigma * (y - x) | |||
dy = x * (rho - z) - y | |||
dz = x * y - beta * z | |||
return dx, dy, dz | |||
runner = bp.integrators.IntegratorRunner(lorenz_system, dt=0.01) | |||
runner.run(100.) | |||
``` | |||
Numerical methods for delay differential equations (SDEs). | |||
```python | |||
xdelay = bm.TimeDelay(bm.zeros(1), delay_len=1., before_t0=1., dt=0.01) | |||
@bp.ddeint(method='rk4', state_delays={'x': xdelay}) | |||
def second_order_eq(x, y, t): | |||
dx = y | |||
dy = -y - 2 * x - 0.5 * xdelay(t - 1) | |||
return dx, dy | |||
runner = bp.integrators.IntegratorRunner(second_order_eq, dt=0.01) | |||
runner.run(100.) | |||
``` | |||
### 3. Dynamics simulation level | |||
Building an E-I balance network. | |||
```python | |||
class EINet(bp.dyn.Network): | |||
@@ -77,9 +191,36 @@ runner = bp.dyn.DSRunner(net) | |||
runner(100.) | |||
``` | |||
Simulating a whole brain network by using rate models. | |||
```python | |||
import numpy as np | |||
class WholeBrainNet(bp.dyn.Network): | |||
def __init__(self, signal_speed=20.): | |||
super(WholeBrainNet, self).__init__() | |||
**2\. Echo state network** | |||
self.fhn = bp.dyn.RateFHN(80, x_ou_sigma=0.01, y_ou_sigma=0.01, name='fhn') | |||
self.syn = bp.dyn.DiffusiveDelayCoupling(self.fhn, self.fhn, | |||
'x->input', | |||
conn_mat=conn_mat, | |||
delay_mat=delay_mat) | |||
def update(self, _t, _dt): | |||
self.syn.update(_t, _dt) | |||
self.fhn.update(_t, _dt) | |||
net = WholeBrainNet() | |||
runner = bp.dyn.DSRunner(net, monitors=['fhn.x'], inputs=['fhn.input', 0.72]) | |||
runner.run(6e3) | |||
``` | |||
### 4. Dynamics training level | |||
Training an echo state network. | |||
```python | |||
i = bp.nn.Input(3) | |||
@@ -88,16 +229,14 @@ o = bp.nn.LinearReadout(3) | |||
net = i >> r >> o | |||
# Ridge Regression | |||
trainer = bp.nn.RidgeTrainer(net, beta=1e-5) | |||
trainer = bp.nn.RidgeTrainer(net, beta=1e-5) # Ridge Regression | |||
# FORCE Learning | |||
trainer = bp.nn.FORCELearning(net, alpha=1.) | |||
trainer = bp.nn.FORCELearning(net, alpha=1.) # FORCE Learning | |||
``` | |||
**3. Next generation reservoir computing** | |||
Training a next-generation reservoir computing model. | |||
```python | |||
i = bp.nn.Input(3) | |||
@@ -111,7 +250,7 @@ trainer = bp.nn.RidgeTrainer(net, beta=1e-5) | |||
**4. Recurrent neural network** | |||
Training an artificial recurrent neural network. | |||
```python | |||
i = bp.nn.Input(3) | |||
@@ -128,7 +267,9 @@ trainer = bp.nn.BPTT(net, | |||
**5\. Analyzing a low-dimensional FitzHugh–Nagumo neuron model** | |||
### 5. Dynamics analysis level | |||
Analyzing a low-dimensional FitzHugh–Nagumo neuron model. | |||
```python | |||
bp.math.enable_x64() | |||
@@ -149,9 +290,10 @@ analyzer.show_figure() | |||
</p> | |||
For **more functions and examples**, please refer to the [documentation](https://brainpy.readthedocs.io/) and [examples](https://brainpy-examples.readthedocs.io/). | |||
### 6. More others | |||
For **more functions and examples**, please refer to the [documentation](https://brainpy.readthedocs.io/) and [examples](https://brainpy-examples.readthedocs.io/). | |||
## License | |||
@@ -1,159 +0,0 @@ | |||
<p align="center"> | |||
<img alt="Header image of BrainPy - brain dynamics programming in Python." src="./images/logo.png" > | |||
</p> | |||
<p align="center"> | |||
<a href="https://pypi.org/project/brain-py/"><img alt="Supported Python Version" src="https://img.shields.io/pypi/pyversions/brain-py"></a> | |||
<a href="https://github.com/PKU-NIP-Lab/BrainPy"><img alt="LICENSE" src="https://anaconda.org/brainpy/brainpy/badges/license.svg"></a> | |||
<a href="https://brainpy.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation" src="https://readthedocs.org/projects/brainpy/badge/?version=latest"></a> | |||
<a href="https://badge.fury.io/py/brain-py"><img alt="PyPI version" src="https://badge.fury.io/py/brain-py.svg"></a> | |||
<a href="https://github.com/PKU-NIP-Lab/BrainPy"><img alt="Linux CI" src="https://github.com/PKU-NIP-Lab/BrainPy/actions/workflows/Linux_CI.yml/badge.svg"></a> | |||
<a href="https://github.com/PKU-NIP-Lab/BrainPy"><img alt="Linux CI" src="https://github.com/PKU-NIP-Lab/BrainPy/actions/workflows/Windows_CI.yml/badge.svg"></a> | |||
</p> | |||
:clap::clap: **CHEERS**: A new version of BrainPy (>=2.0.0, long term support) has been released! :clap::clap: | |||
# Why use BrainPy | |||
``BrainPy`` is an integrative framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation (built on top of [JAX](https://github.com/google/jax)). Core functions provided in BrainPy includes | |||
- **JIT compilation** for class objects. | |||
- **Numerical solvers** for ODEs, SDEs, and others. | |||
- **Dynamics simulation tools** for various brain objects, like neurons, synapses, networks, soma, dendrites, channels, and even more. | |||
- **Dynamics analysis tools** for differential equations, including phase plane analysis and bifurcation analysis, and linearization analysis. | |||
- **Seamless integration with deep learning models**. | |||
- And more ...... | |||
`BrainPy` is designed to effectively satisfy your basic requirements: | |||
- **Pythonic**: BrainPy is based on Python language and has a Pythonic coding style. | |||
- **Flexible and transparent**: BrainPy endows the users with full data/logic flow control. Users can code any logic they want with BrainPy. | |||
- **Extensible**: BrainPy allows users to extend new functionality just based on Python code. Almost every part of the BrainPy system can be extended to be customized. | |||
- **Efficient**: All codes in BrainPy can be just-in-time compiled (based on [JAX](https://github.com/google/jax)) to run on CPU, GPU, or TPU devices, thus guaranteeing its running efficiency. | |||
# How to use BrainPy | |||
## Step 1: installation | |||
``BrainPy`` is based on Python (>=3.6), and the following packages are required to be installed to use ``BrainPy``: `numpy >= 1.15`, `matplotlib >= 3.4`, and `jax >= 0.2.10` ([how to install jax?](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax)) | |||
``BrainPy`` can be installed on Linux (Ubuntu 16.04 or later), macOS (10.12 or later), and Windows platforms. Use the following instructions to install ``brainpy``: | |||
```bash | |||
pip install brain-py -U | |||
``` | |||
*For the full installation details please see documentation: [Quickstart/Installation](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html)* | |||
## Step 2: useful links | |||
- **Documentation:** https://brainpy.readthedocs.io/ | |||
- **Bug reports:** https://github.com/PKU-NIP-Lab/BrainPy/issues | |||
- **Examples from papers**: https://brainpy-examples.readthedocs.io/ | |||
- **Canonical brain models**: https://brainmodels.readthedocs.io/ | |||
## Step 3: inspirational examples | |||
Here we list several examples of BrainPy. For more detailed examples and tutorials please see [**BrainModels**](https://brainmodels.readthedocs.io) or [**BrainPy-Examples**](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/). | |||
### Neuron models | |||
- [Leaky integrate-and-fire neuron model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.LIF.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/LIF.py) | |||
- [Exponential integrate-and-fire neuron model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.ExpIF.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/ExpIF.py) | |||
- [Quadratic integrate-and-fire neuron model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.QuaIF.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/QuaIF.py) | |||
- [Adaptive Quadratic integrate-and-fire model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.AdQuaIF.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/AdQuaIF.py) | |||
- [Adaptive Exponential integrate-and-fire model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.AdExIF.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/AdExIF.py) | |||
- [Generalized integrate-and-fire model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.GIF.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/GIF.py) | |||
- [Hodgkin–Huxley neuron model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.HH.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/HH.py) | |||
- [Izhikevich neuron model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.Izhikevich.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/Izhikevich.py) | |||
- [Morris-Lecar neuron model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.MorrisLecar.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/MorrisLecar.py) | |||
- [Hindmarsh-Rose bursting neuron model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.neurons.HindmarshRose.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/neurons/HindmarshRose.py) | |||
See [brainmodels.neurons](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/neurons.html) to find more. | |||
### Synapse models | |||
- [Voltage jump synapse model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.synapses.VoltageJump.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/synapses/voltage_jump.py) | |||
- [Exponential synapse model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.synapses.ExpCUBA.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/synapses/exponential.py) | |||
- [Alpha synapse model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.synapses.AlphaCUBA.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/synapses/alpha.py) | |||
- [Dual exponential synapse model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.synapses.DualExpCUBA.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/synapses/dual_exp.py) | |||
- [AMPA synapse model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.synapses.AMPA.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/synapses/AMPA.py) | |||
- [GABAA synapse model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.synapses.GABAa.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/synapses/GABAa.py) | |||
- [NMDA synapse model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.synapses.NMDA.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/synapses/NMDA.py) | |||
- [Short-term plasticity model](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/generated/brainmodels.synapses.STP.html), [source code](https://github.com/PKU-NIP-Lab/BrainModels/blob/brainpy-2.x/brainmodels/synapses/STP.py) | |||
See [brainmodels.synapses](https://brainmodels.readthedocs.io/en/brainpy-2.x/apis/synapses.html) to find more. | |||
### Network models | |||
- **[CANN]** [*(Si Wu, 2008)* Continuous-attractor Neural Network](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/cann/Wu_2008_CANN.html) | |||
- [*(Vreeswijk & Sompolinsky, 1996)* E/I balanced network](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/ei_nets/Vreeswijk_1996_EI_net.html) | |||
- [*(Sherman & Rinzel, 1992)* Gap junction leads to anti-synchronization](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/gj_nets/Sherman_1992_gj_antisynchrony.html) | |||
- [*(Wang & Buzsáki, 1996)* Gamma Oscillation](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/oscillation_synchronization/Wang_1996_gamma_oscillation.html) | |||
- [*(Brunel & Hakim, 1999)* Fast Global Oscillation](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/oscillation_synchronization/Brunel_Hakim_1999_fast_oscillation.html) | |||
- [*(Diesmann, et, al., 1999)* Synfire Chains](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/oscillation_synchronization/Diesmann_1999_synfire_chains.html) | |||
- **[Working Memory]** [*(Mi, et. al., 2017)* STP for Working Memory Capacity](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/working_memory/Mi_2017_working_memory_capacity.html) | |||
- **[Working Memory]** [*(Bouchacourt & Buschman, 2019)* Flexible Working Memory Model](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/working_memory/Bouchacourt_2019_Flexible_working_memory.html) | |||
- **[Decision Making]** [*(Wang, 2002)* Decision making spiking model](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/decision_making/Wang_2002_decision_making_spiking.html) | |||
### Dynamics training | |||
- [Train Integrator RNN with BP](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/recurrent_networks/integrator_rnn.html) | |||
- [*(Sussillo & Abbott, 2009)* FORCE Learning](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/recurrent_networks/Sussillo_Abbott_2009_FORCE_Learning.html) | |||
- [*(Laje & Buonomano, 2013)* Robust Timing in RNN](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/recurrent_networks/Laje_Buonomano_2013_robust_timing_rnn.html) | |||
- [*(Song, et al., 2016)*: Training excitatory-inhibitory recurrent network](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/recurrent_networks/Song_2016_EI_RNN.html) | |||
- **[Working Memory]** [*(Masse, et al., 2019)*: RNN with STP for Working Memory](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/recurrent_networks/Masse_2019_STP_RNN.html) | |||
### Low-dimensional dynamics analysis | |||
- [[1D] Simple systems](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/dynamics_analysis/1d_simple_systems.html) | |||
- [[2D] NaK model analysis](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/dynamics_analysis/2d_NaK_model.html) | |||
- [[3D] Hindmarsh Rose Model](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/dynamics_analysis/3d_hindmarsh_rose_model.html) | |||
- **[Decision Making Model]** [[2D] Decision making rate model](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/decision_making/Wang_2006_decision_making_rate.html) | |||
### High-dimensional dynamics analysis | |||
- [*(Yang, 2020)*: Dynamical system analysis for RNN](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/recurrent_networks/Yang_2020_RNN_Analysis.html) | |||
- [Continuous-attractor Neural Network](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/dynamics_analysis/highdim_CANN.html) | |||
- [Gap junction-coupled FitzHugh-Nagumo Model](https://brainpy-examples.readthedocs.io/en/brainpy-2.x/dynamics_analysis/highdim_gj_coupled_fhn.html) | |||
# BrainPy 1.x | |||
If you are using ``brainpy==1.x``, you can find *documentation*, *examples*, and *models* through the following links: | |||
- **Documentation:** https://brainpy.readthedocs.io/en/brainpy-1.x/ | |||
- **Examples from papers**: https://brainpy-examples.readthedocs.io/en/brainpy-1.x/ | |||
- **Canonical brain models**: https://brainmodels.readthedocs.io/en/brainpy-1.x/ | |||
The changes from ``brainpy==1.x`` to ``brainpy==2.x`` can be inspected through [API documentation: release notes](https://brainpy.readthedocs.io/en/latest/apis/auto/changelog.html). | |||
# Contributors |
@@ -1,6 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
__version__ = "2.1.0" | |||
__version__ = "2.1.2" | |||
try: | |||
@@ -15,7 +15,7 @@ except ModuleNotFoundError: | |||
# fundamental modules | |||
from . import errors, tools | |||
from . import errors, tools, check | |||
# "base" module | |||
@@ -37,9 +37,11 @@ from . import integrators | |||
from .integrators import ode | |||
from .integrators import sde | |||
from .integrators import dde | |||
from .integrators import fde | |||
from .integrators.ode import odeint | |||
from .integrators.sde import sdeint | |||
from .integrators.dde import ddeint | |||
from .integrators.fde import fdeint | |||
from .integrators.joint_eq import JointEq | |||
@@ -59,12 +61,12 @@ from . import running | |||
from . import analysis | |||
# "visualization" module, will be remove soon | |||
# "visualization" module, will be removed soon | |||
from .visualization import visualize | |||
# compatible interface | |||
from .compact import * # compact | |||
from .compat import * # compat | |||
# convenient access | |||
@@ -1,17 +1,18 @@ | |||
# -*- coding: utf-8 -*- | |||
import inspect | |||
import time | |||
import warnings | |||
from functools import partial | |||
from jax import vmap | |||
import jax.numpy | |||
import numpy as np | |||
from jax.scipy.optimize import minimize | |||
import brainpy.math as bm | |||
from brainpy import optimizers as optim | |||
from brainpy.analysis import utils | |||
from brainpy.errors import AnalyzerError | |||
from brainpy import optimizers as optim | |||
__all__ = [ | |||
'SlowPointFinder', | |||
@@ -56,15 +57,15 @@ class SlowPointFinder(object): | |||
if f_loss_batch is None: | |||
if f_type == 'discrete': | |||
self.f_loss = bm.jit(lambda h: bm.mean((h - f_cell(h)) ** 2)) | |||
self.f_loss_batch = bm.jit(lambda h: bm.mean((h - bm.vmap(f_cell, auto_infer=False)(h)) ** 2, axis=1)) | |||
self.f_loss_batch = bm.jit(lambda h: bm.mean((h - vmap(f_cell)(h)) ** 2, axis=1)) | |||
if f_type == 'continuous': | |||
self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h) ** 2)) | |||
self.f_loss_batch = bm.jit(lambda h: bm.mean((bm.vmap(f_cell, auto_infer=False)(h)) ** 2, axis=1)) | |||
self.f_loss_batch = bm.jit(lambda h: bm.mean((vmap(f_cell)(h)) ** 2, axis=1)) | |||
else: | |||
self.f_loss_batch = f_loss_batch | |||
self.f_loss = bm.jit(lambda h: bm.mean(f_cell(h) ** 2)) | |||
self.f_jacob_batch = bm.jit(bm.vmap(bm.jacobian(f_cell))) | |||
self.f_jacob_batch = bm.jit(vmap(bm.jacobian(f_cell))) | |||
# essential variables | |||
self._losses = None | |||
@@ -87,8 +88,13 @@ class SlowPointFinder(object): | |||
"""The selected ids of candidate points.""" | |||
return self._selected_ids | |||
def find_fps_with_gd_method(self, candidates, tolerance=1e-5, num_batch=100, | |||
num_opt=10000, opt_setting=None): | |||
def find_fps_with_gd_method(self, | |||
candidates, | |||
tolerance=1e-5, | |||
num_batch=100, | |||
num_opt=10000, | |||
optimizer=None, | |||
opt_setting=None): | |||
"""Optimize fixed points with gradient descent methods. | |||
Parameters | |||
@@ -104,17 +110,30 @@ class SlowPointFinder(object): | |||
Print training information during optimization every so often. | |||
opt_setting: optional, dict | |||
The optimization settings. | |||
.. deprecated:: 2.1.2 | |||
Use "optimizer" to set optimization method instead. | |||
optimizer: optim.Optimizer | |||
The optimizer instance. | |||
.. versionadded:: 2.1.2 | |||
""" | |||
# optimization settings | |||
if opt_setting is None: | |||
opt_method = optim.Adam | |||
opt_lr = optim.ExponentialDecay(0.2, 1, 0.9999) | |||
opt_setting = {'beta1': 0.9, | |||
'beta2': 0.999, | |||
'eps': 1e-8, | |||
'name': None} | |||
if optimizer is None: | |||
optimizer = optim.Adam(lr=optim.ExponentialDecay(0.2, 1, 0.9999), | |||
beta1=0.9, beta2=0.999, eps=1e-8) | |||
else: | |||
assert isinstance(optimizer, optim.Optimizer), (f'Must be an instance of ' | |||
f'{optim.Optimizer.__name__}, ' | |||
f'while we got {type(optimizer)}') | |||
else: | |||
warnings.warn('Please use "optimizer" to set optimization method. ' | |||
'"opt_setting" is deprecated since version 2.1.2. ', | |||
DeprecationWarning) | |||
assert isinstance(opt_setting, dict) | |||
assert 'method' in opt_setting | |||
assert 'lr' in opt_setting | |||
@@ -122,26 +141,25 @@ class SlowPointFinder(object): | |||
if isinstance(opt_method, str): | |||
assert opt_method in optim.__dict__ | |||
opt_method = getattr(optim, opt_method) | |||
assert isinstance(opt_method, type) | |||
if optim.Optimizer not in inspect.getmro(opt_method): | |||
raise ValueError | |||
assert issubclass(opt_method, optim.Optimizer) | |||
opt_lr = opt_setting.pop('lr') | |||
assert isinstance(opt_lr, (int, float, optim.Scheduler)) | |||
opt_setting = opt_setting | |||
optimizer = opt_method(lr=opt_lr, **opt_setting) | |||
if self.verbose: | |||
print(f"Optimizing with {opt_method.__name__} to find fixed points:") | |||
print(f"Optimizing with {optimizer.__name__} to find fixed points:") | |||
# set up optimization | |||
fixed_points = bm.Variable(bm.asarray(candidates)) | |||
grad_f = bm.grad(lambda: self.f_loss_batch(fixed_points.value).mean(), | |||
grad_vars={'a': fixed_points}, return_value=True) | |||
opt = opt_method(train_vars={'a': fixed_points}, lr=opt_lr, **opt_setting) | |||
dyn_vars = opt.vars() + {'_a': fixed_points} | |||
optimizer.register_vars({'a': fixed_points}) | |||
dyn_vars = optimizer.vars() + {'_a': fixed_points} | |||
def train(idx): | |||
gradients, loss = grad_f() | |||
opt.update(gradients) | |||
optimizer.update(gradients) | |||
return loss | |||
@partial(bm.jit, dyn_vars=dyn_vars, static_argnames=('start_i', 'num_batch')) | |||
@@ -191,7 +209,7 @@ class SlowPointFinder(object): | |||
opt_method = lambda f, x0: minimize(f, x0, method='BFGS') | |||
if self.verbose: | |||
print(f"Optimizing to find fixed points:") | |||
f_opt = bm.jit(bm.vmap(lambda x0: opt_method(self.f_loss, x0))) | |||
f_opt = bm.jit(vmap(lambda x0: opt_method(self.f_loss, x0))) | |||
res = f_opt(bm.as_device_array(candidates)) | |||
valid_ids = jax.numpy.where(res.success)[0] | |||
self._fixed_points = np.asarray(res.x[valid_ids]) | |||
@@ -2,8 +2,8 @@ | |||
from functools import partial | |||
import matplotlib.pyplot as plt | |||
import numpy as np | |||
from jax import vmap | |||
from jax import numpy as jnp | |||
from jax.scipy.optimize import minimize | |||
@@ -12,6 +12,8 @@ from brainpy import errors, tools | |||
from brainpy.analysis import constants as C, utils | |||
from brainpy.base.collector import Collector | |||
pyplot = None | |||
__all__ = [ | |||
'LowDimAnalyzer', | |||
'Num1DAnalyzer', | |||
@@ -207,7 +209,10 @@ class LowDimAnalyzer(object): | |||
self.analyzed_results = tools.DictPlus() | |||
def show_figure(self): | |||
plt.show() | |||
global pyplot | |||
if pyplot is None: | |||
from matplotlib import pyplot | |||
pyplot.show() | |||
class Num1DAnalyzer(LowDimAnalyzer): | |||
@@ -258,7 +263,7 @@ class Num1DAnalyzer(LowDimAnalyzer): | |||
@property | |||
def F_vmap_fx(self): | |||
if C.F_vmap_fx not in self.analyzed_results: | |||
self.analyzed_results[C.F_vmap_fx] = bm.jit(bm.vmap(self.F_fx), device=self.jit_device) | |||
self.analyzed_results[C.F_vmap_fx] = bm.jit(vmap(self.F_fx), device=self.jit_device) | |||
return self.analyzed_results[C.F_vmap_fx] | |||
@property | |||
@@ -285,7 +290,7 @@ class Num1DAnalyzer(LowDimAnalyzer): | |||
# --- | |||
# "X": a two-dimensional matrix: (num_batch, num_var) | |||
# "args": a list of one-dimensional vectors, each has the shape of (num_batch,) | |||
self.analyzed_results[C.F_vmap_fp_aux] = bm.jit(bm.vmap(self.F_fixed_point_aux)) | |||
self.analyzed_results[C.F_vmap_fp_aux] = bm.jit(vmap(self.F_fixed_point_aux)) | |||
return self.analyzed_results[C.F_vmap_fp_aux] | |||
@property | |||
@@ -304,7 +309,7 @@ class Num1DAnalyzer(LowDimAnalyzer): | |||
# --- | |||
# "X": a two-dimensional matrix: (num_batch, num_var) | |||
# "args": a list of one-dimensional vectors, each has the shape of (num_batch,) | |||
self.analyzed_results[C.F_vmap_fp_opt] = bm.jit(bm.vmap(self.F_fixed_point_opt)) | |||
self.analyzed_results[C.F_vmap_fp_opt] = bm.jit(vmap(self.F_fixed_point_opt)) | |||
return self.analyzed_results[C.F_vmap_fp_opt] | |||
def _get_fixed_points(self, candidates, *args, num_seg=None, tol_aux=1e-7, loss_screen=None): | |||
@@ -497,7 +502,7 @@ class Num2DAnalyzer(Num1DAnalyzer): | |||
@property | |||
def F_vmap_fy(self): | |||
if C.F_vmap_fy not in self.analyzed_results: | |||
self.analyzed_results[C.F_vmap_fy] = bm.jit(bm.vmap(self.F_fy), device=self.jit_device) | |||
self.analyzed_results[C.F_vmap_fy] = bm.jit(vmap(self.F_fy), device=self.jit_device) | |||
return self.analyzed_results[C.F_vmap_fy] | |||
@property | |||
@@ -659,7 +664,7 @@ class Num2DAnalyzer(Num1DAnalyzer): | |||
if self.F_x_by_y_in_fx is not None: | |||
utils.output("I am evaluating fx-nullcline by F_x_by_y_in_fx ...") | |||
vmap_f = bm.jit(bm.vmap(self.F_x_by_y_in_fx), device=self.jit_device) | |||
vmap_f = bm.jit(vmap(self.F_x_by_y_in_fx), device=self.jit_device) | |||
for j, pars in enumerate(par_seg): | |||
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") | |||
mesh_values = jnp.meshgrid(*((ys,) + pars)) | |||
@@ -675,7 +680,7 @@ class Num2DAnalyzer(Num1DAnalyzer): | |||
elif self.F_y_by_x_in_fx is not None: | |||
utils.output("I am evaluating fx-nullcline by F_y_by_x_in_fx ...") | |||
vmap_f = bm.jit(bm.vmap(self.F_y_by_x_in_fx), device=self.jit_device) | |||
vmap_f = bm.jit(vmap(self.F_y_by_x_in_fx), device=self.jit_device) | |||
for j, pars in enumerate(par_seg): | |||
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") | |||
mesh_values = jnp.meshgrid(*((xs,) + pars)) | |||
@@ -693,9 +698,9 @@ class Num2DAnalyzer(Num1DAnalyzer): | |||
utils.output("I am evaluating fx-nullcline by optimization ...") | |||
# auxiliary functions | |||
f2 = lambda y, x, *pars: self.F_fx(x, y, *pars) | |||
vmap_f2 = bm.jit(bm.vmap(f2), device=self.jit_device) | |||
vmap_brentq_f2 = bm.jit(bm.vmap(utils.jax_brentq(f2)), device=self.jit_device) | |||
vmap_brentq_f1 = bm.jit(bm.vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device) | |||
vmap_f2 = bm.jit(vmap(f2), device=self.jit_device) | |||
vmap_brentq_f2 = bm.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device) | |||
vmap_brentq_f1 = bm.jit(vmap(utils.jax_brentq(self.F_fx)), device=self.jit_device) | |||
# num segments | |||
for _j, Ps in enumerate(par_seg): | |||
@@ -752,7 +757,7 @@ class Num2DAnalyzer(Num1DAnalyzer): | |||
if self.F_x_by_y_in_fy is not None: | |||
utils.output("I am evaluating fy-nullcline by F_x_by_y_in_fy ...") | |||
vmap_f = bm.jit(bm.vmap(self.F_x_by_y_in_fy), device=self.jit_device) | |||
vmap_f = bm.jit(vmap(self.F_x_by_y_in_fy), device=self.jit_device) | |||
for j, pars in enumerate(par_seg): | |||
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") | |||
mesh_values = jnp.meshgrid(*((ys,) + pars)) | |||
@@ -768,7 +773,7 @@ class Num2DAnalyzer(Num1DAnalyzer): | |||
elif self.F_y_by_x_in_fy is not None: | |||
utils.output("I am evaluating fy-nullcline by F_y_by_x_in_fy ...") | |||
vmap_f = bm.jit(bm.vmap(self.F_y_by_x_in_fy), device=self.jit_device) | |||
vmap_f = bm.jit(vmap(self.F_y_by_x_in_fy), device=self.jit_device) | |||
for j, pars in enumerate(par_seg): | |||
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") | |||
mesh_values = jnp.meshgrid(*((xs,) + pars)) | |||
@@ -787,9 +792,9 @@ class Num2DAnalyzer(Num1DAnalyzer): | |||
# auxiliary functions | |||
f2 = lambda y, x, *pars: self.F_fy(x, y, *pars) | |||
vmap_f2 = bm.jit(bm.vmap(f2), device=self.jit_device) | |||
vmap_brentq_f2 = bm.jit(bm.vmap(utils.jax_brentq(f2)), device=self.jit_device) | |||
vmap_brentq_f1 = bm.jit(bm.vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device) | |||
vmap_f2 = bm.jit(vmap(f2), device=self.jit_device) | |||
vmap_brentq_f2 = bm.jit(vmap(utils.jax_brentq(f2)), device=self.jit_device) | |||
vmap_brentq_f1 = bm.jit(vmap(utils.jax_brentq(self.F_fy)), device=self.jit_device) | |||
for j, Ps in enumerate(par_seg): | |||
if len(par_seg.arg_id_segments[0]) > 1: utils.output(f"{C.prefix}segment {j} ...") | |||
@@ -837,7 +842,7 @@ class Num2DAnalyzer(Num1DAnalyzer): | |||
xs = self.resolutions[self.x_var].value | |||
ys = self.resolutions[self.y_var].value | |||
P = tuple(self.resolutions[p].value for p in self.target_par_names) | |||
f_select = bm.jit(bm.vmap(lambda vals, ids: vals[ids], in_axes=(1, 1))) | |||
f_select = bm.jit(vmap(lambda vals, ids: vals[ids], in_axes=(1, 1))) | |||
# num seguments | |||
if isinstance(num_segments, int): | |||
@@ -917,10 +922,10 @@ class Num2DAnalyzer(Num1DAnalyzer): | |||
if self.convert_type() == C.x_by_y: | |||
num_seg = len(self.resolutions[self.y_var]) | |||
f_vmap = bm.jit(bm.vmap(self.F_y_convert[1])) | |||
f_vmap = bm.jit(vmap(self.F_y_convert[1])) | |||
else: | |||
num_seg = len(self.resolutions[self.x_var]) | |||
f_vmap = bm.jit(bm.vmap(self.F_x_convert[1])) | |||
f_vmap = bm.jit(vmap(self.F_x_convert[1])) | |||
# get the signs | |||
signs = jnp.sign(f_vmap(candidates, *args)) | |||
signs = signs.reshape((num_seg, -1)) | |||
@@ -950,10 +955,10 @@ class Num2DAnalyzer(Num1DAnalyzer): | |||
# get another value | |||
if self.convert_type() == C.x_by_y: | |||
y_values = fps | |||
x_values = bm.jit(bm.vmap(self.F_y_convert[0]))(y_values, *args) | |||
x_values = bm.jit(vmap(self.F_y_convert[0]))(y_values, *args) | |||
else: | |||
x_values = fps | |||
y_values = bm.jit(bm.vmap(self.F_x_convert[0]))(x_values, *args) | |||
y_values = bm.jit(vmap(self.F_x_convert[0]))(x_values, *args) | |||
fps = jnp.stack([x_values, y_values]).T | |||
return fps, selected_ids, args | |||
@@ -3,7 +3,7 @@ | |||
from functools import partial | |||
import jax.numpy as jnp | |||
import matplotlib.pyplot as plt | |||
from jax import vmap | |||
import numpy as np | |||
import brainpy.math as bm | |||
@@ -11,6 +11,8 @@ from brainpy import errors | |||
from brainpy.analysis import stability, utils, constants as C | |||
from brainpy.analysis.lowdim.lowdim_analyzer import * | |||
pyplot = None | |||
__all__ = [ | |||
'Bifurcation1D', | |||
'Bifurcation2D', | |||
@@ -41,12 +43,14 @@ class Bifurcation1D(Num1DAnalyzer): | |||
@property | |||
def F_vmap_dfxdx(self): | |||
if C.F_vmap_dfxdx not in self.analyzed_results: | |||
f = bm.jit(bm.vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device) | |||
f = bm.jit(vmap(bm.vector_grad(self.F_fx, argnums=0)), device=self.jit_device) | |||
self.analyzed_results[C.F_vmap_dfxdx] = f | |||
return self.analyzed_results[C.F_vmap_dfxdx] | |||
def plot_bifurcation(self, with_plot=True, show=False, with_return=False, | |||
tol_aux=1e-8, loss_screen=None): | |||
global pyplot | |||
if pyplot is None: from matplotlib import pyplot | |||
utils.output('I am making bifurcation analysis ...') | |||
xs = self.resolutions[self.x_var] | |||
@@ -72,21 +76,21 @@ class Bifurcation1D(Num1DAnalyzer): | |||
container[fp_type]['x'].append(x) | |||
# visualization | |||
plt.figure(self.x_var) | |||
pyplot.figure(self.x_var) | |||
for fp_type, points in container.items(): | |||
if len(points['x']): | |||
plot_style = stability.plot_scheme[fp_type] | |||
plt.plot(points['p'], points['x'], '.', **plot_style, label=fp_type) | |||
plt.xlabel(self.target_par_names[0]) | |||
plt.ylabel(self.x_var) | |||
pyplot.plot(points['p'], points['x'], '.', **plot_style, label=fp_type) | |||
pyplot.xlabel(self.target_par_names[0]) | |||
pyplot.ylabel(self.x_var) | |||
scale = (self.lim_scale - 1) / 2 | |||
plt.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) | |||
plt.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) | |||
pyplot.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) | |||
pyplot.ylim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) | |||
plt.legend() | |||
pyplot.legend() | |||
if show: | |||
plt.show() | |||
pyplot.show() | |||
elif len(self.target_pars) == 2: | |||
container = {c: {'p0': [], 'p1': [], 'x': []} for c in stability.get_1d_stability_types()} | |||
@@ -99,7 +103,7 @@ class Bifurcation1D(Num1DAnalyzer): | |||
container[fp_type]['x'].append(x) | |||
# visualization | |||
fig = plt.figure(self.x_var) | |||
fig = pyplot.figure(self.x_var) | |||
ax = fig.add_subplot(projection='3d') | |||
for fp_type, points in container.items(): | |||
if len(points['x']): | |||
@@ -121,7 +125,7 @@ class Bifurcation1D(Num1DAnalyzer): | |||
ax.grid(True) | |||
ax.legend() | |||
if show: | |||
plt.show() | |||
pyplot.show() | |||
else: | |||
raise errors.BrainPyError(f'Cannot visualize co-dimension {len(self.target_pars)} ' | |||
@@ -156,7 +160,7 @@ class Bifurcation2D(Num2DAnalyzer): | |||
if C.F_vmap_jacobian not in self.analyzed_results: | |||
f1 = lambda xy, *args: jnp.array([self.F_fx(xy[0], xy[1], *args), | |||
self.F_fy(xy[0], xy[1], *args)]) | |||
f2 = bm.jit(bm.vmap(bm.jacobian(f1)), device=self.jit_device) | |||
f2 = bm.jit(vmap(bm.jacobian(f1)), device=self.jit_device) | |||
self.analyzed_results[C.F_vmap_jacobian] = f2 | |||
return self.analyzed_results[C.F_vmap_jacobian] | |||
@@ -212,6 +216,8 @@ class Bifurcation2D(Num2DAnalyzer): | |||
- parameters: a 2D matrix with the shape of (num_point, num_par) | |||
- jacobians: a 3D tensors with the shape of (num_point, 2, 2) | |||
""" | |||
global pyplot | |||
if pyplot is None: from matplotlib import pyplot | |||
utils.output('I am making bifurcation analysis ...') | |||
if self._can_convert_to_one_eq(): | |||
@@ -289,21 +295,21 @@ class Bifurcation2D(Num2DAnalyzer): | |||
# visualization | |||
for var in self.target_var_names: | |||
plt.figure(var) | |||
pyplot.figure(var) | |||
for fp_type, points in container.items(): | |||
if len(points['p']): | |||
plot_style = stability.plot_scheme[fp_type] | |||
plt.plot(points['p'], points[var], '.', **plot_style, label=fp_type) | |||
plt.xlabel(self.target_par_names[0]) | |||
plt.ylabel(var) | |||
pyplot.plot(points['p'], points[var], '.', **plot_style, label=fp_type) | |||
pyplot.xlabel(self.target_par_names[0]) | |||
pyplot.ylabel(var) | |||
scale = (self.lim_scale - 1) / 2 | |||
plt.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) | |||
plt.ylim(*utils.rescale(self.target_vars[var], scale=scale)) | |||
pyplot.xlim(*utils.rescale(self.target_pars[self.target_par_names[0]], scale=scale)) | |||
pyplot.ylim(*utils.rescale(self.target_vars[var], scale=scale)) | |||
plt.legend() | |||
pyplot.legend() | |||
if show: | |||
plt.show() | |||
pyplot.show() | |||
# bifurcation analysis of co-dimension 2 | |||
elif len(self.target_pars) == 2: | |||
@@ -320,7 +326,7 @@ class Bifurcation2D(Num2DAnalyzer): | |||
# visualization | |||
for var in self.target_var_names: | |||
fig = plt.figure(var) | |||
fig = pyplot.figure(var) | |||
ax = fig.add_subplot(projection='3d') | |||
for fp_type, points in container.items(): | |||
if len(points['p0']): | |||
@@ -340,7 +346,7 @@ class Bifurcation2D(Num2DAnalyzer): | |||
ax.grid(True) | |||
ax.legend() | |||
if show: | |||
plt.show() | |||
pyplot.show() | |||
else: | |||
raise ValueError('Unknown length of parameters.') | |||
@@ -350,6 +356,8 @@ class Bifurcation2D(Num2DAnalyzer): | |||
def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=False, | |||
plot_style=None, tol=0.001, show=False, dt=None, offset=1.): | |||
global pyplot | |||
if pyplot is None: from matplotlib import pyplot | |||
utils.output('I am plotting the limit cycle ...') | |||
if self._fixed_points is None: | |||
utils.output('No fixed points found, you may call "plot_bifurcation(with_plot=True)" first.') | |||
@@ -386,31 +394,33 @@ class Bifurcation2D(Num2DAnalyzer): | |||
# visualization | |||
if with_plot: | |||
if plot_style is None: plot_style = dict() | |||
fmt = plot_style.pop('fmt', '.') | |||
fmt = plot_style.pop('fmt', '*') | |||
if len(self.target_par_names) == 2: | |||
if len(ps_limit_cycle[0]): | |||
for i, var in enumerate(self.target_var_names): | |||
plt.figure(var) | |||
plt.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['max'], | |||
pyplot.figure(var) | |||
pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['max'], | |||
**plot_style, label='limit cycle (max)') | |||
plt.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['min'], | |||
pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['min'], | |||
**plot_style, label='limit cycle (min)') | |||
plt.legend() | |||
pyplot.legend() | |||
elif len(self.target_par_names) == 1: | |||
if len(ps_limit_cycle[0]): | |||
for i, var in enumerate(self.target_var_names): | |||
plt.figure(var) | |||
plt.plot(ps_limit_cycle[0], vs_limit_cycle[i]['max'], fmt, | |||
pyplot.figure(var) | |||
pyplot.plot(ps_limit_cycle[0], vs_limit_cycle[i]['max'], fmt, | |||
**plot_style, label='limit cycle (max)') | |||
plt.plot(ps_limit_cycle[0], vs_limit_cycle[i]['min'], fmt, | |||
pyplot.plot(ps_limit_cycle[0], vs_limit_cycle[i]['min'], fmt, | |||
**plot_style, label='limit cycle (min)') | |||
plt.legend() | |||
pyplot.legend() | |||
else: | |||
raise errors.AnalyzerError | |||
if show: | |||
plt.show() | |||
pyplot.show() | |||
if with_return: | |||
return vs_limit_cycle, ps_limit_cycle | |||
@@ -437,6 +447,8 @@ class FastSlow1D(Bifurcation1D): | |||
def plot_trajectory(self, initials, duration, plot_durations=None, | |||
dt=None, show=False, with_plot=True, with_return=False): | |||
global pyplot | |||
if pyplot is None: from matplotlib import pyplot | |||
utils.output('I am plotting the trajectory ...') | |||
# check the initial values | |||
@@ -470,11 +482,11 @@ class FastSlow1D(Bifurcation1D): | |||
end = int(plot_durations[i][1] / dt) | |||
p1_var = self.target_par_names[0] | |||
if len(self.target_par_names) == 1: | |||
lines = plt.plot(mon_res[self.x_var][start: end, i], | |||
lines = pyplot.plot(mon_res[self.x_var][start: end, i], | |||
mon_res[p1_var][start: end, i], label=legend) | |||
elif len(self.target_par_names) == 2: | |||
p2_var = self.target_par_names[1] | |||
lines = plt.plot(mon_res[self.x_var][start: end, i], | |||
lines = pyplot.plot(mon_res[self.x_var][start: end, i], | |||
mon_res[p1_var][start: end, i], | |||
mon_res[p2_var][start: end, i], | |||
label=legend) | |||
@@ -488,10 +500,10 @@ class FastSlow1D(Bifurcation1D): | |||
# scale = (self.lim_scale - 1.) / 2 | |||
# plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) | |||
# plt.ylim(*utils.rescale(self.target_vars[self.target_par_names[0]], scale=scale)) | |||
plt.legend() | |||
pyplot.legend() | |||
if show: | |||
plt.show() | |||
pyplot.show() | |||
if with_return: | |||
return mon_res | |||
@@ -517,6 +529,8 @@ class FastSlow2D(Bifurcation2D): | |||
def plot_trajectory(self, initials, duration, plot_durations=None, | |||
dt=None, show=False, with_plot=True, with_return=False): | |||
global pyplot | |||
if pyplot is None: from matplotlib import pyplot | |||
utils.output('I am plotting the trajectory ...') | |||
# check the initial values | |||
@@ -548,25 +562,25 @@ class FastSlow2D(Bifurcation2D): | |||
end = int(plot_durations[i][1] / dt) | |||
# visualization | |||
plt.figure(self.x_var) | |||
lines = plt.plot(mon_res[self.target_par_names[0]][start: end, i], | |||
pyplot.figure(self.x_var) | |||
lines = pyplot.plot(mon_res[self.target_par_names[0]][start: end, i], | |||
mon_res[self.x_var][start: end, i], | |||
label=legend) | |||
utils.add_arrow(lines[0]) | |||
plt.figure(self.y_var) | |||
lines = plt.plot(mon_res[self.target_par_names[0]][start: end, i], | |||
pyplot.figure(self.y_var) | |||
lines = pyplot.plot(mon_res[self.target_par_names[0]][start: end, i], | |||
mon_res[self.y_var][start: end, i], | |||
label=legend) | |||
utils.add_arrow(lines[0]) | |||
plt.figure(self.x_var) | |||
plt.legend() | |||
plt.figure(self.y_var) | |||
plt.legend() | |||
pyplot.figure(self.x_var) | |||
pyplot.legend() | |||
pyplot.figure(self.y_var) | |||
pyplot.legend() | |||
if show: | |||
plt.show() | |||
pyplot.show() | |||
if with_return: | |||
return mon_res |
@@ -1,14 +1,16 @@ | |||
# -*- coding: utf-8 -*- | |||
import jax.numpy as jnp | |||
import matplotlib.pyplot as plt | |||
import numpy as np | |||
from jax import vmap | |||
import brainpy.math as bm | |||
from brainpy import errors, math | |||
from brainpy.analysis import stability, constants as C, utils | |||
from brainpy.analysis.lowdim.lowdim_analyzer import * | |||
pyplot = None | |||
__all__ = [ | |||
'PhasePlane1D', | |||
'PhasePlane2D', | |||
@@ -62,6 +64,8 @@ class PhasePlane1D(Num1DAnalyzer): | |||
def plot_vector_field(self, show=False, with_plot=True, with_return=False): | |||
"""Plot the vector filed.""" | |||
global pyplot | |||
if pyplot is None: from matplotlib import pyplot | |||
utils.output('I am creating the vector field ...') | |||
# Nullcline of the x variable | |||
@@ -72,19 +76,21 @@ class PhasePlane1D(Num1DAnalyzer): | |||
if with_plot: | |||
label = f"d{self.x_var}dt" | |||
x_style = dict(color='lightcoral', alpha=.7, linewidth=4) | |||
plt.plot(np.asarray(self.resolutions[self.x_var]), y_val, **x_style, label=label) | |||
plt.axhline(0) | |||
plt.xlabel(self.x_var) | |||
plt.ylabel(label) | |||
plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=(self.lim_scale - 1.) / 2)) | |||
plt.legend() | |||
if show: plt.show() | |||
pyplot.plot(np.asarray(self.resolutions[self.x_var]), y_val, **x_style, label=label) | |||
pyplot.axhline(0) | |||
pyplot.xlabel(self.x_var) | |||
pyplot.ylabel(label) | |||
pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=(self.lim_scale - 1.) / 2)) | |||
pyplot.legend() | |||
if show: pyplot.show() | |||
# return | |||
if with_return: | |||
return y_val | |||
def plot_fixed_point(self, show=False, with_plot=True, with_return=False): | |||
"""Plot the fixed point.""" | |||
global pyplot | |||
if pyplot is None: from matplotlib import pyplot | |||
utils.output('I am searching fixed points ...') | |||
# fixed points and stability analysis | |||
@@ -102,10 +108,10 @@ class PhasePlane1D(Num1DAnalyzer): | |||
for fp_type, points in container.items(): | |||
if len(points): | |||
plot_style = stability.plot_scheme[fp_type] | |||
plt.plot(points, [0] * len(points), '.', markersize=20, **plot_style, label=fp_type) | |||
plt.legend() | |||
pyplot.plot(points, [0] * len(points), '.', markersize=20, **plot_style, label=fp_type) | |||
pyplot.legend() | |||
if show: | |||
plt.show() | |||
pyplot.show() | |||
# return | |||
if with_return: | |||
@@ -153,7 +159,7 @@ class PhasePlane2D(Num2DAnalyzer): | |||
@property | |||
def F_vmap_brentq_fy(self): | |||
if C.F_vmap_brentq_fy not in self.analyzed_results: | |||
f_opt = bm.jit(bm.vmap(utils.jax_brentq(self.F_fy))) | |||
f_opt = bm.jit(vmap(utils.jax_brentq(self.F_fy))) | |||
self.analyzed_results[C.F_vmap_brentq_fy] = f_opt | |||
return self.analyzed_results[C.F_vmap_brentq_fy] | |||
@@ -178,6 +184,8 @@ class PhasePlane2D(Num2DAnalyzer): | |||
"units", "angles", "scale". More settings please check | |||
https://matplotlib.org/api/_as_gen/matplotlib.pyplot.quiver.html. | |||
""" | |||
global pyplot | |||
if pyplot is None: from matplotlib import pyplot | |||
utils.output('I am creating the vector field ...') | |||
# get vector fields | |||
@@ -197,7 +205,7 @@ class PhasePlane2D(Num2DAnalyzer): | |||
speed = np.sqrt(dx ** 2 + dy ** 2) | |||
dx = dx / speed | |||
dy = dy / speed | |||
plt.quiver(X, Y, dx, dy, **plot_style) | |||
pyplot.quiver(X, Y, dx, dy, **plot_style) | |||
elif plot_method == 'streamplot': | |||
if plot_style is None: | |||
plot_style = dict(arrowsize=1.2, density=1, color='thistle') | |||
@@ -207,15 +215,15 @@ class PhasePlane2D(Num2DAnalyzer): | |||
min_width, max_width = 0.5, 5.5 | |||
speed = np.nan_to_num(np.sqrt(dx ** 2 + dy ** 2)) | |||
linewidth = min_width + max_width * (speed / speed.max()) | |||
plt.streamplot(X, Y, dx, dy, linewidth=linewidth, **plot_style) | |||
pyplot.streamplot(X, Y, dx, dy, linewidth=linewidth, **plot_style) | |||
else: | |||
raise errors.AnalyzerError(f'Unknown plot_method "{plot_method}", ' | |||
f'only supports "quiver" and "streamplot".') | |||
plt.xlabel(self.x_var) | |||
plt.ylabel(self.y_var) | |||
pyplot.xlabel(self.x_var) | |||
pyplot.ylabel(self.y_var) | |||
if show: | |||
plt.show() | |||
pyplot.show() | |||
if with_return: # return vector fields | |||
return dx, dy | |||
@@ -224,6 +232,8 @@ class PhasePlane2D(Num2DAnalyzer): | |||
y_style=None, x_style=None, show=False, | |||
coords=None, tol_nullcline=1e-7): | |||
"""Plot the nullcline.""" | |||
global pyplot | |||
if pyplot is None: from matplotlib import pyplot | |||
utils.output('I am computing fx-nullcline ...') | |||
if coords is None: | |||
@@ -240,7 +250,7 @@ class PhasePlane2D(Num2DAnalyzer): | |||
if x_style is None: | |||
x_style = dict(color='cornflowerblue', alpha=.7, ) | |||
fmt = x_style.pop('fmt', '.') | |||
plt.plot(x_values_in_fx, y_values_in_fx, fmt, **x_style, label=f"{self.x_var} nullcline") | |||
pyplot.plot(x_values_in_fx, y_values_in_fx, fmt, **x_style, label=f"{self.x_var} nullcline") | |||
# Nullcline of the y variable | |||
utils.output('I am computing fy-nullcline ...') | |||
@@ -252,17 +262,17 @@ class PhasePlane2D(Num2DAnalyzer): | |||
if y_style is None: | |||
y_style = dict(color='lightcoral', alpha=.7, ) | |||
fmt = y_style.pop('fmt', '.') | |||
plt.plot(x_values_in_fy, y_values_in_fy, fmt, **y_style, label=f"{self.y_var} nullcline") | |||
pyplot.plot(x_values_in_fy, y_values_in_fy, fmt, **y_style, label=f"{self.y_var} nullcline") | |||
if with_plot: | |||
plt.xlabel(self.x_var) | |||
plt.ylabel(self.y_var) | |||
pyplot.xlabel(self.x_var) | |||
pyplot.ylabel(self.y_var) | |||
scale = (self.lim_scale - 1.) / 2 | |||
plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) | |||
plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale)) | |||
plt.legend() | |||
pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) | |||
pyplot.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale)) | |||
pyplot.legend() | |||
if show: | |||
plt.show() | |||
pyplot.show() | |||
if with_return: | |||
return {self.x_var: (x_values_in_fx, y_values_in_fx), | |||
@@ -273,6 +283,8 @@ class PhasePlane2D(Num2DAnalyzer): | |||
select_candidates='fx-nullcline', num_rank=100, ): | |||
"""Plot the fixed point and analyze its stability. | |||
""" | |||
global pyplot | |||
if pyplot is None: from matplotlib import pyplot | |||
utils.output('I am searching fixed points ...') | |||
if self._can_convert_to_one_eq(): | |||
@@ -338,10 +350,10 @@ class PhasePlane2D(Num2DAnalyzer): | |||
for fp_type, points in container.items(): | |||
if len(points['x']): | |||
plot_style = stability.plot_scheme[fp_type] | |||
plt.plot(points['x'], points['y'], '.', markersize=20, **plot_style, label=fp_type) | |||
plt.legend() | |||
pyplot.plot(points['x'], points['y'], '.', markersize=20, **plot_style, label=fp_type) | |||
pyplot.legend() | |||
if show: | |||
plt.show() | |||
pyplot.show() | |||
if with_return: | |||
return fixed_points | |||
@@ -377,7 +389,8 @@ class PhasePlane2D(Num2DAnalyzer): | |||
show : bool | |||
Whether show or not. | |||
""" | |||
global pyplot | |||
if pyplot is None: from matplotlib import pyplot | |||
utils.output('I am plotting the trajectory ...') | |||
if axes not in ['v-v', 't-v']: | |||
@@ -413,28 +426,31 @@ class PhasePlane2D(Num2DAnalyzer): | |||
start = int(plot_durations[i][0] / dt) | |||
end = int(plot_durations[i][1] / dt) | |||
if axes == 'v-v': | |||
lines = plt.plot(mon_res[self.x_var][start: end, i], mon_res[self.y_var][start: end, i], | |||
lines = pyplot.plot(mon_res[self.x_var][start: end, i], | |||
mon_res[self.y_var][start: end, i], | |||
label=legend, **kwargs) | |||
utils.add_arrow(lines[0]) | |||
else: | |||
plt.plot(mon_res.ts[start: end], mon_res[self.x_var][start: end, i], | |||
pyplot.plot(mon_res.ts[start: end], | |||
mon_res[self.x_var][start: end, i], | |||
label=legend + f', {self.x_var}', **kwargs) | |||
plt.plot(mon_res.ts[start: end], mon_res[self.y_var][start: end, i], | |||
pyplot.plot(mon_res.ts[start: end], | |||
mon_res[self.y_var][start: end, i], | |||
label=legend + f', {self.y_var}', **kwargs) | |||
# visualization of others | |||
if axes == 'v-v': | |||
plt.xlabel(self.x_var) | |||
plt.ylabel(self.y_var) | |||
pyplot.xlabel(self.x_var) | |||
pyplot.ylabel(self.y_var) | |||
scale = (self.lim_scale - 1.) / 2 | |||
plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) | |||
plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale)) | |||
plt.legend() | |||
pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) | |||
pyplot.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale)) | |||
pyplot.legend() | |||
else: | |||
plt.legend(title='Initial values') | |||
pyplot.legend(title='Initial values') | |||
if show: | |||
plt.show() | |||
pyplot.show() | |||
if with_return: | |||
return mon_res | |||
@@ -462,6 +478,8 @@ class PhasePlane2D(Num2DAnalyzer): | |||
show : bool | |||
Whether show or not. | |||
""" | |||
global pyplot | |||
if pyplot is None: from matplotlib import pyplot | |||
utils.output('I am plotting the limit cycle ...') | |||
# 1. format the initial values | |||
@@ -487,18 +505,18 @@ class PhasePlane2D(Num2DAnalyzer): | |||
x_cycle = x_data[max_index[0]: max_index[1]] | |||
y_cycle = y_data[max_index[0]: max_index[1]] | |||
# 5.5 visualization | |||
lines = plt.plot(x_cycle, y_cycle, label='limit cycle') | |||
lines = pyplot.plot(x_cycle, y_cycle, label='limit cycle') | |||
utils.add_arrow(lines[0]) | |||
else: | |||
utils.output(f'No limit cycle found for initial value {initial}') | |||
# 6. visualization | |||
plt.xlabel(self.x_var) | |||
plt.ylabel(self.y_var) | |||
pyplot.xlabel(self.x_var) | |||
pyplot.ylabel(self.y_var) | |||
scale = (self.lim_scale - 1.) / 2 | |||
plt.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) | |||
plt.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale)) | |||
plt.legend() | |||
pyplot.xlim(*utils.rescale(self.target_vars[self.x_var], scale=scale)) | |||
pyplot.ylim(*utils.rescale(self.target_vars[self.y_var], scale=scale)) | |||
pyplot.legend() | |||
if show: | |||
plt.show() | |||
pyplot.show() |
@@ -2,6 +2,7 @@ | |||
import jax.numpy as jnp | |||
import numpy as np | |||
from brainpy.tools.others import numba_jit | |||
__all__ = [ | |||
@@ -10,7 +11,7 @@ __all__ = [ | |||
] | |||
# @tools.numba_jit | |||
@numba_jit | |||
def _f1(arr, grad, tol): | |||
condition = np.logical_and(grad[:-1] * grad[1:] <= 0, grad[:-1] >= 0) | |||
indexes = np.where(condition)[0] | |||
@@ -19,7 +20,8 @@ def _f1(arr, grad, tol): | |||
length = np.max(data) - np.min(data) | |||
a = arr[indexes[-2]] | |||
b = arr[indexes[-1]] | |||
if np.abs(a - b) <= tol * length: | |||
# TODO: how to choose length threshold, 1e-3? | |||
if length > 1e-3 and np.abs(a - b) <= tol * length: | |||
return indexes[-2:] | |||
return np.array([-1, -1]) | |||
@@ -49,8 +49,7 @@ def model_transform(model): | |||
new_model = [] | |||
for intg in model: | |||
if isinstance(intg.f, JointEq): | |||
new_model.extend([type(intg)(eq, var_type=intg.var_type, dt=intg.dt, dyn_var=intg.dyn_var) | |||
for eq in intg.f.eqs]) | |||
new_model.extend([type(intg)(eq, var_type=intg.var_type, dt=intg.dt) for eq in intg.f.eqs]) | |||
else: | |||
new_model.append(intg) | |||
@@ -4,7 +4,7 @@ | |||
import jax.lax | |||
import jax.numpy as jnp | |||
import numpy as np | |||
from jax import grad, jit | |||
from jax import grad, jit, vmap | |||
from jax.flatten_util import ravel_pytree | |||
import brainpy.math as bm | |||
@@ -197,7 +197,7 @@ def brentq_candidates(vmap_f, *values, args=()): | |||
def brentq_roots(f, starts, ends, *vmap_args, args=()): | |||
in_axes = (0, 0, tuple([0] * len(vmap_args)) + tuple([None] * len(args))) | |||
vmap_f_opt = bm.jit(bm.vmap(jax_brentq(f), in_axes=in_axes)) | |||
vmap_f_opt = bm.jit(vmap(jax_brentq(f), in_axes=in_axes)) | |||
all_args = vmap_args + args | |||
if len(all_args): | |||
res = vmap_f_opt(starts, ends, all_args) | |||
@@ -397,7 +397,7 @@ def roots_of_1d_by_x(f, candidates, args=()): | |||
return fps | |||
starts = candidates[candidate_ids] | |||
ends = candidates[candidate_ids + 1] | |||
f_opt = bm.jit(bm.vmap(jax_brentq(f), in_axes=(0, 0, None))) | |||
f_opt = bm.jit(vmap(jax_brentq(f), in_axes=(0, 0, None))) | |||
res = f_opt(starts, ends, args) | |||
valid_idx = jnp.where(res['status'] == ECONVERGED)[0] | |||
fps2 = res['root'][valid_idx] | |||
@@ -406,7 +406,7 @@ def roots_of_1d_by_x(f, candidates, args=()): | |||
def roots_of_1d_by_xy(f, starts, ends, args): | |||
f = f_without_jaxarray_return(f) | |||
f_opt = bm.jit(bm.vmap(jax_brentq(f))) | |||
f_opt = bm.jit(vmap(jax_brentq(f))) | |||
res = f_opt(starts, ends, (args,)) | |||
valid_idx = jnp.where(res['status'] == ECONVERGED)[0] | |||
xs = res['root'][valid_idx] | |||
@@ -1,6 +1,7 @@ | |||
# -*- coding: utf-8 -*- | |||
import jax.numpy as jnp | |||
from jax import vmap | |||
import numpy as np | |||
import brainpy.math as bm | |||
@@ -76,7 +77,7 @@ def get_sign(f, xs, ys): | |||
def get_sign2(f, *xyz, args=()): | |||
in_axes = tuple(range(len(xyz))) + tuple([None] * len(args)) | |||
f = bm.jit(bm.vmap(f_without_jaxarray_return(f), in_axes=in_axes)) | |||
f = bm.jit(vmap(f_without_jaxarray_return(f), in_axes=in_axes)) | |||
xyz = tuple((v.value if isinstance(v, bm.JaxArray) else v) for v in xyz) | |||
XYZ = jnp.meshgrid(*xyz) | |||
XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ) | |||
@@ -0,0 +1,27 @@ | |||
# -*- coding: utf-8 -*- | |||
__all__ = [ | |||
'is_checking', | |||
'turn_on', | |||
'turn_off', | |||
] | |||
_check = True | |||
def is_checking(): | |||
"""Whether the checking is turn on.""" | |||
return _check | |||
def turn_on(): | |||
"""Turn on the checking.""" | |||
global _check | |||
_check = True | |||
def turn_off(): | |||
"""Turn off the checking.""" | |||
global _check | |||
_check = False |
@@ -1,12 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
from brainpy.dyn import LIF, AdExIF, Izhikevich, ExpCOBA, ExpCUBA, DeltaSynapse | |||
__all__ = [ | |||
'LIF', | |||
'AdExIF', | |||
'Izhikevich', | |||
'ExpCOBA', | |||
'ExpCUBA', | |||
'DeltaSynapse', | |||
] |
@@ -1,11 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
from brainpy.integrators.runner import IntegratorRunner | |||
from brainpy.dyn.runners import DSRunner, StructRunner, ReportRunner | |||
__all__ = [ | |||
'IntegratorRunner', | |||
'DSRunner', | |||
'StructRunner', | |||
'ReportRunner' | |||
] |
@@ -15,6 +15,11 @@ __all__ = [ | |||
class DynamicalSystem(dyn.DynamicalSystem): | |||
"""Dynamical System. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.dyn.DynamicalSystem" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
warnings.warn('Please use "brainpy.dyn.DynamicalSystem" instead. ' | |||
'"brainpy.DynamicalSystem" is deprecated since ' | |||
@@ -23,6 +28,11 @@ class DynamicalSystem(dyn.DynamicalSystem): | |||
class Container(dyn.Container): | |||
"""Container. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.dyn.Container" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
warnings.warn('Please use "brainpy.dyn.Container" instead. ' | |||
'"brainpy.Container" is deprecated since ' | |||
@@ -31,6 +41,11 @@ class Container(dyn.Container): | |||
class Network(dyn.Network): | |||
"""Network. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.dyn.Network" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
warnings.warn('Please use "brainpy.dyn.Network" instead. ' | |||
'"brainpy.Network" is deprecated since ' | |||
@@ -39,6 +54,11 @@ class Network(dyn.Network): | |||
class ConstantDelay(dyn.ConstantDelay): | |||
"""Constant Delay. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.dyn.ConstantDelay" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
warnings.warn('Please use "brainpy.dyn.ConstantDelay" instead. ' | |||
'"brainpy.ConstantDelay" is deprecated since ' | |||
@@ -47,6 +67,11 @@ class ConstantDelay(dyn.ConstantDelay): | |||
class NeuGroup(dyn.NeuGroup): | |||
"""Neuron group. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.dyn.NeuGroup" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
warnings.warn('Please use "brainpy.dyn.NeuGroup" instead. ' | |||
'"brainpy.NeuGroup" is deprecated since ' | |||
@@ -55,6 +80,11 @@ class NeuGroup(dyn.NeuGroup): | |||
class TwoEndConn(dyn.TwoEndConn): | |||
"""Two-end synaptic connection. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.dyn.TwoEndConn" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
warnings.warn('Please use "brainpy.dyn.TwoEndConn" instead. ' | |||
'"brainpy.TwoEndConn" is deprecated since ' |
@@ -13,6 +13,11 @@ __all__ = [ | |||
def set_default_odeint(method): | |||
"""Set default ode integrator. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.ode.set_default_odeint" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.ode.set_default_odeint" instead. ' | |||
'"brainpy.set_default_odeint" is deprecated since ' | |||
'version 2.1.0', DeprecationWarning) | |||
@@ -20,6 +25,11 @@ def set_default_odeint(method): | |||
def get_default_odeint(): | |||
"""Get default ode integrator. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.ode.get_default_odeint" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.ode.get_default_odeint" instead. ' | |||
'"brainpy.get_default_odeint" is deprecated since ' | |||
'version 2.1.0', DeprecationWarning) | |||
@@ -27,6 +37,11 @@ def get_default_odeint(): | |||
def set_default_sdeint(method): | |||
"""Set default sde integrator. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.ode.set_default_sdeint" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.sde.set_default_sdeint" instead. ' | |||
'"brainpy.set_default_sdeint" is deprecated since ' | |||
'version 2.1.0', DeprecationWarning) | |||
@@ -34,6 +49,11 @@ def set_default_sdeint(method): | |||
def get_default_sdeint(): | |||
"""Get default sde integrator. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.ode.get_default_sdeint" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.sde.get_default_sdeint" instead. ' | |||
'"brainpy.get_default_sdeint" is deprecated since ' | |||
'version 2.1.0', DeprecationWarning) |
@@ -23,7 +23,10 @@ def _check_args(args): | |||
class Module(Base): | |||
"""Basic module class.""" | |||
"""Basic module class. | |||
.. deprecated:: 2.1.0 | |||
""" | |||
@staticmethod | |||
def get_param(param, size): | |||
@@ -47,7 +50,7 @@ class Module(Base): | |||
def __init__(self, name=None): # initialize parameters | |||
warnings.warn('Please use "brainpy.rnns.Module" instead. ' | |||
'"brainpy.layers.Module" is deprecated since ' | |||
'version 2.0.3.', DeprecationWarning) | |||
'version 2.1.0.', DeprecationWarning) | |||
super(Module, self).__init__(name=name) | |||
def __call__(self, *args, **kwargs): # initialize variables |
@@ -0,0 +1,98 @@ | |||
# -*- coding: utf-8 -*- | |||
import warnings | |||
from brainpy.dyn import neurons, synapses | |||
__all__ = [ | |||
'LIF', | |||
'AdExIF', | |||
'Izhikevich', | |||
'ExpCOBA', | |||
'ExpCUBA', | |||
'DeltaSynapse', | |||
] | |||
class LIF(neurons.LIF): | |||
"""LIF neuron model. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.dyn.LIF" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
warnings.warn('Please use "brainpy.dyn.LIF" instead. ' | |||
'"brainpy.models.LIF" is deprecated since ' | |||
'version 2.1.0', DeprecationWarning) | |||
super(LIF, self).__init__(*args, **kwargs) | |||
class AdExIF(neurons.AdExIF): | |||
"""AdExIF neuron model. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.dyn.AdExIF" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
warnings.warn('Please use "brainpy.dyn.AdExIF" instead. ' | |||
'"brainpy.models.AdExIF" is deprecated since ' | |||
'version 2.1.0', DeprecationWarning) | |||
super(AdExIF, self).__init__(*args, **kwargs) | |||
class Izhikevich(neurons.Izhikevich): | |||
"""Izhikevich neuron model. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.dyn.Izhikevich" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
warnings.warn('Please use "brainpy.dyn.Izhikevich" instead. ' | |||
'"brainpy.models.Izhikevich" is deprecated since ' | |||
'version 2.1.0', DeprecationWarning) | |||
super(Izhikevich, self).__init__(*args, **kwargs) | |||
class ExpCOBA(synapses.ExpCOBA): | |||
"""ExpCOBA synapse model. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.dyn.ExpCOBA" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
warnings.warn('Please use "brainpy.dyn.ExpCOBA" instead. ' | |||
'"brainpy.models.ExpCOBA" is deprecated since ' | |||
'version 2.1.0', DeprecationWarning) | |||
super(ExpCOBA, self).__init__(*args, **kwargs) | |||
class ExpCUBA(synapses.ExpCUBA): | |||
"""ExpCUBA synapse model. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.dyn.ExpCUBA" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
warnings.warn('Please use "brainpy.dyn.ExpCUBA" instead. ' | |||
'"brainpy.models.ExpCUBA" is deprecated since ' | |||
'version 2.1.0', DeprecationWarning) | |||
super(ExpCUBA, self).__init__(*args, **kwargs) | |||
class DeltaSynapse(synapses.DeltaSynapse): | |||
"""Delta synapse model. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.dyn.DeltaSynapse" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
warnings.warn('Please use "brainpy.dyn.DeltaSynapse" instead. ' | |||
'"brainpy.models.DeltaSynapse" is deprecated since ' | |||
'version 2.1.0', DeprecationWarning) | |||
super(DeltaSynapse, self).__init__(*args, **kwargs) |
@@ -9,8 +9,13 @@ __all__ = [ | |||
class Monitor(monitor.Monitor): | |||
"""Monitor class. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.running.Monitor" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
super(Monitor, self).__init__(*args, **kwargs) | |||
warnings.warn('Please use "brainpy.running.Monitor" instead. ' | |||
'"brainpy.Monitor" is deprecated since version 2.0.3.', | |||
'"brainpy.Monitor" is deprecated since version 2.1.0.', | |||
DeprecationWarning) | |||
super(Monitor, self).__init__(*args, **kwargs) |
@@ -0,0 +1,65 @@ | |||
# -*- coding: utf-8 -*- | |||
import warnings | |||
from brainpy.dyn import runners as dyn_runner | |||
from brainpy.integrators import runner as intg_runner | |||
__all__ = [ | |||
'IntegratorRunner', | |||
'DSRunner', | |||
'StructRunner', | |||
'ReportRunner' | |||
] | |||
class IntegratorRunner(intg_runner.IntegratorRunner): | |||
"""Integrator runner class. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.integrators.IntegratorRunner" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
warnings.warn('Please use "brainpy.integrators.IntegratorRunner" instead. ' | |||
'"brainpy.IntegratorRunner" is deprecated since ' | |||
'version 2.1.0', DeprecationWarning) | |||
super(IntegratorRunner, self).__init__(*args, **kwargs) | |||
class DSRunner(dyn_runner.DSRunner): | |||
"""Dynamical system runner class. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.dyn.DSRunner" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
warnings.warn('Please use "brainpy.dyn.DSRunner" instead. ' | |||
'"brainpy.DSRunner" is deprecated since ' | |||
'version 2.1.0', DeprecationWarning) | |||
super(DSRunner, self).__init__(*args, **kwargs) | |||
class StructRunner(dyn_runner.StructRunner): | |||
"""Dynamical system runner class. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.dyn.StructRunner" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
warnings.warn('Please use "brainpy.dyn.StructRunner" instead. ' | |||
'"brainpy.StructRunner" is deprecated since ' | |||
'version 2.1.0', DeprecationWarning) | |||
super(StructRunner, self).__init__(*args, **kwargs) | |||
class ReportRunner(dyn_runner.ReportRunner): | |||
"""Dynamical system runner class. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.dyn.ReportRunner" instead. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
warnings.warn('Please use "brainpy.dyn.ReportRunner" instead. ' | |||
'"brainpy.ReportRunner" is deprecated since ' | |||
'version 2.1.0', DeprecationWarning) | |||
super(ReportRunner, self).__init__(*args, **kwargs) |
@@ -14,7 +14,7 @@ def test_one2one(): | |||
num = bp.tools.size2num(size) | |||
actual_mat = bp.math.zeros((num, num), dtype=bp.math.bool_) | |||
actual_mat = bp.math.fill_diagonal(actual_mat, True) | |||
bp.math.fill_diagonal(actual_mat, True) | |||
assert bp.math.array_equal(actual_mat, conn_mat) | |||
assert bp.math.array_equal(pre_ids, bp.math.arange(num)) | |||
@@ -42,7 +42,7 @@ def test_all2all(): | |||
print(mat) | |||
actual_mat = bp.math.ones((num, num), dtype=bp.math.bool_) | |||
if not has_self: | |||
actual_mat = bp.math.fill_diagonal(actual_mat, False) | |||
bp.math.fill_diagonal(actual_mat, False) | |||
assert bp.math.array_equal(actual_mat, mat) | |||
@@ -167,8 +167,8 @@ def mackey_glass_series(duration, dt=0.1, beta=2., gamma=1., tau=2., n=9.65, | |||
assert isinstance(inits, (bm.ndarray, jnp.ndarray)) | |||
rng = bm.random.RandomState(seed) | |||
xdelay = bm.FixedLenDelay(inits.shape, tau, dt=dt) | |||
xdelay.data = inits + 0.2 * (rng.random((xdelay.num_delay_steps,) + inits.shape) - 0.5) | |||
xdelay = bm.TimeDelay(inits, tau, dt=dt) | |||
xdelay.data = inits + 0.2 * (rng.random((xdelay.num_delay_step,) + inits.shape) - 0.5) | |||
@ddeint(method=method, state_delays={'x': xdelay}) | |||
def mg_eq(x, t): | |||
@@ -1,752 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
import brainpy.math as bm | |||
from brainpy.integrators.joint_eq import JointEq | |||
from brainpy.integrators.ode import odeint | |||
from brainpy.dyn.base import NeuGroup | |||
__all__ = [ | |||
'LIF', | |||
'ExpIF', | |||
'AdExIF', | |||
'QuaIF', | |||
'AdQuaIF', | |||
'GIF', | |||
] | |||
class LIF(NeuGroup): | |||
r"""Leaky integrate-and-fire neuron model. | |||
**Model Descriptions** | |||
The formal equations of a LIF model [1]_ is given by: | |||
.. math:: | |||
\tau \frac{dV}{dt} = - (V(t) - V_{rest}) + I(t) \\ | |||
\text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} \quad | |||
\text{last} \quad \tau_{ref} \quad \text{ms} | |||
where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting | |||
membrane potential, :math:`V_{reset}` is the reset membrane potential, | |||
:math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant, | |||
:math:`\tau_{ref}` is the refractory time period, | |||
and :math:`I` is the time-variant synaptic inputs. | |||
**Model Examples** | |||
- `(Brette, Romain. 2004) LIF phase locking <https://brainpy-examples.readthedocs.io/en/latest/neurons/Romain_2004_LIF_phase_locking.html>`_ | |||
**Model Parameters** | |||
============= ============== ======== ========================================= | |||
**Parameter** **Init Value** **Unit** **Explanation** | |||
------------- -------------- -------- ----------------------------------------- | |||
V_rest 0 mV Resting membrane potential. | |||
V_reset -5 mV Reset potential after spike. | |||
V_th 20 mV Threshold potential of spike. | |||
tau 10 ms Membrane time constant. Compute by R * C. | |||
tau_ref 5 ms Refractory period length.(ms) | |||
============= ============== ======== ========================================= | |||
**Neuron Variables** | |||
================== ================= ========================================================= | |||
**Variables name** **Initial Value** **Explanation** | |||
------------------ ----------------- --------------------------------------------------------- | |||
V 0 Membrane potential. | |||
input 0 External and synaptic input current. | |||
spike False Flag to mark whether the neuron is spiking. | |||
refractory False Flag to mark whether the neuron is in refractory period. | |||
t_last_spike -1e7 Last spike time stamp. | |||
================== ================= ========================================================= | |||
**References** | |||
.. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model | |||
neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. | |||
""" | |||
def __init__(self, size, V_rest=0., V_reset=-5., V_th=20., tau=10., | |||
tau_ref=1., method='exp_auto', name=None): | |||
# initialization | |||
super(LIF, self).__init__(size=size, name=name) | |||
# parameters | |||
self.V_rest = V_rest | |||
self.V_reset = V_reset | |||
self.V_th = V_th | |||
self.tau = tau | |||
self.tau_ref = tau_ref | |||
# variables | |||
self.V = bm.Variable(bm.zeros(self.num)) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
# integral | |||
self.integral = odeint(method=method, f=self.derivative) | |||
def derivative(self, V, t, I_ext): | |||
dvdt = (-V + self.V_rest + I_ext) / self.tau | |||
return dvdt | |||
def update(self, _t, _dt): | |||
refractory = (_t - self.t_last_spike) <= self.tau_ref | |||
V = self.integral(self.V, _t, self.input, dt=_dt) | |||
V = bm.where(refractory, self.V, V) | |||
spike = V >= self.V_th | |||
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike) | |||
self.V.value = bm.where(spike, self.V_reset, V) | |||
self.refractory.value = bm.logical_or(refractory, spike) | |||
self.spike.value = spike | |||
self.input[:] = 0. | |||
class ExpIF(NeuGroup): | |||
r"""Exponential integrate-and-fire neuron model. | |||
**Model Descriptions** | |||
In the exponential integrate-and-fire model [1]_, the differential | |||
equation for the membrane potential is given by | |||
.. math:: | |||
\tau\frac{d V}{d t}= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} + RI(t), \\ | |||
\text{after} \, V(t) \gt V_{th}, V(t) = V_{reset} \, \text{last} \, \tau_{ref} \, \text{ms} | |||
This equation has an exponential nonlinearity with "sharpness" parameter :math:`\Delta_{T}` | |||
and "threshold" :math:`\vartheta_{rh}`. | |||
The moment when the membrane potential reaches the numerical threshold :math:`V_{th}` | |||
defines the firing time :math:`t^{(f)}`. After firing, the membrane potential is reset to | |||
:math:`V_{rest}` and integration restarts at time :math:`t^{(f)}+\tau_{\rm ref}`, | |||
where :math:`\tau_{\rm ref}` is an absolute refractory time. | |||
If the numerical threshold is chosen sufficiently high, :math:`V_{th}\gg v+\Delta_T`, | |||
its exact value does not play any role. The reason is that the upswing of the action | |||
potential for :math:`v\gg v +\Delta_{T}` is so rapid, that it goes to infinity in | |||
an incredibly short time. The threshold :math:`V_{th}` is introduced mainly for numerical | |||
convenience. For a formal mathematical analysis of the model, the threshold can be pushed | |||
to infinity. | |||
The model was first introduced by Nicolas Fourcaud-Trocmé, David Hansel, Carl van Vreeswijk | |||
and Nicolas Brunel [1]_. The exponential nonlinearity was later confirmed by Badel et al. [3]_. | |||
It is one of the prominent examples of a precise theoretical prediction in computational | |||
neuroscience that was later confirmed by experimental neuroscience. | |||
Two important remarks: | |||
- (i) The right-hand side of the above equation contains a nonlinearity | |||
that can be directly extracted from experimental data [3]_. In this sense the exponential | |||
nonlinearity is not an arbitrary choice but directly supported by experimental evidence. | |||
- (ii) Even though it is a nonlinear model, it is simple enough to calculate the firing | |||
rate for constant input, and the linear response to fluctuations, even in the presence | |||
of input noise [4]_. | |||
**Model Examples** | |||
.. plot:: | |||
:include-source: True | |||
>>> import brainpy as bp | |||
>>> group = bp.dyn.ExpIF(1) | |||
>>> runner = bp.dyn.DSRunner(group, monitors=['V'], inputs=('input', 10.)) | |||
>>> runner.run(300., ) | |||
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', show=True) | |||
**Model Parameters** | |||
============= ============== ======== =================================================== | |||
**Parameter** **Init Value** **Unit** **Explanation** | |||
------------- -------------- -------- --------------------------------------------------- | |||
V_rest -65 mV Resting potential. | |||
V_reset -68 mV Reset potential after spike. | |||
V_th -30 mV Threshold potential of spike. | |||
V_T -59.9 mV Threshold potential of generating action potential. | |||
delta_T 3.48 \ Spike slope factor. | |||
R 1 \ Membrane resistance. | |||
tau 10 \ Membrane time constant. Compute by R * C. | |||
tau_ref 1.7 \ Refractory period length. | |||
============= ============== ======== =================================================== | |||
**Model Variables** | |||
================== ================= ========================================================= | |||
**Variables name** **Initial Value** **Explanation** | |||
------------------ ----------------- --------------------------------------------------------- | |||
V 0 Membrane potential. | |||
input 0 External and synaptic input current. | |||
spike False Flag to mark whether the neuron is spiking. | |||
refractory False Flag to mark whether the neuron is in refractory period. | |||
t_last_spike -1e7 Last spike time stamp. | |||
================== ================= ========================================================= | |||
**References** | |||
.. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation | |||
mechanisms determine the neuronal response to fluctuating | |||
inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. | |||
.. [2] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). | |||
Neuronal dynamics: From single neurons to networks and models | |||
of cognition. Cambridge University Press. | |||
.. [3] Badel, Laurent, Sandrine Lefort, Romain Brette, Carl CH Petersen, | |||
Wulfram Gerstner, and Magnus JE Richardson. "Dynamic IV curves | |||
are reliable predictors of naturalistic pyramidal-neuron voltage | |||
traces." Journal of Neurophysiology 99, no. 2 (2008): 656-666. | |||
.. [4] Richardson, Magnus JE. "Firing-rate response of linear and nonlinear | |||
integrate-and-fire neurons to modulated current-based and | |||
conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919. | |||
.. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire | |||
""" | |||
def __init__(self, size, V_rest=-65., V_reset=-68., V_th=-30., V_T=-59.9, delta_T=3.48, | |||
R=1., tau=10., tau_ref=1.7, method='exp_auto', name=None): | |||
# initialize | |||
super(ExpIF, self).__init__(size=size, name=name) | |||
# parameters | |||
self.V_rest = V_rest | |||
self.V_reset = V_reset | |||
self.V_th = V_th | |||
self.V_T = V_T | |||
self.delta_T = delta_T | |||
self.R = R | |||
self.tau = tau | |||
self.tau_ref = tau_ref | |||
# variables | |||
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
# variables | |||
self.V = bm.Variable(bm.zeros(self.num)) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
# integral | |||
self.integral = odeint(method=method, f=self.derivative) | |||
def derivative(self, V, t, I_ext): | |||
exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) | |||
dvdt = (- (V - self.V_rest) + exp_v + self.R * I_ext) / self.tau | |||
return dvdt | |||
def update(self, _t, _dt): | |||
refractory = (_t - self.t_last_spike) <= self.tau_ref | |||
V = self.integral(self.V, _t, self.input, dt=_dt) | |||
V = bm.where(refractory, self.V, V) | |||
spike = self.V_th <= V | |||
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike) | |||
self.V.value = bm.where(spike, self.V_reset, V) | |||
self.refractory.value = bm.logical_or(refractory, spike) | |||
self.spike.value = spike | |||
self.input[:] = 0. | |||
class AdExIF(NeuGroup): | |||
r"""Adaptive exponential integrate-and-fire neuron model. | |||
**Model Descriptions** | |||
The **adaptive exponential integrate-and-fire model**, also called AdEx, is a | |||
spiking neuron model with two variables [1]_ [2]_. | |||
.. math:: | |||
\begin{aligned} | |||
\tau_m\frac{d V}{d t} &= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} - Rw + RI(t), \\ | |||
\tau_w \frac{d w}{d t} &=a(V-V_{rest}) - w | |||
\end{aligned} | |||
once the membrane potential reaches the spike threshold, | |||
.. math:: | |||
V \rightarrow V_{reset}, \\ | |||
w \rightarrow w+b. | |||
The first equation describes the dynamics of the membrane potential and includes | |||
an activation term with an exponential voltage dependence. Voltage is coupled to | |||
a second equation which describes adaptation. Both variables are reset if an action | |||
potential has been triggered. The combination of adaptation and exponential voltage | |||
dependence gives rise to the name Adaptive Exponential Integrate-and-Fire model. | |||
The adaptive exponential integrate-and-fire model is capable of describing known | |||
neuronal firing patterns, e.g., adapting, bursting, delayed spike initiation, | |||
initial bursting, fast spiking, and regular spiking. | |||
**Model Examples** | |||
- `Examples for different firing patterns <https://brainpy-examples.readthedocs.io/en/latest/neurons/AdExIF_model.html>`_ | |||
**Model Parameters** | |||
============= ============== ======== ======================================================================================================================== | |||
**Parameter** **Init Value** **Unit** **Explanation** | |||
------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ | |||
V_rest -65 mV Resting potential. | |||
V_reset -68 mV Reset potential after spike. | |||
V_th -30 mV Threshold potential of spike and reset. | |||
V_T -59.9 mV Threshold potential of generating action potential. | |||
delta_T 3.48 \ Spike slope factor. | |||
a 1 \ The sensitivity of the recovery variable :math:`u` to the sub-threshold fluctuations of the membrane potential :math:`v` | |||
b 1 \ The increment of :math:`w` produced by a spike. | |||
R 1 \ Membrane resistance. | |||
tau 10 ms Membrane time constant. Compute by R * C. | |||
tau_w 30 ms Time constant of the adaptation current. | |||
============= ============== ======== ======================================================================================================================== | |||
**Model Variables** | |||
================== ================= ========================================================= | |||
**Variables name** **Initial Value** **Explanation** | |||
------------------ ----------------- --------------------------------------------------------- | |||
V 0 Membrane potential. | |||
w 0 Adaptation current. | |||
input 0 External and synaptic input current. | |||
spike False Flag to mark whether the neuron is spiking. | |||
t_last_spike -1e7 Last spike time stamp. | |||
================== ================= ========================================================= | |||
**References** | |||
.. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation | |||
mechanisms determine the neuronal response to fluctuating | |||
inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. | |||
.. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model | |||
""" | |||
def __init__(self, size, V_rest=-65., V_reset=-68., V_th=-30., V_T=-59.9, delta_T=3.48, a=1., | |||
b=1., tau=10., tau_w=30., R=1., method='exp_auto', name=None): | |||
super(AdExIF, self).__init__(size=size, name=name) | |||
# parameters | |||
self.V_rest = V_rest | |||
self.V_reset = V_reset | |||
self.V_th = V_th | |||
self.V_T = V_T | |||
self.delta_T = delta_T | |||
self.a = a | |||
self.b = b | |||
self.tau = tau | |||
self.tau_w = tau_w | |||
self.R = R | |||
# variables | |||
self.w = bm.Variable(bm.zeros(self.num)) | |||
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.V = bm.Variable(bm.zeros(self.num)) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
# functions | |||
self.integral = odeint(method=method, f=self.derivative) | |||
def dV(self, V, t, w, I_ext): | |||
dVdt = (- V + self.V_rest + self.delta_T * bm.exp((V - self.V_T) / self.delta_T) - | |||
self.R * w + self.R * I_ext) / self.tau | |||
return dVdt | |||
def dw(self, w, t, V): | |||
dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w | |||
return dwdt | |||
@property | |||
def derivative(self): | |||
return JointEq([self.dV, self.dw]) | |||
def update(self, _t, _dt): | |||
V, w = self.integral(self.V, self.w, _t, self.input, dt=_dt) | |||
spike = V >= self.V_th | |||
self.t_last_spike[:] = bm.where(spike, _t, self.t_last_spike) | |||
self.V.value = bm.where(spike, self.V_reset, V) | |||
self.w.value = bm.where(spike, w + self.b, w) | |||
self.spike.value = spike | |||
self.input[:] = 0. | |||
class QuaIF(NeuGroup): | |||
r"""Quadratic Integrate-and-Fire neuron model. | |||
**Model Descriptions** | |||
In contrast to physiologically accurate but computationally expensive | |||
neuron models like the Hodgkin–Huxley model, the QIF model [1]_ seeks only | |||
to produce **action potential-like patterns** and ignores subtleties | |||
like gating variables, which play an important role in generating action | |||
potentials in a real neuron. However, the QIF model is incredibly easy | |||
to implement and compute, and relatively straightforward to study and | |||
understand, thus has found ubiquitous use in computational neuroscience. | |||
.. math:: | |||
\tau \frac{d V}{d t}=c(V-V_{rest})(V-V_c) + RI(t) | |||
where the parameters are taken to be :math:`c` =0.07, and :math:`V_c = -50 mV` (Latham et al., 2000). | |||
**Model Examples** | |||
.. plot:: | |||
:include-source: True | |||
>>> import brainpy as bp | |||
>>> | |||
>>> group = bp.dyn.QuaIF(1,) | |||
>>> | |||
>>> runner = bp.dyn.DSRunner(group, monitors=['V'], inputs=('input', 20.)) | |||
>>> runner.run(duration=200.) | |||
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True) | |||
**Model Parameters** | |||
============= ============== ======== ======================================================================================================================== | |||
**Parameter** **Init Value** **Unit** **Explanation** | |||
------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ | |||
V_rest -65 mV Resting potential. | |||
V_reset -68 mV Reset potential after spike. | |||
V_th -30 mV Threshold potential of spike and reset. | |||
V_c -50 mV Critical voltage for spike initiation. Must be larger than V_rest. | |||
c .07 \ Coefficient describes membrane potential update. Larger than 0. | |||
R 1 \ Membrane resistance. | |||
tau 10 ms Membrane time constant. Compute by R * C. | |||
tau_ref 0 ms Refractory period length. | |||
============= ============== ======== ======================================================================================================================== | |||
**Model Variables** | |||
================== ================= ========================================================= | |||
**Variables name** **Initial Value** **Explanation** | |||
------------------ ----------------- --------------------------------------------------------- | |||
V 0 Membrane potential. | |||
input 0 External and synaptic input current. | |||
spike False Flag to mark whether the neuron is spiking. | |||
refractory False Flag to mark whether the neuron is in refractory period. | |||
t_last_spike -1e7 Last spike time stamp. | |||
================== ================= ========================================================= | |||
**References** | |||
.. [1] P. E. Latham, B.J. Richmond, P. Nelson and S. Nirenberg | |||
(2000) Intrinsic dynamics in neuronal networks. I. Theory. | |||
J. Neurophysiology 83, pp. 808–827. | |||
""" | |||
def __init__(self, size, V_rest=-65., V_reset=-68., V_th=-30., V_c=-50.0, c=.07, | |||
R=1., tau=10., tau_ref=0., method='exp_auto', name=None): | |||
# initialization | |||
super(QuaIF, self).__init__(size=size, name=name) | |||
# parameters | |||
self.V_rest = V_rest | |||
self.V_reset = V_reset | |||
self.V_th = V_th | |||
self.V_c = V_c | |||
self.c = c | |||
self.R = R | |||
self.tau = tau | |||
self.tau_ref = tau_ref | |||
# variables | |||
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
# variables | |||
self.V = bm.Variable(bm.zeros(self.num)) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
# integral | |||
self.integral = odeint(method=method, f=self.derivative) | |||
def derivative(self, V, t, I_ext): | |||
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I_ext) / self.tau | |||
return dVdt | |||
def update(self, _t, _dt, **kwargs): | |||
refractory = (_t - self.t_last_spike) <= self.tau_ref | |||
V = self.integral(self.V, _t, self.input, dt=_dt) | |||
V = bm.where(refractory, self.V, V) | |||
spike = self.V_th <= V | |||
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike) | |||
self.V.value = bm.where(spike, self.V_reset, V) | |||
self.refractory.value = bm.logical_or(refractory, spike) | |||
self.spike.value = spike | |||
self.input[:] = 0. | |||
class AdQuaIF(NeuGroup): | |||
r"""Adaptive quadratic integrate-and-fire neuron model. | |||
**Model Descriptions** | |||
The adaptive quadratic integrate-and-fire neuron model [1]_ is given by: | |||
.. math:: | |||
\begin{aligned} | |||
\tau_m \frac{d V}{d t}&=c(V-V_{rest})(V-V_c) - w + I(t), \\ | |||
\tau_w \frac{d w}{d t}&=a(V-V_{rest}) - w, | |||
\end{aligned} | |||
once the membrane potential reaches the spike threshold, | |||
.. math:: | |||
V \rightarrow V_{reset}, \\ | |||
w \rightarrow w+b. | |||
**Model Examples** | |||
.. plot:: | |||
:include-source: True | |||
>>> import brainpy as bp | |||
>>> group = bp.dyn.AdQuaIF(1, ) | |||
>>> runner = bp.dyn.DSRunner(group, monitors=['V', 'w'], inputs=('input', 30.)) | |||
>>> runner.run(300) | |||
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) | |||
>>> fig.add_subplot(gs[0, 0]) | |||
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V') | |||
>>> fig.add_subplot(gs[1, 0]) | |||
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True) | |||
**Model Parameters** | |||
============= ============== ======== ======================================================= | |||
**Parameter** **Init Value** **Unit** **Explanation** | |||
------------- -------------- -------- ------------------------------------------------------- | |||
V_rest -65 mV Resting potential. | |||
V_reset -68 mV Reset potential after spike. | |||
V_th -30 mV Threshold potential of spike and reset. | |||
V_c -50 mV Critical voltage for spike initiation. Must be larger | |||
than :math:`V_{rest}`. | |||
a 1 \ The sensitivity of the recovery variable :math:`u` to | |||
the sub-threshold fluctuations of the membrane | |||
potential :math:`v` | |||
b .1 \ The increment of :math:`w` produced by a spike. | |||
c .07 \ Coefficient describes membrane potential update. | |||
Larger than 0. | |||
tau 10 ms Membrane time constant. | |||
tau_w 10 ms Time constant of the adaptation current. | |||
============= ============== ======== ======================================================= | |||
**Model Variables** | |||
================== ================= ========================================================== | |||
**Variables name** **Initial Value** **Explanation** | |||
------------------ ----------------- ---------------------------------------------------------- | |||
V 0 Membrane potential. | |||
w 0 Adaptation current. | |||
input 0 External and synaptic input current. | |||
spike False Flag to mark whether the neuron is spiking. | |||
t_last_spike -1e7 Last spike time stamp. | |||
================== ================= ========================================================== | |||
**References** | |||
.. [1] Izhikevich, E. M. (2004). Which model to use for cortical spiking | |||
neurons?. IEEE transactions on neural networks, 15(5), 1063-1070. | |||
.. [2] Touboul, Jonathan. "Bifurcation analysis of a general class of | |||
nonlinear integrate-and-fire neurons." SIAM Journal on Applied | |||
Mathematics 68, no. 4 (2008): 1045-1079. | |||
""" | |||
def __init__(self, size, V_rest=-65., V_reset=-68., V_th=-30., V_c=-50.0, a=1., b=.1, | |||
c=.07, tau=10., tau_w=10., method='exp_auto', name=None): | |||
super(AdQuaIF, self).__init__(size=size, name=name) | |||
# parameters | |||
self.V_rest = V_rest | |||
self.V_reset = V_reset | |||
self.V_th = V_th | |||
self.V_c = V_c | |||
self.c = c | |||
self.a = a | |||
self.b = b | |||
self.tau = tau | |||
self.tau_w = tau_w | |||
# variables | |||
self.V = bm.Variable(bm.zeros(self.num)) | |||
self.w = bm.Variable(bm.zeros(self.num)) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
# integral | |||
self.integral = odeint(method=method, f=self.derivative) | |||
def dV(self, V, t, w, I_ext): | |||
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I_ext) / self.tau | |||
return dVdt | |||
def dw(self, w, t, V): | |||
dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w | |||
return dwdt | |||
@property | |||
def derivative(self): | |||
return JointEq([self.dV, self.dw]) | |||
def update(self, _t, _dt): | |||
V, w = self.integral(self.V, self.w, _t, self.input, dt=_dt) | |||
spike = self.V_th <= V | |||
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike) | |||
self.V.value = bm.where(spike, self.V_reset, V) | |||
self.w.value = bm.where(spike, w + self.b, w) | |||
self.spike.value = spike | |||
self.input[:] = 0. | |||
class GIF(NeuGroup): | |||
r"""Generalized Integrate-and-Fire model. | |||
**Model Descriptions** | |||
The generalized integrate-and-fire model [1]_ is given by | |||
.. math:: | |||
&\frac{d I_j}{d t} = - k_j I_j | |||
&\frac{d V}{d t} = ( - (V - V_{rest}) + R\sum_{j}I_j + RI) / \tau | |||
&\frac{d V_{th}}{d t} = a(V - V_{rest}) - b(V_{th} - V_{th\infty}) | |||
When :math:`V` meet :math:`V_{th}`, Generalized IF neuron fires: | |||
.. math:: | |||
&I_j \leftarrow R_j I_j + A_j | |||
&V \leftarrow V_{reset} | |||
&V_{th} \leftarrow max(V_{th_{reset}}, V_{th}) | |||
Note that :math:`I_j` refers to arbitrary number of internal currents. | |||
**Model Examples** | |||
- `Detailed examples to reproduce different firing patterns <https://brainpy-examples.readthedocs.io/en/latest/neurons/Niebur_2009_GIF.html>`_ | |||
**Model Parameters** | |||
============= ============== ======== ==================================================================== | |||
**Parameter** **Init Value** **Unit** **Explanation** | |||
------------- -------------- -------- -------------------------------------------------------------------- | |||
V_rest -70 mV Resting potential. | |||
V_reset -70 mV Reset potential after spike. | |||
V_th_inf -50 mV Target value of threshold potential :math:`V_{th}` updating. | |||
V_th_reset -60 mV Free parameter, should be larger than :math:`V_{reset}`. | |||
R 20 \ Membrane resistance. | |||
tau 20 ms Membrane time constant. Compute by :math:`R * C`. | |||
a 0 \ Coefficient describes the dependence of | |||
:math:`V_{th}` on membrane potential. | |||
b 0.01 \ Coefficient describes :math:`V_{th}` update. | |||
k1 0.2 \ Constant pf :math:`I1`. | |||
k2 0.02 \ Constant of :math:`I2`. | |||
R1 0 \ Free parameter. | |||
Describes dependence of :math:`I_1` reset value on | |||
:math:`I_1` value before spiking. | |||
R2 1 \ Free parameter. | |||
Describes dependence of :math:`I_2` reset value on | |||
:math:`I_2` value before spiking. | |||
A1 0 \ Free parameter. | |||
A2 0 \ Free parameter. | |||
============= ============== ======== ==================================================================== | |||
**Model Variables** | |||
================== ================= ========================================================= | |||
**Variables name** **Initial Value** **Explanation** | |||
------------------ ----------------- --------------------------------------------------------- | |||
V -70 Membrane potential. | |||
input 0 External and synaptic input current. | |||
spike False Flag to mark whether the neuron is spiking. | |||
V_th -50 Spiking threshold potential. | |||
I1 0 Internal current 1. | |||
I2 0 Internal current 2. | |||
t_last_spike -1e7 Last spike time stamp. | |||
================== ================= ========================================================= | |||
**References** | |||
.. [1] Mihalaş, Ştefan, and Ernst Niebur. "A generalized linear | |||
integrate-and-fire neural model produces diverse spiking | |||
behaviors." Neural computation 21.3 (2009): 704-718. | |||
.. [2] Teeter, Corinne, Ramakrishnan Iyer, Vilas Menon, Nathan | |||
Gouwens, David Feng, Jim Berg, Aaron Szafer et al. "Generalized | |||
leaky integrate-and-fire models classify multiple neuron types." | |||
Nature communications 9, no. 1 (2018): 1-15. | |||
""" | |||
def __init__(self, size, V_rest=-70., V_reset=-70., V_th_inf=-50., V_th_reset=-60., | |||
R=20., tau=20., a=0., b=0.01, k1=0.2, k2=0.02, R1=0., R2=1., A1=0., | |||
A2=0., method='exp_auto', name=None): | |||
# initialization | |||
super(GIF, self).__init__(size=size, name=name) | |||
# params | |||
self.V_rest = V_rest | |||
self.V_reset = V_reset | |||
self.V_th_inf = V_th_inf | |||
self.V_th_reset = V_th_reset | |||
self.R = R | |||
self.tau = tau | |||
self.a = a | |||
self.b = b | |||
self.k1 = k1 | |||
self.k2 = k2 | |||
self.R1 = R1 | |||
self.R2 = R2 | |||
self.A1 = A1 | |||
self.A2 = A2 | |||
# variables | |||
self.I1 = bm.Variable(bm.zeros(self.num)) | |||
self.I2 = bm.Variable(bm.zeros(self.num)) | |||
self.V_th = bm.Variable(bm.ones(self.num) * -50.) | |||
self.V = bm.Variable(bm.zeros(self.num)) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
# integral | |||
self.integral = odeint(method=method, f=self.derivative) | |||
def dI1(self, I1, t): | |||
return - self.k1 * I1 | |||
def dI2(self, I2, t): | |||
return - self.k2 * I2 | |||
def dVth(self, V_th, t, V): | |||
return self.a * (V - self.V_rest) - self.b * (V_th - self.V_th_inf) | |||
def dV(self, V, t, I1, I2, I_ext): | |||
return (- (V - self.V_rest) + self.R * I_ext + self.R * I1 + self.R * I2) / self.tau | |||
@property | |||
def derivative(self): | |||
return JointEq([self.dI1, self.dI2, self.dVth, self.dV]) | |||
def update(self, _t, _dt): | |||
I1, I2, V_th, V = self.integral(self.I1, self.I2, self.V_th, self.V, _t, self.input, dt=_dt) | |||
spike = self.V_th <= V | |||
V = bm.where(spike, self.V_reset, V) | |||
I1 = bm.where(spike, self.R1 * I1 + self.A1, I1) | |||
I2 = bm.where(spike, self.R2 * I2 + self.A2, I2) | |||
reset_th = bm.logical_and(V_th < self.V_th_reset, spike) | |||
V_th = bm.where(reset_th, self.V_th_reset, V_th) | |||
self.spike.value = spike | |||
self.I1.value = I1 | |||
self.I2.value = I2 | |||
self.V_th.value = V_th | |||
self.V.value = V | |||
self.input[:] = 0. |
@@ -1,7 +1,8 @@ | |||
# -*- coding: utf-8 -*- | |||
from .biological_models import * | |||
from .IF_models import * | |||
from .fractional_models import * | |||
from .input_models import * | |||
from .noise_models import * | |||
from .rate_models import * | |||
from .reduced_models import * |
@@ -1,9 +1,14 @@ | |||
# -*- coding: utf-8 -*- | |||
from typing import Union, Callable | |||
import brainpy.math as bm | |||
from brainpy.dyn.base import NeuGroup | |||
from brainpy.initialize import OneInit, Uniform, Initializer, init_param | |||
from brainpy.integrators.joint_eq import JointEq | |||
from brainpy.integrators.ode import odeint | |||
from brainpy.dyn.base import NeuGroup | |||
from brainpy.tools.checking import check_initializer | |||
from brainpy.types import Shape, Parameter, Tensor | |||
__all__ = [ | |||
'HH', | |||
@@ -178,8 +183,24 @@ class HH(NeuGroup): | |||
The Journal of Mathematical Neuroscience 6, no. 1 (2016): 1-92. | |||
""" | |||
def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03, | |||
V_th=20., C=1.0, method='exp_auto', name=None): | |||
def __init__( | |||
self, | |||
size: Shape, | |||
ENa: Parameter = 50., | |||
gNa: Parameter = 120., | |||
EK: Parameter = -77., | |||
gK: Parameter = 36., | |||
EL: Parameter = -54.387, | |||
gL: Parameter = 0.03, | |||
V_th: Parameter = 20., | |||
C: Parameter = 1.0, | |||
V_initializer: Union[Initializer, Callable, Tensor] = Uniform(-70, -60.), | |||
m_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.5), | |||
h_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.6), | |||
n_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.32), | |||
method: str = 'exp_auto', | |||
name: str = None | |||
): | |||
# initialization | |||
super(HH, self).__init__(size=size, name=name) | |||
@@ -194,10 +215,14 @@ class HH(NeuGroup): | |||
self.V_th = V_th | |||
# variables | |||
self.m = bm.Variable(0.5 * bm.ones(self.num)) | |||
self.h = bm.Variable(0.6 * bm.ones(self.num)) | |||
self.n = bm.Variable(0.32 * bm.ones(self.num)) | |||
self.V = bm.Variable(bm.zeros(self.num)) | |||
check_initializer(m_initializer, 'm_initializer', allow_none=False) | |||
check_initializer(h_initializer, 'h_initializer', allow_none=False) | |||
check_initializer(n_initializer, 'n_initializer', allow_none=False) | |||
check_initializer(V_initializer, 'V_initializer', allow_none=False) | |||
self.m = bm.Variable(init_param(m_initializer, (self.num,))) | |||
self.h = bm.Variable(init_param(h_initializer, (self.num,))) | |||
self.n = bm.Variable(init_param(n_initializer, (self.num,))) | |||
self.V = bm.Variable(init_param(V_initializer, (self.num,))) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
@@ -334,9 +359,27 @@ class MorrisLecar(NeuGroup): | |||
.. [3] https://en.wikipedia.org/wiki/Morris%E2%80%93Lecar_model | |||
""" | |||
def __init__(self, size, V_Ca=130., g_Ca=4.4, V_K=-84., g_K=8., V_leak=-60., | |||
g_leak=2., C=20., V1=-1.2, V2=18., V3=2., V4=30., phi=0.04, | |||
V_th=10., method='exp_auto', name=None): | |||
def __init__( | |||
self, | |||
size: Shape, | |||
V_Ca: Parameter = 130., | |||
g_Ca: Parameter = 4.4, | |||
V_K: Parameter = -84., | |||
g_K: Parameter = 8., | |||
V_leak: Parameter = -60., | |||
g_leak: Parameter = 2., | |||
C: Parameter = 20., | |||
V1: Parameter = -1.2, | |||
V2: Parameter = 18., | |||
V3: Parameter = 2., | |||
V4: Parameter = 30., | |||
phi: Parameter = 0.04, | |||
V_th: Parameter = 10., | |||
W_initializer: Union[Callable, Initializer, Tensor] = OneInit(0.02), | |||
V_initializer: Union[Callable, Initializer, Tensor] = Uniform(-70., -60.), | |||
method: str = 'exp_auto', | |||
name: str = None | |||
): | |||
# initialization | |||
super(MorrisLecar, self).__init__(size=size, name=name) | |||
@@ -356,8 +399,10 @@ class MorrisLecar(NeuGroup): | |||
self.V_th = V_th | |||
# vars | |||
self.W = bm.Variable(bm.ones(self.num) * 0.02) | |||
self.V = bm.Variable(bm.zeros(self.num)) | |||
check_initializer(V_initializer, 'V_initializer', allow_none=False) | |||
check_initializer(W_initializer, 'W_initializer', allow_none=False) | |||
self.W = bm.Variable(init_param(W_initializer, (self.num,))) | |||
self.V = bm.Variable(init_param(V_initializer, (self.num,))) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
@@ -0,0 +1,294 @@ | |||
# -*- coding: utf-8 -*- | |||
from typing import Union, Sequence, Callable | |||
import brainpy.math as bm | |||
from brainpy.dyn.base import NeuGroup | |||
from brainpy.initialize import ZeroInit, OneInit, Initializer, init_param | |||
from brainpy.integrators.fde import CaputoL1Schema | |||
from brainpy.integrators.fde import GLShortMemory | |||
from brainpy.integrators.joint_eq import JointEq | |||
from brainpy.tools.checking import check_float, check_integer | |||
from brainpy.tools.checking import check_initializer | |||
from brainpy.types import Parameter, Shape, Tensor | |||
__all__ = [ | |||
'FractionalNeuron', | |||
'FractionalFHR', | |||
'FractionalIzhikevich', | |||
] | |||
class FractionalNeuron(NeuGroup): | |||
"""Fractional-order neuron model.""" | |||
pass | |||
class FractionalFHR(FractionalNeuron): | |||
r"""The fractional-order FH-R model [1]_. | |||
FitzHugh and Rinzel introduced FH-R model (1976, in an unpublished article), | |||
which is the modification of the classical FHN neuron model. The fractional-order | |||
FH-R model is described as | |||
.. math:: | |||
\begin{array}{rcl} | |||
\frac{{d}^{\alpha }v}{d{t}^{\alpha }} & = & v-{v}^{3}/3-w+y+I={f}_{1}(v,w,y),\\ | |||
\frac{{d}^{\alpha }w}{d{t}^{\alpha }} & = & \delta (a+v-bw)={f}_{2}(v,w,y),\\ | |||
\frac{{d}^{\alpha }y}{d{t}^{\alpha }} & = & \mu (c-v-dy)={f}_{3}(v,w,y), | |||
\end{array} | |||
where :math:`v, w` and :math:`y` represent the membrane voltage, recovery variable | |||
and slow modulation of the current respectively. | |||
:math:`I` measures the constant magnitude of external stimulus current, and :math:`\alpha` | |||
is the fractional exponent which ranges in the interval :math:`(0 < \alpha \le 1)`. | |||
:math:`a, b, c, d, \delta` and :math:`\mu` are the system parameters. | |||
The system reduces to the original classical order system when :math:`\alpha=1`. | |||
:math:`\mu` indicates a small parameter that determines the pace of the slow system | |||
variable :math:`y`. The fast subsystem (:math:`v-w`) presents a relaxation oscillator | |||
in the phase plane where :math:`\delta` is a small parameter. | |||
:math:`v` is expressed in mV (millivolt) scale. Time :math:`t` is in ms (millisecond) scale. | |||
It exhibits tonic spiking or quiescent state depending on the parameter sets for a fixed | |||
value of :math:`I`. The parameter :math:`a` in the 2D FHN model corresponds to the | |||
parameter :math:`c` of the FH-R neuron model. If we decrease the value of :math:`a`, | |||
it causes longer intervals between two burstings, however there exists :math:`a` | |||
relatively fixed time of bursting duration. With the increasing of :math:`a`, the | |||
interburst intervals become shorter and periodic bursting changes to tonic spiking. | |||
Examples | |||
-------- | |||
- [(Mondal, et, al., 2019): Fractional-order FitzHugh-Rinzel bursting neuron model](https://brainpy-examples.readthedocs.io/en/latest/neurons/2019_Fractional_order_FHR_model.html) | |||
Parameters | |||
---------- | |||
size: int, sequence of int | |||
The size of the neuron group. | |||
alpha: float, tensor | |||
The fractional order. | |||
num_memory: int | |||
The total number of the short memory. | |||
References | |||
---------- | |||
.. [1] Mondal, A., Sharma, S.K., Upadhyay, R.K. *et al.* Firing activities of a fractional-order FitzHugh-Rinzel bursting neuron model and its coupled dynamics. *Sci Rep* **9,** 15721 (2019). https://doi.org/10.1038/s41598-019-52061-4 | |||
""" | |||
def __init__( | |||
self, | |||
size: Shape, | |||
alpha: Union[float, Sequence[float]], | |||
num_memory: int = 1000, | |||
a: Parameter = 0.7, | |||
b: Parameter = 0.8, | |||
c: Parameter = -0.775, | |||
d: Parameter = 1., | |||
delta: Parameter = 0.08, | |||
mu: Parameter = 0.0001, | |||
Vth: Parameter = 1.8, | |||
V_initializer: Union[Initializer, Callable, Tensor] = OneInit(2.5), | |||
w_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), | |||
y_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), | |||
name: str = None | |||
): | |||
super(FractionalFHR, self).__init__(size, name=name) | |||
# fractional order | |||
self.alpha = alpha | |||
check_integer(num_memory, 'num_memory', allow_none=False) | |||
# parameters | |||
self.a = a | |||
self.b = b | |||
self.c = c | |||
self.d = d | |||
self.delta = delta | |||
self.mu = mu | |||
self.Vth = Vth | |||
# variables | |||
check_initializer(V_initializer, 'V_initializer', allow_none=False) | |||
check_initializer(w_initializer, 'w_initializer', allow_none=False) | |||
check_initializer(y_initializer, 'y_initializer', allow_none=False) | |||
self.V = bm.Variable(init_param(V_initializer, (self.num,))) | |||
self.w = bm.Variable(init_param(w_initializer, (self.num,))) | |||
self.y = bm.Variable(init_param(y_initializer, (self.num,))) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
# integral function | |||
self.integral = GLShortMemory(self.derivative, | |||
alpha=alpha, | |||
num_memory=num_memory, | |||
inits=[self.V, self.w, self.y]) | |||
def dV(self, V, t, w, y): | |||
return V - V ** 3 / 3 - w + y + self.input | |||
def dw(self, w, t, V): | |||
return self.delta * (self.a + V - self.b * w) | |||
def dy(self, y, t, V): | |||
return self.mu * (self.c - V - self.d * y) | |||
@property | |||
def derivative(self): | |||
return JointEq([self.dV, self.dw, self.dy]) | |||
def update(self, _t, _dt): | |||
V, w, y = self.integral(self.V, self.w, self.y, _t, _dt) | |||
self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth) | |||
self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike) | |||
self.V.value = V | |||
self.w.value = w | |||
self.y.value = y | |||
self.input[:] = 0. | |||
def set_init(self, values: dict): | |||
for k, v in values.items(): | |||
if k not in self.integral.inits: | |||
raise ValueError(f'Variable "{k}" is not defined in this model.') | |||
variable = getattr(self, k) | |||
variable[:] = v | |||
self.integral.inits[k][:] = v | |||
class FractionalIzhikevich(FractionalNeuron): | |||
r"""Fractional-order Izhikevich model [10]_. | |||
The fractional-order Izhikevich model is given by | |||
.. math:: | |||
\begin{aligned} | |||
&\tau \frac{d^{\alpha} v}{d t^{\alpha}}=\mathrm{f} v^{2}+g v+h-u+R I \\ | |||
&\tau \frac{d^{\alpha} u}{d t^{\alpha}}=a(b v-u) | |||
\end{aligned} | |||
where :math:`\alpha` is the fractional order (exponent) such that :math:`0<\alpha\le1`. | |||
It is a commensurate system that reduces to classical Izhikevich model at :math:`\alpha=1`. | |||
The time :math:`t` is in ms; and the system variable :math:`v` expressed in mV | |||
corresponds to membrane voltage. Moreover, :math:`u` expressed in mV is the | |||
recovery variable that corresponds to the activation of K+ ionic current and | |||
inactivation of Na+ ionic current. | |||
The parameters :math:`f, g, h` are fixed constants (should not be changed) such | |||
that :math:`f=0.04` (mV)−1, :math:`g=5, h=140` mV; and :math:`a` and :math:`b` are | |||
dimensionless parameters. The time constant :math:`\tau=1` ms; the resistance | |||
:math:`R=1` Ω; and :math:`I` expressed in mA measures the injected (applied) | |||
dc stimulus current to the system. | |||
When the membrane voltage reaches the spike peak :math:`v_{peak}`, the two variables | |||
are rest as follow: | |||
.. math:: | |||
\text { if } v \geq v_{\text {peak }} \text { then }\left\{\begin{array}{l} | |||
v \leftarrow c \\ | |||
u \leftarrow u+d | |||
\end{array}\right. | |||
we used :math:`v_{peak}=30` mV, and :math:`c` and :math:`d` are parameters expressed | |||
in mV. When the spike reaches its peak value, the membrane voltage :math:`v` and the | |||
recovery variable :math:`u` are reset according to the above condition. | |||
Examples | |||
-------- | |||
- [(Teka, et. al, 2018): Fractional-order Izhikevich neuron model](https://brainpy-examples.readthedocs.io/en/latest/neurons/2018_Fractional_Izhikevich_model.html) | |||
References | |||
---------- | |||
.. [10] Teka, Wondimu W., Ranjit Kumar Upadhyay, and Argha Mondal. "Spiking and | |||
bursting patterns of fractional-order Izhikevich model." Communications | |||
in Nonlinear Science and Numerical Simulation 56 (2018): 161-176. | |||
""" | |||
def __init__( | |||
self, | |||
size: Shape, | |||
alpha: Union[float, Sequence[float]], | |||
num_step: int, | |||
a: Parameter = 0.02, | |||
b: Parameter = 0.20, | |||
c: Parameter = -65., | |||
d: Parameter = 8., | |||
f: Parameter = 0.04, | |||
g: Parameter = 5., | |||
h: Parameter = 140., | |||
tau: Parameter = 1., | |||
R: Parameter = 1., | |||
V_th: Parameter = 30., | |||
V_initializer: Union[Initializer, Callable, Tensor] = OneInit(-65.), | |||
u_initializer: Union[Initializer, Callable, Tensor] = OneInit(0.20 * -65.), | |||
name: str = None | |||
): | |||
# initialization | |||
super(FractionalIzhikevich, self).__init__(size=size, name=name) | |||
# params | |||
self.alpha = alpha | |||
check_float(alpha, 'alpha', min_bound=0., max_bound=1., allow_none=False, allow_int=True) | |||
self.a = a | |||
self.b = b | |||
self.c = c | |||
self.d = d | |||
self.f = f | |||
self.g = g | |||
self.h = h | |||
self.tau = tau | |||
self.R = R | |||
self.V_th = V_th | |||
# variables | |||
check_initializer(V_initializer, 'V_initializer', allow_none=False) | |||
check_initializer(u_initializer, 'u_initializer', allow_none=False) | |||
self.V = bm.Variable(init_param(V_initializer, (self.num,))) | |||
self.u = bm.Variable(init_param(u_initializer, (self.num,))) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
# functions | |||
check_integer(num_step, 'num_step', allow_none=False) | |||
self.integral = CaputoL1Schema(f=self.derivative, | |||
alpha=alpha, | |||
num_step=num_step, | |||
inits=[self.V, self.u]) | |||
def dV(self, V, t, u, I_ext): | |||
dVdt = self.f * V * V + self.g * V + self.h - u + self.R * I_ext | |||
return dVdt / self.tau | |||
def du(self, u, t, V): | |||
dudt = self.a * (self.b * V - u) | |||
return dudt / self.tau | |||
@property | |||
def derivative(self): | |||
return JointEq([self.dV, self.du]) | |||
def update(self, _t, _dt): | |||
V, u = self.integral(self.V, self.u, t=_t, I_ext=self.input, dt=_dt) | |||
spikes = V >= self.V_th | |||
self.t_last_spike.value = bm.where(spikes, _t, self.t_last_spike) | |||
self.V.value = bm.where(spikes, self.c, V) | |||
self.u.value = bm.where(spikes, u + self.d, u) | |||
self.spike.value = spikes | |||
self.input[:] = 0. | |||
def set_init(self, values: dict): | |||
for k, v in values.items(): | |||
if k not in self.integral.inits: | |||
raise ValueError(f'Variable "{k}" is not defined in this model.') | |||
variable = getattr(self, k) | |||
variable[:] = v | |||
self.integral.inits[k][:] = v |
@@ -0,0 +1,72 @@ | |||
# -*- coding: utf-8 -*- | |||
import brainpy.math as bm | |||
from brainpy.dyn.base import NeuGroup | |||
from brainpy.integrators.sde import sdeint | |||
from brainpy.types import Parameter, Shape | |||
__all__ = [ | |||
'OUProcess', | |||
] | |||
class OUProcess(NeuGroup): | |||
r"""The Ornstein–Uhlenbeck process. | |||
The Ornstein–Uhlenbeck process :math:`x_{t}` is defined by the following | |||
stochastic differential equation: | |||
.. math:: | |||
\tau dx_{t}=-\theta \,x_{t}\,dt+\sigma \,dW_{t} | |||
where :math:`\theta >0` and :math:`\sigma >0` are parameters and :math:`W_{t}` | |||
denotes the Wiener process. | |||
Parameters | |||
---------- | |||
size: int, sequence of int | |||
The model size. | |||
mean: Parameter | |||
The noise mean value. | |||
sigma: Parameter | |||
The noise amplitude. | |||
tau: Parameter | |||
The decay time constant. | |||
method: str | |||
The numerical integration method for stochastic differential equation. | |||
name: str | |||
The model name. | |||
""" | |||
def __init__( | |||
self, | |||
size: Shape, | |||
mean: Parameter, | |||
sigma: Parameter, | |||
tau: Parameter, | |||
method: str = 'euler', | |||
name: str = None | |||
): | |||
super(OUProcess, self).__init__(size=size, name=name) | |||
# parameters | |||
self.mean = mean | |||
self.sigma = sigma | |||
self.tau = tau | |||
# variables | |||
self.x = bm.Variable(bm.ones(self.num) * mean) | |||
# integral functions | |||
self.integral = sdeint(f=self.df, g=self.dg, method=method) | |||
def df(self, x, t): | |||
f_x_ou = (self.mean - x) / self.tau | |||
return f_x_ou | |||
def dg(self, x, t): | |||
return self.sigma | |||
def update(self, _t, _dt): | |||
self.x.value = self.integral(self.x, _t, _dt) |
@@ -1,145 +1,155 @@ | |||
# -*- coding: utf-8 -*- | |||
from typing import Union, Callable | |||
import numpy as np | |||
from jax.experimental.host_callback import id_tap | |||
import brainpy.math as bm | |||
from brainpy import check | |||
from brainpy.dyn.base import NeuGroup | |||
from brainpy.initialize import Initializer, Uniform | |||
from brainpy.initialize import init_param | |||
from brainpy.integrators.dde import ddeint | |||
from brainpy.integrators.joint_eq import JointEq | |||
from brainpy.integrators.ode import odeint | |||
from brainpy.types import Parameter, Shape | |||
from brainpy.tools.checking import check_float, check_initializer | |||
from brainpy.types import Parameter, Shape, Tensor | |||
from .noise_models import OUProcess | |||
__all__ = [ | |||
'FHN', | |||
'RateGroup', | |||
'RateFHN', | |||
'FeedbackFHN', | |||
'MeanFieldQIF', | |||
'RateQIF', | |||
'StuartLandauOscillator', | |||
'WilsonCowanModel', | |||
] | |||
class FHN(NeuGroup): | |||
r"""FitzHugh-Nagumo neuron model. | |||
**Model Descriptions** | |||
The FitzHugh–Nagumo model (FHN), named after Richard FitzHugh (1922–2007) | |||
who suggested the system in 1961 [1]_ and J. Nagumo et al. who created the | |||
equivalent circuit the following year, describes a prototype of an excitable | |||
system (e.g., a neuron). | |||
class RateGroup(NeuGroup): | |||
def update(self, _t, _dt): | |||
raise NotImplementedError | |||
The motivation for the FitzHugh-Nagumo model was to isolate conceptually | |||
the essentially mathematical properties of excitation and propagation from | |||
the electrochemical properties of sodium and potassium ion flow. The model | |||
consists of | |||
- a *voltage-like variable* having cubic nonlinearity that allows regenerative | |||
self-excitation via a positive feedback, and | |||
- a *recovery variable* having a linear dynamics that provides a slower negative feedback. | |||
class RateFHN(NeuGroup): | |||
r"""FitzHugh-Nagumo system used in [1]_. | |||
.. math:: | |||
\begin{aligned} | |||
{\dot {v}} &=v-{\frac {v^{3}}{3}}-w+RI_{\rm {ext}}, \\ | |||
\tau {\dot {w}}&=v+a-bw. | |||
\end{aligned} | |||
The FHN Model is an example of a relaxation oscillator | |||
because, if the external stimulus :math:`I_{\text{ext}}` | |||
exceeds a certain threshold value, the system will exhibit | |||
a characteristic excursion in phase space, before the | |||
variables :math:`v` and :math:`w` relax back to their rest values. | |||
This behaviour is typical for spike generations (a short, | |||
nonlinear elevation of membrane voltage :math:`v`, | |||
diminished over time by a slower, linear recovery variable | |||
:math:`w`) in a neuron after stimulation by an external | |||
input current. | |||
**Model Examples** | |||
.. plot:: | |||
:include-source: True | |||
>>> import brainpy as bp | |||
>>> fhn = bp.dyn.FHN(1) | |||
>>> runner = bp.dyn.DSRunner(fhn, inputs=('input', 1.), monitors=['V', 'w']) | |||
>>> runner.run(100.) | |||
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.w, legend='w') | |||
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', show=True) | |||
**Model Parameters** | |||
============= ============== ======== ======================== | |||
**Parameter** **Init Value** **Unit** **Explanation** | |||
------------- -------------- -------- ------------------------ | |||
a 1 \ Positive constant | |||
b 1 \ Positive constant | |||
tau 10 ms Membrane time constant. | |||
V_th 1.8 mV Threshold potential of spike. | |||
============= ============== ======== ======================== | |||
**Model Variables** | |||
\frac{dx}{dt} = -\alpha V^3 + \beta V^2 + \gamma V - w + I_{ext}\\ | |||
\tau \frac{dy}{dt} = (V - \delta - \epsilon w) | |||
================== ================= ========================================================= | |||
**Variables name** **Initial Value** **Explanation** | |||
------------------ ----------------- --------------------------------------------------------- | |||
V 0 Membrane potential. | |||
w 0 A recovery variable which represents | |||
the combined effects of sodium channel | |||
de-inactivation and potassium channel | |||
deactivation. | |||
input 0 External and synaptic input current. | |||
spike False Flag to mark whether the neuron is spiking. | |||
t_last_spike -1e7 Last spike time stamp. | |||
================== ================= ========================================================= | |||
Parameters | |||
---------- | |||
size: Shape | |||
The model size. | |||
x_ou_mean: Parameter | |||
The noise mean of the :math:`x` variable, [mV/ms] | |||
y_ou_mean: Parameter | |||
The noise mean of the :math:`y` variable, [mV/ms]. | |||
x_ou_sigma: Parameter | |||
The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. | |||
y_ou_sigma: Parameter | |||
The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. | |||
x_ou_tau: Parameter | |||
The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. | |||
y_ou_tau: Parameter | |||
The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. | |||
**References** | |||
.. [1] FitzHugh, Richard. "Impulses and physiological states in theoretical models of nerve membrane." Biophysical journal 1.6 (1961): 445-466. | |||
.. [2] https://en.wikipedia.org/wiki/FitzHugh%E2%80%93Nagumo_model | |||
.. [3] http://www.scholarpedia.org/article/FitzHugh-Nagumo_model | |||
References | |||
---------- | |||
.. [1] Kostova, T., Ravindran, R., & Schonbek, M. (2004). FitzHugh–Nagumo | |||
revisited: Types of bifurcations, periodical forcing and stability | |||
regions by a Lyapunov functional. International journal of | |||
bifurcation and chaos, 14(03), 913-925. | |||
""" | |||
def __init__(self, | |||
def __init__( | |||
self, | |||
size: Shape, | |||
a: Parameter = 0.7, | |||
b: Parameter = 0.8, | |||
tau: Parameter = 12.5, | |||
Vth: Parameter = 1.8, | |||
method: str = 'exp_auto', | |||
name: str = None): | |||
# initialization | |||
super(FHN, self).__init__(size=size, name=name) | |||
# parameters | |||
self.a = a | |||
self.b = b | |||
# fhn parameters | |||
alpha: Parameter = 3.0, | |||
beta: Parameter = 4.0, | |||
gamma: Parameter = -1.5, | |||
delta: Parameter = 0.0, | |||
epsilon: Parameter = 0.5, | |||
tau: Parameter = 20.0, | |||
# noise parameters | |||
x_ou_mean: Parameter = 0.0, | |||
x_ou_sigma: Parameter = 0.0, | |||
x_ou_tau: Parameter = 5.0, | |||
y_ou_mean: Parameter = 0.0, | |||
y_ou_sigma: Parameter = 0.0, | |||
y_ou_tau: Parameter = 5.0, | |||
# other parameters | |||
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05), | |||
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05), | |||
method: str = None, | |||
sde_method: str = None, | |||
name: str = None, | |||
): | |||
super(RateFHN, self).__init__(size=size, name=name) | |||
# model parameters | |||
self.alpha = alpha | |||
self.beta = beta | |||
self.gamma = gamma | |||
self.delta = delta | |||
self.epsilon = epsilon | |||
self.tau = tau | |||
self.Vth = Vth | |||
# noise parameters | |||
self.x_ou_mean = x_ou_mean # mV/ms, OU process | |||
self.y_ou_mean = y_ou_mean # mV/ms, OU process | |||
self.x_ou_sigma = x_ou_sigma # mV/ms/sqrt(ms), noise intensity | |||
self.y_ou_sigma = y_ou_sigma # mV/ms/sqrt(ms), noise intensity | |||
self.x_ou_tau = x_ou_tau # ms, timescale of the Ornstein-Uhlenbeck noise process | |||
self.y_ou_tau = y_ou_tau # ms, timescale of the Ornstein-Uhlenbeck noise process | |||
# variables | |||
self.w = bm.Variable(bm.zeros(self.num)) | |||
self.V = bm.Variable(bm.zeros(self.num)) | |||
check_initializer(x_initializer, 'x_initializer') | |||
check_initializer(y_initializer, 'y_initializer') | |||
self.x = bm.Variable(init_param(x_initializer, (self.num,))) | |||
self.y = bm.Variable(init_param(x_initializer, (self.num,))) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
# integral | |||
self.integral = odeint(method=method, f=self.derivative) | |||
# noise variables | |||
self.x_ou = self.y_ou = None | |||
if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): | |||
self.x_ou = OUProcess(self.num, | |||
self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau, | |||
method=sde_method) | |||
if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): | |||
self.y_ou = OUProcess(self.num, | |||
self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau, | |||
method=sde_method) | |||
def dV(self, V, t, w, I_ext): | |||
return V - V * V * V / 3 - w + I_ext | |||
# integral functions | |||
self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method) | |||
def dw(self, w, t, V): | |||
return (V + self.a - self.b * w) / self.tau | |||
def dx(self, x, t, y, x_ext): | |||
return - self.alpha * x ** 3 + self.beta * x ** 2 + self.gamma * x - y + x_ext | |||
@property | |||
def derivative(self): | |||
return JointEq([self.dV, self.dw]) | |||
def dy(self, y, t, x, y_ext=0.): | |||
return (x - self.delta - self.epsilon * y) / self.tau + y_ext | |||
def update(self, _t, _dt): | |||
V, w = self.integral(self.V, self.w, _t, self.input, dt=_dt) | |||
self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth) | |||
self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike) | |||
self.V.value = V | |||
self.w.value = w | |||
if self.x_ou is not None: | |||
self.input += self.x_ou.x | |||
self.x_ou.update(_t, _dt) | |||
y_ext = 0. | |||
if self.y_ou is not None: | |||
y_ext = self.y_ou.x | |||
self.y_ou.update(_t, _dt) | |||
x, y = self.integral(self.x, self.y, _t, x_ext=self.input, y_ext=y_ext, dt=_dt) | |||
self.x.value = x | |||
self.y.value = y | |||
self.input[:] = 0. | |||
@@ -151,8 +161,8 @@ class FeedbackFHN(NeuGroup): | |||
.. math:: | |||
\begin{aligned} | |||
\frac{dv}{dt} &= v(t) - \frac{v^3(t)}{3} - w(t) + \mu[v(t-\mathrm{delay}) - v_0] \\ | |||
\frac{dw}{dt} &= [v(t) + a - b w(t)] / \tau | |||
\frac{dx}{dt} &= x(t) - \frac{x^3(t)}{3} - y(t) + \mu[x(t-\mathrm{delay}) - x_0] \\ | |||
\frac{dy}{dt} &= [x(t) + a - b y(t)] / \tau | |||
\end{aligned} | |||
@@ -160,10 +170,10 @@ class FeedbackFHN(NeuGroup): | |||
>>> import brainpy as bp | |||
>>> fhn = bp.dyn.FeedbackFHN(1, delay=10.) | |||
>>> runner = bp.dyn.DSRunner(fhn, inputs=('input', 1.), monitors=['V', 'w']) | |||
>>> runner = bp.dyn.DSRunner(fhn, inputs=('input', 1.), monitors=['x', 'y']) | |||
>>> runner.run(100.) | |||
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.w, legend='w') | |||
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', show=True) | |||
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.y, legend='y') | |||
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.x, legend='x', show=True) | |||
**Model Parameters** | |||
@@ -181,6 +191,23 @@ class FeedbackFHN(NeuGroup): | |||
when negative, it is a inhibitory feedback. | |||
============= ============== ======== ======================== | |||
Parameters | |||
---------- | |||
x_ou_mean: Parameter | |||
The noise mean of the :math:`x` variable, [mV/ms] | |||
y_ou_mean: Parameter | |||
The noise mean of the :math:`y` variable, [mV/ms]. | |||
x_ou_sigma: Parameter | |||
The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. | |||
y_ou_sigma: Parameter | |||
The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. | |||
x_ou_tau: Parameter | |||
The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. | |||
y_ou_tau: Parameter | |||
The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. | |||
References | |||
---------- | |||
.. [4] Plant, Richard E. (1981). *A FitzHugh Differential-Difference | |||
@@ -189,61 +216,109 @@ class FeedbackFHN(NeuGroup): | |||
""" | |||
def __init__(self, | |||
def __init__( | |||
self, | |||
size: Shape, | |||
# model parameters | |||
a: Parameter = 0.7, | |||
b: Parameter = 0.8, | |||
delay: Parameter = 10., | |||
tau: Parameter = 12.5, | |||
mu: Parameter = 1.6886, | |||
v0: Parameter = -1, | |||
Vth: Parameter = 1.8, | |||
x0: Parameter = -1, | |||
# noise parameters | |||
x_ou_mean: Parameter = 0.0, | |||
x_ou_sigma: Parameter = 0.0, | |||
x_ou_tau: Parameter = 5.0, | |||
y_ou_mean: Parameter = 0.0, | |||
y_ou_sigma: Parameter = 0.0, | |||
y_ou_tau: Parameter = 5.0, | |||
# other parameters | |||
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05), | |||
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05), | |||
method: str = 'rk4', | |||
name: str = None): | |||
sde_method: str = None, | |||
name: str = None, | |||
dt: float = None | |||
): | |||
super(FeedbackFHN, self).__init__(size=size, name=name) | |||
# dt | |||
self.dt = bm.get_dt() if dt is None else dt | |||
check_float(self.dt, 'dt', allow_none=False, min_bound=0., allow_int=False) | |||
# parameters | |||
self.a = a | |||
self.b = b | |||
self.delay = delay | |||
self.tau = tau | |||
self.mu = mu # feedback strength | |||
self.v0 = v0 # resting potential | |||
self.Vth = Vth | |||
self.v0 = x0 # resting potential | |||
# noise parameters | |||
self.x_ou_mean = x_ou_mean | |||
self.y_ou_mean = y_ou_mean | |||
self.x_ou_sigma = x_ou_sigma | |||
self.y_ou_sigma = y_ou_sigma | |||
self.x_ou_tau = x_ou_tau | |||
self.y_ou_tau = y_ou_tau | |||
# variables | |||
self.w = bm.Variable(bm.zeros(self.num)) | |||
self.V = bm.Variable(bm.zeros(self.num)) | |||
self.Vdelay = bm.FixedLenDelay(self.num, self.delay) | |||
check_initializer(x_initializer, 'x_initializer') | |||
check_initializer(y_initializer, 'y_initializer') | |||
self.x = bm.Variable(init_param(x_initializer, (self.num,))) | |||
self.y = bm.Variable(init_param(x_initializer, (self.num,))) | |||
self.x_delay = bm.TimeDelay(self.x, self.delay, dt=self.dt, interp_method='round') | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
# noise variables | |||
self.x_ou = self.y_ou = None | |||
if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): | |||
self.x_ou = OUProcess(self.num, | |||
self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau, | |||
method=sde_method) | |||
if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): | |||
self.y_ou = OUProcess(self.num, | |||
self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau, | |||
method=sde_method) | |||
# integral | |||
self.integral = ddeint(method=method, f=self.derivative, | |||
state_delays={'V': self.Vdelay}) | |||
self.integral = ddeint(method=method, | |||
f=JointEq([self.dx, self.dy]), | |||
state_delays={'V': self.x_delay}) | |||
def dV(self, V, t, w, Vdelay): | |||
return (V - V * V * V / 3 - w + self.input + | |||
self.mu * (Vdelay(t - self.delay) - self.v0)) | |||
def dx(self, x, t, y, x_ext): | |||
return x - x * x * x / 3 - y + x_ext + self.mu * (self.x_delay(t - self.delay) - self.v0) | |||
def dw(self, w, t, V): | |||
return (V + self.a - self.b * w) / self.tau | |||
def dy(self, y, t, x, y_ext): | |||
return (x + self.a - self.b * y + y_ext) / self.tau | |||
@property | |||
def derivative(self): | |||
return JointEq([self.dV, self.dw]) | |||
def _check_dt(self, dt, *args): | |||
if np.absolute(dt - self.dt) > 1e-6: | |||
raise ValueError(f'The "dt" {dt} used in model running is ' | |||
f'not consistent with the "dt" {self.dt} ' | |||
f'used in model definition.') | |||
def update(self, _t, _dt): | |||
V, w = self.integral(self.V, self.w, _t, Vdelay=self.Vdelay, dt=_dt) | |||
self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth) | |||
self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike) | |||
self.V.value = V | |||
self.w.value = w | |||
if check.is_checking(): | |||
id_tap(self._check_dt, _dt) | |||
if self.x_ou is not None: | |||
self.input += self.x_ou.x | |||
self.x_ou.update(_t, _dt) | |||
y_ext = 0. | |||
if self.y_ou is not None: | |||
y_ext = self.y_ou.x | |||
self.y_ou.update(_t, _dt) | |||
x, y = self.integral(self.x, self.y, _t, x_ext=self.input, y_ext=y_ext, dt=_dt) | |||
self.x.value = x | |||
self.y.value = y | |||
self.input[:] = 0. | |||
class MeanFieldQIF(NeuGroup): | |||
class RateQIF(NeuGroup): | |||
r"""A mean-field model of a quadratic integrate-and-fire neuron population. | |||
**Model Descriptions** | |||
@@ -282,6 +357,21 @@ class MeanFieldQIF(NeuGroup): | |||
J 15 \ the strength of the recurrent coupling inside the population | |||
============= ============== ======== ======================== | |||
Parameters | |||
---------- | |||
x_ou_mean: Parameter | |||
The noise mean of the :math:`x` variable, [mV/ms] | |||
y_ou_mean: Parameter | |||
The noise mean of the :math:`y` variable, [mV/ms]. | |||
x_ou_sigma: Parameter | |||
The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. | |||
y_ou_sigma: Parameter | |||
The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. | |||
x_ou_tau: Parameter | |||
The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. | |||
y_ou_tau: Parameter | |||
The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. | |||
References | |||
---------- | |||
@@ -294,15 +384,32 @@ class MeanFieldQIF(NeuGroup): | |||
""" | |||
def __init__(self, | |||
def __init__( | |||
self, | |||
size: Shape, | |||
# model parameters | |||
tau: Parameter = 1., | |||
eta: Parameter = -5.0, | |||
delta: Parameter = 1.0, | |||
J: Parameter = 15., | |||
# noise parameters | |||
x_ou_mean: Parameter = 0.0, | |||
x_ou_sigma: Parameter = 0.0, | |||
x_ou_tau: Parameter = 5.0, | |||
y_ou_mean: Parameter = 0.0, | |||
y_ou_sigma: Parameter = 0.0, | |||
y_ou_tau: Parameter = 5.0, | |||
# other parameters | |||
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05), | |||
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.05), | |||
method: str = 'exp_auto', | |||
name: str = None): | |||
super(MeanFieldQIF, self).__init__(size=size, name=name) | |||
name: str = None, | |||
sde_method: str = None, | |||
): | |||
super(RateQIF, self).__init__(size=size, name=name) | |||
# parameters | |||
self.tau = tau # | |||
@@ -310,54 +417,309 @@ class MeanFieldQIF(NeuGroup): | |||
self.delta = delta # the half-width at half maximum of the Lorenzian distribution over the neural excitability | |||
self.J = J # the strength of the recurrent coupling inside the population | |||
# noise parameters | |||
self.x_ou_mean = x_ou_mean | |||
self.y_ou_mean = y_ou_mean | |||
self.x_ou_sigma = x_ou_sigma | |||
self.y_ou_sigma = y_ou_sigma | |||
self.x_ou_tau = x_ou_tau | |||
self.y_ou_tau = y_ou_tau | |||
# variables | |||
self.r = bm.Variable(bm.ones(1)) | |||
self.V = bm.Variable(bm.ones(1)) | |||
self.input = bm.Variable(bm.zeros(1)) | |||
check_initializer(x_initializer, 'x_initializer') | |||
check_initializer(y_initializer, 'y_initializer') | |||
self.x = bm.Variable(init_param(x_initializer, (self.num,))) | |||
self.y = bm.Variable(init_param(x_initializer, (self.num,))) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
# noise variables | |||
self.x_ou = self.y_ou = None | |||
if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): | |||
self.x_ou = OUProcess(self.num, | |||
self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau, | |||
method=sde_method) | |||
if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): | |||
self.y_ou = OUProcess(self.num, | |||
self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau, | |||
method=sde_method) | |||
# functions | |||
self.integral = odeint(self.derivative, method=method) | |||
self.integral = odeint(JointEq([self.dx, self.dy]), method=method) | |||
def dy(self, y, t, x, y_ext): | |||
return (self.delta / (bm.pi * self.tau) + 2. * x * y + y_ext) / self.tau | |||
def dx(self, x, t, y, x_ext): | |||
return (x ** 2 + self.eta + x_ext + self.J * y * self.tau - | |||
(bm.pi * y * self.tau) ** 2) / self.tau | |||
def update(self, _t, _dt): | |||
if self.x_ou is not None: | |||
self.input += self.x_ou.x | |||
self.x_ou.update(_t, _dt) | |||
y_ext = 0. | |||
if self.y_ou is not None: | |||
y_ext = self.y_ou.x | |||
self.y_ou.update(_t, _dt) | |||
x, y = self.integral(self.x, self.y, t=_t, x_ext=self.input, y_ext=y_ext, dt=_dt) | |||
self.x.value = x | |||
self.y.value = y | |||
self.input[:] = 0. | |||
class StuartLandauOscillator(RateGroup): | |||
r""" | |||
Stuart-Landau model with Hopf bifurcation. | |||
.. math:: | |||
\frac{dx}{dt} = (a - x^2 - y^2) * x - w*y + I^x_{ext} \\ | |||
\frac{dy}{dt} = (a - x^2 - y^2) * y + w*x + I^y_{ext} | |||
Parameters | |||
---------- | |||
x_ou_mean: Parameter | |||
The noise mean of the :math:`x` variable, [mV/ms] | |||
y_ou_mean: Parameter | |||
The noise mean of the :math:`y` variable, [mV/ms]. | |||
x_ou_sigma: Parameter | |||
The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. | |||
y_ou_sigma: Parameter | |||
The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. | |||
x_ou_tau: Parameter | |||
The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. | |||
y_ou_tau: Parameter | |||
The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. | |||
""" | |||
def __init__( | |||
self, | |||
size: Shape, | |||
# model parameters | |||
a=0.25, | |||
w=0.2, | |||
# noise parameters | |||
x_ou_mean: Parameter = 0.0, | |||
x_ou_sigma: Parameter = 0.0, | |||
x_ou_tau: Parameter = 5.0, | |||
y_ou_mean: Parameter = 0.0, | |||
y_ou_sigma: Parameter = 0.0, | |||
y_ou_tau: Parameter = 5.0, | |||
# other parameters | |||
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.5), | |||
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(0, 0.5), | |||
method: str = None, | |||
sde_method: str = None, | |||
name: str = None, | |||
): | |||
super(StuartLandauOscillator, self).__init__(size=size, | |||
name=name) | |||
# model parameters | |||
self.a = a | |||
self.w = w | |||
# noise parameters | |||
self.x_ou_mean = x_ou_mean | |||
self.y_ou_mean = y_ou_mean | |||
self.x_ou_sigma = x_ou_sigma | |||
self.y_ou_sigma = y_ou_sigma | |||
self.x_ou_tau = x_ou_tau | |||
self.y_ou_tau = y_ou_tau | |||
# variables | |||
check_initializer(x_initializer, 'x_initializer') | |||
check_initializer(y_initializer, 'y_initializer') | |||
self.x = bm.Variable(init_param(x_initializer, (self.num,))) | |||
self.y = bm.Variable(init_param(x_initializer, (self.num,))) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
def dr(self, r, t, v): | |||
return (self.delta / (bm.pi * self.tau) + 2. * r * v) / self.tau | |||
# noise variables | |||
self.x_ou = self.y_ou = None | |||
if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): | |||
self.x_ou = OUProcess(self.num, | |||
self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau, | |||
method=sde_method) | |||
if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): | |||
self.y_ou = OUProcess(self.num, | |||
self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau, | |||
method=sde_method) | |||
def dV(self, v, t, r): | |||
return (v ** 2 + self.eta + self.input + self.J * r * self.tau - | |||
(bm.pi * r * self.tau) ** 2) / self.tau | |||
# integral functions | |||
self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method) | |||
@property | |||
def derivative(self): | |||
return JointEq([self.dV, self.dr]) | |||
def dx(self, x, t, y, x_ext, a, w): | |||
return (a - x * x - y * y) * x - w * y + x_ext | |||
def dy(self, y, t, x, y_ext, a, w): | |||
return (a - x * x - y * y) * y - w * y + y_ext | |||
def update(self, _t, _dt): | |||
self.V.value, self.r.value = self.integral(self.V, self.r, _t, _dt) | |||
self.integral[:] = 0. | |||
if self.x_ou is not None: | |||
self.input += self.x_ou.x | |||
self.x_ou.update(_t, _dt) | |||
y_ext = 0. | |||
if self.y_ou is not None: | |||
y_ext = self.y_ou.x | |||
self.y_ou.update(_t, _dt) | |||
x, y = self.integral(self.x, self.y, _t, x_ext=self.input, | |||
y_ext=y_ext, a=self.a, w=self.w, dt=_dt) | |||
self.x.value = x | |||
self.y.value = y | |||
self.input[:] = 0. | |||
class WilsonCowanModel(RateGroup): | |||
"""Wilson-Cowan population model. | |||
class VanDerPolOscillator(NeuGroup): | |||
pass | |||
Parameters | |||
---------- | |||
x_ou_mean: Parameter | |||
The noise mean of the :math:`x` variable, [mV/ms] | |||
y_ou_mean: Parameter | |||
The noise mean of the :math:`y` variable, [mV/ms]. | |||
x_ou_sigma: Parameter | |||
The noise intensity of the :math:`x` variable, [mV/ms/sqrt(ms)]. | |||
y_ou_sigma: Parameter | |||
The noise intensity of the :math:`y` variable, [mV/ms/sqrt(ms)]. | |||
x_ou_tau: Parameter | |||
The timescale of the Ornstein-Uhlenbeck noise process of :math:`x` variable, [ms]. | |||
y_ou_tau: Parameter | |||
The timescale of the Ornstein-Uhlenbeck noise process of :math:`y` variable, [ms]. | |||
class ThetaNeuron(NeuGroup): | |||
pass | |||
""" | |||
class MeanFieldQIFWithSFA(NeuGroup): | |||
pass | |||
def __init__( | |||
self, | |||
size: Shape, | |||
# Excitatory parameters | |||
E_tau=1., # excitatory time constant | |||
E_a=1.2, # excitatory gain | |||
E_theta=2.8, # excitatory firing threshold | |||
# Inhibitory parameters | |||
I_tau=1., # inhibitory time constant | |||
I_a=1., # inhibitory gain | |||
I_theta=4.0, # inhibitory firing threshold | |||
# connection parameters | |||
wEE=12., # local E-E coupling | |||
wIE=4., # local E-I coupling | |||
wEI=13., # local I-E coupling | |||
wII=11., # local I-I coupling | |||
# Refractory parameter | |||
r=1, | |||
# state initializer | |||
x_initializer: Union[Initializer, Callable, Tensor] = Uniform(max_val=0.05), | |||
y_initializer: Union[Initializer, Callable, Tensor] = Uniform(max_val=0.05), | |||
# noise parameters | |||
x_ou_mean: Parameter = 0.0, | |||
x_ou_sigma: Parameter = 0.0, | |||
x_ou_tau: Parameter = 5.0, | |||
y_ou_mean: Parameter = 0.0, | |||
y_ou_sigma: Parameter = 0.0, | |||
y_ou_tau: Parameter = 5.0, | |||
# other parameters | |||
sde_method: str = None, | |||
method: str = 'exp_euler_auto', | |||
name: str = None, | |||
): | |||
super(WilsonCowanModel, self).__init__(size=size, name=name) | |||
# model parameters | |||
self.E_tau = E_tau | |||
self.E_a = E_a | |||
self.E_theta = E_theta | |||
self.I_tau = I_tau | |||
self.I_a = I_a | |||
self.I_theta = I_theta | |||
self.wEE = wEE | |||
self.wIE = wIE | |||
self.wEI = wEI | |||
self.wII = wII | |||
self.r = r | |||
# noise parameters | |||
self.x_ou_mean = x_ou_mean | |||
self.y_ou_mean = y_ou_mean | |||
self.x_ou_sigma = x_ou_sigma | |||
self.y_ou_sigma = y_ou_sigma | |||
self.x_ou_tau = x_ou_tau | |||
self.y_ou_tau = y_ou_tau | |||
# variables | |||
check_initializer(x_initializer, 'x_initializer') | |||
check_initializer(y_initializer, 'y_initializer') | |||
self.x = bm.Variable(init_param(x_initializer, (self.num,))) | |||
self.y = bm.Variable(init_param(x_initializer, (self.num,))) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
# noise variables | |||
self.x_ou = self.y_ou = None | |||
if bm.any(self.x_ou_mean > 0.) or bm.any(self.x_ou_sigma > 0.): | |||
self.x_ou = OUProcess(self.num, | |||
self.x_ou_mean, self.x_ou_sigma, self.x_ou_tau, | |||
method=sde_method) | |||
if bm.any(self.y_ou_mean > 0.) or bm.any(self.y_ou_sigma > 0.): | |||
self.y_ou = OUProcess(self.num, | |||
self.y_ou_mean, self.y_ou_sigma, self.y_ou_tau, | |||
method=sde_method) | |||
class JansenRitModel(NeuGroup): | |||
# functions | |||
self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method) | |||
# functions | |||
def F(self, x, a, theta): | |||
return 1 / (1 + bm.exp(-a * (x - theta))) - 1 / (1 + bm.exp(a * theta)) | |||
def dx(self, x, t, y, x_ext): | |||
x = self.wEE * x - self.wIE * y + x_ext | |||
return (-x + (1 - self.r * x) * self.F(x, self.E_a, self.E_theta)) / self.E_tau | |||
def dy(self, y, t, x, y_ext): | |||
x = self.wEI * x - self.wII * y + y_ext | |||
return (-y + (1 - self.r * y) * self.F(x, self.I_a, self.I_theta)) / self.I_tau | |||
def update(self, _t, _dt): | |||
if self.x_ou is not None: | |||
self.input += self.x_ou.x | |||
self.x_ou.update(_t, _dt) | |||
y_ext = 0. | |||
if self.y_ou is not None: | |||
y_ext = self.y_ou.x | |||
self.y_ou.update(_t, _dt) | |||
x, y = self.integral(self.x, self.y, _t, x_ext=self.input, y_ext=y_ext, dt=_dt) | |||
self.x.value = x | |||
self.y.value = y | |||
self.input[:] = 0. | |||
class JansenRitModel(RateGroup): | |||
pass | |||
class WilsonCowanModel(NeuGroup): | |||
class KuramotoOscillator(RateGroup): | |||
pass | |||
class StuartLandauOscillator(NeuGroup): | |||
class ThetaNeuron(RateGroup): | |||
pass | |||
class KuramotoOscillator(NeuGroup): | |||
class RateQIFWithSFA(RateGroup): | |||
pass | |||
class VanDerPolOscillator(RateGroup): | |||
pass |
@@ -1,16 +1,861 @@ | |||
# -*- coding: utf-8 -*- | |||
from typing import Union, Callable | |||
import brainpy.math as bm | |||
from brainpy.dyn.base import NeuGroup | |||
from brainpy.initialize import ZeroInit, OneInit, Initializer, init_param | |||
from brainpy.integrators.joint_eq import JointEq | |||
from brainpy.integrators.ode import odeint | |||
from brainpy.dyn.base import NeuGroup | |||
from brainpy.tools.checking import check_initializer | |||
from brainpy.types import Shape, Parameter, Tensor | |||
__all__ = [ | |||
'LIF', | |||
'ExpIF', | |||
'AdExIF', | |||
'QuaIF', | |||
'AdQuaIF', | |||
'GIF', | |||
'Izhikevich', | |||
'HindmarshRose', | |||
'FHN', | |||
] | |||
class LIF(NeuGroup): | |||
r"""Leaky integrate-and-fire neuron model. | |||
**Model Descriptions** | |||
The formal equations of a LIF model [1]_ is given by: | |||
.. math:: | |||
\tau \frac{dV}{dt} = - (V(t) - V_{rest}) + I(t) \\ | |||
\text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} \quad | |||
\text{last} \quad \tau_{ref} \quad \text{ms} | |||
where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting | |||
membrane potential, :math:`V_{reset}` is the reset membrane potential, | |||
:math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant, | |||
:math:`\tau_{ref}` is the refractory time period, | |||
and :math:`I` is the time-variant synaptic inputs. | |||
**Model Examples** | |||
- `(Brette, Romain. 2004) LIF phase locking <https://brainpy-examples.readthedocs.io/en/latest/neurons/Romain_2004_LIF_phase_locking.html>`_ | |||
**Model Parameters** | |||
============= ============== ======== ========================================= | |||
**Parameter** **Init Value** **Unit** **Explanation** | |||
------------- -------------- -------- ----------------------------------------- | |||
V_rest 0 mV Resting membrane potential. | |||
V_reset -5 mV Reset potential after spike. | |||
V_th 20 mV Threshold potential of spike. | |||
tau 10 ms Membrane time constant. Compute by R * C. | |||
tau_ref 5 ms Refractory period length.(ms) | |||
============= ============== ======== ========================================= | |||
**Neuron Variables** | |||
================== ================= ========================================================= | |||
**Variables name** **Initial Value** **Explanation** | |||
------------------ ----------------- --------------------------------------------------------- | |||
V 0 Membrane potential. | |||
input 0 External and synaptic input current. | |||
spike False Flag to mark whether the neuron is spiking. | |||
refractory False Flag to mark whether the neuron is in refractory period. | |||
t_last_spike -1e7 Last spike time stamp. | |||
================== ================= ========================================================= | |||
**References** | |||
.. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model | |||
neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. | |||
""" | |||
def __init__( | |||
self, | |||
size: Shape, | |||
V_rest: Parameter = 0., | |||
V_reset: Parameter = -5., | |||
V_th: Parameter = 20., | |||
tau: Parameter = 10., | |||
tau_ref: Parameter = 1., | |||
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), | |||
method: str = 'exp_auto', | |||
name: str = None | |||
): | |||
# initialization | |||
super(LIF, self).__init__(size=size, name=name) | |||
# parameters | |||
self.V_rest = V_rest | |||
self.V_reset = V_reset | |||
self.V_th = V_th | |||
self.tau = tau | |||
self.tau_ref = tau_ref | |||
# variables | |||
check_initializer(V_initializer, 'V_initializer') | |||
self.V = bm.Variable(init_param(V_initializer, (self.num,))) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
# integral | |||
self.integral = odeint(method=method, f=self.derivative) | |||
def derivative(self, V, t, I_ext): | |||
dvdt = (-V + self.V_rest + I_ext) / self.tau | |||
return dvdt | |||
def update(self, _t, _dt): | |||
refractory = (_t - self.t_last_spike) <= self.tau_ref | |||
V = self.integral(self.V, _t, self.input, dt=_dt) | |||
V = bm.where(refractory, self.V, V) | |||
spike = V >= self.V_th | |||
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike) | |||
self.V.value = bm.where(spike, self.V_reset, V) | |||
self.refractory.value = bm.logical_or(refractory, spike) | |||
self.spike.value = spike | |||
self.input[:] = 0. | |||
class ExpIF(NeuGroup): | |||
r"""Exponential integrate-and-fire neuron model. | |||
**Model Descriptions** | |||
In the exponential integrate-and-fire model [1]_, the differential | |||
equation for the membrane potential is given by | |||
.. math:: | |||
\tau\frac{d V}{d t}= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} + RI(t), \\ | |||
\text{after} \, V(t) \gt V_{th}, V(t) = V_{reset} \, \text{last} \, \tau_{ref} \, \text{ms} | |||
This equation has an exponential nonlinearity with "sharpness" parameter :math:`\Delta_{T}` | |||
and "threshold" :math:`\vartheta_{rh}`. | |||
The moment when the membrane potential reaches the numerical threshold :math:`V_{th}` | |||
defines the firing time :math:`t^{(f)}`. After firing, the membrane potential is reset to | |||
:math:`V_{rest}` and integration restarts at time :math:`t^{(f)}+\tau_{\rm ref}`, | |||
where :math:`\tau_{\rm ref}` is an absolute refractory time. | |||
If the numerical threshold is chosen sufficiently high, :math:`V_{th}\gg v+\Delta_T`, | |||
its exact value does not play any role. The reason is that the upswing of the action | |||
potential for :math:`v\gg v +\Delta_{T}` is so rapid, that it goes to infinity in | |||
an incredibly short time. The threshold :math:`V_{th}` is introduced mainly for numerical | |||
convenience. For a formal mathematical analysis of the model, the threshold can be pushed | |||
to infinity. | |||
The model was first introduced by Nicolas Fourcaud-Trocmé, David Hansel, Carl van Vreeswijk | |||
and Nicolas Brunel [1]_. The exponential nonlinearity was later confirmed by Badel et al. [3]_. | |||
It is one of the prominent examples of a precise theoretical prediction in computational | |||
neuroscience that was later confirmed by experimental neuroscience. | |||
Two important remarks: | |||
- (i) The right-hand side of the above equation contains a nonlinearity | |||
that can be directly extracted from experimental data [3]_. In this sense the exponential | |||
nonlinearity is not an arbitrary choice but directly supported by experimental evidence. | |||
- (ii) Even though it is a nonlinear model, it is simple enough to calculate the firing | |||
rate for constant input, and the linear response to fluctuations, even in the presence | |||
of input noise [4]_. | |||
**Model Examples** | |||
.. plot:: | |||
:include-source: True | |||
>>> import brainpy as bp | |||
>>> group = bp.dyn.ExpIF(1) | |||
>>> runner = bp.dyn.DSRunner(group, monitors=['V'], inputs=('input', 10.)) | |||
>>> runner.run(300., ) | |||
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V', show=True) | |||
**Model Parameters** | |||
============= ============== ======== =================================================== | |||
**Parameter** **Init Value** **Unit** **Explanation** | |||
------------- -------------- -------- --------------------------------------------------- | |||
V_rest -65 mV Resting potential. | |||
V_reset -68 mV Reset potential after spike. | |||
V_th -30 mV Threshold potential of spike. | |||
V_T -59.9 mV Threshold potential of generating action potential. | |||
delta_T 3.48 \ Spike slope factor. | |||
R 1 \ Membrane resistance. | |||
tau 10 \ Membrane time constant. Compute by R * C. | |||
tau_ref 1.7 \ Refractory period length. | |||
============= ============== ======== =================================================== | |||
**Model Variables** | |||
================== ================= ========================================================= | |||
**Variables name** **Initial Value** **Explanation** | |||
------------------ ----------------- --------------------------------------------------------- | |||
V 0 Membrane potential. | |||
input 0 External and synaptic input current. | |||
spike False Flag to mark whether the neuron is spiking. | |||
refractory False Flag to mark whether the neuron is in refractory period. | |||
t_last_spike -1e7 Last spike time stamp. | |||
================== ================= ========================================================= | |||
**References** | |||
.. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation | |||
mechanisms determine the neuronal response to fluctuating | |||
inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. | |||
.. [2] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014). | |||
Neuronal dynamics: From single neurons to networks and models | |||
of cognition. Cambridge University Press. | |||
.. [3] Badel, Laurent, Sandrine Lefort, Romain Brette, Carl CH Petersen, | |||
Wulfram Gerstner, and Magnus JE Richardson. "Dynamic IV curves | |||
are reliable predictors of naturalistic pyramidal-neuron voltage | |||
traces." Journal of Neurophysiology 99, no. 2 (2008): 656-666. | |||
.. [4] Richardson, Magnus JE. "Firing-rate response of linear and nonlinear | |||
integrate-and-fire neurons to modulated current-based and | |||
conductance-based synaptic drive." Physical Review E 76, no. 2 (2007): 021919. | |||
.. [5] https://en.wikipedia.org/wiki/Exponential_integrate-and-fire | |||
""" | |||
def __init__( | |||
self, | |||
size: Shape, | |||
V_rest: Parameter = -65., | |||
V_reset: Parameter = -68., | |||
V_th: Parameter = -30., | |||
V_T: Parameter = -59.9, | |||
delta_T: Parameter = 3.48, | |||
R: Parameter = 1., | |||
tau: Parameter = 10., | |||
tau_ref: Parameter = 1.7, | |||
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), | |||
method: str = 'exp_auto', | |||
name: str = None | |||
): | |||
# initialize | |||
super(ExpIF, self).__init__(size=size, name=name) | |||
# parameters | |||
self.V_rest = V_rest | |||
self.V_reset = V_reset | |||
self.V_th = V_th | |||
self.V_T = V_T | |||
self.delta_T = delta_T | |||
self.R = R | |||
self.tau = tau | |||
self.tau_ref = tau_ref | |||
# variables | |||
check_initializer(V_initializer, 'V_initializer') | |||
self.V = bm.Variable(init_param(V_initializer, (self.num,))) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
# integral | |||
self.integral = odeint(method=method, f=self.derivative) | |||
def derivative(self, V, t, I_ext): | |||
exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) | |||
dvdt = (- (V - self.V_rest) + exp_v + self.R * I_ext) / self.tau | |||
return dvdt | |||
def update(self, _t, _dt): | |||
refractory = (_t - self.t_last_spike) <= self.tau_ref | |||
V = self.integral(self.V, _t, self.input, dt=_dt) | |||
V = bm.where(refractory, self.V, V) | |||
spike = self.V_th <= V | |||
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike) | |||
self.V.value = bm.where(spike, self.V_reset, V) | |||
self.refractory.value = bm.logical_or(refractory, spike) | |||
self.spike.value = spike | |||
self.input[:] = 0. | |||
class AdExIF(NeuGroup): | |||
r"""Adaptive exponential integrate-and-fire neuron model. | |||
**Model Descriptions** | |||
The **adaptive exponential integrate-and-fire model**, also called AdEx, is a | |||
spiking neuron model with two variables [1]_ [2]_. | |||
.. math:: | |||
\begin{aligned} | |||
\tau_m\frac{d V}{d t} &= - (V-V_{rest}) + \Delta_T e^{\frac{V-V_T}{\Delta_T}} - Rw + RI(t), \\ | |||
\tau_w \frac{d w}{d t} &=a(V-V_{rest}) - w | |||
\end{aligned} | |||
once the membrane potential reaches the spike threshold, | |||
.. math:: | |||
V \rightarrow V_{reset}, \\ | |||
w \rightarrow w+b. | |||
The first equation describes the dynamics of the membrane potential and includes | |||
an activation term with an exponential voltage dependence. Voltage is coupled to | |||
a second equation which describes adaptation. Both variables are reset if an action | |||
potential has been triggered. The combination of adaptation and exponential voltage | |||
dependence gives rise to the name Adaptive Exponential Integrate-and-Fire model. | |||
The adaptive exponential integrate-and-fire model is capable of describing known | |||
neuronal firing patterns, e.g., adapting, bursting, delayed spike initiation, | |||
initial bursting, fast spiking, and regular spiking. | |||
**Model Examples** | |||
- `Examples for different firing patterns <https://brainpy-examples.readthedocs.io/en/latest/neurons/Gerstner_2005_AdExIF_model.html>`_ | |||
**Model Parameters** | |||
============= ============== ======== ======================================================================================================================== | |||
**Parameter** **Init Value** **Unit** **Explanation** | |||
------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ | |||
V_rest -65 mV Resting potential. | |||
V_reset -68 mV Reset potential after spike. | |||
V_th -30 mV Threshold potential of spike and reset. | |||
V_T -59.9 mV Threshold potential of generating action potential. | |||
delta_T 3.48 \ Spike slope factor. | |||
a 1 \ The sensitivity of the recovery variable :math:`u` to the sub-threshold fluctuations of the membrane potential :math:`v` | |||
b 1 \ The increment of :math:`w` produced by a spike. | |||
R 1 \ Membrane resistance. | |||
tau 10 ms Membrane time constant. Compute by R * C. | |||
tau_w 30 ms Time constant of the adaptation current. | |||
============= ============== ======== ======================================================================================================================== | |||
**Model Variables** | |||
================== ================= ========================================================= | |||
**Variables name** **Initial Value** **Explanation** | |||
------------------ ----------------- --------------------------------------------------------- | |||
V 0 Membrane potential. | |||
w 0 Adaptation current. | |||
input 0 External and synaptic input current. | |||
spike False Flag to mark whether the neuron is spiking. | |||
t_last_spike -1e7 Last spike time stamp. | |||
================== ================= ========================================================= | |||
**References** | |||
.. [1] Fourcaud-Trocmé, Nicolas, et al. "How spike generation | |||
mechanisms determine the neuronal response to fluctuating | |||
inputs." Journal of Neuroscience 23.37 (2003): 11628-11640. | |||
.. [2] http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model | |||
""" | |||
def __init__( | |||
self, | |||
size: Shape, | |||
V_rest: Parameter = -65., | |||
V_reset: Parameter = -68., | |||
V_th: Parameter = -30., | |||
V_T: Parameter = -59.9, | |||
delta_T: Parameter = 3.48, | |||
a: Parameter = 1., | |||
b: Parameter = 1., | |||
tau: Parameter = 10., | |||
tau_w: Parameter = 30., | |||
R: Parameter = 1., | |||
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), | |||
w_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), | |||
method: str = 'exp_auto', | |||
name: str = None | |||
): | |||
super(AdExIF, self).__init__(size=size, name=name) | |||
# parameters | |||
self.V_rest = V_rest | |||
self.V_reset = V_reset | |||
self.V_th = V_th | |||
self.V_T = V_T | |||
self.delta_T = delta_T | |||
self.a = a | |||
self.b = b | |||
self.tau = tau | |||
self.tau_w = tau_w | |||
self.R = R | |||
# variables | |||
check_initializer(V_initializer, 'V_initializer') | |||
check_initializer(w_initializer, 'w_initializer') | |||
self.V = bm.Variable(init_param(V_initializer, (self.num,))) | |||
self.w = bm.Variable(init_param(w_initializer, (self.num,))) | |||
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
# functions | |||
self.integral = odeint(method=method, f=self.derivative) | |||
def dV(self, V, t, w, I_ext): | |||
dVdt = (- V + self.V_rest + self.delta_T * bm.exp((V - self.V_T) / self.delta_T) - | |||
self.R * w + self.R * I_ext) / self.tau | |||
return dVdt | |||
def dw(self, w, t, V): | |||
dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w | |||
return dwdt | |||
@property | |||
def derivative(self): | |||
return JointEq([self.dV, self.dw]) | |||
def update(self, _t, _dt): | |||
V, w = self.integral(self.V, self.w, _t, self.input, dt=_dt) | |||
spike = V >= self.V_th | |||
self.t_last_spike[:] = bm.where(spike, _t, self.t_last_spike) | |||
self.V.value = bm.where(spike, self.V_reset, V) | |||
self.w.value = bm.where(spike, w + self.b, w) | |||
self.spike.value = spike | |||
self.input[:] = 0. | |||
class QuaIF(NeuGroup): | |||
r"""Quadratic Integrate-and-Fire neuron model. | |||
**Model Descriptions** | |||
In contrast to physiologically accurate but computationally expensive | |||
neuron models like the Hodgkin–Huxley model, the QIF model [1]_ seeks only | |||
to produce **action potential-like patterns** and ignores subtleties | |||
like gating variables, which play an important role in generating action | |||
potentials in a real neuron. However, the QIF model is incredibly easy | |||
to implement and compute, and relatively straightforward to study and | |||
understand, thus has found ubiquitous use in computational neuroscience. | |||
.. math:: | |||
\tau \frac{d V}{d t}=c(V-V_{rest})(V-V_c) + RI(t) | |||
where the parameters are taken to be :math:`c` =0.07, and :math:`V_c = -50 mV` (Latham et al., 2000). | |||
**Model Examples** | |||
.. plot:: | |||
:include-source: True | |||
>>> import brainpy as bp | |||
>>> | |||
>>> group = bp.dyn.QuaIF(1,) | |||
>>> | |||
>>> runner = bp.dyn.DSRunner(group, monitors=['V'], inputs=('input', 20.)) | |||
>>> runner.run(duration=200.) | |||
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True) | |||
**Model Parameters** | |||
============= ============== ======== ======================================================================================================================== | |||
**Parameter** **Init Value** **Unit** **Explanation** | |||
------------- -------------- -------- ------------------------------------------------------------------------------------------------------------------------ | |||
V_rest -65 mV Resting potential. | |||
V_reset -68 mV Reset potential after spike. | |||
V_th -30 mV Threshold potential of spike and reset. | |||
V_c -50 mV Critical voltage for spike initiation. Must be larger than V_rest. | |||
c .07 \ Coefficient describes membrane potential update. Larger than 0. | |||
R 1 \ Membrane resistance. | |||
tau 10 ms Membrane time constant. Compute by R * C. | |||
tau_ref 0 ms Refractory period length. | |||
============= ============== ======== ======================================================================================================================== | |||
**Model Variables** | |||
================== ================= ========================================================= | |||
**Variables name** **Initial Value** **Explanation** | |||
------------------ ----------------- --------------------------------------------------------- | |||
V 0 Membrane potential. | |||
input 0 External and synaptic input current. | |||
spike False Flag to mark whether the neuron is spiking. | |||
refractory False Flag to mark whether the neuron is in refractory period. | |||
t_last_spike -1e7 Last spike time stamp. | |||
================== ================= ========================================================= | |||
**References** | |||
.. [1] P. E. Latham, B.J. Richmond, P. Nelson and S. Nirenberg | |||
(2000) Intrinsic dynamics in neuronal networks. I. Theory. | |||
J. Neurophysiology 83, pp. 808–827. | |||
""" | |||
def __init__( | |||
self, | |||
size: Shape, | |||
V_rest: Parameter = -65., | |||
V_reset: Parameter = -68., | |||
V_th: Parameter = -30., | |||
V_c: Parameter = -50.0, | |||
c: Parameter = .07, | |||
R: Parameter = 1., | |||
tau: Parameter = 10., | |||
tau_ref: Parameter = 0., | |||
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), | |||
method: str = 'exp_auto', | |||
name: str = None | |||
): | |||
# initialization | |||
super(QuaIF, self).__init__(size=size, name=name) | |||
# parameters | |||
self.V_rest = V_rest | |||
self.V_reset = V_reset | |||
self.V_th = V_th | |||
self.V_c = V_c | |||
self.c = c | |||
self.R = R | |||
self.tau = tau | |||
self.tau_ref = tau_ref | |||
# variables | |||
self.V = bm.Variable(init_param(V_initializer, (self.num,))) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
# integral | |||
self.integral = odeint(method=method, f=self.derivative) | |||
def derivative(self, V, t, I_ext): | |||
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I_ext) / self.tau | |||
return dVdt | |||
def update(self, _t, _dt, **kwargs): | |||
refractory = (_t - self.t_last_spike) <= self.tau_ref | |||
V = self.integral(self.V, _t, self.input, dt=_dt) | |||
V = bm.where(refractory, self.V, V) | |||
spike = self.V_th <= V | |||
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike) | |||
self.V.value = bm.where(spike, self.V_reset, V) | |||
self.refractory.value = bm.logical_or(refractory, spike) | |||
self.spike.value = spike | |||
self.input[:] = 0. | |||
class AdQuaIF(NeuGroup): | |||
r"""Adaptive quadratic integrate-and-fire neuron model. | |||
**Model Descriptions** | |||
The adaptive quadratic integrate-and-fire neuron model [1]_ is given by: | |||
.. math:: | |||
\begin{aligned} | |||
\tau_m \frac{d V}{d t}&=c(V-V_{rest})(V-V_c) - w + I(t), \\ | |||
\tau_w \frac{d w}{d t}&=a(V-V_{rest}) - w, | |||
\end{aligned} | |||
once the membrane potential reaches the spike threshold, | |||
.. math:: | |||
V \rightarrow V_{reset}, \\ | |||
w \rightarrow w+b. | |||
**Model Examples** | |||
.. plot:: | |||
:include-source: True | |||
>>> import brainpy as bp | |||
>>> group = bp.dyn.AdQuaIF(1, ) | |||
>>> runner = bp.dyn.DSRunner(group, monitors=['V', 'w'], inputs=('input', 30.)) | |||
>>> runner.run(300) | |||
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) | |||
>>> fig.add_subplot(gs[0, 0]) | |||
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, ylabel='V') | |||
>>> fig.add_subplot(gs[1, 0]) | |||
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.w, ylabel='w', show=True) | |||
**Model Parameters** | |||
============= ============== ======== ======================================================= | |||
**Parameter** **Init Value** **Unit** **Explanation** | |||
------------- -------------- -------- ------------------------------------------------------- | |||
V_rest -65 mV Resting potential. | |||
V_reset -68 mV Reset potential after spike. | |||
V_th -30 mV Threshold potential of spike and reset. | |||
V_c -50 mV Critical voltage for spike initiation. Must be larger | |||
than :math:`V_{rest}`. | |||
a 1 \ The sensitivity of the recovery variable :math:`u` to | |||
the sub-threshold fluctuations of the membrane | |||
potential :math:`v` | |||
b .1 \ The increment of :math:`w` produced by a spike. | |||
c .07 \ Coefficient describes membrane potential update. | |||
Larger than 0. | |||
tau 10 ms Membrane time constant. | |||
tau_w 10 ms Time constant of the adaptation current. | |||
============= ============== ======== ======================================================= | |||
**Model Variables** | |||
================== ================= ========================================================== | |||
**Variables name** **Initial Value** **Explanation** | |||
------------------ ----------------- ---------------------------------------------------------- | |||
V 0 Membrane potential. | |||
w 0 Adaptation current. | |||
input 0 External and synaptic input current. | |||
spike False Flag to mark whether the neuron is spiking. | |||
t_last_spike -1e7 Last spike time stamp. | |||
================== ================= ========================================================== | |||
**References** | |||
.. [1] Izhikevich, E. M. (2004). Which model to use for cortical spiking | |||
neurons?. IEEE transactions on neural networks, 15(5), 1063-1070. | |||
.. [2] Touboul, Jonathan. "Bifurcation analysis of a general class of | |||
nonlinear integrate-and-fire neurons." SIAM Journal on Applied | |||
Mathematics 68, no. 4 (2008): 1045-1079. | |||
""" | |||
def __init__( | |||
self, | |||
size: Shape, | |||
V_rest: Parameter = -65., | |||
V_reset: Parameter = -68., | |||
V_th: Parameter = -30., | |||
V_c: Parameter = -50.0, | |||
a: Parameter = 1., | |||
b: Parameter = .1, | |||
c: Parameter = .07, | |||
tau: Parameter = 10., | |||
tau_w: Parameter = 10., | |||
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), | |||
w_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), | |||
method: str = 'exp_auto', | |||
name: str = None | |||
): | |||
super(AdQuaIF, self).__init__(size=size, name=name) | |||
# parameters | |||
self.V_rest = V_rest | |||
self.V_reset = V_reset | |||
self.V_th = V_th | |||
self.V_c = V_c | |||
self.c = c | |||
self.a = a | |||
self.b = b | |||
self.tau = tau | |||
self.tau_w = tau_w | |||
# variables | |||
check_initializer(V_initializer, 'V_initializer') | |||
check_initializer(w_initializer, 'w_initializer') | |||
self.V = bm.Variable(init_param(V_initializer, (self.num,))) | |||
self.w = bm.Variable(init_param(w_initializer, (self.num,))) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
# integral | |||
self.integral = odeint(method=method, f=self.derivative) | |||
def dV(self, V, t, w, I_ext): | |||
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I_ext) / self.tau | |||
return dVdt | |||
def dw(self, w, t, V): | |||
dwdt = (self.a * (V - self.V_rest) - w) / self.tau_w | |||
return dwdt | |||
@property | |||
def derivative(self): | |||
return JointEq([self.dV, self.dw]) | |||
def update(self, _t, _dt): | |||
V, w = self.integral(self.V, self.w, _t, self.input, dt=_dt) | |||
spike = self.V_th <= V | |||
self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike) | |||
self.V.value = bm.where(spike, self.V_reset, V) | |||
self.w.value = bm.where(spike, w + self.b, w) | |||
self.spike.value = spike | |||
self.input[:] = 0. | |||
class GIF(NeuGroup): | |||
r"""Generalized Integrate-and-Fire model. | |||
**Model Descriptions** | |||
The generalized integrate-and-fire model [1]_ is given by | |||
.. math:: | |||
&\frac{d I_j}{d t} = - k_j I_j | |||
&\frac{d V}{d t} = ( - (V - V_{rest}) + R\sum_{j}I_j + RI) / \tau | |||
&\frac{d V_{th}}{d t} = a(V - V_{rest}) - b(V_{th} - V_{th\infty}) | |||
When :math:`V` meet :math:`V_{th}`, Generalized IF neuron fires: | |||
.. math:: | |||
&I_j \leftarrow R_j I_j + A_j | |||
&V \leftarrow V_{reset} | |||
&V_{th} \leftarrow max(V_{th_{reset}}, V_{th}) | |||
Note that :math:`I_j` refers to arbitrary number of internal currents. | |||
**Model Examples** | |||
- `Detailed examples to reproduce different firing patterns <https://brainpy-examples.readthedocs.io/en/latest/neurons/Niebur_2009_GIF.html>`_ | |||
**Model Parameters** | |||
============= ============== ======== ==================================================================== | |||
**Parameter** **Init Value** **Unit** **Explanation** | |||
------------- -------------- -------- -------------------------------------------------------------------- | |||
V_rest -70 mV Resting potential. | |||
V_reset -70 mV Reset potential after spike. | |||
V_th_inf -50 mV Target value of threshold potential :math:`V_{th}` updating. | |||
V_th_reset -60 mV Free parameter, should be larger than :math:`V_{reset}`. | |||
R 20 \ Membrane resistance. | |||
tau 20 ms Membrane time constant. Compute by :math:`R * C`. | |||
a 0 \ Coefficient describes the dependence of | |||
:math:`V_{th}` on membrane potential. | |||
b 0.01 \ Coefficient describes :math:`V_{th}` update. | |||
k1 0.2 \ Constant pf :math:`I1`. | |||
k2 0.02 \ Constant of :math:`I2`. | |||
R1 0 \ Free parameter. | |||
Describes dependence of :math:`I_1` reset value on | |||
:math:`I_1` value before spiking. | |||
R2 1 \ Free parameter. | |||
Describes dependence of :math:`I_2` reset value on | |||
:math:`I_2` value before spiking. | |||
A1 0 \ Free parameter. | |||
A2 0 \ Free parameter. | |||
============= ============== ======== ==================================================================== | |||
**Model Variables** | |||
================== ================= ========================================================= | |||
**Variables name** **Initial Value** **Explanation** | |||
------------------ ----------------- --------------------------------------------------------- | |||
V -70 Membrane potential. | |||
input 0 External and synaptic input current. | |||
spike False Flag to mark whether the neuron is spiking. | |||
V_th -50 Spiking threshold potential. | |||
I1 0 Internal current 1. | |||
I2 0 Internal current 2. | |||
t_last_spike -1e7 Last spike time stamp. | |||
================== ================= ========================================================= | |||
**References** | |||
.. [1] Mihalaş, Ştefan, and Ernst Niebur. "A generalized linear | |||
integrate-and-fire neural model produces diverse spiking | |||
behaviors." Neural computation 21.3 (2009): 704-718. | |||
.. [2] Teeter, Corinne, Ramakrishnan Iyer, Vilas Menon, Nathan | |||
Gouwens, David Feng, Jim Berg, Aaron Szafer et al. "Generalized | |||
leaky integrate-and-fire models classify multiple neuron types." | |||
Nature communications 9, no. 1 (2018): 1-15. | |||
""" | |||
def __init__( | |||
self, | |||
size: Shape, | |||
V_rest: Parameter = -70., | |||
V_reset: Parameter = -70., | |||
V_th_inf: Parameter = -50., | |||
V_th_reset: Parameter = -60., | |||
R: Parameter = 20., | |||
tau: Parameter = 20., | |||
a: Parameter = 0., | |||
b: Parameter = 0.01, | |||
k1: Parameter = 0.2, | |||
k2: Parameter = 0.02, | |||
R1: Parameter = 0., | |||
R2: Parameter = 1., | |||
A1: Parameter = 0., | |||
A2: Parameter = 0., | |||
V_initializer: Union[Initializer, Callable, Tensor] = OneInit(-70.), | |||
I1_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), | |||
I2_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), | |||
Vth_initializer: Union[Initializer, Callable, Tensor] = OneInit(-50.), | |||
method: str = 'exp_auto', | |||
name: str = None | |||
): | |||
# initialization | |||
super(GIF, self).__init__(size=size, name=name) | |||
# params | |||
self.V_rest = V_rest | |||
self.V_reset = V_reset | |||
self.V_th_inf = V_th_inf | |||
self.V_th_reset = V_th_reset | |||
self.R = R | |||
self.tau = tau | |||
self.a = a | |||
self.b = b | |||
self.k1 = k1 | |||
self.k2 = k2 | |||
self.R1 = R1 | |||
self.R2 = R2 | |||
self.A1 = A1 | |||
self.A2 = A2 | |||
# variables | |||
check_initializer(V_initializer, 'V_initializer') | |||
check_initializer(I1_initializer, 'I1_initializer') | |||
check_initializer(I2_initializer, 'I2_initializer') | |||
check_initializer(Vth_initializer, 'Vth_initializer') | |||
self.I1 = bm.Variable(init_param(I1_initializer, (self.num,))) | |||
self.I2 = bm.Variable(init_param(I2_initializer, (self.num,))) | |||
self.V = bm.Variable(init_param(V_initializer, (self.num,))) | |||
self.V_th = bm.Variable(init_param(Vth_initializer, (self.num,))) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
# integral | |||
self.integral = odeint(method=method, f=self.derivative) | |||
def dI1(self, I1, t): | |||
return - self.k1 * I1 | |||
def dI2(self, I2, t): | |||
return - self.k2 * I2 | |||
def dVth(self, V_th, t, V): | |||
return self.a * (V - self.V_rest) - self.b * (V_th - self.V_th_inf) | |||
def dV(self, V, t, I1, I2, I_ext): | |||
return (- (V - self.V_rest) + self.R * I_ext + self.R * I1 + self.R * I2) / self.tau | |||
@property | |||
def derivative(self): | |||
return JointEq([self.dI1, self.dI2, self.dVth, self.dV]) | |||
def update(self, _t, _dt): | |||
I1, I2, V_th, V = self.integral(self.I1, self.I2, self.V_th, self.V, _t, self.input, dt=_dt) | |||
spike = self.V_th <= V | |||
V = bm.where(spike, self.V_reset, V) | |||
I1 = bm.where(spike, self.R1 * I1 + self.A1, I1) | |||
I2 = bm.where(spike, self.R2 * I2 + self.A2, I2) | |||
reset_th = bm.logical_and(V_th < self.V_th_reset, spike) | |||
V_th = bm.where(reset_th, self.V_th_reset, V_th) | |||
self.spike.value = spike | |||
self.I1.value = I1 | |||
self.I2.value = I2 | |||
self.V_th.value = V_th | |||
self.V.value = V | |||
self.input[:] = 0. | |||
class Izhikevich(NeuGroup): | |||
r"""The Izhikevich neuron model. | |||
@@ -79,8 +924,20 @@ class Izhikevich(NeuGroup): | |||
IEEE transactions on neural networks 15.5 (2004): 1063-1070. | |||
""" | |||
def __init__(self, size, a=0.02, b=0.20, c=-65., d=8., tau_ref=0., | |||
V_th=30., method='exp_auto', name=None): | |||
def __init__( | |||
self, | |||
size: Shape, | |||
a: Parameter = 0.02, | |||
b: Parameter = 0.20, | |||
c: Parameter = -65., | |||
d: Parameter = 8., | |||
tau_ref: Parameter = 0., | |||
V_th: Parameter = 30., | |||
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), | |||
u_initializer: Union[Initializer, Callable, Tensor] = OneInit(), | |||
method: str = 'exp_auto', | |||
name: str = None | |||
): | |||
# initialization | |||
super(Izhikevich, self).__init__(size=size, name=name) | |||
@@ -93,11 +950,13 @@ class Izhikevich(NeuGroup): | |||
self.tau_ref = tau_ref | |||
# variables | |||
self.u = bm.Variable(bm.ones(self.num)) | |||
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.V = bm.Variable(bm.zeros(self.num)) | |||
check_initializer(V_initializer, 'V_initializer') | |||
check_initializer(u_initializer, 'u_initializer') | |||
self.u = bm.Variable(init_param(u_initializer, (self.num,))) | |||
self.V = bm.Variable(init_param(V_initializer, (self.num,))) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
# functions | |||
@@ -157,7 +1016,7 @@ class HindmarshRose(NeuGroup): | |||
>>> import matplotlib.pyplot as plt | |||
>>> | |||
>>> bp.math.set_dt(dt=0.01) | |||
>>> bp.set_default_odeint('rk4') | |||
>>> bp.ode.set_default_odeint('rk4') | |||
>>> | |||
>>> types = ['quiescence', 'spiking', 'bursting', 'irregular_spiking', 'irregular_bursting'] | |||
>>> bs = bp.math.array([1.0, 3.5, 2.5, 2.95, 2.8]) | |||
@@ -222,8 +1081,23 @@ class HindmarshRose(NeuGroup): | |||
033128. | |||
""" | |||
def __init__(self, size, a=1., b=3., c=1., d=5., r=0.01, s=4., V_rest=-1.6, | |||
V_th=1.0, method='exp_auto', name=None): | |||
def __init__( | |||
self, | |||
size: Shape, | |||
a: Parameter = 1., | |||
b: Parameter = 3., | |||
c: Parameter = 1., | |||
d: Parameter = 5., | |||
r: Parameter = 0.01, | |||
s: Parameter = 4., | |||
V_rest: Parameter = -1.6, | |||
V_th: Parameter = 1.0, | |||
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), | |||
y_initializer: Union[Initializer, Callable, Tensor] = OneInit(-10.), | |||
z_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), | |||
method: str = 'exp_auto', | |||
name: str = None | |||
): | |||
# initialization | |||
super(HindmarshRose, self).__init__(size=size, name=name) | |||
@@ -238,9 +1112,12 @@ class HindmarshRose(NeuGroup): | |||
self.V_rest = V_rest | |||
# variables | |||
self.z = bm.Variable(bm.zeros(self.num)) | |||
self.y = bm.Variable(bm.ones(self.num) * -10.) | |||
self.V = bm.Variable(bm.zeros(self.num)) | |||
check_initializer(V_initializer, 'V_initializer') | |||
check_initializer(y_initializer, 'y_initializer') | |||
check_initializer(z_initializer, 'z_initializer') | |||
self.z = bm.Variable(init_param(V_initializer, (self.num,))) | |||
self.y = bm.Variable(init_param(y_initializer, (self.num,))) | |||
self.V = bm.Variable(init_param(z_initializer, (self.num,))) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
@@ -269,3 +1146,138 @@ class HindmarshRose(NeuGroup): | |||
self.y.value = y | |||
self.z.value = z | |||
self.input[:] = 0. | |||
class FHN(NeuGroup): | |||
r"""FitzHugh-Nagumo neuron model. | |||
**Model Descriptions** | |||
The FitzHugh–Nagumo model (FHN), named after Richard FitzHugh (1922–2007) | |||
who suggested the system in 1961 [1]_ and J. Nagumo et al. who created the | |||
equivalent circuit the following year, describes a prototype of an excitable | |||
system (e.g., a neuron). | |||
The motivation for the FitzHugh-Nagumo model was to isolate conceptually | |||
the essentially mathematical properties of excitation and propagation from | |||
the electrochemical properties of sodium and potassium ion flow. The model | |||
consists of | |||
- a *voltage-like variable* having cubic nonlinearity that allows regenerative | |||
self-excitation via a positive feedback, and | |||
- a *recovery variable* having a linear dynamics that provides a slower negative feedback. | |||
.. math:: | |||
\begin{aligned} | |||
{\dot {v}} &=v-{\frac {v^{3}}{3}}-w+RI_{\rm {ext}}, \\ | |||
\tau {\dot {w}}&=v+a-bw. | |||
\end{aligned} | |||
The FHN Model is an example of a relaxation oscillator | |||
because, if the external stimulus :math:`I_{\text{ext}}` | |||
exceeds a certain threshold value, the system will exhibit | |||
a characteristic excursion in phase space, before the | |||
variables :math:`v` and :math:`w` relax back to their rest values. | |||
This behaviour is typical for spike generations (a short, | |||
nonlinear elevation of membrane voltage :math:`v`, | |||
diminished over time by a slower, linear recovery variable | |||
:math:`w`) in a neuron after stimulation by an external | |||
input current. | |||
**Model Examples** | |||
.. plot:: | |||
:include-source: True | |||
>>> import brainpy as bp | |||
>>> fhn = bp.dyn.FHN(1) | |||
>>> runner = bp.dyn.DSRunner(fhn, inputs=('input', 1.), monitors=['V', 'w']) | |||
>>> runner.run(100.) | |||
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.w, legend='w') | |||
>>> bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', show=True) | |||
**Model Parameters** | |||
============= ============== ======== ======================== | |||
**Parameter** **Init Value** **Unit** **Explanation** | |||
------------- -------------- -------- ------------------------ | |||
a 1 \ Positive constant | |||
b 1 \ Positive constant | |||
tau 10 ms Membrane time constant. | |||
V_th 1.8 mV Threshold potential of spike. | |||
============= ============== ======== ======================== | |||
**Model Variables** | |||
================== ================= ========================================================= | |||
**Variables name** **Initial Value** **Explanation** | |||
------------------ ----------------- --------------------------------------------------------- | |||
V 0 Membrane potential. | |||
w 0 A recovery variable which represents | |||
the combined effects of sodium channel | |||
de-inactivation and potassium channel | |||
deactivation. | |||
input 0 External and synaptic input current. | |||
spike False Flag to mark whether the neuron is spiking. | |||
t_last_spike -1e7 Last spike time stamp. | |||
================== ================= ========================================================= | |||
**References** | |||
.. [1] FitzHugh, Richard. "Impulses and physiological states in theoretical models of nerve membrane." Biophysical journal 1.6 (1961): 445-466. | |||
.. [2] https://en.wikipedia.org/wiki/FitzHugh%E2%80%93Nagumo_model | |||
.. [3] http://www.scholarpedia.org/article/FitzHugh-Nagumo_model | |||
""" | |||
def __init__( | |||
self, | |||
size: Shape, | |||
a: Parameter = 0.7, | |||
b: Parameter = 0.8, | |||
tau: Parameter = 12.5, | |||
Vth: Parameter = 1.8, | |||
V_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), | |||
w_initializer: Union[Initializer, Callable, Tensor] = ZeroInit(), | |||
method: str = 'exp_auto', | |||
name: str = None | |||
): | |||
# initialization | |||
super(FHN, self).__init__(size=size, name=name) | |||
# parameters | |||
self.a = a | |||
self.b = b | |||
self.tau = tau | |||
self.Vth = Vth | |||
# variables | |||
check_initializer(V_initializer, 'V_initializer') | |||
check_initializer(w_initializer, 'w_initializer') | |||
self.w = bm.Variable(init_param(w_initializer, (self.num,))) | |||
self.V = bm.Variable(init_param(V_initializer, (self.num,))) | |||
self.input = bm.Variable(bm.zeros(self.num)) | |||
self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) | |||
self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) | |||
# integral | |||
self.integral = odeint(method=method, f=self.derivative) | |||
def dV(self, V, t, w, I_ext): | |||
return V - V * V * V / 3 - w + I_ext | |||
def dw(self, w, t, V): | |||
return (V + self.a - self.b * w) / self.tau | |||
@property | |||
def derivative(self): | |||
return JointEq([self.dV, self.dw]) | |||
def update(self, _t, _dt): | |||
V, w = self.integral(self.V, self.w, _t, self.input, dt=_dt) | |||
self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth) | |||
self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike) | |||
self.V.value = V | |||
self.w.value = w | |||
self.input[:] = 0. |
@@ -7,6 +7,7 @@ import numpy as np | |||
import tqdm.auto | |||
from jax.experimental.host_callback import id_tap | |||
from brainpy.base.base import TensorCollector | |||
from brainpy import math as bm | |||
from brainpy.dyn import utils | |||
from brainpy.dyn.base import DynamicalSystem | |||
@@ -74,7 +75,6 @@ class DSRunner(Runner): | |||
self.dyn_vars.update({'_i': self._i}) | |||
else: | |||
self._i = None | |||
self.dyn_vars.update(self.target.vars().unique()) | |||
# run function | |||
self._run_func = self.build_run_function() | |||
@@ -159,29 +159,33 @@ class DSRunner(Runner): | |||
return_with_idx[key] = (data, bm.asarray(idx)) | |||
def func(_t, _dt): | |||
res = {k: (v.flatten() if bm.ndim(v) > 1 else v) for k, v in return_without_idx.items()} | |||
res = {k: (v.flatten() if bm.ndim(v) > 1 else v.value) | |||
for k, v in return_without_idx.items()} | |||
res.update({k: (v.flatten()[idx] if bm.ndim(v) > 1 else v[idx]) | |||
for k, (v, idx) in return_with_idx.items()}) | |||
return res | |||
return func | |||
def _run_one_step(self, t_and_dt): | |||
_t, _dt = t_and_dt[0], t_and_dt[1] | |||
self._input_step(_t=_t, _dt=_dt) | |||
self.target.update(_t=_t, _dt=_dt) | |||
def _run_one_step(self, _t): | |||
self._input_step(_t=_t, _dt=self.dt) | |||
self.target.update(_t=_t, _dt=self.dt) | |||
if self.progress_bar: | |||
id_tap(lambda *args: self._pbar.update(), ()) | |||
return self._monitor_step(_t=_t, _dt=_dt) | |||
return self._monitor_step(_t=_t, _dt=self.dt) | |||
def build_run_function(self): | |||
if self.jit: | |||
f_run = bm.make_loop(self._run_one_step, dyn_vars=self.dyn_vars, has_return=True) | |||
dyn_vars = TensorCollector() | |||
dyn_vars.update(self.dyn_vars) | |||
dyn_vars.update(self.target.vars().unique()) | |||
f_run = bm.make_loop(self._run_one_step, | |||
dyn_vars=dyn_vars, | |||
has_return=True) | |||
else: | |||
def f_run(t_and_dt): | |||
all_t, all_dt = t_and_dt | |||
def f_run(all_t): | |||
for i in range(all_t.shape[0]): | |||
mon = self._run_one_step((all_t[i], all_dt[i])) | |||
mon = self._run_one_step(all_t[i]) | |||
for k, v in mon.items(): | |||
self.mon.item_contents[k].append(v) | |||
return None, {} | |||
@@ -212,8 +216,7 @@ class DSRunner(Runner): | |||
start_t = float(self._start_t) | |||
end_t = float(start_t + duration) | |||
# times | |||
times = bm.arange(start_t, end_t, self.dt) | |||
time_steps = bm.ones_like(times) * self.dt | |||
times = np.arange(start_t, end_t, self.dt) | |||
# build monitor | |||
for key in self.mon.item_contents.keys(): | |||
self.mon.item_contents[key] = [] # reshape the monitor items | |||
@@ -223,7 +226,7 @@ class DSRunner(Runner): | |||
self._pbar.set_description(f"Running a duration of {round(float(duration), 3)} ({times.size} steps)", | |||
refresh=True) | |||
t0 = time.time() | |||
_, hists = self._run_func([times.value, time_steps.value]) | |||
_, hists = self._run_func(times) | |||
running_time = time.time() - t0 | |||
if self.progress_bar: | |||
self._pbar.close() | |||
@@ -277,23 +280,24 @@ class ReportRunner(DSRunner): | |||
# Build the update function | |||
if jit: | |||
self._update_step = bm.jit(self.target.update, dyn_vars=self.dyn_vars) | |||
dyn_vars = TensorCollector() | |||
dyn_vars.update(self.dyn_vars) | |||
dyn_vars.update(self.target.vars().unique()) | |||
self._update_step = bm.jit(self.target.update, dyn_vars=dyn_vars) | |||
else: | |||
self._update_step = self.target.update | |||
def _run_one_step(self, t_and_dt): | |||
_t, _dt = t_and_dt[0], t_and_dt[1] | |||
self._input_step(_t=_t, _dt=_dt) | |||
self._update_step(_t=_t, _dt=_dt) | |||
def _run_one_step(self, _t): | |||
self._input_step(_t, self.dt) | |||
self._update_step(_t, self.dt) | |||
if self.progress_bar: | |||
self._pbar.update() | |||
return self._monitor_step(_t=_t, _dt=_dt) | |||
return self._monitor_step(_t, self.dt) | |||
def build_run_function(self): | |||
def f_run(t_and_dt): | |||
all_t, all_dt = t_and_dt | |||
def f_run(all_t): | |||
for i in range(all_t.shape[0]): | |||
mon = self._run_one_step((all_t[i], all_dt[i])) | |||
mon = self._run_one_step(all_t[i]) | |||
for k, v in mon.items(): | |||
self.mon.item_contents[k].append(v) | |||
return None, {} |
@@ -1,3 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
from .ds_runner import * |
@@ -3,3 +3,5 @@ | |||
from .abstract_models import * | |||
from .biological_models import * | |||
from .learning_rules import * | |||
from .delay_coupling import * | |||
@@ -1,9 +1,10 @@ | |||
# -*- coding: utf-8 -*- | |||
import brainpy.math as bm | |||
from brainpy.dyn.base import NeuGroup | |||
from brainpy.dyn.base import TwoEndConn, ConstantDelay | |||
from brainpy.integrators.joint_eq import JointEq | |||
from brainpy.integrators.ode import odeint | |||
from brainpy.dyn.base import TwoEndConn, ConstantDelay | |||
__all__ = [ | |||
'DeltaSynapse', | |||
@@ -67,8 +68,17 @@ class DeltaSynapse(TwoEndConn): | |||
""" | |||
def __init__(self, pre, post, conn, delay=0., post_has_ref=False, w=1., | |||
post_key='V', name=None): | |||
def __init__( | |||
self, | |||
pre: NeuGroup, | |||
post: NeuGroup, | |||
conn, | |||
delay=0., | |||
post_has_ref=False, | |||
w=1., | |||
post_key='V', | |||
name=None | |||
): | |||
super(DeltaSynapse, self).__init__(pre=pre, post=post, conn=conn, name=name) | |||
self.check_pre_attrs('spike') | |||
self.check_post_attrs(post_key) | |||
@@ -193,8 +203,17 @@ class ExpCUBA(TwoEndConn): | |||
Cambridge: Cambridge UP, 2011. 172-95. Print. | |||
""" | |||
def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, | |||
method='exp_auto', name=None): | |||
def __init__( | |||
self, | |||
pre: NeuGroup, | |||
post: NeuGroup, | |||
conn, | |||
g_max=1., | |||
delay=0., | |||
tau=8.0, | |||
method='exp_auto', | |||
name=None | |||
): | |||
super(ExpCUBA, self).__init__(pre=pre, post=post, conn=conn, name=name) | |||
self.check_pre_attrs('spike') | |||
self.check_post_attrs('input', 'V') | |||
@@ -0,0 +1,206 @@ | |||
# -*- coding: utf-8 -*- | |||
from typing import Optional, Union, Sequence, Dict, List | |||
from jax import vmap | |||
import brainpy.math as bm | |||
from brainpy.dyn.base import TwoEndConn | |||
from brainpy.initialize import Initializer, ZeroInit | |||
from brainpy.tools.checking import check_sequence | |||
from brainpy.types import Tensor | |||
__all__ = [ | |||
'DelayCoupling', | |||
'DiffusiveDelayCoupling', | |||
'AdditiveDelayCoupling', | |||
] | |||
class DelayCoupling(TwoEndConn): | |||
""" | |||
Delay coupling base class. | |||
coupling: str | |||
The way of coupling. | |||
gc: float | |||
The global coupling strength. | |||
signal_speed: float | |||
Signal transmission speed between areas. | |||
sc_mat: optional, tensor | |||
Structural connectivity matrix. Adjacency matrix of coupling strengths, | |||
will be normalized to 1. If not given, then a single node simulation | |||
will be assumed. Default None | |||
fl_mat: optional, tensor | |||
Fiber length matrix. Will be used for computing the | |||
delay matrix together with the signal transmission | |||
speed parameter `signal_speed`. Default None. | |||
""" | |||
"""Global delay variables. Useful when the same target | |||
variable is used in multiple mappings.""" | |||
global_delay_vars: Dict[str, bm.LengthDelay] = dict() | |||
def __init__( | |||
self, | |||
pre, | |||
post, | |||
from_to: Union[str, Sequence[str]], | |||
conn_mat: Tensor, | |||
delay_mat: Optional[Tensor] = None, | |||
delay_initializer: Initializer = ZeroInit(), | |||
domain: str = 'local', | |||
name: str = None | |||
): | |||
super(DelayCoupling, self).__init__(pre, post, name=name) | |||
# local delay variables | |||
self.local_delay_vars: Dict[str, bm.LengthDelay] = dict() | |||
# domain | |||
if domain not in ['global', 'local']: | |||
raise ValueError('"domain" must be a string in ["global", "local"]. ' | |||
f'Bug we got {domain}.') | |||
self.domain = domain | |||
# pairs of (source, destination) | |||
self.source_target_pairs: Dict[str, List[bm.Variable]] = dict() | |||
source_vars = {} | |||
if isinstance(from_to, str): | |||
from_to = [from_to] | |||
check_sequence(from_to, 'from_to', elem_type=str, allow_none=False) | |||
for pair in from_to: | |||
splits = [v.strip() for v in pair.split('->')] | |||
if len(splits) != 2: | |||
raise ValueError('The (source, target) pair in "from_to" ' | |||
'should be defined as "a -> b".') | |||
if not hasattr(self.pre, splits[0]): | |||
raise ValueError(f'"{splits[0]}" is not defined in pre-synaptic group {self.pre.name}') | |||
if not hasattr(self.post, splits[1]): | |||
raise ValueError(f'"{splits[1]}" is not defined in post-synaptic group {self.post.name}') | |||
source = f'{self.pre.name}.{splits[0]}' | |||
target = getattr(self.post, splits[1]) | |||
if splits[0] not in self.source_target_pairs: | |||
self.source_target_pairs[source] = [target] | |||
source_vars[source] = getattr(self.pre, splits[0]) | |||
if not isinstance(source_vars[source], bm.Variable): | |||
raise ValueError(f'The target variable {source} for delay should ' | |||
f'be an instance of brainpy.math.Variable, while ' | |||
f'we got {type(source_vars[source])}') | |||
else: | |||
if target in self.source_target_pairs: | |||
raise ValueError(f'{pair} has been defined twice in {from_to}.') | |||
self.source_target_pairs[source].append(target) | |||
# Connection matrix | |||
conn_mat = bm.asarray(conn_mat) | |||
required_shape = (self.post.num, self.pre.num) | |||
if conn_mat.shape != required_shape: | |||
raise ValueError(f'we expect the structural connection matrix has the shape of ' | |||
f'(post.num, pre.num), i.e., {required_shape}, ' | |||
f'while we got {conn_mat.shape}.') | |||
self.conn_mat = bm.asarray(conn_mat) | |||
bm.fill_diagonal(self.conn_mat, 0) | |||
# Delay matrix | |||
if delay_mat is None: | |||
self.delay_mat = bm.zeros(required_shape, dtype=bm.int_) | |||
else: | |||
if delay_mat.shape != required_shape: | |||
raise ValueError(f'we expect the fiber length matrix has the shape of ' | |||
f'(post.num, pre.num), i.e., {required_shape}. ' | |||
f'While we got {delay_mat.shape}.') | |||
self.delay_mat = bm.asarray(delay_mat, dtype=bm.int_) | |||
# delay variables | |||
num_delay_step = int(self.delay_mat.max()) | |||
for var in self.source_target_pairs.keys(): | |||
if domain == 'local': | |||
variable = source_vars[var] | |||
shape = (num_delay_step,) + variable.shape | |||
delay_data = delay_initializer(shape, dtype=variable.dtype) | |||
self.local_delay_vars[var] = bm.LengthDelay(variable, num_delay_step, delay_data) | |||
else: | |||
if var not in self.global_delay_vars: | |||
variable = source_vars[var] | |||
shape = (num_delay_step,) + variable.shape | |||
delay_data = delay_initializer(shape, dtype=variable.dtype) | |||
self.global_delay_vars[var] = bm.LengthDelay(variable, num_delay_step, delay_data) | |||
# save into local delay vars when first seen "var", | |||
# for later update current value! | |||
self.local_delay_vars[var] = self.global_delay_vars[var] | |||
else: | |||
if self.global_delay_vars[var].delay_len < num_delay_step: | |||
variable = source_vars[var] | |||
shape = (num_delay_step,) + variable.shape | |||
delay_data = delay_initializer(shape, dtype=variable.dtype) | |||
self.global_delay_vars[var].init(variable, num_delay_step, delay_data) | |||
self.register_implicit_nodes(self.local_delay_vars) | |||
self.register_implicit_nodes(self.global_delay_vars) | |||
def update(self, _t, _dt): | |||
raise NotImplementedError('Must implement the update() function by users.') | |||
class DiffusiveDelayCoupling(DelayCoupling): | |||
def update(self, _t, _dt): | |||
for source, targets in self.source_target_pairs.items(): | |||
# delay variable | |||
if self.domain == 'local': | |||
delay_var: bm.LengthDelay = self.local_delay_vars[source] | |||
elif self.domain == 'global': | |||
delay_var: bm.LengthDelay = self.global_delay_vars[source] | |||
else: | |||
raise ValueError(f'Unknown domain: {self.domain}') | |||
# current data | |||
name, var = source.split('.') | |||
assert name == self.pre.name | |||
variable = getattr(self.pre, var) | |||
# delays | |||
f = vmap(lambda i: delay_var(self.delay_mat[i], bm.arange(self.pre.num))) # (pre.num,) | |||
delays = f(bm.arange(self.post.num).value) | |||
diffusive = delays - bm.expand_dims(variable, axis=1) # (post.num, pre.num) | |||
diffusive = (self.conn_mat * diffusive).sum(axis=1) | |||
# output to target variable | |||
for target in targets: | |||
target.value += diffusive | |||
# update | |||
if source in self.local_delay_vars: | |||
delay_var.update(variable) | |||
class AdditiveDelayCoupling(DelayCoupling): | |||
def update(self, _t, _dt): | |||
for source, targets in self.source_target_pairs.items(): | |||
# delay variable | |||
if self.domain == 'local': | |||
delay_var: bm.LengthDelay = self.local_delay_vars[source] | |||
elif self.domain == 'global': | |||
delay_var: bm.LengthDelay = self.global_delay_vars[source] | |||
else: | |||
raise ValueError(f'Unknown domain: {self.domain}') | |||
# current data | |||
name, var = source.split('.') | |||
assert name == self.pre.name | |||
variable = getattr(self.pre, var) | |||
# delay function | |||
f = vmap(lambda i: delay_var(self.delay_mat[i], bm.arange(self.pre.num))) # (pre.num,) | |||
delays = f(bm.arange(self.post.num)) # (post.num, pre.num) | |||
additive = (self.conn_mat * delays).sum(axis=1) | |||
# output to target variable | |||
for target in targets: | |||
target.value += additive | |||
# update | |||
if source in self.local_delay_vars: | |||
delay_var.update(variable) |
@@ -101,7 +101,8 @@ class JaxTracerError(MathError): | |||
else: | |||
raise ValueError | |||
msg += 'While there are changed variables which are not wrapped into "dyn_vars". Please check!' | |||
# msg += 'While there are changed variables which are not wrapped into "dyn_vars". Please check!' | |||
msg = 'While there are changed variables which are not wrapped into "dyn_vars". Please check!' | |||
super(JaxTracerError, self).__init__(msg) | |||
@@ -6,6 +6,7 @@ You can access them through ``brainpy.init.XXX``. | |||
""" | |||
from .base import * | |||
from .generic import * | |||
from .random_inits import * | |||
from .regular_inits import * | |||
from .decay_inits import * |
@@ -13,7 +13,7 @@ class Initializer(abc.ABC): | |||
"""Base Initialization Class.""" | |||
@abc.abstractmethod | |||
def __call__(self, shape): | |||
def __call__(self, shape, dtype=None): | |||
raise NotImplementedError | |||
@@ -21,7 +21,7 @@ class InterLayerInitializer(Initializer): | |||
"""The superclass of Initializers that initialize the weights between two layers.""" | |||
@abc.abstractmethod | |||
def __call__(self, shape): | |||
def __call__(self, shape, dtype=None): | |||
raise NotImplementedError | |||
@@ -29,5 +29,5 @@ class IntraLayerInitializer(Initializer): | |||
"""The superclass of Initializers that initialize the weights within a layer.""" | |||
@abc.abstractmethod | |||
def __call__(self, shape): | |||
def __call__(self, shape, dtype=None): | |||
raise NotImplementedError |
@@ -0,0 +1,46 @@ | |||
# -*- coding: utf-8 -*- | |||
from typing import Union, Callable | |||
import jax.numpy as jnp | |||
import numpy as onp | |||
import brainpy.math as bm | |||
from brainpy.tools.others import to_size | |||
from brainpy.types import Shape | |||
from .base import Initializer | |||
__all__ = [ | |||
'init_param', | |||
] | |||
def init_param(param: Union[Callable, Initializer, bm.ndarray, jnp.ndarray], | |||
size: Shape): | |||
"""Initialize parameters. | |||
Parameters | |||
---------- | |||
param: callable, Initializer, bm.ndarray, jnp.ndarray | |||
The initialization of the parameter. | |||
- If it is None, the created parameter will be None. | |||
- If it is a callable function :math:`f`, the ``f(size)`` will be returned. | |||
- If it is an instance of :py:class:`brainpy.init.Initializer``, the ``f(size)`` will be returned. | |||
- If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``. | |||
size: int, sequence of int | |||
The shape of the parameter. | |||
""" | |||
size = to_size(size) | |||
if param is None: | |||
return None | |||
elif callable(param): | |||
param = param(size) | |||
elif isinstance(param, (onp.ndarray, jnp.ndarray)): | |||
param = bm.asarray(param) | |||
elif isinstance(param, (bm.JaxArray,)): | |||
param = param | |||
else: | |||
raise ValueError(f'Unknown param type {type(param)}: {param}') | |||
assert param.shape == size, f'"param.shape" is not the required size {size}' | |||
return param | |||
@@ -40,7 +40,7 @@ class Normal(InterLayerInitializer): | |||
def __init__(self, scale=1., seed=None): | |||
super(Normal, self).__init__() | |||
self.scale = scale | |||
self.rng = bm.random.RandomState(seed=seed) | |||
self.rng = np.random.RandomState(seed=seed) | |||
def __call__(self, shape, dtype=None): | |||
shape = [tools.size2num(d) for d in shape] | |||
@@ -64,7 +64,7 @@ class Uniform(InterLayerInitializer): | |||
super(Uniform, self).__init__() | |||
self.min_val = min_val | |||
self.max_val = max_val | |||
self.rng = bm.random.RandomState(seed=seed) | |||
self.rng = np.random.RandomState(seed=seed) | |||
def __call__(self, shape, dtype=None): | |||
shape = [tools.size2num(d) for d in shape] | |||
@@ -79,7 +79,7 @@ class VarianceScaling(InterLayerInitializer): | |||
self.in_axis = in_axis | |||
self.out_axis = out_axis | |||
self.distribution = distribution | |||
self.rng = bm.random.RandomState(seed=seed) | |||
self.rng = np.random.RandomState(seed=seed) | |||
def __call__(self, shape, dtype=None): | |||
shape = [tools.size2num(d) for d in shape] | |||
@@ -94,18 +94,17 @@ class VarianceScaling(InterLayerInitializer): | |||
raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode)) | |||
variance = bm.array(self.scale / denominator, dtype=dtype) | |||
if self.distribution == "truncated_normal": | |||
from scipy.stats import truncnorm | |||
# constant is stddev of standard normal truncated to (-2, 2) | |||
stddev = bm.sqrt(variance) / bm.array(.87962566103423978, dtype) | |||
res = self.rng.truncated_normal(-2, 2, shape) * stddev | |||
return bm.asarray(res, dtype=dtype) | |||
res = truncnorm(-2, 2).rvs(shape) * stddev | |||
elif self.distribution == "normal": | |||
res = self.rng.normal(size=shape) * bm.sqrt(variance) | |||
return bm.asarray(res, dtype=dtype) | |||
elif self.distribution == "uniform": | |||
res = self.rng.uniform(low=-1, high=1, size=shape) * bm.sqrt(3 * variance) | |||
return bm.asarray(res, dtype=dtype) | |||
else: | |||
raise ValueError("invalid distribution for variance scaling initializer") | |||
return bm.asarray(res, dtype=dtype) | |||
class KaimingUniform(VarianceScaling): | |||
@@ -180,7 +179,7 @@ class Orthogonal(InterLayerInitializer): | |||
super(Orthogonal, self).__init__() | |||
self.scale = scale | |||
self.axis = axis | |||
self.rng = bm.random.RandomState(seed=seed) | |||
self.rng = np.random.RandomState(seed=seed) | |||
def __call__(self, shape, dtype=None): | |||
shape = [tools.size2num(d) for d in shape] | |||
@@ -6,203 +6,5 @@ This module provides various methods to form current inputs. | |||
You can access them through ``brainpy.inputs.XXX``. | |||
""" | |||
import numpy as np | |||
from brainpy import math as bm | |||
__all__ = [ | |||
'section_input', | |||
'constant_input', 'constant_current', | |||
'spike_input', 'spike_current', | |||
'ramp_input', 'ramp_current', | |||
] | |||
def section_input(values, durations, dt=None, return_length=False): | |||
"""Format an input current with different sections. | |||
For example: | |||
If you want to get an input where the size is 0 bwteen 0-100 ms, | |||
and the size is 1. between 100-200 ms. | |||
>>> section_input(values=[0, 1], | |||
>>> durations=[100, 100]) | |||
Parameters | |||
---------- | |||
values : list, np.ndarray | |||
The current values for each period duration. | |||
durations : list, np.ndarray | |||
The duration for each period. | |||
dt : float | |||
Default is None. | |||
return_length : bool | |||
Return the final duration length. | |||
Returns | |||
------- | |||
current_and_duration : tuple | |||
(The formatted current, total duration) | |||
""" | |||
assert len(durations) == len(values), f'"values" and "durations" must be the same length, while ' \ | |||
f'we got {len(values)} != {len(durations)}.' | |||
dt = bm.get_dt() if dt is None else dt | |||
# get input current shape, and duration | |||
I_duration = sum(durations) | |||
I_shape = () | |||
for val in values: | |||
shape = bm.shape(val) | |||
if len(shape) > len(I_shape): | |||
I_shape = shape | |||
# get the current | |||
start = 0 | |||
I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape, dtype=bm.float_) | |||
for c_size, duration in zip(values, durations): | |||
length = int(duration / dt) | |||
I_current[start: start + length] = c_size | |||
start += length | |||
if return_length: | |||
return I_current, I_duration | |||
else: | |||
return I_current | |||
def constant_input(I_and_duration, dt=None): | |||
"""Format constant input in durations. | |||
For example: | |||
If you want to get an input where the size is 0 bwteen 0-100 ms, | |||
and the size is 1. between 100-200 ms. | |||
>>> import brainpy.math as bm | |||
>>> constant_input([(0, 100), (1, 100)]) | |||
>>> constant_input([(bm.zeros(100), 100), (bm.random.rand(100), 100)]) | |||
Parameters | |||
---------- | |||
I_and_duration : list | |||
This parameter receives the current size and the current | |||
duration pairs, like `[(Isize1, duration1), (Isize2, duration2)]`. | |||
dt : float | |||
Default is None. | |||
Returns | |||
------- | |||
current_and_duration : tuple | |||
(The formatted current, total duration) | |||
""" | |||
dt = bm.get_dt() if dt is None else dt | |||
# get input current dimension, shape, and duration | |||
I_duration = 0. | |||
I_shape = () | |||
for I in I_and_duration: | |||
I_duration += I[1] | |||
shape = bm.shape(I[0]) | |||
if len(shape) > len(I_shape): | |||
I_shape = shape | |||
# get the current | |||
start = 0 | |||
I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape, dtype=bm.float_) | |||
for c_size, duration in I_and_duration: | |||
length = int(duration / dt) | |||
I_current[start: start + length] = c_size | |||
start += length | |||
return I_current, I_duration | |||
constant_current = constant_input | |||
def spike_input(sp_times, sp_lens, sp_sizes, duration, dt=None): | |||
"""Format current input like a series of short-time spikes. | |||
For example: | |||
If you want to generate a spike train at 10 ms, 20 ms, 30 ms, 200 ms, 300 ms, | |||
and each spike lasts 1 ms and the spike current is 0.5, then you can use the | |||
following funtions: | |||
>>> spike_input(sp_times=[10, 20, 30, 200, 300], | |||
>>> sp_lens=1., # can be a list to specify the spike length at each point | |||
>>> sp_sizes=0.5, # can be a list to specify the current size at each point | |||
>>> duration=400.) | |||
Parameters | |||
---------- | |||
sp_times : list, tuple | |||
The spike time-points. Must be an iterable object. | |||
sp_lens : int, float, list, tuple | |||
The length of each point-current, mimicking the spike durations. | |||
sp_sizes : int, float, list, tuple | |||
The current sizes. | |||
duration : int, float | |||
The total current duration. | |||
dt : float | |||
The default is None. | |||
Returns | |||
------- | |||
current : bm.ndarray | |||
The formatted input current. | |||
""" | |||
dt = bm.get_dt() if dt is None else dt | |||
assert isinstance(sp_times, (list, tuple)) | |||
if isinstance(sp_lens, (float, int)): | |||
sp_lens = [sp_lens] * len(sp_times) | |||
if isinstance(sp_sizes, (float, int)): | |||
sp_sizes = [sp_sizes] * len(sp_times) | |||
current = bm.zeros(int(np.ceil(duration / dt)), dtype=bm.float_) | |||
for time, dur, size in zip(sp_times, sp_lens, sp_sizes): | |||
pp = int(time / dt) | |||
p_len = int(dur / dt) | |||
current[pp: pp + p_len] = size | |||
return current | |||
spike_current = spike_input | |||
def ramp_input(c_start, c_end, duration, t_start=0, t_end=None, dt=None): | |||
"""Get the gradually changed input current. | |||
Parameters | |||
---------- | |||
c_start : float | |||
The minimum (or maximum) current size. | |||
c_end : float | |||
The maximum (or minimum) current size. | |||
duration : int, float | |||
The total duration. | |||
t_start : float | |||
The ramped current start time-point. | |||
t_end : float | |||
The ramped current end time-point. Default is the None. | |||
dt : float, int, optional | |||
The numerical precision. | |||
Returns | |||
------- | |||
current : bm.ndarray | |||
The formatted current | |||
""" | |||
dt = bm.get_dt() if dt is None else dt | |||
t_end = duration if t_end is None else t_end | |||
current = bm.zeros(int(np.ceil(duration / dt)), dtype=bm.float_) | |||
p1 = int(np.ceil(t_start / dt)) | |||
p2 = int(np.ceil(t_end / dt)) | |||
current[p1: p2] = bm.array(bm.linspace(c_start, c_end, p2 - p1), dtype=bm.float_) | |||
return current | |||
ramp_current = ramp_input | |||
from .currents import * | |||
@@ -0,0 +1,386 @@ | |||
# -*- coding: utf-8 -*- | |||
import numpy as np | |||
from brainpy import math as bm | |||
from brainpy.tools.checking import check_float, check_integer | |||
__all__ = [ | |||
'section_input', | |||
'constant_input', 'constant_current', | |||
'spike_input', 'spike_current', | |||
'ramp_input', 'ramp_current', | |||
'wiener_process', | |||
'ou_process', | |||
'sinusoidal_input', | |||
'square_input', | |||
] | |||
def section_input(values, durations, dt=None, return_length=False): | |||
"""Format an input current with different sections. | |||
For example: | |||
If you want to get an input where the size is 0 bwteen 0-100 ms, | |||
and the size is 1. between 100-200 ms. | |||
>>> section_input(values=[0, 1], | |||
>>> durations=[100, 100]) | |||
Parameters | |||
---------- | |||
values : list, np.ndarray | |||
The current values for each period duration. | |||
durations : list, np.ndarray | |||
The duration for each period. | |||
dt : float | |||
Default is None. | |||
return_length : bool | |||
Return the final duration length. | |||
Returns | |||
------- | |||
current_and_duration : tuple | |||
(The formatted current, total duration) | |||
""" | |||
assert len(durations) == len(values), f'"values" and "durations" must be the same length, while ' \ | |||
f'we got {len(values)} != {len(durations)}.' | |||
dt = bm.get_dt() if dt is None else dt | |||
# get input current shape, and duration | |||
I_duration = sum(durations) | |||
I_shape = () | |||
for val in values: | |||
shape = bm.shape(val) | |||
if len(shape) > len(I_shape): | |||
I_shape = shape | |||
# get the current | |||
start = 0 | |||
I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape, dtype=bm.float_) | |||
for c_size, duration in zip(values, durations): | |||
length = int(duration / dt) | |||
I_current[start: start + length] = c_size | |||
start += length | |||
if return_length: | |||
return I_current, I_duration | |||
else: | |||
return I_current | |||
def constant_input(I_and_duration, dt=None): | |||
"""Format constant input in durations. | |||
For example: | |||
If you want to get an input where the size is 0 bwteen 0-100 ms, | |||
and the size is 1. between 100-200 ms. | |||
>>> import brainpy.math as bm | |||
>>> constant_input([(0, 100), (1, 100)]) | |||
>>> constant_input([(bm.zeros(100), 100), (bm.random.rand(100), 100)]) | |||
Parameters | |||
---------- | |||
I_and_duration : list | |||
This parameter receives the current size and the current | |||
duration pairs, like `[(Isize1, duration1), (Isize2, duration2)]`. | |||
dt : float | |||
Default is None. | |||
Returns | |||
------- | |||
current_and_duration : tuple | |||
(The formatted current, total duration) | |||
""" | |||
dt = bm.get_dt() if dt is None else dt | |||
# get input current dimension, shape, and duration | |||
I_duration = 0. | |||
I_shape = () | |||
for I in I_and_duration: | |||
I_duration += I[1] | |||
shape = bm.shape(I[0]) | |||
if len(shape) > len(I_shape): | |||
I_shape = shape | |||
# get the current | |||
start = 0 | |||
I_current = bm.zeros((int(np.ceil(I_duration / dt)),) + I_shape, dtype=bm.float_) | |||
for c_size, duration in I_and_duration: | |||
length = int(duration / dt) | |||
I_current[start: start + length] = c_size | |||
start += length | |||
return I_current, I_duration | |||
constant_current = constant_input | |||
def spike_input(sp_times, sp_lens, sp_sizes, duration, dt=None): | |||
"""Format current input like a series of short-time spikes. | |||
For example: | |||
If you want to generate a spike train at 10 ms, 20 ms, 30 ms, 200 ms, 300 ms, | |||
and each spike lasts 1 ms and the spike current is 0.5, then you can use the | |||
following funtions: | |||
>>> spike_input(sp_times=[10, 20, 30, 200, 300], | |||
>>> sp_lens=1., # can be a list to specify the spike length at each point | |||
>>> sp_sizes=0.5, # can be a list to specify the current size at each point | |||
>>> duration=400.) | |||
Parameters | |||
---------- | |||
sp_times : list, tuple | |||
The spike time-points. Must be an iterable object. | |||
sp_lens : int, float, list, tuple | |||
The length of each point-current, mimicking the spike durations. | |||
sp_sizes : int, float, list, tuple | |||
The current sizes. | |||
duration : int, float | |||
The total current duration. | |||
dt : float | |||
The default is None. | |||
Returns | |||
------- | |||
current : bm.ndarray | |||
The formatted input current. | |||
""" | |||
dt = bm.get_dt() if dt is None else dt | |||
assert isinstance(sp_times, (list, tuple)) | |||
if isinstance(sp_lens, (float, int)): | |||
sp_lens = [sp_lens] * len(sp_times) | |||
if isinstance(sp_sizes, (float, int)): | |||
sp_sizes = [sp_sizes] * len(sp_times) | |||
current = bm.zeros(int(np.ceil(duration / dt)), dtype=bm.float_) | |||
for time, dur, size in zip(sp_times, sp_lens, sp_sizes): | |||
pp = int(time / dt) | |||
p_len = int(dur / dt) | |||
current[pp: pp + p_len] = size | |||
return current | |||
spike_current = spike_input | |||
def ramp_input(c_start, c_end, duration, t_start=0, t_end=None, dt=None): | |||
"""Get the gradually changed input current. | |||
Parameters | |||
---------- | |||
c_start : float | |||
The minimum (or maximum) current size. | |||
c_end : float | |||
The maximum (or minimum) current size. | |||
duration : int, float | |||
The total duration. | |||
t_start : float | |||
The ramped current start time-point. | |||
t_end : float | |||
The ramped current end time-point. Default is the None. | |||
dt : float, int, optional | |||
The numerical precision. | |||
Returns | |||
------- | |||
current : bm.ndarray | |||
The formatted current | |||
""" | |||
dt = bm.get_dt() if dt is None else dt | |||
t_end = duration if t_end is None else t_end | |||
current = bm.zeros(int(np.ceil(duration / dt)), dtype=bm.float_) | |||
p1 = int(np.ceil(t_start / dt)) | |||
p2 = int(np.ceil(t_end / dt)) | |||
current[p1: p2] = bm.array(bm.linspace(c_start, c_end, p2 - p1), dtype=bm.float_) | |||
return current | |||
ramp_current = ramp_input | |||
def wiener_process(duration, dt=None, n=1, t_start=0., t_end=None, seed=None): | |||
"""Stimulus sampled from a Wiener process, i.e. | |||
drawn from standard normal distribution N(0, sqrt(dt)). | |||
Parameters | |||
---------- | |||
duration: float | |||
The input duration. | |||
dt: float | |||
The numerical precision. | |||
n: int | |||
The variable number. | |||
t_start: float | |||
The start time. | |||
t_end: float | |||
The end time. | |||
seed: int | |||
The noise seed. | |||
""" | |||
dt = bm.get_dt() if dt is None else dt | |||
check_float(dt, 'dt', allow_none=False, min_bound=0.) | |||
check_integer(n, 'n', allow_none=False, min_bound=0) | |||
rng = bm.random.RandomState(seed) | |||
t_end = duration if t_end is None else t_end | |||
i_start = int(t_start / dt) | |||
i_end = int(t_end / dt) | |||
noises = rng.standard_normal((i_end - i_start, n)) * bm.sqrt(dt) | |||
currents = bm.zeros((int(duration / dt), n)) | |||
currents[i_start: i_end] = noises | |||
return currents | |||
def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None, seed=None): | |||
r"""Ornstein–Uhlenbeck input. | |||
.. math:: | |||
dX = (mu - X)/\tau * dt + \sigma*dW | |||
Parameters | |||
---------- | |||
mean: float | |||
Drift of the OU process. | |||
sigma: float | |||
Standard deviation of the Wiener process, i.e. strength of the noise. | |||
tau: float | |||
Timescale of the OU process, in ms. | |||
duration: float | |||
The input duration. | |||
dt: float | |||
The numerical precision. | |||
n: int | |||
The variable number. | |||
t_start: float | |||
The start time. | |||
t_end: float | |||
The end time. | |||
""" | |||
dt = bm.get_dt() if dt is None else dt | |||
dt_sqrt = bm.sqrt(dt) | |||
check_float(dt, 'dt', allow_none=False, min_bound=0.) | |||
check_integer(n, 'n', allow_none=False, min_bound=0) | |||
rng = bm.random.RandomState(seed) | |||
x = bm.Variable(bm.ones(n) * mean) | |||
def _f(t): | |||
x.value = x + dt * ((mean - x) / tau) + sigma * dt_sqrt * rng.standard_normal(n) | |||
f = bm.make_loop(_f, dyn_vars=[x, rng], out_vars=x) | |||
noises = f(bm.arange(t_start, t_end, dt)) | |||
t_end = duration if t_end is None else t_end | |||
i_start = int(t_start / dt) | |||
i_end = int(t_end / dt) | |||
currents = bm.zeros((int(duration / dt), n)) | |||
currents[i_start: i_end] = noises | |||
return currents | |||
def sinusoidal_input(amplitude, frequency, duration, dt=None, t_start=0., t_end=None, dc_bias=False): | |||
"""Sinusoidal input. | |||
Parameters | |||
---------- | |||
amplitude: float | |||
Amplitude of the sinusoid. | |||
frequency: float | |||
Frequency of the sinus oscillation, in Hz | |||
duration: float | |||
The input duration. | |||
t_start: float | |||
The start time. | |||
t_end: float | |||
The end time. | |||
dt: float | |||
The numerical precision. | |||
dc_bias: bool | |||
Whether the sinusoid oscillates around 0 (False), or | |||
has a positive DC bias, thus non-negative (True). | |||
""" | |||
dt = bm.get_dt() if dt is None else dt | |||
check_float(dt, 'dt', allow_none=False, min_bound=0.) | |||
if t_end is None: | |||
t_end = duration | |||
times = bm.arange(0, t_end-t_start, dt) | |||
start_i = int(t_start/dt) | |||
end_i = int(t_end/dt) | |||
sin_inputs = amplitude * bm.sin(2 * bm.pi * times * (frequency / 1000.0)) | |||
if dc_bias: | |||
sin_inputs += amplitude | |||
currents = bm.zeros(int(duration / dt)) | |||
currents[start_i:end_i] = sin_inputs | |||
return currents | |||
def _square(t, duty=0.5): | |||
t, w = np.asarray(t), np.asarray(duty) | |||
w = np.asarray(w + (t - t)) | |||
t = np.asarray(t + (w - w)) | |||
if t.dtype.char in ['fFdD']: | |||
ytype = t.dtype.char | |||
else: | |||
ytype = 'd' | |||
y = np.zeros(t.shape, ytype) | |||
# width must be between 0 and 1 inclusive | |||
mask1 = (w > 1) | (w < 0) | |||
np.place(y, mask1, np.nan) | |||
# on the interval 0 to duty*2*pi function is 1 | |||
tmod = np.mod(t, 2 * np.pi) | |||
mask2 = (1 - mask1) & (tmod < w * 2 * np.pi) | |||
np.place(y, mask2, 1) | |||
# on the interval duty*2*pi to 2*pi function is | |||
# (pi*(w+1)-tmod) / (pi*(1-w)) | |||
mask3 = (1 - mask1) & (1 - mask2) | |||
np.place(y, mask3, -1) | |||
return y | |||
def square_input(amplitude, frequency, duration, dt=None, dc_bias=False, t_start=None, t_end=None): | |||
"""Oscillatory square input. | |||
Parameters | |||
---------- | |||
amplitude: float | |||
Amplitude of the square oscillation. | |||
frequency: float | |||
Frequency of the square oscillation, in Hz. | |||
duration: float | |||
The input duration. | |||
t_start: float | |||
The start time. | |||
t_end: float | |||
The end time. | |||
dt: float | |||
The numerical precision. | |||
dc_bias: bool | |||
Whether the sinusoid oscillates around 0 (False), or | |||
has a positive DC bias, thus non-negative (True). | |||
""" | |||
dt = bm.get_dt() if dt is None else dt | |||
check_float(dt, 'dt', allow_none=False, min_bound=0.) | |||
if t_end is None: | |||
t_end = duration | |||
times = bm.arange(0, t_end - t_start, dt) | |||
currents = bm.zeros(int(duration / dt)) | |||
start_i = int(t_start/dt) | |||
end_i = int(t_end/dt) | |||
sin_inputs = amplitude * _square(2 * bm.pi * times * (frequency / 1000.0)) | |||
if dc_bias: | |||
sin_inputs += amplitude | |||
currents[start_i:end_i] = sin_inputs | |||
return currents | |||
@@ -7,6 +7,7 @@ including: | |||
- ordinary differential equations (ODEs) | |||
- stochastic differential equations (SDEs) | |||
- delay differential equations (DDEs) | |||
- fractional differential equations (FDEs) | |||
Details please see the following. | |||
""" | |||
@@ -41,6 +42,14 @@ from .dde.generic import (ddeint, | |||
set_default_ddeint, | |||
register_dde_integrator) | |||
# others | |||
# FDE tools | |||
from . import fde | |||
from .fde.base import FDEIntegrator | |||
from .fde.generic import (fdeint, | |||
get_default_fdeint, | |||
set_default_fdeint, | |||
register_fde_integrator) | |||
# PDE tools | |||
from . import pde |
@@ -40,6 +40,7 @@ class Integrator(AbstractIntegrator): | |||
@property | |||
def dt(self): | |||
"""The numerical integration precision.""" | |||
return self._dt | |||
@dt.setter | |||
@@ -48,6 +49,7 @@ class Integrator(AbstractIntegrator): | |||
@property | |||
def variables(self): | |||
"""The variables defined in the differential equation.""" | |||
return self._variables | |||
@variables.setter | |||
@@ -56,6 +58,7 @@ class Integrator(AbstractIntegrator): | |||
@property | |||
def parameters(self): | |||
"""The parameters defined in the differential equation.""" | |||
return self._parameters | |||
@parameters.setter | |||
@@ -64,6 +67,7 @@ class Integrator(AbstractIntegrator): | |||
@property | |||
def arguments(self): | |||
"""All arguments when calling the numer integrator of the differential equation.""" | |||
return self._arguments | |||
@arguments.setter | |||
@@ -72,6 +76,7 @@ class Integrator(AbstractIntegrator): | |||
@property | |||
def integral(self): | |||
"""The integral function.""" | |||
return self._integral | |||
@integral.setter | |||
@@ -79,6 +84,7 @@ class Integrator(AbstractIntegrator): | |||
self.set_integral(f) | |||
def set_integral(self, f): | |||
"""Set the integral function.""" | |||
if not callable(f): | |||
raise ValueError(f'integral function must be a callable function, ' | |||
f'but we got {type(f)}: {f}') | |||
@@ -25,7 +25,7 @@ class DDEIntegrator(Integrator): | |||
dt: Union[float, int] = None, | |||
name: str = None, | |||
show_code: bool = False, | |||
state_delays: Dict[str, bm.FixedLenDelay] = None, | |||
state_delays: Dict[str, bm.TimeDelay] = None, | |||
neutral_delays: Dict[str, bm.NeutralDelay] = None, | |||
): | |||
dt = bm.get_dt() if dt is None else dt | |||
@@ -59,7 +59,9 @@ class DDEIntegrator(Integrator): | |||
# delays | |||
self._state_delays = dict() | |||
if state_delays is not None: | |||
check_dict_data(state_delays, key_type=str, val_type=bm.FixedLenDelay) | |||
check_dict_data(state_delays, | |||
key_type=str, | |||
val_type=(bm.TimeDelay, bm.LengthDelay)) | |||
for key, delay in state_delays.items(): | |||
if key not in self.variables: | |||
raise DiffEqError(f'"{key}" is not defined in the variables: {self.variables}') | |||
@@ -67,7 +69,9 @@ class DDEIntegrator(Integrator): | |||
self.register_implicit_nodes(self._state_delays) | |||
self._neutral_delays = dict() | |||
if neutral_delays is not None: | |||
check_dict_data(neutral_delays, key_type=str, val_type=bm.NeutralDelay) | |||
check_dict_data(neutral_delays, | |||
key_type=str, | |||
val_type=bm.NeutralDelay) | |||
for key, delay in neutral_delays.items(): | |||
if key not in self.variables: | |||
raise DiffEqError(f'"{key}" is not defined in the variables: {self.variables}') | |||
@@ -111,11 +115,19 @@ class DDEIntegrator(Integrator): | |||
else: | |||
new_dvars = {k: new_dvars[i] for i, k in enumerate(self.variables)} | |||
for key, delay in self.neutral_delays.items(): | |||
if isinstance(delay, bm.LengthDelay): | |||
delay.update(new_dvars[key]) | |||
elif isinstance(delay, bm.TimeDelay): | |||
delay.update(kwargs['t'] + dt, new_dvars[key]) | |||
raise ValueError('Unknown delay variable.') | |||
# update state delay variables | |||
for key, delay in self.state_delays.items(): | |||
if isinstance(delay, bm.LengthDelay): | |||
delay.update(dict_vars[key]) | |||
elif isinstance(delay, bm.TimeDelay): | |||
delay.update(kwargs['t'] + dt, dict_vars[key]) | |||
raise ValueError('Unknown delay variable.') | |||
return new_vars | |||
@@ -4,6 +4,7 @@ from brainpy.integrators.constants import F, DT | |||
from brainpy.integrators.dde.base import DDEIntegrator | |||
from brainpy.integrators.ode import common | |||
from brainpy.integrators.utils import compile_code, check_kws | |||
from brainpy.integrators.dde.generic import register_dde_integrator | |||
__all__ = [ | |||
'ExplicitRKIntegrator', | |||
@@ -47,8 +48,6 @@ class ExplicitRKIntegrator(DDEIntegrator): | |||
def integral(*vars, **kwargs): | |||
pass | |||
self.build() | |||
def build(self): | |||
@@ -72,24 +71,36 @@ class Euler(ExplicitRKIntegrator): | |||
C = [0] | |||
register_dde_integrator('euler', Euler) | |||
class MidPoint(ExplicitRKIntegrator): | |||
A = [(), (0.5,)] | |||
B = [0, 1] | |||
C = [0, 0.5] | |||
register_dde_integrator('midpoint', MidPoint) | |||
class Heun2(ExplicitRKIntegrator): | |||
A = [(), (1,)] | |||
B = [0.5, 0.5] | |||
C = [0, 1] | |||
register_dde_integrator('heun2', Heun2) | |||
class Ralston2(ExplicitRKIntegrator): | |||
A = [(), ('2/3',)] | |||
B = [0.25, 0.75] | |||
C = [0, '2/3'] | |||
register_dde_integrator('ralston2', Ralston2) | |||
class RK2(ExplicitRKIntegrator): | |||
def __init__(self, f, beta=2 / 3, var_type=None, dt=None, name=None, show_code=False): | |||
self.A = [(), (beta,)] | |||
@@ -98,43 +109,67 @@ class RK2(ExplicitRKIntegrator): | |||
super(RK2, self).__init__(f=f, var_type=var_type, dt=dt, name=name, show_code=show_code) | |||
register_dde_integrator('rk2', RK2) | |||
class RK3(ExplicitRKIntegrator): | |||
A = [(), (0.5,), (-1, 2)] | |||
B = ['1/6', '2/3', '1/6'] | |||
C = [0, 0.5, 1] | |||
register_dde_integrator('rk3', RK3) | |||
class Heun3(ExplicitRKIntegrator): | |||
A = [(), ('1/3',), (0, '2/3')] | |||
B = [0.25, 0, 0.75] | |||
C = [0, '1/3', '2/3'] | |||
register_dde_integrator('heun3', Heun3) | |||
class Ralston3(ExplicitRKIntegrator): | |||
A = [(), (0.5,), (0, 0.75)] | |||
B = ['2/9', '1/3', '4/9'] | |||
C = [0, 0.5, 0.75] | |||
register_dde_integrator('ralston3', Ralston3) | |||
class SSPRK3(ExplicitRKIntegrator): | |||
A = [(), (1,), (0.25, 0.25)] | |||
B = ['1/6', '1/6', '2/3'] | |||
C = [0, 1, 0.5] | |||
register_dde_integrator('ssprk3', SSPRK3) | |||
class RK4(ExplicitRKIntegrator): | |||
A = [(), (0.5,), (0., 0.5), (0., 0., 1)] | |||
B = ['1/6', '1/3', '1/3', '1/6'] | |||
C = [0, 0.5, 0.5, 1] | |||
register_dde_integrator('rk4', RK4) | |||
class Ralston4(ExplicitRKIntegrator): | |||
A = [(), (.4,), (.29697761, .15875964), (.21810040, -3.05096516, 3.83286476)] | |||
B = [.17476028, -.55148066, 1.20553560, .17118478] | |||
C = [0, .4, .45573725, 1] | |||
register_dde_integrator('ralston4', Ralston4) | |||
class RK4Rule38(ExplicitRKIntegrator): | |||
A = [(), ('1/3',), ('-1/3', '1'), (1, -1, 1)] | |||
B = [0.125, 0.375, 0.375, 0.125] | |||
C = [0, '1/3', '2/3', 1] | |||
register_dde_integrator('rk4_38rule', RK4Rule38) |
@@ -1,7 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
from .base import DDEIntegrator | |||
from .explicit_rk import * | |||
__all__ = [ | |||
'ddeint', | |||
@@ -12,19 +11,6 @@ __all__ = [ | |||
] | |||
name2method = { | |||
# explicit RK | |||
'euler': Euler, 'Euler': Euler, | |||
'midpoint': MidPoint, 'MidPoint': MidPoint, | |||
'heun2': Heun2, 'Heun2': Heun2, | |||
'ralston2': Ralston2, 'Ralston2': Ralston2, | |||
'rk2': RK2, 'RK2': RK2, | |||
'rk3': RK3, 'RK3': RK3, | |||
'heun3': Heun3, 'Heun3': Heun3, | |||
'ralston3': Ralston3, 'Ralston3': Ralston3, | |||
'ssprk3': SSPRK3, 'SSPRK3': SSPRK3, | |||
'rk4': RK4, 'RK4': RK4, | |||
'ralston4': Ralston4, 'Ralston4': Ralston4, | |||
'rk4_38rule': RK4Rule38, 'RK4Rule38': RK4Rule38, | |||
} | |||
@@ -132,7 +118,7 @@ def register_dde_integrator(name, integrator): | |||
""" | |||
if name in name2method: | |||
raise ValueError(f'"{name}" has been registered in DDE integrators.') | |||
if DDEIntegrator not in integrator.__bases__: | |||
if not issubclass(integrator, DDEIntegrator): | |||
raise ValueError(f'"integrator" must be an instance of {DDEIntegrator.__name__}') | |||
name2method[name] = integrator | |||
@@ -0,0 +1,401 @@ | |||
# -*- coding: utf-8 -*- | |||
""" | |||
This module provides numerical methods for integrating Caputo fractional derivative equations. | |||
""" | |||
import jax.numpy as jnp | |||
from jax.experimental.host_callback import id_tap | |||
from brainpy import check | |||
import brainpy.math as bm | |||
from brainpy.errors import UnsupportedError | |||
from brainpy.integrators.constants import DT | |||
from brainpy.integrators.utils import check_inits, format_args | |||
from brainpy.tools.checking import check_integer | |||
from .base import FDEIntegrator | |||
from .generic import register_fde_integrator | |||
__all__ = [ | |||
'CaputoEuler', | |||
'CaputoL1Schema', | |||
] | |||
class CaputoEuler(FDEIntegrator): | |||
r"""One-step Euler method for Caputo fractional differential equations. | |||
Given a fractional initial value problem, | |||
.. math:: | |||
D_{*}^{\alpha} y(t)=f(t, y(t)), \quad y^{(k)}(0)=y_{0}^{(k)}, \quad k=0,1, \ldots,\lceil\alpha\rceil-1 | |||
where the :math:`y_0^{(k)}` ay be arbitrary real numbers and where :math:`\alpha>0`. | |||
:math:`D_{*}^{\alpha}` denotes the differential operator in the sense of Caputo, defined | |||
by | |||
.. math:: | |||
D_{*}^{\alpha} z(t)=J^{n-\alpha} D^{n} z(t) | |||
where :math:`n:=\lceil\alpha\rceil` is the smallest integer :math:`\geqslant \alpha`, | |||
Here :math:`D^n` is the usual differential operator of (integer) order :math:`n`, | |||
and for :math:`\mu > 0`, :math:`J^{\mu}` is the Riemann–Liouville integral operator | |||
of order :math:`\mu`, defined by | |||
.. math:: | |||
J^{\mu} z(t)=\frac{1}{\Gamma(\mu)} \int_{0}^{t}(t-u)^{\mu-1} z(u) \mathrm{d} u | |||
The one-step Euler method for fractional differential equation is defined as | |||
.. math:: | |||
y_{k+1} = y_0 + \frac{1}{\Gamma(\alpha)} \sum_{j=0}^{k} b_{j, k+1} f\left(t_{j}, y_{j}\right). | |||
where | |||
.. math:: | |||
b_{j, k+1}=\frac{h^{\alpha}}{\alpha}\left((k+1-j)^{\alpha}-(k-j)^{\alpha}\right). | |||
Examples | |||
-------- | |||
>>> import brainpy as bp | |||
>>> | |||
>>> a, b, c = 10, 28, 8 / 3 | |||
>>> def lorenz(x, y, z, t): | |||
>>> dx = a * (y - x) | |||
>>> dy = x * (b - z) - y | |||
>>> dz = x * y - c * z | |||
>>> return dx, dy, dz | |||
>>> | |||
>>> duration = 30. | |||
>>> dt = 0.005 | |||
>>> inits = [1., 0., 1.] | |||
>>> f = bp.fde.CaputoEuler(lorenz, alpha=0.97, num_step=int(duration / dt), inits=inits) | |||
>>> runner = bp.integrators.IntegratorRunner(f, monitors=list('xyz'), dt=dt, inits=inits) | |||
>>> runner.run(duration) | |||
>>> | |||
>>> import matplotlib.pyplot as plt | |||
>>> plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten()) | |||
>>> plt.show() | |||
Parameters | |||
---------- | |||
f : callable | |||
The derivative function. | |||
alpha: int, float, jnp.ndarray, bm.ndarray, sequence | |||
The fractional-order of the derivative function. Should be in the range of ``(0., 1.)``. | |||
num_step: int | |||
The total time step of the simulation. | |||
inits: sequence | |||
A sequence of the initial values for variables. | |||
dt: float, int | |||
The numerical precision. | |||
name: str | |||
The integrator name. | |||
References | |||
---------- | |||
.. [1] Li, Changpin, and Fanhai Zeng. "The finite difference methods for fractional | |||
ordinary differential equations." Numerical Functional Analysis and | |||
Optimization 34.2 (2013): 149-179. | |||
.. [2] Diethelm, Kai, Neville J. Ford, and Alan D. Freed. "Detailed error analysis | |||
for a fractional Adams method." Numerical algorithms 36.1 (2004): 31-52. | |||
""" | |||
def __init__(self, f, alpha, num_step, inits, dt=None, name=None): | |||
super(CaputoEuler, self).__init__(f=f, alpha=alpha, dt=dt, name=name) | |||
# fractional order | |||
if not jnp.all(jnp.logical_and(self.alpha < 1, self.alpha > 0)): | |||
raise UnsupportedError(f'Only support the fractional order in (0, 1), ' | |||
f'but we got {self.alpha}.') | |||
# memory length | |||
check_integer(num_step, 'num_step', min_bound=1, allow_none=False) | |||
self.num_step = num_step | |||
# initial values | |||
self.inits = check_inits(inits, self.variables) | |||
# coefficients | |||
from scipy.special import rgamma | |||
rgamma_alpha = bm.asarray(rgamma(bm.as_numpy(self.alpha))) | |||
ranges = bm.asarray([bm.arange(num_step + 1) for _ in self.variables]).T | |||
coef = rgamma_alpha * bm.diff(bm.power(ranges, self.alpha), axis=0) | |||
self.coef = bm.flip(coef, axis=0) | |||
# variable states | |||
self.f_states = {v: bm.Variable(bm.zeros((num_step,) + self.inits[v].shape)) | |||
for v in self.variables} | |||
self.register_implicit_vars(self.f_states) | |||
self.idx = bm.Variable(bm.asarray([1], dtype=bm.int32)) | |||
self.set_integral(self._integral_func) | |||
def _check_step(self, args, transform): | |||
dt, t = args | |||
if self.num_step * dt < t: | |||
raise ValueError(f'The maximum number of step is {self.num_step}, ' | |||
f'however, the current time {t} require a time ' | |||
f'step number {t / dt}.') | |||
def _integral_func(self, *args, **kwargs): | |||
# format arguments | |||
all_args = format_args(args, kwargs, self.arguments) | |||
dt = all_args.pop(DT, self.dt) | |||
if check.is_checking(): | |||
id_tap(self._check_step, (dt, all_args['t'])) | |||
# derivative values | |||
devs = self.f(**all_args) | |||
if len(self.variables) == 1: | |||
if not isinstance(devs, (bm.ndarray, jnp.ndarray)): | |||
raise ValueError('Derivative values must be a tensor when there ' | |||
'is only one variable in the equation.') | |||
devs = {self.variables[0]: devs} | |||
else: | |||
if not isinstance(devs, (tuple, list)): | |||
raise ValueError('Derivative values must be a list/tuple of tensors ' | |||
'when there are multiple variables in the equation.') | |||
devs = {var: devs[i] for i, var in enumerate(self.variables)} | |||
# function states | |||
for key in self.variables: | |||
self.f_states[key][self.idx[0]] = devs[key] | |||
# integral results | |||
integrals = [] | |||
idx = ((self.num_step - 1 - self.idx) + bm.arange(self.num_step)) % self.num_step | |||
for i, key in enumerate(self.variables): | |||
integral = self.inits[key] + self.coef[idx, i] @ self.f_states[key] | |||
integrals.append(integral * (dt ** self.alpha[i] / self.alpha[i])) | |||
self.idx.value = (self.idx + 1) % self.num_step | |||
# return integrals | |||
if len(self.variables) == 1: | |||
return integrals[0] | |||
else: | |||
return integrals | |||
register_fde_integrator(name='CaputoEuler', integrator=CaputoEuler) | |||
class CaputoABM(FDEIntegrator): | |||
"""Adams-Bashforth-Moulton (ABM) Method for Caputo fractional differential equations. | |||
""" | |||
pass | |||
class CaputoL1Schema(FDEIntegrator): | |||
r"""The L1 scheme method for the numerical approximation of the Caputo | |||
fractional-order derivative equations [3]_. | |||
For the fractional order :math:`0<\alpha<1`, let the fractional derivative of variable | |||
:math:`x(t)` be | |||
.. math:: | |||
\frac{d^{\alpha} x}{d t^{\alpha}}=F(x, t) | |||
The Caputo definition of the fractional derivative for variable :math:`x` is | |||
.. math:: | |||
\frac{d^{\alpha} x}{d t^{\alpha}}=\frac{1}{\Gamma(1-\alpha)} \int_{0}^{t} \frac{x^{\prime}(u)}{(t-u)^{\alpha}} d u | |||
where :math:`\Gamma` is the Gamma function. | |||
The fractional-order derivative is capable of integrating the activity of the | |||
function over all past activities weighted by a function that follows a power-law. | |||
Using one of the numerical methods, the L1 scheme method [3]_, the numerical | |||
approximation of the fractional-order derivative of :math:`x` is | |||
.. math:: | |||
\frac{d^{\alpha} \chi}{d t^{\alpha}} \approx \frac{(d t)^{-\alpha}}{\Gamma(2-\alpha)}\left[\sum_{k=0}^{N-1}\left[x\left(t_{k+1}\right)- | |||
\mathrm{x}\left(t_{k}\right)\right]\left[(N-k)^{1-\alpha}-(N-1-k)^{1-\alpha}\right]\right] | |||
Therefore, the numerical solution of original system is given by | |||
.. math:: | |||
x\left(t_{N}\right) \approx d t^{\alpha} \Gamma(2-\alpha) F(x, t)+x\left(t_{N-1}\right)- | |||
\left[\sum_{k=0}^{N-2}\left[x\left(t_{k+1}\right)-x\left(t_{k}\right)\right]\left[(N-k)^{1-\alpha}-(N-1-k)^{1-\alpha}\right]\right] | |||
Hence, the solution of the fractional-order derivative can be described as the | |||
difference between the *Markov term* and the *memory trace*. The *Markov term* | |||
weighted by the gamma function is | |||
.. math:: | |||
\text { Markov term }=d t^{\alpha} \Gamma(2-\alpha) F(x, t)+x\left(t_{N-1}\right) | |||
The memory trace (:math:`x`-memory trace since it is related to variable :math:`x`) is | |||
.. math:: | |||
\text { Memory trace }=\sum_{k=0}^{N-2}\left[x\left(t_{k+1}\right)-x\left(t_{k}\right)\right]\left[(N-k)^{1-\alpha}-(N-(k+1))^{1-\alpha}\right] | |||
The memory trace integrates all the past activity and captures the long-term | |||
history of the system. For :math:`\alpha=1`, the memory trace is 0 for any | |||
time :math:`t`. When the fractional order :math:`\alpha` is decreased from 1, | |||
the memory trace non-linearly increases from 0, and its dynamics strongly | |||
depends on time. Thus, the fractional order dynamics strongly deviates | |||
from the first order dynamics. | |||
Examples | |||
-------- | |||
>>> import brainpy as bp | |||
>>> | |||
>>> a, b, c = 10, 28, 8 / 3 | |||
>>> def lorenz(x, y, z, t): | |||
>>> dx = a * (y - x) | |||
>>> dy = x * (b - z) - y | |||
>>> dz = x * y - c * z | |||
>>> return dx, dy, dz | |||
>>> | |||
>>> duration = 30. | |||
>>> dt = 0.005 | |||
>>> inits = [1., 0., 1.] | |||
>>> f = bp.fde.CaputoL1Schema(lorenz, alpha=0.99, num_step=int(duration / dt), inits=inits) | |||
>>> runner = bp.integrators.IntegratorRunner(f, monitors=list('xz'), dt=dt, inits=inits) | |||
>>> runner.run(duration) | |||
>>> | |||
>>> import matplotlib.pyplot as plt | |||
>>> plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten()) | |||
>>> plt.show() | |||
Parameters | |||
---------- | |||
f : callable | |||
The derivative function. | |||
alpha: int, float, jnp.ndarray, bm.ndarray, sequence | |||
The fractional-order of the derivative function. Should be in the range of ``(0., 1.]``. | |||
num_step: int | |||
The total time step of the simulation. | |||
inits: sequence | |||
A sequence of the initial values for variables. | |||
dt: float, int | |||
The numerical precision. | |||
name: str | |||
The integrator name. | |||
References | |||
---------- | |||
.. [3] Oldham, K., & Spanier, J. (1974). The fractional calculus theory | |||
and applications of differentiation and integration to arbitrary | |||
order. Elsevier. | |||
""" | |||
def __init__(self, f, alpha, num_step, inits, dt=None, name=None): | |||
super(CaputoL1Schema, self).__init__(f=f, alpha=alpha, dt=dt, name=name) | |||
# fractional order | |||
if not jnp.all(jnp.logical_and(self.alpha <= 1, self.alpha > 0)): | |||
raise UnsupportedError(f'Only support the fractional order in (0, 1), ' | |||
f'but we got {self.alpha}.') | |||
from scipy.special import gamma | |||
self.gamma_alpha = bm.asarray(gamma(bm.as_numpy(2 - self.alpha))) | |||
# memory length | |||
check_integer(num_step, 'num_step', min_bound=1, allow_none=False) | |||
self.num_step = num_step | |||
# initial values | |||
inits = check_inits(inits, self.variables) | |||
self.inits = {v: bm.Variable(inits[v]) for v in self.variables} | |||
self.register_implicit_vars(self.inits) | |||
# coefficients | |||
ranges = bm.asarray([bm.arange(1, num_step + 2) for _ in self.variables]).T | |||
coef = bm.diff(bm.power(ranges, 1 - self.alpha), axis=0) | |||
self.coef = bm.flip(coef, axis=0) | |||
# variable states | |||
self.diff_states = {v + "_diff": bm.Variable(bm.zeros((num_step,) + self.inits[v].shape)) | |||
for v in self.variables} | |||
self.register_implicit_vars(self.diff_states) | |||
self.idx = bm.Variable(bm.asarray([self.num_step - 1], dtype=bm.int32)) | |||
# integral function | |||
self.set_integral(self._integral_func) | |||
def hists(self, var=None, numpy=True): | |||
if var is None: | |||
hists_ = {k: bm.vstack([self.inits[k], self.diff_states[k + '_diff']]) | |||
for k in self.variables} | |||
hists_ = {k: bm.cumsum(v, axis=0) for k, v in hists_.items()} | |||
if numpy: | |||
hists_ = {k: v.numpy() for k, v in hists_} | |||
return hists_ | |||
else: | |||
assert var in self.variables, (f'"{var}" is not defined in equation ' | |||
f'variables: {self.variables}') | |||
hists_ = bm.vstack([self.inits[var], self.diff_states[var + '_diff']]) | |||
hists_ = bm.cumsum(hists_, axis=0) | |||
if numpy: | |||
hists_ = hists_.numpy() | |||
return hists_ | |||
def _check_step(self, args, transform): | |||
dt, t = args | |||
if self.num_step * dt < t: | |||
raise ValueError(f'The maximum number of step is {self.num_step}, ' | |||
f'however, the current time {t} require a time ' | |||
f'step number {t / dt}.') | |||
def _integral_func(self, *args, **kwargs): | |||
# format arguments | |||
all_args = format_args(args, kwargs, self.arguments) | |||
dt = all_args.pop(DT, self.dt) | |||
if check.is_checking(): | |||
id_tap(self._check_step, (dt, all_args['t'])) | |||
# derivative values | |||
devs = self.f(**all_args) | |||
if len(self.variables) == 1: | |||
if not isinstance(devs, (bm.ndarray, jnp.ndarray)): | |||
raise ValueError('Derivative values must be a tensor when there ' | |||
'is only one variable in the equation.') | |||
devs = {self.variables[0]: devs} | |||
else: | |||
if not isinstance(devs, (tuple, list)): | |||
raise ValueError('Derivative values must be a list/tuple of tensors ' | |||
'when there are multiple variables in the equation.') | |||
devs = {var: devs[i] for i, var in enumerate(self.variables)} | |||
# integral results | |||
integrals = [] | |||
idx = ((self.num_step - 1 - self.idx) + bm.arange(self.num_step)) % self.num_step | |||
for i, key in enumerate(self.variables): | |||
self.diff_states[key + '_diff'][self.idx[0]] = all_args[key] - self.inits[key] | |||
self.inits[key].value = all_args[key] | |||
markov_term = dt ** self.alpha[i] * self.gamma_alpha[i] * devs[key] + all_args[key] | |||
memory_trace = self.coef[idx, i] @ self.diff_states[key + '_diff'] | |||
integral = markov_term - memory_trace | |||
integrals.append(integral) | |||
self.idx.value = (self.idx + 1) % self.num_step | |||
# return integrals | |||
if len(self.variables) == 1: | |||
return integrals[0] | |||
else: | |||
return integrals | |||
register_fde_integrator(name='CaputoL1', integrator=CaputoL1Schema) | |||
register_fde_integrator(name='CaputoL1Schema', integrator=CaputoL1Schema) |
@@ -0,0 +1,190 @@ | |||
# -*- coding: utf-8 -*- | |||
""" | |||
This module provides numerical solvers for Grünwald–Letnikov derivative FDEs. | |||
""" | |||
import jax.numpy as jnp | |||
import brainpy.math as bm | |||
from brainpy.errors import UnsupportedError | |||
from brainpy.integrators.constants import DT | |||
from brainpy.tools.checking import check_integer | |||
from .base import FDEIntegrator | |||
from brainpy.integrators.utils import check_inits, format_args | |||
__all__ = [ | |||
'GLShortMemory' | |||
] | |||
class GLShortMemory(FDEIntegrator): | |||
r"""Efficient Computation of the Short-Memory Principle in Grünwald-Letnikov Method [1]_. | |||
According to the explicit numerical approximation of Grünwald-Letnikov, the | |||
fractional-order derivative :math:`q` for a discrete function :math:`f(t_K)` | |||
can be described as follows: | |||
.. math:: | |||
{{}_{k-\frac{L_{m}}{h}}D_{t_{k}}^{q}}f(t_{k})\approx h^{-q} | |||
\sum\limits_{j=0}^{k}C_{j}^{q}f(t_{k-j}) | |||
where :math:`L_{m}` is the memory lenght, :math:`h` is the integration step size, | |||
and :math:`C_{j}^{q}` are the binomial coefficients which are calculated recursively with | |||
.. math:: | |||
C_{0}^{q}=1,\ C_{j}^{q}=\left(1- \frac{1+q}{j}\right)C_{j-1}^{q},\ j=1,2, \ldots k. | |||
Then, the numerical solution for a fractional-order differential equation (FODE) expressed | |||
in the form | |||
.. math:: | |||
D_{t_{k}}^{q}x(t_{k})=f(x(t_{k})) | |||
can be obtained by | |||
.. math:: | |||
x(t_{k})=f(x(t_{k-1}))h^{q}- \sum\limits_{j=1}^{k}C_{j}^{q}x(t_{k-j}). | |||
for :math:`0 < q < 1`. The above expression requires infinity memory length | |||
for numerical solution since the summation term depends on the discritized | |||
time :math:`t_k`. This implies relatively high simulation times. | |||
To reduce the computational time, the upper bound of summation needs to be modified by | |||
:math:`k=v`, where | |||
.. math:: | |||
v=\begin{cases} k, & k\leq M,\\ L_{m}, & k > M. \end{cases} | |||
This is known as the short-memory principle, where :math:`M` | |||
is the memory window with a width defined by :math:`M=\frac{L_{m}}{h}`. | |||
As was reported in [2]_, the accuracy increases by increaing the width of memory window. | |||
Examples | |||
-------- | |||
>>> import brainpy as bp | |||
>>> | |||
>>> a, b, c = 10, 28, 8 / 3 | |||
>>> def lorenz(x, y, z, t): | |||
>>> dx = a * (y - x) | |||
>>> dy = x * (b - z) - y | |||
>>> dz = x * y - c * z | |||
>>> return dx, dy, dz | |||
>>> | |||
>>> integral = bp.fde.GLShortMemory(lorenz, | |||
>>> alpha=0.96, | |||
>>> num_memory=500, | |||
>>> inits=[1., 0., 1.]) | |||
>>> runner = bp.integrators.IntegratorRunner(integral, | |||
>>> monitors=list('xyz'), | |||
>>> inits=[1., 0., 1.], | |||
>>> dt=0.005) | |||
>>> runner.run(100.) | |||
>>> | |||
>>> import matplotlib.pyplot as plt | |||
>>> plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten()) | |||
>>> plt.show() | |||
Parameters | |||
---------- | |||
f : callable | |||
The derivative function. | |||
alpha: int, float, jnp.ndarray, bm.ndarray, sequence | |||
The fractional-order of the derivative function. Should be in the range of ``(0., 1.)``. | |||
num_memory: int | |||
The length of the short memory. | |||
inits: sequence | |||
A sequence of the initial values for variables. | |||
dt: float, int | |||
The numerical precision. | |||
name: str | |||
The integrator name. | |||
References | |||
---------- | |||
.. [1] Clemente-López, D., et al. "Efficient computation of the | |||
Grünwald-Letnikov method for arm-based implementations of | |||
fractional-order chaotic systems." 2019 8th International | |||
Conference on Modern Circuits and Systems Technologies (MOCAST). IEEE, 2019. | |||
.. [2] M. F. Tolba, A. M. AbdelAty, N. S. Soliman, L. A. Said, A. H. | |||
Madian, A. T. Azar, et al., "FPGA implementation of two fractional | |||
order chaotic systems", International Journal of Electronics and | |||
Communications, vol. 78, pp. 162-172, 2017. | |||
""" | |||
def __init__(self, f, alpha, num_memory, inits, dt=None, name=None): | |||
super(GLShortMemory, self).__init__(f=f, alpha=alpha, dt=dt, name=name) | |||
# fractional order | |||
if not jnp.all(jnp.logical_and(self.alpha <= 1, self.alpha > 0)): | |||
raise UnsupportedError(f'Only support the fractional order in (0, 1), ' | |||
f'but we got {self.alpha}.') | |||
# memory length | |||
check_integer(num_memory, 'num_memory', min_bound=1, allow_none=False) | |||
self.num_memory = num_memory | |||
# initial values | |||
self.inits = check_inits(inits, self.variables) | |||
# delays | |||
self.delays = {} | |||
for key, val in self.inits.items(): | |||
delay = bm.Variable(bm.zeros((self.num_memory,) + val.shape, dtype=bm.float_)) | |||
delay[0] = val | |||
self.delays[key] = delay | |||
self._idx = bm.Variable(bm.asarray([1], dtype=bm.int32)) | |||
self.register_implicit_vars(self.delays) | |||
# binomial coefficients | |||
bc = (1 - (1 + self.alpha.reshape((-1, 1))) / jnp.arange(1, num_memory + 1)) | |||
bc = jnp.cumprod(jnp.vstack([jnp.ones_like(self.alpha), bc.T]), axis=0) | |||
self._binomial_coef = jnp.flip(bc[1:], axis=0) | |||
# integral function | |||
self.set_integral(self._integral_func) | |||
@property | |||
def binomial_coef(self): | |||
return bm.as_numpy(jnp.flip(self._binomial_coef, axis=0)) | |||
def _integral_func(self, *args, **kwargs): | |||
# format arguments | |||
all_args = format_args(args, kwargs, self.arguments) | |||
dt = all_args.pop(DT, self.dt) | |||
# derivative values | |||
devs = self.f(**all_args) | |||
if len(self.variables) == 1: | |||
if not isinstance(devs, (bm.ndarray, jnp.ndarray)): | |||
raise ValueError('Derivative values must be a tensor when there ' | |||
'is only one variable in the equation.') | |||
devs = {self.variables[0]: devs} | |||
else: | |||
if not isinstance(devs, (tuple, list)): | |||
raise ValueError('Derivative values must be a list/tuple of tensors ' | |||
'when there are multiple variables in the equation.') | |||
devs = {var: devs[i] for i, var in enumerate(self.variables)} | |||
# integral results | |||
integrals = [] | |||
idx = (self._idx + bm.arange(self.num_memory)) % self.num_memory | |||
for i, var in enumerate(self.variables): | |||
summation = self._binomial_coef[:, i] @ self.delays[var][idx] | |||
integral = (dt ** self.alpha[i]) * devs[var] - summation | |||
self.delays[var][self._idx[0]] = integral | |||
integrals.append(integral) | |||
self._idx.value = (self._idx + 1) % self.num_memory | |||
# return integrals | |||
if len(self.variables) == 1: | |||
return integrals[0] | |||
else: | |||
return integrals |
@@ -1,95 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
import jax.numpy as jnp | |||
from jax import vmap | |||
from jax.lax import cond | |||
from brainpy.math.special import Gamma | |||
from brainpy.tools.checking import check_float | |||
__all__ = [ | |||
'RL', | |||
] | |||
def RLcoeffs(index_k, index_j, alpha): | |||
"""Calculates coefficients for the RL differintegral operator. | |||
see Baleanu, D., Diethelm, K., Scalas, E., and Trujillo, J.J. (2012). Fractional | |||
Calculus: Models and Numerical Methods. World Scientific. | |||
""" | |||
def f1(x): | |||
k, j = x | |||
return ((k - 1) ** (1 - alpha) - | |||
(k + alpha - 1) * k ** -alpha) | |||
def f2(x): | |||
k, j = x | |||
return cond(k == j, lambda _: 1., f3, x) | |||
def f3(x): | |||
k, j = x | |||
return ((k - j + 1) ** (1 - alpha) + | |||
(k - j - 1) ** (1 - alpha) - | |||
2 * (k - j) ** (1 - alpha)) | |||
return cond(index_j == 0, f1, f2, (index_k, index_j)) | |||
def RLmatrix(alpha, N): | |||
""" Define the coefficient matrix for the RL algorithm. """ | |||
ij = jnp.tril_indices(N, -1) | |||
coeff = vmap(RLcoeffs, in_axes=(0, 0, None))(ij[0], ij[1], alpha) | |||
mat = jnp.zeros((N, N)).at[ij].set(coeff) | |||
diagonal = jnp.arange(N) | |||
mat = mat.at[diagonal, diagonal].set(1.) | |||
return mat / Gamma(2 - alpha) | |||
def RL(alpha, f, domain_start=0.0, domain_end=1.0, dt=0.01): | |||
""" Calculate the RL algorithm using a trapezoid rule over | |||
an array of function values. | |||
Examples | |||
-------- | |||
>>> RL_sqrt = RL(0.5, lambda x: x ** 0.5) | |||
>>> RL_poly = RL(0.5, lambda x: x**2 - 1, 0., 1., 100) | |||
Parameters | |||
---------- | |||
alpha : float | |||
The order of the differintegral to be computed. | |||
f : function | |||
This is the function that is to be differintegrated. | |||
domain_start : float, int | |||
The left-endpoint of the function domain. Default value is 0. | |||
domain_end : float, int | |||
The right-endpoint of the function domain; the point at which the | |||
differintegral is being evaluated. Default value is 1. | |||
dt : float, int | |||
The number of points in the domain. Default value is 100. | |||
Returns | |||
------- | |||
RL : float 1d-array | |||
Each element of the array is the RL differintegral evaluated at the | |||
corresponding function array index. | |||
""" | |||
# checking | |||
assert domain_start < domain_end, ('"domain_start" should be lower than "domain_end", ' | |||
f'while we got {domain_start} >= {domain_end}') | |||
check_float(alpha, 'alpha', allow_none=False) | |||
check_float(domain_start, 'domain_start', allow_none=False) | |||
check_float(domain_end, 'domain_start', allow_none=False) | |||
check_float(dt, 'dt', allow_none=False) | |||
# computing | |||
points = jnp.arange(domain_start, domain_end, dt) | |||
f_values = vmap(f)(points) | |||
# Calculate the RL differintegral. | |||
D = RLmatrix(alpha, points.shape[0]) | |||
RL = dt ** -alpha * jnp.dot(D, f_values) | |||
return RL | |||
@@ -1 +1,8 @@ | |||
# -*- coding: utf-8 -*- | |||
from .base import * | |||
from .generic import * | |||
from .GL import * | |||
from .Caputo import * | |||
@@ -1,8 +1,82 @@ | |||
# -*- coding: utf-8 -*- | |||
from ..base import Integrator | |||
import abc | |||
from typing import Union, Callable | |||
import jax.numpy as jnp | |||
import brainpy.math as bm | |||
from brainpy.integrators.base import Integrator | |||
from brainpy.integrators.utils import get_args | |||
from brainpy.errors import UnsupportedError | |||
__all__ = [ | |||
'FDEIntegrator' | |||
] | |||
class FDEIntegrator(Integrator): | |||
pass | |||
"""Numerical integrator for fractional differential equations (FEDs). | |||
Parameters | |||
---------- | |||
f : callable | |||
The derivative function. | |||
alpha: int, float, jnp.ndarray, bm.ndarray, sequence | |||
The fractional-order of the derivative function. | |||
dt: float, int | |||
The numerical precision. | |||
name: str | |||
The integrator name. | |||
""" | |||
"""The fraction order for each variable.""" | |||
alpha: jnp.ndarray | |||
"""The numerical integration precision.""" | |||
dt: Union[float, int] | |||
"""The fraction derivative function.""" | |||
f: Callable | |||
def __init__(self, f, alpha, dt=None, name=None): | |||
dt = bm.get_dt() if dt is None else dt | |||
parses = get_args(f) | |||
variables = parses[0] # variable names, (before 't') | |||
parameters = parses[1] # parameter names, (after 't') | |||
arguments = parses[2] # function arguments | |||
# super initialization | |||
super(FDEIntegrator, self).__init__(name=name, | |||
variables=variables, | |||
parameters=parameters, | |||
arguments=arguments, | |||
dt=dt) | |||
# derivative function | |||
self.f = f | |||
# fractional-order | |||
if isinstance(alpha, (int, float)): | |||
alpha = jnp.ones(len(self.variables)) * alpha | |||
elif isinstance(alpha, (jnp.ndarray, bm.ndarray)): | |||
alpha = bm.as_device_array(alpha) | |||
elif isinstance(alpha, (list, tuple)): | |||
for a in alpha: | |||
assert isinstance(a, (float, int)), (f'Must be a tuple/list of int/float, ' | |||
f'but we got {type(a)}: {a}') | |||
alpha = jnp.asarray(alpha) | |||
else: | |||
raise UnsupportedError(f'Do not support {type(alpha)}, please ' | |||
f'set fractional-order as number/tuple/list/tensor.') | |||
if len(alpha) != len(self.variables): | |||
raise ValueError(f'There are {len(self.variables)} variables, ' | |||
f'while we only got {len(alpha)} fractional-order ' | |||
f'settings: {alpha}') | |||
self.alpha = alpha | |||
@abc.abstractmethod | |||
def build(self): | |||
raise NotImplementedError('Must implement how to build your step function.') | |||
@@ -0,0 +1,92 @@ | |||
# -*- coding: utf-8 -*- | |||
from .base import FDEIntegrator | |||
__all__ = [ | |||
'fdeint', | |||
'set_default_fdeint', | |||
'get_default_fdeint', | |||
'register_fde_integrator', | |||
'get_supported_methods', | |||
] | |||
name2method = { | |||
} | |||
_DEFAULT_DDE_METHOD = 'CaputoL1' | |||
def fdeint(f=None, method='CaputoL1', **kwargs): | |||
"""Numerical integration for FDEs. | |||
Parameters | |||
---------- | |||
f : callable, function | |||
The derivative function. | |||
method : str | |||
The shortcut name of the numerical integrator. | |||
Returns | |||
------- | |||
integral : FDEIntegrator | |||
The numerical solver of `f`. | |||
""" | |||
method = _DEFAULT_DDE_METHOD if method is None else method | |||
if method not in name2method: | |||
raise ValueError(f'Unknown FDE numerical method "{method}". Currently ' | |||
f'BrainPy only support: {list(name2method.keys())}') | |||
if f is None: | |||
return lambda f: name2method[method](f, **kwargs) | |||
else: | |||
return name2method[method](f, **kwargs) | |||
def set_default_fdeint(method): | |||
"""Set the default ODE numerical integrator method for differential equations. | |||
Parameters | |||
---------- | |||
method : str, callable | |||
Numerical integrator method. | |||
""" | |||
if not isinstance(method, str): | |||
raise ValueError(f'Only support string, not {type(method)}.') | |||
if method not in name2method: | |||
raise ValueError(f'Unsupported ODE_INT numerical method: {method}.') | |||
global _DEFAULT_DDE_METHOD | |||
_DEFAULT_ODE_METHOD = method | |||
def get_default_fdeint(): | |||
"""Get the default ODE numerical integrator method. | |||
Returns | |||
------- | |||
method : str | |||
The default numerical integrator method. | |||
""" | |||
return _DEFAULT_DDE_METHOD | |||
def register_fde_integrator(name, integrator): | |||
"""Register a new ODE integrator. | |||
Parameters | |||
---------- | |||
name: ste | |||
The integrator name. | |||
integrator: type | |||
The integrator. | |||
""" | |||
if name in name2method: | |||
raise ValueError(f'"{name}" has been registered in ODE integrators.') | |||
if not issubclass(integrator, FDEIntegrator): | |||
raise ValueError(f'"integrator" must be an instance of {FDEIntegrator.__name__}') | |||
name2method[name] = integrator | |||
def get_supported_methods(): | |||
"""Get all supported numerical methods for DDEs.""" | |||
return list(name2method.keys()) |
@@ -0,0 +1,33 @@ | |||
# -*- coding: utf-8 -*- | |||
import unittest | |||
import numpy as np | |||
import brainpy as bp | |||
class TestCaputoL1(unittest.TestCase): | |||
def test1(self): | |||
bp.math.enable_x64() | |||
alpha = 0.9 | |||
intg = bp.fde.CaputoL1Schema(lambda a, t: a, | |||
alpha=alpha, | |||
num_step=10, | |||
inits=[1., ]) | |||
for N in [2, 3, 4, 5, 6, 7, 8]: | |||
diff = np.random.rand(N - 1, 1) | |||
memory_trace = 0 | |||
for i in range(N - 1): | |||
c = (N - i) ** (1 - alpha) - (N - i - 1) ** (1 - alpha) | |||
memory_trace += c * diff[i] | |||
intg.idx[0] = N - 1 | |||
intg.diff_states['a_diff'][:N - 1] = bp.math.asarray(diff) | |||
idx = ((intg.num_step - intg.idx) + np.arange(intg.num_step)) % intg.num_step | |||
memory_trace2 = intg.coef[idx, 0] @ intg.diff_states['a_diff'] | |||
print() | |||
print(memory_trace[0], ) | |||
print(memory_trace2[0], bp.math.array_equal(memory_trace[0], memory_trace2[0])) |
@@ -0,0 +1,32 @@ | |||
# -*- coding: utf-8 -*- | |||
import unittest | |||
import matplotlib.pyplot as plt | |||
import brainpy as bp | |||
class TestGLShortMemory(unittest.TestCase): | |||
def test_lorenz(self): | |||
a, b, c = 10, 28, 8 / 3 | |||
def lorenz(x, y, z, t): | |||
dx = a * (y - x) | |||
dy = x * (b - z) - y | |||
dz = x * y - c * z | |||
return dx, dy, dz | |||
integral = bp.fde.GLShortMemory(lorenz, | |||
alpha=0.99, | |||
num_memory=500, | |||
inits=[1., 0., 1.]) | |||
runner = bp.integrators.IntegratorRunner(integral, | |||
monitors=list('xyz'), | |||
inits=[1., 0., 1.], | |||
dt=0.005) | |||
runner.run(100.) | |||
plt.plot(runner.mon.x.flatten(), runner.mon.z.flatten()) | |||
plt.show(block=False) | |||
@@ -1,16 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
import unittest | |||
from brainpy.integrators.fde.RL import RLmatrix | |||
import brainpy.math as bm | |||
class TestRLAlgorithm(unittest.TestCase): | |||
def test_RL_matrix_shape(self): | |||
bm.enable_x64() | |||
print() | |||
print(RLmatrix(0.4, 5)) | |||
self.assertTrue(RLmatrix(0.4, 10).shape == (10, 10)) | |||
self.assertTrue(RLmatrix(1.2, 5).shape == (5, 5)) | |||
@@ -153,7 +153,7 @@ class JointEq(object): | |||
for par in args[len(vars) + 1:]: | |||
if (par not in vars_in_eqs) and (par not in all_arg_pars) and (par not in all_kwarg_pars): | |||
all_arg_pars.append(par) | |||
for key, value in kwargs.values(): | |||
for key, value in kwargs.items(): | |||
if key in all_kwarg_pars and value != all_kwarg_pars[key]: | |||
raise errors.DiffEqError(f'We got two different default value of "{key}": ' | |||
f'{all_kwarg_pars[key]} != {value}') | |||
@@ -58,6 +58,7 @@ from brainpy import errors | |||
from brainpy.integrators import constants as C, utils | |||
from brainpy.integrators.ode import common | |||
from brainpy.integrators.ode.base import ODEIntegrator | |||
from .generic import register_ode_integrator | |||
__all__ = [ | |||
'AdaptiveRKIntegrator', | |||
@@ -239,6 +240,9 @@ class RKF12(AdaptiveRKIntegrator): | |||
C = [0, 0.5, 1] | |||
register_ode_integrator('rkf12', RKF12) | |||
class RKF45(AdaptiveRKIntegrator): | |||
r"""The Runge–Kutta–Fehlberg method for ODEs. | |||
@@ -285,6 +289,9 @@ class RKF45(AdaptiveRKIntegrator): | |||
C = [0, 0.25, 0.375, '12/13', 1, '1/3'] | |||
register_ode_integrator('rkf45', RKF45) | |||
class DormandPrince(AdaptiveRKIntegrator): | |||
r"""The Dormand–Prince method for ODEs. | |||
@@ -336,6 +343,9 @@ class DormandPrince(AdaptiveRKIntegrator): | |||
C = [0, 0.2, 0.3, 0.8, '8/9', 1, 1] | |||
register_ode_integrator('rkdp', DormandPrince) | |||
class CashKarp(AdaptiveRKIntegrator): | |||
r"""The Cash–Karp method for ODEs. | |||
@@ -384,6 +394,9 @@ class CashKarp(AdaptiveRKIntegrator): | |||
C = [0, 0.2, 0.3, 0.6, 1, 0.875] | |||
register_ode_integrator('ck', CashKarp) | |||
class BogackiShampine(AdaptiveRKIntegrator): | |||
r"""The Bogacki–Shampine method for ODEs. | |||
@@ -427,6 +440,9 @@ class BogackiShampine(AdaptiveRKIntegrator): | |||
C = [0, 0.5, 0.75, 1] | |||
register_ode_integrator('bs', BogackiShampine) | |||
class HeunEuler(AdaptiveRKIntegrator): | |||
r"""The Heun–Euler method for ODEs. | |||
@@ -457,6 +473,9 @@ class HeunEuler(AdaptiveRKIntegrator): | |||
C = [0, 1] | |||
register_ode_integrator('heun_euler', HeunEuler) | |||
class DOP853(AdaptiveRKIntegrator): | |||
# def DOP853(f=None, tol=None, adaptive=None, dt=None, show_code=None, each_var_is_scalar=None): | |||
r"""The DOP853 method for ODEs. | |||
@@ -473,3 +492,23 @@ class DOP853(AdaptiveRKIntegrator): | |||
.. [2] http://www.unige.ch/~hairer/software.html | |||
""" | |||
pass | |||
class BoSh3(AdaptiveRKIntegrator): | |||
""" | |||
Bogacki--Shampine's 3/2 method. | |||
3rd order explicit Runge--Kutta method. Has an embedded 2nd order method for | |||
adaptive step sizing. | |||
""" | |||
A = [(), | |||
(0.5,), | |||
(0.0, 0.75), | |||
('2/9', '1/3', '4/9')] | |||
B1 = ['2/9', '1/3', '4/9', 0.0] | |||
B2 = ['-5/72', 1 / 12, '1/9', '-1/8'] | |||
C = [0., 0.5, 0.75, 1.0] | |||
register_ode_integrator('BoSh3', BoSh3) |
@@ -21,6 +21,17 @@ def f_names(f): | |||
class ODEIntegrator(Integrator): | |||
"""Numerical Integrator for Ordinary Differential Equations (ODEs). | |||
Parameters | |||
---------- | |||
f : callable | |||
The derivative function. | |||
var_type: str | |||
The type for each variable. | |||
dt: float, int | |||
The numerical precision. | |||
name: str | |||
The integrator name. | |||
""" | |||
def __init__(self, f, var_type=None, dt=None, name=None, show_code=False): | |||
@@ -70,6 +70,7 @@ More details please see references [2]_ [3]_ [4]_. | |||
from brainpy.integrators import constants as C, utils | |||
from brainpy.integrators.ode import common | |||
from brainpy.integrators.ode.base import ODEIntegrator | |||
from .generic import register_ode_integrator | |||
__all__ = [ | |||
'ExplicitRKIntegrator', | |||
@@ -247,6 +248,9 @@ class Euler(ExplicitRKIntegrator): | |||
C = [0] | |||
register_ode_integrator('euler', Euler) | |||
class MidPoint(ExplicitRKIntegrator): | |||
r"""Explicit midpoint method for ODEs. | |||
@@ -341,6 +345,9 @@ class MidPoint(ExplicitRKIntegrator): | |||
C = [0, 0.5] | |||
register_ode_integrator('midpoint', MidPoint) | |||
class Heun2(ExplicitRKIntegrator): | |||
r"""Heun's method for ODEs. | |||
@@ -406,6 +413,9 @@ class Heun2(ExplicitRKIntegrator): | |||
C = [0, 1] | |||
register_ode_integrator('heun2', Heun2) | |||
class Ralston2(ExplicitRKIntegrator): | |||
r"""Ralston's method for ODEs. | |||
@@ -437,6 +447,9 @@ class Ralston2(ExplicitRKIntegrator): | |||
C = [0, '2/3'] | |||
register_ode_integrator('ralston2', Ralston2) | |||
class RK2(ExplicitRKIntegrator): | |||
r"""Generic second order Runge-Kutta method for ODEs. | |||
@@ -560,6 +573,9 @@ class RK2(ExplicitRKIntegrator): | |||
super(RK2, self).__init__(f=f, var_type=var_type, dt=dt, name=name, show_code=show_code) | |||
register_ode_integrator('rk2', RK2) | |||
class RK3(ExplicitRKIntegrator): | |||
r"""Classical third-order Runge-Kutta method for ODEs. | |||
@@ -598,6 +614,9 @@ class RK3(ExplicitRKIntegrator): | |||
C = [0, 0.5, 1] | |||
register_ode_integrator('rk3', RK3) | |||
class Heun3(ExplicitRKIntegrator): | |||
r"""Heun's third-order method for ODEs. | |||
@@ -622,6 +641,9 @@ class Heun3(ExplicitRKIntegrator): | |||
C = [0, '1/3', '2/3'] | |||
register_ode_integrator('heun3', Heun3) | |||
class Ralston3(ExplicitRKIntegrator): | |||
r"""Ralston's third-order method for ODEs. | |||
@@ -651,6 +673,9 @@ class Ralston3(ExplicitRKIntegrator): | |||
C = [0, 0.5, 0.75] | |||
register_ode_integrator('ralston3', Ralston3) | |||
class SSPRK3(ExplicitRKIntegrator): | |||
r"""Third-order Strong Stability Preserving Runge-Kutta (SSPRK3). | |||
@@ -674,6 +699,9 @@ class SSPRK3(ExplicitRKIntegrator): | |||
C = [0, 1, 0.5] | |||
register_ode_integrator('ssprk3', SSPRK3) | |||
class RK4(ExplicitRKIntegrator): | |||
r"""Classical fourth-order Runge-Kutta method for ODEs. | |||
@@ -741,6 +769,9 @@ class RK4(ExplicitRKIntegrator): | |||
C = [0, 0.5, 0.5, 1] | |||
register_ode_integrator('rk4', RK4) | |||
class Ralston4(ExplicitRKIntegrator): | |||
r"""Ralston's fourth-order method for ODEs. | |||
@@ -772,6 +803,9 @@ class Ralston4(ExplicitRKIntegrator): | |||
C = [0, .4, .45573725, 1] | |||
register_ode_integrator('ralston4', Ralston4) | |||
class RK4Rule38(ExplicitRKIntegrator): | |||
r"""3/8-rule fourth-order method for ODEs. | |||
@@ -811,3 +845,6 @@ class RK4Rule38(ExplicitRKIntegrator): | |||
A = [(), ('1/3',), ('-1/3', '1'), (1, -1, 1)] | |||
B = [0.125, 0.375, 0.375, 0.125] | |||
C = [0, '1/3', '2/3', 1] | |||
register_ode_integrator('rk4_38rule', RK4Rule38) |
@@ -113,6 +113,7 @@ from brainpy.base.collector import Collector | |||
from brainpy.integrators import constants as C, utils, joint_eq | |||
from brainpy.integrators.analysis_by_ast import separate_variables | |||
from brainpy.integrators.ode.base import ODEIntegrator | |||
from .generic import register_ode_integrator | |||
try: | |||
import sympy | |||
@@ -506,6 +507,10 @@ class ExponentialEuler(ODEIntegrator): | |||
return s_df_part | |||
register_ode_integrator('exponential_euler', ExponentialEuler) | |||
register_ode_integrator('exp_euler', ExponentialEuler) | |||
class ExpEulerAuto(ODEIntegrator): | |||
"""Exponential Euler method using automatic differentiation. | |||
@@ -762,3 +767,7 @@ class ExpEulerAuto(ODEIntegrator): | |||
return args[0] + dt * phi * derivative | |||
return [(integral, vars, pars), ] | |||
register_ode_integrator('exp_euler_auto', ExpEulerAuto) | |||
register_ode_integrator('exp_auto', ExpEulerAuto) |
@@ -1,9 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
from .base import ODEIntegrator | |||
from .adaptive_rk import * | |||
from .explicit_rk import * | |||
from .exponential import * | |||
__all__ = [ | |||
'odeint', | |||
@@ -14,31 +11,6 @@ __all__ = [ | |||
] | |||
name2method = { | |||
# explicit RK | |||
'euler': Euler, 'Euler': Euler, | |||
'midpoint': MidPoint, 'MidPoint': MidPoint, | |||
'heun2': Heun2, 'Heun2': Heun2, | |||
'ralston2': Ralston2, 'Ralston2': Ralston2, | |||
'rk2': RK2, 'RK2': RK2, | |||
'rk3': RK3, 'RK3': RK3, | |||
'heun3': Heun3, 'Heun3': Heun3, | |||
'ralston3': Ralston3, 'Ralston3': Ralston3, | |||
'ssprk3': SSPRK3, 'SSPRK3': SSPRK3, | |||
'rk4': RK4, 'RK4': RK4, | |||
'ralston4': Ralston4, 'Ralston4': Ralston4, | |||
'rk4_38rule': RK4Rule38, 'RK4Rule38': RK4Rule38, | |||
# adaptive RK | |||
'rkf12': RKF12, 'RKF12': RKF12, | |||
'rkf45': RKF45, 'RKF45': RKF45, | |||
'rkdp': DormandPrince, 'dp': DormandPrince, 'DormandPrince': DormandPrince, | |||
'ck': CashKarp, 'CashKarp': CashKarp, | |||
'bs': BogackiShampine, 'BogackiShampine': BogackiShampine, | |||
'heun_euler': HeunEuler, 'HeunEuler': HeunEuler, | |||
# exponential integrators | |||
'exponential_euler': ExponentialEuler, 'exp_euler': ExponentialEuler, 'ExponentialEuler': ExponentialEuler, | |||
'exp_euler_auto': ExpEulerAuto, 'exp_auto': ExpEulerAuto, 'ExpEulerAuto': ExpEulerAuto, | |||
} | |||
_DEFAULT_DDE_METHOD = 'euler' | |||
@@ -134,7 +106,7 @@ def register_ode_integrator(name, integrator): | |||
""" | |||
if name in name2method: | |||
raise ValueError(f'"{name}" has been registered in ODE integrators.') | |||
if ODEIntegrator not in integrator.__bases__: | |||
if not issubclass(integrator, ODEIntegrator): | |||
raise ValueError(f'"integrator" must be an instance of {ODEIntegrator.__name__}') | |||
name2method[name] = integrator | |||
@@ -93,7 +93,7 @@ class IntegratorRunner(Runner): | |||
>>> dt = 0.01; beta=2.; gamma=1.; tau=2.; n=9.65 | |||
>>> mg_eq = lambda x, t, xdelay: (beta * xdelay(t - tau) / (1 + xdelay(t - tau) ** n) | |||
>>> - gamma * x) | |||
>>> xdelay = bm.FixedLenDelay(1, delay_len=tau, dt=dt, before_t0=lambda t: 1.2) | |||
>>> xdelay = bm.TimeDelay(bm.asarray([1.2]), delay_len=tau, dt=dt, before_t0=lambda t: 1.2) | |||
>>> integral = bp.ddeint(mg_eq, method='rk4', state_delays={'x': xdelay}) | |||
>>> runner = bp.integrators.IntegratorRunner( | |||
>>> integral, | |||
@@ -1,8 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
from .base import SDEIntegrator | |||
from .normal import * | |||
from .srk_scalar import * | |||
__all__ = [ | |||
'sdeint', | |||
@@ -13,15 +11,6 @@ __all__ = [ | |||
] | |||
name2method = { | |||
'euler': Euler, 'Euler': Euler, | |||
'heun': Heun, 'Heun': Heun, | |||
'milstein': Milstein, 'Milstein': Milstein, | |||
'exponential_euler': ExponentialEuler, 'exp_euler': ExponentialEuler, 'ExponentialEuler': ExponentialEuler, | |||
# RK methods | |||
'srk1w1': SRK1W1, 'SRK1W1': SRK1W1, | |||
'srk2w1': SRK2W1, 'SRK2W1': SRK2W1, | |||
'klpl': KlPl, 'KlPl': KlPl, | |||
} | |||
_DEFAULT_SDE_METHOD = 'euler' | |||
@@ -98,7 +87,7 @@ def register_sde_integrator(name, integrator): | |||
""" | |||
if name in name2method: | |||
raise ValueError(f'"{name}" has been registered in SDE integrators.') | |||
if SDEIntegrator not in integrator.__bases__: | |||
if not issubclass(integrator, SDEIntegrator): | |||
raise ValueError(f'"integrator" must be an instance of {SDEIntegrator.__name__}') | |||
name2method[name] = integrator | |||
@@ -6,6 +6,7 @@ from brainpy import errors, math | |||
from brainpy.integrators import constants, utils | |||
from brainpy.integrators.analysis_by_ast import separate_variables | |||
from brainpy.integrators.sde.base import SDEIntegrator | |||
from .generic import register_sde_integrator | |||
try: | |||
import sympy | |||
@@ -142,6 +143,9 @@ class Euler(SDEIntegrator): | |||
func_name=self.func_name) | |||
register_sde_integrator('euler', Euler) | |||
class Heun(Euler): | |||
def __init__(self, f, g, dt=None, name=None, show_code=False, | |||
var_type=None, intg_type=None, wiener_type=None): | |||
@@ -154,6 +158,9 @@ class Heun(Euler): | |||
self.build() | |||
register_sde_integrator('heun', Heun) | |||
class Milstein(SDEIntegrator): | |||
def __init__(self, f, g, dt=None, name=None, show_code=False, | |||
var_type=None, intg_type=None, wiener_type=None): | |||
@@ -238,6 +245,9 @@ class Milstein(SDEIntegrator): | |||
func_name=self.func_name) | |||
register_sde_integrator('milstein', Milstein) | |||
class ExponentialEuler(SDEIntegrator): | |||
r"""First order, explicit exponential Euler method. | |||
@@ -399,3 +409,7 @@ class ExponentialEuler(SDEIntegrator): | |||
if hasattr(self.derivative[constants.F], '__self__'): | |||
host = self.derivative[constants.F].__self__ | |||
self.integral = self.integral.__get__(host, host.__class__) | |||
register_sde_integrator('exponential_euler', ExponentialEuler) | |||
register_sde_integrator('exp_euler', ExponentialEuler) |
@@ -2,6 +2,7 @@ | |||
from brainpy.integrators import constants, utils | |||
from brainpy.integrators.sde.base import SDEIntegrator | |||
from .generic import register_sde_integrator | |||
__all__ = [ | |||
'SRK1W1', | |||
@@ -175,6 +176,9 @@ class SRK1W1(SDEIntegrator): | |||
func_name=self.func_name) | |||
register_sde_integrator('srk1w1', SRK1W1) | |||
class SRK2W1(SDEIntegrator): | |||
r"""Order 1.5 Strong SRK Methods for SDEs with Scalar Noise. | |||
@@ -315,6 +319,9 @@ class SRK2W1(SDEIntegrator): | |||
func_name=self.func_name) | |||
register_sde_integrator('srk2w1', SRK2W1) | |||
class KlPl(SDEIntegrator): | |||
def __init__(self, f, g, dt=None, name=None, show_code=False, | |||
var_type=None, intg_type=None, wiener_type=None): | |||
@@ -367,3 +374,6 @@ class KlPl(SDEIntegrator): | |||
code_lines=self.code_lines, | |||
show_code=self.show_code, | |||
func_name=self.func_name) | |||
register_sde_integrator('klpl', KlPl) |
@@ -4,12 +4,17 @@ | |||
import inspect | |||
from pprint import pprint | |||
import brainpy.math as bm | |||
from brainpy.errors import UnsupportedError | |||
from brainpy import errors | |||
__all__ = [ | |||
'get_args', | |||
'check_kws', | |||
'compile_code', | |||
'check_inits', | |||
'format_args', | |||
] | |||
@@ -103,3 +108,35 @@ def compile_code(code_lines, code_scope, func_name, show_code=False): | |||
exec(compile(code, '', 'exec'), code_scope) | |||
new_f = code_scope[func_name] | |||
return new_f | |||
def check_inits(inits, variables): | |||
if isinstance(inits, (tuple, list)): | |||
assert len(inits) == len(variables), (f'Then number of variables is {len(variables)}, ' | |||
f'however we only got {len(inits)} initial values.') | |||
inits = {v: inits[i] for i, v in enumerate(variables)} | |||
elif isinstance(inits, dict): | |||
assert len(inits) == len(variables), (f'Then number of variables is {len(variables)}, ' | |||
f'however we only got {len(inits)} initial values.') | |||
else: | |||
raise UnsupportedError('Only supports dict/sequence of data for initial values. ' | |||
f'But we got {type(inits)}: {inits}') | |||
for key in list(inits.keys()): | |||
if key not in variables: | |||
raise ValueError(f'"{key}" is not defined in variables: {variables}') | |||
val = inits[key] | |||
if isinstance(val, (float, int)): | |||
inits[key] = bm.asarray([val], dtype=bm.float_) | |||
return inits | |||
def format_args(args, kwargs, arguments): | |||
all_args = dict() | |||
for i, arg in enumerate(args): | |||
all_args[arguments[i]] = arg | |||
for key, arg in kwargs.items(): | |||
if key in all_args: | |||
raise ValueError(f'{key} has been provided in *args, ' | |||
f'but we detect it again in **kwargs.') | |||
all_args[key] = arg | |||
return all_args |
@@ -41,7 +41,7 @@ def _return(outputs, reduction): | |||
def cross_entropy_loss(logits, targets, weight=None, reduction='mean'): | |||
"""This criterion combines ``LogSoftmax`` and `NLLLoss`` in one single class. | |||
r"""This criterion combines ``LogSoftmax`` and `NLLLoss`` in one single class. | |||
It is useful when training a classification problem with `C` classes. | |||
If provided, the optional argument :attr:`weight` should be a 1D `Tensor` | |||
@@ -120,7 +120,7 @@ def cross_entropy_loss(logits, targets, weight=None, reduction='mean'): | |||
def cross_entropy_sparse(logits, labels): | |||
"""Computes the softmax cross-entropy loss. | |||
r"""Computes the softmax cross-entropy loss. | |||
Args: | |||
logits: (batch, ..., #class) tensor of logits. | |||
@@ -155,7 +155,7 @@ def cross_entropy_sigmoid(logits, labels): | |||
def l1_loos(logits, targets, reduction='sum'): | |||
"""Creates a criterion that measures the mean absolute error (MAE) between each element in | |||
r"""Creates a criterion that measures the mean absolute error (MAE) between each element in | |||
the logits :math:`x` and targets :math:`y`. It is useful in regression problems. | |||
The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: | |||
@@ -207,7 +207,7 @@ def l1_loos(logits, targets, reduction='sum'): | |||
def l2_loss(predicts, targets): | |||
"""Computes the L2 loss. | |||
r"""Computes the L2 loss. | |||
The 0.5 term is standard in "Pattern Recognition and Machine Learning" | |||
by Bishop [1]_, but not "The Elements of Statistical Learning" by Tibshirani. | |||
@@ -246,7 +246,7 @@ def l2_norm(x): | |||
def mean_absolute_error(x, y, axis=None): | |||
"""Computes the mean absolute error between x and y. | |||
r"""Computes the mean absolute error between x and y. | |||
Args: | |||
x: a tensor of shape (d0, .. dN-1). | |||
@@ -261,7 +261,7 @@ def mean_absolute_error(x, y, axis=None): | |||
def mean_squared_error(predicts, targets, axis=None): | |||
"""Computes the mean squared error between x and y. | |||
r"""Computes the mean squared error between x and y. | |||
Args: | |||
predicts: a tensor of shape (d0, .. dN-1). | |||
@@ -276,7 +276,7 @@ def mean_squared_error(predicts, targets, axis=None): | |||
def mean_squared_log_error(y_true, y_pred, axis=None): | |||
"""Computes the mean squared logarithmic error between y_true and y_pred. | |||
r"""Computes the mean squared logarithmic error between y_true and y_pred. | |||
Args: | |||
y_true: a tensor of shape (d0, .. dN-1). | |||
@@ -291,7 +291,7 @@ def mean_squared_log_error(y_true, y_pred, axis=None): | |||
def huber_loss(predicts, targets, delta: float = 1.0): | |||
"""Huber loss. | |||
r"""Huber loss. | |||
Huber loss is similar to L2 loss close to zero, L1 loss away from zero. | |||
If gradient descent is applied to the `huber loss`, it is equivalent to | |||
@@ -353,7 +353,7 @@ def multiclass_logistic_loss(label: int, logits: jn.ndarray) -> float: | |||
def smooth_labels(labels, alpha: float) -> jn.ndarray: | |||
"""Apply label smoothing. | |||
r"""Apply label smoothing. | |||
Label smoothing is often used in combination with a cross-entropy loss. | |||
Smoothed labels favour small logit gaps, and it has been shown that this can | |||
provide better model calibration by preventing overconfident predictions. | |||
@@ -411,7 +411,7 @@ def softmax_cross_entropy(logits, labels): | |||
def log_cosh(predicts, targets=None, ): | |||
"""Calculates the log-cosh loss for a set of predictions. | |||
r"""Calculates the log-cosh loss for a set of predictions. | |||
log(cosh(x)) is approximately `(x**2) / 2` for small x and `abs(x) - log(2)` | |||
for large x. It is a twice differentiable alternative to the Huber loss. | |||
@@ -46,7 +46,7 @@ from . import random | |||
from .autograd import * | |||
from .controls import * | |||
from .jit import * | |||
from .parallels import * | |||
# from .parallels import * | |||
# settings | |||
from . import setting | |||
@@ -56,8 +56,7 @@ from .function import * | |||
# functions | |||
from .activations import * | |||
from . import activations | |||
from .compact import * | |||
from . import special | |||
from .compat import * | |||
def get_dint(): | |||
@@ -1,8 +1,7 @@ | |||
# -*- coding: utf-8 -*- | |||
from typing import Union, Callable, Dict, Sequence | |||
from functools import partial | |||
from typing import Union, Callable, Dict, Sequence | |||
import jax | |||
import numpy as np | |||
@@ -41,7 +40,7 @@ def _make_cls_call_func(grad_func, grad_tree, grad_vars, dyn_vars, | |||
except UnexpectedTracerError as e: | |||
for v, d in zip(grad_vars, old_grad_vs): v.value = d | |||
for v, d in zip(dyn_vars, old_dyn_vs): v.value = d | |||
raise errors.JaxTracerError(variables=dyn_vars+grad_vars) from e | |||
raise errors.JaxTracerError(variables=dyn_vars + grad_vars) from e | |||
for v, d in zip(grad_vars, new_grad_vs): v.value = d | |||
for v, d in zip(dyn_vars, new_dyn_vs): v.value = d | |||
@@ -1,8 +1,10 @@ | |||
# -*- coding: utf-8 -*- | |||
__all__ = [ | |||
'optimizers', 'losses' | |||
'optimizers', 'losses', | |||
'FixedLenDelay', | |||
] | |||
from . import optimizers, losses | |||
from .delay_vars import * | |||
@@ -0,0 +1,45 @@ | |||
# -*- coding: utf-8 -*- | |||
import warnings | |||
from typing import Union, Callable | |||
import jax.numpy as jnp | |||
from brainpy.math.jaxarray import ndarray | |||
from brainpy.math.numpy_ops import zeros | |||
from brainpy.math.delay_vars import TimeDelay | |||
__all__ = [ | |||
'FixedLenDelay' | |||
] | |||
def FixedLenDelay(shape, | |||
delay_len: Union[float, int], | |||
before_t0: Union[Callable, ndarray, jnp.ndarray, float, int] = None, | |||
t0: Union[float, int] = 0., | |||
dt: Union[float, int] = None, | |||
name: str = None, | |||
interp_method='linear_interp', ): | |||
"""Delay variable which has a fixed delay length. | |||
.. deprecated:: 2.1.2 | |||
Please use "brainpy.math.TimeDelay" instead. | |||
See Also | |||
-------- | |||
TimeDelay | |||
""" | |||
warnings.warn('Please use "brainpy.math.TimeDelay" instead. ' | |||
'"brainpy.math.FixedLenDelay" is deprecated since version 2.1.2. ', | |||
DeprecationWarning) | |||
return TimeDelay(inits=zeros(shape), | |||
delay_len=delay_len, | |||
before_t0=before_t0, | |||
t0=t0, | |||
dt=dt, | |||
name=name, | |||
interp_method=interp_method) | |||
@@ -17,6 +17,11 @@ __all__ = [ | |||
def cross_entropy_loss(*args, **kwargs): | |||
"""Cross entropy loss. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.losses.cross_entropy_loss" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.losses.XXX" instead. ' | |||
'"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ', | |||
DeprecationWarning) | |||
@@ -24,6 +29,11 @@ def cross_entropy_loss(*args, **kwargs): | |||
def l1_loos(*args, **kwargs): | |||
"""L1 loss. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.losses.l1_loss" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.losses.XXX" instead. ' | |||
'"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ', | |||
DeprecationWarning) | |||
@@ -31,6 +41,11 @@ def l1_loos(*args, **kwargs): | |||
def l2_loss(*args, **kwargs): | |||
"""L2 loss. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.losses.l2_loss" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.losses.XXX" instead. ' | |||
'"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ', | |||
DeprecationWarning) | |||
@@ -38,6 +53,11 @@ def l2_loss(*args, **kwargs): | |||
def l2_norm(*args, **kwargs): | |||
"""L2 normal. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.losses.l2_norm" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.losses.XXX" instead. ' | |||
'"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ', | |||
DeprecationWarning) | |||
@@ -45,6 +65,11 @@ def l2_norm(*args, **kwargs): | |||
def huber_loss(*args, **kwargs): | |||
"""Huber loss. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.losses.huber_loss" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.losses.XXX" instead. ' | |||
'"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ', | |||
DeprecationWarning) | |||
@@ -52,6 +77,11 @@ def huber_loss(*args, **kwargs): | |||
def mean_absolute_error(*args, **kwargs): | |||
"""mean absolute error loss. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.losses.mean_absolute_error" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.losses.XXX" instead. ' | |||
'"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ', | |||
DeprecationWarning) | |||
@@ -59,6 +89,11 @@ def mean_absolute_error(*args, **kwargs): | |||
def mean_squared_error(*args, **kwargs): | |||
"""Mean squared error loss. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.losses.mean_squared_error" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.losses.XXX" instead. ' | |||
'"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ', | |||
DeprecationWarning) | |||
@@ -66,6 +101,11 @@ def mean_squared_error(*args, **kwargs): | |||
def mean_squared_log_error(*args, **kwargs): | |||
"""Mean squared log error loss. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.losses.mean_squared_log_error" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.losses.XXX" instead. ' | |||
'"brainpy.math.losses.XXX" is deprecated since version 2.0.3. ', | |||
DeprecationWarning) |
@@ -22,6 +22,11 @@ __all__ = [ | |||
def SGD(*args, **kwargs): | |||
"""SGD optimizer. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.optim.SGD" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.optim.SGD" instead. ' | |||
'"brainpy.math.optimizers.SGD" is ' | |||
'deprecated since version 2.0.3. ', | |||
@@ -30,6 +35,11 @@ def SGD(*args, **kwargs): | |||
def Momentum(*args, **kwargs): | |||
"""Momentum optimizer. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.optim.Momentum" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.optim.Momentum" instead. ' | |||
'"brainpy.math.optimizers.Momentum" is ' | |||
'deprecated since version 2.0.3. ', | |||
@@ -38,6 +48,11 @@ def Momentum(*args, **kwargs): | |||
def MomentumNesterov(*args, **kwargs): | |||
"""MomentumNesterov optimizer. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.optim.MomentumNesterov" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.optim.MomentumNesterov" instead. ' | |||
'"brainpy.math.optimizers.MomentumNesterov" is ' | |||
'deprecated since version 2.0.3. ', | |||
@@ -46,6 +61,11 @@ def MomentumNesterov(*args, **kwargs): | |||
def Adagrad(*args, **kwargs): | |||
"""Adagrad optimizer. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.optim.Adagrad" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.optim.Adagrad" instead. ' | |||
'"brainpy.math.optimizers.Adagrad" is ' | |||
'deprecated since version 2.0.3. ', | |||
@@ -54,6 +74,11 @@ def Adagrad(*args, **kwargs): | |||
def Adadelta(*args, **kwargs): | |||
"""Adadelta optimizer. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.optim.Adadelta" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.optim.Adadelta" instead. ' | |||
'"brainpy.math.optimizers.Adadelta" is ' | |||
'deprecated since version 2.0.3. ', | |||
@@ -62,6 +87,11 @@ def Adadelta(*args, **kwargs): | |||
def RMSProp(*args, **kwargs): | |||
"""RMSProp optimizer. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.optim.RMSProp" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.optim.RMSProp" instead. ' | |||
'"brainpy.math.optimizers.RMSProp" is ' | |||
'deprecated since version 2.0.3. ', | |||
@@ -70,6 +100,11 @@ def RMSProp(*args, **kwargs): | |||
def Adam(*args, **kwargs): | |||
"""Adam optimizer. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.optim.Adam" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.optim.Adam" instead. ' | |||
'"brainpy.math.optimizers.Adam" is ' | |||
'deprecated since version 2.0.3. ', | |||
@@ -78,6 +113,11 @@ def Adam(*args, **kwargs): | |||
def Constant(*args, **kwargs): | |||
"""Constant scheduler. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.optim.Constant" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.optim.Constant" instead. ' | |||
'"brainpy.math.optimizers.Constant" is ' | |||
'deprecated since version 2.0.3. ', | |||
@@ -86,6 +126,11 @@ def Constant(*args, **kwargs): | |||
def ExponentialDecay(*args, **kwargs): | |||
"""ExponentialDecay scheduler. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.optim.ExponentialDecay" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.optim.ExponentialDecay" instead. ' | |||
'"brainpy.math.optimizers.ExponentialDecay" is ' | |||
'deprecated since version 2.0.3. ', | |||
@@ -94,6 +139,11 @@ def ExponentialDecay(*args, **kwargs): | |||
def InverseTimeDecay(*args, **kwargs): | |||
"""InverseTimeDecay scheduler. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.optim.InverseTimeDecay" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.optim.InverseTimeDecay" instead. ' | |||
'"brainpy.math.optimizers.InverseTimeDecay" is ' | |||
'deprecated since version 2.0.3. ', | |||
@@ -102,6 +152,11 @@ def InverseTimeDecay(*args, **kwargs): | |||
def PolynomialDecay(*args, **kwargs): | |||
"""PolynomialDecay scheduler. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.optim.PolynomialDecay" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.optim.PolynomialDecay" instead. ' | |||
'"brainpy.math.optimizers.PolynomialDecay" is ' | |||
'deprecated since version 2.0.3. ', | |||
@@ -110,6 +165,11 @@ def PolynomialDecay(*args, **kwargs): | |||
def PiecewiseConstant(*args, **kwargs): | |||
"""PiecewiseConstant scheduler. | |||
.. deprecated:: 2.1.0 | |||
Please use "brainpy.optim.PiecewiseConstant" instead. | |||
""" | |||
warnings.warn('Please use "brainpy.optim.PiecewiseConstant" instead. ' | |||
'"brainpy.math.optimizers.PiecewiseConstant" is ' | |||
'deprecated since version 2.0.3. ', |
@@ -1,41 +1,47 @@ | |||
# -*- coding: utf-8 -*- | |||
from typing import Union, Callable, Tuple | |||
from typing import Union, Callable | |||
import jax.numpy as jnp | |||
import numpy as np | |||
from jax import vmap | |||
from jax.experimental.host_callback import id_tap | |||
from jax.lax import cond | |||
from brainpy import math as bm | |||
from brainpy import check | |||
from brainpy.base.base import Base | |||
from brainpy.tools.checking import check_float | |||
from brainpy.tools.others import to_size | |||
from brainpy.errors import UnsupportedError | |||
from brainpy.math import numpy_ops as ops | |||
from brainpy.math.jaxarray import ndarray, Variable | |||
from brainpy.math.setting import get_dt | |||
from brainpy.tools.checking import check_float, check_integer | |||
__all__ = [ | |||
'AbstractDelay', | |||
'FixedLenDelay', | |||
'TimeDelay', | |||
'NeutralDelay', | |||
'LengthDelay', | |||
] | |||
class AbstractDelay(Base): | |||
def update(self, time, value): | |||
def update(self, *args, **kwargs): | |||
raise NotImplementedError | |||
_FUNC_BEFORE = 'function' | |||
_DATA_BEFORE = 'data' | |||
_INTERP_LINEAR = 'linear_interp' | |||
_INTERP_ROUND = 'round' | |||
class FixedLenDelay(AbstractDelay): | |||
"""Delay variable which has a fixed delay length. | |||
class TimeDelay(AbstractDelay): | |||
"""Delay variable which has a fixed delay time length. | |||
For example, we create a delay variable which has a maximum delay length of 1 ms | |||
>>> import brainpy.math as bm | |||
>>> delay = bm.FixedLenDelay(bm.zeros(3), delay_len=1., dt=0.1) | |||
>>> delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1) | |||
>>> delay(-0.5) | |||
[-0. -0. -0.] | |||
@@ -43,13 +49,13 @@ class FixedLenDelay(AbstractDelay): | |||
1. the one-dimensional delay data | |||
>>> delay = bm.FixedLenDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t) | |||
>>> delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1, before_t0=lambda t: t) | |||
>>> delay(-0.2) | |||
[-0.2 -0.2 -0.2] | |||
2. the two-dimensional delay data | |||
>>> delay = bm.FixedLenDelay((3, 2), delay_len=1., dt=0.1, before_t0=lambda t: t) | |||
>>> delay = bm.TimeDelay(bm.zeros((3, 2)), delay_len=1., dt=0.1, before_t0=lambda t: t) | |||
>>> delay(-0.6) | |||
[[-0.6 -0.6] | |||
[-0.6 -0.6] | |||
@@ -57,8 +63,8 @@ class FixedLenDelay(AbstractDelay): | |||
3. the three-dimensional delay data | |||
>>> delay = bm.FixedLenDelay((3, 2, 1), delay_len=1., dt=0.1, before_t0=lambda t: t) | |||
>>> delay(-0.6) | |||
>>> delay = bm.TimeDelay(bm.zeros((3, 2, 1)), delay_len=1., dt=0.1, before_t0=lambda t: t) | |||
>>> delay(-0.8) | |||
[[[-0.8] | |||
[-0.8]] | |||
[[-0.8] | |||
@@ -68,8 +74,8 @@ class FixedLenDelay(AbstractDelay): | |||
Parameters | |||
---------- | |||
shape: int, sequence of int | |||
The delay data shape. | |||
inits: int, sequence of int | |||
The initial delay data. | |||
t0: float, int | |||
The zero time. | |||
delay_len: float, int | |||
@@ -83,155 +89,225 @@ class FixedLenDelay(AbstractDelay): | |||
of :math:`(num_delay, ...)`, where the longest delay data is aranged in | |||
the first index. | |||
name: str | |||
The delay instance name. | |||
interp_method: str | |||
The way to deal with the delay at the time which is not integer times of the time step. | |||
For exameple, if the time step ``dt=0.1``, the time delay length ``delay_len=1.``, | |||
when users require the delay data at ``t-0.53``, we can deal this situation with | |||
the following methods: | |||
- ``"linear_interp"``: using linear interpolation to get the delay value | |||
at the required time (default). | |||
- ``"round"``: round the time to make it is the integer times of the time step. For | |||
the above situation, we will use the time at ``t-0.5`` to approximate the delay data | |||
at ``t-0.53``. | |||
.. versionadded:: 2.1.1 | |||
See Also | |||
-------- | |||
LengthDelay | |||
""" | |||
def __init__( | |||
self, | |||
shape: Union[int, Tuple[int, ...]], | |||
inits: Union[ndarray, jnp.ndarray], | |||
delay_len: Union[float, int], | |||
before_t0: Union[Callable, bm.ndarray, jnp.ndarray, float, int] = None, | |||
before_t0: Union[Callable, ndarray, jnp.ndarray, float, int] = None, | |||
t0: Union[float, int] = 0., | |||
dt: Union[float, int] = None, | |||
name: str = None, | |||
dtype=None, | |||
interp_method='linear_interp', | |||
): | |||
super(FixedLenDelay, self).__init__(name=name) | |||
super(TimeDelay, self).__init__(name=name) | |||
# shape | |||
self.shape = to_size(shape) | |||
self.dtype = dtype | |||
assert isinstance(inits, (ndarray, np.ndarray)), (f'Must be an instance of brainpy.math.ndarray ' | |||
f'or jax.numpy.ndarray. But we got {type(inits)}') | |||
self.shape = inits.shape | |||
# delay_len | |||
self.t0 = t0 | |||
self._dt = bm.get_dt() if dt is None else dt | |||
self.dt = get_dt() if dt is None else dt | |||
check_float(delay_len, 'delay_len', allow_none=False, allow_int=True, min_bound=0.) | |||
self._delay_len = delay_len | |||
self.delay_len = delay_len + self._dt | |||
self.num_delay_steps = int(bm.ceil(self.delay_len / self._dt).value) | |||
self.delay_len = delay_len | |||
self.num_delay_step = int(ops.ceil(self.delay_len / self.dt).value) + 1 | |||
# interp method | |||
if interp_method not in [_INTERP_LINEAR, _INTERP_ROUND]: | |||
raise UnsupportedError(f'Un-supported interpolation method {interp_method}, ' | |||
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}') | |||
self.interp_method = interp_method | |||
# other variables | |||
self._idx = bm.Variable(bm.asarray([0])) | |||
# time variables | |||
self.idx = ops.Variable(ops.asarray([0])) | |||
check_float(t0, 't0', allow_none=False, allow_int=True, ) | |||
self._current_time = bm.Variable(bm.asarray([t0])) | |||
self.current_time = Variable(ops.asarray([t0])) | |||
# delay data | |||
self._data = bm.Variable(bm.zeros((self.num_delay_steps,) + self.shape, dtype=dtype)) | |||
self.data = Variable(ops.zeros((self.num_delay_step,) + self.shape, | |||
dtype=inits.dtype)) | |||
if before_t0 is None: | |||
self._before_type = _DATA_BEFORE | |||
elif callable(before_t0): | |||
self._before_t0 = lambda t: jnp.asarray(bm.broadcast_to(before_t0(t), self.shape).value, | |||
dtype=self.dtype) | |||
self._before_t0 = lambda t: jnp.asarray(ops.broadcast_to(before_t0(t), self.shape).value, | |||
dtype=inits.dtype) | |||
self._before_type = _FUNC_BEFORE | |||
elif isinstance(before_t0, (bm.ndarray, jnp.ndarray, float, int)): | |||
elif isinstance(before_t0, (ndarray, jnp.ndarray, float, int)): | |||
self._before_type = _DATA_BEFORE | |||
try: | |||
self._data[:] = before_t0 | |||
except: | |||
raise ValueError(f'Cannot set delay data by using "before_t0". ' | |||
f'The delay data has the shape of ' | |||
f'{((self.num_delay_steps,) + self.shape)}, while ' | |||
f'we got "before_t0" of {bm.asarray(before_t0).shape}. ' | |||
f'They are not compatible. Note that the delay length ' | |||
f'{self._delay_len} will automatically add a dt {self.dt} ' | |||
f'to {self.delay_len}.') | |||
self.data[:-1] = before_t0 | |||
else: | |||
raise ValueError(f'"before_t0" does not support {type(before_t0)}: before_t0') | |||
@property | |||
def idx(self): | |||
return self._idx | |||
@idx.setter | |||
def idx(self, value): | |||
raise ValueError('Cannot set "idx" by users.') | |||
@property | |||
def dt(self): | |||
return self._dt | |||
@dt.setter | |||
def dt(self, value): | |||
raise ValueError('Cannot set "dt" by users.') | |||
@property | |||
def data(self): | |||
return self._data | |||
@data.setter | |||
def data(self, value): | |||
self._data[:] = value | |||
raise ValueError(f'"before_t0" does not support {type(before_t0)}') | |||
# set initial data | |||
self.data[-1] = inits | |||
@property | |||
def current_time(self): | |||
return self._current_time[0] | |||
# interpolation function | |||
self.f = jnp.interp | |||
for dim in range(1, len(self.shape) + 1, 1): | |||
self.f = vmap(self.f, in_axes=(None, None, dim), out_axes=dim - 1) | |||
def _check_time(self, times, transforms): | |||
prev_time, current_time = times | |||
current_time = bm.as_device_array(current_time) | |||
prev_time = bm.as_device_array(prev_time) | |||
if prev_time > current_time: | |||
current_time = current_time[0] | |||
if prev_time > current_time + 1e-6: | |||
raise ValueError(f'\n' | |||
f'!!! Error in {self.__class__.__name__}: \n' | |||
f'The request time should be less than the ' | |||
f'current time {current_time}. But we ' | |||
f'got {prev_time} > {current_time}') | |||
lower_time = jnp.asarray(current_time - self.delay_len) | |||
if prev_time < lower_time: | |||
lower_time = current_time - self.delay_len | |||
if prev_time < lower_time - self.dt: | |||
raise ValueError(f'\n' | |||
f'!!! Error in {self.__class__.__name__}: \n' | |||
f'The request time of the variable should be in ' | |||
f'[{lower_time}, {current_time}], but we got {prev_time}') | |||
def __call__(self, prev_time): | |||
def __call__(self, time, indices=None): | |||
# check | |||
id_tap(self._check_time, (prev_time, self.current_time)) | |||
if check.is_checking(): | |||
id_tap(self._check_time, (time, self.current_time)) | |||
if self._before_type == _FUNC_BEFORE: | |||
return cond(prev_time < self.t0, | |||
return cond(time < self.t0, | |||
self._before_t0, | |||
self._fn1, | |||
prev_time) | |||
self._after_t0, | |||
time) | |||
else: | |||
return self._fn1(prev_time) | |||
def _fn1(self, prev_time): | |||
diff = self.delay_len - (self.current_time - prev_time) | |||
if isinstance(diff, bm.ndarray): diff = diff.value | |||
req_num_step = jnp.asarray(diff / self._dt, dtype=bm.get_dint()) | |||
extra = diff - req_num_step * self._dt | |||
return self._after_t0(time) | |||
def _after_t0(self, prev_time): | |||
diff = self.delay_len - (self.current_time[0] - prev_time) | |||
if isinstance(diff, ndarray): | |||
diff = diff.value | |||
if self.interp_method == _INTERP_LINEAR: | |||
req_num_step = jnp.asarray(diff / self.dt, dtype=ops.int32) | |||
extra = diff - req_num_step * self.dt | |||
return cond(extra == 0., self._true_fn, self._false_fn, (req_num_step, extra)) | |||
elif self.interp_method == _INTERP_ROUND: | |||
req_num_step = jnp.asarray(jnp.round(diff / self.dt), dtype=ops.int32) | |||
return self._true_fn([req_num_step, 0.]) | |||
else: | |||
raise UnsupportedError(f'Un-supported interpolation method {self.interp_method}, ' | |||
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}') | |||
def _true_fn(self, div_mod): | |||
req_num_step, extra = div_mod | |||
return self._data[self.idx[0] + req_num_step] | |||
return self.data[self.idx[0] + req_num_step] | |||
def _false_fn(self, div_mod): | |||
req_num_step, extra = div_mod | |||
f = jnp.interp | |||
for dim in range(1, len(self.shape) + 1, 1): | |||
f = vmap(f, in_axes=(None, None, dim), out_axes=dim - 1) | |||
idx = jnp.asarray([self.idx[0] + req_num_step, | |||
self.idx[0] + req_num_step + 1]) | |||
idx %= self.num_delay_steps | |||
return f(extra, jnp.asarray([0., self._dt]), self._data[idx]) | |||
idx %= self.num_delay_step | |||
return self.f(extra, jnp.asarray([0., self.dt]), self.data[idx]) | |||
def update(self, time, value): | |||
self._data[self._idx[0]] = value | |||
# check_float(time, 'time', allow_none=False, allow_int=True) | |||
self._current_time[0] = time | |||
self._idx.value = (self._idx + 1) % self.num_delay_steps | |||
self.data[self.idx[0]] = value | |||
self.current_time[0] = time | |||
self.idx.value = (self.idx + 1) % self.num_delay_step | |||
class NeutralDelay(TimeDelay): | |||
pass | |||
class VariedLenDelay(AbstractDelay): | |||
"""Delay variable which has a functional delay | |||
class LengthDelay(AbstractDelay): | |||
"""Delay variable which has a fixed delay length. | |||
Parameters | |||
---------- | |||
inits: int, sequence of int | |||
The initial delay data. | |||
delay_len: int | |||
The maximum delay length. | |||
delay_data: Tensor | |||
The delay data. | |||
name: str | |||
The delay object name. | |||
See Also | |||
-------- | |||
TimeDelay | |||
""" | |||
def update(self, time, value): | |||
pass | |||
def __init__( | |||
self, | |||
inits: Union[ndarray, jnp.ndarray], | |||
delay_len: int, | |||
delay_data: Union[ndarray, jnp.ndarray, float, int] = None, | |||
name: str = None, | |||
): | |||
super(LengthDelay, self).__init__(name=name) | |||
self.init(inits, delay_len, delay_data) | |||
def init(self, inits, delay_len, delay_data): | |||
assert isinstance(inits, (ndarray, np.ndarray)), (f'Must be an instance of brainpy.math.ndarray ' | |||
f'or jax.numpy.ndarray. But we got {type(inits)}') | |||
self.shape = inits.shape | |||
def __init__(self): | |||
super(VariedLenDelay, self).__init__() | |||
# delay_len | |||
check_integer(delay_len, 'delay_len', allow_none=False, min_bound=0) | |||
self.delay_len = delay_len | |||
self.num_delay_step = delay_len + 1 | |||
# time variables | |||
self.idx = Variable(ops.asarray([0], dtype=ops.int32)) | |||
class NeutralDelay(FixedLenDelay): | |||
# delay data | |||
self.data = Variable(ops.zeros((self.num_delay_step,) + self.shape, | |||
dtype=inits.dtype)) | |||
if delay_data is None: | |||
pass | |||
elif isinstance(delay_data, (ndarray, jnp.ndarray, float, int)): | |||
self.data[:-1] = delay_data | |||
else: | |||
raise ValueError(f'"delay_data" does not support {type(delay_data)}') | |||
def _check_delay(self, delay_len, transforms): | |||
if isinstance(delay_len, ndarray): | |||
delay_len = delay_len.value | |||
if np.any(delay_len >= self.num_delay_step): | |||
raise ValueError(f'\n' | |||
f'!!! Error in {self.__class__.__name__}: \n' | |||
f'The request delay length should be less than the ' | |||
f'maximum delay {self.delay_len}. But we ' | |||
f'got {delay_len}') | |||
def __call__(self, delay_len, indices=None): | |||
# check | |||
if check.is_checking(): | |||
id_tap(self._check_delay, delay_len) | |||
# the delay length | |||
delay_idx = (self.idx[0] - delay_len - 1) % self.num_delay_step | |||
if delay_idx.dtype not in [ops.int32, ops.int64]: | |||
raise ValueError(f'"delay_len" must be integer, but we got {delay_len}') | |||
# the delay data | |||
if indices is None: | |||
return self.data[delay_idx] | |||
else: | |||
return self.data[delay_idx, indices] | |||
def update(self, value): | |||
if ops.shape(value) != self.shape: | |||
raise ValueError(f'value shape should be {self.shape}, but we got {ops.shape(value)}') | |||
self.data[self.idx[0]] = value | |||
self.idx.value = (self.idx + 1) % self.num_delay_step |
@@ -79,7 +79,7 @@ __all__ = [ | |||
'setxor1d', 'tensordot', 'trim_zeros', 'union1d', 'unravel_index', 'unwrap', 'take_along_axis', | |||
# others | |||
'clip_by_norm', 'as_device_array', 'as_variable', 'as_jaxarray', 'as_numpy', | |||
'clip_by_norm', 'as_device_array', 'as_variable', 'as_numpy', | |||
] | |||
_min = min | |||
@@ -89,6 +89,10 @@ _max = max | |||
# others | |||
# ------ | |||
# def as_jax_array(tensor): | |||
# return asarray(tensor) | |||
def as_device_array(tensor): | |||
if isinstance(tensor, JaxArray): | |||
return tensor.value | |||
@@ -111,10 +115,6 @@ def as_variable(tensor): | |||
return Variable(asarray(tensor)) | |||
def as_jaxarray(tensor): | |||
return asarray(tensor) | |||
def _remove_jaxarray(obj): | |||
if isinstance(obj, JaxArray): | |||
return obj.value | |||
@@ -1507,10 +1507,10 @@ def vander(x, N=None, increasing=False): | |||
def fill_diagonal(a, val): | |||
a = _remove_jaxarray(a) | |||
assert a.ndim >= 2 | |||
assert isinstance(a, JaxArray), f'Must be a JaxArray, but got {type(a)}' | |||
assert a.ndim >= 2, f'Only support tensor has dimension >= 2, but got {a.shape}' | |||
i, j = jnp.diag_indices(_min(a.shape[-2:])) | |||
return JaxArray(a.at[..., i, j].set(val)) | |||
a._value = a.value.at[..., i, j].set(val) | |||
# indexing funcs | |||
@@ -27,6 +27,7 @@ from brainpy import errors | |||
from brainpy.base.base import Base | |||
from brainpy.base.collector import TensorCollector | |||
from brainpy.math.random import RandomState | |||
from brainpy.math.jaxarray import JaxArray | |||
from brainpy.tools.codes import change_func_name | |||
__all__ = [ | |||
@@ -35,29 +36,31 @@ __all__ = [ | |||
] | |||
def _make_vmap(func, dyn_vars, rand_vars, in_axes, out_axes, | |||
batch_idx, axis_name, reduce_func, f_name=None): | |||
def _make_vmap(func, nonbatched_vars, batched_vars, in_axes, out_axes, | |||
batch_idx, axis_name, f_name=None): | |||
@functools.partial(jax.vmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name) | |||
def vmapped_func(dyn_data, rand_data, *args, **kwargs): | |||
dyn_vars.assign(dyn_data) | |||
rand_vars.assign(rand_data) | |||
def vmapped_func(nonbatched_data, batched_data, *args, **kwargs): | |||
nonbatched_vars.assign(nonbatched_data) | |||
batched_vars.assign(batched_data) | |||
out = func(*args, **kwargs) | |||
dyn_changes = dyn_vars.dict() | |||
rand_changes = rand_vars.dict() | |||
return out, dyn_changes, rand_changes | |||
nonbatched_changes = nonbatched_vars.dict() | |||
batched_changes = batched_vars.dict() | |||
return nonbatched_changes, batched_changes, out | |||
def call(*args, **kwargs): | |||
dyn_data = dyn_vars.dict() | |||
n = args[batch_idx[0]].shape[batch_idx[1]] | |||
rand_data = {key: val.split_keys(n) for key, val in rand_vars.items()} | |||
nonbatched_data = nonbatched_vars.dict() | |||
batched_data = {key: val.split_keys(n) for key, val in batched_vars.items()} | |||
try: | |||
out, dyn_changes, rand_changes = vmapped_func(dyn_data, rand_data, *args, **kwargs) | |||
out, dyn_changes, rand_changes = vmapped_func(nonbatched_data, batched_data, *args, **kwargs) | |||
except UnexpectedTracerError as e: | |||
dyn_vars.assign(dyn_data) | |||
rand_vars.assign(rand_data) | |||
raise errors.JaxTracerError(variables=dyn_vars) from e | |||
for key, v in dyn_changes.items(): dyn_vars[key] = reduce_func(v) | |||
for key, v in rand_changes.items(): rand_vars[key] = reduce_func(v) | |||
nonbatched_vars.assign(nonbatched_data) | |||
batched_vars.assign(batched_data) | |||
raise errors.JaxTracerError() from e | |||
# for key, v in dyn_changes.items(): | |||
# dyn_vars[key] = reduce_func(v) | |||
# for key, v in rand_changes.items(): | |||
# rand_vars[key] = reduce_func(v) | |||
return out | |||
return change_func_name(name=f_name, f=call) if f_name else call | |||
@@ -77,7 +80,7 @@ def vmap(func, dyn_vars=None, batched_vars=None, | |||
---------- | |||
func : Base, function, callable | |||
The function or the module to compile. | |||
dyn_vars : dict | |||
dyn_vars : dict, sequence | |||
batched_vars : dict | |||
in_axes : optional, int, sequence of int | |||
Specify which input array axes to map over. If each positional argument to | |||
@@ -207,13 +210,19 @@ def vmap(func, dyn_vars=None, batched_vars=None, | |||
axis_name=axis_name) | |||
else: | |||
if isinstance(dyn_vars, JaxArray): | |||
dyn_vars = [dyn_vars] | |||
if isinstance(dyn_vars, (tuple, list)): | |||
dyn_vars = {f'_vmap_v{i}': v for i, v in enumerate(dyn_vars)} | |||
assert isinstance(dyn_vars, dict) | |||
# dynamical variables | |||
dyn_vars, rand_vars = TensorCollector(), TensorCollector() | |||
_dyn_vars, _rand_vars = TensorCollector(), TensorCollector() | |||
for key, val in dyn_vars.items(): | |||
if isinstance(val, RandomState): | |||
rand_vars[key] = val | |||
_rand_vars[key] = val | |||
else: | |||
dyn_vars[key] = val | |||
_dyn_vars[key] = val | |||
# in axes | |||
if in_axes is None: | |||
@@ -249,13 +258,12 @@ def vmap(func, dyn_vars=None, batched_vars=None, | |||
# jit function | |||
return _make_vmap(func=func, | |||
dyn_vars=dyn_vars, | |||
rand_vars=rand_vars, | |||
nonbatched_vars=_dyn_vars, | |||
batched_vars=_rand_vars, | |||
in_axes=in_axes, | |||
out_axes=out_axes, | |||
axis_name=axis_name, | |||
batch_idx=batch_idx, | |||
reduce_func=reduce_func) | |||
batch_idx=batch_idx) | |||
else: | |||
raise errors.BrainPyError(f'Only support instance of {Base.__name__}, or a callable ' | |||
@@ -1,71 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
import jax.numpy as jnp | |||
from brainpy.tools.checking import check_integer | |||
__all__ = [ | |||
'poch', | |||
'Gamma', | |||
'Beta', | |||
] | |||
def poch(a, n): | |||
""" Returns the Pochhammer symbol (a)_n. """ | |||
# First, check if 'a' is a real number (this is currently only working for reals). | |||
assert not isinstance(a, complex), "a must be real: %r" % a | |||
check_integer(n, allow_none=False, min_bound=0) | |||
# Compute the Pochhammer symbol. | |||
return 1.0 if n == 0 else jnp.prod(jnp.arange(n) + a) | |||
def Gamma(z): | |||
""" Paul Godfrey's Gamma function implementation valid for z complex. | |||
This is converted from Godfrey's Gamma.m Matlab file available at | |||
https://www.mathworks.com/matlabcentral/fileexchange/3572-gamma. | |||
15 significant digits of accuracy for real z and 13 | |||
significant digits for other values. | |||
""" | |||
zz = z | |||
# Find negative real parts of z and make them positive. | |||
if isinstance(z, (complex, jnp.complex64, jnp.complex128)): | |||
Z = [z.real, z.imag] | |||
if Z[0] < 0: | |||
Z[0] = -Z[0] | |||
z = jnp.asarray(Z) | |||
z = z.astype(complex) | |||
g = 607 / 128. | |||
c = jnp.asarray([0.99999999999999709182, 57.156235665862923517, -59.597960355475491248, | |||
14.136097974741747174, -0.49191381609762019978, .33994649984811888699e-4, | |||
.46523628927048575665e-4, -.98374475304879564677e-4, .15808870322491248884e-3, | |||
-.21026444172410488319e-3, .21743961811521264320e-3, -.16431810653676389022e-3, | |||
.84418223983852743293e-4, -.26190838401581408670e-4, .36899182659531622704e-5]) | |||
if z == 0 or z == 1: | |||
return 1. | |||
if ((jnp.round(zz) == zz) | |||
and (zz.imag == 0) | |||
and (zz.real <= 0)): # Adjust for negative poles. | |||
return jnp.inf | |||
z = z - 1 | |||
zh = z + 0.5 | |||
zgh = zh + g | |||
zp = zgh ** (zh * 0.5) # Trick for avoiding floating-point overflow above z = 141. | |||
idx = jnp.arange(len(c) - 1, 0, -1) | |||
ss = jnp.sum(c[idx] / (z + idx)) | |||
sq2pi = 2.5066282746310005024157652848110 | |||
f = (sq2pi * (c[0] + ss)) * ((zp * jnp.exp(-zgh)) * zp) | |||
if isinstance(zz, (complex, jnp.complex64, jnp.complex128)): | |||
return f.astype(complex) | |||
elif isinstance(zz, int) and zz >= 0: | |||
f = jnp.round(f) | |||
return f.astype(int) | |||
else: | |||
return f | |||
def Beta(x, y): | |||
""" Beta function using Godfrey's Gamma function. """ | |||
return Gamma(x) * Gamma(y) / Gamma(x + y) |
@@ -5,50 +5,66 @@ import unittest | |||
import brainpy.math as bm | |||
class TestFixedLenDelay(unittest.TestCase): | |||
class TestTimeDelay(unittest.TestCase): | |||
def test_dim1(self): | |||
bm.enable_x64() | |||
# linear interp | |||
t0 = 0. | |||
before_t0 = bm.repeat(bm.arange(11).reshape((-1, 1)), 10, axis=1) | |||
delay = bm.FixedLenDelay(10, delay_len=1., t0=t0, dt=0.1, before_t0=before_t0) | |||
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones(10) * 10)) | |||
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones(10) * 9.5)) | |||
before_t0 = bm.repeat(bm.arange(10).reshape((-1, 1)), 10, axis=1) | |||
delay = bm.TimeDelay(bm.zeros(10), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0) | |||
print(delay(t0 - 0.1)) | |||
print(delay(t0 - 0.15)) | |||
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones(10) * 9.)) | |||
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones(10) * 8.5)) | |||
print() | |||
print(delay(t0 - 0.23)) | |||
print(delay(t0 - 0.23) - bm.ones(10) * 8.7) | |||
# self.assertTrue(bm.array_equal(delay(t0 - 0.23), bm.ones(10) * 8.7)) | |||
# round interp | |||
delay = bm.TimeDelay(bm.zeros(10), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0, | |||
interp_method='round') | |||
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones(10) * 9)) | |||
print(delay(t0 - 0.15)) | |||
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones(10) * 8)) | |||
self.assertTrue(bm.array_equal(delay(t0 - 0.2), bm.ones(10) * 8)) | |||
def test_dim2(self): | |||
t0 = 0. | |||
before_t0 = bm.repeat(bm.arange(11).reshape((-1, 1)), 10, axis=1) | |||
before_t0 = bm.repeat(before_t0.reshape((11, 10, 1)), 5, axis=2) | |||
delay = bm.FixedLenDelay((10, 5), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0) | |||
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones((10, 5)) * 10)) | |||
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones((10, 5)) * 9.5)) | |||
before_t0 = bm.repeat(bm.arange(10).reshape((-1, 1)), 10, axis=1) | |||
before_t0 = bm.repeat(before_t0.reshape((10, 10, 1)), 5, axis=2) | |||
delay = bm.TimeDelay(bm.zeros((10, 5)), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0) | |||
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones((10, 5)) * 9)) | |||
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones((10, 5)) * 8.5)) | |||
# self.assertTrue(bm.array_equal(delay(t0 - 0.23), bm.ones((10, 5)) * 8.7)) | |||
def test_dim3(self): | |||
t0 = 0. | |||
before_t0 = bm.repeat(bm.arange(11).reshape((-1, 1)), 10, axis=1) | |||
before_t0 = bm.repeat(before_t0.reshape((11, 10, 1)), 5, axis=2) | |||
before_t0 = bm.repeat(before_t0.reshape((11, 10, 5, 1)), 3, axis=3) | |||
delay = bm.FixedLenDelay((10, 5, 3), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0) | |||
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones((10, 5, 3)) * 10)) | |||
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones((10, 5, 3)) * 9.5)) | |||
before_t0 = bm.repeat(bm.arange(10).reshape((-1, 1)), 10, axis=1) | |||
before_t0 = bm.repeat(before_t0.reshape((10, 10, 1)), 5, axis=2) | |||
before_t0 = bm.repeat(before_t0.reshape((10, 10, 5, 1)), 3, axis=3) | |||
delay = bm.TimeDelay(bm.zeros((10, 5, 3)), delay_len=1., t0=t0, dt=0.1, before_t0=before_t0) | |||
self.assertTrue(bm.array_equal(delay(t0 - 0.1), bm.ones((10, 5, 3)) * 9)) | |||
self.assertTrue(bm.array_equal(delay(t0 - 0.15), bm.ones((10, 5, 3)) * 8.5)) | |||
# self.assertTrue(bm.array_equal(delay(t0 - 0.23), bm.ones((10, 5, 3)) * 8.7)) | |||
def test1(self): | |||
print() | |||
delay = bm.FixedLenDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t) | |||
delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1, before_t0=lambda t: t) | |||
print(delay(-0.2)) | |||
delay = bm.FixedLenDelay((3, 2), delay_len=1., dt=0.1, before_t0=lambda t: t) | |||
delay = bm.TimeDelay(bm.zeros((3, 2)), delay_len=1., dt=0.1, before_t0=lambda t: t) | |||
print(delay(-0.6)) | |||
delay = bm.FixedLenDelay((3, 2, 1), delay_len=1., dt=0.1, before_t0=lambda t: t) | |||
delay = bm.TimeDelay(bm.zeros((3, 2, 1)), delay_len=1., dt=0.1, before_t0=lambda t: t) | |||
print(delay(-0.8)) | |||
def test_current_time2(self): | |||
print() | |||
delay = bm.FixedLenDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t) | |||
delay = bm.TimeDelay(bm.zeros(3), delay_len=1., dt=0.1, before_t0=lambda t: t) | |||
print(delay(0.)) | |||
before_t0 = bm.repeat(bm.arange(11).reshape((-1, 1)), 10, axis=1) | |||
before_t0 = bm.repeat(before_t0.reshape((11, 10, 1)), 5, axis=2) | |||
delay = bm.FixedLenDelay((10, 5), delay_len=1., dt=0.1, before_t0=before_t0) | |||
before_t0 = bm.repeat(bm.arange(10).reshape((-1, 1)), 10, axis=1) | |||
before_t0 = bm.repeat(before_t0.reshape((10, 10, 1)), 5, axis=2) | |||
delay = bm.TimeDelay(bm.zeros((10, 5)), delay_len=1., dt=0.1, before_t0=before_t0) | |||
print(delay(0.)) | |||
# def test_prev_time_beyond_boundary(self): | |||
@@ -56,3 +72,42 @@ class TestFixedLenDelay(unittest.TestCase): | |||
# delay = bm.FixedLenDelay(3, delay_len=1., dt=0.1, before_t0=lambda t: t) | |||
# delay(-1.2) | |||
class TestLengthDelay(unittest.TestCase): | |||
def test1(self): | |||
dim = 3 | |||
delay = bm.LengthDelay(bm.zeros(dim), 10) | |||
print(delay(1)) | |||
self.assertTrue(bm.array_equal(delay(1), bm.zeros(dim))) | |||
delay = bm.jit(delay) | |||
print(delay(1)) | |||
self.assertTrue(bm.array_equal(delay(1), bm.zeros(dim))) | |||
def test2(self): | |||
dim = 3 | |||
delay = bm.LengthDelay(bm.zeros(dim), 10, delay_data=bm.arange(1, 11).reshape((10, 1))) | |||
print(delay(0)) | |||
self.assertTrue(bm.array_equal(delay(0), bm.zeros(dim))) | |||
print(delay(1)) | |||
self.assertTrue(bm.array_equal(delay(1), bm.ones(dim) * 10)) | |||
delay = bm.jit(delay) | |||
print(delay(0)) | |||
self.assertTrue(bm.array_equal(delay(0), bm.zeros(dim))) | |||
print(delay(1)) | |||
self.assertTrue(bm.array_equal(delay(1), bm.ones(dim) * 10)) | |||
def test3(self): | |||
dim = 3 | |||
delay = bm.LengthDelay(bm.zeros(dim), 10, delay_data=bm.arange(1, 11).reshape((10, 1))) | |||
print(delay(bm.asarray([1, 2, 3]), | |||
bm.arange(3))) | |||
# self.assertTrue(bm.array_equal(delay(0), bm.zeros(dim))) | |||
delay = bm.jit(delay) | |||
print(delay(bm.asarray([1, 2, 3]), | |||
bm.arange(3))) | |||
# self.assertTrue(bm.array_equal(delay(1), bm.ones(dim) * 10)) | |||
@@ -5,207 +5,6 @@ This module aims to provide commonly used analysis methods for simulated neurona | |||
You can access them through ``brainpy.measure.XXX``. | |||
""" | |||
from .correlation import * | |||
from .firings import * | |||
import numpy as np | |||
from brainpy import tools, math | |||
__all__ = [ | |||
'cross_correlation', | |||
'voltage_fluctuation', | |||
'raster_plot', | |||
'firing_rate', | |||
] | |||
# @tools.numba_jit | |||
def _cc(states, i, j): | |||
sqrt_ij = np.sqrt(np.sum(states[i]) * np.sum(states[j])) | |||
k = 0. if sqrt_ij == 0. else np.sum(states[i] * states[j]) / sqrt_ij | |||
return k | |||
def cross_correlation(spikes, bin, dt=None): | |||
r"""Calculate cross correlation index between neurons. | |||
The coherence [1]_ between two neurons i and j is measured by their | |||
cross-correlation of spike trains at zero time lag within a time bin | |||
of :math:`\Delta t = \tau`. More specifically, suppose that a long | |||
time interval T is divided into small bins of :math:`\Delta t` and | |||
that two spike trains are given by :math:`X(l)=` 0 or 1, :math:`Y(l)=` 0 | |||
or 1, :math:`l=1,2, \ldots, K(T / K=\tau)`. Thus, we define a coherence | |||
measure for the pair as: | |||
.. math:: | |||
\kappa_{i j}(\tau)=\frac{\sum_{l=1}^{K} X(l) Y(l)} | |||
{\sqrt{\sum_{l=1}^{K} X(l) \sum_{l=1}^{K} Y(l)}} | |||
The population coherence measure :math:`\kappa(\tau)` is defined by the | |||
average of :math:`\kappa_{i j}(\tau)` over many pairs of neurons in the | |||
network. | |||
Parameters | |||
---------- | |||
spikes : | |||
The history of spike states of the neuron group. | |||
It can be easily get via `StateMonitor(neu, ['spike'])`. | |||
bin : float, int | |||
The time bin to normalize spike states. | |||
dt : float, optional | |||
The time precision. | |||
Returns | |||
------- | |||
cc_index : float | |||
The cross correlation value which represents the synchronization index. | |||
References | |||
---------- | |||
.. [1] Wang, Xiao-Jing, and György Buzsáki. "Gamma oscillation by synaptic | |||
inhibition in a hippocampal interneuronal network model." Journal of | |||
neuroscience 16.20 (1996): 6402-6413. | |||
""" | |||
spikes = np.asarray(spikes) | |||
dt = math.get_dt() if dt is None else dt | |||
bin_size = int(bin / dt) | |||
num_hist, num_neu = spikes.shape | |||
num_bin = int(np.ceil(num_hist / bin_size)) | |||
if num_bin * bin_size != num_hist: | |||
spikes = np.append(spikes, np.zeros((num_bin * bin_size - num_hist, num_neu)), axis=0) | |||
states = spikes.T.reshape((num_neu, num_bin, bin_size)) | |||
states = (np.sum(states, axis=2) > 0.).astype(np.float_) | |||
all_k = [] | |||
for i in range(num_neu): | |||
for j in range(i + 1, num_neu): | |||
all_k.append(_cc(states, i, j)) | |||
return np.mean(all_k) | |||
# @tools.numba_jit | |||
def _var(neu_signal): | |||
return np.mean(neu_signal * neu_signal) - np.mean(neu_signal) ** 2 | |||
def voltage_fluctuation(potentials): | |||
r"""Calculate neuronal synchronization via voltage variance. | |||
The method comes from [1]_ [2]_ [3]_. | |||
First, average over the membrane potential :math:`V` | |||
.. math:: | |||
V(t) = \frac{1}{N} \sum_{i=1}^{N} V_i(t) | |||
The variance of the time fluctuations of :math:`V(t)` is | |||
.. math:: | |||
\sigma_V^2 = \left\langle \left[ V(t) \right]^2 \right\rangle_t - | |||
\left[ \left\langle V(t) \right\rangle_t \right]^2 | |||
where :math:`\left\langle \ldots \right\rangle_t = (1 / T_m) \int_0^{T_m} dt \, \ldots` | |||
denotes time-averaging over a large time, :math:`\tau_m`. After normalization | |||
of :math:`\sigma_V` to the average over the population of the single cell | |||
membrane potentials | |||
.. math:: | |||
\sigma_{V_i}^2 = \left\langle\left[ V_i(t) \right]^2 \right\rangle_t - | |||
\left[ \left\langle V_i(t) \right\rangle_t \right]^2 | |||
one defines a synchrony measure, :math:`\chi (N)`, for the activity of a system | |||
of :math:`N` neurons by: | |||
.. math:: | |||
\chi^2 \left( N \right) = \frac{\sigma_V^2}{ \frac{1}{N} \sum_{i=1}^N | |||
\sigma_{V_i}^2} | |||
Parameters | |||
---------- | |||
potentials : | |||
The membrane potential matrix of the neuron group. | |||
Returns | |||
------- | |||
sync_index : float | |||
The synchronization index. | |||
References | |||
---------- | |||
.. [1] Golomb, D. and Rinzel J. (1993) Dynamics of globally coupled | |||
inhibitory neurons with heterogeneity. Phys. Rev. reversal_potential 48:4810-4814. | |||
.. [2] Golomb D. and Rinzel J. (1994) Clustering in globally coupled | |||
inhibitory neurons. Physica D 72:259-282. | |||
.. [3] David Golomb (2007) Neuronal synchrony measures. Scholarpedia, 2(1):1347. | |||
""" | |||
potentials = np.asarray(potentials) | |||
num_hist, num_neu = potentials.shape | |||
avg = np.mean(potentials, axis=1) | |||
avg_var = np.mean(avg * avg) - np.mean(avg) ** 2 | |||
neu_vars = [] | |||
for i in range(num_neu): | |||
neu_vars.append(_var(potentials[:, i])) | |||
var_mean = np.mean(neu_vars) | |||
return avg_var / var_mean if var_mean != 0. else 1. | |||
def raster_plot(sp_matrix, times): | |||
"""Get spike raster plot which displays the spiking activity | |||
of a group of neurons over time. | |||
Parameters | |||
---------- | |||
sp_matrix : bnp.ndarray | |||
The matrix which record spiking activities. | |||
times : bnp.ndarray | |||
The time steps. | |||
Returns | |||
------- | |||
raster_plot : tuple | |||
Include (neuron index, spike time). | |||
""" | |||
sp_matrix = np.asarray(sp_matrix) | |||
times = np.asarray(times) | |||
elements = np.where(sp_matrix > 0.) | |||
index = elements[1] | |||
time = times[elements[0]] | |||
return index, time | |||
def firing_rate(sp_matrix, width, dt=None): | |||
r"""Calculate the mean firing rate over in a neuron group. | |||
This method is adopted from Brian2. | |||
The firing rate in trial :math:`k` is the spike count :math:`n_{k}^{sp}` | |||
in an interval of duration :math:`T` divided by :math:`T`: | |||
.. math:: | |||
v_k = {n_k^{sp} \over T} | |||
Parameters | |||
---------- | |||
sp_matrix : math.JaxArray, np.ndarray | |||
The spike matrix which record spiking activities. | |||
width : int, float | |||
The width of the ``window`` in millisecond. | |||
dt : float, optional | |||
The sample rate. | |||
Returns | |||
------- | |||
rate : numpy.ndarray | |||
The population rate in Hz, smoothed with the given window. | |||
""" | |||
sp_matrix = np.asarray(sp_matrix) | |||
rate = np.sum(sp_matrix, axis=1) / sp_matrix.shape[1] | |||
dt = math.get_dt() if dt is None else dt | |||
width1 = int(width / 2 / dt) * 2 + 1 | |||
window = np.ones(width1) * 1000 / width | |||
return np.convolve(rate, window, mode='same') |
@@ -0,0 +1,270 @@ | |||
# -*- coding: utf-8 -*- | |||
from functools import partial | |||
import numpy as np | |||
from jax import vmap, jit, lax, numpy as jnp | |||
from brainpy import math as bm | |||
__all__ = [ | |||
'cross_correlation', | |||
'voltage_fluctuation', | |||
'matrix_correlation', | |||
'weighted_correlation', | |||
'functional_connectivity', | |||
'functional_connectivity_dynamics', | |||
] | |||
@jit | |||
@partial(vmap, in_axes=(None, 0, 0)) | |||
def _cc(states, i, j): | |||
sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j])) | |||
return lax.cond(sqrt_ij == 0., | |||
lambda _: 0., | |||
lambda ij: jnp.sum(states[i] * states[j]) / sqrt_ij, | |||
(i, j)) | |||
def cross_correlation(spikes, bin, dt=None): | |||
r"""Calculate cross correlation index between neurons. | |||
The coherence [1]_ between two neurons i and j is measured by their | |||
cross-correlation of spike trains at zero time lag within a time bin | |||
of :math:`\Delta t = \tau`. More specifically, suppose that a long | |||
time interval T is divided into small bins of :math:`\Delta t` and | |||
that two spike trains are given by :math:`X(l)=` 0 or 1, :math:`Y(l)=` 0 | |||
or 1, :math:`l=1,2, \ldots, K(T / K=\tau)`. Thus, we define a coherence | |||
measure for the pair as: | |||
.. math:: | |||
\kappa_{i j}(\tau)=\frac{\sum_{l=1}^{K} X(l) Y(l)} | |||
{\sqrt{\sum_{l=1}^{K} X(l) \sum_{l=1}^{K} Y(l)}} | |||
The population coherence measure :math:`\kappa(\tau)` is defined by the | |||
average of :math:`\kappa_{i j}(\tau)` over many pairs of neurons in the | |||
network. | |||
Parameters | |||
---------- | |||
spikes : | |||
The history of spike states of the neuron group. | |||
It can be easily get via `StateMonitor(neu, ['spike'])`. | |||
bin : float, int | |||
The time bin to normalize spike states. | |||
dt : float, optional | |||
The time precision. | |||
Returns | |||
------- | |||
cc_index : float | |||
The cross correlation value which represents the synchronization index. | |||
References | |||
---------- | |||
.. [1] Wang, Xiao-Jing, and György Buzsáki. "Gamma oscillation by synaptic | |||
inhibition in a hippocampal interneuronal network model." Journal of | |||
neuroscience 16.20 (1996): 6402-6413. | |||
""" | |||
spikes = bm.asarray(spikes) | |||
dt = bm.get_dt() if dt is None else dt | |||
bin_size = int(bin / dt) | |||
num_hist, num_neu = spikes.shape | |||
num_bin = int(np.ceil(num_hist / bin_size)) | |||
if num_bin * bin_size != num_hist: | |||
spikes = bm.append(spikes, bm.zeros((num_bin * bin_size - num_hist, num_neu)), axis=0) | |||
states = spikes.T.reshape((num_neu, num_bin, bin_size)) | |||
states = bm.asarray(bm.sum(states, axis=2) > 0., dtype=jnp.float_) | |||
indices = jnp.tril_indices(4, k=-1) | |||
return jnp.mean(_cc(states.value, *indices)) | |||
@partial(vmap, in_axes=(None, 0)) | |||
def _var(neu_signal, i): | |||
neu_signal = neu_signal[:, i] | |||
return jnp.mean(neu_signal * neu_signal) - jnp.mean(neu_signal) ** 2 | |||
@jit | |||
def voltage_fluctuation(potentials): | |||
r"""Calculate neuronal synchronization via voltage variance. | |||
The method comes from [1]_ [2]_ [3]_. | |||
First, average over the membrane potential :math:`V` | |||
.. math:: | |||
V(t) = \frac{1}{N} \sum_{i=1}^{N} V_i(t) | |||
The variance of the time fluctuations of :math:`V(t)` is | |||
.. math:: | |||
\sigma_V^2 = \left\langle \left[ V(t) \right]^2 \right\rangle_t - | |||
\left[ \left\langle V(t) \right\rangle_t \right]^2 | |||
where :math:`\left\langle \ldots \right\rangle_t = (1 / T_m) \int_0^{T_m} dt \, \ldots` | |||
denotes time-averaging over a large time, :math:`\tau_m`. After normalization | |||
of :math:`\sigma_V` to the average over the population of the single cell | |||
membrane potentials | |||
.. math:: | |||
\sigma_{V_i}^2 = \left\langle\left[ V_i(t) \right]^2 \right\rangle_t - | |||
\left[ \left\langle V_i(t) \right\rangle_t \right]^2 | |||
one defines a synchrony measure, :math:`\chi (N)`, for the activity of a system | |||
of :math:`N` neurons by: | |||
.. math:: | |||
\chi^2 \left( N \right) = \frac{\sigma_V^2}{ \frac{1}{N} \sum_{i=1}^N | |||
\sigma_{V_i}^2} | |||
Parameters | |||
---------- | |||
potentials : | |||
The membrane potential matrix of the neuron group. | |||
Returns | |||
------- | |||
sync_index : float | |||
The synchronization index. | |||
References | |||
---------- | |||
.. [1] Golomb, D. and Rinzel J. (1993) Dynamics of globally coupled | |||
inhibitory neurons with heterogeneity. Phys. Rev. reversal_potential 48:4810-4814. | |||
.. [2] Golomb D. and Rinzel J. (1994) Clustering in globally coupled | |||
inhibitory neurons. Physica D 72:259-282. | |||
.. [3] David Golomb (2007) Neuronal synchrony measures. Scholarpedia, 2(1):1347. | |||
""" | |||
potentials = bm.as_device_array(potentials) | |||
num_hist, num_neu = potentials.shape | |||
var_mean = jnp.mean(_var(potentials, jnp.arange(num_neu))) | |||
avg = jnp.mean(potentials, axis=1) | |||
avg_var = jnp.mean(avg * avg) - jnp.mean(avg) ** 2 | |||
return lax.cond(var_mean != 0., | |||
lambda _: avg_var / var_mean, | |||
lambda _: 1., | |||
()) | |||
def matrix_correlation(x, y): | |||
"""Pearson correlation of the lower triagonal of two matrices. | |||
The triangular matrix is offset by k = 1 in order to ignore the diagonal line | |||
Parameters | |||
---------- | |||
x: tensor | |||
First matrix. | |||
y: tensor | |||
Second matrix | |||
Returns | |||
------- | |||
coef: tensor | |||
Correlation coefficient | |||
""" | |||
x = bm.as_numpy(x) | |||
y = bm.as_numpy(y) | |||
if x.ndim != 2: | |||
raise ValueError(f'Only support 2d tensor, but we got a tensor ' | |||
f'with the shape of {x.shape}') | |||
if y.ndim != 2: | |||
raise ValueError(f'Only support 2d tensor, but we got a tensor ' | |||
f'with the shape of {y.shape}') | |||
x = x[np.triu_indices_from(x, k=1)] | |||
y = y[np.triu_indices_from(y, k=1)] | |||
cc = np.corrcoef(x, y)[0, 1] | |||
return cc | |||
def functional_connectivity(activities): | |||
"""Functional connectivity matrix of timeseries activities. | |||
Parameters | |||
---------- | |||
activities: tensor | |||
The multidimensional tensor with the shape of ``(num_time, num_sample)``. | |||
Returns | |||
------- | |||
connectivity_matrix: tensor | |||
``num_sample x num_sample`` functional connectivity matrix. | |||
""" | |||
activities = bm.as_numpy(activities) | |||
if activities.ndim != 2: | |||
raise ValueError('Only support 2d tensor with shape of "(num_time, num_sample)". ' | |||
f'But we got a tensor with the shape of {activities.shape}') | |||
fc = np.corrcoef(activities.T) | |||
return np.nan_to_num(fc) | |||
@jit | |||
def functional_connectivity_dynamics(activities, window_size=30, step_size=5): | |||
"""Computes functional connectivity dynamics (FCD) matrix. | |||
Parameters | |||
---------- | |||
activities: tensor | |||
The time series with shape of ``(num_time, num_sample)``. | |||
window_size: int | |||
Size of each rolling window in time steps, defaults to 30. | |||
step_size: int | |||
Step size between each rolling window, defaults to 5. | |||
Returns | |||
------- | |||
fcd_matrix: tensor | |||
FCD matrix. | |||
""" | |||
pass | |||
def _weighted_mean(x, w): | |||
"""Weighted Mean""" | |||
return jnp.sum(x * w) / jnp.sum(w) | |||
def _weighted_cov(x, y, w): | |||
"""Weighted Covariance""" | |||
return jnp.sum(w * (x - _weighted_mean(x, w)) * (y - _weighted_mean(y, w))) / jnp.sum(w) | |||
@jit | |||
def weighted_correlation(x, y, w): | |||
"""Weighted Pearson correlation of two data series. | |||
Parameters | |||
---------- | |||
x: tensor | |||
The data series 1. | |||
y: tensor | |||
The data series 2. | |||
w: tensor | |||
Weight vector, must have same length as x and y. | |||
Returns | |||
------- | |||
corr: tensor | |||
Weighted correlation coefficient. | |||
""" | |||
x = bm.as_device_array(x) | |||
y = bm.as_device_array(y) | |||
w = bm.as_device_array(w) | |||
if x.ndim != 1: | |||
raise ValueError(f'Only support 1d tensor, but we got a tensor ' | |||
f'with the shape of {x.shape}') | |||
if y.ndim != 1: | |||
raise ValueError(f'Only support 1d tensor, but we got a tensor ' | |||
f'with the shape of {y.shape}') | |||
if w.ndim != 1: | |||
raise ValueError(f'Only support 1d tensor, but we got a tensor ' | |||
f'with the shape of {w.shape}') | |||
return _weighted_cov(x, y, w) / jnp.sqrt(_weighted_cov(x, x, w) * _weighted_cov(y, y, w)) |
@@ -0,0 +1,76 @@ | |||
# -*- coding: utf-8 -*- | |||
import numpy as np | |||
from jax import jit | |||
from brainpy import math as bm | |||
__all__ = [ | |||
'raster_plot', | |||
'firing_rate', | |||
] | |||
def raster_plot(sp_matrix, times): | |||
"""Get spike raster plot which displays the spiking activity | |||
of a group of neurons over time. | |||
Parameters | |||
---------- | |||
sp_matrix : bnp.ndarray | |||
The matrix which record spiking activities. | |||
times : bnp.ndarray | |||
The time steps. | |||
Returns | |||
------- | |||
raster_plot : tuple | |||
Include (neuron index, spike time). | |||
""" | |||
sp_matrix = np.asarray(sp_matrix) | |||
times = np.asarray(times) | |||
elements = np.where(sp_matrix > 0.) | |||
index = elements[1] | |||
time = times[elements[0]] | |||
return index, time | |||
@jit | |||
def _firing_rate(sp_matrix, window): | |||
sp_matrix = bm.asarray(sp_matrix) | |||
rate = bm.sum(sp_matrix, axis=1) / sp_matrix.shape[1] | |||
return bm.convolve(rate, window, mode='same') | |||
def firing_rate(sp_matrix, width, dt=None, numpy=True): | |||
r"""Calculate the mean firing rate over in a neuron group. | |||
This method is adopted from Brian2. | |||
The firing rate in trial :math:`k` is the spike count :math:`n_{k}^{sp}` | |||
in an interval of duration :math:`T` divided by :math:`T`: | |||
.. math:: | |||
v_k = {n_k^{sp} \over T} | |||
Parameters | |||
---------- | |||
sp_matrix : math.JaxArray, np.ndarray | |||
The spike matrix which record spiking activities. | |||
width : int, float | |||
The width of the ``window`` in millisecond. | |||
dt : float, optional | |||
The sample rate. | |||
Returns | |||
------- | |||
rate : numpy.ndarray | |||
The population rate in Hz, smoothed with the given window. | |||
""" | |||
dt = bm.get_dt() if (dt is None) else dt | |||
width1 = int(width / 2 / dt) * 2 + 1 | |||
window = bm.ones(width1) * 1000 / width | |||
fr = _firing_rate(sp_matrix, window) | |||
return fr.numpy() if numpy else fr | |||
@@ -0,0 +1,59 @@ | |||
# -*- coding: utf-8 -*- | |||
import unittest | |||
import brainpy as bp | |||
class TestCrossCorrelation(unittest.TestCase): | |||
def test_cc(self): | |||
spikes = bp.math.ones((1000, 10)) | |||
cc1 = bp.measure.cross_correlation(spikes, 1.) | |||
self.assertTrue(cc1 == 1.) | |||
spikes = bp.math.zeros((1000, 10)) | |||
cc2 = bp.measure.cross_correlation(spikes, 1.) | |||
self.assertTrue(cc2 == 0.) | |||
def test_cc2(self): | |||
spikes = bp.math.random.randint(0, 2, (1000, 10)) | |||
print(bp.measure.cross_correlation(spikes, 1.)) | |||
print(bp.measure.cross_correlation(spikes, 0.5)) | |||
def test_cc3(self): | |||
spikes = bp.math.random.random((1000, 100)) < 0.8 | |||
print(bp.measure.cross_correlation(spikes, 1.)) | |||
print(bp.measure.cross_correlation(spikes, 0.5)) | |||
def test_cc4(self): | |||
spikes = bp.math.random.random((1000, 100)) < 0.2 | |||
print(bp.measure.cross_correlation(spikes, 1.)) | |||
print(bp.measure.cross_correlation(spikes, 0.5)) | |||
def test_cc5(self): | |||
spikes = bp.math.random.random((1000, 100)) < 0.05 | |||
print(bp.measure.cross_correlation(spikes, 1.)) | |||
print(bp.measure.cross_correlation(spikes, 0.5)) | |||
class TestVoltageFluctuation(unittest.TestCase): | |||
def test_vf1(self): | |||
voltages = bp.math.random.normal(0, 10, size=(1000, 100)) | |||
print(bp.measure.voltage_fluctuation(voltages)) | |||
voltages = bp.math.ones((1000, 100)) | |||
print(bp.measure.voltage_fluctuation(voltages)) | |||
class TestFunctionalConnectivity(unittest.TestCase): | |||
def test_cf1(self): | |||
act = bp.math.random.random((10000, 3)) | |||
print(bp.measure.functional_connectivity(act)) | |||
class TestMatrixCorrelation(unittest.TestCase): | |||
def test_mc(self): | |||
A = bp.math.random.random((100, 100)) | |||
B = bp.math.random.random((100, 100)) | |||
print(bp.measure.matrix_correlation(A, B)) | |||
@@ -0,0 +1,22 @@ | |||
# -*- coding: utf-8 -*- | |||
import unittest | |||
import brainpy as bp | |||
class TestFiringRate(unittest.TestCase): | |||
def test_fr1(self): | |||
spikes = bp.math.ones((1000, 10)) | |||
print(bp.measure.firing_rate(spikes, 1.)) | |||
def test_fr2(self): | |||
spikes = bp.math.random.random((1000, 10)) < 0.2 | |||
print(bp.measure.firing_rate(spikes, 1.)) | |||
print(bp.measure.firing_rate(spikes, 10.)) | |||
def test_fr3(self): | |||
spikes = bp.math.random.random((1000, 10)) < 0.02 | |||
print(bp.measure.firing_rate(spikes, 1.)) | |||
print(bp.measure.firing_rate(spikes, 5.)) | |||
@@ -12,7 +12,7 @@ This module provide basic Node class for whole ``brainpy.nn`` system. | |||
This means ``brainpy.nn.Network`` is only used to pack element nodes. It will be | |||
never be an element node. | |||
- ``brainpy.nn.FrozenNetwork``: The whole network which can be represented as a basic | |||
elementary node when composing a larger network. TODO | |||
elementary node when composing a larger network (TODO). | |||
""" | |||
from copy import copy, deepcopy | |||
@@ -48,6 +48,16 @@ __all__ = [ | |||
NODE_STATES = ['inputs', 'feedbacks', 'state', 'output'] | |||
SUPPORTED_LAYOUTS = ['shell_layout', | |||
'multipartite_layout', | |||
'spring_layout', | |||
'spiral_layout', | |||
'spectral_layout', | |||
'random_layout', | |||
'planar_layout', | |||
'kamada_kawai_layout', | |||
'circular_layout'] | |||
def not_implemented(fun: Callable) -> Callable: | |||
"""Marks the given module method is not implemented. | |||
@@ -92,8 +102,10 @@ class Node(Base): | |||
self._is_ff_initialized = False | |||
self._is_fb_initialized = False | |||
self._is_state_initialized = False | |||
self._is_fb_state_initialized = False | |||
self._trainable = trainable | |||
self._state = None # the state of the current node | |||
self._fb_output = None # the feedback output of the current node | |||
# data pass function | |||
if self.data_pass_type not in DATA_PASS_FUNC: | |||
raise ValueError(f'Unsupported data pass type {self.data_pass_type}. ' | |||
@@ -111,12 +123,9 @@ class Node(Base): | |||
name = type(self).__name__ | |||
prefix = ' ' * (len(name) + 1) | |||
line1 = (f"{name}(name={self.name}, " | |||
f"trainable={self.trainable}, " | |||
f"forwards={self.feedforward_shapes}, " | |||
f"feedbacks={self.feedback_shapes}, \n") | |||
line2 = (f"{prefix}output={self.output_shape}, " | |||
f"support_feedback={self.support_feedback}, " | |||
f"data_pass_type={self.data_pass_type})") | |||
line2 = f"{prefix}output={self.output_shape}" | |||
return line1 + line2 | |||
def __call__(self, *args, **kwargs) -> Tensor: | |||
@@ -194,7 +203,7 @@ class Node(Base): | |||
@property | |||
def state(self) -> Optional[Tensor]: | |||
"""Node current internal state.""" | |||
if self.is_ff_initialized: | |||
if self._is_ff_initialized: | |||
return self._state | |||
return None | |||
@@ -209,9 +218,9 @@ class Node(Base): | |||
This method allows the maximum flexibility to change the | |||
node state. It can set a new data (same shape, same dtype) | |||
to the state. It can also set the data with another batch size. | |||
We highly recommend the user to use this function. | |||
to the state. It can also set a new data with the different | |||
shape. We highly recommend the user to use this function. | |||
instead of using ``self.state.value``. | |||
""" | |||
if self.state is None: | |||
if self.output_shape is not None: | |||
@@ -225,31 +234,52 @@ class Node(Base): | |||
self.state._value = bm.as_device_array(state) | |||
@property | |||
def trainable(self) -> bool: | |||
"""Returns if the Node can be trained.""" | |||
return self._trainable | |||
def fb_output(self) -> Optional[Tensor]: | |||
return self._fb_output | |||
@property | |||
def is_ff_initialized(self) -> bool: | |||
return self._is_ff_initialized | |||
@fb_output.setter | |||
def fb_output(self, value: Tensor): | |||
raise NotImplementedError('Please use "set_fb_output()" to reset the node feedback state, ' | |||
'or use "self.fb_output.value" to change the state content.') | |||
@is_ff_initialized.setter | |||
def is_ff_initialized(self, value: bool): | |||
assert isinstance(value, bool) | |||
self._is_ff_initialized = value | |||
def set_fb_output(self, state: Tensor): | |||
""" | |||
Safely set the feedback state of the node. | |||
@property | |||
def is_fb_initialized(self) -> bool: | |||
return self._is_fb_initialized | |||
This method allows the maximum flexibility to change the | |||
node state. It can set a new data (same shape, same dtype) | |||
to the state. It can also set a new data with the different | |||
shape. We highly recommend the user to use this function. | |||
instead of using ``self.fb_output.value``. | |||
""" | |||
if self.fb_output is None: | |||
if self.output_shape is not None: | |||
check_batch_shape(self.output_shape, state.shape) | |||
self._fb_output = bm.Variable(state) if not isinstance(state, bm.Variable) else state | |||
else: | |||
check_batch_shape(self.fb_output.shape, state.shape) | |||
if self.fb_output.dtype != state.dtype: | |||
raise MathError('Cannot set the feedback state, because the dtype is ' | |||
f'not consistent: {self.fb_output.dtype} != {state.dtype}') | |||
self.fb_output._value = bm.as_device_array(state) | |||
@is_fb_initialized.setter | |||
def is_fb_initialized(self, value: bool): | |||
assert isinstance(value, bool) | |||
self._is_fb_initialized = value | |||
@property | |||
def trainable(self) -> bool: | |||
"""Returns if the Node can be trained.""" | |||
return self._trainable | |||
@property | |||
def is_state_initialized(self): | |||
return self._is_state_initialized | |||
def is_initialized(self) -> bool: | |||
if self._is_ff_initialized and self._is_state_initialized: | |||
if self.feedback_shapes is not None: | |||
if self._is_fb_initialized and self._is_fb_state_initialized: | |||
return True | |||
else: | |||
return False | |||
else: | |||
return True | |||
else: | |||
return False | |||
@trainable.setter | |||
def trainable(self, value: bool): | |||
@@ -268,7 +298,7 @@ class Node(Base): | |||
self.set_feedforward_shapes(size) | |||
def set_feedforward_shapes(self, feedforward_shapes: Dict): | |||
if not self.is_ff_initialized: | |||
if not self._is_ff_initialized: | |||
check_dict_data(feedforward_shapes, | |||
key_type=(Node, str), | |||
val_type=(list, tuple), | |||
@@ -278,11 +308,11 @@ class Node(Base): | |||
if self.feedforward_shapes is not None: | |||
for key, size in self._feedforward_shapes.items(): | |||
if key not in feedforward_shapes: | |||
raise ValueError(f"Impossible to reset the input data of {self.name}. " | |||
raise ValueError(f"Impossible to reset the input shape of {self.name}. " | |||
f"Because this Node has the input dimension {size} from {key}. " | |||
f"While we do not find it in the given feedforward_shapes") | |||
if not check_batch_shape(size, feedforward_shapes[key], mode='bool'): | |||
raise ValueError(f"Impossible to reset the input data of {self.name}. " | |||
raise ValueError(f"Impossible to reset the input shape of {self.name}. " | |||
f"Because this Node has the input dimension {size} from {key}. " | |||
f"While the give shape is {feedforward_shapes[key]}") | |||
@@ -296,7 +326,7 @@ class Node(Base): | |||
self.set_feedback_shapes(size) | |||
def set_feedback_shapes(self, fb_shapes: Dict): | |||
if not self.is_fb_initialized: | |||
if not self._is_fb_initialized: | |||
check_dict_data(fb_shapes, key_type=(Node, str), val_type=(tuple, list), name='fb_shapes') | |||
self._feedback_shapes = fb_shapes | |||
else: | |||
@@ -321,14 +351,21 @@ class Node(Base): | |||
self.set_output_shape(size) | |||
@property | |||
def support_feedback(self): | |||
if hasattr(self.init_fb, 'not_implemented'): | |||
if self.init_fb.not_implemented: | |||
def is_feedback_input_supported(self): | |||
if hasattr(self.init_fb_conn, 'not_implemented'): | |||
if self.init_fb_conn.not_implemented: | |||
return False | |||
return True | |||
@property | |||
def is_feedback_supported(self): | |||
if self.fb_output is None: | |||
return False | |||
else: | |||
return True | |||
def set_output_shape(self, shape: Sequence[int]): | |||
if not self.is_ff_initialized: | |||
if not self._is_ff_initialized: | |||
if not isinstance(shape, (tuple, list)): | |||
raise ValueError(f'Must be a sequence of int, but got {shape}') | |||
self._output_shape = tuple(shape) | |||
@@ -368,84 +405,88 @@ class Node(Base): | |||
new_obj.name = self.unique_name(name or (self.name + '_copy')) | |||
return new_obj | |||
def _ff_init(self): | |||
if not self.is_ff_initialized: | |||
def _init_ff_conn(self): | |||
if not self._is_ff_initialized: | |||
try: | |||
self.init_ff() | |||
self.init_ff_conn() | |||
except Exception as e: | |||
raise ModelBuildError(f'{self.name} initialization failed.') from e | |||
self._is_ff_initialized = True | |||
if self.output_shape is None: | |||
raise ValueError(f'Please set the output shape when implementing ' | |||
f'"init_ff()" of the node {self.name}') | |||
def _fb_init(self): | |||
if not self.is_fb_initialized: | |||
def _init_fb_conn(self): | |||
if not self._is_fb_initialized: | |||
try: | |||
self.init_fb() | |||
self.init_fb_conn() | |||
except Exception as e: | |||
raise ModelBuildError(f"{self.name} initialization failed.") from e | |||
self._is_fb_initialized = True | |||
@not_implemented | |||
def init_fb(self): | |||
def init_fb_conn(self): | |||
"""Initialize the feedback connections. | |||
This function will be called only once.""" | |||
raise ValueError(f'This node \n\n{self} \n\ndoes not support feedback connection.') | |||
def init_ff(self): | |||
def init_ff_conn(self): | |||
"""Initialize the feedforward connections. | |||
This function will be called only once.""" | |||
raise NotImplementedError('Please implement the feedforward initialization.') | |||
def init_state(self, num_batch=1): | |||
def _init_state(self, num_batch=1): | |||
state = self.init_state(num_batch) | |||
if state is not None: | |||
self.set_state(state) | |||
def _init_fb_output(self, num_batch=1): | |||
output = self.init_fb_output(num_batch) | |||
if output is not None: | |||
self.set_fb_output(output) | |||
def init_state(self, num_batch=1) -> Optional[Tensor]: | |||
"""Set the initial node state. | |||
This function can be called multiple times.""" | |||
pass | |||
def initialize(self, | |||
ff: Optional[Union[Tensor, Dict[Any, Tensor]]] = None, | |||
fb: Optional[Union[Tensor, Dict[Any, Tensor]]] = None, | |||
num_batch: int = None): | |||
def init_fb_output(self, num_batch=1) -> Optional[Tensor]: | |||
"""Set the initial node feedback state. | |||
This function can be called multiple times. However, | |||
it is only triggered when the node has feedback connections. | |||
""" | |||
Initialize the whole network. This function must be called before applying JIT. | |||
return bm.zeros((num_batch,) + self.output_shape[1:], dtype=bm.float_) | |||
def initialize(self, num_batch: int): | |||
""" | |||
Initialize the node. This function must be called before applying JIT. | |||
This function is useful, because it is independent from the __call__ function. | |||
We can use this function before we applying JIT to __call__ function. | |||
This function is useful, because it is independent of the __call__ function. | |||
We can use this function before we apply JIT to __call__ function. | |||
""" | |||
# feedforward initialization | |||
if not self.is_ff_initialized: | |||
# feedforward data | |||
if ff is None: | |||
if self._feedforward_shapes is None: | |||
if self.feedforward_shapes is None: | |||
raise ValueError('Cannot initialize this node, because we detect ' | |||
'both "feedforward_shapes"and "ff" inputs are None. ') | |||
in_sizes = self._feedforward_shapes | |||
if num_batch is None: | |||
raise ValueError('"num_batch" cannot be None when "ff" is not provided.') | |||
'both "feedforward_shapes" is None. ' | |||
'Two ways can solve this problem:\n\n' | |||
'1. Connecting an instance of "brainpy.nn.Input()" to this node. \n' | |||
'2. Providing the "input_shape" when initialize the node.') | |||
check_integer(num_batch, 'num_batch', min_bound=0, allow_none=False) | |||
else: | |||
if isinstance(ff, (bm.ndarray, jnp.ndarray)): | |||
ff = {self.name: ff} | |||
assert isinstance(ff, dict), f'"ff" must be a dict or a tensor, got {type(ff)}: {ff}' | |||
assert self.name in ff, f'Cannot find input for this node \n\n{self} \n\nwhen given "ff" {ff}' | |||
batch_sizes = [v.shape[0] for v in ff.values()] | |||
if set(batch_sizes) != 1: | |||
raise ValueError('Batch sizes must be consistent, but we got multiple ' | |||
f'batch sizes {set(batch_sizes)} for the given input: \n' | |||
f'{ff}') | |||
in_sizes = {k: (None,) + v.shape[1:] for k, v in ff.items()} | |||
if (num_batch is not None) and (num_batch != batch_sizes[0]): | |||
raise ValueError(f'The provided "num_batch" {num_batch} is consistent with the ' | |||
f'batch size of the provided data {batch_sizes[0]}') | |||
# initialize feedforward | |||
self.set_feedforward_shapes(in_sizes) | |||
self._ff_init() | |||
self.init_state(num_batch) | |||
self._init_ff_conn() | |||
# initialize state | |||
self._init_state(num_batch) | |||
self._is_state_initialized = True | |||
if self.feedback_shapes is not None: | |||
# feedback initialization | |||
if fb is not None: | |||
if not self.is_fb_initialized: # initialize feedback | |||
assert isinstance(fb, dict), f'"fb" must be a dict, got {type(fb)}' | |||
fb_sizes = {k: (None,) + v.shape[1:] for k, v in fb.items()} | |||
self.set_feedback_shapes(fb_sizes) | |||
self._fb_init() | |||
else: | |||
self._is_fb_initialized = True | |||
self._init_fb_conn() | |||
# initialize feedback state | |||
self._init_fb_output(num_batch) | |||
self._is_fb_state_initialized = True | |||
def _check_inputs(self, ff, fb=None): | |||
# check feedforward inputs | |||
@@ -477,9 +518,8 @@ class Node(Base): | |||
forced_feedbacks: Dict[str, Tensor] = None, | |||
monitors=None, | |||
**kwargs) -> Union[Tensor, Tuple[Tensor, Dict]]: | |||
# # initialization | |||
# self.initialize(ff, fb) | |||
if not (self.is_ff_initialized and self.is_fb_initialized and self.is_state_initialized): | |||
# checking | |||
if not self.is_initialized: | |||
raise ValueError('Please initialize the Node first by calling "initialize()" function.') | |||
# initialize the forced data | |||
@@ -511,6 +551,7 @@ class Node(Base): | |||
assert self.state is not None, (f'{self} \n\nhas no state, while ' | |||
f'the user try to monitor its state.') | |||
state_monitors[key] = None | |||
# calling | |||
ff, fb = self._check_inputs(ff, fb=fb) | |||
if 'inputs' in state_monitors: | |||
@@ -528,7 +569,7 @@ class Node(Base): | |||
else: | |||
return output | |||
def forward(self, ff, fb=None, **kwargs): | |||
def forward(self, ff, fb=None, **shared_kwargs): | |||
"""The feedforward computation function of a node. | |||
Parameters | |||
@@ -537,7 +578,7 @@ class Node(Base): | |||
The feedforward inputs. | |||
fb: optional, tensor, dict, sequence | |||
The feedback inputs. | |||
**kwargs | |||
**shared_kwargs | |||
Other parameters. | |||
Returns | |||
@@ -547,12 +588,12 @@ class Node(Base): | |||
""" | |||
raise NotImplementedError | |||
def feedback(self, **kwargs): | |||
def feedback(self, ff_output, **shared_kwargs): | |||
"""The feedback computation function of a node. | |||
Parameters | |||
---------- | |||
**kwargs | |||
**shared_kwargs | |||
Other global parameters. | |||
Returns | |||
@@ -560,12 +601,17 @@ class Node(Base): | |||
Tensor | |||
A feedback output tensor value. | |||
""" | |||
return self.state | |||
return ff_output | |||
class RecurrentNode(Node): | |||
""" | |||
Basic class for recurrent node. | |||
The supports for the recurrent node are: | |||
- Self-connection when using ``plot_node_graph()`` function | |||
- Set trainable state with ``state_trainable=True``. | |||
""" | |||
def __init__(self, | |||
@@ -617,19 +663,6 @@ class RecurrentNode(Node): | |||
else: | |||
self.state._value = bm.as_device_array(state) | |||
def __repr__(self): | |||
name = type(self).__name__ | |||
prefix = ' ' * (len(name) + 1) | |||
line1 = (f"{name}(name={self.name}, recurrent=True, " | |||
f"trainable={self.trainable}, \n") | |||
line2 = (f"{prefix}forwards={self.feedforward_shapes}, " | |||
f"feedbacks={self.feedback_shapes}, \n") | |||
line3 = (f"{prefix}output={self.output_shape}, " | |||
f"support_feedback={self.support_feedback}, " | |||
f"data_pass_type={self.data_pass_type})") | |||
return line1 + line2 + line3 | |||
class Network(Node): | |||
"""Basic Network class for neural network building in BrainPy.""" | |||
@@ -806,8 +839,8 @@ class Network(Node): | |||
def replace_graph(self, | |||
nodes: Sequence[Node], | |||
ff_edges: Sequence[Tuple[Node, Node]], | |||
fb_edges: Sequence[Tuple[Node, Node]] = None) -> "Network": | |||
ff_edges: Sequence[Tuple[Node, ...]], | |||
fb_edges: Sequence[Tuple[Node, ...]] = None) -> "Network": | |||
if fb_edges is None: fb_edges = tuple() | |||
# assign nodes and edges | |||
@@ -817,16 +850,45 @@ class Network(Node): | |||
self._network_init() | |||
return self | |||
def init_ff(self): | |||
def set_output_shape(self, shape: Dict[str, Sequence[int]]): | |||
# check shape | |||
if not isinstance(shape, dict): | |||
raise ValueError(f'Must be a dict of <node name, shape>, but got {type(shape)}: {shape}') | |||
for key, val in shape.items(): | |||
if not isinstance(val, (tuple, list)): | |||
raise ValueError(f'Must be a sequence of int, but got {val} for key "{key}"') | |||
# for s in val: | |||
# if not (isinstance(s, int) or (s is None)): | |||
# raise ValueError(f'Must be a sequence of int, but got {val}') | |||
if not self._is_ff_initialized: | |||
if len(self.exit_nodes) == 1: | |||
self._output_shape = tuple(shape.values())[0] | |||
else: | |||
self._output_shape = shape | |||
else: | |||
for val in shape.values(): | |||
check_batch_shape(val, self.output_shape) | |||
def init_ff_conn(self): | |||
"""Initialize the feedforward connections of the network. | |||
This function will be called only once.""" | |||
# input shapes of entry nodes | |||
for node in self.entry_nodes: | |||
# set ff shapes | |||
if node.feedforward_shapes is None: | |||
if self.feedforward_shapes is None: | |||
raise ValueError('Cannot find the input size. ' | |||
'Cannot initialize the network.') | |||
else: | |||
node.set_feedforward_shapes({node.name: self._feedforward_shapes[node.name]}) | |||
node._ff_init() | |||
# set fb shapes | |||
if node in self.fb_senders: | |||
fb_shapes = {node: node.output_shape for node in self.fb_senders.get(node, [])} | |||
if None not in fb_shapes.values(): | |||
node.set_feedback_shapes(fb_shapes) | |||
# init ff conn | |||
node._init_ff_conn() | |||
# initialize the data | |||
children_queue = [] | |||
@@ -840,49 +902,79 @@ class Network(Node): | |||
children_queue.append(child) | |||
while len(children_queue): | |||
node = children_queue.pop(0) | |||
# initialize input and output sizes | |||
# set ff shapes | |||
parent_sizes = {p: p.output_shape for p in self.ff_senders.get(node, [])} | |||
node.set_feedforward_shapes(parent_sizes) | |||
node._ff_init() | |||
if node in self.fb_senders: | |||
# set fb shapes | |||
fb_shapes = {node: node.output_shape for node in self.fb_senders.get(node, [])} | |||
if None not in fb_shapes.values(): | |||
node.set_feedback_shapes(fb_shapes) | |||
# init ff conn | |||
node._init_ff_conn() | |||
# append children | |||
for child in self.ff_receivers.get(node, []): | |||
ff_senders[child].remove(node) | |||
if len(ff_senders.get(child, [])) == 0: | |||
children_queue.append(child) | |||
def init_fb(self): | |||
# set output shape | |||
out_sizes = {node: node.output_shape for node in self.exit_nodes} | |||
self.set_output_shape(out_sizes) | |||
def init_fb_conn(self): | |||
"""Initialize the feedback connections of the network. | |||
This function will be called only once.""" | |||
for receiver, senders in self.fb_senders.items(): | |||
fb_sizes = {node: node.output_shape for node in senders} | |||
if None in fb_sizes.values(): | |||
none_size_nodes = [repr(n) for n, v in fb_sizes.items() if v is None] | |||
none_size_nodes = "\n".join(none_size_nodes) | |||
raise ValueError(f'Output shapes of nodes \n\n' | |||
f'{none_size_nodes}\n\n' | |||
f'have not been initialized, ' | |||
f'leading us cannot initialize the ' | |||
f'feedback connection of node \n\n' | |||
f'{receiver}') | |||
receiver.set_feedback_shapes(fb_sizes) | |||
receiver._fb_init() | |||
receiver._init_fb_conn() | |||
def init_state(self, num_batch=1): | |||
"""Initialize the states of all children nodes.""" | |||
def _init_state(self, num_batch=1): | |||
"""Initialize the states of all children nodes. | |||
This function can be called multiple times.""" | |||
for node in self.lnodes: | |||
node.init_state(num_batch) | |||
node._init_state(num_batch) | |||
def initialize(self, | |||
ff: Optional[Union[Tensor, Dict[Any, Tensor]]] = None, | |||
fb: Optional[Union[Tensor, Dict[Any, Tensor]]] = None, | |||
num_batch: int = None): | |||
def _init_fb_output(self, num_batch=1): | |||
"""Initialize the node feedback state. | |||
This function can be called multiple times. However, | |||
it is only triggered when the node has feedback connections. | |||
""" | |||
for node in self.feedback_nodes: | |||
node._init_fb_output(num_batch) | |||
def initialize(self, num_batch: int): | |||
""" | |||
Initialize the whole network. This function must be called before applying JIT. | |||
This function is useful, because it is independent from the __call__ function. | |||
We can use this function before we applying JIT to __call__ function. | |||
This function is useful, because it is independent of the __call__ function. | |||
We can use this function before we apply JIT to __call__ function. | |||
""" | |||
# feedforward initialization | |||
if not self.is_ff_initialized: | |||
# set feedforward shapes | |||
if not self._is_ff_initialized: | |||
# check input and output nodes | |||
assert len(self.entry_nodes) > 0, (f"We found this network \n\n" | |||
if len(self.entry_nodes) <= 0: | |||
raise ValueError(f"We found this network \n\n" | |||
f"{self} " | |||
f"\n\nhas no input nodes.") | |||
assert len(self.exit_nodes) > 0, (f"We found this network \n\n" | |||
if len(self.exit_nodes) <= 0: | |||
raise ValueError(f"We found this network \n\n" | |||
f"{self} " | |||
f"\n\nhas no output nodes.") | |||
# check whether has a feedforward path for each feedback pair | |||
# check whether it has a feedforward path for each feedback pair | |||
ff_edges = [(a.name, b.name) for a, b in self.ff_edges] | |||
for node, receiver in self.fb_edges: | |||
if not detect_path(receiver.name, node.name, ff_edges): | |||
@@ -895,49 +987,42 @@ class Network(Node): | |||
f'feedforward connection between them. ') | |||
# feedforward checking | |||
if ff is None: | |||
in_sizes = dict() | |||
for node in self.entry_nodes: | |||
if node._feedforward_shapes is None: | |||
if node.feedforward_shapes is None: | |||
raise ValueError('Cannot initialize this node, because we detect ' | |||
'both "feedforward_shapes" and "ff" inputs are None. ' | |||
'"feedforward_shapes" is None. ' | |||
'Maybe you need a brainpy.nn.Input instance ' | |||
'to instruct the input size.') | |||
in_sizes.update(node._feedforward_shapes) | |||
if num_batch is None: | |||
raise ValueError('"num_batch" cannot be None when "ff" is not provided.') | |||
check_integer(num_batch, 'num_batch', min_bound=0, allow_none=False) | |||
else: | |||
if isinstance(ff, (bm.ndarray, jnp.ndarray)): | |||
ff = {self.entry_nodes[0].name: ff} | |||
assert isinstance(ff, dict), f'ff must be a dict or a tensor, got {type(ff)}: {ff}' | |||
for n in self.entry_nodes: | |||
if n.name not in ff: | |||
raise ValueError(f'Cannot find the input of the node {n}') | |||
batch_sizes = [v.shape[0] for v in ff.values()] | |||
if len(set(batch_sizes)) != 1: | |||
raise ValueError('Batch sizes must be consistent, but we got multiple ' | |||
f'batch sizes {set(batch_sizes)} for the given input: \n' | |||
f'{ff}') | |||
in_sizes = {k: (None,) + v.shape[1:] for k, v in ff.items()} | |||
if (num_batch is not None) and (num_batch != batch_sizes[0]): | |||
raise ValueError(f'The provided "num_batch" {num_batch} is consistent with the ' | |||
f'batch size of the provided data {batch_sizes[0]}') | |||
# initialize feedforward | |||
self.set_feedforward_shapes(in_sizes) | |||
self._ff_init() | |||
self.init_state(num_batch) | |||
# feedforward initialization | |||
if self.feedforward_shapes is None: | |||
raise ValueError('Cannot initialize this node, because we detect ' | |||
'both "feedforward_shapes" is None. ') | |||
check_integer(num_batch, 'num_batch', min_bound=1, allow_none=False) | |||
self._init_ff_conn() | |||
# initialize state | |||
self._init_state(num_batch) | |||
self._is_state_initialized = True | |||
# set feedback shapes | |||
if not self._is_fb_initialized: | |||
if len(self.fb_senders) > 0: | |||
fb_sizes = dict() | |||
for sender in self.fb_senders.keys(): | |||
fb_sizes[sender] = sender.output_shape | |||
self.set_feedback_shapes(fb_sizes) | |||
# feedback initialization | |||
if len(self.fb_senders): | |||
# initialize feedback | |||
if not self.is_fb_initialized: | |||
self._fb_init() | |||
else: | |||
self.is_fb_initialized = True | |||
if self.feedback_shapes is not None: | |||
self._init_fb_conn() | |||
# initialize feedback state | |||
self._init_fb_output(num_batch) | |||
self._is_fb_state_initialized = True | |||
def _check_inputs(self, ff, fb=None): | |||
# feedforward inputs | |||
@@ -986,8 +1071,7 @@ class Network(Node): | |||
monitors: Optional[Sequence[str]] = None, | |||
**kwargs): | |||
# initialization | |||
# self.initialize(ff, fb) | |||
if not (self.is_ff_initialized and self.is_fb_initialized and self.is_state_initialized): | |||
if not self.is_initialized: | |||
raise ValueError('Please initialize the Network first by calling "initialize()" function.') | |||
# initialize the forced data | |||
@@ -1038,7 +1122,7 @@ class Network(Node): | |||
forced_states: Dict[str, Tensor] = None, | |||
forced_feedbacks: Dict[str, Tensor] = None, | |||
monitors: Dict = None, | |||
**kwargs): | |||
**shared_kwargs): | |||
"""The main computation function of a network. | |||
Parameters | |||
@@ -1053,7 +1137,7 @@ class Network(Node): | |||
The fixed feedback for the nodes in the network. | |||
monitors: optional, sequence | |||
Can be used to monitor the state or the attribute of a node in the network. | |||
**kwargs | |||
**shared_kwargs | |||
Other parameters which will be parsed into every node. | |||
Returns | |||
@@ -1077,10 +1161,11 @@ class Network(Node): | |||
parent_outputs = {} | |||
for i, node in enumerate(self._entry_nodes): | |||
ff_ = {node.name: ff[i]} | |||
fb_ = {p: (forced_feedbacks[p.name] if (p.name in forced_feedbacks) else p.feedback()) | |||
fb_ = {p: (forced_feedbacks[p.name] if (p.name in forced_feedbacks) else p.fb_output) | |||
for p in self.fb_senders.get(node, [])} | |||
self._call_a_node(node, ff_, fb_, monitors, forced_states, | |||
parent_outputs, children_queue, ff_senders, **kwargs) | |||
parent_outputs, children_queue, ff_senders, | |||
**shared_kwargs) | |||
runned_nodes.add(node.name) | |||
# run the model | |||
@@ -1088,23 +1173,23 @@ class Network(Node): | |||
node = children_queue.pop(0) | |||
# get feedforward and feedback inputs | |||
ff = {p: parent_outputs[p] for p in self.ff_senders.get(node, [])} | |||
fb = {p: (forced_feedbacks[p.name] if (p.name in forced_feedbacks) else p.feedback()) | |||
fb = {p: (forced_feedbacks[p.name] if (p.name in forced_feedbacks) else p.fb_output) | |||
for p in self.fb_senders.get(node, [])} | |||
# call the node | |||
self._call_a_node(node, ff, fb, monitors, forced_states, | |||
parent_outputs, children_queue, ff_senders, | |||
**kwargs) | |||
# #- remove unnecessary parent outputs -# | |||
# needed_parents = [] | |||
# runned_nodes.add(node.name) | |||
# for child in (all_nodes - runned_nodes): | |||
# for parent in self.ff_senders[self.implicit_nodes[child]]: | |||
# needed_parents.append(parent.name) | |||
# for parent in list(parent_outputs.keys()): | |||
# _name = parent.name | |||
# if _name not in needed_parents and _name not in output_nodes: | |||
# parent_outputs.pop(parent) | |||
**shared_kwargs) | |||
# - remove unnecessary parent outputs - # | |||
needed_parents = [] | |||
runned_nodes.add(node.name) | |||
for child in (all_nodes - runned_nodes): | |||
for parent in self.ff_senders[self.implicit_nodes[child]]: | |||
needed_parents.append(parent.name) | |||
for parent in list(parent_outputs.keys()): | |||
_name = parent.name | |||
if _name not in needed_parents and _name not in output_nodes: | |||
parent_outputs.pop(parent) | |||
# returns | |||
if len(self.exit_nodes) > 1: | |||
@@ -1114,7 +1199,8 @@ class Network(Node): | |||
return state, monitors | |||
def _call_a_node(self, node, ff, fb, monitors, forced_states, | |||
parent_outputs, children_queue, ff_senders, **kwargs): | |||
parent_outputs, children_queue, ff_senders, | |||
**shared_kwargs): | |||
ff = node.data_pass_func(ff) | |||
if f'{node.name}.inputs' in monitors: | |||
monitors[f'{node.name}.inputs'] = ff | |||
@@ -1123,12 +1209,17 @@ class Network(Node): | |||
fb = node.data_pass_func(fb) | |||
if f'{node.name}.feedbacks' in monitors: | |||
monitors[f'{node.name}.feedbacks'] = fb | |||
parent_outputs[node] = node.forward(ff, fb, **kwargs) | |||
parent_outputs[node] = node.forward(ff, fb, **shared_kwargs) | |||
else: | |||
parent_outputs[node] = node.forward(ff, **kwargs) | |||
if node.name in forced_states: # forced state | |||
parent_outputs[node] = node.forward(ff, **shared_kwargs) | |||
# get the feedback state | |||
if node in self.fb_receivers: | |||
node.set_fb_output(node.feedback(parent_outputs[node], **shared_kwargs)) | |||
# forced state | |||
if node.name in forced_states: | |||
node.state.value = forced_states[node.name] | |||
parent_outputs[node] = forced_states[node.name] | |||
# parent_outputs[node] = forced_states[node.name] | |||
# monitor the values | |||
if f'{node.name}.state' in monitors: | |||
monitors[f'{node.name}.state'] = node.state.value | |||
if f'{node.name}.output' in monitors: | |||
@@ -1143,7 +1234,7 @@ class Network(Node): | |||
fig_size: tuple = (10, 10), | |||
node_size: int = 2000, | |||
arrow_size: int = 20, | |||
layout='spectral_layout'): | |||
layout='shell_layout'): | |||
"""Plot the node graph based on NetworkX package | |||
Parameters | |||
@@ -1155,7 +1246,17 @@ class Network(Node): | |||
arrow_size:int, default to 20 | |||
The size of the arrow | |||
layout: str | |||
The graph layout. More please see networkx Graph Layout. | |||
The graph layout. The supported layouts are: | |||
- "shell_layout" | |||
- "multipartite_layout" | |||
- "spring_layout" | |||
- "spiral_layout" | |||
- "spectral_layout" | |||
- "random_layout" | |||
- "planar_layout" | |||
- "kamada_kawai_layout" | |||
- "circular_layout" | |||
""" | |||
try: | |||
import networkx as nx | |||
@@ -1204,15 +1305,8 @@ class Network(Node): | |||
G.add_edges_from(fb_edges) | |||
G.add_edges_from(rec_edges) | |||
assert layout in ['shell_layout', | |||
'multipartite_layout', | |||
'spring_layout', | |||
'spiral_layout', | |||
'spectral_layout', | |||
'random_layout', | |||
'planar_layout', | |||
'kamada_kawai_layout', | |||
'circular_layout'] | |||
if layout not in SUPPORTED_LAYOUTS: | |||
raise UnsupportedError(f'Only support layouts: {SUPPORTED_LAYOUTS}') | |||
layout = getattr(nx, layout)(G) | |||
plt.figure(figsize=fig_size) | |||
@@ -1252,10 +1346,12 @@ class Network(Node): | |||
proxie = [] | |||
labels = [] | |||
if len(nodes_trainable): | |||
proxie.append(Line2D([], [], color='white', marker='o', markerfacecolor=trainable_color)) | |||
proxie.append(Line2D([], [], color='white', marker='o', | |||
markerfacecolor=trainable_color)) | |||
labels.append('Trainable') | |||
if len(nodes_untrainable): | |||
proxie.append(Line2D([], [], color='white', marker='o', markerfacecolor=untrainable_color)) | |||
proxie.append(Line2D([], [], color='white', marker='o', | |||
markerfacecolor=untrainable_color)) | |||
labels.append('Untrainable') | |||
if len(ff_edges): | |||
proxie.append(Line2D([], [], color=ff_color, linewidth=2)) | |||
@@ -1267,8 +1363,7 @@ class Network(Node): | |||
proxie.append(Line2D([], [], color=rec_color, linewidth=2)) | |||
labels.append('Recurrent') | |||
plt.legend(proxie, labels, scatterpoints=1, markerscale=2, | |||
loc='best') | |||
plt.legend(proxie, labels, scatterpoints=1, markerscale=2, loc='best') | |||
plt.tight_layout() | |||
plt.show() | |||
@@ -4,9 +4,8 @@ | |||
import jax.lax | |||
import brainpy.math as bm | |||
from brainpy.initialize import XavierNormal, ZeroInit | |||
from brainpy.initialize import XavierNormal, ZeroInit, init_param | |||
from brainpy.nn.base import Node | |||
from brainpy.nn.utils import init_param | |||
__all__ = [ | |||
'Conv2D', | |||
@@ -81,7 +80,7 @@ class Conv2D(Node): | |||
self.padding = padding | |||
self.groups = groups | |||
def init_ff(self): | |||
def init_ff_conn(self): | |||
assert self.num_input % self.groups == 0, '"nin" should be divisible by groups' | |||
size = _check_tuple(self.kernel_size) + (self.num_input // self.groups, self.num_output) | |||
self.w = init_param(self.w_init, size) | |||
@@ -90,7 +89,7 @@ class Conv2D(Node): | |||
self.w = bm.TrainVar(self.w) | |||
self.b = bm.TrainVar(self.b) | |||
def forward(self, ff, **kwargs): | |||
def forward(self, ff, **shared_kwargs): | |||
x = ff[0] | |||
nin = self.w.value.shape[2] * self.groups | |||
assert x.shape[1] == nin, (f'Attempting to convolve an input with {x.shape[1]} input channels ' | |||
@@ -44,11 +44,11 @@ class Dropout(Node): | |||
self.prob = prob | |||
self.rng = bm.random.RandomState(seed=seed) | |||
def init_ff(self): | |||
def init_ff_conn(self): | |||
self.set_output_shape(self.feedforward_shapes) | |||
def forward(self, ff, **kwargs): | |||
if kwargs.get('train', True): | |||
def forward(self, ff, **shared_kwargs): | |||
if shared_kwargs.get('train', True): | |||
keep_mask = self.rng.bernoulli(self.prob, ff.shape) | |||
return bm.where(keep_mask, ff / self.prob, 0.) | |||
else: | |||
@@ -1,14 +1,20 @@ | |||
# -*- coding: utf-8 -*- | |||
from typing import Union, Callable | |||
import brainpy.math as bm | |||
from brainpy.initialize import (XavierNormal, ZeroInit, | |||
Uniform, Orthogonal) | |||
from brainpy.initialize import (XavierNormal, | |||
ZeroInit, | |||
Uniform, | |||
Orthogonal, | |||
init_param, | |||
Initializer) | |||
from brainpy.nn.base import RecurrentNode | |||
from brainpy.nn.utils import init_param | |||
from brainpy.tools.checking import (check_integer, | |||
check_initializer, | |||
check_shape_consistency) | |||
from brainpy.types import Tensor | |||
__all__ = [ | |||
'VanillaRNN', | |||
@@ -33,19 +39,21 @@ class VanillaRNN(RecurrentNode): | |||
def __init__( | |||
self, | |||
num_unit: int, | |||
state_initializer=Uniform(), | |||
wi_initializer=XavierNormal(), | |||
wh_initializer=XavierNormal(), | |||
bias_initializer=ZeroInit(), | |||
activation='relu', | |||
trainable=True, | |||
state_initializer: Union[Tensor, Callable, Initializer] = Uniform(), | |||
wi_initializer: Union[Tensor, Callable, Initializer] = XavierNormal(), | |||
wh_initializer: Union[Tensor, Callable, Initializer] = XavierNormal(), | |||
bias_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(), | |||
activation: str = 'relu', | |||
trainable: bool = True, | |||
**kwargs | |||
): | |||
super(VanillaRNN, self).__init__(trainable=trainable, **kwargs) | |||
self.num_unit = num_unit | |||
check_integer(num_unit, 'num_unit', min_bound=1, allow_none=False) | |||
self.set_output_shape((None, self.num_unit)) | |||
# initializers | |||
self._state_initializer = state_initializer | |||
self._wi_initializer = wi_initializer | |||
self._wh_initializer = wh_initializer | |||
@@ -55,23 +63,23 @@ class VanillaRNN(RecurrentNode): | |||
check_initializer(state_initializer, 'state_initializer', allow_none=False) | |||
check_initializer(bias_initializer, 'bias_initializer', allow_none=True) | |||
# activation function | |||
self.activation = bm.activations.get(activation) | |||
def init_ff(self): | |||
def init_ff_conn(self): | |||
unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True) | |||
assert len(unique_size) == 1, 'Only support data with or without batch size.' | |||
num_input = sum(free_sizes) | |||
self.set_output_shape(unique_size + (self.num_unit,)) | |||
# weights | |||
num_input = sum(free_sizes) | |||
self.Wff = init_param(self._wi_initializer, (num_input, self.num_unit)) | |||
self.Wrec = init_param(self._wh_initializer, (self.num_unit, self.num_unit)) | |||
self.bff = init_param(self._bias_initializer, (self.num_unit,)) | |||
self.bias = init_param(self._bias_initializer, (self.num_unit,)) | |||
if self.trainable: | |||
self.Wff = bm.TrainVar(self.Wff) | |||
self.Wrec = bm.TrainVar(self.Wrec) | |||
self.bff = None if (self.bff is None) else bm.TrainVar(self.bff) | |||
self.bias = None if (self.bias is None) else bm.TrainVar(self.bias) | |||
def init_fb(self): | |||
def init_fb_conn(self): | |||
unique_size, free_sizes = check_shape_consistency(self.feedback_shapes, -1, True) | |||
assert len(unique_size) == 1, 'Only support data with or without batch size.' | |||
num_feedback = sum(free_sizes) | |||
@@ -80,16 +88,15 @@ class VanillaRNN(RecurrentNode): | |||
if self.trainable: | |||
self.Wfb = bm.TrainVar(self.Wfb) | |||
def init_state(self, num_batch): | |||
state = init_param(self._state_initializer, (num_batch, self.num_unit)) | |||
self.set_state(state) | |||
def init_state(self, num_batch=1): | |||
return init_param(self._state_initializer, (num_batch, self.num_unit)) | |||
def forward(self, ff, fb=None, **kwargs): | |||
def forward(self, ff, fb=None, **shared_kwargs): | |||
ff = bm.concatenate(ff, axis=-1) | |||
h = ff @ self.Wff | |||
h += self.state.value @ self.Wrec | |||
if self.bff is not None: | |||
h += self.bff | |||
if self.bias is not None: | |||
h += self.bias | |||
if fb is not None: | |||
fb = bm.concatenate(fb, axis=-1) | |||
h += fb @ self.Wfb | |||
@@ -98,8 +105,7 @@ class VanillaRNN(RecurrentNode): | |||
class GRU(RecurrentNode): | |||
r""" | |||
Gated Recurrent Unit. | |||
r"""Gated Recurrent Unit. | |||
The implementation is based on (Chung, et al., 2014) [1]_ with biases. | |||
@@ -130,17 +136,18 @@ class GRU(RecurrentNode): | |||
def __init__( | |||
self, | |||
num_unit: int, | |||
wi_initializer=Orthogonal(), | |||
wh_initializer=Orthogonal(), | |||
bias_initializer=ZeroInit(), | |||
state_initializer=ZeroInit(), | |||
trainable=True, | |||
wi_initializer: Union[Tensor, Callable, Initializer] = Orthogonal(), | |||
wh_initializer: Union[Tensor, Callable, Initializer] = Orthogonal(), | |||
bias_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(), | |||
state_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(), | |||
trainable: bool = True, | |||
**kwargs | |||
): | |||
super(GRU, self).__init__(trainable=trainable, **kwargs) | |||
self.num_unit = num_unit | |||
check_integer(num_unit, 'num_unit', min_bound=1, allow_none=False) | |||
self.set_output_shape((None, self.num_unit)) | |||
self._wi_initializer = wi_initializer | |||
self._wh_initializer = wh_initializer | |||
@@ -151,30 +158,39 @@ class GRU(RecurrentNode): | |||
check_initializer(state_initializer, 'state_initializer', allow_none=False) | |||
check_initializer(bias_initializer, 'bias_initializer', allow_none=True) | |||
def init_ff(self): | |||
def init_ff_conn(self): | |||
# data shape | |||
unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True) | |||
assert len(unique_size) == 1, 'Only support data with or without batch size.' | |||
num_input = sum(free_sizes) | |||
self.set_output_shape(unique_size + (self.num_unit,)) | |||
# weights | |||
self.i_weight = init_param(self._wi_initializer, (num_input, self.num_unit * 3)) | |||
self.h_weight = init_param(self._wh_initializer, (self.num_unit, self.num_unit * 3)) | |||
num_input = sum(free_sizes) | |||
self.Wi_ff = init_param(self._wi_initializer, (num_input, self.num_unit * 3)) | |||
self.Wh = init_param(self._wh_initializer, (self.num_unit, self.num_unit * 3)) | |||
self.bias = init_param(self._bias_initializer, (self.num_unit * 3,)) | |||
if self.trainable: | |||
self.i_weight = bm.TrainVar(self.i_weight) | |||
self.h_weight = bm.TrainVar(self.h_weight) | |||
self.Wi_ff = bm.TrainVar(self.Wi_ff) | |||
self.Wh = bm.TrainVar(self.Wh) | |||
self.bias = bm.TrainVar(self.bias) if (self.bias is not None) else None | |||
def init_state(self, num_batch): | |||
state = init_param(self._state_initializer, (num_batch, self.num_unit)) | |||
self.set_state(state) | |||
def init_fb_conn(self): | |||
unique_size, free_sizes = check_shape_consistency(self.feedback_shapes, -1, True) | |||
assert len(unique_size) == 1, 'Only support data with or without batch size.' | |||
num_feedback = sum(free_sizes) | |||
# weights | |||
self.Wi_fb = init_param(self._wi_initializer, (num_feedback, self.num_unit * 3)) | |||
if self.trainable: | |||
self.Wi_fb = bm.TrainVar(self.Wi_fb) | |||
def forward(self, ff, fb=None, **kwargs): | |||
ff = bm.concatenate(ff, axis=-1) | |||
gates_x = bm.matmul(ff, self.i_weight) | |||
def init_state(self, num_batch=1): | |||
return init_param(self._state_initializer, (num_batch, self.num_unit)) | |||
def forward(self, ff, fb=None, **shared_kwargs): | |||
gates_x = bm.matmul(bm.concatenate(ff, axis=-1), self.Wi_ff) | |||
if fb is not None: | |||
gates_x += bm.matmul(bm.concatenate(fb, axis=-1), self.Wi_fb) | |||
zr_x, a_x = bm.split(gates_x, indices_or_sections=[2 * self.num_unit], axis=-1) | |||
w_h_z, w_h_a = bm.split(self.h_weight, indices_or_sections=[2 * self.num_unit], axis=-1) | |||
w_h_z, w_h_a = bm.split(self.Wh, indices_or_sections=[2 * self.num_unit], axis=-1) | |||
zr_h = bm.matmul(self.state, w_h_z) | |||
zr = zr_x + zr_h | |||
has_bias = (self.bias is not None) | |||
@@ -235,48 +251,62 @@ class LSTM(RecurrentNode): | |||
def __init__( | |||
self, | |||
num_unit: int, | |||
weight_initializer=Orthogonal(), | |||
bias_initializer=ZeroInit(), | |||
state_initializer=ZeroInit(), | |||
trainable=True, | |||
wi_initializer: Union[Tensor, Callable, Initializer] = XavierNormal(), | |||
wh_initializer: Union[Tensor, Callable, Initializer] = XavierNormal(), | |||
bias_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(), | |||
state_initializer: Union[Tensor, Callable, Initializer] = ZeroInit(), | |||
trainable: bool = True, | |||
**kwargs | |||
): | |||
super(LSTM, self).__init__(trainable=trainable, **kwargs) | |||
self.num_unit = num_unit | |||
check_integer(num_unit, 'num_unit', min_bound=1, allow_none=False) | |||
self.set_output_shape((None, self.num_unit,)) | |||
self._state_initializer = state_initializer | |||
self._weight_initializer = weight_initializer | |||
self._wi_initializer = wi_initializer | |||
self._wh_initializer = wh_initializer | |||
self._bias_initializer = bias_initializer | |||
check_initializer(weight_initializer, 'weight_initializer', allow_none=False) | |||
check_initializer(wi_initializer, 'wi_initializer', allow_none=False) | |||
check_initializer(wh_initializer, 'wh_initializer', allow_none=False) | |||
check_initializer(bias_initializer, 'bias_initializer', allow_none=True) | |||
check_initializer(state_initializer, 'state_initializer', allow_none=False) | |||
def init_ff(self): | |||
def init_ff_conn(self): | |||
# data shape | |||
unique_size, free_sizes = check_shape_consistency(self.feedforward_shapes, -1, True) | |||
assert len(unique_size) == 1, 'Only support data with or without batch size.' | |||
num_input = sum(free_sizes) | |||
self.set_output_shape(unique_size + (self.num_unit,)) | |||
# weights | |||
self.weight = init_param(self._weight_initializer, (num_input + self.num_unit, self.num_unit * 4)) | |||
num_input = sum(free_sizes) | |||
self.Wi_ff = init_param(self._wi_initializer, (num_input, self.num_unit * 4)) | |||
self.Wh = init_param(self._wh_initializer, (self.num_unit, self.num_unit * 4)) | |||
self.bias = init_param(self._bias_initializer, (self.num_unit * 4,)) | |||
if self.trainable: | |||
self.weight = bm.TrainVar(self.weight) | |||
self.Wi_ff = bm.TrainVar(self.Wi_ff) | |||
self.Wh = bm.TrainVar(self.Wh) | |||
self.bias = None if (self.bias is None) else bm.TrainVar(self.bias) | |||
def init_state(self, num_batch): | |||
hc = init_param(self._state_initializer, (num_batch * 2, self.num_unit)) | |||
self.set_state(hc) | |||
def init_fb_conn(self): | |||
unique_size, free_sizes = check_shape_consistency(self.feedback_shapes, -1, True) | |||
assert len(unique_size) == 1, 'Only support data with or without batch size.' | |||
num_feedback = sum(free_sizes) | |||
# weights | |||
self.Wi_fb = init_param(self._wi_initializer, (num_feedback, self.num_unit * 4)) | |||
if self.trainable: | |||
self.Wi_fb = bm.TrainVar(self.Wi_fb) | |||
def init_state(self, num_batch=1): | |||
return init_param(self._state_initializer, (num_batch * 2, self.num_unit)) | |||
def forward(self, ff, fb=None, **kwargs): | |||
def forward(self, ff, fb=None, **shared_kwargs): | |||
h, c = bm.split(self.state, 2) | |||
xh = bm.concatenate(tuple(ff) + (h,), axis=-1) | |||
if self.bias is None: | |||
gated = xh @ self.weight | |||
else: | |||
gated = xh @ self.weight + self.bias | |||
gated = bm.concatenate(ff, axis=-1) @ self.Wi_ff | |||
if fb is not None: | |||
gated += bm.concatenate(fb, axis=-1) @ self.Wi_fb | |||
if self.bias is not None: | |||
gated += self.bias | |||
gated += h @ self.Wh | |||
i, g, f, o = bm.split(gated, indices_or_sections=4, axis=-1) | |||
c = bm.sigmoid(f + 1.) * c + bm.sigmoid(i) * bm.tanh(g) | |||
h = bm.sigmoid(o) * bm.tanh(c) | |||
@@ -291,7 +321,7 @@ class LSTM(RecurrentNode): | |||
@h.setter | |||
def h(self, value): | |||
if self.state is None: | |||
raise ValueError('Cannot set "h" state. Because it is not initialized.') | |||
raise ValueError('Cannot set "h" state. Because the state is not initialized.') | |||
self.state[:self.state.shape[0] // 2, :] = value | |||
@property | |||
@@ -302,7 +332,7 @@ class LSTM(RecurrentNode): | |||
@c.setter | |||
def c(self, value): | |||
if self.state is None: | |||
raise ValueError('Cannot set "c" state. Because it is not initialized.') | |||
raise ValueError('Cannot set "c" state. Because the state is not initialized.') | |||
self.state[self.state.shape[0] // 2:, :] = value | |||
@@ -39,12 +39,12 @@ class LinearReadout(Dense): | |||
super(LinearReadout, self).__init__(num_unit=num_unit, weight_initializer=weight_initializer, bias_initializer=bias_initializer, **kwargs) | |||
def init_state(self, num_batch=1): | |||
state = bm.Variable(bm.zeros((num_batch,) + self.output_shape[1:], dtype=bm.float_)) | |||
self.set_state(state) | |||
return bm.zeros((num_batch,) + self.output_shape[1:], dtype=bm.float_) | |||
def forward(self, ff, fb=None, **kwargs): | |||
self.state.value = super(LinearReadout, self).forward(ff, fb=fb, **kwargs) | |||
return self.state | |||
def forward(self, ff, fb=None, **shared_kwargs): | |||
h = super(LinearReadout, self).forward(ff, fb=fb, **shared_kwargs) | |||
self.state.value = h | |||
return h | |||
def __force_init__(self, train_pars: Optional[Dict] = None): | |||
if train_pars is None: train_pars = dict() | |||
@@ -76,4 +76,4 @@ class LinearReadout(Dense): | |||
# update the weights | |||
e = bm.atleast_2d(self.state - target) # (1, num_output) | |||
dw = bm.dot(-c * k, e) # (num_hidden, num_output) | |||
self.weights += dw | |||
self.Wff += dw |
@@ -1,7 +1,7 @@ | |||
# -*- coding: utf-8 -*- | |||
from itertools import combinations_with_replacement | |||
from typing import Union | |||
from typing import Union, Sequence | |||
import numpy as np | |||
@@ -9,7 +9,8 @@ import brainpy.math as bm | |||
from brainpy.nn.base import RecurrentNode | |||
from brainpy.tools.checking import (check_shape_consistency, | |||
check_float, | |||
check_integer) | |||
check_integer, | |||
check_sequence) | |||
__all__ = [ | |||
'NVAR' | |||
@@ -46,7 +47,7 @@ class NVAR(RecurrentNode): | |||
---------- | |||
delay: int | |||
The number of delay step. | |||
order: int | |||
order: int, sequence of int | |||
The nonlinear order. | |||
stride: int | |||
The stride to sample linear part vector in the delays. | |||
@@ -63,59 +64,67 @@ class NVAR(RecurrentNode): | |||
def __init__(self, | |||
delay: int, | |||
order: int, | |||
order: Union[int, Sequence[int]], | |||
stride: int = 1, | |||
constant: Union[float, int] = None, | |||
**kwargs): | |||
super(NVAR, self).__init__(**kwargs) | |||
self.delay = delay | |||
if not isinstance(order, (tuple, list)): | |||
order = [order] | |||
self.order = order | |||
self.stride = stride | |||
self.constant = constant | |||
check_sequence(order, 'order', elem_type=int, allow_none=False) | |||
self.delay = delay | |||
check_integer(delay, 'delay', allow_none=False) | |||
check_integer(order, 'order', allow_none=False) | |||
self.stride = stride | |||
check_integer(stride, 'stride', allow_none=False) | |||
self.constant = constant | |||
check_float(constant, 'constant', allow_none=True, allow_int=True) | |||
def init_ff(self): | |||
self.comb_ids = [] | |||
# delay variables | |||
self.num_delay = self.delay * self.stride | |||
self.idx = bm.Variable(bm.array([0], dtype=bm.uint32)) | |||
self.store = None | |||
def init_ff_conn(self): | |||
"""Initialize feedforward connections.""" | |||
# input dimension | |||
batch_size, free_size = check_shape_consistency(self.feedforward_shapes, -1, True) | |||
self.input_dim = sum(free_size) | |||
assert batch_size == (None,), f'batch_size must be None, but got {batch_size}' | |||
# linear dimension | |||
linear_dim = self.delay * self.input_dim | |||
# for each monomial created in the non linear part, indices | |||
# for each monomial created in the non-linear part, indices | |||
# of the n components involved, n being the order of the | |||
# monomials. Precompute them to improve efficiency. | |||
idx = np.array(list(combinations_with_replacement(np.arange(linear_dim), self.order))) | |||
self.comb_ids = bm.asarray(idx) | |||
# number of non linear components is (d + n - 1)! / (d - 1)! n! | |||
for order in self.order: | |||
idx = np.array(list(combinations_with_replacement(np.arange(linear_dim), order))) | |||
self.comb_ids.append(bm.asarray(idx)) | |||
# number of non-linear components is (d + n - 1)! / (d - 1)! n! | |||
# i.e. number of all unique monomials of order n made from the | |||
# linear components. | |||
nonlinear_dim = len(self.comb_ids) | |||
nonlinear_dim = sum([len(ids) for ids in self.comb_ids]) | |||
# output dimension | |||
output_dim = int(linear_dim + nonlinear_dim) | |||
self.output_dim = int(linear_dim + nonlinear_dim) | |||
if self.constant is not None: | |||
output_dim += 1 | |||
self.set_output_shape((None, output_dim)) | |||
# delay variables | |||
self.num_delay = self.delay * self.stride | |||
self.idx = bm.Variable(bm.array([0], dtype=bm.uint32)) | |||
self.store = None | |||
self.output_dim += 1 | |||
self.set_output_shape((None, self.output_dim)) | |||
def init_state(self, num_batch=1): | |||
# to store the k*s last inputs, k being the delay and s the strides | |||
"""Initialize the node state which depends on batch size.""" | |||
# To store the last inputs. | |||
# Note, the batch axis is not in the first dimension, so we | |||
# manually handle the state of NVAR, rather return it. | |||
state = bm.zeros((self.num_delay, num_batch, self.input_dim), dtype=bm.float_) | |||
if self.store is None: | |||
self.store = bm.Variable(state) | |||
else: | |||
self.store.value = state | |||
def forward(self, ff, fb=None, **kwargs): | |||
# 1. store the current input | |||
def forward(self, ff, fb=None, **shared_kwargs): | |||
all_parts = [] | |||
# 1. Store the current input | |||
ff = bm.concatenate(ff, axis=-1) | |||
self.store[self.idx[0]] = ff | |||
self.idx.value = (self.idx + 1) % self.num_delay | |||
@@ -124,12 +133,15 @@ class NVAR(RecurrentNode): | |||
select_ids = (self.idx[0] + bm.arange(self.num_delay)[::self.stride]) % self.num_delay | |||
linear_parts = bm.moveaxis(self.store[select_ids], 0, 1) # (num_batch, num_time, num_feature) | |||
linear_parts = bm.reshape(linear_parts, (linear_parts.shape[0], -1)) | |||
# 3. constant | |||
if self.constant is not None: | |||
constant = bm.broadcast_to(self.constant, linear_parts.shape[:-1] + (1,)) | |||
all_parts.append(constant) | |||
all_parts.append(linear_parts) | |||
# 3. Nonlinear part: | |||
# select monomial terms and compute them | |||
nonlinear_parts = bm.prod(linear_parts[:, self.comb_ids], axis=2) | |||
if self.constant is None: | |||
return bm.concatenate([linear_parts, nonlinear_parts], axis=-1) | |||
else: | |||
constant = bm.broadcast_to(self.constant, linear_parts.shape[:-1] + (1,)) | |||
return bm.concatenate([constant, linear_parts, nonlinear_parts], axis=-1) | |||
for ids in self.comb_ids: | |||
all_parts.append(bm.prod(linear_parts[:, ids], axis=2)) | |||
# 4. Return all parts | |||
return bm.concatenate(all_parts, axis=-1) | |||
@@ -3,9 +3,8 @@ | |||
from typing import Optional, Union, Callable | |||
import brainpy.math as bm | |||
from brainpy.initialize import Normal, ZeroInit, Initializer | |||
from brainpy.initialize import Normal, ZeroInit, Initializer, init_param | |||
from brainpy.nn.base import RecurrentNode | |||
from brainpy.nn.utils import init_param | |||
from brainpy.tools.checking import (check_shape_consistency, | |||
check_float, | |||
check_initializer, | |||
@@ -158,7 +157,7 @@ class Reservoir(RecurrentNode): | |||
self.noise_type = noise_type | |||
check_string(noise_type, 'noise_type', ['normal', 'uniform']) | |||
def init_ff(self): | |||
def init_ff_conn(self): | |||
"""Initialize feedforward connections, weights, and variables.""" | |||
unique_shape, free_shapes = check_shape_consistency(self.feedforward_shapes, -1, True) | |||
self.set_output_shape(unique_shape + (self.num_unit,)) | |||
@@ -197,10 +196,9 @@ class Reservoir(RecurrentNode): | |||
def init_state(self, num_batch=1): | |||
# initialize internal state | |||
state = bm.Variable(bm.zeros((num_batch, self.num_unit), dtype=bm.float_)) | |||
self.set_state(state) | |||
return bm.zeros((num_batch, self.num_unit), dtype=bm.float_) | |||
def init_fb(self): | |||
def init_fb_conn(self): | |||
"""Initialize feedback connections, weights, and variables.""" | |||
if self.feedback_shapes is not None: | |||
unique_shape, free_shapes = check_shape_consistency(self.feedback_shapes, -1, True) | |||
@@ -215,7 +213,7 @@ class Reservoir(RecurrentNode): | |||
if self.trainable: | |||
self.Wfb = bm.TrainVar(self.Wfb) | |||
def forward(self, ff, fb=None, **kwargs): | |||
def forward(self, ff, fb=None, **shared_kwargs): | |||
"""Feedforward output.""" | |||
# inputs | |||
x = bm.concatenate(ff, axis=-1) | |||
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》