#32 latest updates on codes and docs

Merged
BrainPy merged 160 commits from sync_openi into master 1 year ago
  1. +1
    -1
      .github/ISSUE_TEMPLATE/bug_report.md
  2. +0
    -5
      .github/ISSUE_TEMPLATE/config.yml
  3. +10
    -0
      .github/ISSUE_TEMPLATE/feature_request.md
  4. +1
    -2
      .github/workflows/Linux_CI.yml
  5. +1
    -3
      .github/workflows/MacOS_CI.yml
  6. +2
    -2
      .github/workflows/Windows_CI.yml
  7. +1
    -0
      .gitignore
  8. +201
    -674
      LICENSE
  9. +21
    -10
      README.md
  10. +50
    -52
      brainpy/__init__.py
  11. +1
    -3
      brainpy/analysis/__init__.py
  12. +4
    -6
      brainpy/analysis/highdim/slow_points.py
  13. +8
    -1
      brainpy/analysis/lowdim/lowdim_analyzer.py
  14. +53
    -17
      brainpy/analysis/lowdim/lowdim_bifurcation.py
  15. +12
    -11
      brainpy/analysis/lowdim/lowdim_phase_plane.py
  16. +89
    -0
      brainpy/analysis/lowdim/tests/test_bifurcation.py
  17. +1
    -3
      brainpy/analysis/lowdim/tests/test_phase_plane.py
  18. +72
    -0
      brainpy/analysis/plotstyle.py
  19. +3
    -29
      brainpy/analysis/stability.py
  20. +4
    -3
      brainpy/analysis/utils/model.py
  21. +54
    -0
      brainpy/check.py
  22. +357
    -105
      brainpy/connect/base.py
  23. +54
    -39
      brainpy/connect/custom_conn.py
  24. +256
    -104
      brainpy/connect/random_conn.py
  25. +189
    -104
      brainpy/connect/regular_conn.py
  26. +24
    -14
      brainpy/connect/tests/test_random_conn.py
  27. +81
    -37
      brainpy/connect/tests/test_regular_conn.py
  28. +167
    -0
      brainpy/datasets/vision/cifar.py
  29. +17
    -13
      brainpy/dyn/base.py
  30. +2
    -1
      brainpy/dyn/layers/__init__.py
  31. +36
    -0
      brainpy/dyn/layers/activate.py
  32. +327
    -130
      brainpy/dyn/layers/conv.py
  33. +4
    -6
      brainpy/dyn/layers/dropout.py
  34. +30
    -1
      brainpy/dyn/layers/linear.py
  35. +378
    -278
      brainpy/dyn/layers/normalization.py
  36. +3
    -3
      brainpy/dyn/layers/nvar.py
  37. +242
    -124
      brainpy/dyn/layers/pooling.py
  38. +80
    -38
      brainpy/dyn/layers/rnncells.py
  39. +0
    -201
      brainpy/dyn/layers/tests/test_normalization.py
  40. +29
    -14
      brainpy/dyn/layers/tests/test_pooling.py
  41. +53
    -56
      brainpy/dyn/neurons/biological_models.py
  42. +8
    -6
      brainpy/dyn/neurons/input_groups.py
  43. +2
    -2
      brainpy/dyn/neurons/noise_groups.py
  44. +148
    -129
      brainpy/dyn/neurons/reduced_models.py
  45. +12
    -12
      brainpy/dyn/rates/populations.py
  46. +4
    -5
      brainpy/dyn/runners.py
  47. +117
    -29
      brainpy/dyn/synapses/abstract_models.py
  48. +8
    -8
      brainpy/dyn/synapses/biological_models.py
  49. +2
    -2
      brainpy/dyn/synapses/gap_junction.py
  50. +20
    -0
      brainpy/dyn/tests/test_base_classes.py
  51. +47
    -0
      brainpy/dyn/tests/test_dyn_runner.py
  52. +110
    -10
      brainpy/initialize/generic.py
  53. +2
    -2
      brainpy/inputs/currents.py
  54. +1
    -1
      brainpy/integrators/ode/tests/test_ode_method_adaptive_rk.py
  55. +14
    -17
      brainpy/integrators/runner.py
  56. +1
    -1
      brainpy/integrators/sde/tests/test_sde_scalar.py
  57. +9
    -1
      brainpy/math/autograd.py
  58. +53
    -37
      brainpy/math/controls.py
  59. +110
    -25
      brainpy/math/delayvars.py
  60. +322
    -204
      brainpy/math/jaxarray.py
  61. +13
    -9
      brainpy/math/jit.py
  62. +7
    -12
      brainpy/math/operators/__init__.py
  63. +52
    -0
      brainpy/math/operators/event_matmul.py
  64. +18
    -59
      brainpy/math/operators/op_register.py
  65. +0
    -489
      brainpy/math/operators/pre2post.py
  66. +0
    -47
      brainpy/math/operators/pre2syn.py
  67. +625
    -0
      brainpy/math/operators/pre_syn_post.py
  68. +57
    -16
      brainpy/math/operators/sparse_matmul.py
  69. +2
    -2
      brainpy/math/operators/spikegrad.py
  70. +0
    -235
      brainpy/math/operators/syn2post.py
  71. +3
    -45
      brainpy/math/operators/tests/test_op_register.py
  72. +0
    -28
      brainpy/math/operators/utils.py
  73. +139
    -9
      brainpy/math/operators/wrap_jax.py
  74. +68
    -28
      brainpy/math/random.py
  75. +14
    -2
      brainpy/math/setting.py
  76. +40
    -26
      brainpy/math/tests/test_delay_vars.py
  77. +51
    -0
      brainpy/math/tests/test_jaxarray.py
  78. +5
    -0
      brainpy/math/tests/test_numpy_einsum.py
  79. +5
    -1
      brainpy/math/tests/test_numpy_indexing.py
  80. +6
    -1
      brainpy/math/tests/test_numpy_ops.py
  81. +56
    -0
      brainpy/math/tests/test_transformation_context.py
  82. +4
    -0
      brainpy/measure/__init__.py
  83. +124
    -73
      brainpy/measure/correlation.py
  84. +13
    -18
      brainpy/measure/firings.py
  85. +114
    -0
      brainpy/measure/lfp.py
  86. +33
    -7
      brainpy/measure/tests/test_correlation.py
  87. +3
    -3
      brainpy/modes.py
  88. +2
    -2
      brainpy/optimizers/optimizer.py
  89. +19
    -2
      brainpy/running/__init__.py
  90. +141
    -0
      brainpy/running/jax_multiprocessing.py
  91. +14
    -23
      brainpy/running/native_multiprocessing.py
  92. +228
    -0
      brainpy/running/pathos_multiprocessing.py
  93. +3
    -2
      brainpy/tools/others/numba_util.py
  94. +22
    -1
      brainpy/tools/others/others.py
  95. +12
    -5
      brainpy/train/back_propagation.py
  96. +1
    -0
      brainpy/train/offline.py
  97. +4
    -5
      brainpy/train/online.py
  98. +18
    -135
      docs/auto_generater.py
  99. +5
    -10
      docs/conf.py
  100. +9
    -2
      docs/index.rst

+ 1
- 1
.github/ISSUE_TEMPLATE/bug_report.md View File

@@ -1,6 +1,6 @@
---
name: 'Bug Report'
about: 'Report a bug to help improve the package'
about: 'Report a bug or unexpected behavior to help us improve the package'
labels: 'bug'
---



+ 0
- 5
.github/ISSUE_TEMPLATE/config.yml View File

@@ -1,5 +0,0 @@
blank_issues_enabled: false
contact_links:
- name: Question
url: https://github.com/PKU-NIP-Lab/BrainPy/discussions
about: Please ask questions on the Discussions tab

+ 10
- 0
.github/ISSUE_TEMPLATE/feature_request.md View File

@@ -0,0 +1,10 @@
---
name: 'Feature Request'
about: 'Suggest a new idea or improvement for BrainPy'
labels: 'enhancement'
---

Please:

- [ ] Check for duplicate requests.
- [ ] Describe your goal, and if possible provide a code snippet with a motivating example.

+ 1
- 2
.github/workflows/Linux_CI.yml View File

@@ -28,7 +28,6 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
# python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.14.tar.gz
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
python setup.py install
- name: Lint with flake8
@@ -36,7 +35,7 @@ jobs:
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest brainpy/

+ 1
- 3
.github/workflows/MacOS_CI.yml View File

@@ -28,8 +28,6 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install jax==0.3.14
python -m pip install jaxlib==0.3.14
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
python setup.py install
- name: Lint with flake8
@@ -37,7 +35,7 @@ jobs:
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest brainpy/

+ 2
- 2
.github/workflows/Windows_CI.yml View File

@@ -31,7 +31,7 @@ jobs:
python -m pip install numpy>=1.21.0
python -m pip install "jaxlib==0.3.14" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.14.tar.gz
python -m pip install -r requirements-win.txt
python -m pip install -r requirements-dev.txt
python -m pip install tqdm brainpylib
python setup.py install
- name: Lint with flake8
@@ -39,7 +39,7 @@ jobs:
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest brainpy/

+ 1
- 0
.gitignore View File

@@ -10,6 +10,7 @@ development
brainpy/dyn/tests/data
examples/simulation/data
examples/simulation/results
examples/ANN_models/data
examples/analysis/data
extensions/.idea
extensions/wheelhouse


+ 201
- 674
LICENSE View File

@@ -1,674 +1,201 @@
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007

Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.

Preamble

The GNU General Public License is a free, copyleft license for
software and other kinds of works.

The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.

When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.

To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.

For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.

Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.

For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.

Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.

Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.

The precise terms and conditions for copying, distribution and
modification follow.

TERMS AND CONDITIONS

0. Definitions.

"This License" refers to version 3 of the GNU General Public License.

"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.

"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.

To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.

A "covered work" means either the unmodified Program or a work based
on the Program.

To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.

To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.

An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.

1. Source Code.

The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.

A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.

The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.

The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.

The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.

The Corresponding Source for a work in source code form is that
same work.

2. Basic Permissions.

All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.

You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.

Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.

3. Protecting Users' Legal Rights From Anti-Circumvention Law.

No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.

When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.

4. Conveying Verbatim Copies.

You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.

You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.

5. Conveying Modified Source Versions.

You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:

a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.

b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".

c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.

d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.

A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.

6. Conveying Non-Source Forms.

You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:

a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.

b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.

c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.

d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.

e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.

A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.

A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.

"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.

If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).

The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.

Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.

7. Additional Terms.

"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.

When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.

Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:

a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or

b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or

c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or

d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or

e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or

f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.

All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.

If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.

Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.

8. Termination.

You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).

However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.

Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.

Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.

9. Acceptance Not Required for Having Copies.

You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.

10. Automatic Licensing of Downstream Recipients.

Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.

An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.

You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.

11. Patents.

A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".

A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.

Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.

In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.

If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.

If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.

A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.

Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.

12. No Surrender of Others' Freedom.

If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.

13. Use with the GNU Affero General Public License.

Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.

14. Revised Versions of this License.

The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.

Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.

If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.

Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.

15. Disclaimer of Warranty.

THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.

16. Limitation of Liability.

IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.

17. Interpretation of Sections 15 and 16.

If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.

END OF TERMS AND CONDITIONS

How to Apply These Terms to Your New Programs

If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.

To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.

BrainPy: Brain Dynamics Programming in Python
Copyright (C) 2022 BrainPy team

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.

Also add information on how to contact you by electronic and paper mail.

If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:

BrainPy Copyright (C) 2022 BrainPy team
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.

The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".

You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.

The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/

TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

1. Definitions.

"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.

"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.

"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.

"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.

"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.

"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.

"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).

"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.

"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."

"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.

2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.

3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.

4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:

(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and

(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and

(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and

(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.

You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.

5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.

6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.

7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.

8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.

9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.

END OF TERMS AND CONDITIONS

APPENDIX: How to apply the Apache License to your work.

To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright [yyyy] [name of copyright owner]

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

+ 21
- 10
README.md View File

@@ -1,5 +1,5 @@
<p align="center">
<img alt="Header image of BrainPy - brain dynamics programming in Python." src="./images/logo.png" width=80%>
<img alt="Header image of BrainPy - brain dynamics programming in Python." src="https://github.com/PKU-NIP-Lab/BrainPy/blob/master/images/logo.png" width=80%>
</p>


@@ -10,20 +10,28 @@
<a href="https://brainpy.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation" src="https://readthedocs.org/projects/brainpy/badge/?version=latest"></a>
<a href="https://badge.fury.io/py/brainpy"><img alt="PyPI version" src="https://badge.fury.io/py/brainpy.svg"></a>
<a href="https://github.com/PKU-NIP-Lab/BrainPy"><img alt="Linux CI" src="https://github.com/PKU-NIP-Lab/BrainPy/actions/workflows/Linux_CI.yml/badge.svg"></a>
<a href="https://github.com/PKU-NIP-Lab/BrainPy"><img alt="Linux CI" src="https://github.com/PKU-NIP-Lab/BrainPy/actions/workflows/Windows_CI.yml/badge.svg"></a>
<a href="https://github.com/PKU-NIP-Lab/BrainPy"><img alt="Linux CI" src="https://github.com/PKU-NIP-Lab/BrainPy/actions/workflows/MacOS_CI.yml/badge.svg"></a>
<a href="https://github.com/PKU-NIP-Lab/BrainPy"><img alt="Windows CI" src="https://github.com/PKU-NIP-Lab/BrainPy/actions/workflows/Windows_CI.yml/badge.svg"></a>
<a href="https://github.com/PKU-NIP-Lab/BrainPy"><img alt="MacOS CI" src="https://github.com/PKU-NIP-Lab/BrainPy/actions/workflows/MacOS_CI.yml/badge.svg"></a>
</p>




BrainPy is a flexible, efficient, and extensible framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation (built on top of [JAX](https://github.com/google/jax)). It provides an integrative ecosystem for brain dynamics programming, including brain dynamics **simulation**, **training**, **analysis**, etc.
BrainPy is a flexible, efficient, and extensible framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation (built on top of [JAX](https://github.com/google/jax), [Numba](https://github.com/numba/numba), and other JIT compilers). It provides an integrative ecosystem for brain dynamics programming, including brain dynamics **building**, **simulation**, **training**, **analysis**, etc.

- **Website (documentation and APIs)**: https://brainpy.readthedocs.io/en/latest
- **Source**: https://github.com/PKU-NIP-Lab/BrainPy
- **Bug reports**: https://github.com/PKU-NIP-Lab/BrainPy/issues
- **Source on OpenI**: https://git.openi.org.cn/OpenI/BrainPy
- **Examples from literature**: https://brainpy-examples.readthedocs.io/



## Ecosystem

- **[BrainPy](https://github.com/PKU-NIP-Lab/BrainPy)**: The solution for the general-purpose brain dynamics programming.
- **[brainpylib](https://github.com/PKU-NIP-Lab/brainpylib)**: Efficient operators for the sparse and event-driven computation.
- **[BrainPyExamples](https://github.com/PKU-NIP-Lab/BrainPyExamples)**: Comprehensive examples of BrainPy computation.
- **[brainpy-largescale](https://github.com/NH-NCL/brainpy-largescale)**: One solution for the large-scale brain modeling.



@@ -35,13 +43,16 @@ BrainPy is based on Python (>=3.7) and can be installed on Linux (Ubuntu 16.04
$ pip install brainpy -U
```

The following packages are required for ``BrainPy``:

`numpy >= 1.15` and `jax >= 0.3.0` ([how to install jax?](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax))

For detailed installation instructions, please refer to the documentation: [Quickstart/Installation](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html)



## License

[GNU General Public License v3.0](https://github.com/PKU-NIP-Lab/BrainPy/blob/master/LICENSE)
[Apache License, Version 2.0](https://github.com/PKU-NIP-Lab/BrainPy/blob/master/LICENSE)



## Citing

If you are using BrainPy, please consider citing [the corresponding papers](https://brainpy.readthedocs.io/en/latest/tutorial_FAQs/citing_and_publication.html).

+ 50
- 52
brainpy/__init__.py View File

@@ -1,44 +1,30 @@
# -*- coding: utf-8 -*-

__version__ = "2.2.1"


try:
import jaxlib
del jaxlib
except ModuleNotFoundError:
raise ModuleNotFoundError(
'Please install jaxlib. See '
'https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax '
'for installation instructions.'
)
__version__ = "2.2.4.0"


# fundamental modules
from . import errors, tools, check, modes


# "base" module
from . import base
from .base.base import Base
from .base.collector import Collector, TensorCollector


# math foundation
from . import math


# toolboxes
from . import (connect, # synaptic connection
initialize, # weight initialization
optimizers, # gradient descent optimizers
losses, # loss functions
measure, # methods for data analysis
datasets, # methods for generating data
inputs, # methods for generating input currents
algorithms, # online or offline training algorithms
)
from . import (
connect, # synaptic connection
initialize, # weight initialization
optimizers, # gradient descent optimizers
losses, # loss functions
measure, # methods for data analysis
datasets, # methods for generating data
inputs, # methods for generating input currents
algorithms, # online or offline training algorithms
)

# numerical integrators
from . import integrators
@@ -50,49 +36,61 @@ from .integrators.sde import sdeint
from .integrators.fde import fdeint
from .integrators.joint_eq import JointEq


# dynamics simulation
from . import dyn
from .dyn import (channels, # channel models
layers, # ANN layers
networks, # network models
neurons, # neuron groups
rates, # rate models
synapses, # synaptic dynamics
synouts, # synaptic output
synplast, # synaptic plasticity
)
from brainpy.dyn.base import (DynamicalSystem,
Container,
Sequential,
Network,
NeuGroup,
SynConn,
SynOut,
SynSTP,
SynLTP,
TwoEndConn,
CondNeuGroup,
Channel,)
from .dyn import (
channels, # channel models
layers, # ANN layers
networks, # network models
neurons, # neuron groups
rates, # rate models
synapses, # synaptic dynamics
synouts, # synaptic output
synplast, # synaptic plasticity
)
from .dyn.base import (
DynamicalSystem,
Container,
Sequential,
Network,
NeuGroup,
SynConn,
SynOut,
SynSTP,
SynLTP,
TwoEndConn,
CondNeuGroup,
Channel,
)
from .dyn.runners import *


# dynamics training
from . import train

from .train import (
DSTrainer,
OnlineTrainer, ForceTrainer,
OfflineTrainer, RidgeTrainer,
BPFF,
BPTT,
OnlineBPTT,
)

# automatic dynamics analysis
from . import analysis

from .analysis import (
DSAnalyzer,
PhasePlane1D, PhasePlane2D,
Bifurcation1D, Bifurcation2D,
FastSlow1D, FastSlow2D,
SlowPointFinder,
)

# running
from . import running


# "visualization" module, will be removed soon
from .visualization import visualize


# convenient access
conn = connect
init = initialize


+ 1
- 3
brainpy/analysis/__init__.py View File

@@ -22,6 +22,4 @@ from .lowdim.lowdim_phase_plane import *
from .lowdim.lowdim_bifurcation import *

from .constants import *
from . import constants as C
from . import stability
from . import utils
from . import constants as C, stability, plotstyle, utils

+ 4
- 6
brainpy/analysis/highdim/slow_points.py View File

@@ -355,8 +355,7 @@ class SlowPointFinder(base.DSAnalyzer):
return loss

def batch_train(start_i, n_batch):
f = bm.make_loop(train, dyn_vars=dyn_vars, has_return=True)
return f(bm.arange(start_i, start_i + n_batch))
return bm.for_loop(train, dyn_vars, bm.arange(start_i, start_i + n_batch))

# Run the optimization
if self.verbose:
@@ -369,7 +368,7 @@ class SlowPointFinder(base.DSAnalyzer):
break
batch_idx_start = oidx * num_batch
start_time = time.time()
(_, train_losses) = batch_train(start_i=batch_idx_start, n_batch=num_batch)
train_losses = batch_train(start_i=batch_idx_start, n_batch=num_batch)
batch_time = time.time() - start_time
opt_losses.append(train_losses)

@@ -722,8 +721,6 @@ class SlowPointFinder(base.DSAnalyzer):
shared = DotDict(t=t, dt=dt, i=0)

def f_cell(h: Dict):
target.clear_input()

# update target variables
for k, v in self.target_vars.items():
v.value = (bm.asarray(h[k], dtype=v.dtype)
@@ -735,6 +732,7 @@ class SlowPointFinder(base.DSAnalyzer):
v.value = self.excluded_data[k]

# add inputs
target.clear_input()
if f_input is not None:
f_input(shared)

@@ -743,7 +741,7 @@ class SlowPointFinder(base.DSAnalyzer):
target.update(*args)

# get new states
new_h = {k: (v.value if v.batch_axis is None else jnp.squeeze(v.value, axis=v.batch_axis))
new_h = {k: (v.value if (v.batch_axis is None) else jnp.squeeze(v.value, axis=v.batch_axis))
for k, v in self.target_vars.items()}
return new_h



+ 8
- 1
brainpy/analysis/lowdim/lowdim_analyzer.py View File

@@ -96,6 +96,9 @@ class LowDimAnalyzer(DSAnalyzer):
for key in self.target_vars.keys():
if key not in self.model.variables:
raise errors.AnalyzerError(f'{key} is not a dynamical variable in {self.model}.')
value = self.target_vars[key]
if value[0] > value[1]:
raise errors.AnalyzerError(f'The range of variable {key} is reversed, which means {value[0]} should be smaller than {value[1]}.')

# fixed variables
# ----------------
@@ -134,9 +137,13 @@ class LowDimAnalyzer(DSAnalyzer):
target_pars = dict()
if not isinstance(target_pars, dict):
raise errors.AnalyzerError('"target_pars" must be a dict with the format of {"par1": (val1, val2)}.')
for key in target_pars.keys():
for key, value in target_pars.items():
if key not in self.model.parameters:
raise errors.AnalyzerError(f'"{key}" is not a valid parameter in "{self.model}" model.')
if value[0] > value[1]:
raise errors.AnalyzerError(
f'The range of parameter {key} is reversed, which means {value[0]} should be smaller than {value[1]}.')

self.target_pars = Collector(target_pars)
self.target_par_names = list(self.target_pars.keys()) # list of target_pars



+ 53
- 17
brainpy/analysis/lowdim/lowdim_bifurcation.py View File

@@ -5,10 +5,11 @@ from functools import partial
import jax.numpy as jnp
from jax import vmap
import numpy as np
from copy import deepcopy

import brainpy.math as bm
from brainpy import errors
from brainpy.analysis import stability, utils, constants as C
from brainpy.analysis import stability, plotstyle, utils, constants as C
from brainpy.analysis.lowdim.lowdim_analyzer import *

pyplot = None
@@ -79,8 +80,8 @@ class Bifurcation1D(Num1DAnalyzer):
pyplot.figure(self.x_var)
for fp_type, points in container.items():
if len(points['x']):
plot_style = stability.plot_scheme[fp_type]
pyplot.plot(points['p'], points['x'], '.', **plot_style, label=fp_type)
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
pyplot.plot(points['p'], points['x'], **plot_style, label=fp_type)
pyplot.xlabel(self.target_par_names[0])
pyplot.ylabel(self.x_var)

@@ -107,10 +108,12 @@ class Bifurcation1D(Num1DAnalyzer):
ax = fig.add_subplot(projection='3d')
for fp_type, points in container.items():
if len(points['x']):
plot_style = stability.plot_scheme[fp_type]
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
xs = points['p0']
ys = points['p1']
zs = points['x']
plot_style.pop('linestyle')
plot_style['s'] = plot_style.pop('markersize', None)
ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

ax.set_xlabel(self.target_par_names[0])
@@ -298,8 +301,8 @@ class Bifurcation2D(Num2DAnalyzer):
pyplot.figure(var)
for fp_type, points in container.items():
if len(points['p']):
plot_style = stability.plot_scheme[fp_type]
pyplot.plot(points['p'], points[var], '.', **plot_style, label=fp_type)
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
pyplot.plot(points['p'], points[var], **plot_style, label=fp_type)
pyplot.xlabel(self.target_par_names[0])
pyplot.ylabel(var)

@@ -330,10 +333,12 @@ class Bifurcation2D(Num2DAnalyzer):
ax = fig.add_subplot(projection='3d')
for fp_type, points in container.items():
if len(points['p0']):
plot_style = stability.plot_scheme[fp_type]
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
xs = points['p0']
ys = points['p1']
zs = points[var]
plot_style.pop('linestyle')
plot_style['s'] = plot_style.pop('markersize', None)
ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

ax.set_xlabel(self.target_par_names[0])
@@ -354,8 +359,17 @@ class Bifurcation2D(Num2DAnalyzer):
if with_return:
return final_fps, final_pars, jacobians

def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=False,
plot_style=None, tol=0.001, show=False, dt=None, offset=1.):
def plot_limit_cycle_by_sim(
self,
duration=100,
with_plot: bool = True,
with_return: bool = False,
plot_style: dict = None,
tol: float = 0.001,
show: bool = False,
dt: float = None,
offset: float = 1.
):
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am plotting the limit cycle ...')
@@ -400,10 +414,16 @@ class Bifurcation2D(Num2DAnalyzer):
if len(ps_limit_cycle[0]):
for i, var in enumerate(self.target_var_names):
pyplot.figure(var)
pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['max'],
**plot_style, label='limit cycle (max)')
pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['min'],
**plot_style, label='limit cycle (min)')
pyplot.plot(ps_limit_cycle[0],
ps_limit_cycle[1],
vs_limit_cycle[i]['max'],
**plot_style,
label='limit cycle (max)')
pyplot.plot(ps_limit_cycle[0],
ps_limit_cycle[1],
vs_limit_cycle[i]['min'],
**plot_style,
label='limit cycle (min)')
pyplot.legend()

elif len(self.target_par_names) == 1:
@@ -427,8 +447,16 @@ class Bifurcation2D(Num2DAnalyzer):


class FastSlow1D(Bifurcation1D):
def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
pars_update=None, resolutions=None, options=None):
def __init__(
self,
model,
fast_vars: dict,
slow_vars: dict,
fixed_vars: dict = None,
pars_update: dict = None,
resolutions=None,
options: dict = None
):
super(FastSlow1D, self).__init__(model=model,
target_pars=slow_vars,
target_vars=fast_vars,
@@ -510,8 +538,16 @@ class FastSlow1D(Bifurcation1D):


class FastSlow2D(Bifurcation2D):
def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
pars_update=None, resolutions=0.1, options=None):
def __init__(
self,
model,
fast_vars: dict,
slow_vars: dict,
fixed_vars: dict = None,
pars_update: dict = None,
resolutions=0.1,
options: dict = None
):
super(FastSlow2D, self).__init__(model=model,
target_pars=slow_vars,
target_vars=fast_vars,


+ 12
- 11
brainpy/analysis/lowdim/lowdim_phase_plane.py View File

@@ -4,9 +4,10 @@ import jax.numpy as jnp
import numpy as np
from jax import vmap

from copy import deepcopy
import brainpy.math as bm
from brainpy import errors, math
from brainpy.analysis import stability, constants as C, utils
from brainpy.analysis import stability, plotstyle, constants as C, utils
from brainpy.analysis.lowdim.lowdim_analyzer import *

pyplot = None
@@ -107,8 +108,8 @@ class PhasePlane1D(Num1DAnalyzer):
if with_plot:
for fp_type, points in container.items():
if len(points):
plot_style = stability.plot_scheme[fp_type]
pyplot.plot(points, [0] * len(points), '.', markersize=20, **plot_style, label=fp_type)
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
pyplot.plot(points, [0] * len(points), **plot_style, label=fp_type)
pyplot.legend()
if show:
pyplot.show()
@@ -248,9 +249,9 @@ class PhasePlane2D(Num2DAnalyzer):

if with_plot:
if x_style is None:
x_style = dict(color='cornflowerblue', alpha=.7, )
fmt = x_style.pop('fmt', '.')
pyplot.plot(x_values_in_fx, y_values_in_fx, fmt, **x_style, label=f"{self.x_var} nullcline")
x_style = dict(color='cornflowerblue', alpha=.7, fmt='.')
line_args = (x_style.pop('fmt'), ) if 'fmt' in x_style else tuple()
pyplot.plot(x_values_in_fx, y_values_in_fx, *line_args, **x_style, label=f"{self.x_var} nullcline")

# Nullcline of the y variable
utils.output('I am computing fy-nullcline ...')
@@ -260,9 +261,9 @@ class PhasePlane2D(Num2DAnalyzer):

if with_plot:
if y_style is None:
y_style = dict(color='lightcoral', alpha=.7, )
fmt = y_style.pop('fmt', '.')
pyplot.plot(x_values_in_fy, y_values_in_fy, fmt, **y_style, label=f"{self.y_var} nullcline")
y_style = dict(color='lightcoral', alpha=.7, fmt='.')
line_args = (y_style.pop('fmt'), ) if 'fmt' in y_style else tuple()
pyplot.plot(x_values_in_fy, y_values_in_fy, *line_args, **y_style, label=f"{self.y_var} nullcline")

if with_plot:
pyplot.xlabel(self.x_var)
@@ -349,8 +350,8 @@ class PhasePlane2D(Num2DAnalyzer):
if with_plot:
for fp_type, points in container.items():
if len(points['x']):
plot_style = stability.plot_scheme[fp_type]
pyplot.plot(points['x'], points['y'], '.', markersize=20, **plot_style, label=fp_type)
plot_style = deepcopy(plotstyle.plot_schema[fp_type])
pyplot.plot(points['x'], points['y'], **plot_style, label=fp_type)
pyplot.legend()
if show:
pyplot.show()


+ 89
- 0
brainpy/analysis/lowdim/tests/test_bifurcation.py View File

@@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-


import pytest
pytest.skip('Test cannot pass in github action.', allow_module_level=True)
import unittest

import brainpy as bp
import brainpy.math as bm
import matplotlib.pyplot as plt

block = False


class FitzHughNagumoModel(bp.dyn.DynamicalSystem):
def __init__(self, method='exp_auto'):
super(FitzHughNagumoModel, self).__init__()

# parameters
self.a = 0.7
self.b = 0.8
self.tau = 12.5

# variables
self.V = bm.Variable(bm.zeros(1))
self.w = bm.Variable(bm.zeros(1))
self.Iext = bm.Variable(bm.zeros(1))

# functions
def dV(V, t, w, Iext=0.):
dV = V - V * V * V / 3 - w + Iext
return dV

def dw(w, t, V, a=0.7, b=0.8):
dw = (V + a - b * w) / self.tau
return dw

self.int_V = bp.odeint(dV, method=method)
self.int_w = bp.odeint(dw, method=method)

def update(self, tdi):
t, dt = tdi['t'], tdi['dt']
self.V.value = self.int_V(self.V, t, self.w, self.Iext, dt)
self.w.value = self.int_w(self.w, t, self.V, self.a, self.b, dt)
self.Iext[:] = 0.


class TestBifurcation1D(unittest.TestCase):
def test_bifurcation_1d(self):
bp.math.enable_x64()

@bp.odeint
def int_x(x, t, a=1., b=1.):
return bp.math.sin(a * x) + bp.math.cos(b * x)

pp = bp.analysis.PhasePlane1D(
model=int_x,
target_vars={'x': [-bp.math.pi, bp.math.pi]},
resolutions=0.1
)
pp.plot_vector_field()
pp.plot_fixed_point(show=True)

bf = bp.analysis.Bifurcation1D(
model=int_x,
target_vars={'x': [-bp.math.pi, bp.math.pi]},
target_pars={'a': [0.5, 1.5], 'b': [0.5, 1.5]},
resolutions={'a': 0.1, 'b': 0.1}
)
bf.plot_bifurcation(show=False)
plt.show(block=block)
plt.close()
bp.math.disable_x64()

def test_bifurcation_2d(self):
bp.math.enable_x64()

model = FitzHughNagumoModel()
bif = bp.analysis.Bifurcation2D(
model=model,
target_vars={'V': [-3., 3.], 'w': [-1, 3.]},
target_pars={'Iext': [0., 1.]},
resolutions={'Iext': 0.1}
)
bif.plot_bifurcation()
bif.plot_limit_cycle_by_sim()
plt.show(block=block)

# bp.math.disable_x64()

+ 1
- 3
brainpy/analysis/lowdim/tests/test_phase_plane.py View File

@@ -3,13 +3,13 @@
import unittest

import brainpy as bp
import matplotlib.pyplot as plt

block = False


class TestPhasePlane(unittest.TestCase):
def test_1d(self):
import matplotlib.pyplot as plt
bp.math.enable_x64()

@bp.odeint
@@ -30,8 +30,6 @@ class TestPhasePlane(unittest.TestCase):
bp.math.disable_x64()

def test_2d_decision_making_model(self):
import matplotlib.pyplot as plt

bp.math.enable_x64()
gamma = 0.641 # Saturation factor for gating variable
tau = 0.06 # Synaptic time constant [sec]


+ 72
- 0
brainpy/analysis/plotstyle.py View File

@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-


__all__ = [
'plot_schema',
'set_plot_schema',
]

from .stability import (CENTER_MANIFOLD, SADDLE_NODE, STABLE_POINT_1D,
UNSTABLE_POINT_1D, CENTER_2D, STABLE_NODE_2D,
STABLE_FOCUS_2D, STABLE_STAR_2D, STABLE_DEGENERATE_2D,
UNSTABLE_NODE_2D, UNSTABLE_FOCUS_2D, UNSTABLE_STAR_2D,
UNSTABLE_DEGENERATE_2D, UNSTABLE_LINE_2D,
STABLE_POINT_3D, UNSTABLE_POINT_3D, STABLE_NODE_3D,
UNSTABLE_SADDLE_3D, UNSTABLE_NODE_3D, STABLE_FOCUS_3D,
UNSTABLE_FOCUS_3D, UNSTABLE_CENTER_3D, UNKNOWN_3D)


_markersize = 10

plot_schema = {}

plot_schema[CENTER_MANIFOLD] = {'color': 'orangered', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}
plot_schema[SADDLE_NODE] = {"color": 'tab:blue', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}

plot_schema[STABLE_POINT_1D] = {"color": 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}
plot_schema[UNSTABLE_POINT_1D] = {"color": 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}

plot_schema.update({
CENTER_2D: {'color': 'lime', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_NODE_2D: {"color": 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_FOCUS_2D: {"color": 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_STAR_2D: {'color': 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_DEGENERATE_2D: {'color': 'blueviolet', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_NODE_2D: {"color": 'tab:orange', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_FOCUS_2D: {"color": 'tab:cyan', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_STAR_2D: {'color': 'green', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_DEGENERATE_2D: {'color': 'springgreen', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_LINE_2D: {'color': 'dodgerblue', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
})


plot_schema.update({
STABLE_POINT_3D: {'color': 'tab:gray', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_POINT_3D: {'color': 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_NODE_3D: {'color': 'tab:green', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_SADDLE_3D: {'color': 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_FOCUS_3D: {'color': 'tab:pink', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_FOCUS_3D: {'color': 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_NODE_3D: {'color': 'tab:orange', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_CENTER_3D: {'color': 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNKNOWN_3D: {'color': 'tab:cyan', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
})


def set_plot_schema(fixed_point: str, **schema):
if not isinstance(fixed_point, str):
raise TypeError(f'Must instance of string, but we got {type(fixed_point)}: {fixed_point}')
if fixed_point not in plot_schema:
raise KeyError(f'Fixed point type {fixed_point} does not found in the built-in types. ')
plot_schema[fixed_point].update(**schema)


def set_markersize(markersize):
if not isinstance(markersize, int):
raise TypeError(f"Must be an integer, but got {type(markersize)}: {markersize}")
global _markersize
__markersize = markersize
for key in tuple(plot_schema.keys()):
plot_schema[key]['markersize'] = markersize



+ 3
- 29
brainpy/analysis/stability.py View File

@@ -6,7 +6,7 @@ __all__ = [
'get_1d_stability_types',
'get_2d_stability_types',
'get_3d_stability_types',
'plot_scheme',

'stability_analysis',

@@ -27,17 +27,13 @@ __all__ = [
'UNSTABLE_LINE_2D',
]

plot_scheme = {}

SADDLE_NODE = 'saddle node'
CENTER_MANIFOLD = 'center manifold'
plot_scheme[CENTER_MANIFOLD] = {'color': 'orangered'}
plot_scheme[SADDLE_NODE] = {"color": 'tab:blue'}

STABLE_POINT_1D = 'stable point'
UNSTABLE_POINT_1D = 'unstable point'
plot_scheme[STABLE_POINT_1D] = {"color": 'tab:red'}
plot_scheme[UNSTABLE_POINT_1D] = {"color": 'tab:olive'}

CENTER_2D = 'center'
STABLE_NODE_2D = 'stable node'
@@ -49,18 +45,7 @@ UNSTABLE_FOCUS_2D = 'unstable focus'
UNSTABLE_STAR_2D = 'unstable star'
UNSTABLE_DEGENERATE_2D = 'unstable degenerate'
UNSTABLE_LINE_2D = 'unstable line'
plot_scheme.update({
CENTER_2D: {'color': 'lime'},
STABLE_NODE_2D: {"color": 'tab:red'},
STABLE_FOCUS_2D: {"color": 'tab:purple'},
STABLE_STAR_2D: {'color': 'tab:olive'},
STABLE_DEGENERATE_2D: {'color': 'blueviolet'},
UNSTABLE_NODE_2D: {"color": 'tab:orange'},
UNSTABLE_FOCUS_2D: {"color": 'tab:cyan'},
UNSTABLE_STAR_2D: {'color': 'green'},
UNSTABLE_DEGENERATE_2D: {'color': 'springgreen'},
UNSTABLE_LINE_2D: {'color': 'dodgerblue'},
})


STABLE_POINT_3D = 'unclassified stable point'
UNSTABLE_POINT_3D = 'unclassified unstable point'
@@ -71,17 +56,6 @@ STABLE_FOCUS_3D = 'stable focus'
UNSTABLE_FOCUS_3D = 'unstable focus'
UNSTABLE_CENTER_3D = 'unstable center'
UNKNOWN_3D = 'unknown 3d'
plot_scheme.update({
STABLE_POINT_3D: {'color': 'tab:gray'},
UNSTABLE_POINT_3D: {'color': 'tab:purple'},
STABLE_NODE_3D: {'color': 'tab:green'},
UNSTABLE_SADDLE_3D: {'color': 'tab:red'},
UNSTABLE_FOCUS_3D: {'color': 'tab:pink'},
STABLE_FOCUS_3D: {'color': 'tab:purple'},
UNSTABLE_NODE_3D: {'color': 'tab:orange'},
UNSTABLE_CENTER_3D: {'color': 'tab:olive'},
UNKNOWN_3D: {'color': 'tab:cyan'},
})


def get_1d_stability_types():


+ 4
- 3
brainpy/analysis/utils/model.py View File

@@ -112,14 +112,14 @@ class TrajectModel(DynamicalSystem):

# variables
assert isinstance(initial_vars, dict)
initial_vars = {k: bm.Variable(jnp.asarray(bm.as_device_array(v), dtype=jnp.float_))
initial_vars = {k: bm.Variable(jnp.asarray(bm.as_device_array(v), dtype=bm.dftype()))
for k, v in initial_vars.items()}
self.register_implicit_vars(initial_vars)

# parameters
pars = dict() if pars is None else pars
assert isinstance(pars, dict)
self.pars = [jnp.asarray(bm.as_device_array(v), dtype=jnp.float_)
self.pars = [jnp.asarray(bm.as_device_array(v), dtype=bm.dftype())
for k, v in pars.items()]

# integrals
@@ -128,7 +128,8 @@ class TrajectModel(DynamicalSystem):
# runner
self.runner = DSRunner(self,
monitors=list(initial_vars.keys()),
dyn_vars=self.vars().unique(), dt=dt,
dyn_vars=self.vars().unique(),
dt=dt,
progress_bar=False)

def update(self, sha):


+ 54
- 0
brainpy/check.py View File

@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-


from brainpy.errors import PackageMissingError

__all__ = [
'is_checking',
'turn_on',
@@ -9,6 +11,8 @@ __all__ = [

_check = True

_BRAINPYLIB_MINIMAL_VERSION = '0.1.2'


def is_checking():
"""Whether the checking is turn on."""
@@ -25,3 +29,53 @@ def turn_off():
"""Turn off the checking."""
global _check
_check = False


try:
import jaxlib
del jaxlib
except ModuleNotFoundError:
raise ModuleNotFoundError(
'''

BrainPy needs jaxlib, please install it.

1. If you are using Windows system, install jaxlib through

>>> pip install jaxlib -f https://whls.blob.core.windows.net/unstable/index.html

2. If you are using macOS platform, install jaxlib through

>>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html

3. If you are using Linux platform, install jaxlib through

>>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html

4. If you are using Linux + CUDA platform, install jaxlib through

>>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Note that the versions of "jax" and "jaxlib" should be consistent, like "jax=0.3.14" and "jaxlib=0.3.14".

For more detail installation instructions, please see https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#dependency-2-jax

''') from None

try:
import brainpylib

if brainpylib.__version__ < _BRAINPYLIB_MINIMAL_VERSION:
raise PackageMissingError(
f'brainpy need "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}". \n'
f'Please install it through:\n\n'
f'>>> pip install brainpylib -U'
)

del brainpylib
except ModuleNotFoundError:
raise PackageMissingError(
f'brainpy need "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}". \n'
f'Please install "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}" through:\n\n'
f'>>> pip install brainpylib'
)

+ 357
- 105
brainpy/connect/base.py View File

@@ -3,7 +3,8 @@
import abc
from typing import Union, List, Tuple

import numpy as np
import jax.numpy as jnp
import numpy as onp

from brainpy import tools, math as bm
from brainpy.errors import ConnectorError
@@ -23,7 +24,9 @@ __all__ = [
'Connector', 'TwoEndConnector', 'OneEndConnector',

# methods
'csr2csc', 'csr2mat', 'mat2csr', 'ij2csr'
'mat2coo', 'mat2csc', 'mat2csr',
'csr2csc', 'csr2mat', 'csr2coo',
'coo2csr', 'coo2csc', 'coo2mat',
]

CONN_MAT = 'conn_mat'
@@ -35,15 +38,19 @@ PRE2SYN = 'pre2syn'
POST2SYN = 'post2syn'
PRE_SLICE = 'pre_slice'
POST_SLICE = 'post_slice'
COO = 'coo'
CSR = 'csr'
CSC = 'csc'

SUPPORTED_SYN_STRUCTURE = [CONN_MAT,
PRE_IDS, POST_IDS,
PRE2POST, POST2PRE,
PRE2SYN, POST2SYN,
PRE_SLICE, POST_SLICE]
PRE_SLICE, POST_SLICE,
COO, CSR, CSC]

MAT_DTYPE = np.bool_
IDX_DTYPE = np.uint32
MAT_DTYPE = jnp.bool_
IDX_DTYPE = jnp.uint32


def set_default_dtype(mat_dtype=None, idx_dtype=None):
@@ -92,7 +99,39 @@ class Connector(abc.ABC):


class TwoEndConnector(Connector):
"""Synaptic connector to build synapse connections between two neuron groups."""
"""Synaptic connector to build connections between two neuron groups.

If users want to customize their `Connector`, there are two ways:

1. Implementing ``build_conn(self)`` function, which returns one of
the connection data ``csr`` (CSR sparse data, a tuple of <post_ids, inptr>),
``coo`` (COO sparse data, a tuple of <pre_ids, post_ids>), or ``mat``
(a binary connection matrix). For instance,

.. code-block:: python

import brainpy as bp
class MyConnector(bp.conn.TwoEndConnector):
def build_conn(self):
return dict(csr=, mat=, coo=)

2. Implementing functions ``build_mat()``, ``build_csr()``, and
``build_coo()``. Users can provide all three functions, or one of them.

.. code-block:: python

import brainpy as bp
class MyConnector(bp.conn.TwoEndConnector):
def build_mat(self, ):
return conn_matrix

def build_csr(self, ):
return post_ids, inptr

def build_coo(self, ):
return pre_ids, post_ids

"""

def __init__(self, ):
self.pre_size = None
@@ -100,6 +139,9 @@ class TwoEndConnector(Connector):
self.pre_num = None
self.post_num = None

def __repr__(self):
return self.__class__.__name__

def __call__(self, pre_size, post_size):
"""Create the concrete connections between two end objects.

@@ -140,15 +182,16 @@ class TwoEndConnector(Connector):
"""
self.__call__(pre_size, post_size)

def check(self, structures: Union[Tuple, List, str]):
# check "pre_num" and "post_num"
try:
assert self.pre_num is not None and self.post_num is not None
except AssertionError:
raise ConnectorError(f'self.pre_num or self.post_num is not defined. '
f'Please use self.__call__(pre_size, post_size) '
f'before requiring properties.')
@property
def is_version2_style(self):
if ((hasattr(self.build_coo, 'not_customized') and self.build_coo.not_customized) and
(hasattr(self.build_csr, 'not_customized') and self.build_csr.not_customized) and
(hasattr(self.build_mat, 'not_customized') and self.build_mat.not_customized)):
return False
else:
return True

def _check(self, structures: Union[Tuple, List, str]):
# check synaptic structures
if isinstance(structures, str):
structures = [structures]
@@ -160,21 +203,18 @@ class TwoEndConnector(Connector):
f'Only {SUPPORTED_SYN_STRUCTURE} is supported.')

def _return_by_mat(self, structures, mat, all_data: dict):
assert isinstance(mat, np.ndarray) and np.ndim(mat) == 2
assert mat.ndim == 2
if (CONN_MAT in structures) and (CONN_MAT not in all_data):
all_data[CONN_MAT] = bm.asarray(mat, dtype=MAT_DTYPE)

require_other_structs = len([s for s in structures if s != CONN_MAT]) > 0
if require_other_structs:
pre_ids, post_ids = np.where(mat > 0)
pre_ids = np.ascontiguousarray(pre_ids, dtype=IDX_DTYPE)
post_ids = np.ascontiguousarray(post_ids, dtype=IDX_DTYPE)
self._return_by_ij(structures, ij=(pre_ids, post_ids), all_data=all_data)
if len([s for s in structures
if s not in [CONN_MAT]]) > 0:
ij = mat2coo(mat)
self._return_by_coo(structures, coo=ij, all_data=all_data)

def _return_by_csr(self, structures, csr: tuple, all_data: dict):
indices, indptr = csr
assert isinstance(indices, np.ndarray)
assert isinstance(indptr, np.ndarray)
np = onp if isinstance(indices, onp.ndarray) else bm
assert self.pre_num == indptr.size - 1

if (CONN_MAT in structures) and (CONN_MAT not in all_data):
@@ -188,15 +228,29 @@ class TwoEndConnector(Connector):
if (POST_IDS in structures) and (POST_IDS not in all_data):
all_data[POST_IDS] = bm.asarray(indices, dtype=IDX_DTYPE)

if (COO in structures) and (COO not in all_data):
pre_ids = np.repeat(np.arange(self.pre_num), np.diff(indptr))
all_data[COO] = (bm.asarray(pre_ids, dtype=IDX_DTYPE),
bm.asarray(indices, dtype=IDX_DTYPE))

if (PRE2POST in structures) and (PRE2POST not in all_data):
all_data[PRE2POST] = (bm.asarray(indices, dtype=IDX_DTYPE),
bm.asarray(indptr, dtype=IDX_DTYPE))

if (CSR in structures) and (CSR not in all_data):
all_data[CSR] = (bm.asarray(indices, dtype=IDX_DTYPE),
bm.asarray(indptr, dtype=IDX_DTYPE))

if (POST2PRE in structures) and (POST2PRE not in all_data):
indc, indptrc = csr2csc((indices, indptr), self.post_num)
all_data[POST2PRE] = (bm.asarray(indc, dtype=IDX_DTYPE),
bm.asarray(indptrc, dtype=IDX_DTYPE))

if (CSC in structures) and (CSC not in all_data):
indc, indptrc = csr2csc((indices, indptr), self.post_num)
all_data[CSC] = (bm.asarray(indc, dtype=IDX_DTYPE),
bm.asarray(indptrc, dtype=IDX_DTYPE))

if (PRE2SYN in structures) and (PRE2SYN not in all_data):
syn_seq = np.arange(indices.size, dtype=IDX_DTYPE)
all_data[PRE2SYN] = (bm.asarray(syn_seq, dtype=IDX_DTYPE),
@@ -208,13 +262,11 @@ class TwoEndConnector(Connector):
all_data[POST2SYN] = (bm.asarray(syn_seqc, dtype=IDX_DTYPE),
bm.asarray(indptrc, dtype=IDX_DTYPE))

def _return_by_ij(self, structures, ij: tuple, all_data: dict):
pre_ids, post_ids = ij
assert isinstance(pre_ids, np.ndarray)
assert isinstance(post_ids, np.ndarray)
def _return_by_coo(self, structures, coo: tuple, all_data: dict):
pre_ids, post_ids = coo

if (CONN_MAT in structures) and (CONN_MAT not in all_data):
all_data[CONN_MAT] = bm.asarray(ij2mat(ij, self.pre_num, self.post_num), dtype=MAT_DTYPE)
all_data[CONN_MAT] = bm.asarray(coo2mat(coo, self.pre_num, self.post_num), dtype=MAT_DTYPE)

if (PRE_IDS in structures) and (PRE_IDS not in all_data):
all_data[PRE_IDS] = bm.asarray(pre_ids, dtype=IDX_DTYPE)
@@ -222,59 +274,76 @@ class TwoEndConnector(Connector):
if (POST_IDS in structures) and (POST_IDS not in all_data):
all_data[POST_IDS] = bm.asarray(post_ids, dtype=IDX_DTYPE)

require_other_structs = len([s for s in structures
if s not in [CONN_MAT, PRE_IDS, POST_IDS]]) > 0
if require_other_structs:
csr = ij2csr(pre_ids, post_ids, self.pre_num)
if (COO in structures) and (COO not in all_data):
all_data[COO] = (bm.asarray(pre_ids, dtype=IDX_DTYPE),
bm.asarray(post_ids, dtype=IDX_DTYPE))

if CSC in structures and CSC not in all_data:
csc = coo2csc(coo, self.post_num)
all_data[CSC] = (bm.asarray(csc[0], dtype=IDX_DTYPE),
bm.asarray(csc[1], dtype=IDX_DTYPE))

if POST2PRE in structures and POST2PRE not in all_data:
csc = coo2csc(coo, self.post_num)
all_data[POST2PRE] = (bm.asarray(csc[0], dtype=IDX_DTYPE),
bm.asarray(csc[1], dtype=IDX_DTYPE))

if (len([s for s in structures
if s not in [CONN_MAT, PRE_IDS, POST_IDS,
COO, CSC, POST2PRE]]) > 0):
csr = coo2csr(coo, self.pre_num)
self._return_by_csr(structures, csr=csr, all_data=all_data)

def make_returns(self, structures, conn_data, csr=None, mat=None, ij=None):
def _make_returns(self, structures, conn_data):
"""Make the desired synaptic structures and return them.
"""
csr = None
mat = None
coo = None
if isinstance(conn_data, dict):
csr = conn_data['csr']
mat = conn_data['mat']
ij = conn_data['ij']
csr = conn_data.get('csr', None)
mat = conn_data.get('mat', None)
coo = conn_data.get('coo', None)
elif isinstance(conn_data, tuple):
if conn_data[0] == 'csr':
csr = conn_data[1]
elif conn_data[0] == 'mat':
mat = conn_data[1]
elif conn_data[0] == 'ij':
ij = conn_data[1]
elif conn_data[0] == 'coo':
coo = conn_data[1]
else:
raise ConnectorError(f'Must provide one of "csr", "mat" or "ij". Got "{conn_data[0]}" instead.')
raise ConnectorError(f'Must provide one of "csr", "mat" or "coo". Got "{conn_data[0]}" instead.')
else:
raise ConnectorError('Unknown type')

# checking
all_data = dict()
if (csr is None) and (mat is None) and (ij is None):
raise ConnectorError('Must provide one of "csr", "mat" or "ij".')
if (csr is None) and (mat is None) and (coo is None):
raise ConnectorError('Must provide one of "csr", "mat" or "coo".')
structures = (structures,) if isinstance(structures, str) else structures
assert isinstance(structures, (tuple, list))

all_data = dict()
# "csr" structure
if csr is not None:
assert isinstance(csr[0], np.ndarray)
assert isinstance(csr[1], np.ndarray)
if (PRE2POST in structures) and (PRE2POST not in all_data):
all_data[PRE2POST] = (bm.asarray(csr[0], dtype=IDX_DTYPE),
bm.asarray(csr[1], dtype=IDX_DTYPE))
self._return_by_csr(structures, csr=csr, all_data=all_data)

# "mat" structure
if mat is not None:
assert isinstance(mat, np.ndarray) and np.ndim(mat) == 2
assert mat.ndim == 2
if (CONN_MAT in structures) and (CONN_MAT not in all_data):
all_data[CONN_MAT] = bm.asarray(mat, dtype=MAT_DTYPE)
self._return_by_mat(structures, mat=mat, all_data=all_data)
# "ij" structure
if ij is not None:
assert isinstance(ij[0], np.ndarray)
assert isinstance(ij[1], np.ndarray)

# "coo" structure
if coo is not None:
if (PRE_IDS in structures) and (PRE_IDS not in structures):
all_data[PRE_IDS] = bm.asarray(ij[0], dtype=IDX_DTYPE)
all_data[PRE_IDS] = bm.asarray(coo[0], dtype=IDX_DTYPE)
if (POST_IDS in structures) and (POST_IDS not in structures):
all_data[POST_IDS] = bm.asarray(ij[1], dtype=IDX_DTYPE)
self._return_by_ij(structures, ij=ij, all_data=all_data)
all_data[POST_IDS] = bm.asarray(coo[1], dtype=IDX_DTYPE)
self._return_by_coo(structures, coo=coo, all_data=all_data)

# return
if len(structures) == 1:
@@ -282,25 +351,191 @@ class TwoEndConnector(Connector):
else:
return tuple([all_data[n] for n in structures])

def require(self, *structures):
"""Require all the connection data needed.

Examples
--------

>>> import brainpy as bp
>>> conn = bp.connect.FixedProb(0.1)
>>> mat = conn.require(10, 20, 'conn_mat')
>>> mat.shape
(10, 20)
"""

if len(structures) > 0:
pre_size = None
post_size = None
if not isinstance(structures[0], str):
pre_size = structures[0]
structures = structures[1:]
if len(structures) > 0:
if not isinstance(structures[0], str):
post_size = structures[0]
structures = structures[1:]
if pre_size is not None:
self.__call__(pre_size, post_size)
else:
return tuple()

if self.pre_num is None or self.post_num is None:
raise ConnectorError(f'self.pre_num or self.post_num is not defined. '
f'Please use "self.require(pre_size, post_size, DATA1, DATA2, ...)" ')

_has_coo_imp = not hasattr(self.build_coo, 'not_customized')
_has_csr_imp = not hasattr(self.build_csr, 'not_customized')
_has_mat_imp = not hasattr(self.build_mat, 'not_customized')

self._check(structures)
if (_has_coo_imp or _has_csr_imp or _has_mat_imp):
if len(structures) == 1:
if PRE2POST in structures and _has_csr_imp:
r = self.build_csr()
return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE)
elif CSR in structures and _has_csr_imp:
r = self.build_csr()
return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE)
elif CONN_MAT in structures and _has_mat_imp:
return bm.asarray(self.build_mat(), dtype=MAT_DTYPE)
elif PRE_IDS in structures and _has_coo_imp:
return bm.asarray(self.build_coo()[0], dtype=IDX_DTYPE)
elif POST_IDS in structures and _has_coo_imp:
return bm.asarray(self.build_coo()[1], dtype=IDX_DTYPE)
elif COO in structures and not _has_coo_imp:
return bm.asarray(self.build_coo(), dtype=IDX_DTYPE)

elif len(structures) == 2:
if (PRE_IDS in structures and POST_IDS in structures and _has_coo_imp):
r = self.build_coo()
if structures[0] == PRE_IDS:
return bm.asarray(r[0], dtype=IDX_DTYPE), bm.asarray(r[1], dtype=IDX_DTYPE)
else:
return bm.asarray(r[1], dtype=IDX_DTYPE), bm.asarray(r[0], dtype=IDX_DTYPE)

if ((CSR in structures or PRE2POST in structures)
and _has_csr_imp and COO in structures and _has_coo_imp):
csr = self.build_csr()
csr = (bm.asarray(csr[0], dtype=IDX_DTYPE), bm.asarray(csr[1], dtype=IDX_DTYPE))
coo = self.build_coo()
coo = (bm.asarray(coo[0], dtype=IDX_DTYPE), bm.asarray(coo[1], dtype=IDX_DTYPE))
if structures[0] == COO:
return coo, csr
else:
return csr, coo

if ((CSR in structures or PRE2POST in structures)
and _has_csr_imp and CONN_MAT in structures and _has_mat_imp):
csr = self.build_csr()
csr = (bm.asarray(csr[0], dtype=IDX_DTYPE), bm.asarray(csr[1], dtype=IDX_DTYPE))
mat = bm.asarray(self.build_mat(), dtype=MAT_DTYPE)
if structures[0] == CONN_MAT:
return mat, csr
else:
return csr, mat

if (COO in structures and _has_coo_imp and CONN_MAT in structures and _has_mat_imp):
coo = self.build_coo()
coo = (bm.asarray(coo[0], dtype=IDX_DTYPE), bm.asarray(coo[1], dtype=IDX_DTYPE))
mat = bm.asarray(self.build_mat(), dtype=MAT_DTYPE)
if structures[0] == COO:
return coo, mat
else:
return mat, coo

conn_data = dict(csr=None, ij=None, mat=None)
if _has_coo_imp:
conn_data['coo'] = self.build_coo()
# if (CSR in structures or PRE2POST in structures) and _has_csr_imp:
# conn_data['csr'] = self.build_csr()
# if CONN_MAT in structures and _has_mat_imp:
# conn_data['mat'] = self.build_mat()
elif _has_csr_imp:
conn_data['csr'] = self.build_csr()
# if COO in structures and _has_coo_imp:
# conn_data['coo'] = self.build_coo()
# if CONN_MAT in structures and _has_mat_imp:
# conn_data['mat'] = self.build_mat()
elif _has_mat_imp:
conn_data['mat'] = self.build_mat()
# if COO in structures and _has_coo_imp:
# conn_data['coo'] = self.build_coo()
# if (CSR in structures or PRE2POST in structures) and _has_csr_imp:
# conn_data['csr'] = self.build_csr()
else:
raise ValueError

else:
conn_data = self.build_conn()
return self._make_returns(structures, conn_data)

def requires(self, *structures):
"""Require all the connection data needed."""
return self.require(*structures)

@tools.not_customized
def build_conn(self):
"""build connections with certain data type.

If users want to customize their connections, please provide one
of the following functions:

- ``build_mat()``: build a matrix binary connection matrix.
- ``build_csr()``: build a csr sparse connection data.
- ``build_coo()``: build a coo sparse connection data.
- ``build_conn()``: deprecated.

Returns
-------
A tuple with two elements: connection type (str) and connection data.
example: return 'csr', (ind, indptr)
Or a dict with three elements: csr, mat and ij.
example: return dict(csr=(ind, indptr), mat=None, ij=None)
conn: tuple, dict
A tuple with two elements: connection type (str) and connection data.
For example: ``return 'csr', (ind, indptr)``
Or a dict with three elements: csr, mat and coo. For example:
``return dict(csr=(ind, indptr), mat=None, coo=None)``
"""
raise NotImplementedError
pass

def require(self, *structures):
self.check(structures)
conn_data = self.build_conn()
return self.make_returns(structures, conn_data)
@tools.not_customized
def build_mat(self):
"""Build a binary matrix connection data.

def requires(self, *structures):
return self.require(*structures)

If users want to customize their connections, please provide one
of the following functions:

- ``build_mat()``: build a matrix binary connection matrix.
- ``build_csr()``: build a csr sparse connection data.
- ``build_coo()``: build a coo sparse connection data.
- ``build_conn()``: deprecated.

Returns
-------
conn: Array
A binary matrix with the shape ``(num_pre, num_post)``.
"""
pass

@tools.not_customized
def build_csr(self):
"""Build a csr sparse connection data.

Returns
-------
conn: tuple
A tuple denoting the ``(indices, indptr)``.
"""
pass

@tools.not_customized
def build_coo(self):
"""Build a coo sparse connection data.

Returns
-------
conn: tuple
A tuple denoting the ``(pre_ids, post_ids)``.
"""
pass


class OneEndConnector(TwoEndConnector):
@@ -329,84 +564,101 @@ class OneEndConnector(TwoEndConnector):
else:
post_size = tuple(post_size)
self.pre_size, self.post_size = pre_size, post_size

self.pre_num = tools.size2num(self.pre_size)
self.post_num = tools.size2num(self.post_size)
return self

def _reset_conn(self, pre_size, post_size=None):
self.__init__()

self.__call__(pre_size, post_size)


def csr2csc(csr, post_num, data=None):
"""Convert csr to csc."""
indices, indptr = csr
pre_ids = np.repeat(np.arange(indptr.size - 1), np.diff(indptr))

sort_ids = np.argsort(indices, kind='mergesort') # to maintain the original order of the elements with the same value
pre_ids_new = np.asarray(pre_ids[sort_ids], dtype=IDX_DTYPE)

unique_post_ids, count = np.unique(indices, return_counts=True)
post_count = np.zeros(post_num, dtype=IDX_DTYPE)
post_count[unique_post_ids] = count

indptr_new = post_count.cumsum()
indptr_new = np.insert(indptr_new, 0, 0)
indptr_new = np.asarray(indptr_new, dtype=IDX_DTYPE)

if data is None:
return pre_ids_new, indptr_new
else:
data_new = data[sort_ids]
return pre_ids_new, indptr_new, data_new


def mat2csr(dense):
"""convert a dense matrix to (indices, indptr)."""
if isinstance(dense, bm.ndarray):
dense = np.asarray(dense)
np = onp if isinstance(dense, onp.ndarray) else bm
pre_ids, post_ids = np.where(dense > 0)
pre_num = dense.shape[0]
return coo2csr((pre_ids, post_ids), dense.shape[0])


def mat2coo(dense):
np = onp if isinstance(dense, onp.ndarray) else bm
pre_ids, post_ids = np.where(dense > 0)
return np.asarray(pre_ids, dtype=IDX_DTYPE), np.asarray(post_ids, dtype=IDX_DTYPE)

uni_idx, count = np.unique(pre_ids, return_counts=True)
pre_count = np.zeros(pre_num, dtype=IDX_DTYPE)
pre_count[uni_idx] = count
indptr = count.cumsum()
indptr = np.insert(indptr, 0, 0)

return np.asarray(post_ids, dtype=IDX_DTYPE), np.asarray(indptr, dtype=IDX_DTYPE)
def mat2csc(dense):
np = onp if isinstance(dense, onp.ndarray) else bm
pre_ids, post_ids = np.where(dense > 0)
return coo2csr((post_ids, pre_ids), dense.shape[1])


def csr2mat(csr, num_pre, num_post):
"""convert (indices, indptr) to a dense matrix."""
indices, indptr = csr
np = onp if isinstance(indices, onp.ndarray) else bm
d = np.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post
pre_ids = np.repeat(np.arange(indptr.size - 1), np.diff(indptr))
d[pre_ids, indices] = True
return d


def ij2mat(ij, num_pre, num_post):
def csr2csc(csr, post_num, data=None):
"""Convert csr to csc."""
return coo2csc(csr2coo(csr), post_num, data)


def csr2coo(csr):
np = onp if isinstance(csr[0], onp.ndarray) else bm
indices, indptr = csr
pre_ids = np.repeat(np.arange(indptr.size - 1), np.diff(indptr))
return pre_ids, indices


def coo2mat(ij, num_pre, num_post):
"""convert (indices, indptr) to a dense matrix."""
pre_ids, post_ids = ij
np = onp if isinstance(pre_ids, onp.ndarray) else bm
d = np.zeros((num_pre, num_post), dtype=MAT_DTYPE) # num_pre, num_post
d[pre_ids, post_ids] = True
return d


def ij2csr(pre_ids, post_ids, num_pre):
"""convert pre_ids, post_ids to (indices, indptr)."""
# sorting
sort_ids = np.argsort(pre_ids, kind='mergesort')
post_ids = post_ids[sort_ids]
def coo2csr(coo, num_pre):
"""convert pre_ids, post_ids to (indices, indptr) when'jax_platform_name' = 'gpu'"""
pre_ids, post_ids = coo
np = onp if isinstance(pre_ids, onp.ndarray) else bm

sort_ids = np.argsort(pre_ids)
post_ids = np.asarray(post_ids)
post_ids = post_ids[sort_ids]
indices = post_ids
unique_pre_ids, pre_count = np.unique(pre_ids, return_counts=True)
final_pre_count = np.zeros(num_pre, dtype=IDX_DTYPE)
final_pre_count = np.zeros(num_pre, dtype=jnp.uint32)
final_pre_count[unique_pre_ids] = pre_count
indptr = final_pre_count.cumsum()
indptr = np.insert(indptr, 0, 0)

return np.asarray(indices, dtype=IDX_DTYPE), np.asarray(indptr, dtype=IDX_DTYPE)


def coo2csc(coo, post_num, data=None):
"""Convert csr to csc."""
pre_ids, indices = coo
np = onp if isinstance(indices, onp.ndarray) else bm

# to maintain the original order of the elements with the same value
sort_ids = np.argsort(indices)
pre_ids_new = np.asarray(pre_ids[sort_ids], dtype=IDX_DTYPE)

unique_post_ids, count = np.unique(indices, return_counts=True)
post_count = np.zeros(post_num, dtype=IDX_DTYPE)
post_count[unique_post_ids] = count

indptr_new = post_count.cumsum()
indptr_new = np.insert(indptr_new, 0, 0)
indptr_new = np.asarray(indptr_new, dtype=IDX_DTYPE)

if data is None:
return pre_ids_new, indptr_new
else:
data_new = data[sort_ids]
return pre_ids_new, indptr_new, data_new

+ 54
- 39
brainpy/connect/custom_conn.py View File

@@ -3,14 +3,15 @@
import jax.numpy as jnp
import numpy as np

from brainpy import math as bm
from brainpy import tools
from brainpy.errors import ConnectorError
from brainpy.math.jaxarray import JaxArray
from .base import *

__all__ = [
'MatConn',
'IJConn',
'CSRConn',
'SparseMatConn'
]

@@ -21,19 +22,21 @@ class MatConn(TwoEndConnector):
def __init__(self, conn_mat):
super(MatConn, self).__init__()

assert isinstance(conn_mat, (np.ndarray, JaxArray, jnp.ndarray)) and conn_mat.ndim == 2
assert isinstance(conn_mat, (np.ndarray, bm.JaxArray, jnp.ndarray)) and conn_mat.ndim == 2
self.pre_num, self.post_num = conn_mat.shape
self.pre_size, self.post_size = (self.pre_num,), (self.post_num,)
self.conn_mat = np.asarray(conn_mat).astype(MAT_DTYPE)
self.conn_mat = bm.asarray(conn_mat).astype(MAT_DTYPE)
def __call__(self, pre_size, post_size):
assert self.pre_num == tools.size2num(pre_size)
assert self.post_num == tools.size2num(post_size)
return self

def build_conn(self):
return 'mat', self.conn_mat
def build_mat(self):
assert self.conn_mat.shape[0] == self.pre_num
assert self.conn_mat.shape[1] == self.post_num
return self.conn_mat


class IJConn(TwoEndConnector):
@@ -42,37 +45,61 @@ class IJConn(TwoEndConnector):
def __init__(self, i, j):
super(IJConn, self).__init__()

assert isinstance(i, (np.ndarray, JaxArray, jnp.ndarray)) and i.ndim == 1
assert isinstance(j, (np.ndarray, JaxArray, jnp.ndarray)) and j.ndim == 1
assert isinstance(i, (np.ndarray, bm.JaxArray, jnp.ndarray)) and i.ndim == 1
assert isinstance(j, (np.ndarray, bm.JaxArray, jnp.ndarray)) and j.ndim == 1
assert i.size == j.size

# initialize the class via "pre_ids" and "post_ids"
self.pre_ids = np.asarray(i).astype(IDX_DTYPE)
self.post_ids = np.asarray(j).astype(IDX_DTYPE)
self.pre_ids = bm.asarray(i).astype(IDX_DTYPE)
self.post_ids = bm.asarray(j).astype(IDX_DTYPE)
self.max_pre = bm.max(self.pre_ids)
self.max_post = bm.max(self.post_ids)

def __call__(self, pre_size, post_size):
super(IJConn, self).__call__(pre_size, post_size)

max_pre = np.max(self.pre_ids)
max_post = np.max(self.post_ids)
if max_pre >= self.pre_num:
if self.max_pre >= self.pre_num:
raise ConnectorError(f'pre_num ({self.pre_num}) should be greater than '
f'the maximum id ({max_pre}) of self.pre_ids.')
if max_post >= self.post_num:
f'the maximum id ({self.max_pre}) of self.pre_ids.')
if self.max_post >= self.post_num:
raise ConnectorError(f'post_num ({self.post_num}) should be greater than '
f'the maximum id ({max_post}) of self.post_ids.')
f'the maximum id ({self.max_post}) of self.post_ids.')
return self

def build_conn(self):
return 'ij', (self.pre_ids, self.post_ids)
def build_coo(self):
if self.pre_num <= self.max_pre:
raise ConnectorError(f'pre_num ({self.pre_num}) should be greater than '
f'the maximum id ({self.max_pre}) of self.pre_ids.')
if self.post_num <= self.max_post:
raise ConnectorError(f'post_num ({self.post_num}) should be greater than '
f'the maximum id ({self.max_post}) of self.post_ids.')
return self.pre_ids, self.post_ids


class CSRConn(TwoEndConnector):
"""Connector built from the CSR sparse connection matrix."""

def __init__(self, indices, inptr):
super(CSRConn, self).__init__()

self.indices = bm.asarray(indices, dtype=IDX_DTYPE)
self.inptr = bm.asarray(inptr, dtype=IDX_DTYPE)
self.pre_num = self.inptr.size - 1
self.max_post = bm.max(self.indices)

def build_csr(self):
if self.pre_num != self.pre_num:
raise ConnectorError(f'(pre_size, post_size) is inconsistent with '
f'the shape of the sparse matrix.')
if self.post_num <= self.max_post:
raise ConnectorError(f'post_num ({self.post_num}) should be greater than '
f'the maximum id ({self.max_post}) of self.post_ids.')
return self.indices, self.inptr


class SparseMatConn(TwoEndConnector):
class SparseMatConn(CSRConn):
"""Connector built from the sparse connection matrix"""

def __init__(self, csr_mat):
super(SparseMatConn, self).__init__()

try:
from scipy.sparse import csr_matrix
except (ModuleNotFoundError, ImportError):
@@ -80,20 +107,8 @@ class SparseMatConn(TwoEndConnector):
f'Please run "pip install scipy" to install scipy.')

assert isinstance(csr_mat, csr_matrix)
csr_mat.data = np.asarray(csr_mat.data).astype(MAT_DTYPE)
self.csr_mat = csr_mat
self.pre_num, self.post_num = csr_mat.shape

def __call__(self, pre_size, post_size):
try:
assert self.pre_num == tools.size2num(pre_size)
assert self.post_num == tools.size2num(post_size)
except AssertionError:
raise ConnectorError(f'(pre_size, post_size) is inconsistent with the shape of the sparse matrix.')

super(SparseMatConn, self).__call__(pre_size, post_size)
return self

def build_conn(self):
ind, indptr = self.csr_mat.indices, self.csr_mat.indptr
return 'csr', (ind, indptr)
super(SparseMatConn, self).__init__(indices=bm.asarray(self.csr_mat.indices, dtype=IDX_DTYPE),
inptr=bm.asarray(self.csr_mat.indptr, dtype=IDX_DTYPE))
self.pre_num = csr_mat.shape[0]
self.post_num = csr_mat.shape[1]

+ 256
- 104
brainpy/connect/random_conn.py View File

@@ -1,15 +1,20 @@
# -*- coding: utf-8 -*-

from typing import Optional

import jax.numpy as jnp
import numpy as np

import brainpy.math as bm
from brainpy.errors import ConnectorError
from brainpy.tools.others import numba_seed, numba_jit, SUPPORT_NUMBA, format_seed
from brainpy.tools.others import numba_seed, numba_jit, numba_range, SUPPORT_NUMBA, format_seed
from .base import *

__all__ = [
'FixedProb',
'FixedPreNum',
'FixedPostNum',
'FixedTotalNum',
'GaussianProb',
'ProbDist',

@@ -25,70 +30,141 @@ class FixedProb(TwoEndConnector):

Parameters
----------
prob : float
prob: float
The conn probability.
pre_ratio: float
The ratio of pre-synaptic neurons to connect.
include_self : bool
Whether create (i, i) conn?
allow_multi_conn: bool
Allow one pre-synaptic neuron connects to multiple post-synaptic neurons?

.. versionadded:: 2.2.3.2

seed : optional, int
Seed the random generator.
"""

def __init__(self, prob, pre_ratio=1., include_self=True, seed=None):
def __init__(self, prob, pre_ratio=1., include_self=True, allow_multi_conn=False, seed=None):
super(FixedProb, self).__init__()
assert 0. <= prob <= 1.
assert 0. <= pre_ratio <= 1.
self.prob = prob
self.pre_ratio = pre_ratio
self.include_self = include_self
self.seed = format_seed(seed)
self.rng = np.random.RandomState(seed=self.seed)

rng = np.random if SUPPORT_NUMBA else self.rng

def _connect(pre_i, num_post):
if rng.random() < pre_ratio:
p = rng.random(num_post) <= prob
if (not include_self) and pre_i < num_post:
p[pre_i] = False
return np.where(p)[0]

self._connect = numba_jit(_connect)
self._jaxrand = bm.random.RandomState(self.seed)
self._nprand = np.random.RandomState(self.seed)
self.allow_multi_conn = allow_multi_conn

def __repr__(self):
return (f'{self.__class__.__name__}(prob={self.prob}, pre_ratio={self.pre_ratio}, '
f'include_self={self.include_self}, allow_multi_conn={self.allow_multi_conn}, '
f'seed={self.seed})')

def _iii(self):
if (not self.include_self) and (self.pre_num != self.post_num):
raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). '
f'But `include_self` is set to True.')

if self.pre_ratio < 1.:
pre_num_to_select = int(self.pre_num * self.pre_ratio)
pre_ids = self._jaxrand.choice(self.pre_num, size=(pre_num_to_select,), replace=False)
else:
pre_num_to_select = self.pre_num
pre_ids = jnp.arange(self.pre_num)

def build_conn(self):
# seed
self.seed = self.rng.randint(1, int(1e7))
if SUPPORT_NUMBA: numba_seed(self.seed)
post_num_total = self.post_num
post_num_to_select = int(self.post_num * self.prob)

# make connections
ind = []
count = np.zeros(self.pre_num, dtype=IDX_DTYPE)
for i in range(self.pre_num):
posts = self._connect(pre_i=i, num_post=self.post_num)
if posts is not None:
ind.append(posts)
count[i] = len(posts)
ind = np.concatenate(ind) if len(ind) > 0 else np.asarray([], dtype=IDX_DTYPE)
indptr = np.concatenate(([0], count)).cumsum()
if self.allow_multi_conn:
selected_post_ids = self._jaxrand.randint(0, post_num_total, (pre_num_to_select, post_num_to_select))

return 'csr', (ind, indptr)
else:
if SUPPORT_NUMBA:
rng = np.random
numba_seed(self._nprand.randint(0, int(1e8)))
else:
rng = self._nprand

@numba_jit # (parallel=True, nogil=True)
def single_conn():
posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=np.uint32)
for i in numba_range(pre_num_to_select):
posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False)
return posts

selected_post_ids = jnp.asarray(single_conn())
return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids)

def build_coo(self):
_, post_num_to_select, selected_post_ids, pre_ids = self._iii()
selected_post_ids = selected_post_ids.flatten()
selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select)
if not self.include_self:
true_ids = selected_pre_ids != selected_post_ids
selected_pre_ids = selected_pre_ids[true_ids]
selected_post_ids = selected_post_ids[true_ids]
return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE)

def build_csr(self):
pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._iii()
pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select
if not self.include_self:
true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1))
pre_nums -= jnp.sum(true_ids, axis=1)
selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()]
else:
selected_post_ids = selected_post_ids.flatten()
selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums]))
return selected_post_ids.astype(IDX_DTYPE), selected_pre_inptr.astype(IDX_DTYPE)

def build_mat(self):
pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio
mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state
mat = bm.asarray(mat)
if not self.include_self:
bm.fill_diagonal(mat, False)
return mat.astype(MAT_DTYPE)


class FixedNum(TwoEndConnector):
"""Connect with fixed number for each pre- or post-synaptic neuron.
class FixedTotalNum(TwoEndConnector):
"""Connect the synaptic neurons with fixed total number.

Parameters
----------
num : float, int
The conn probability (if "num" is float) or the fixed number of
connectivity (if "num" is int).
include_self : bool
Whether create (i, i) conn ?
seed : None, int
Seed the random generator.
num : float,int
The conn total number.
seed: int, optional
The random number seed.
"""

def __init__(self, num, include_self=True, seed=None):
def __init__(self, num, seed=None):
super(FixedTotalNum, self).__init__()
if isinstance(num, int):
assert num >= 0, '"num" must be a non-negative integer.'
elif isinstance(num, float):
assert 0. <= num <= 1., '"num" must be in [0., 1.).'
else:
raise ConnectorError(f'Unknown type: {type(num)}')
self.num = num
self.seed = format_seed(seed)
self.rng = bm.random.RandomState(self.seed)

def build_coo(self):
if self.num > self.pre_num * self.post_num:
raise ConnectorError(f'"num" must be smaller than "all2all num", '
f'but got {self.num} > {self.pre_num * self.post_num}')
selected_pre_ids = self.rng.randint(0, self.pre_num, (self.num,))
selected_post_ids = self.rng.randint(0, self.post_num, (self.num,))
return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE)

def __repr__(self):
return f'{self.__class__.__name__}(num={self.num}, seed={self.seed})'


class FixedNum(TwoEndConnector):
def __init__(self, num, include_self=True, allow_multi_conn=False, seed=None):
super(FixedNum, self).__init__()
if isinstance(num, int):
assert num >= 0, '"num" must be a non-negative integer.'
@@ -99,58 +175,75 @@ class FixedNum(TwoEndConnector):
self.num = num
self.seed = format_seed(seed)
self.include_self = include_self
self.rng = np.random.RandomState(seed=self.seed)
rng = np.random if SUPPORT_NUMBA else self.rng
self.allow_multi_conn = allow_multi_conn
self.rng = bm.random.RandomState(self.seed) if allow_multi_conn else np.random.RandomState(self.seed)

def _fixed_num_prob(num_need, num_total, i=0):
prob = rng.random(num_total)
if not include_self and i <= num_total:
prob[i] = 1.
neu_idx = np.argsort(prob)[:num_need]
return np.asarray(neu_idx, dtype=IDX_DTYPE)

self._connect = numba_jit(_fixed_num_prob)
def __repr__(self):
return f'{self.__class__.__name__}(num={self.num}, include_self={self.include_self}, seed={self.seed})'


class FixedPreNum(FixedNum):
"""Connect the pre-synaptic neurons with fixed number for each post-synaptic neuron.
"""Connect a fixed number pf pre-synaptic neurons for each post-synaptic neuron.

Parameters
----------
num : float, int
The connection probability (if "num" is float) or the fixed number of
connectivity (if "num" is int).
The conn probability (if "num" is float) or the fixed number of
connectivity (if "num" is int).
include_self : bool
Whether create (i, i) conn ?
Whether create (i, i) conn ?
seed : None, int
Seed the random generator.
allow_multi_conn: bool
Allow one pre-synaptic neuron connects to multiple post-synaptic neurons?

.. versionadded:: 2.2.3.2

"""

def build_conn(self):
# check
if isinstance(self.num, int):
assert 0 <= self.num <= self.pre_num, f'"num" must be smaller than "self.pre_num", ' \
f'but got {self.num} > {self.pre_num}'
num = self.num
def build_coo(self):
if isinstance(self.num, int) and self.num > self.pre_num:
raise ConnectorError(f'"num" must be smaller than "pre_num", '
f'but got {self.num} > {self.pre_num}')
if (not self.include_self) and (self.pre_num != self.post_num):
raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). '
f'But `include_self` is set to True.')
pre_num_to_select = int(self.pre_num * self.num) if isinstance(self.num, float) else self.num
pre_num_total = self.pre_num
post_num_total = self.post_num

if self.allow_multi_conn:
selected_pre_ids = self.rng.randint(0, pre_num_total, (post_num_total, pre_num_to_select,))

else:
assert 0. <= self.num <= 1., f'"num" must be in [0., 1.), but got {self.num}'
num = int(self.pre_num * self.num)
if SUPPORT_NUMBA:
rng = np.random
numba_seed(self.rng.randint(0, int(1e8)))
else:
rng = self.rng

# seed
self.seed = self.rng.randint(1, int(1e7))
numba_seed(self.seed)
@numba_jit # (parallel=True, nogil=True)
def single_conn():
posts = np.zeros((post_num_total, pre_num_to_select), dtype=np.uint32)
for i in numba_range(post_num_total):
posts[i] = rng.choice(pre_num_total, pre_num_to_select, replace=False)
return posts

# make connections
pre_ids = []
for i in range(self.post_num):
pres = self._connect(num_need=num, num_total=self.pre_num, i=i)
pre_ids.append(pres)
pre_ids = np.concatenate(pre_ids) if len(pre_ids) > 0 else np.asarray([], dtype=IDX_DTYPE)
post_ids = np.repeat(np.arange(self.post_num), num)
selected_pre_ids = jnp.asarray(single_conn())

return 'ij', (pre_ids, post_ids)
post_nums = jnp.ones((post_num_total,), dtype=IDX_DTYPE) * pre_num_to_select
if not self.include_self:
true_ids = selected_pre_ids == jnp.reshape(jnp.arange(pre_num_total), (-1, 1))
post_nums -= jnp.sum(true_ids, axis=1)
selected_pre_ids = selected_pre_ids.flatten()[jnp.logical_not(true_ids).flatten()]
else:
selected_pre_ids = selected_pre_ids.flatten()
selected_post_ids = jnp.repeat(jnp.arange(post_num_total), post_nums)
return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE)


class FixedPostNum(FixedNum):
"""Connect the post-synaptic neurons with fixed number for each pre-synaptic neuron.
"""Connect the fixed number of post-synaptic neurons for each pre-synaptic neuron.

Parameters
----------
@@ -161,32 +254,66 @@ class FixedPostNum(FixedNum):
Whether create (i, i) conn ?
seed : None, int
Seed the random generator.
"""
allow_multi_conn: bool
Allow one pre-synaptic neuron connects to multiple post-synaptic neurons?

def build_conn(self):
# check
if isinstance(self.num, int):
assert 0 <= self.num <= self.post_num, f'"num" must be smaller than "self.post_num", ' \
f'but got {self.num} > {self.post_num}'
num = self.num
else:
assert 0. <= self.num <= 1., f'"num" must be in [0., 1.), but got {self.num}'
num = int(self.post_num * self.num)
.. versionadded:: 2.2.3.2

# seed
self.seed = self.rng.randint(1, int(1e7))
numba_seed(self.seed)
"""

# make connections
post_ids = [] # i.e. post_ids
for i in range(self.pre_num):
posts = self._connect(num_need=num, num_total=self.post_num, i=i)
post_ids.append(posts)
post_ids = np.concatenate(post_ids)
count = np.ones(self.pre_num, dtype=IDX_DTYPE) * num
indptr = np.concatenate(([0], count)).cumsum()
def _ii(self):
if isinstance(self.num, int) and self.num > self.post_num:
raise ConnectorError(f'"num" must be smaller than "post_num", '
f'but got {self.num} > {self.post_num}')
if (not self.include_self) and (self.pre_num != self.post_num):
raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). '
f'But `include_self` is set to True.')
post_num_to_select = int(self.post_num * self.num) if isinstance(self.num, float) else self.num
pre_num_to_select = self.pre_num
pre_ids = jnp.arange(self.pre_num)
post_num_total = self.post_num

if self.allow_multi_conn:
selected_post_ids = self.rng.randint(0, post_num_total, (pre_num_to_select, post_num_to_select,))

return 'csr', (post_ids, indptr)
else:
if SUPPORT_NUMBA:
rng = np.random
numba_seed(self.rng.randint(0, int(1e8)))
else:
rng = self.rng

@numba_jit # (parallel=True, nogil=True)
def single_conn():
posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=np.uint32)
for i in numba_range(pre_num_to_select):
posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False)
return posts

selected_post_ids = jnp.asarray(single_conn())
return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids)

def build_coo(self):
_, post_num_to_select, selected_post_ids, pre_ids = self._ii()
selected_post_ids = selected_post_ids.flatten()
selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select)
if not self.include_self:
true_ids = selected_pre_ids != selected_post_ids
selected_pre_ids = selected_pre_ids[true_ids]
selected_post_ids = selected_post_ids[true_ids]
return selected_pre_ids.astype(IDX_DTYPE), selected_post_ids.astype(IDX_DTYPE)

def build_csr(self):
pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._ii()
pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select
if not self.include_self:
true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1))
pre_nums -= jnp.sum(true_ids, axis=1)
selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()]
else:
selected_post_ids = selected_post_ids.flatten()
selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums]))
return selected_post_ids.astype(IDX_DTYPE), selected_pre_inptr.astype(IDX_DTYPE)


class GaussianProb(OneEndConnector):
@@ -222,7 +349,7 @@ class GaussianProb(OneEndConnector):
normalize : bool
Whether normalize the connection probability .
include_self : bool
Whether create the conn at the same position.
Whether create the connection at the same position.
seed : int
The random seed.
"""
@@ -230,7 +357,7 @@ class GaussianProb(OneEndConnector):
def __init__(
self,
sigma: float,
encoding_values=None,
encoding_values: Optional[np.ndarray] = None,
normalize: bool = True,
include_self: bool = True,
periodic_boundary: bool = False,
@@ -245,7 +372,14 @@ class GaussianProb(OneEndConnector):
self.seed = format_seed(seed)
self.rng = np.random.RandomState(self.seed)

def build_conn(self):
def __repr__(self):
return (f'{self.__class__.__name__}(sigma={self.sigma}, '
f'normalize={self.normalize}, '
f'periodic_boundary={self.periodic_boundary}, '
f'include_self={self.include_self}, '
f'seed={self.seed})')

def build_mat(self, pre_size=None, post_size=None):
# value range to encode
if self.encoding_values is None:
value_ranges = tuple([(0, s) for s in self.pre_size])
@@ -272,6 +406,7 @@ class GaussianProb(OneEndConnector):

# values
values = [np.linspace(vs[0], vs[1], n + 1)[:n] for vs, n in zip(value_ranges, self.pre_size)]
# post_values = np.stack([v.flatten() for v in np.meshgrid(*values, indexing='ij')])
post_values = np.stack([v.flatten() for v in np.meshgrid(*values)])
value_sizes = np.array([v[1] - v[0] for v in value_ranges])
if value_sizes.ndim < post_values.ndim:
@@ -300,12 +435,10 @@ class GaussianProb(OneEndConnector):
prob_mat /= prob_mat.max()

# connectivity
conn_mat = prob_mat >= self.rng.random(prob_mat.shape)

conn_mat = np.asarray(prob_mat) >= self.rng.random(prob_mat.shape)
if not self.include_self:
np.fill_diagonal(conn_mat, False)

return 'mat', conn_mat
return conn_mat


class SmallWorld(TwoEndConnector):
@@ -373,6 +506,13 @@ class SmallWorld(TwoEndConnector):

self._connect = numba_jit(_smallworld_rewire)

def __repr__(self):
return (f'{self.__class__.__name__}(prob={self.prob}, '
f'directed={self.directed}, '
f'num_neighbor={self.num_neighbor}, '
f'include_self={self.include_self}, '
f'seed={self.seed})')

def build_conn(self):
assert self.pre_size == self.post_size

@@ -487,6 +627,11 @@ class ScaleFreeBA(TwoEndConnector):

self._connect = numba_jit(_random_subset)

def __repr__(self):
return (f'{self.__class__.__name__}(m={self.m}, '
f'directed={self.directed}, '
f'seed={self.seed})')

def build_conn(self):
assert self.pre_num == self.post_num

@@ -573,6 +718,10 @@ class ScaleFreeBADual(TwoEndConnector):

self._connect = numba_jit(_random_subset)

def __repr__(self):
return (f'{self.__class__.__name__}(m1={self.m1}, m2={self.m2}, '
f'p={self.p}, directed={self.directed}, seed={self.seed})')

def build_conn(self):
assert self.pre_num == self.post_num
# seed
@@ -683,6 +832,9 @@ class PowerLaw(TwoEndConnector):

self._connect = numba_jit(_random_subset)

def __repr__(self):
return (f'{self.__class__.__name__}(m={self.m}, p={self.p}, directed={self.directed}, seed={self.seed})')

def build_conn(self):
assert self.pre_num == self.post_num
# seed
@@ -886,7 +1038,7 @@ class ProbDist(TwoEndConnector):
post_size = np.asarray(self.post_size)
connected_pres = []
connected_posts = []
pre_ids = np.meshgrid(*(np.arange(p) for p in self.pre_size))
pre_ids = np.meshgrid(*(np.arange(p) for p in self.pre_size), indexing='ij')
pre_ids = tuple([(np.moveaxis(p, 0, 1).flatten()) if p.ndim > 1 else p.flatten() for p in pre_ids])
size = np.prod(pre_size)
for i in range(size):


+ 189
- 104
brainpy/connect/regular_conn.py View File

@@ -1,16 +1,13 @@
# -*- coding: utf-8 -*-
from typing import Union, Tuple, List

import logging

import jax
import numpy as np

from brainpy import math as bm
from brainpy.errors import ConnectorError
from brainpy.tools.others import numba_jit

from .base import *

logger = logging.getLogger('brainpy.building.connect')

__all__ = [
'One2One', 'one2one',
'All2All', 'all2all',
@@ -24,6 +21,7 @@ class One2One(TwoEndConnector):
"""Connect two neuron groups one by one. This means
The two neuron groups should have the same size.
"""

def __init__(self):
super(One2One, self).__init__()

@@ -36,11 +34,25 @@ class One2One(TwoEndConnector):
f'same size, but {self.pre_num} != {self.post_num}.')
return self

def build_conn(self):
def build_coo(self):
if self.pre_num != self.post_num:
raise ConnectorError(f'One2One connection must be defined in two groups with the '
f'same size, but {self.pre_num} != {self.post_num}.')
return np.arange(self.pre_num, dtype=IDX_DTYPE), np.arange(self.post_num, dtype=IDX_DTYPE),

def build_csr(self):
if self.pre_num != self.post_num:
raise ConnectorError(f'One2One connection must be defined in two groups with the '
f'same size, but {self.pre_num} != {self.post_num}.')
ind = np.arange(self.pre_num)
indptr = np.arange(self.pre_num + 1)
return (np.asarray(ind, dtype=IDX_DTYPE), np.asarray(indptr, dtype=IDX_DTYPE))

return dict(csr=(ind, indptr), mat=None, ij=None)
def build_mat(self, pre_size=None, post_size=None):
if self.pre_num != self.post_num:
raise ConnectorError(f'One2One connection must be defined in two groups with the '
f'same size, but {self.pre_num} != {self.post_num}.')
return np.fill_diagonal(np.zeros((self.pre_num, self.post_num), dtype=MAT_DTYPE), True)


one2one = One2One()
@@ -56,95 +68,153 @@ class All2All(TwoEndConnector):
self.include_self = include_self
super(All2All, self).__init__()

def build_conn(self):
def __repr__(self):
return f'{self.__class__.__name__}(include_self={self.include_self})'

def build_mat(self):
mat = np.ones((self.pre_num, self.post_num), dtype=MAT_DTYPE)
if not self.include_self:
np.fill_diagonal(mat, False)

return dict(csr=None, mat=mat, ij=None)
return mat


all2all = All2All(include_self=True)


@numba_jit
def _grid_four(height, width, row, include_self):
conn_i = []
conn_j = []

for col in range(width):
i_index = (row * width) + col
if 0 <= row - 1 < height:
j_index = ((row - 1) * width) + col
conn_i.append(i_index)
conn_j.append(j_index)
if 0 <= row + 1 < height:
j_index = ((row + 1) * width) + col
conn_i.append(i_index)
conn_j.append(j_index)
if 0 <= col - 1 < width:
j_index = (row * width) + col - 1
conn_i.append(i_index)
conn_j.append(j_index)
if 0 <= col + 1 < width:
j_index = (row * width) + col + 1
conn_i.append(i_index)
conn_j.append(j_index)
if include_self:
conn_i.append(i_index)
conn_j.append(i_index)
return conn_i, conn_j


class GridFour(OneEndConnector):
"""The nearest four neighbors conn method."""

def __init__(self, include_self=False):
super(GridFour, self).__init__()
def get_size_length(sizes: Union[Tuple, List]):
if not isinstance(sizes, (tuple, list)):
raise TypeError
lengths = []
a = 1
for s in reversed(sizes):
lengths.insert(0, a)
a *= s
return np.asarray(lengths)


class GridConn(OneEndConnector):
def __init__(
self,
strides,
include_self: bool = False,
periodic_boundary: bool = False,
):
super(GridConn, self).__init__()
self.strides = strides
self.include_self = include_self

def build_conn(self):
# only the 1- or 2-D structure is supported
if len(self.pre_size) == 1:
height, width = self.pre_size[0], 1
elif len(self.pre_size) == 2:
height, width = self.pre_size
self.periodic_boundary = periodic_boundary

def __repr__(self):
return f'{self.__class__.__name__}(include_self={self.include_self}, periodic_boundary={self.periodic_boundary})'

def _format(self):
dim = len(self.post_size)
if self.pre_num != self.post_num:
raise ConnectorError(f'{self.__class__.__name__} is used to for connection within '
f'a same population. But we detect pre_num != post_num '
f'({self.pre_num} != {self.post_num}).')
# point indices
indices = bm.meshgrid(*(bm.arange(size) for size in self.post_size), indexing='ij')
indices = bm.asarray(indices)
indices = indices.reshape(dim, self.post_num).T
lengths = bm.asarray(self.post_size)
return lengths, dim, indices

def _get_strides(self, dim):
# increments
increments = np.asarray(np.meshgrid(*(self.strides for _ in range(dim)))).reshape(dim, -1).T
select_ids = self._select_stride(increments)
increments = bm.asarray(increments[select_ids])
return increments

def _select_stride(self, stride: np.ndarray) -> np.ndarray:
raise NotImplementedError

def _select_dist(self, dist: bm.ndarray) -> bm.ndarray:
raise NotImplementedError

def build_mat(self):
sizes, _, indices = self._format()

@jax.vmap
def f_connect(pre_id):
# pre_id: R^(num_dim)
dist = bm.abs(pre_id - indices)
if self.periodic_boundary:
dist = bm.where(dist > sizes / 2, sizes - dist, dist)
return self._select_dist(dist)

return bm.asarray(f_connect(indices), dtype=MAT_DTYPE)

def build_coo(self):
sizes, dim, indices = self._format()
strides = self._get_strides(dim)

@jax.vmap
def f_connect(pre_id):
# pre_id: R^(num_dim)
post_ids = pre_id + strides
if self.periodic_boundary:
post_ids = post_ids % sizes
else:
post_ids = bm.where(post_ids < sizes, post_ids, -1)
size = len(post_ids)
pre_ids = bm.repeat(pre_id, size).reshape(dim, size).T
return pre_ids, post_ids

pres, posts = f_connect(indices)
pres = pres.reshape(-1, dim)
posts = posts.reshape(-1, dim)
idx = bm.nonzero(bm.all(posts >= 0, axis=1))[0]
pres = pres[idx]
posts = posts[idx]
if dim == 1:
pres = pres.flatten()
posts = posts.flatten()
else:
raise ConnectorError(f'Currently, GridFour only supports the two-dimensional geometry.')
strides = bm.asarray(get_size_length(self.post_size))
pres = bm.sum(pres * strides, axis=1)
posts = bm.sum(posts * strides, axis=1)
return bm.asarray(pres, dtype=IDX_DTYPE), bm.asarray(posts, dtype=IDX_DTYPE)


class GridFour(GridConn):
"""The nearest four neighbors connection method.

conn_i = []
conn_j = []
for row in range(height):
a = _grid_four(height, width, row, include_self=self.include_self)
conn_i.extend(a[0])
conn_j.extend(a[1])
pre_ids = np.asarray(conn_i, dtype=IDX_DTYPE)
post_ids = np.asarray(conn_j, dtype=IDX_DTYPE)
Parameters
----------
periodic_boundary : bool
Whether the neuron encode the value space with the periodic boundary.
.. versionadded:: 2.2.3.2

include_self : bool
Whether create connection at the same position.
"""

def __init__(
self,
include_self: bool = False,
periodic_boundary: bool = False
):
super(GridFour, self).__init__(strides=np.asarray([-1, 0, 1]),
include_self=include_self,
periodic_boundary=periodic_boundary)
self.include_self = include_self
self.periodic_boundary = periodic_boundary

def _select_stride(self, stride: np.ndarray) -> np.ndarray:
temp = abs(stride).sum(axis=1)
return (temp <= 1) if self.include_self else (temp == 1)

return 'ij', (pre_ids, post_ids)
def _select_dist(self, dist: bm.ndarray) -> bm.ndarray:
dist = bm.linalg.norm(dist, axis=1)
return dist <= 1 if self.include_self else dist == 1


grid_four = GridFour()


@numba_jit
def _grid_n(height, width, row, n, include_self):
conn_i = []
conn_j = []
for col in range(width):
i_index = (row * width) + col
for row_diff in range(-n, n + 1):
for col_diff in range(-n, n + 1):
if (not include_self) and (row_diff == col_diff == 0):
continue
if 0 <= row + row_diff < height and 0 <= col + col_diff < width:
j_index = ((row + row_diff) * width) + col + col_diff
conn_i.append(i_index)
conn_j.append(j_index)
return conn_i, conn_j


class GridN(OneEndConnector):
class GridN(GridConn):
"""The nearest (2*N+1) * (2*N+1) neighbors conn method.

Parameters
@@ -162,40 +232,55 @@ class GridN(OneEndConnector):
[x x x x x]
[x x x x x]
include_self : bool
Whether create (i, i) conn ?
Whether create (i, i) conn ?
periodic_boundary: bool
Whether the neuron encode the value space with the periodic boundary.
.. versionadded:: 2.2.3.2
"""

def __init__(self, N=1, include_self=False):
super(GridN, self).__init__()
def __init__(
self,
N: int = 1,
include_self: bool = False,
periodic_boundary: bool = False
):
super(GridN, self).__init__(strides=np.arange(-N, N + 1, 1),
include_self=include_self,
periodic_boundary=periodic_boundary)
self.N = N
self.include_self = include_self

def build_conn(self):
if len(self.pre_size) == 1:
height, width = self.pre_size[0], 1
elif len(self.pre_size) == 2:
height, width = self.pre_size
else:
raise ConnectorError(f'Currently, GridN only supports the two-dimensional geometry.')
def __repr__(self):
return (f'{self.__class__.__name__}(N={self.N}, '
f'include_self={self.include_self}, '
f'periodic_boundary={self.periodic_boundary})')

conn_i = []
conn_j = []
for row in range(height):
res = _grid_n(height=height, width=width, row=row,
n=self.N, include_self=self.include_self)
conn_i.extend(res[0])
conn_j.extend(res[1])
pre_ids = np.asarray(conn_i, dtype=IDX_DTYPE)
post_ids = np.asarray(conn_j, dtype=IDX_DTYPE)
def _select_stride(self, stride: np.ndarray) -> np.ndarray:
return (np.ones(len(stride), dtype=bool)
if self.include_self else
(np.sum(np.abs(stride), axis=1) > 0))

return 'ij', (pre_ids, post_ids)
def _select_dist(self, dist: bm.ndarray) -> bm.ndarray:
if self.include_self:
return bm.all(dist <= self.N, axis=1)
else:
return bm.logical_and(bm.all(dist <= self.N, axis=1),
bm.logical_not(bm.all(dist == 0, axis=1)))


class GridEight(GridN):
"""The nearest eight neighbors conn method."""
"""The nearest eight neighbors conn method.

Parameters
----------
include_self : bool
Whether create (i, i) conn ?
periodic_boundary: bool
Whether the neurons encode the value space with the periodic boundary.
.. versionadded:: 2.2.3.2
"""

def __init__(self, include_self=False):
super(GridEight, self).__init__(N=1, include_self=include_self)
def __init__(self, include_self=False, periodic_boundary: bool = False):
super(GridEight, self).__init__(N=1, include_self=include_self, periodic_boundary=periodic_boundary)


grid_eight = GridEight()

+ 24
- 14
brainpy/connect/tests/test_random_conn.py View File

@@ -2,23 +2,30 @@

import pytest

import unittest

import brainpy as bp


def test_random_prob():
conn1 = bp.connect.FixedProb(prob=0.1, seed=123)
conn1(pre_size=(10, 20), post_size=(10, 20))
pre_ids, post_ids, pre2post = conn1.require('pre_ids', 'post_ids', 'pre2post')
class TestFixedProb(unittest.TestCase):
def test_size_consistent(self):
conn1 = bp.connect.FixedProb(prob=0.1, seed=123)
conn1(pre_size=(10, 20), post_size=(10, 20))
pre_ids, post_ids, pre2post = conn1.require('pre_ids', 'post_ids', 'pre2post')
self.assertTrue(len(pre_ids) == len(post_ids))
self.assertTrue(len(pre_ids) == len(pre2post[0]))

conn2 = bp.connect.FixedProb(prob=0.1, seed=123)
conn2(pre_size=(10, 20), post_size=(10, 20))
mat = conn2.require(bp.connect.CONN_MAT)
pre_ids2, post_ids2 = bp.math.where(mat)
def test_require_method(self):
conn2 = bp.connect.FixedProb(prob=0.1, seed=123)
conn2(pre_size=(10, 20), post_size=(10, 20))
mat = conn2.require(bp.connect.CONN_MAT)
self.assertTrue(mat.shape == (200, 200))

print()
assert bp.math.array_equal(pre_ids, pre_ids2)
assert bp.math.array_equal(post_ids, post_ids2)
print('weight_mat', mat)
mat = conn2(100, 1000).require(bp.connect.CONN_MAT)
self.assertTrue(mat.shape == (100, 1000))

mat = conn2.require(10, 20, bp.connect.CONN_MAT)
self.assertTrue(mat.shape == (10, 20))


def test_random_fix_pre1():
@@ -30,8 +37,11 @@ def test_random_fix_pre1():
mat2 = conn2.require(bp.connect.CONN_MAT)

print()
print(f'num = {num}')
print('conn_mat 1\n', mat1)
print(mat1.sum())
print('conn_mat 2\n', mat2)
print(mat2.sum())

assert bp.math.array_equal(mat1, mat2)

@@ -45,7 +55,7 @@ def test_random_fix_pre2():


def test_random_fix_pre3():
with pytest.raises(AssertionError):
with pytest.raises(bp.errors.ConnectorError):
conn1 = bp.connect.FixedPreNum(num=6, seed=1234)(pre_size=3, post_size=4)
conn1.require(bp.connect.CONN_MAT)

@@ -73,7 +83,7 @@ def test_random_fix_post2():


def test_random_fix_post3():
with pytest.raises(AssertionError):
with pytest.raises(bp.errors.ConnectorError):
conn1 = bp.connect.FixedPostNum(num=6, seed=1234)(pre_size=3, post_size=4)
conn1.require(bp.connect.CONN_MAT)



+ 81
- 37
brainpy/connect/tests/test_regular_conn.py View File

@@ -1,52 +1,29 @@
# -*- coding: utf-8 -*-
import numpy as np

import brainpy as bp
from brainpy import connect

import unittest

def test_one2one():
for size in [100, (3, 4), (4, 5, 6)]:
conn = connect.One2One()(pre_size=size, post_size=size)

conn_mat, pre_ids, post_ids, pre2post, pre2syn, post2pre, post2syn = \
conn.require('conn_mat', 'pre_ids', 'post_ids', 'pre2post', 'pre2syn', 'post2pre', 'post2syn')

num = bp.tools.size2num(size)

actual_mat = bp.math.zeros((num, num), dtype=bp.math.bool_)
bp.math.fill_diagonal(actual_mat, True)

assert bp.math.array_equal(actual_mat, conn_mat)
assert bp.math.array_equal(pre_ids, bp.math.arange(num))
assert bp.math.array_equal(post_ids, bp.math.arange(num))

print()
print('conn_mat', conn_mat)
print('pre_ids', pre_ids)
print('post_ids', post_ids)
print('pre2post', pre2post)
print('post2pre', post2pre)
print('pre2syn', pre2syn)
print('post2syn', post2syn)


def test_all2all():
for has_self in [True, False]:
class TestOne2One(unittest.TestCase):
def test_one2one(self):
for size in [100, (3, 4), (4, 5, 6)]:
conn = connect.All2All(include_self=has_self)(pre_size=size, post_size=size)
mat = conn.require(connect.CONN_MAT)
conn = connect.One2One()(pre_size=size, post_size=size)

conn_mat, pre_ids, post_ids, pre2post, pre2syn, post2pre, post2syn = \
conn.require('conn_mat', 'pre_ids', 'post_ids', 'pre2post', 'pre2syn', 'post2pre', 'post2syn')

num = bp.tools.size2num(size)

print(mat)
actual_mat = bp.math.ones((num, num), dtype=bp.math.bool_)
if not has_self:
bp.math.fill_diagonal(actual_mat, False)
actual_mat = bp.math.zeros((num, num), dtype=bp.math.bool_)
bp.math.fill_diagonal(actual_mat, True)

assert bp.math.array_equal(actual_mat, mat)
assert bp.math.array_equal(actual_mat, conn_mat)
assert bp.math.array_equal(pre_ids, bp.math.arange(num))
assert bp.math.array_equal(post_ids, bp.math.arange(num))

print()
print('conn_mat', conn_mat)
print('pre_ids', pre_ids)
print('post_ids', post_ids)
@@ -56,5 +33,72 @@ def test_all2all():
print('post2syn', post2syn)


def test_grid_four():
pass
class TestAll2All(unittest.TestCase):
def test_all2all(self):
for has_self in [True, False]:
for size in [100, (3, 4), (4, 5, 6)]:
conn = connect.All2All(include_self=has_self)(pre_size=size, post_size=size)
mat = conn.require(connect.CONN_MAT)
conn_mat, pre_ids, post_ids, pre2post, pre2syn, post2pre, post2syn = \
conn.require('conn_mat', 'pre_ids', 'post_ids', 'pre2post', 'pre2syn', 'post2pre', 'post2syn')
num = bp.tools.size2num(size)

print(mat)
actual_mat = bp.math.ones((num, num), dtype=bp.math.bool_)
if not has_self:
bp.math.fill_diagonal(actual_mat, False)
assert bp.math.array_equal(actual_mat, mat)

print()
print('conn_mat', conn_mat)
print('pre_ids', pre_ids)
print('post_ids', post_ids)
print('pre2post', pre2post)
print('post2pre', post2pre)
print('pre2syn', pre2syn)
print('post2syn', post2syn)


class TestGridConn(unittest.TestCase):
def test_grid_four(self):
for periodic_boundary in [True, False]:
for include_self in [True, False]:
for size in (10, [10, 10], (4, 4, 5)):
conn = bp.conn.GridFour(include_self=include_self,
periodic_boundary=periodic_boundary)(size, size)
mat = conn.build_mat()
pre_ids, post_ids = conn.build_coo()
new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool)
new_mat[pre_ids, post_ids] = True

print(f'periodic_boundary = {periodic_boundary}, include_self = {include_self}, size = {size}')
self.assertTrue(bp.math.allclose(mat, new_mat))

def test_grid_eight(self):
for periodic_boundary in [True, False]:
for include_self in [True, False]:
for size in (10, [10, 10], (4, 4, 5)):
conn = bp.conn.GridEight(include_self=include_self,
periodic_boundary=periodic_boundary)(size, size)
mat = conn.build_mat()
pre_ids, post_ids = conn.build_coo()
new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool)
new_mat[pre_ids, post_ids] = True

print(f'periodic_boundary = {periodic_boundary}, include_self = {include_self}, size = {size}')
self.assertTrue(bp.math.allclose(mat, new_mat))

def test_grid_N(self):
for periodic_boundary in [True, False]:
for include_self in [True, False]:
for size in (10, [10, 10], (4, 4, 5)):
conn = bp.conn.GridN(include_self=include_self,
periodic_boundary=periodic_boundary,
N=2)(size, size)
mat = conn.build_mat()
pre_ids, post_ids = conn.build_coo()
new_mat = bp.math.zeros((np.prod(size), np.prod(size)), dtype=bool)
new_mat[pre_ids, post_ids] = True

print(f'periodic_boundary = {periodic_boundary}, include_self = {include_self}, size = {size}')
self.assertTrue(bp.math.allclose(mat, new_mat))

+ 167
- 0
brainpy/datasets/vision/cifar.py View File

@@ -0,0 +1,167 @@
import os.path
import pickle
from typing import Any, Callable, Optional, Tuple

import numpy as np
from PIL import Image

from .utils import check_integrity, download_and_extract_archive
from .base import VisionDataset


class CIFAR10(VisionDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

Args:
root (string): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.

"""

base_folder = "cifar-10-batches-py"
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
tgz_md5 = "c58f30108f718f92721af3b95e74349a"
train_list = [
["data_batch_1", "c99cafc152244af753f735de768cd75f"],
["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
["data_batch_4", "634d18415352ddfa80567beed471001a"],
["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
]

test_list = [
["test_batch", "40351d587109b95175f43aff81a1287e"],
]
meta = {
"filename": "batches.meta",
"key": "label_names",
"md5": "5ff9c542aee3614f3951f8cda6e48888",
}

def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:

super().__init__(root, transform=transform, target_transform=target_transform)

self.train = train # training set or test set

if download:
self.download()

if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")

if self.train:
downloaded_list = self.train_list
else:
downloaded_list = self.test_list

self.data: Any = []
self.targets = []

# now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, "rb") as f:
entry = pickle.load(f, encoding="latin1")
self.data.append(entry["data"])
if "labels" in entry:
self.targets.extend(entry["labels"])
else:
self.targets.extend(entry["fine_labels"])

self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC

self._load_meta()

def _load_meta(self) -> None:
path = os.path.join(self.root, self.base_folder, self.meta["filename"])
if not check_integrity(path, self.meta["md5"]):
raise RuntimeError("Dataset metadata file not found or corrupted. You can use download=True to download it")
with open(path, "rb") as infile:
data = pickle.load(infile, encoding="latin1")
self.classes = data[self.meta["key"]]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]

# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self) -> int:
return len(self.data)

def _check_integrity(self) -> bool:
for filename, md5 in self.train_list + self.test_list:
fpath = os.path.join(self.root, self.base_folder, filename)
if not check_integrity(fpath, md5):
return False
return True

def download(self) -> None:
if self._check_integrity():
print("Files already downloaded and verified")
return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)

def extra_repr(self) -> str:
split = "Train" if self.train is True else "Test"
return f"Split: {split}"


class CIFAR100(CIFAR10):
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

This is a subclass of the `CIFAR10` Dataset.
"""

base_folder = "cifar-100-python"
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename = "cifar-100-python.tar.gz"
tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
train_list = [
["train", "16019d7e3df5f24257cddd939b257f8d"],
]

test_list = [
["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"],
]
meta = {
"filename": "meta",
"key": "fine_label_names",
"md5": "7973b15100ade9c7d40fb424638fde48",
}

+ 17
- 13
brainpy/dyn/base.py View File

@@ -73,7 +73,7 @@ class DynamicalSystem(Base):

def __init__(
self,
name: str = None,
name: Optional[str] = None,
mode: Optional[Mode] = None,
):
# mode setting
@@ -182,7 +182,8 @@ class DynamicalSystem(Base):
elif delay.num_delay_step - 1 < max_delay_step:
self.global_delay_data[identifier][0].reset(delay_target, max_delay_step, initial_delay_data)
else:
self.global_delay_data[identifier] = (None, delay_target)
if identifier not in self.global_delay_data:
self.global_delay_data[identifier] = (None, delay_target)
self.register_implicit_nodes(self.local_delay_vars)
return delay_step

@@ -343,8 +344,7 @@ class DynamicalSystem(Base):
raise NoImplementationError('Subclass must implement offline_fit() function when using OfflineTrainer.')

def clear_input(self):
for node in self.nodes(level=1, include_self=False).subset(NeuGroup).unique().values():
node.clear_input()
pass


class Container(DynamicalSystem):
@@ -430,6 +430,10 @@ class Container(DynamicalSystem):
else:
return super(Container, self).__getattribute__(item)

def clear_input(self):
for node in self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values():
node.clear_input()


class Sequential(Container):
def __init__(
@@ -753,8 +757,8 @@ class SynConn(DynamicalSystem):

def __repr__(self):
names = self.__class__.__name__
return (f'{names}(name={self.name}, mode={self.mode}, '
f'{" " * len(names)} pre={self.pre}, '
return (f'{names}(name={self.name}, mode={self.mode}, \n'
f'{" " * len(names)} pre={self.pre}, \n'
f'{" " * len(names)} post={self.post})')

def check_pre_attrs(self, *attrs):
@@ -984,7 +988,7 @@ class TwoEndConn(SynConn):
ltp.register_master(master=self)
self.ltp: SynLTP = ltp

def init_weights(
def _init_weights(
self,
weight: Union[float, Array, Initializer, Callable],
comp_method: str,
@@ -992,7 +996,7 @@ class TwoEndConn(SynConn):
) -> Union[float, Array]:
if comp_method not in ['sparse', 'dense']:
raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}')
if sparse_data not in ['csr', 'ij']:
if sparse_data not in ['csr', 'ij', 'coo']:
raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {sparse_data}')
if self.conn is None:
raise ValueError(f'Must provide "conn" when initialize the model {self.name}')
@@ -1010,11 +1014,11 @@ class TwoEndConn(SynConn):
if comp_method == 'sparse':
if sparse_data == 'csr':
conn_mask = self.conn.require('pre2post')
elif sparse_data == 'ij':
elif sparse_data in ['ij', 'coo']:
conn_mask = self.conn.require('post_ids', 'pre_ids')
else:
ValueError(f'Unknown sparse data type: {sparse_data}')
weight = parameter(weight, conn_mask[1].shape, allow_none=False)
weight = parameter(weight, conn_mask[0].shape, allow_none=False)
elif comp_method == 'dense':
weight = parameter(weight, (self.pre.num, self.post.num), allow_none=False)
conn_mask = self.conn.require('conn_mat')
@@ -1026,7 +1030,7 @@ class TwoEndConn(SynConn):
weight = bm.TrainVar(weight)
return weight, conn_mask

def syn2post_with_all2all(self, syn_value, syn_weight):
def _syn2post_with_all2all(self, syn_value, syn_weight):
if bm.ndim(syn_weight) == 0:
if isinstance(self.mode, BatchingMode):
post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:])
@@ -1039,10 +1043,10 @@ class TwoEndConn(SynConn):
post_vs = syn_value @ syn_weight
return post_vs

def syn2post_with_one2one(self, syn_value, syn_weight):
def _syn2post_with_one2one(self, syn_value, syn_weight):
return syn_value * syn_weight

def syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
if bm.ndim(syn_weight) == 0:
post_vs = (syn_weight * syn_value) @ conn_mat
else:


+ 2
- 1
brainpy/dyn/layers/__init__.py View File

@@ -7,4 +7,5 @@ from .reservoir import *
from .rnncells import *
from .conv import *
from .normalization import *
from .pooling import *
from .pooling import *
from .activate import *

+ 36
- 0
brainpy/dyn/layers/activate.py View File

@@ -0,0 +1,36 @@
from typing import Callable
from typing import Optional

from brainpy.dyn.base import DynamicalSystem
from brainpy.modes import Mode, training


class Activation(DynamicalSystem):
r"""Applies an activation function to the inputs

Parameters:
----------
activate_fun: Callable, function
The function of Activation
name: str, Optional
The name of the object
mode: Mode
Enable training this node or not. (default True).
"""

def __init__(
self,
activate_fun: Callable,
name: Optional[str] = None,
mode: Mode = training,
**kwargs,
):
super().__init__(name, mode)
self.activate_fun = activate_fun
self.kwargs = kwargs

def update(self, sha, x):
return self.activate_fun(x, **self.kwargs)

def reset_state(self, batch_size=None):
pass

+ 327
- 130
brainpy/dyn/layers/conv.py View File

@@ -1,37 +1,38 @@
# -*- coding: utf-8 -*-


import jax.lax
from jax import lax
from typing import Union, Tuple, Optional, Sequence

import brainpy.math as bm
from brainpy import math as bm, tools
from brainpy.dyn.base import DynamicalSystem
from brainpy.initialize import XavierNormal, ZeroInit, parameter
from brainpy.modes import Mode, TrainingMode, NormalMode, training, check
from brainpy.initialize import Initializer, XavierNormal, ZeroInit, parameter
from brainpy.modes import Mode, TrainingMode, training
from brainpy.types import Array

__all__ = [
'GeneralConv',
'Conv1D',
'Conv2D',
'Conv3D'
]


def _check_tuple(v):
if isinstance(v, (tuple, list)):
return tuple(v)
elif isinstance(v, int):
return (v, v)
def to_dimension_numbers(num_spatial_dims: int, channels_last: bool, transpose: bool) -> lax.ConvDimensionNumbers:
"""Create a `lax.ConvDimensionNumbers` for the given inputs."""
num_dims = num_spatial_dims + 2
if channels_last:
spatial_dims = tuple(range(1, num_dims - 1))
image_dn = (0, num_dims - 1) + spatial_dims
else:
raise ValueError


def _conv_dimension_numbers(input_shape):
"""Computes the dimension numbers based on the input shape."""
ndim = len(input_shape)
lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
out_spec = lhs_spec
return jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
spatial_dims = tuple(range(2, num_dims))
image_dn = (0, 1) + spatial_dims
if transpose:
kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
else:
kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))
return lax.ConvDimensionNumbers(lhs_spec=image_dn,
rhs_spec=kernel_dn,
out_spec=image_dn)


class GeneralConv(DynamicalSystem):
@@ -39,182 +40,378 @@ class GeneralConv(DynamicalSystem):

Parameters
----------
in_channels: integer
number of input channels.
out_channels: integer
number of output channels.
kernel_size: sequence[int]
shape of the convolutional kernel. For 1D convolution,
the kernel size can be passed as an integer. For all other cases, it must
be a sequence of integers.
strides: sequence[int]
an integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
padding: str, sequence[int]
either the string `'SAME'`, the string `'VALID'`, the string
`'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
num_spatial_dims: int
The number of spatial dimensions of the input.
in_channels: int
The number of input channels.
out_channels: int
The number of output channels.
kernel_size: int, sequence of int
The shape of the convolutional kernel.
For 1D convolution, the kernel size can be passed as an integer.
For all other cases, it must be a sequence of integers.
strides: int, sequence of int
An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
padding: str, sequence of int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
high)` integer pairs that give the padding to apply before and after each
spatial dimension. A single int is interpeted as applying the same padding
in all dims and passign a single int in a sequence causes the same padding
to be used on both sides.
input_dilation: integer, sequence[int]
an integer or a sequence of `n` integers, giving the
spatial dimension.
lhs_dilation: int, sequence of int
An integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `inputs`
(default: 1). Convolution with input dilation `d` is equivalent to
transposed convolution with stride `d`.
kernel_dilation: integer, sequence[int]
an integer or a sequence of `n` integers, giving the
rhs_dilation: int, sequence of int
An integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of the convolution
kernel (default: 1). Convolution with kernel dilation
is also known as 'atrous convolution'.
groups: integer, default 1.
If specified divides the input
features into groups.
w_init: brainpy.init.Initializer
initializer for the convolutional kernel.
b_init: brainpy.init.Initializer
initializer for the bias.
groups: int
If specified, divides the input features into groups. default 1.
w_init: Initializer
The initializer for the convolutional kernel.
b_init: Initializer
The initializer for the bias.
mask: Array, Optional
The optional mask of the weights.
mode: Mode
The computation mode of the current object. Default it is `training`.
name: str, Optional
The name of the object.
"""

def __init__(
self,
in_channels,
out_channels,
kernel_size,
strides=None,
padding='SAME',
input_dilation=None,
kernel_dilation=None,
groups=1,
w_init=XavierNormal(),
b_init=ZeroInit(),
num_spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
strides: Union[int, Tuple[int, ...]] = 1,
padding: Union[str, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
w_init: Initializer = XavierNormal(),
b_init: Initializer = ZeroInit(),
mask: Optional[Array] = None,
mode: Mode = training,
name: str = None,
):
super(GeneralConv, self).__init__(name=name, mode=mode)

self.num_spatial_dims = num_spatial_dims
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.strides = strides
self.padding = padding
self.input_dilation = input_dilation
self.kernel_dilation = kernel_dilation
self.strides = tools.replicate(strides, num_spatial_dims, 'strides')
self.kernel_size = tools.replicate(kernel_size, num_spatial_dims, 'kernel_size')
self.lhs_dilation = tools.replicate(lhs_dilation, num_spatial_dims, 'lhs_dilation')
self.rhs_dilation = tools.replicate(rhs_dilation, num_spatial_dims, 'rhs_dilation')
self.groups = groups
self.w_init = w_init
self.b_init = b_init
self.dimension_numbers = None
self.mask = mask
self.dimension_numbers = to_dimension_numbers(num_spatial_dims, channels_last=True, transpose=False)

if isinstance(padding, str):
assert padding in ['SAME', 'VALID']
elif isinstance(padding, tuple):
for k in padding:
assert isinstance(k, int)
elif isinstance(padding, (tuple, list)):
if isinstance(padding[0], int):
padding = (padding,) * num_spatial_dims
elif isinstance(padding[0], (tuple, list)):
if len(padding) == 1:
padding = tuple(padding) * num_spatial_dims
else:
if len(padding) != num_spatial_dims:
raise ValueError(f"Padding {padding} must be a Tuple[int, int], "
f"or sequence of Tuple[int, int] with length 1, "
f"or sequence of Tuple[int, int] length {num_spatial_dims}.")
padding = tuple(padding)
else:
raise ValueError
self.padding = padding

assert out_channels % self.groups == 0, '"nout" should be divisible by groups'
assert self.out_channels % self.groups == 0, '"out_channels" should be divisible by groups'
assert self.in_channels % self.groups == 0, '"in_channels" should be divisible by groups'

assert self.in_channels % self.groups == 0, '"nin" should be divisible by groups'
kernel_shape = _check_tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels)
kernel_shape = tuple(self.kernel_size) + (self.in_channels // self.groups, self.out_channels)
bias_shape = (1,) * len(self.kernel_size) + (self.out_channels,)
self.w = parameter(self.w_init, kernel_shape)
self.b = parameter(self.b_init, (1,) * len(self.kernel_size) + (self.out_channels,))
self.b = parameter(self.b_init, bias_shape)
if isinstance(self.mode, TrainingMode):
self.w = bm.TrainVar(self.w)
self.b = bm.TrainVar(self.b)

def _check_input_dim(self, x):
pass
raise NotImplementedError

def update(self, sha, x):
self._check_input_dim(x)
if self.strides is None:
self.strides = (1,) * (len(x.shape) - 2)
y = jax.lax.conv_general_dilated(lhs=x.value if isinstance(x, bm.JaxArray) else x,
rhs=self.w.value,
window_strides=self.strides,
padding=self.padding,
lhs_dilation=self.input_dilation,
rhs_dilation=self.kernel_dilation,
feature_group_count=self.groups,
dimension_numbers=self.dimension_numbers)
w = self.w.value
if self.mask is not None:
if self.mask.shape != self.w.shape:
raise ValueError(f"Mask needs to have the same shape as weights. {self.mask.shape} != {self.w.shape}")
w *= self.mask
y = lax.conv_general_dilated(lhs=bm.as_jax(x),
rhs=bm.as_jax(w),
window_strides=self.strides,
padding=self.padding,
lhs_dilation=self.lhs_dilation,
rhs_dilation=self.rhs_dilation,
feature_group_count=self.groups,
dimension_numbers=self.dimension_numbers)
if self.b is None:
return y
return y + self.b.value
else:
return y + self.b.value

def reset_state(self, batch_size=None):
pass


class Conv1D(GeneralConv):
"""One-dimensional convolution.

Parameters
----------
in_channels: int
The number of input channels.
out_channels: int
The number of output channels.
kernel_size: int, sequence of int
The shape of the convolutional kernel.
For 1D convolution, the kernel size can be passed as an integer.
For all other cases, it must be a sequence of integers.
strides: int, sequence of int
An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
padding: str, sequence of int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
high)` integer pairs that give the padding to apply before and after each
spatial dimension.
lhs_dilation: int, sequence of int
An integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `inputs`
(default: 1). Convolution with input dilation `d` is equivalent to
transposed convolution with stride `d`.
rhs_dilation: int, sequence of int
An integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of the convolution
kernel (default: 1). Convolution with kernel dilation
is also known as 'atrous convolution'.
groups: int
If specified, divides the input features into groups. default 1.
w_init: Initializer
The initializer for the convolutional kernel.
b_init: Initializer
The initializer for the bias.
mask: Array, Optional
The optional mask of the weights.
mode: Mode
The computation mode of the current object. Default it is `training`.
name: str, Optional
The name of the object.

"""

def __init__(
self,
in_channels,
out_channels,
kernel_size,
**kwargs
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
strides: Union[int, Tuple[int, ...]] = 1,
padding: Union[str, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
w_init: Initializer = XavierNormal(),
b_init: Initializer = ZeroInit(),
mask: Optional[Array] = None,
mode: Mode = training,
name: str = None,
):
super(Conv1D, self).__init__(in_channels, out_channels, kernel_size, **kwargs)

self.dimension_numbers = ('NWC', 'WIO', 'NWC')
super(Conv1D, self).__init__(num_spatial_dims=1,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
strides=strides,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
groups=groups,
w_init=w_init,
b_init=b_init,
mask=mask,
mode=mode,
name=name)

def _check_input_dim(self, x):
ndim = len(x.shape)
if ndim != 3:
raise ValueError(
"expected 3D input (got {}D input)".format(ndim)
)
if x.ndim != 3:
raise ValueError(f"expected 3D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(
f"input channels={x.shape[-1]} needs to have the same size as in_channels={self.in_channels}."
)
assert len(self.kernel_size) == 1, "expected 1D kernel size (got {}D input)".format(self.kernel_size)
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")


class Conv2D(GeneralConv):
"""Two-dimensional convolution.

Parameters
----------
in_channels: int
The number of input channels.
out_channels: int
The number of output channels.
kernel_size: int, sequence of int
The shape of the convolutional kernel.
For 1D convolution, the kernel size can be passed as an integer.
For all other cases, it must be a sequence of integers.
strides: int, sequence of int
An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
padding: str, sequence of int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
high)` integer pairs that give the padding to apply before and after each
spatial dimension.
lhs_dilation: int, sequence of int
An integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `inputs`
(default: 1). Convolution with input dilation `d` is equivalent to
transposed convolution with stride `d`.
rhs_dilation: int, sequence of int
An integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of the convolution
kernel (default: 1). Convolution with kernel dilation
is also known as 'atrous convolution'.
groups: int
If specified, divides the input features into groups. default 1.
w_init: Initializer
The initializer for the convolutional kernel.
b_init: Initializer
The initializer for the bias.
mask: Array, Optional
The optional mask of the weights.
mode: Mode
The computation mode of the current object. Default it is `training`.
name: str, Optional
The name of the object.

"""

def __init__(
self,
in_channels,
out_channels,
kernel_size,
**kwargs
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
strides: Union[int, Tuple[int, ...]] = 1,
padding: Union[str, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
w_init: Initializer = XavierNormal(),
b_init: Initializer = ZeroInit(),
mask: Optional[Array] = None,
mode: Mode = training,
name: str = None,
):
super(Conv2D, self).__init__(in_channels, out_channels, kernel_size, **kwargs)

self.dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
super(Conv2D, self).__init__(num_spatial_dims=2,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
strides=strides,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
groups=groups,
w_init=w_init,
b_init=b_init,
mask=mask,
mode=mode,
name=name)

def _check_input_dim(self, x):
ndim = len(x.shape)
if ndim != 4:
raise ValueError(
"expected 4D input (got {}D input)".format(ndim)
)
if x.ndim != 4:
raise ValueError(f"expected 4D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(
f"input channels={x.shape[-1]} needs to have the same size as in_channels={self.in_channels}."
)
assert len(self.kernel_size) == 2, "expected 2D kernel size (got {}D input)".format(self.kernel_size)
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")


class Conv3D(GeneralConv):
"""Three-dimensional convolution.

Parameters
----------
in_channels: int
The number of input channels.
out_channels: int
The number of output channels.
kernel_size: int, sequence of int
The shape of the convolutional kernel.
For 1D convolution, the kernel size can be passed as an integer.
For all other cases, it must be a sequence of integers.
strides: int, sequence of int
An integer or a sequence of `n` integers, representing the inter-window strides (default: 1).
padding: str, sequence of int, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence of n `(low,
high)` integer pairs that give the padding to apply before and after each
spatial dimension.
lhs_dilation: int, sequence of int
An integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of `inputs`
(default: 1). Convolution with input dilation `d` is equivalent to
transposed convolution with stride `d`.
rhs_dilation: int, sequence of int
An integer or a sequence of `n` integers, giving the
dilation factor to apply in each spatial dimension of the convolution
kernel (default: 1). Convolution with kernel dilation
is also known as 'atrous convolution'.
groups: int
If specified, divides the input features into groups. default 1.
w_init: Initializer
The initializer for the convolutional kernel.
b_init: Initializer
The initializer for the bias.
mask: Array, Optional
The optional mask of the weights.
mode: Mode
The computation mode of the current object. Default it is `training`.
name: str, Optional
The name of the object.

"""

def __init__(
self,
in_channels,
out_channels,
kernel_size,
**kwargs
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, ...]],
strides: Union[int, Tuple[int, ...]] = 1,
padding: Union[str, Tuple[int, int], Sequence[Tuple[int, int]]] = 'SAME',
lhs_dilation: Union[int, Tuple[int, ...]] = 1,
rhs_dilation: Union[int, Tuple[int, ...]] = 1,
groups: int = 1,
w_init: Initializer = XavierNormal(),
b_init: Initializer = ZeroInit(),
mask: Optional[Array] = None,
mode: Mode = training,
name: str = None,
):
super(Conv3D, self).__init__(in_channels, out_channels, kernel_size, **kwargs)

self.dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC')
super(Conv3D, self).__init__(num_spatial_dims=3,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
strides=strides,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
groups=groups,
w_init=w_init,
b_init=b_init,
mask=mask,
mode=mode,
name=name)

def _check_input_dim(self, x):
ndim = len(x.shape)
if ndim != 5:
raise ValueError(
"expected 5D input (got {}D input)".format(ndim)
)
if x.ndim != 5:
raise ValueError(f"expected 5D input (got {x.ndim}D input)")
if self.in_channels != x.shape[-1]:
raise ValueError(
f"input channels={x.shape[-1]} needs to have the same size as in_channels={self.in_channels}."
)
assert len(self.kernel_size) == 3, "expected 3D kernel size (got {}D input)".format(self.kernel_size)
raise ValueError(f"input channels={x.shape[-1]} needs to have "
f"the same size as in_channels={self.in_channels}.")

+ 4
- 6
brainpy/dyn/layers/dropout.py View File

@@ -15,11 +15,7 @@ class Dropout(DynamicalSystem):
In training, to compensate for the fraction of input values dropped (`rate`),
all surviving values are multiplied by `1 / (1 - rate)`.

The parameter `shared_axes` allows to specify a list of axes on which
the mask will be shared: we will use size 1 on those axes for dropout mask
and broadcast it. Sharing reduces randomness, but can save memory.

This layer is active only during training (`mode='train'`). In other
This layer is active only during training (`mode=brainpy.modes.training`). In other
circumstances it is a no-op.

Parameters
@@ -28,6 +24,8 @@ class Dropout(DynamicalSystem):
Probability to keep element of the tensor.
seed : optional, int
The random sampling seed.
mode: Mode
The computation mode of the object.
name : str, optional
The name of the dynamic system.

@@ -47,7 +45,7 @@ class Dropout(DynamicalSystem):
):
super(Dropout, self).__init__(mode=mode, name=name)
self.prob = prob
self.rng = bm.random.RandomState(seed=seed)
self.rng = bm.random.RandomState(seed)

def update(self, sha, x):
if sha.get('fit', True):


+ 30
- 1
brainpy/dyn/layers/linear.py View File

@@ -9,12 +9,13 @@ from brainpy import math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.errors import MathError
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
from brainpy.modes import Mode, TrainingMode, training
from brainpy.modes import Mode, TrainingMode, BatchingMode, training, batching
from brainpy.tools.checking import check_initializer
from brainpy.types import Array

__all__ = [
'Dense',
'Flatten'
]


@@ -188,3 +189,31 @@ class Dense(DynamicalSystem):
bias, Wff = bm.split(weights, [1])
self.W.value = Wff
self.b.value = bias[0]


class Flatten(DynamicalSystem):
r"""Flattens a contiguous range of dims into 2D or 1D.

Parameters:
----------
name: str, Optional
The name of the object
mode: Mode
Enable training this node or not. (default True)
"""

def __init__(
self,
name: Optional[str] = None,
mode: Optional[Mode] = batching,
):
super().__init__(name, mode)

def update(self, shr, x):
if isinstance(self.mode, BatchingMode):
return x.reshape((x.shape[0], -1))
else:
return x.flatten()

def reset_state(self, batch_size=None):
pass

+ 378
- 278
brainpy/dyn/layers/normalization.py View File

@@ -1,262 +1,374 @@
# -*- coding: utf-8 -*-

from typing import Union
from typing import Union, Optional, Sequence

import jax.nn
import jax.numpy as jnp
import jax.lax
from jax import lax, numpy as jnp

import brainpy.math as bm
from brainpy.initialize import ZeroInit, OneInit, Initializer, parameter
from brainpy.dyn.base import DynamicalSystem
from brainpy.modes import Mode, TrainingMode, NormalMode, training, check
from brainpy.initialize import ZeroInit, OneInit, Initializer, parameter
from brainpy.modes import Mode, TrainingMode, training

__all__ = [
'BatchNorm',
'BatchNorm1d',
'BatchNorm2d',
'BatchNorm3d',
'GroupNorm',
'BatchNorm1D',
'BatchNorm2D',
'BatchNorm3D',

'LayerNorm',
'GroupNorm',
'InstanceNorm',
]


def _abs_sq(x):
"""Computes the elementwise square of the absolute value |x|^2."""
if jnp.iscomplexobj(x):
return lax.square(lax.real(x)) + lax.square(lax.imag(x))
else:
return lax.square(x)


class BatchNorm(DynamicalSystem):
"""Batch Normalization node.
"""Batch Normalization layer.

This layer aims to reduce the internal covariant shift of data. It
normalizes a batch of data by fixing the mean and variance of inputs
on each feature (channel). Most commonly, the first axis of the data
is the batch, and the last is the channel. However, users can specify
the axes to be normalized.

adapted from jax.example_libraries.stax.BatchNorm
https://jax.readthedocs.io/en/latest/_modules/jax/example_libraries/stax.html#BatchNorm

Parameters
----------
num_features: int
``C`` from an expected input of size ``(..., C)``.
axis: int, tuple, list
axes where the data will be normalized. The feature (channel) axis should be excluded.
Axes where the data will be normalized. The feature (channel) axis should be excluded.
epsilon: float
a value added to the denominator for numerical stability. Default: 1e-5
use_bias: bool
whether to translate data in refactoring. Default: True
use_scale: bool
whether to scale data in refactoring. Default: True
beta_init: brainpy.init.Initializer
an initializer generating the original translation matrix
gamma_init: brainpy.init.Initializer
an initializer generating the original scaling matrix
A value added to the denominator for numerical stability. Default: 1e-5
affine: bool
A boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
bias_init: Initializer
An initializer generating the original translation matrix
scale_init: Initializer
An initializer generating the original scaling matrix
"""

def __init__(self,
axis: Union[int, tuple, list],
epsilon: float = 1e-5,
use_bias: bool = True,
use_scale: bool = True,
beta_init: Initializer = ZeroInit(),
gamma_init: Initializer = OneInit(),
mode: Mode = training,
name: str = None,
**kwargs):
def __init__(
self,
num_features: int,
axis: Union[int, Sequence[int]],
epsilon: float = 1e-5,
momentum: Optional[float] = 0.99,
affine: bool = True,
bias_init: Initializer = ZeroInit(),
scale_init: Initializer = OneInit(),
mode: Mode = training,
name: str = None,
):
super(BatchNorm, self).__init__(name=name, mode=mode)

# parameters
self.num_features = num_features
self.epsilon = epsilon
self.bias = use_bias
self.scale = use_scale
self.beta_init = beta_init if use_bias else ()
self.gamma_init = gamma_init if use_scale else ()
self.momentum = momentum
self.affine = affine
self.bias_init = bias_init
self.scale_init = scale_init
self.axis = (axis,) if jnp.isscalar(axis) else axis

# variables
self.running_mean = bm.Variable(bm.zeros(self.num_features))
self.running_var = bm.Variable(bm.ones(self.num_features))
if self.affine:
assert isinstance(self.mode, TrainingMode)
self.bias = bm.TrainVar(parameter(self.bias_init, self.num_features))
self.scale = bm.TrainVar(parameter(self.scale_init, self.num_features))

def _check_input_dim(self, x):
pass
raise NotImplementedError

def update(self, sha, x):
self._check_input_dim(x)

input_shape = tuple(d for i, d in enumerate(x.shape) if i not in self.axis)
self.beta = parameter(self.beta_init, input_shape) if self.bias else None
self.gamma = parameter(self.gamma_init, input_shape) if self.scale else None
if isinstance(self.mode, TrainingMode):
self.beta = bm.TrainVar(self.beta)
self.gamma = bm.TrainVar(self.gamma)

ed = tuple(None if i in self.axis else slice(None) for i in range(jnp.ndim(x)))
# output = bm.normalize(x, self.axis, epsilon=self.epsilon)
print(x)
output = jax.nn.standardize(x.value, self.axis, epsilon=self.epsilon)
print(output)
if self.bias and self.scale: return self.gamma[ed] * output + self.beta[ed]
if self.bias: return output + self.beta[ed]
if self.scale: return self.gamma[ed] * output
return output
if sha['fit']:
mean = bm.mean(x, self.axis)
mean2 = bm.mean(_abs_sq(x), self.axis)
var = jnp.maximum(0., mean2 - _abs_sq(mean))
self.running_mean.value = (self.momentum * self.running_mean.value +
(1 - self.momentum) * mean)
self.running_var.value = (self.momentum * self.running_var.value +
(1 - self.momentum) * var)
else:
mean = self.running_mean.value
var = self.running_var.value
stats_shape = [(1 if i in self.axis else x.shape[i]) for i in range(x.ndim)]
mean = mean.reshape(stats_shape)
var = var.reshape(stats_shape)

y = x - mean
mul = lax.rsqrt(var + lax.convert_element_type(self.epsilon, x.dtype))
if self.affine:
mul *= self.scale
y *= mul
if self.affine:
y += self.bias
return y

def reset_state(self, batch_size=None):
pass


class BatchNorm1d(BatchNorm):
class BatchNorm1D(BatchNorm):
"""1-D batch normalization.

The data should be of `(b, l, c)`, where `b` is the batch dimension,
`l` is the layer dimension, and `c` is the channel dimension, or of
'(b, c)'.
`l` is the layer dimension, and `c` is the channel dimension.

Parameters
----------
num_features: int
``C`` from an expected input of size ``(B, L, C)``.
axis: int, tuple, list
axes where the data will be normalized. The feature (channel) axis should be excluded.
epsilon: float
a value added to the denominator for numerical stability. Default: 1e-5
use_bias: bool
whether to translate data in refactoring. Default: True
use_scale: bool
whether to scale data in refactoring. Default: True
beta_init: brainpy.init.Initializer
A value added to the denominator for numerical stability. Default: 1e-5
affine: bool
A boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
bias_init: Initializer
an initializer generating the original translation matrix
gamma_init: brainpy.init.Initializer
scale_init: Initializer
an initializer generating the original scaling matrix
"""
def __init__(self, axis=(0, 1), **kwargs):
super(BatchNorm1d, self).__init__(axis=axis, **kwargs)

def __init__(
self,
num_features: int,
axis: Union[int, Sequence[int]] = (0, 1),
epsilon: float = 1e-5,
momentum: Optional[float] = 0.99,
affine: bool = True,
bias_init: Initializer = ZeroInit(),
scale_init: Initializer = OneInit(),
mode: Mode = training,
name: str = None,
):
super(BatchNorm1D, self).__init__(num_features=num_features,
axis=axis,
epsilon=epsilon,
momentum=momentum,
affine=affine,
bias_init=bias_init,
scale_init=scale_init,
mode=mode,
name=name)

def _check_input_dim(self, x):
ndim = len(x.shape)
if ndim != 2 and ndim != 3:
raise ValueError(
"expected 2D or 3D input (got {}D input)".format(ndim)
)
if ndim == 2 and len(self.axis) == 2:
self.axis = (0,)
if x.ndim != 3:
raise ValueError(f"expected 3D input (got {x.ndim}D input)")
assert x.shape[-1] == self.num_features


class BatchNorm2d(BatchNorm):
class BatchNorm2D(BatchNorm):
"""2-D batch normalization.
The data should be of `(b, h, w, c)`, where `b` is the batch dimension,
`h` is the height dimension, `w` is the width dimension, and `c` is the
channel dimension.

Parameters
----------
axis: int, tuple, list
axes where the data will be normalized. The feature (channel) axis should be excluded.
epsilon: float
a value added to the denominator for numerical stability. Default: 1e-5
use_bias: bool
whether to translate data in refactoring. Default: True
use_scale: bool
whether to scale data in refactoring. Default: True
beta_init: brainpy.init.Initializer
an initializer generating the original translation matrix
gamma_init: brainpy.init.Initializer
an initializer generating the original scaling matrix
"""
def __init__(self, axis=(0, 1, 2), **kwargs):
super(BatchNorm2d, self).__init__(axis=axis, **kwargs)

The data should be of `(b, h, w, c)`, where `b` is the batch dimension,
`h` is the height dimension, `w` is the width dimension, and `c` is the
channel dimension.

Parameters
----------
num_features: int
``C`` from an expected input of size ``(B, H, W, C)``.
axis: int, tuple, list
axes where the data will be normalized. The feature (channel) axis should be excluded.
epsilon: float
a value added to the denominator for numerical stability. Default: 1e-5
affine: bool
A boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
bias_init: Initializer
an initializer generating the original translation matrix
scale_init: Initializer
an initializer generating the original scaling matrix
"""

def __init__(
self,
num_features: int,
axis: Union[int, Sequence[int]] = (0, 1, 2),
epsilon: float = 1e-5,
momentum: Optional[float] = 0.99,
affine: bool = True,
bias_init: Initializer = ZeroInit(),
scale_init: Initializer = OneInit(),
mode: Mode = training,
name: str = None,
):
super(BatchNorm2D, self).__init__(num_features=num_features,
axis=axis,
epsilon=epsilon,
momentum=momentum,
affine=affine,
bias_init=bias_init,
scale_init=scale_init,
mode=mode,
name=name)

def _check_input_dim(self, x):
ndim = len(x.shape)
if ndim != 4:
raise ValueError(
"expected 4D input (got {}D input)".format(ndim)
)
if x.ndim != 4:
raise ValueError(f"expected 4D input (got {x.ndim}D input)")
assert x.shape[-1] == self.num_features


class BatchNorm3d(BatchNorm):
class BatchNorm3D(BatchNorm):
"""3-D batch normalization.
The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
`h` is the height dimension, `w` is the width dimension, `d` is the depth
dimension, and `c` is the channel dimension.

Parameters
----------
axis: int, tuple, list
axes where the data will be normalized. The feature (channel) axis should be excluded.
epsilon: float
a value added to the denominator for numerical stability. Default: 1e-5
use_bias: bool
whether to translate data in refactoring. Default: True
use_scale: bool
whether to scale data in refactoring. Default: True
beta_init: brainpy.init.Initializer
an initializer generating the original translation matrix
gamma_init: brainpy.init.Initializer
an initializer generating the original scaling matrix
"""
def __init__(self, axis=(0, 1, 2, 3), **kwargs):
super(BatchNorm3d, self).__init__(axis=axis, **kwargs)

The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
`h` is the height dimension, `w` is the width dimension, `d` is the depth
dimension, and `c` is the channel dimension.

Parameters
----------
num_features: int
``C`` from an expected input of size ``(B, H, W, D, C)``.
axis: int, tuple, list
axes where the data will be normalized. The feature (channel) axis should be excluded.
epsilon: float
a value added to the denominator for numerical stability. Default: 1e-5
affine: bool
A boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
bias_init: Initializer
an initializer generating the original translation matrix
scale_init: Initializer
an initializer generating the original scaling matrix
"""

def __init__(
self,
num_features: int,
axis: Union[int, Sequence[int]] = (0, 1, 2, 3),
epsilon: float = 1e-5,
momentum: Optional[float] = 0.99,
affine: bool = True,
bias_init: Initializer = ZeroInit(),
scale_init: Initializer = OneInit(),
mode: Mode = training,
name: str = None,
):
super(BatchNorm3D, self).__init__(num_features=num_features,
axis=axis,
epsilon=epsilon,
momentum=momentum,
affine=affine,
bias_init=bias_init,
scale_init=scale_init,
mode=mode,
name=name)

def _check_input_dim(self, x):
ndim = len(x.shape)
if ndim != 5:
raise ValueError(
"expected 5D input (got {}D input)".format(ndim)
)
if x.ndim != 5:
raise ValueError(f"expected 5D input (got {x.ndim}D input)")
assert x.shape[-1] == self.num_features


class LayerNorm(DynamicalSystem):
"""Layer normalization (https://arxiv.org/abs/1607.06450).

.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

This layer normalizes data on each example, independently of the batch. More
specifically, it normalizes data of shape (b, d1, d2, ..., c) on the axes of
the data dimensions and the channel (d1, d2, ..., c). Different from batch
normalization, gamma and beta are assigned to each position (elementwise
normalization, scale and bias are assigned to each position (elementwise
operation) instead of the whole channel. If users want to assign a single
gamma and beta to a whole example/whole channel, please use GroupNorm/
scale and bias to a whole example/whole channel, please use GroupNorm/
InstanceNorm.

Parameters
----------
normalized_shape: int, sequence of int
The input shape from an expected input of size

.. math::
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
\times \ldots \times \text{normalized\_shape}[-1]]

If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
epsilon: float
a value added to the denominator for numerical stability. Default: 1e-5
use_bias: bool
whether to translate data in refactoring. Default: True
use_scale: bool
whether to scale data in refactoring. Default: True
beta_init: brainpy.init.Initializer
bias_init: Initializer
an initializer generating the original translation matrix
gamma_init: brainpy.init.Initializer
scale_init: Initializer
an initializer generating the original scaling matrix
axis: int, tuple, list
axes where the data will be normalized. The batch axis should be excluded.
elementwise_affine: bool
A boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.

Examples
--------
>>> import brainpy as bp
>>> import brainpy.math as bm
>>>
>>> # NLP Example
>>> batch, sentence_length, embedding_dim = 20, 5, 10
>>> embedding = bm.random.randn(batch, sentence_length, embedding_dim)
>>> layer_norm = bp.layers.LayerNorm(embedding_dim)
>>> # Activate module
>>> layer_norm(embedding)
>>>
>>> # Image Example
>>> N, C, H, W = 20, 5, 10, 10
>>> input = bm.random.randn(N, H, W, C)
>>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
>>> # as shown in the image below
>>> layer_norm = bp.layers.LayerNorm([H, W, C])
>>> output = layer_norm(input)

"""
def __init__(self,
epsilon: float = 1e-5,
use_bias: bool = True,
use_scale: bool = True,
beta_init: Initializer = ZeroInit(),
gamma_init: Initializer = OneInit(),
axis: Union[int, tuple] = None,
mode: Mode = training,
name: str = None,
**kwargs):

def __init__(
self,
normalized_shape: Union[int, Sequence[int]],
epsilon: float = 1e-5,
bias_init: Initializer = ZeroInit(),
scale_init: Initializer = OneInit(),
elementwise_affine: bool = True,
mode: Mode = training,
name: str = None
):
super(LayerNorm, self).__init__(name=name, mode=mode)
self.epsilon = epsilon
self.bias = use_bias
self.scale = use_scale
self.beta_init = beta_init if use_bias else ()
self.gamma_init = gamma_init if use_scale else ()
self.axis = (axis,) if jnp.isscalar(axis) else axis

def default_axis(self, x):
# default: the first axis (batch dim) is excluded
return tuple(i for i in range(1, len(x.shape)))
self.epsilon = epsilon
self.bias_init = bias_init
self.scale_init = scale_init
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape, )
self.normalized_shape = tuple(normalized_shape)
assert all([isinstance(s, int) for s in normalized_shape]), 'Must be a sequence of integer.'
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
assert isinstance(self.mode, TrainingMode)
self.bias = bm.TrainVar(parameter(self.bias_init, self.normalized_shape))
self.scale = bm.TrainVar(parameter(self.scale_init, self.normalized_shape))

def update(self, sha, x):
if self.axis is None:
self.axis = self.default_axis(x)
# todo: what if elementwise_affine = False?
input_shape = tuple(d for i, d in enumerate(x.shape) if i in self.axis)
self.beta = parameter(self.beta_init, input_shape) if self.bias else None
self.gamma = parameter(self.gamma_init, input_shape) if self.scale else None
if isinstance(self.mode, TrainingMode):
self.beta = bm.TrainVar(self.beta)
self.gamma = bm.TrainVar(self.gamma)

ed = tuple(None if i not in self.axis else slice(None) for i in range(jnp.ndim(x)))
output = bm.normalize(x, self.axis, epsilon=self.epsilon)
if self.bias and self.scale: return self.gamma[ed] * output + self.beta[ed]
if self.bias: return output + self.beta[ed]
if self.scale: return self.gamma[ed] * output
return output
if x.shape[-len(self.normalized_shape):] != self.normalized_shape:
raise ValueError(f'Expect the input shape should be (..., {", ".join(self.normalized_shape)}), '
f'but we got {x.shape}')
axis = tuple(range(0, x.ndim - len(self.normalized_shape)))
mean = jnp.mean(bm.as_jax(x), axis=axis, keepdims=True)
variance = jnp.var(bm.as_jax(x), axis=axis, keepdims=True)
inv = lax.rsqrt(variance + lax.convert_element_type(self.epsilon, x.dtype))
out = (x - mean) * inv
if self.elementwise_affine:
out = self.scale * out + self.bias
return out

def reset_state(self, batch_size=None):
pass
@@ -265,107 +377,88 @@ class LayerNorm(DynamicalSystem):
class GroupNorm(DynamicalSystem):
"""Group normalization layer.

.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta


This layer divides channels into groups and normalizes the features within each
group. Its computation is also independent of the batch size. The feature size
must be multiple of the group size.

The shape of the data should be (b, d1, d2, ..., c), where `d` denotes the batch
size and `c` denotes the feature (channel) size. The `d` and `c` axis should be
excluded in parameter `axis`.
size and `c` denotes the feature (channel) size.

Parameters
----------
num_groups: int
the number of groups. It should be a factor of the number of features.
group_size: int
the group size. It should equal to int(num_features / num_groups).
Either `num_groups` or `group_size` should be specified.
The number of groups. It should be a factor of the number of channels.
num_channels: int
The number of channels expected in input.
epsilon: float
a value added to the denominator for numerical stability. Default: 1e-5
use_bias: bool
whether to translate data in refactoring. Default: True
use_scale: bool
whether to scale data in refactoring. Default: True
beta_init: brainpy.init.Initializer
an initializer generating the original translation matrix
gamma_init: brainpy.init.Initializer
an initializer generating the original scaling matrix
axis: int, tuple, list
axes where the data will be normalized. Besides the batch axis, the channel
axis should be also excluded, since it will be automatically added to `axis`.
affine: bool
A boolean value that when set to ``True``, this module
has learnable per-channel affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
bias_init: Initializer
An initializer generating the original translation matrix
scale_init: Initializer
An initializer generating the original scaling matrix

Examples
--------
>>> import brainpy as bp
>>> import brainpy.math as bm
>>> input = bm.random.randn(20, 10, 10, 6)
>>> # Separate 6 channels into 3 groups
>>> m = bp.layers.GroupNorm(3, 6)
>>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
>>> m = bp.layers.GroupNorm(6, 6)
>>> # Put all 6 channels into a single group (equivalent with LayerNorm)
>>> m = bp.layers.GroupNorm(1, 6)
>>> # Activating the module
>>> output = m(input)
"""
def __init__(self,
num_groups: int = None,
group_size: int = None,
epsilon: float = 1e-5,
use_bias: bool = True,
use_scale: bool = True,
beta_init: Initializer = ZeroInit(),
gamma_init: Initializer = OneInit(),
axis: Union[int, tuple] = None,
mode: Mode = training,
name: str = None,
**kwargs):
def __init__(
self,
num_groups: int,
num_channels: int,
epsilon: float = 1e-5,
affine: bool = True,
bias_init: Initializer = ZeroInit(),
scale_init: Initializer = OneInit(),
mode: Mode = training,
name: str = None,
):
super(GroupNorm, self).__init__(name=name, mode=mode)
if num_channels % num_groups != 0:
raise ValueError('num_channels must be divisible by num_groups')
self.num_groups = num_groups
self.group_size = group_size
self.num_channels = num_channels
self.epsilon = epsilon
self.bias = use_bias
self.scale = use_scale
self.beta_init = beta_init if use_bias else ()
self.gamma_init = gamma_init if use_scale else ()
self.norm_axis = (axis,) if jnp.isscalar(axis) else axis
self.affine = affine
self.bias_init = bias_init
self.scale_init = scale_init
if self.affine:
assert isinstance(self.mode, TrainingMode)
self.bias = bm.TrainVar(parameter(self.bias_init, self.num_channels))
self.scale = bm.TrainVar(parameter(self.scale_init, self.num_channels))

def update(self, sha, x):
num_channels = x.shape[-1]
self.ndim = len(x)

# compute num_groups and group_size
if ((self.num_groups is None and self.group_size is None) or
(self.num_groups is not None and self.group_size is not None)):
raise ValueError('Either `num_groups` or `group_size` should be specified. '
'Once one is specified, the other will be automatically '
'computed.')

if self.num_groups is None:
assert self.group_size > 0, '`group_size` should be a positive integer.'
if num_channels % self.group_size != 0:
raise ValueError('The number of channels ({}) is not multiple of the '
'group size ({}).'.format(num_channels, self.group_size))
else:
self.num_groups = num_channels // self.group_size
else: # self.num_groups is not None:
assert self.num_groups > 0, '`num_groups` should be a positive integer.'
if num_channels % self.num_groups != 0:
raise ValueError('The number of channels ({}) is not multiple of the '
'number of groups ({}).'.format(num_channels, self.num_groups))
else:
self.group_size = num_channels // self.num_groups

# axes for normalization
if self.norm_axis is None:
# default: the first axis (batch dim) and the second-last axis (num_group dim) are excluded
self.norm_axis = tuple(i for i in range(1, len(x.shape) - 1)) + (self.ndim,)

group_shape = x.shape[:-1] + (self.num_groups, self.group_size)
input_shape = tuple(d for i, d in enumerate(group_shape) if i in self.norm_axis)
self.beta = parameter(self.beta_init, input_shape) if self.bias else None
self.gamma = parameter(self.gamma_init, input_shape) if self.scale else None
if isinstance(self.mode, TrainingMode):
self.beta = bm.TrainVar(self.beta)
self.gamma = bm.TrainVar(self.gamma)

group_shape = x.shape[:-1] + (self.num_groups, self.group_size)
ff_reshape = x.reshape(group_shape)
ed = tuple(None if i not in self.norm_axis else slice(None) for i in range(jnp.ndim(ff_reshape)))
output = bm.normalize(ff_reshape, self.norm_axis, epsilon=self.epsilon)
if self.bias and self.scale:
output = self.gamma[ed] * output + self.beta[ed]
elif self.bias:
output = output + self.beta[ed]
elif self.scale:
output = self.gamma[ed] * output
return output.reshape(x.shape)
assert x.shape[-1] == self.num_channels
origin_shape, origin_dim = x.shape, x.ndim
group_shape = (-1,) + x.shape[1:-1] + (self.num_groups, self.num_channels // self.num_groups)
x = bm.as_jax(x.reshape(group_shape))
reduction_axes = tuple(range(1, x.ndim - 1)) + (-1,)
mean = jnp.mean(x, reduction_axes, keepdims=True)
var = jnp.var(x, reduction_axes, keepdims=True)
x = (x - mean) * lax.rsqrt(var + lax.convert_element_type(self.epsilon, x.dtype))
x = x.reshape(origin_shape)
if self.affine:
x = x * lax.broadcast_to_rank(self.scale, origin_dim)
x = x + lax.broadcast_to_rank(self.bias, origin_dim)
return x


class InstanceNorm(GroupNorm):
@@ -376,28 +469,35 @@ class InstanceNorm(GroupNorm):

Parameters
----------
num_channels: int
The number of channels expected in input.
epsilon: float
a value added to the denominator for numerical stability. Default: 1e-5
use_bias: bool
whether to translate data in refactoring. Default: True
use_scale: bool
whether to scale data in refactoring. Default: True
beta_init: brainpy.init.Initializer
affine: bool
A boolean value that when set to ``True``, this module
has learnable per-channel affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
bias_init: Initializer
an initializer generating the original translation matrix
gamma_init: brainpy.init.Initializer
scale_init: Initializer
an initializer generating the original scaling matrix
axis: int, tuple, list
axes where the data will be normalized. The batch and channel axes
should be excluded.
"""
def __init__(self,
epsilon: float = 1e-5,
use_bias: bool = True,
use_scale: bool = True,
beta_init: Initializer = ZeroInit(),
gamma_init: Initializer = OneInit(),
axis: Union[int, tuple] = None,
**kwargs):
super(InstanceNorm, self).__init__(group_size=1, epsilon=epsilon, use_bias=use_bias,
use_scale=use_scale, beta_init=beta_init,
gamma_init=gamma_init, axis=axis, **kwargs)

def __init__(
self,
num_channels: int,
epsilon: float = 1e-5,
affine: bool = True,
bias_init: Initializer = ZeroInit(),
scale_init: Initializer = OneInit(),
mode: Mode = training,
name: str = None,
):
super(InstanceNorm, self).__init__(num_channels=num_channels,
num_groups=num_channels,
epsilon=epsilon,
affine=affine,
bias_init=bias_init,
scale_init=scale_init,
mode=mode,
name=name)

+ 3
- 3
brainpy/dyn/layers/nvar.py View File

@@ -8,7 +8,7 @@ import numpy as np

import brainpy.math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.modes import Mode, NormalMode, BatchingMode, batching, check
from brainpy.modes import Mode, NormalMode, BatchingMode, batching, check_mode
from brainpy.tools.checking import (check_integer, check_sequence)

__all__ = [
@@ -73,7 +73,7 @@ class NVAR(DynamicalSystem):
name: str = None,
):
super(NVAR, self).__init__(mode=mode, name=name)
check(self.mode, (BatchingMode, NormalMode), self.__class__.__name__)
check_mode(self.mode, (BatchingMode, NormalMode), self.__class__.__name__)

# parameters
order = tuple() if order is None else order
@@ -82,7 +82,7 @@ class NVAR(DynamicalSystem):
self.order = tuple(order)
check_sequence(order, 'order', allow_none=False)
for o in order:
check_integer(o, 'delay', allow_none=False, min_bound=2)
check_integer(o, 'order', allow_none=False, min_bound=2)
check_integer(delay, 'delay', allow_none=False, min_bound=1)
check_integer(stride, 'stride', allow_none=False, min_bound=1)
assert isinstance(constant, bool), f'Must be an instance of boolean, but got {constant}.'


+ 242
- 124
brainpy/dyn/layers/pooling.py View File

@@ -1,159 +1,277 @@
# -*- coding: utf-8 -*-

from typing import Union, Tuple, Sequence, Optional, Any, TypeVar

import numpy as np
from jax import lax

import jax.lax
import brainpy.math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.modes import Mode, TrainingMode, NormalMode, training, check
from brainpy.modes import Mode, training, BatchingMode
from brainpy.types import Array

__all__ = [
'Pool',
'MaxPool',
'AvgPool',
'MinPool'
]

T = TypeVar('T')

class Pool(DynamicalSystem):
def __init__(self, init_v, reduce_fn, window_shape, strides, padding,
mode: Mode = training,
name: str = None,
**kwargs):
"""Pooling functions are implemented using the ReduceWindow XLA op.

Args:
init_v: scalar
the initial value for the reduction
reduce_fn: callable
a reduce function of the form `(T, T) -> T`.
window_shape: tuple
a shape tuple defining the window to reduce over.
strides: sequence[int]
a sequence of `n` integers, representing the inter-window strides.
padding: str, sequence[int]
either the string `'SAME'`, the string `'VALID'`, or a sequence
of `n` `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.

Returns:
The output of the reduction for each window slice.
"""
super(Pool, self).__init__(name=name, mode=mode)
self.init_v = init_v
self.reduce_fn = reduce_fn
self.window_shape = window_shape
self.strides = strides or (1,) * len(window_shape)
assert len(self.window_shape) == len(self.strides), (
f"len({self.window_shape}) must equal len({self.strides})")
self.strides = (1,) + self.strides + (1,)
self.dims = (1,) + window_shape + (1,)
self.is_single_input = False

if not isinstance(padding, str):
padding = tuple(map(tuple, padding))
assert len(padding) == len(window_shape), (
f"padding {padding} must specify pads for same number of dims as "
f"window_shape {window_shape}")
assert all([len(x) == 2 for x in padding]), (
f"each entry in padding {padding} must be length 2")
padding = ((0, 0),) + padding + ((0, 0),)
self.padding = padding

def update(self, sha, x):
input_shapes = tuple(d for d in x.shape if d is not None)
assert len(input_shapes) == len(self.dims), f"len({len(input_shapes)}) != len({self.dims})"
def _infer_shape(x: Array,
mode: Mode,
size: Union[T, Sequence[T]],
channel_axis: Optional[int] = None,
element: T = 1):
"""Infer shape for pooling window or strides."""

# padding_vals = jax.lax.padtype_to_pads(input_shapes, self.dims, self.strides, self.padding)
# ones = (1,) * len(self.dims)
# out_shapes = jax.lax.reduce_window_shape_tuple(
# input_shapes, self.dims, self.strides, padding_vals, ones, ones)
#
# out_shapes = tuple((None,)) + tuple(d for i, d in enumerate(out_shapes) if i != 0)
# channel axis
if channel_axis and not 0 <= abs(channel_axis) < x.ndim:
raise ValueError(f"Invalid channel axis {channel_axis} for {x.shape}")
if channel_axis and channel_axis < 0:
channel_axis = x.ndim + channel_axis

y = jax.lax.reduce_window(x, self.init_v, self.reduce_fn, self.dims, self.strides, self.padding)
if isinstance(size, (tuple, list)):
assert isinstance(size, (tuple, list)), "Should be a tuple/list of integer."
size = tuple(size)
if len(size) > x.ndim:
raise ValueError(f'Invalid size {size}. Its dimension is bigger than its input.')
elif len(size) == x.ndim:
return size
else:
if isinstance(mode, BatchingMode):
size = (element,) + size
if len(size) + 1 == x.ndim:
if channel_axis is None:
raise ValueError('"channel_axis" should be provided.')
size = size[:channel_axis] + (element,) + size[channel_axis:]
else:
raise ValueError(f'size {size} is invalid. Please provide more elements.')
return size

return y
else:
if isinstance(mode, BatchingMode):
return (element,) + tuple((size if d != channel_axis else element) for d in range(1, x.ndim))
else:
return tuple((size if d != channel_axis else element) for d in range(0, x.ndim))


class AvgPool(Pool):
"""Pools the input by taking the average over a window.
class Pool(DynamicalSystem):
"""Pooling functions are implemented using the ReduceWindow XLA op.

Parameters
----------
window_shape: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
strides: int, sequence of int
An integer, or a sequence of integers, representing the inter-window strides (default: `(1, ..., 1)`).
padding: str, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped,
used to infer ``window_shape`` or ``strides`` if they are an integer.
mode: Mode
The computation mode.
name: optional, str
The object name.

Args:
window_shape: tuple
a shape tuple defining the window to reduce over.
strides: sequence[int]
a sequence of `n` integers, representing the inter-window strides (default: `(1, ..., 1)`).
padding: str, sequence[int]
either the string `'SAME'`, the string `'VALID'`, or a sequence
of `n` `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension (default: `'VALID'`).

Returns:
The average for each window slice.
"""

def __init__(self, window_shape, strides=None, padding="VALID"):
super(AvgPool, self).__init__(
init_v=0.,
reduce_fn=jax.lax.add,
window_shape=window_shape,
strides=strides,
padding=padding
)
def __init__(
self,
init_value,
computation,
window_shape: Union[int, Sequence[int]],
strides: Union[int, Sequence[int]],
padding: Union[str, Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = None,
mode: Mode = training,
name: Optional[str] = None,
):
super(Pool, self).__init__(mode=mode, name=name)

self.init_value = init_value
self.computation = computation
self.window_shape = window_shape
self.strides = strides
self.padding = padding
self.channel_axis = channel_axis
if isinstance(padding, str):
if padding not in ("SAME", "VALID"):
raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")
else:
assert all([isinstance(x, (tuple, list)) for x in padding]), \
f'padding should be sequence of Tuple[int, int]. {padding}'
assert all([len(x) == 2 for x in padding]), f"each entry in padding {padding} must be length 2"

def update(self, sha, x):
y = jax.lax.reduce_window(x, self.init_v, self.reduce_fn, self.dims, self.strides, self.padding)
y = y / bm.prod(bm.asarray(self.window_shape))
return y
window_shape = _infer_shape(x, self.mode, self.window_shape, self.channel_axis)
strides = _infer_shape(x, self.mode, self.strides, self.channel_axis)
padding = (self.padding if isinstance(self.padding, str) else
_infer_shape(x, self.mode, self.padding, self.channel_axis, element=(0, 0)))
return lax.reduce_window(bm.as_jax(x),
init_value=self.init_value,
computation=self.computation,
window_dimensions=window_shape,
window_strides=strides,
padding=padding)

def reset_state(self, batch_size=None):
pass


class MaxPool(Pool):
"""Pools the input by taking the maximum over a window.

Args:
window_shape: tuple
a shape tuple defining the window to reduce over.
strides: sequence[int]
a sequence of `n` integers, representing the inter-window strides (default: `(1, ..., 1)`).
padding: str, sequence[int]
either the string `'SAME'`, the string `'VALID'`, or a sequence
of `n` `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension (default: `'VALID'`).

Returns:
The maximum for each window slice.
Parameters
----------
window_shape: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
strides: int, sequence of int
An integer, or a sequence of integers, representing the inter-window strides (default: `(1, ..., 1)`).
padding: str, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped,
used to infer ``window_shape`` or ``strides`` if they are an integer.
mode: Mode
The computation mode.
name: optional, str
The object name.

"""
def __init__(self, window_shape, strides=None, padding="VALID"):
super(MaxPool, self).__init__(
init_v=-bm.inf,
reduce_fn=jax.lax.max,
window_shape=window_shape,
strides=strides,
padding=padding
)

def __init__(
self,
window_shape: Union[int, Sequence[int]],
strides: Union[int, Sequence[int]],
padding: Union[str, Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = None,
mode: Mode = training,
name: Optional[str] = None,
):
super(MaxPool, self).__init__(init_value=-bm.inf,
computation=lax.max,
window_shape=window_shape,
strides=strides,
padding=padding,
channel_axis=channel_axis,
mode=mode,
name=name)


class MinPool(Pool):
"""Pools the input by taking the minimum over a window.

Args:
window_shape: tuple
a shape tuple defining the window to reduce over.
strides: sequence[int]
a sequence of `n` integers, representing the inter-window strides (default: `(1, ..., 1)`).
padding: str, sequence[int]
either the string `'SAME'`, the string `'VALID'`, or a sequence
of `n` `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension (default: `'VALID'`).

Returns:
The minimum for each window slice.
"""
def __init__(self, window_shape, strides=None, padding="VALID"):
super(MinPool, self).__init__(
init_v=bm.inf,
reduce_fn=jax.lax.min,
window_shape=window_shape,
strides=strides,
padding=padding
)
Parameters
----------
window_shape: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
strides: int, sequence of int
An integer, or a sequence of integers, representing the inter-window strides (default: `(1, ..., 1)`).
padding: str, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped,
used to infer ``window_shape`` or ``strides`` if they are an integer.
mode: Mode
The computation mode.
name: optional, str
The object name.

"""

def __init__(
self,
window_shape: Union[int, Sequence[int]],
strides: Union[int, Sequence[int]],
padding: Union[str, Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = None,
mode: Mode = training,
name: Optional[str] = None,
):
super(MinPool, self).__init__(init_value=bm.inf,
computation=lax.min,
window_shape=window_shape,
strides=strides,
padding=padding,
channel_axis=channel_axis,
mode=mode,
name=name)


class AvgPool(Pool):
"""Pools the input by taking the average over a window.


Parameters
----------
window_shape: int, sequence of int
An integer, or a sequence of integers defining the window to reduce over.
strides: int, sequence of int
An integer, or a sequence of integers, representing the inter-window strides (default: `(1, ..., 1)`).
padding: str, sequence of tuple
Either the string `'SAME'`, the string `'VALID'`, or a sequence
of n `(low, high)` integer pairs that give the padding to apply before
and after each spatial dimension.
channel_axis: int, optional
Axis of the spatial channels for which pooling is skipped,
used to infer ``window_shape`` or ``strides`` if they are an integer.
mode: Mode
The computation mode.
name: optional, str
The object name.

"""

def __init__(
self,
window_shape: Union[int, Sequence[int]],
strides: Union[int, Sequence[int]],
padding: Union[str, Sequence[Tuple[int, int]]] = "VALID",
channel_axis: Optional[int] = None,
mode: Mode = training,
name: Optional[str] = None,
):
super(AvgPool, self).__init__(init_value=0.,
computation=lax.add,
window_shape=window_shape,
strides=strides,
padding=padding,
channel_axis=channel_axis,
mode=mode,
name=name)

def update(self, sha, x):
window_shape = _infer_shape(x, self.mode, self.window_shape, self.channel_axis)
strides = _infer_shape(x, self.mode, self.strides, self.channel_axis)
padding = (self.padding if isinstance(self.padding, str) else
_infer_shape(x, self.mode, self.padding, self.channel_axis, element=(0, 0)))
pooled = lax.reduce_window(bm.as_jax(x),
init_value=self.init_value,
computation=self.computation,
window_dimensions=window_shape,
window_strides=strides,
padding=padding)
if padding == "VALID":
# Avoid the extra reduce_window.
return pooled / np.prod(window_shape)
else:
# Count the number of valid entries at each input point, then use that for
# computing average. Assumes that any two arrays of same shape will be
# padded the same.
window_counts = lax.reduce_window(bm.ones_like(x).value,
init_value=self.init_value,
computation=self.computation,
window_dimensions=window_shape,
window_strides=strides,
padding=padding)
assert pooled.shape == window_counts.shape
return pooled / window_counts

+ 80
- 38
brainpy/dyn/layers/rnncells.py View File

@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

import warnings
from typing import Union, Callable

import brainpy.math as bm
@@ -17,19 +17,22 @@ from brainpy.tools.checking import (check_integer,
from brainpy.types import Array

__all__ = [
'VanillaRNN',
'GRU',
'LSTM',
'RNNCell', 'GRUCell', 'LSTMCell',

# deprecated
'VanillaRNN', 'GRU', 'LSTM',
]


class RecurrentCell(DynamicalSystem):
def __init__(self,
num_out: int,
state_initializer: Union[Array, Callable, Initializer] = ZeroInit(),
mode: Mode = training,
train_state: bool = False,
name: str = None):
def __init__(
self,
num_out: int,
state_initializer: Union[Array, Callable, Initializer] = ZeroInit(),
mode: Mode = training,
train_state: bool = False,
name: str = None
):
super(RecurrentCell, self).__init__(mode=mode, name=name)

# parameters
@@ -40,7 +43,7 @@ class RecurrentCell(DynamicalSystem):
self.train_state = train_state


class VanillaRNN(RecurrentCell):
class RNNCell(RecurrentCell):
r"""Basic fully-connected RNN core.

Given :math:`x_t` and the previous hidden state :math:`h_{t-1}` the
@@ -68,8 +71,6 @@ class VanillaRNN(RecurrentCell):
activation: str, callable
The activation function. It can be a string or a callable function.
See ``brainpy.math.activations`` for more details.
trainable: bool
Whether set the node is trainable.

"""

@@ -86,11 +87,11 @@ class VanillaRNN(RecurrentCell):
train_state: bool = False,
name: str = None,
):
super(VanillaRNN, self).__init__(num_out=num_out,
state_initializer=state_initializer,
train_state=train_state,
mode=mode,
name=name)
super(RNNCell, self).__init__(num_out=num_out,
state_initializer=state_initializer,
train_state=train_state,
mode=mode,
name=name)

# parameters
self.num_in = num_in
@@ -137,7 +138,7 @@ class VanillaRNN(RecurrentCell):
return self.state.value


class GRU(RecurrentCell):
class GRUCell(RecurrentCell):
r"""Gated Recurrent Unit.

The implementation is based on (Chung, et al., 2014) [1]_ with biases.
@@ -174,8 +175,6 @@ class GRU(RecurrentCell):
activation: str, callable
The activation function. It can be a string or a callable function.
See ``brainpy.math.activations`` for more details.
trainable: bool
Whether set the node is trainable.

References
----------
@@ -197,11 +196,11 @@ class GRU(RecurrentCell):
train_state: bool = False,
name: str = None,
):
super(GRU, self).__init__(num_out=num_out,
state_initializer=state_initializer,
train_state=train_state,
mode=mode,
name=name)
super(GRUCell, self).__init__(num_out=num_out,
state_initializer=state_initializer,
train_state=train_state,
mode=mode,
name=name)
# parameters
self.num_in = num_in
check_integer(num_in, 'num_in', min_bound=1, allow_none=False)
@@ -259,7 +258,7 @@ class GRU(RecurrentCell):
return self.state.value


class LSTM(RecurrentCell):
class LSTMCell(RecurrentCell):
r"""Long short-term memory (LSTM) RNN core.

The implementation is based on (zaremba, et al., 2014) [1]_. Given
@@ -305,8 +304,6 @@ class LSTM(RecurrentCell):
activation: str, callable
The activation function. It can be a string or a callable function.
See ``brainpy.math.activations`` for more details.
trainable: bool
Whether set the node is trainable.

References
----------
@@ -331,11 +328,11 @@ class LSTM(RecurrentCell):
train_state: bool = False,
name: str = None,
):
super(LSTM, self).__init__(num_out=num_out,
state_initializer=state_initializer,
train_state=train_state,
mode=mode,
name=name)
super(LSTMCell, self).__init__(num_out=num_out,
state_initializer=state_initializer,
train_state=train_state,
mode=mode,
name=name)
# parameters
self.num_in = num_in
check_integer(num_in, 'num_in', min_bound=1, allow_none=False)
@@ -409,17 +406,62 @@ class LSTM(RecurrentCell):
self.state[self.state.shape[0] // 2:, :] = value


class ConvNDLSTM(DynamicalSystem):
class VanillaRNN(RNNCell):
"""Vanilla RNN.

.. deprecated:: 2.2.3.4
Use `RNNCell` instead. `VanillaRNN` will be removed since version 2.4.0.

"""

def __init__(self, *args, **kwargs):
super(VanillaRNN, self).__init__(*args, **kwargs)
warnings.warn('Use "brainpy.layers.RNNCell" instead. '
'"brainpy.layers.VanillaRNN" is deprecated and will be removed since 2.4.0.',
UserWarning)


class GRU(GRUCell):
"""GRU.

.. deprecated:: 2.2.3.4
Use `GRUCell` instead. `GRU` will be removed since version 2.4.0.

"""

def __init__(self, *args, **kwargs):
super(GRU, self).__init__(*args, **kwargs)
warnings.warn('Use "brainpy.layers.GRUCell" instead. '
'"brainpy.layers.GRU" is deprecated and will be removed since 2.4.0.',
UserWarning)


class LSTM(LSTMCell):
"""LSTM.

.. deprecated:: 2.2.3.4
Use `LSTMCell` instead. `LSTM` will be removed since version 2.4.0.

"""

def __init__(self, *args, **kwargs):
super(LSTM, self).__init__(*args, **kwargs)
warnings.warn('Use "brainpy.layers.LSTMCell" instead. '
'"brainpy.layers.LSTM" is deprecated and will be removed since 2.4.0.',
UserWarning)


class ConvNDLSTMCell(DynamicalSystem):
pass


class Conv1DLSTM(ConvNDLSTM):
class Conv1DLSTMCell(ConvNDLSTMCell):
pass


class Conv2DLSTM(ConvNDLSTM):
class Conv2DLSTMCell(ConvNDLSTMCell):
pass


class Conv3DLSTM(ConvNDLSTM):
class Conv3DLSTMCell(ConvNDLSTMCell):
pass

+ 0
- 201
brainpy/dyn/layers/tests/test_normalization.py View File

@@ -1,201 +0,0 @@
# -*- coding: utf-8 -*-


from unittest import TestCase

import brainpy as bp


class TestBatchNorm1d(TestCase):
def test_batchnorm1d1(self):
class BatchNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(BatchNormNet, self).__init__()
self.norm = bp.dyn.layers.BatchNorm1d(axis=(0, 1, 2))

def update(self, shared, x):
x = self.norm(shared, x)
return x

inputs = bp.math.ones((2, 3, 4))
inputs[0, 0, :] = 2.
inputs[0, 1, 0] = 5.
print(inputs)
model = BatchNormNet()
shared = {'fit': False}
print(model(shared, inputs))

def test_batchnorm1d2(self):
class BatchNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(BatchNormNet, self).__init__()
self.norm = bp.dyn.layers.BatchNorm1d()
self.dense = bp.dyn.layers.Dense(num_in=4, num_out=4)

def update(self, shared, x):
x = self.norm(shared, x)
x = self.dense(shared, x)
return x

inputs = bp.math.ones((2, 4))
inputs[0, :] = 2.
print(inputs)
model = BatchNormNet()
shared = {'fit': False}
print(model(shared, inputs))


class TestBatchNorm2d(TestCase):
def test_batchnorm2d(self):
class BatchNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(BatchNormNet, self).__init__()
self.norm = bp.dyn.layers.BatchNorm2d()

def update(self, shared, x):
x = self.norm(shared, x)
return x

inputs = bp.math.ones((10, 32, 32, 3))
inputs[0, 1, :, :] = 2.
print(inputs)
model = BatchNormNet()
shared = {'fit': False}
print(model(shared, inputs))


class TestBatchNorm3d(TestCase):
def test_batchnorm3d(self):
class BatchNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(BatchNormNet, self).__init__()
self.norm = bp.dyn.layers.BatchNorm3d()

def update(self, shared, x):
x = self.norm(shared, x)
return x

inputs = bp.math.ones((10, 32, 32, 16, 3))
print(inputs)
model = BatchNormNet()
shared = {'fit': False}
print(model(shared, inputs))


class TestBatchNorm(TestCase):
def test_batchnorm1(self):
class BatchNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(BatchNormNet, self).__init__()
self.norm = bp.dyn.layers.BatchNorm(axis=(0, 2), use_bias=False) # channel axis: 1

def update(self, shared, x):
x = self.norm(shared, x)
return x

inputs = bp.math.ones((2, 3, 4))
inputs[0, 0, :] = 2.
inputs[0, 1, 0] = 5.
print(inputs)
model = BatchNormNet()
shared = {'fit': False}
print(model(shared, inputs))

def test_batchnorm2(self):
class BatchNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(BatchNormNet, self).__init__()
self.norm = bp.dyn.layers.BatchNorm(axis=(0, 2)) # channel axis: 1
self.dense = bp.dyn.layers.Dense(num_in=12, num_out=2)

def update(self, shared, x):
x = self.norm(shared, x)
x = x.reshape(-1, 12)
x = self.dense(shared, x)
return x

inputs = bp.math.ones((2, 3, 4))
inputs[0, 0, :] = 2.
inputs[0, 1, 0] = 5.
# print(inputs)
model = BatchNormNet()
shared = {'fit': False}
print(model(shared, inputs))


class TestLayerNorm(TestCase):
def test_layernorm1(self):
class LayerNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(LayerNormNet, self).__init__()
self.norm = bp.dyn.layers.LayerNorm()

def update(self, shared, x):
x = self.norm(shared, x)
return x

inputs = bp.math.ones((2, 3, 4))
inputs[0, 0, :] = 2.
inputs[0, 1, 0] = 5.
print(inputs)
model = LayerNormNet()
shared = {'fit': False}
print(model(shared, inputs))

def test_layernorm2(self):
class LayerNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(LayerNormNet, self).__init__()
self.norm = bp.dyn.layers.LayerNorm(axis=2)

def update(self, shared, x):
x = self.norm(shared, x)
return x

inputs = bp.math.ones((2, 3, 4))
inputs[0, 0, :] = 2.
inputs[0, 1, 0] = 5.
print(inputs)
model = LayerNormNet()
shared = {'fit': False}
print(model(shared, inputs))


class TestInstanceNorm(TestCase):
def test_instancenorm(self):
class InstanceNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(InstanceNormNet, self).__init__()
self.norm = bp.dyn.layers.InstanceNorm()

def update(self, shared, x):
x = self.norm(shared, x)
return x

inputs = bp.math.ones((2, 3, 4))
inputs[0, 0, :] = 2.
inputs[0, 1, 0] = 5.
print(inputs)
model = InstanceNormNet()
shared = {'fit': False}
print(model(shared, inputs))


class TestGroupNorm(TestCase):
def test_groupnorm1(self):
class GroupNormNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(GroupNormNet, self).__init__()
self.norm = bp.dyn.layers.GroupNorm(num_groups=2)

def update(self, shared, x):
x = self.norm(shared, x)
return x

inputs = bp.math.ones((2, 3, 4))
inputs[0, 0, :] = 2.
inputs[0, 1, 0] = 5.
print(inputs)
model = GroupNormNet()
shared = {'fit': False}
print(model(shared, inputs))

+ 29
- 14
brainpy/dyn/layers/tests/test_pooling.py View File

@@ -1,41 +1,56 @@
# -*- coding: utf-8 -*-
import random

import pytest
from unittest import TestCase
import brainpy as bp
import jax.numpy as jnp
import jax
import numpy as np

import brainpy as bp
import brainpy.math as bm


class TestPool(TestCase):
def test_maxpool(self):
class MaxPoolNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(MaxPoolNet, self).__init__()
self.maxpool = bp.dyn.layers.MaxPool((2, 2))
self.maxpool = bp.dyn.layers.MaxPool((2, 2), 1, channel_axis=-1)

def update(self, sha, x):
x = self.maxpool(sha, x)
return x
return self.maxpool(sha, x)

x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
print(jnp.arange(9).reshape(3, 3))
print(x)
print(x.shape)
shared = {'fit': False}
net = MaxPoolNet()
y = net(shared, x)
print("out shape: ", y.shape)
expected_y = jnp.array([
[4., 5.],
[7., 8.],
]).reshape((1, 2, 2, 1))
expected_y = jnp.array([[4., 5.],
[7., 8.]]).reshape((1, 2, 2, 1))
np.testing.assert_allclose(y, expected_y)

def test_maxpool2(self):
class MaxPoolNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(MaxPoolNet, self).__init__()
self.maxpool = bp.dyn.layers.MaxPool((2, 2), (2, 2), channel_axis=-1)

def update(self, sha, x):
return self.maxpool(sha, x)

rng = bm.random.RandomState(123)
x = rng.rand(10, 20, 20, 4)
net = MaxPoolNet()
y = net(None, x)
print("out shape: ", y.shape)

def test_minpool(self):
class MinPoolNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(MinPoolNet, self).__init__()
self.maxpool = bp.dyn.layers.MinPool((2, 2))
self.maxpool = bp.dyn.layers.MinPool((2, 2), 1, channel_axis=-1)

def update(self, sha, x):
x = self.maxpool(sha, x)
@@ -56,7 +71,7 @@ class TestPool(TestCase):
class AvgPoolNet(bp.dyn.DynamicalSystem):
def __init__(self):
super(AvgPoolNet, self).__init__()
self.maxpool = bp.dyn.layers.AvgPool((2, 2))
self.maxpool = bp.dyn.layers.AvgPool((2, 2), 1, channel_axis=-1)

def update(self, sha, x):
x = self.maxpool(sha, x)
@@ -67,4 +82,4 @@ class TestPool(TestCase):
net = AvgPoolNet()
y = net(shared, x)
print("out shape: ", y.shape)
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.))
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.))

+ 53
- 56
brainpy/dyn/neurons/biological_models.py View File

@@ -4,11 +4,11 @@ from typing import Union, Callable, Optional

import brainpy.math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.initialize import OneInit, Uniform, Initializer, parameter, noise as init_noise, variable
from brainpy.initialize import OneInit, Uniform, Initializer, parameter, noise as init_noise, variable_
from brainpy.integrators.joint_eq import JointEq
from brainpy.integrators.ode import odeint
from brainpy.integrators.sde import sdeint
from brainpy.modes import Mode, BatchingMode, TrainingMode, NormalMode, normal, check
from brainpy.modes import Mode, BatchingMode, NormalMode, normal, check_mode
from brainpy.tools.checking import check_initializer
from brainpy.types import Shape, Array

@@ -219,7 +219,7 @@ class HH(NeuGroup):
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (BatchingMode, NormalMode), self.__class__.__name__)
check_mode(self.mode, (BatchingMode, NormalMode), self.__class__.__name__)

# parameters
self.ENa = parameter(ENa, self.varshape, allow_none=False)
@@ -243,21 +243,18 @@ class HH(NeuGroup):
self._V_initializer = V_initializer

# variables
self.V = variable(self._V_initializer, mode, self.varshape)
if self._m_initializer is None:
self.m = bm.Variable(self.m_inf(self.V.value))
else:
self.m = variable(self._m_initializer, mode, self.varshape)
if self._h_initializer is None:
self.h = bm.Variable(self.h_inf(self.V.value))
else:
self.h = variable(self._h_initializer, mode, self.varshape)
if self._n_initializer is None:
self.n = bm.Variable(self.n_inf(self.V.value))
else:
self.n = variable(self._n_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.V = variable_(self._V_initializer, self.varshape, mode)
self.m = (bm.Variable(self.m_inf(self.V.value))
if m_initializer is None else
variable_(self._m_initializer, self.varshape, mode))
self.h = (bm.Variable(self.h_inf(self.V.value))
if h_initializer is None else
variable_(self._h_initializer, self.varshape, mode))
self.n = (bm.Variable(self.n_inf(self.V.value))
if n_initializer is None else
variable_(self._n_initializer, self.varshape, mode))
self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)

# integral
if self.noise is None:
@@ -284,21 +281,21 @@ class HH(NeuGroup):
dn = lambda self, n, t, V: self.n_alpha(V) * (1 - n) - self.n_beta(V) * n

def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
if self._m_initializer is None:
self.m.value = self.m_inf(self.V.value)
else:
self.m.value = variable(self._m_initializer, batch_size, self.varshape)
self.m.value = variable_(self._m_initializer, self.varshape, batch_size)
if self._h_initializer is None:
self.h.value = self.h_inf(self.V.value)
else:
self.h.value = variable(self._h_initializer, batch_size, self.varshape)
self.h.value = variable_(self._h_initializer, self.varshape, batch_size)
if self._n_initializer is None:
self.n.value = self.n_inf(self.V.value)
else:
self.n.value = variable(self._n_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.n.value = variable_(self._n_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)

def dV(self, V, t, m, h, n, I_ext):
I_Na = (self.gNa * m ** 3.0 * h) * (V - self.ENa)
@@ -309,7 +306,7 @@ class HH(NeuGroup):

@property
def derivative(self):
return JointEq([self.dV, self.dm, self.dh, self.dn])
return JointEq(self.dV, self.dm, self.dh, self.dn)

def update(self, tdi, x=None):
t, dt = tdi['t'], tdi['dt']
@@ -430,7 +427,7 @@ class MorrisLecar(NeuGroup):
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (BatchingMode, NormalMode), self.__class__)
check_mode(self.mode, (BatchingMode, NormalMode), self.__class__)

# params
self.V_Ca = parameter(V_Ca, self.varshape, allow_none=False)
@@ -455,10 +452,10 @@ class MorrisLecar(NeuGroup):
self._V_initializer = V_initializer

# variables
self.W = variable(self._W_initializer, mode, self.varshape)
self.V = variable(self._V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.W = variable_(self._W_initializer, self.varshape, mode)
self.V = variable_(self._V_initializer, self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)
self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode)

# integral
if self.noise is None:
@@ -467,10 +464,10 @@ class MorrisLecar(NeuGroup):
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset_state(self, batch_size=None):
self.W.value = variable(self._W_initializer, batch_size, self.varshape)
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.W.value = variable_(self._W_initializer, self.varshape, batch_size)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)

def dV(self, V, t, W, I_ext):
M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2))
@@ -688,7 +685,7 @@ class PinskyRinzelModel(NeuGroup):
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (NormalMode, BatchingMode), self.__class__)
check_mode(self.mode, (NormalMode, BatchingMode), self.__class__)

# conductance parameters
self.gAHP = parameter(gAHP, self.varshape, allow_none=False)
@@ -721,16 +718,16 @@ class PinskyRinzelModel(NeuGroup):
self._Ca_initializer = Ca_initializer

# variables
self.Vs = variable(self._Vs_initializer, mode, self.varshape)
self.Vd = variable(self._Vd_initializer, mode, self.varshape)
self.Ca = variable(self._Ca_initializer, mode, self.varshape)
self.Vs = variable_(self._Vs_initializer, self.varshape, mode)
self.Vd = variable_(self._Vd_initializer, self.varshape, mode)
self.Ca = variable_(self._Ca_initializer, self.varshape, mode)
self.h = bm.Variable(self.inf_h(self.Vs), batch_axis=0 if isinstance(mode, BatchingMode) else None)
self.n = bm.Variable(self.inf_n(self.Vs), batch_axis=0 if isinstance(mode, BatchingMode) else None)
self.s = bm.Variable(self.inf_s(self.Vd), batch_axis=0 if isinstance(mode, BatchingMode) else None)
self.c = bm.Variable(self.inf_c(self.Vd), batch_axis=0 if isinstance(mode, BatchingMode) else None)
self.q = bm.Variable(self.inf_q(self.Ca), batch_axis=0 if isinstance(mode, BatchingMode) else None)
self.Id = variable(bm.zeros, mode, self.varshape) # input to soma
self.Is = variable(bm.zeros, mode, self.varshape) # input to dendrite
self.Id = variable_(bm.zeros, self.varshape, mode) # input to soma
self.Is = variable_(bm.zeros, self.varshape, mode) # input to dendrite
# self.spike = bm.Variable(bm.zeros(self.varshape, dtype=bool))

# integral
@@ -740,17 +737,17 @@ class PinskyRinzelModel(NeuGroup):
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset_state(self, batch_size=None):
self.Vd.value = variable(self._Vd_initializer, batch_size, self.varshape)
self.Vs.value = variable(self._Vs_initializer, batch_size, self.varshape)
self.Ca.value = variable(self._Ca_initializer, batch_size, self.varshape)
self.Vd.value = variable_(self._Vd_initializer, self.varshape, batch_size)
self.Vs.value = variable_(self._Vs_initializer, self.varshape, batch_size)
self.Ca.value = variable_(self._Ca_initializer, self.varshape, batch_size)
batch_axis = 0 if isinstance(self.mode, BatchingMode) else None
self.h.value = bm.Variable(self.inf_h(self.Vs), batch_axis=batch_axis)
self.n.value = bm.Variable(self.inf_n(self.Vs), batch_axis=batch_axis)
self.s.value = bm.Variable(self.inf_s(self.Vd), batch_axis=batch_axis)
self.c.value = bm.Variable(self.inf_c(self.Vd), batch_axis=batch_axis)
self.q.value = bm.Variable(self.inf_q(self.Ca), batch_axis=batch_axis)
self.Id.value = variable(bm.zeros, batch_size, self.varshape)
self.Is.value = variable(bm.zeros, batch_size, self.varshape)
self.Id.value = variable_(bm.zeros, self.varshape, batch_size)
self.Is.value = variable_(bm.zeros, self.varshape, batch_size)
# self.spike[:] = False

def dCa(self, Ca, t, s, Vd):
@@ -997,7 +994,7 @@ class WangBuzsakiModel(NeuGroup):
):
# initialization
super(WangBuzsakiModel, self).__init__(size=size, keep_size=keep_size, name=name, mode=mode)
check(self.mode, (BatchingMode, NormalMode), self.__class__)
check_mode(self.mode, (BatchingMode, NormalMode), self.__class__)

# parameters
self.ENa = parameter(ENa, self.varshape, allow_none=False)
@@ -1020,11 +1017,11 @@ class WangBuzsakiModel(NeuGroup):
self._V_initializer = V_initializer

# variables
self.h = variable(self._h_initializer, mode, self.varshape)
self.n = variable(self._n_initializer, mode, self.varshape)
self.V = variable(self._V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.h = variable_(self._h_initializer, self.varshape, mode)
self.n = variable_(self._n_initializer, self.varshape, mode)
self.V = variable_(self._V_initializer, self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)
self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode)

# integral
if self.noise is None:
@@ -1033,11 +1030,11 @@ class WangBuzsakiModel(NeuGroup):
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset_state(self, batch_size=None):
self.h.value = variable(self._h_initializer, batch_size, self.varshape)
self.n.value = variable(self._n_initializer, batch_size, self.varshape)
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.h.value = variable_(self._h_initializer, self.varshape, batch_size)
self.n.value = variable_(self._n_initializer, self.varshape, batch_size)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)

def m_inf(self, V):
alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1)


+ 8
- 6
brainpy/dyn/neurons/input_groups.py View File

@@ -7,10 +7,11 @@ import jax.numpy as jnp
import brainpy.math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.errors import ModelBuildError
from brainpy.initialize import Initializer, parameter, variable
from brainpy.initialize import Initializer, parameter, variable_
from brainpy.modes import Mode, BatchingMode, normal
from brainpy.types import Shape, Array


__all__ = [
'InputGroup',
'OutputGroup',
@@ -138,7 +139,7 @@ class SpikeTimeGroup(NeuGroup):

# variables
self.i = bm.Variable(bm.zeros(1, dtype=bm.ditype()))
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode)
if need_sort:
sort_idx = bm.argsort(self.times)
self.indices.value = self.indices[sort_idx]
@@ -161,7 +162,7 @@ class SpikeTimeGroup(NeuGroup):

def reset_state(self, batch_size=None):
self.i[0] = 1
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)

def update(self, tdi, x=None):
self.spike[:] = False
@@ -192,8 +193,8 @@ class PoissonGroup(NeuGroup):
self.freqs = parameter(freqs, self.num, allow_none=False)

# variables
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.rng = bm.random.RandomState(seed=seed)
self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode)
self.rng = bm.random.RandomState(seed)

def update(self, tdi, x=None):
shape = (self.spike.shape[:1] + self.varshape) if isinstance(self.mode, BatchingMode) else self.varshape
@@ -204,4 +205,5 @@ class PoissonGroup(NeuGroup):
self.reset_state(batch_size)

def reset_state(self, batch_size=None):
self.spike.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)


+ 2
- 2
brainpy/dyn/neurons/noise_groups.py View File

@@ -63,13 +63,13 @@ class OUProcess(NeuGroup):
self.tau = init.parameter(tau, self.varshape, allow_none=False)

# variables
self.x = init.variable(lambda s: bm.ones(s) * self.mean, mode, self.varshape)
self.x = init.variable_(lambda s: bm.ones(s) * self.mean, self.varshape, mode)

# integral functions
self.integral = sdeint(f=self.df, g=self.dg, method=method)

def reset_state(self, batch_size=None):
self.x.value = init.variable(lambda s: bm.ones(s) * self.mean, batch_size, self.varshape)
self.x.value = init.variable_(lambda s: bm.ones(s) * self.mean, self.varshape, batch_size)

def df(self, x, t):
return (self.mean - x) / self.tau


+ 148
- 129
brainpy/dyn/neurons/reduced_models.py View File

@@ -1,15 +1,16 @@
# -*- coding: utf-8 -*-

from typing import Union, Callable
from typing import Union, Callable, Optional
from functools import partial

from jax.lax import stop_gradient

import brainpy.math as bm
from brainpy.dyn.base import NeuGroup
from brainpy.initialize import (ZeroInit, OneInit, Initializer,
parameter, variable, noise as init_noise)
parameter, variable_, noise as init_noise)
from brainpy.integrators import sdeint, odeint, JointEq
from brainpy.modes import Mode, NormalMode, BatchingMode, TrainingMode, normal, check
from brainpy.modes import Mode, NormalMode, BatchingMode, TrainingMode, normal, check_mode
from brainpy.tools.checking import check_initializer, check_callable
from brainpy.types import Shape, Array

@@ -87,7 +88,7 @@ class LeakyIntegrator(NeuGroup):
mode=mode,
keep_size=keep_size,
name=name)
check(self.mode, (TrainingMode, NormalMode), self.__class__)
check_mode(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
@@ -100,8 +101,8 @@ class LeakyIntegrator(NeuGroup):
self._V_initializer = V_initializer

# variables
self.V = variable(self._V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.V = variable_(self._V_initializer, self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)

# integral
if self.noise is None:
@@ -113,8 +114,8 @@ class LeakyIntegrator(NeuGroup):
return (-V + self.V_rest + self.R * I_ext) / self.tau

def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)

def update(self, tdi, x=None):
if x is not None: self.input += x
@@ -191,11 +192,11 @@ class LIF(NeuGroup):
V_th: Union[float, Array, Initializer, Callable] = 20.,
R: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 10.,
tau_ref: Union[float, Array, Initializer, Callable] = None,
tau_ref: Optional[Union[float, Array, Initializer, Callable]] = None,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
noise: Optional[Union[float, Array, Initializer, Callable]] = None,
method: str = 'exp_auto',
name: str = None,
name: Optional[str] = None,

# training parameter
mode: Mode = normal,
@@ -206,7 +207,7 @@ class LIF(NeuGroup):
name=name,
keep_size=keep_size,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)
check_mode(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
@@ -223,13 +224,13 @@ class LIF(NeuGroup):
self._V_initializer = V_initializer

# variables
self.V = variable(self._V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.V = variable_(self._V_initializer, self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)
sp_type = bm.dftype() if isinstance(mode, TrainingMode) else bool # the gradient of spike is a float
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode)
if self.tau_ref is not None:
self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape)
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, mode)
self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode)

# integral
if self.noise is None:
@@ -241,13 +242,13 @@ class LIF(NeuGroup):
return (-V + self.V_rest + self.R * I_ext) / self.tau

def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)
if self.tau_ref is not None:
self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape)
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size)
self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)

def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
@@ -419,7 +420,7 @@ class ExpIF(NeuGroup):
name=name,
mode=mode,
keep_size=keep_size, )
check(self.mode, (TrainingMode, NormalMode), self.__class__)
check_mode(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
@@ -437,13 +438,13 @@ class ExpIF(NeuGroup):
self._V_initializer = V_initializer

# variables
self.V = variable(V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.V = variable_(V_initializer, self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape)
self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode)
self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, mode)
if self.tau_ref is not None:
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode)

# integral
if self.noise is None:
@@ -452,13 +453,13 @@ class ExpIF(NeuGroup):
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)
self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size)
if self.tau_ref is not None:
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)

def derivative(self, V, t, I_ext):
exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
@@ -541,6 +542,7 @@ class AdExIF(NeuGroup):
R 1 \ Membrane resistance.
tau 10 ms Membrane time constant. Compute by R * C.
tau_w 30 ms Time constant of the adaptation current.
tau_ref 0. ms Refractory time.
============= ============== ======== ========================================================================================================================

**Model Variables**
@@ -552,6 +554,7 @@ class AdExIF(NeuGroup):
w 0 Adaptation current.
input 0 External and synaptic input current.
spike False Flag to mark whether the neuron is spiking.
refractory False Flag to mark whether the neuron is in refractory period.
t_last_spike -1e7 Last spike time stamp.
================== ================= =========================================================

@@ -575,32 +578,34 @@ class AdExIF(NeuGroup):
b: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 10.,
tau_w: Union[float, Array, Initializer, Callable] = 30.,
tau_ref: Optional[Union[float, Array, Initializer, Callable]] = 30.,
R: Union[float, Array, Initializer, Callable] = 1.,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
w_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Union[float, Array, Initializer, Callable] = None,
noise: Optional[Union[float, Array, Initializer, Callable]] = None,
method: str = 'exp_auto',
keep_size: bool = False,
mode: Mode = normal,
name: str = None
name: Optional[str] = None
):
super(AdExIF, self).__init__(size=size,
keep_size=keep_size,
name=name,
mode=mode, )
check(self.mode, (TrainingMode, NormalMode), self.__class__)
check_mode(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
self.V_reset = parameter(V_reset, self.varshape, allow_none=False)
self.V_th = parameter(V_th, self.varshape, allow_none=False)
self.V_T = parameter(V_T, self.varshape, allow_none=False)
self.delta_T = parameter(delta_T, self.varshape, allow_none=False)
self.a = parameter(a, self.varshape, allow_none=False)
self.b = parameter(b, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.tau = parameter(tau, self.varshape, allow_none=False)
self.tau_w = parameter(tau_w, self.varshape, allow_none=False)
self.R = parameter(R, self.varshape, allow_none=False)
self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True)
self.delta_T = parameter(delta_T, self.varshape, allow_none=False)
self.noise = init_noise(noise, self.varshape, num_vars=2)

# initializers
@@ -610,12 +615,15 @@ class AdExIF(NeuGroup):
self._w_initializer = w_initializer

# variables
self.V = variable(V_initializer, mode, self.varshape)
self.w = variable(w_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.V = variable_(V_initializer, self.varshape, mode)
self.w = variable_(w_initializer, self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)
sp_type = bm.dftype() if isinstance(mode, BatchingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)

self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode)
if self.tau_ref is not None:
self.refractory = variable_(partial(bm.zeros, dtype=bool), self.varshape, mode)
self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e8, self.varshape, mode)
# functions
if self.noise is None:
self.integral = odeint(method=method, f=self.derivative)
@@ -623,11 +631,16 @@ class AdExIF(NeuGroup):
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.w.value = variable(self._w_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
self.w.value = variable_(self._w_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=(bm.dftype()
if isinstance(self.mode, TrainingMode)
else bool)),
self.varshape, batch_size)
if self.tau_ref is not None:
self.refractory.value = variable_(partial(bm.zeros, dtype=bool), self.varshape, batch_size)
self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e8, self.varshape, batch_size)

def dV(self, V, t, w, I_ext):
exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
@@ -646,10 +659,16 @@ class AdExIF(NeuGroup):
t, dt = tdi.t, tdi.dt
if x is not None: self.input += x
V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt)
if self.tau_ref is not None:
refractory = (t - self.t_last_spike) <= self.tau_ref
V = bm.where(refractory, self.V.value, V)
spike = V >= self.V_th
self.V.value = bm.where(spike, self.V_reset, V)
self.w.value = bm.where(spike, w + self.b, w)
self.spike.value = spike
if self.tau_ref is not None:
self.refractory.value = bm.logical_or(refractory, spike)
self.t_last_spike.value = bm.where(spike, t, self.t_last_spike)

def clear_input(self):
self.input[:] = 0.
@@ -745,7 +764,7 @@ class QuaIF(NeuGroup):
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)
check_mode(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
@@ -763,13 +782,13 @@ class QuaIF(NeuGroup):
self._V_initializer = V_initializer

# variables
self.V = variable(V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.V = variable_(V_initializer, self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape)
self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode)
self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, mode)
if self.tau_ref is not None:
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode)

# integral
if self.noise is None:
@@ -778,13 +797,13 @@ class QuaIF(NeuGroup):
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)
self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size)
if self.tau_ref is not None:
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)

def derivative(self, V, t, I_ext):
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I_ext) / self.tau
@@ -914,7 +933,7 @@ class AdQuaIF(NeuGroup):
keep_size=keep_size,
name=name,
mode=mode, )
check(self.mode, (TrainingMode, NormalMode), self.__class__)
check_mode(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
@@ -935,12 +954,12 @@ class AdQuaIF(NeuGroup):
self._w_initializer = w_initializer

# variables
self.V = variable(V_initializer, mode, self.varshape)
self.w = variable(w_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.V = variable_(V_initializer, self.varshape, mode)
self.w = variable_(w_initializer, self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode)
self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode)

# integral
if self.noise is None:
@@ -949,12 +968,12 @@ class AdQuaIF(NeuGroup):
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.w.value = variable(self._w_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
self.w.value = variable_(self._w_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)
self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)

def dV(self, V, t, w, I_ext):
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I_ext) / self.tau
@@ -1098,7 +1117,7 @@ class GIF(NeuGroup):
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)
check_mode(self.mode, (TrainingMode, NormalMode), self.__class__)

# params
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
@@ -1129,13 +1148,13 @@ class GIF(NeuGroup):
self._Vth_initializer = Vth_initializer

# variables
self.I1 = variable(I1_initializer, mode, self.varshape)
self.I2 = variable(I2_initializer, mode, self.varshape)
self.V_th = variable(Vth_initializer, mode, self.varshape)
self.V = variable(V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.I1 = variable_(I1_initializer, self.varshape, mode)
self.I2 = variable_(I2_initializer, self.varshape, mode)
self.V_th = variable_(Vth_initializer, self.varshape, mode)
self.V = variable_(V_initializer, self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode)

# integral
if self.noise is None:
@@ -1144,13 +1163,13 @@ class GIF(NeuGroup):
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset_state(self, batch_size=None):
self.I1.value = variable(self._I1_initializer, batch_size, self.varshape)
self.I2.value = variable(self._I2_initializer, batch_size, self.varshape)
self.V_th.value = variable(self._Vth_initializer, batch_size, self.varshape)
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.I1.value = variable_(self._I1_initializer, self.varshape, batch_size)
self.I2.value = variable_(self._I2_initializer, self.varshape, batch_size)
self.V_th.value = variable_(self._Vth_initializer, self.varshape, batch_size)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)

def dI1(self, I1, t):
return - self.k1 * I1
@@ -1263,7 +1282,7 @@ class ALIFBellec2020(NeuGroup):
size=size,
keep_size=keep_size,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)
check_mode(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
@@ -1284,14 +1303,14 @@ class ALIFBellec2020(NeuGroup):
self._a_initializer = a_initializer

# variables
self.a = variable(a_initializer, mode, self.varshape)
self.V = variable(V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.a = variable_(a_initializer, self.varshape, mode)
self.V = variable_(V_initializer, self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode)
if self.tau_ref is not None:
self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape)
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, mode)
self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode)

# integral
if self.noise is None:
@@ -1310,14 +1329,14 @@ class ALIFBellec2020(NeuGroup):
return JointEq([self.dV, self.da])

def reset_state(self, batch_size=None):
self.a.value = variable(self._a_initializer, batch_size, self.varshape)
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.a.value = variable_(self._a_initializer, self.varshape, batch_size)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)
if self.tau_ref is not None:
self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape)
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size)
self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)

def update(self, tdi, x=None):
t, dt = tdi.t, tdi.dt
@@ -1455,7 +1474,7 @@ class Izhikevich(NeuGroup):
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)
check_mode(self.mode, (TrainingMode, NormalMode), self.__class__)

# params
self.a = parameter(a, self.varshape, allow_none=False)
@@ -1474,14 +1493,14 @@ class Izhikevich(NeuGroup):
self._u_initializer = u_initializer

# variables
self.u = variable(u_initializer, mode, self.varshape)
self.V = variable(V_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.u = variable_(u_initializer, self.varshape, mode)
self.V = variable_(V_initializer, self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode)
if self.tau_ref is not None:
self.t_last_spike = variable(lambda s: bm.ones(s) * -1e7, mode, self.varshape)
self.refractory = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, mode)
self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, mode)

# functions
if self.noise is None:
@@ -1490,14 +1509,14 @@ class Izhikevich(NeuGroup):
self.integral = sdeint(method=method, f=JointEq([self.dV, self.du]), g=self.noise)

def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.u.value = variable(self._u_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
self.u.value = variable_(self._u_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)
if self.tau_ref is not None:
self.t_last_spike.value = variable(lambda s: bm.ones(s) * -1e7, batch_size, self.varshape)
self.refractory.value = variable(lambda s: bm.zeros(s, dtype=bool), batch_size, self.varshape)
self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size)
self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)

def dV(self, V, t, u, I_ext):
dVdt = 0.04 * V * V + 5 * V + 140 - u + I_ext
@@ -1685,7 +1704,7 @@ class HindmarshRose(NeuGroup):
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)
check_mode(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.a = parameter(a, self.varshape, allow_none=False)
@@ -1708,12 +1727,12 @@ class HindmarshRose(NeuGroup):
self._z_initializer = z_initializer

# variables
self.V = variable(self._V_initializer, mode, self.varshape)
self.y = variable(self._y_initializer, mode, self.varshape)
self.z = variable(self._z_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.V = variable_(self._V_initializer, self.varshape, mode)
self.y = variable_(self._y_initializer, self.varshape, mode)
self.z = variable_(self._z_initializer, self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode)

# integral
if self.noise is None:
@@ -1722,12 +1741,12 @@ class HindmarshRose(NeuGroup):
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.y.value = variable(self._y_initializer, batch_size, self.varshape)
self.z.value = variable(self._z_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
self.y.value = variable_(self._y_initializer, self.varshape, batch_size)
self.z.value = variable_(self._z_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)

def dV(self, V, t, y, z, I_ext):
return y - self.a * V * V * V + self.b * V * V - z + I_ext
@@ -1864,7 +1883,7 @@ class FHN(NeuGroup):
keep_size=keep_size,
name=name,
mode=mode)
check(self.mode, (TrainingMode, NormalMode), self.__class__)
check_mode(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
self.a = parameter(a, self.varshape, allow_none=False)
@@ -1881,11 +1900,11 @@ class FHN(NeuGroup):
self._w_initializer = w_initializer

# variables
self.V = variable(self._V_initializer, mode, self.varshape)
self.w = variable(self._w_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.V = variable_(self._V_initializer, self.varshape, mode)
self.w = variable_(self._w_initializer, self.varshape, mode)
self.input = variable_(bm.zeros, self.varshape, mode)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike = variable(lambda s: bm.zeros(s, dtype=sp_type), mode, self.varshape)
self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode)

# integral
if self.noise is None:
@@ -1894,11 +1913,11 @@ class FHN(NeuGroup):
self.integral = sdeint(method=method, f=self.derivative, g=self.noise)

def reset_state(self, batch_size=None):
self.V.value = variable(self._V_initializer, batch_size, self.varshape)
self.w.value = variable(self._w_initializer, batch_size, self.varshape)
self.input.value = variable(bm.zeros, batch_size, self.varshape)
self.V.value = variable_(self._V_initializer, self.varshape, batch_size)
self.w.value = variable_(self._w_initializer, self.varshape, batch_size)
self.input.value = variable_(bm.zeros, self.varshape, batch_size)
sp_type = bm.dftype() if isinstance(self.mode, TrainingMode) else bool
self.spike.value = variable(lambda s: bm.zeros(s, dtype=sp_type), batch_size, self.varshape)
self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size)

def dV(self, V, t, w, I_ext):
return V - V * V * V / 3 - w + I_ext


+ 12
- 12
brainpy/dyn/rates/populations.py View File

@@ -180,8 +180,8 @@ class FHN(RateModel):
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.
self.input.value = bm.zeros_like(self.input)
self.input_y.value = bm.zeros_like(self.input_y)


class FeedbackFHN(RateModel):
@@ -375,8 +375,8 @@ class FeedbackFHN(RateModel):
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.
self.input.value = bm.zeros_like(self.input)
self.input_y.value = bm.zeros_like(self.input_y)


class QIF(RateModel):
@@ -558,8 +558,8 @@ class QIF(RateModel):
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.
self.input.value = bm.zeros_like(self.input)
self.input_y.value = bm.zeros_like(self.input_y)


class StuartLandauOscillator(RateModel):
@@ -700,8 +700,8 @@ class StuartLandauOscillator(RateModel):
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.
self.input.value = bm.zeros_like(self.input)
self.input_y.value = bm.zeros_like(self.input_y)


class WilsonCowanModel(RateModel):
@@ -857,8 +857,8 @@ class WilsonCowanModel(RateModel):
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.
self.input.value = bm.zeros_like(self.input)
self.input_y.value = bm.zeros_like(self.input_y)


class JansenRitModel(RateModel):
@@ -976,5 +976,5 @@ class ThresholdLinearModel(RateModel):
self.i.value = bm.maximum(self.i + di * dt, 0.)

def clear_input(self):
self.Ie[:] = 0.
self.Ii[:] = 0.
self.Ie.value = bm.zeros_like(self.Ie)
self.Ii.value = bm.zeros_like(self.Ii)

+ 4
- 5
brainpy/dyn/runners.py View File

@@ -566,8 +566,7 @@ class DSRunner(Runner):

monitor_func = self.build_monitors(self._mon_info[0], self._mon_info[1], shared_args)

def _step_func(inputs):
t, i, x = inputs
def _step_func(t, i, x):
self.target.clear_input()
# input step
shared = DotDict(t=t, i=i, dt=self.dt)
@@ -586,8 +585,8 @@ class DSRunner(Runner):
if self.jit['predict']:
dyn_vars = self.target.vars()
dyn_vars.update(self.dyn_vars)
f = bm.make_loop(_step_func, dyn_vars=dyn_vars.unique(), has_return=True)
run_func = lambda all_inputs: f(all_inputs)[1]
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
run_func = lambda all_inputs: bm.for_loop(_step_func, dyn_vars.unique(), all_inputs)

else:
def run_func(xs):
@@ -601,7 +600,7 @@ class DSRunner(Runner):
x = tree_map(lambda x: x[i], xs, is_leaf=lambda x: isinstance(x, bm.JaxArray))

# step at the i
output, mon = _step_func((times[i], indices[i], x))
output, mon = _step_func(times[i], indices[i], x)

# append output and monitor
outputs.append(output)


+ 117
- 29
brainpy/dyn/synapses/abstract_models.py View File

@@ -7,19 +7,22 @@ from jax.lax import stop_gradient

import brainpy.math as bm
from brainpy.connect import TwoEndConnector, All2All, One2One
from brainpy.dyn.base import NeuGroup, SynOut, SynSTP, TwoEndConn
from brainpy.initialize import Initializer, variable
from brainpy.dyn.base import NeuGroup, SynOut, SynSTP, TwoEndConn, SynConn
from brainpy.initialize import Initializer, variable_
from brainpy.integrators import odeint, JointEq
from brainpy.modes import Mode, BatchingMode, normal
from brainpy.tools.checking import check_integer, check_float
from brainpy.modes import Mode, BatchingMode, normal, NormalMode, check_mode
from brainpy.types import Array
from ..synouts import CUBA, MgBlock


__all__ = [
'Delta',
'Exponential',
'DualExponential',
'Alpha',
'NMDA',
'PoissonInput',
]


@@ -116,7 +119,7 @@ class Delta(TwoEndConn):
self.comp_method = comp_method

# connections and weights
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method=comp_method, sparse_data='csr')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr')

# register delay
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)
@@ -140,10 +143,10 @@ class Delta(TwoEndConn):
# synaptic values onto the post
if isinstance(self.conn, All2All):
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max)
@@ -157,7 +160,7 @@ class Delta(TwoEndConn):
# post_vs *= f2(stp_value)
else:
syn_value = self.stp(bm.asarray(pre_spike, dtype=bm.dftype()))
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
if self.post_ref_key:
post_vs = post_vs * (1. - getattr(self.post, self.post_ref_key))

@@ -293,17 +296,17 @@ class Exponential(TwoEndConn):
raise ValueError(f'"tau" must be a scalar or a tensor with size of 1. But we got {self.tau}')

# connections and weights
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='csr')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='csr')

# variables
self.g = variable(bm.zeros, mode, self.post.num)
self.g = variable_(bm.zeros, self.post.num, mode)
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)

# function
self.integral = odeint(lambda g, t: -g / self.tau, method=method)

def reset_state(self, batch_size=None):
self.g.value = variable(bm.zeros, batch_size, self.post.num)
self.g.value = variable_(bm.zeros, self.post.num, batch_size)
self.output.reset_state(batch_size)
if self.stp is not None: self.stp.reset_state(batch_size)

@@ -325,11 +328,11 @@ class Exponential(TwoEndConn):
if isinstance(self.conn, All2All):
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
if self.stp is not None: syn_value = self.stp(syn_value)
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
if self.stp is not None: syn_value = self.stp(syn_value)
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max)
@@ -340,7 +343,7 @@ class Exponential(TwoEndConn):
else:
syn_value = bm.asarray(pre_spike, dtype=bm.dftype())
if self.stp is not None: syn_value = self.stp(syn_value)
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
# updates
self.g.value = self.integral(self.g.value, t, dt) + post_vs

@@ -484,19 +487,19 @@ class DualExponential(TwoEndConn):
f'But we got {self.tau_decay}')

# connections
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')

# variables
self.h = variable(bm.zeros, mode, self.pre.num)
self.g = variable(bm.zeros, mode, self.pre.num)
self.h = variable_(bm.zeros, self.pre.num, mode)
self.g = variable_(bm.zeros, self.pre.num, mode)
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)

# integral
self.integral = odeint(method=method, f=JointEq([self.dg, self.dh]))

def reset_state(self, batch_size=None):
self.h.value = variable(bm.zeros, batch_size, self.pre.num)
self.g.value = variable(bm.zeros, batch_size, self.pre.num)
self.h.value = variable_(bm.zeros, self.pre.num, batch_size)
self.g.value = variable_(bm.zeros, self.pre.num, batch_size)
self.output.reset_state(batch_size)
if self.stp is not None: self.stp.reset_state(batch_size)

@@ -528,16 +531,16 @@ class DualExponential(TwoEndConn):
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(syn_value)
else:
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

# output
return self.output(post_vs)
@@ -826,11 +829,11 @@ class NMDA(TwoEndConn):
self.stop_spike_gradient = stop_spike_gradient

# connections and weights
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')

# variables
self.g = variable(bm.zeros, mode, self.pre.num)
self.x = variable(bm.zeros, mode, self.pre.num)
self.g = variable_(bm.zeros, self.pre.num, mode)
self.x = variable_(bm.zeros, self.pre.num, mode)
self.delay_step = self.register_delay(f"{self.pre.name}.spike", delay_step, self.pre.spike)

# integral
@@ -843,8 +846,8 @@ class NMDA(TwoEndConn):
return -x / self.tau_rise

def reset_state(self, batch_size=None):
self.g.value = variable(bm.zeros, batch_size, self.pre.num)
self.x.value = variable(bm.zeros, batch_size, self.pre.num)
self.g.value = variable_(bm.zeros, self.pre.num, batch_size)
self.x.value = variable_(bm.zeros, self.pre.num, batch_size)
self.output.reset_state(batch_size)
if self.stp is not None: self.stp.reset_state(batch_size)

@@ -869,16 +872,101 @@ class NMDA(TwoEndConn):
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(syn_value)
else:
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

# output
return self.output(post_vs)


class PoissonInput(SynConn):
"""Poisson Input to the given `Variable`.

Adds independent Poisson input to a target variable. For large
numbers of inputs, this is much more efficient than creating a
`PoissonGroup`. The synaptic events are generated randomly during the
simulation and are not preloaded and stored in memory. All the inputs must
target the same variable, have the same frequency and same synaptic weight.
All neurons in the target variable receive independent realizations of
Poisson spike trains.

Parameters
----------
target_var: Variable
The variable that is targeted by this input.
num_input: int
The number of inputs.
freq: float
The frequency of each of the inputs. Must be a scalar.
weight: float
The synaptic weight. Must be a scalar.
"""

def __init__(
self,
target_var: bm.Variable,
num_input: int,
freq: Union[int, float],
weight: Union[int, float],
seed: Optional[int] = None,
mode: Mode = normal,
name: str = None
):
from ..neurons.input_groups import InputGroup, OutputGroup
super(PoissonInput, self).__init__(InputGroup(1), OutputGroup(1), name=name, mode=mode)
self.pre = None
self.post = None

# check data
if not isinstance(target_var, bm.Variable):
raise TypeError(f'"target_var" must be an instance of Variable. '
f'But we got {type(target_var)}: {target_var}')
check_integer(num_input, 'num_input', min_bound=1)
check_float(freq, 'freq', min_bound=0., allow_int=True)
check_float(weight, 'weight', allow_int=True)
check_mode(mode, (NormalMode, BatchingMode), name=self.__class__.__name__)

# parameters
self.target_var = target_var
self.num_input = num_input
self.freq = freq
self.weight = weight
self.seed = seed
self.rng = bm.random.RandomState(self.seed)

def update(self, tdi):
p = self.freq * tdi.dt / 1e3
a = self.num_input * p
b = self.num_input * (1 - p)
if isinstance(tdi.dt, (int, float)): # dt is not in tracing
if (a > 5) and (b > 5):
inp = self.rng.normal(a, b * p, self.target_var.shape)
else:
inp = self.rng.binomial(self.num_input, p, self.target_var.shape)

else: # dt is in tracing
inp = bm.cond((a > 5) * (b > 5),
lambda _: self.rng.normal(a, b * p, self.target_var.shape),
lambda _: self.rng.binomial(self.num_input, p, self.target_var.shape),
None)
self.target_var += inp * self.weight

def __repr__(self):
names = self.__class__.__name__
return f'{names}(name={self.name}, num_input={self.num_input}, freq={self.freq}, weight={self.weight})'

def reset_state(self, batch_size=None):
pass

def reset(self, batch_size=None):
self.rng.seed(self.seed)
self.reset_state(batch_size)


+ 8
- 8
brainpy/dyn/synapses/biological_models.py View File

@@ -181,7 +181,7 @@ class AMPA(TwoEndConn):
raise ValueError(f'"T_duration" must be a scalar or a tensor with size of 1. But we got {T_duration}')

# connection
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')

# variables
self.g = variable(bm.zeros, mode, self.pre.num)
@@ -226,16 +226,16 @@ class AMPA(TwoEndConn):
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(syn_value)
else:
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

# output
return self.output(post_vs)
@@ -526,7 +526,7 @@ class BioNMDA(TwoEndConn):
self.stop_spike_gradient = stop_spike_gradient

# connections and weights
self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='ij')
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, sparse_data='ij')

# variables
self.g = variable(bm.zeros, mode, self.pre.num)
@@ -575,16 +575,16 @@ class BioNMDA(TwoEndConn):
syn_value = self.g.value
if self.stp is not None: syn_value = self.stp(syn_value)
if isinstance(self.conn, All2All):
post_vs = self.syn2post_with_all2all(syn_value, self.g_max)
post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
elif isinstance(self.conn, One2One):
post_vs = self.syn2post_with_one2one(syn_value, self.g_max)
post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
else:
if self.comp_method == 'sparse':
f = lambda s: bm.pre2post_sum(s, self.post.num, *self.conn_mask)
if isinstance(self.mode, BatchingMode): f = vmap(f)
post_vs = f(syn_value)
else:
post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask)
post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask)

# output
return self.output(post_vs)

+ 2
- 2
brainpy/dyn/synapses/gap_junction.py View File

@@ -29,8 +29,8 @@ class GapJunction(TwoEndConn):
conn=conn,
name=name)
# checking
self.check_pre_attrs('V', 'spike')
self.check_post_attrs('V', 'input', 'spike')
self.check_pre_attrs('V')
self.check_post_attrs('V', 'input')

# assert isinstance(self.output, _NullSynOut)
# assert isinstance(self.stp, _NullSynSTP)


+ 20
- 0
brainpy/dyn/tests/test_base_classes.py View File

@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-

import unittest

import brainpy as bp


class TestDynamicalSystem(unittest.TestCase):
def test_delay(self):
A = bp.neurons.LIF(1)
B = bp.neurons.LIF(1)
C = bp.neurons.LIF(1)
A2B = bp.synapses.Exponential(A, B, bp.conn.All2All(), delay_step=1)
A2C = bp.synapses.Exponential(A, C, bp.conn.All2All(), delay_step=None)
net = bp.Network(A, B, C, A2B, A2C)

runner = bp.DSRunner(net,)
runner.run(10.)



+ 47
- 0
brainpy/dyn/tests/test_dyn_runner.py View File

@@ -32,6 +32,53 @@ class TestDSRunner(unittest.TestCase):
runner = bp.dyn.DSRunner(ExampleDS(), dt=1., monitors=['i'], progress_bar=False)
runner.run(100.)

def test_DSView(self):
class EINet(bp.dyn.Network):
def __init__(self, scale=1.0, method='exp_auto'):
super(EINet, self).__init__()

# network size
num_exc = int(800 * scale)
num_inh = int(200 * scale)

# neurons
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.)
self.E = bp.neurons.LIF(num_exc, **pars, method=method)
self.I = bp.neurons.LIF(num_inh, **pars, method=method)
self.E.V[:] = bm.random.randn(num_exc) * 2 - 55.
self.I.V[:] = bm.random.randn(num_inh) * 2 - 55.

# synapses
we = 0.6 / scale # excitatory synaptic weight (voltage)
wi = 6.7 / scale # inhibitory synaptic weight
self.E2E = bp.synapses.Exponential(self.E, self.E[:100], bp.conn.FixedProb(0.02),
output=bp.synouts.COBA(E=0.), g_max=we,
tau=5., method=method)
self.E2I = bp.synapses.Exponential(self.E, self.I[:100], bp.conn.FixedProb(0.02),
output=bp.synouts.COBA(E=0.), g_max=we,
tau=5., method=method)
self.I2E = bp.synapses.Exponential(self.I, self.E[:100], bp.conn.FixedProb(0.02),
output=bp.synouts.COBA(E=-80.), g_max=wi,
tau=10., method=method)
self.I2I = bp.synapses.Exponential(self.I, self.I[:100], bp.conn.FixedProb(0.02),
output=bp.synouts.COBA(E=-80.), g_max=wi,
tau=10., method=method)

net = EINet(scale=1., method='exp_auto')
# with JIT
runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike},
inputs=[(net.E.input, 20.), (net.I.input, 20.)]).run(1.)

# without JIT
runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike},
inputs=[(net.E.input, 20.), (net.I.input, 20.)],
jit=False).run(0.2)






# class TestMonitor(TestCase):
# def test_1d_array(self):
# try1 = TryGroup(monitors=['a'])


+ 110
- 10
brainpy/initialize/generic.py View File

@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
import warnings
from typing import Union, Callable, Optional

import jax.numpy as jnp
@@ -11,10 +11,10 @@ from brainpy.types import Shape, Array
from brainpy.modes import Mode, NormalMode, BatchingMode
from .base import Initializer


__all__ = [
'parameter',
'variable',
'variable_',
'noise',
'delay',

@@ -27,6 +27,7 @@ def parameter(
param: Union[Callable, Initializer, bm.ndarray, np.ndarray, jnp.ndarray, float, int, bool],
size: Shape,
allow_none: bool = True,
allow_scalar: bool = True,
):
"""Initialize parameters.

@@ -42,11 +43,17 @@ def parameter(
The shape of the parameter.
allow_none: bool
Whether allow the parameter is None.
allow_scalar: bool
Whether allow the parameter is a scalar value.

Returns
-------
param: JaxArray, float, None
param: JaxArray, float, int, bool, None
The initialized parameter.

See Also
--------
variable_, noise, delay
"""
if param is None:
if allow_none:
@@ -55,9 +62,9 @@ def parameter(
raise ValueError(f'Expect a parameter with type of float, JaxArray, Initializer, or '
f'Callable function, but we got None. ')
size = to_size(size)
if isinstance(param, (float, int, bool)):
if allow_scalar and isinstance(param, (float, int, bool)):
return param
elif callable(param):
if callable(param):
param = bm.asarray(param(size))
elif isinstance(param, (np.ndarray, jnp.ndarray)):
param = bm.asarray(param)
@@ -67,9 +74,11 @@ def parameter(
param = param
else:
raise ValueError(f'Unknown param type {type(param)}: {param}')
if param.shape != () and param.shape != (1,) and param.shape != size:
raise ValueError(f'The shape of the parameters should be (), (1,) '
f'or {size}, but we got {param.shape}')
if allow_scalar:
if param.shape == () or param.shape == (1,):
return param
if param.shape != size:
raise ValueError(f'The shape of the parameters should be {size}, but we got {param.shape}')
return param


@@ -78,15 +87,81 @@ def init_param(
size: Shape,
allow_none: bool = True,
):
"""Initialize parameters. Same as ``parameter()``.

.. deprecated:: 2.2.3.4
Will be removed since version 2.4.0.
"""
return parameter(param, size, allow_none)


def variable_(
data: Union[Callable, Array],
size: Shape = None,
batch_size_or_mode: Optional[Union[int, bool, Mode]] = None,
batch_axis: int = 0,
):
"""Initialize variables. Same as `variable()`.

Parameters
----------
data: callable, function, Array
The data to be initialized as a ``Variable``.
batch_size_or_mode: int, bool, Mode, optional
The batch size, model ``Mode``, boolean state.
This is used to specify the batch size of this variable.
If it is a boolean or an instance of ``Mode``, the batch size will be 1.
If it is None, the variable has no batch axis.
size: Shape
The shape of the variable.
batch_axis: int
The batch axis.

Returns
-------
variable: bm.Variable
The target ``Variable`` instance.

See Also
--------
variable, parameter, noise, delay

"""
return variable(data, batch_size_or_mode, size, batch_axis)


def variable(
data: Union[Callable, Array],
batch_size_or_mode: Optional[Union[int, bool, Mode]] = None,
size: Shape = None,
batch_axis: int = 0,
):
"""Initialize variables.

Parameters
----------
data: callable, function, Array
The data to be initialized as a ``Variable``.
batch_size_or_mode: int, bool, Mode, optional
The batch size, model ``Mode``, boolean state.
This is used to specify the batch size of this variable.
If it is a boolean or an instance of ``Mode``, the batch size will be 1.
If it is None, the variable has no batch axis.
size: Shape
The shape of the variable.
batch_axis: int
The batch axis.

Returns
-------
variable: bm.Variable
The target ``Variable`` instance.

See Also
--------
variable_, parameter, noise, delay

"""
size = to_size(size)
if callable(data):
if size is None:
@@ -124,11 +199,33 @@ def variable(


def noise(
noises: Optional[Union[int, bm.ndarray, jnp.ndarray, Initializer, Callable]],
noises: Optional[Union[int, float, bm.ndarray, jnp.ndarray, Initializer, Callable]],
size: Shape,
num_vars: int = 1,
noise_idx: int = 0,
) -> Optional[Callable]:
"""Initialize a noise function.

Parameters
----------
noises: Any
size: Shape
The size of the noise.
num_vars: int
The number of variables.
noise_idx: int
The index of the current noise among all noise variables.

Returns
-------
noise_func: function, None
The noise function.

See Also
--------
variable_, parameter, delay

"""
if callable(noises):
return noises
elif noises is None:
@@ -162,6 +259,10 @@ def delay(
-------
info: tuple
The triple of delay type, delay steps, and delay variable.

See Also
--------
variable_, parameter, noise
"""
# check delay type
if delay_step is None:
@@ -198,4 +299,3 @@ def delay(
delays = None

return delay_type, delay_step, delays


+ 2
- 2
brainpy/inputs/currents.py View File

@@ -307,9 +307,9 @@ def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None,

def _f(t):
x.value = x + dt * ((mean - x) / tau) + sigma * dt_sqrt * rng.rand(n)
return x.value

f = bm.make_loop(_f, dyn_vars=[x, rng], out_vars=x)
noises = f(jnp.arange(t_start, t_end, dt))
noises = bm.for_loop(_f, [x, rng], jnp.arange(t_start, t_end, dt))

t_end = duration if t_end is None else t_end
i_start = int(t_start / dt)


+ 1
- 1
brainpy/integrators/ode/tests/test_ode_method_adaptive_rk.py View File

@@ -45,7 +45,7 @@ def run_integrator(method, show=False, tol=0.001, adaptive=True):

if show:
fig = plt.figure()
ax = fig.gca(projection='3d')
ax = fig.add_subplot(111, projection='3d')
plt.plot(mon_x, mon_y, mon_z)
ax.set_xlabel('x')
ax.set_xlabel('y')


+ 14
- 17
brainpy/integrators/runner.py View File

@@ -217,16 +217,12 @@ class IntegratorRunner(Runner):

# build the update step
if self.jit['predict']:
_loop_func = bm.make_loop(
self._step,
dyn_vars=self.dyn_vars,
out_vars={k: self.variables[k] for k in self.monitors.keys()},
has_return=True
)
def _loop_func(times):
return bm.for_loop(self._step, self.dyn_vars, times)
else:
def _loop_func(times):
out_vars = {k: [] for k in self.monitors.keys()}
returns = {k: [] for k in self.fun_monitors.keys()}
returns.update({k: [] for k in self.monitors.keys()})
for i in range(len(times)):
_t = times[i]
_dt = self.dt
@@ -237,9 +233,9 @@ class IntegratorRunner(Runner):
self._step(_t)
# variable monitors
for k in self.monitors.keys():
out_vars[k].append(bm.as_device_array(self.variables[k]))
out_vars = {k: bm.asarray(out_vars[k]) for k in self.monitors.keys()}
return out_vars, returns
returns[k].append(bm.as_device_array(self.variables[k]))
returns = {k: bm.asarray(returns[k]) for k in returns.keys()}
return returns
self.step_func = _loop_func

def _step(self, t):
@@ -252,11 +248,6 @@ class IntegratorRunner(Runner):
kwargs.update({k: v[self.idx.value] for k, v in self._dyn_args.items()})
self.idx += 1

# return of function monitors
returns = dict()
for key, func in self.fun_monitors.items():
returns[key] = func(t, self.dt)

# call integrator function
update_values = self.target(**kwargs)
if len(self.target.variables) == 1:
@@ -268,6 +259,13 @@ class IntegratorRunner(Runner):
# progress bar
if self.progress_bar:
id_tap(lambda *args: self._pbar.update(), ())

# return of function monitors
returns = dict()
for key, func in self.fun_monitors.items():
returns[key] = func(t, self.dt)
for k in self.monitors.keys():
returns[k] = self.variables[k].value
return returns

def run(self, duration, start_t=None, eval_time=False):
@@ -302,14 +300,13 @@ class IntegratorRunner(Runner):
refresh=True)
if eval_time:
t0 = time.time()
hists, returns = self.step_func(times)
hists = self.step_func(times)
if eval_time:
running_time = time.time() - t0
if self.progress_bar:
self._pbar.close()

# post-running
hists.update(returns)
times += self.dt
if self.numpy_mon_after_run:
times = np.asarray(times)


+ 1
- 1
brainpy/integrators/sde/tests/test_sde_scalar.py View File

@@ -50,7 +50,7 @@ def lorenz_system(method, **kwargs):
mon3 = bp.math.array(mon3).to_numpy()

fig = plt.figure()
ax = fig.gca(projection='3d')
ax = fig.add_subplot(111, projection='3d')
plt.plot(mon1, mon2, mon3)
ax.set_xlabel('x')
ax.set_xlabel('y')


+ 9
- 1
brainpy/math/autograd.py View File

@@ -16,7 +16,9 @@ from jax.tree_util import tree_flatten, tree_unflatten, tree_map, tree_transpose
from jax.util import safe_map

from brainpy import errors
from brainpy.math.jaxarray import JaxArray
from brainpy.base.naming import get_unique_name
from brainpy.math.jaxarray import JaxArray, add_context, del_context


__all__ = [
'grad', # gradient of scalar function
@@ -28,20 +30,26 @@ __all__ = [

def _make_cls_call_func(grad_func, grad_tree, grad_vars, dyn_vars,
argnums, return_value, has_aux):
name = get_unique_name('_brainpy_object_oriented_grad_')

# outputs
def call_func(*args, **kwargs):
old_grad_vs = [v.value for v in grad_vars]
old_dyn_vs = [v.value for v in dyn_vars]
try:
add_context(name)
grads, (outputs, new_grad_vs, new_dyn_vs) = grad_func(old_grad_vs,
old_dyn_vs,
*args,
**kwargs)
del_context(name)
except UnexpectedTracerError as e:
del_context(name)
for v, d in zip(grad_vars, old_grad_vs): v._value = d
for v, d in zip(dyn_vars, old_dyn_vs): v._value = d
raise errors.JaxTracerError(variables=dyn_vars + grad_vars) from e
except Exception as e:
del_context(name)
for v, d in zip(grad_vars, old_grad_vs): v._value = d
for v, d in zip(dyn_vars, old_dyn_vs): v._value = d
raise e


+ 53
- 37
brainpy/math/controls.py View File

@@ -13,9 +13,10 @@ except ImportError:
from jax.core import UnexpectedTracerError

from brainpy import errors
from brainpy.base.naming import get_unique_name
from brainpy.math.jaxarray import (JaxArray, Variable,
turn_on_global_jit,
turn_off_global_jit)
add_context,
del_context)
from brainpy.math.numpy_ops import as_device_array

__all__ = [
@@ -158,17 +159,19 @@ def make_loop(body_fun, dyn_vars, out_vars=None, has_return=False):
out_vars=out_vars,
has_return=has_return)

name = get_unique_name('_brainpy_object_oriented_make_loop_')

# functions
if has_return:
def call(xs=None, length=None):
init_values = [v.value for v in dyn_vars]
try:
turn_on_global_jit()
add_context(name)
dyn_values, (out_values, results) = lax.scan(
f=fun2scan, init=init_values, xs=xs, length=length)
turn_off_global_jit()
del_context(name)
except UnexpectedTracerError as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, init_values): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
for v, d in zip(dyn_vars, dyn_values): v._value = d
@@ -178,15 +181,15 @@ def make_loop(body_fun, dyn_vars, out_vars=None, has_return=False):
def call(xs):
init_values = [v.value for v in dyn_vars]
try:
turn_on_global_jit()
add_context(name)
dyn_values, out_values = lax.scan(f=fun2scan, init=init_values, xs=xs)
turn_off_global_jit()
del_context(name)
except UnexpectedTracerError as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, init_values): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
except Exception as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, init_values): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v._value = d
@@ -255,20 +258,22 @@ def make_while(cond_fun, body_fun, dyn_vars):
for v, d in zip(dyn_vars, dyn_values): v._value = d
return as_device_array(cond_fun(static_values))

name = get_unique_name('_brainpy_object_oriented_make_while_')

def call(x=None):
dyn_init = [v.value for v in dyn_vars]
try:
turn_on_global_jit()
add_context(name)
dyn_values, _ = lax.while_loop(cond_fun=_cond_fun,
body_fun=_body_fun,
init_val=(dyn_init, x))
turn_off_global_jit()
del_context(name)
except UnexpectedTracerError as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, dyn_init): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
except Exception as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, dyn_init): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v._value = d
@@ -330,6 +335,8 @@ def make_cond(true_fun, false_fun, dyn_vars=None):
if not isinstance(v, JaxArray):
raise ValueError(f'Only support {JaxArray.__name__}, but got {type(v)}')

name = get_unique_name('_brainpy_object_oriented_make_cond_')

if len(dyn_vars) > 0:
def _true_fun(op):
dyn_vals, static_vals = op
@@ -348,15 +355,15 @@ def make_cond(true_fun, false_fun, dyn_vars=None):
def call(pred, x=None):
old_values = [v.value for v in dyn_vars]
try:
turn_on_global_jit()
add_context(name)
dyn_values, res = lax.cond(pred, _true_fun, _false_fun, (old_values, x))
turn_off_global_jit()
del_context(name)
except UnexpectedTracerError as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, old_values): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
except Exception as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, old_values): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v._value = d
@@ -364,9 +371,9 @@ def make_cond(true_fun, false_fun, dyn_vars=None):

else:
def call(pred, x=None):
turn_on_global_jit()
add_context(name)
res = lax.cond(pred, true_fun, false_fun, x)
turn_off_global_jit()
del_context(name)
return res

return call
@@ -445,6 +452,8 @@ def cond(
if not isinstance(v, Variable):
raise ValueError(f'Only support {Variable.__name__}, but got {type(v)}')

name = get_unique_name('_brainpy_object_oriented_cond_')

# calling the model
if len(dyn_vars) > 0:
def _true_fun(op):
@@ -463,25 +472,25 @@ def cond(

old_values = [v.value for v in dyn_vars]
try:
turn_on_global_jit()
add_context(name)
dyn_values, res = lax.cond(pred=pred,
true_fun=_true_fun,
false_fun=_false_fun,
operand=(old_values, operands))
turn_off_global_jit()
del_context(name)
except UnexpectedTracerError as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, old_values): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
except Exception as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, old_values): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v._value = d
else:
turn_on_global_jit()
add_context(name)
res = lax.cond(pred, true_fun, false_fun, operands)
turn_off_global_jit()
del_context(name)
return res


@@ -591,7 +600,11 @@ def ifelse(
if show_code: print(codes)
exec(compile(codes.strip(), '', 'exec'), code_scope)
f = code_scope['f']
return f(operands)
name = get_unique_name('_brainpy_object_oriented_ifelse_')
add_context(name)
r = f(operands)
del_context(name)
return r


def for_loop(body_fun: Callable,
@@ -616,9 +629,9 @@ def for_loop(body_fun: Callable,
>>> a_hist = bm.for_loop(body, dyn_vars=[a, b], operands=bm.arange(1, 5))
>>> a_hist
DeviceArray([[ 1.],
[ 3.],
[ 6.],
[10.]], dtype=float32)
[ 3.],
[ 6.],
[10.]], dtype=float32)
>>> a
Variable([10.], dtype=float32)
>>> b
@@ -694,22 +707,24 @@ def for_loop(body_fun: Callable,
results = body_fun(*x)
return [v.value for v in dyn_vars], results

name = get_unique_name('_brainpy_object_oriented_for_loop_')

# functions
init_vals = [v.value for v in dyn_vars]
try:
turn_on_global_jit()
add_context(name)
dyn_vals, out_vals = lax.scan(f=fun2scan,
init=init_vals,
xs=operands,
reverse=reverse,
unroll=unroll)
turn_off_global_jit()
del_context(name)
except UnexpectedTracerError as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, init_vals): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
except Exception as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, init_vals): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_vals): v._value = d
@@ -797,19 +812,20 @@ def while_loop(
r = cond_fun(*static_vals)
return r if isinstance(r, JaxArray) else r

name = get_unique_name('_brainpy_object_oriented_while_loop_')
dyn_init = [v.value for v in dyn_vars]
try:
turn_on_global_jit()
add_context(name)
dyn_values, out = lax.while_loop(cond_fun=_cond_fun,
body_fun=_body_fun,
init_val=(dyn_init, operands))
turn_off_global_jit()
del_context(name)
except UnexpectedTracerError as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, dyn_init): v._value = d
raise errors.JaxTracerError(variables=dyn_vars) from e
except Exception as e:
turn_off_global_jit()
del_context(name)
for v, d in zip(dyn_vars, dyn_init): v._value = d
raise e
for v, d in zip(dyn_vars, dyn_values): v._value = d


+ 110
- 25
brainpy/math/delayvars.py View File

@@ -262,6 +262,10 @@ class NeuTimeDelay(TimeDelay):
pass


ROTATION_UPDATING = 'rotation'
CONCAT_UPDATING = 'concatenate'


class LengthDelay(AbstractDelay):
"""Delay variable which has a fixed delay length.

@@ -271,10 +275,36 @@ class LengthDelay(AbstractDelay):
The initial delay data.
delay_len: int
The maximum delay length.
initial_delay_data: Array
The delay data.
initial_delay_data: Any
The delay data. It can be a Python number, like float, int, boolean values.
It can also be arrays. Or a callable function or instance of ``Connector``.
Note that ``initial_delay_data`` should be arranged as the following way::

delay = 1 [ data
delay = 2 data
... ....
... ....
delay = delay_len-1 data
delay = delay_len data ]

.. versionchanged:: 2.2.3.2

The data in the previous version of ``LengthDelay`` is::

delay = delay_len [ data
delay = delay_len-1 data
... ....
... ....
delay = 2 data
delay = 1 data ]


name: str
The delay object name.
batch_axis: int
The batch axis. If not provided, it will be inferred from the `delay_target`.
update_method: str
The method used for updating delay.

See Also
--------
@@ -288,19 +318,35 @@ class LengthDelay(AbstractDelay):
initial_delay_data: Union[float, int, bool, ndarray, jnp.ndarray, Callable] = None,
name: str = None,
batch_axis: int = None,
update_method: str = ROTATION_UPDATING
):
super(LengthDelay, self).__init__(name=name)

assert update_method in [ROTATION_UPDATING, CONCAT_UPDATING]
self.update_method = update_method
# attributes and variables
self.data: Variable = None
self.num_delay_step: int = None
self.idx: Variable = None
self.data: Variable = None

# initialization
self.reset(delay_target, delay_len, initial_delay_data, batch_axis)

@property
def delay_shape(self):
"""The data shape of this delay variable."""
return self.data.shape

@property
def delay_target_shape(self):
"""The data shape of the delay target."""
return self.data.shape[1:]

def __repr__(self):
return f'{self.__class__.__name__}(num_delay_step={self.num_delay_step}, delay_target_shape={self.data.shape[1:]})'
name = self.__class__.__name__
return (f'{name}(num_delay_step={self.num_delay_step}, '
f'delay_target_shape={self.delay_target_shape}, '
f'update_method={self.update_method})')

def reset(
self,
@@ -321,13 +367,7 @@ class LengthDelay(AbstractDelay):
delay_len = self.num_delay_step - 1
self.num_delay_step = delay_len + 1

# time variables
if self.idx is None:
self.idx = Variable(jnp.asarray([0], dtype=jnp.int32))
else:
self.idx.value = jnp.asarray([0], dtype=jnp.int32)

# delay data
# initialize delay data
if self.data is None:
if batch_axis is None:
if isinstance(delay_target, Variable) and (delay_target.batch_axis is not None):
@@ -338,39 +378,84 @@ class LengthDelay(AbstractDelay):
else:
self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape,
dtype=delay_target.dtype)
self.data[-1] = delay_target

# update delay data
self.data[0] = delay_target
if initial_delay_data is None:
pass
elif isinstance(initial_delay_data, (ndarray, jnp.ndarray, float, int, bool)):
self.data[:-1] = initial_delay_data
self.data[1:] = initial_delay_data
elif callable(initial_delay_data):
self.data[:-1] = initial_delay_data((delay_len,) + delay_target.shape,
self.data[1:] = initial_delay_data((delay_len,) + delay_target.shape,
dtype=delay_target.dtype)
else:
raise ValueError(f'"delay_data" does not support {type(initial_delay_data)}')

# time variables
if self.update_method == ROTATION_UPDATING:
if self.idx is None:
self.idx = Variable(stop_gradient(jnp.asarray([0], dtype=jnp.int32)))
else:
self.idx.value = stop_gradient(jnp.asarray([0], dtype=jnp.int32))

def _check_delay(self, delay_len):
raise ValueError(f'The request delay length should be less than the '
f'maximum delay {self.num_delay_step}. '
f'But we got {delay_len}')

def __call__(self, delay_len, *indices):
# check
return self.retrieve(delay_len, *indices)

def retrieve(self, delay_len, *indices):
"""Retrieve the delay data acoording to the delay length.

Parameters
----------
delay_len: int, Array
The delay length used to retrieve the data.
"""
if check.is_checking():
check_error_in_jit(bm.any(delay_len >= self.num_delay_step), self._check_delay, delay_len)
# the delay length
delay_idx = (self.idx[0] - delay_len - 1) % self.num_delay_step
delay_idx = stop_gradient(delay_idx)
if not jnp.issubdtype(delay_idx.dtype, jnp.integer):
raise ValueError(f'"delay_len" must be integer, but we got {delay_len}')
# the delay data

if self.update_method == ROTATION_UPDATING:
delay_idx = (self.idx[0] + delay_len) % self.num_delay_step
delay_idx = stop_gradient(delay_idx)

elif self.update_method == CONCAT_UPDATING:
delay_idx = delay_len

else:
raise ValueError(f'Unknown updating method "{self.update_method}"')

# the delay index
if isinstance(delay_idx, int):
pass
elif hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
indices = (delay_idx,) + tuple(indices)
# the delay data
return self.data[indices]

def update(self, value: Union[float, JaxArray, jnp.DeviceArray]):
idx = stop_gradient(self.idx[0])
self.data[idx] = value
self.idx.value = stop_gradient((self.idx + 1) % self.num_delay_step)
def update(self, value: Union[float, int, bool, JaxArray, jnp.DeviceArray]):
"""Update delay variable with the new data.

Parameters
----------
value: Any
The value of the latest data, used to update this delay variable.
"""
if self.update_method == ROTATION_UPDATING:
self.idx.value = stop_gradient((self.idx - 1) % self.num_delay_step)
self.data[self.idx[0]] = value

elif self.update_method == CONCAT_UPDATING:
if self.num_delay_step >= 2:
self.data.value = bm.vstack([bm.broadcast_to(value, self.data.shape[1:]), self.data[1:]])
else:
self.data[:] = value

else:
raise ValueError(f'Unknown updating method "{self.update_method}"')


class NeuLenDelay(LengthDelay):


+ 322
- 204
brainpy/math/jaxarray.py View File

@@ -33,26 +33,69 @@ _all_slice = slice(None, None, None)
msg = ('JaxArray created outside of the jit function '
'cannot be updated in JIT mode. You should '
'mark it as brainpy.math.Variable instead.')
_global_jit_mode = False

_jax_transformation_context_ = []

def turn_on_global_jit():
"""Turn on the global JIT mode to declare
all instantiated JaxArray cannot be updated."""
global _global_jit_mode
_global_jit_mode = True

def add_context(name):
_jax_transformation_context_.append(name)

def turn_off_global_jit():
"""Turn off the global JIT mode."""
global _global_jit_mode
_global_jit_mode = False

def del_context(name=None):
try:
context = _jax_transformation_context_.pop(-1)
if name is not None:
if context != name:
raise MathError('Transformation context is different!')
# warnings.warn(, UserWarning)
except IndexError:
raise MathError('No transformation context!')
# warnings.warn('No transformation context!', UserWarning)


def get_context():
if len(_jax_transformation_context_) > 0:
return _jax_transformation_context_[-1]
else:
return None


def check_context(arr_context):
if arr_context is None:
if len(_jax_transformation_context_) > 0:
raise MathError(f'JaxArray created outside of the transformation functions '
f'({_jax_transformation_context_[-1]}) cannot be updated. '
f'You should mark it as a brainpy.math.Variable instead.')
return True
else:
return False
else:
if len(_jax_transformation_context_) > 0:
if arr_context != _jax_transformation_context_[-1]:
raise MathError(f'JaxArray context "{arr_context}" differs from the JAX '
f'transformation context "{_jax_transformation_context_[-1]}"'
'\n\n'
'JaxArray created in one transformation function '
'cannot be updated another transformation function. '
'You should mark it as a brainpy.math.Variable instead.')
return True
else:
return False


def _check_input_array(array):
if isinstance(array, JaxArray):
return array.value
elif isinstance(array, np.ndarray):
return jnp.asarray(array)
else:
return array


class JaxArray(object):
"""Multiple-dimensional array in JAX backend.
"""
__slots__ = ("_value", "_outside_global_jit")
__slots__ = ("_value", "_transform_context")

def __init__(self, value, dtype=None):
# array value
@@ -64,7 +107,7 @@ class JaxArray(object):
value = jnp.asarray(value, dtype=dtype)
self._value = value
# jit mode
self._outside_global_jit = False if _global_jit_mode else True
self._transform_context = get_context()

@property
def value(self):
@@ -77,7 +120,7 @@ class JaxArray(object):
def update(self, value):
"""Update the value of this JaxArray.
"""
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)
if isinstance(value, JaxArray):
value = value.value
@@ -174,26 +217,31 @@ class JaxArray(object):
if isinstance(index, slice) and (index == _all_slice):
return self.value
elif isinstance(index, tuple):
index = tuple(x.value if isinstance(x, JaxArray) else x for x in index)
index = tuple((x.value if isinstance(x, JaxArray) else x) for x in index)
elif isinstance(index, JaxArray):
index = index.value
return self.value[index]

def __setitem__(self, index, value):
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)

# value is JaxArray
if isinstance(value, JaxArray):
value = value.value
# value is numpy.ndarray
elif isinstance(value, np.ndarray):
value = jnp.asarray(value)

# tuple index
# index is a tuple
if isinstance(index, tuple):
index = tuple(x.value if isinstance(x, JaxArray) else x for x in index)

# JaxArray index
index = tuple(_check_input_array(x) for x in index)
# index is JaxArray
elif isinstance(index, JaxArray):
index = index.value
# index is numpy.ndarray
elif isinstance(index, np.ndarray):
index = jnp.asarray(index)

# update
self._value = self._value.at[index].set(value)
@@ -221,199 +269,199 @@ class JaxArray(object):
return JaxArray(self._value.__invert__())

def __eq__(self, oc):
return JaxArray(self._value == (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value == _check_input_array(oc))

def __ne__(self, oc):
return JaxArray(self._value != (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value != _check_input_array(oc))

def __lt__(self, oc):
return JaxArray(self._value < (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value < _check_input_array(oc))

def __le__(self, oc):
return JaxArray(self._value <= (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value <= _check_input_array(oc))

def __gt__(self, oc):
return JaxArray(self._value > (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value > _check_input_array(oc))

def __ge__(self, oc):
return JaxArray(self._value >= (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value >= _check_input_array(oc))

def __add__(self, oc):
return JaxArray(self._value + (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value + _check_input_array(oc))

def __radd__(self, oc):
return JaxArray(self._value + (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value + _check_input_array(oc))

def __iadd__(self, oc):
# a += b
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)
self._value += (oc.value if isinstance(oc, JaxArray) else oc)
self._value += _check_input_array(oc)
return self

def __sub__(self, oc):
return JaxArray(self._value - (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value - _check_input_array(oc))

def __rsub__(self, oc):
return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) - self._value)
return JaxArray(_check_input_array(oc) - self._value)

def __isub__(self, oc):
# a -= b
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)
self._value = self._value - (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self._value - _check_input_array(oc)
return self

def __mul__(self, oc):
return JaxArray(self._value * (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value * _check_input_array(oc))

def __rmul__(self, oc):
return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) * self._value)
return JaxArray(_check_input_array(oc) * self._value)

def __imul__(self, oc):
# a *= b
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)
self._value = self._value * (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self._value * _check_input_array(oc)
return self

def __rdiv__(self, oc):
return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) / self._value)
return JaxArray(_check_input_array(oc) / self._value)

def __truediv__(self, oc):
return JaxArray(self._value / (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value / _check_input_array(oc))

def __rtruediv__(self, oc):
return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) / self._value)
return JaxArray(_check_input_array(oc) / self._value)

def __itruediv__(self, oc):
# a /= b
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)
self._value = self._value / (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self._value / _check_input_array(oc)
return self

def __floordiv__(self, oc):
return JaxArray(self._value // (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value // _check_input_array(oc))

def __rfloordiv__(self, oc):
return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) // self._value)
return JaxArray(_check_input_array(oc) // self._value)

def __ifloordiv__(self, oc):
# a //= b
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)
self._value = self._value // (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self._value // _check_input_array(oc)
return self

def __divmod__(self, oc):
return JaxArray(self._value.__divmod__(oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value.__divmod__(_check_input_array(oc)))

def __rdivmod__(self, oc):
return JaxArray(self._value.__rdivmod__(oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value.__rdivmod__(_check_input_array(oc)))

def __mod__(self, oc):
return JaxArray(self._value % (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value % _check_input_array(oc))

def __rmod__(self, oc):
return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) % self._value)
return JaxArray(_check_input_array(oc) % self._value)

def __imod__(self, oc):
# a %= b
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)
self._value = self._value % (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self._value % _check_input_array(oc)
return self

def __pow__(self, oc):
return JaxArray(self._value ** (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value ** _check_input_array(oc))

def __rpow__(self, oc):
return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) ** self._value)
return JaxArray(_check_input_array(oc) ** self._value)

def __ipow__(self, oc):
# a **= b
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)
self._value = self._value ** (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self._value ** _check_input_array(oc)
return self

def __matmul__(self, oc):
return JaxArray(self._value @ (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value @ _check_input_array(oc))

def __rmatmul__(self, oc):
return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) @ self._value)
return JaxArray(_check_input_array(oc) @ self._value)

def __imatmul__(self, oc):
# a @= b
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)
self._value = self._value @ (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self._value @ _check_input_array(oc)
return self

def __and__(self, oc):
return JaxArray(self._value & (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value & _check_input_array(oc))

def __rand__(self, oc):
return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) & self._value)
return JaxArray(_check_input_array(oc) & self._value)

def __iand__(self, oc):
# a &= b
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)
self._value = self._value & (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self._value & _check_input_array(oc)
return self

def __or__(self, oc):
return JaxArray(self._value | (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value | _check_input_array(oc))

def __ror__(self, oc):
return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) | self._value)
return JaxArray(_check_input_array(oc) | self._value)

def __ior__(self, oc):
# a |= b
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)
self._value = self._value | (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self._value | _check_input_array(oc)
return self

def __xor__(self, oc):
return JaxArray(self._value ^ (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value ^ _check_input_array(oc))

def __rxor__(self, oc):
return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) ^ self._value)
return JaxArray(_check_input_array(oc) ^ self._value)

def __ixor__(self, oc):
# a ^= b
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)
self._value = self._value ^ (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self._value ^ _check_input_array(oc)
return self

def __lshift__(self, oc):
return JaxArray(self._value << (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value << _check_input_array(oc))

def __rlshift__(self, oc):
return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) << self._value)
return JaxArray(_check_input_array(oc) << self._value)

def __ilshift__(self, oc):
# a <<= b
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)
self._value = self._value << (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self._value << _check_input_array(oc)
return self

def __rshift__(self, oc):
return JaxArray(self._value >> (oc.value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value >> _check_input_array(oc))

def __rrshift__(self, oc):
return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) >> self._value)
return JaxArray(_check_input_array(oc) >> self._value)

def __irshift__(self, oc):
# a >>= b
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)
self._value = self._value >> (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self._value >> _check_input_array(oc)
return self

def __round__(self, ndigits=None):
@@ -428,17 +476,17 @@ class JaxArray(object):
return self.value.at

def block_host_until_ready(self, *args):
self._value.block_host_until_ready(*args)
return self.value.block_host_until_ready(*args)

def block_until_ready(self, *args):
self._value.block_until_ready(*args)
return self.value.block_until_ready(*args)

def device(self):
raise self.value.device()
return self.value.device()

@property
def device_buffer(self):
raise self.value.device_buffer
return self.value.device_buffer

# ----------------------- #
# NumPy methods #
@@ -533,7 +581,7 @@ class JaxArray(object):

def fill(self, value):
"""Fill the array with a scalar value."""
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)
self._value = jnp.ones_like(self.value) * value

@@ -661,7 +709,7 @@ class JaxArray(object):
but unspecified fields will still be used, in the order in which
they come up in the dtype, to break ties.
"""
if self._outside_global_jit and _global_jit_mode:
if check_context(self._transform_context):
raise MathError(msg)
self._value = self.value.sort(axis=axis, kind=kind, order=order)

@@ -885,10 +933,42 @@ ndarray = JaxArray

class Variable(JaxArray):
"""The pointer to specify the dynamical variable.

Initializing an instance of ``Variable`` by two ways:

>>> import brainpy.math as bm
>>> # 1. init a Variable by the concreate data
>>> v1 = bm.Variable(bm.zeros(10))
>>> # 2. init a Variable by the data shape
>>> v2 = bm.Variable(10)

Note that when initializing a `Variable` by the data shape,
all values in this `Variable` will be initialized as zeros.

Parameters
----------
value_or_size: Shape, Array
The value or the size of the value.
dtype:
The type of the data.
batch_axis: optional, int
The batch axis.
"""
__slots__ = ('_value', '_batch_axis')

def __init__(self, value, dtype=None, batch_axis: int = None):
def __init__(
self,
value_or_size,
dtype=None,
batch_axis: int = None
):
if isinstance(value_or_size, int):
value = jnp.zeros(value_or_size, dtype=dtype)
elif isinstance(value_or_size, (tuple, list)) and all([isinstance(s, int) for s in value_or_size]):
value = jnp.zeros(value_or_size, dtype=dtype)
else:
value = value_or_size

super(Variable, self).__init__(value, dtype=dtype)

# check batch axis
@@ -961,7 +1041,7 @@ class Variable(JaxArray):

# tuple index
if isinstance(index, tuple):
index = tuple(x.value if isinstance(x, JaxArray) else x for x in index)
index = tuple(_check_input_array(x) for x in index)

# JaxArray index
elif isinstance(index, JaxArray):
@@ -972,77 +1052,67 @@ class Variable(JaxArray):

def __iadd__(self, oc):
# a += b
# self._value += (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self.value + (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self.value + _check_input_array(oc)
return self

def __isub__(self, oc):
# a -= b
self._value = self.value - (oc.value if isinstance(oc, JaxArray) else oc)
# self._value -= (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self.value - _check_input_array(oc)
return self

def __imul__(self, oc):
# a *= b
self._value = self.value * (oc.value if isinstance(oc, JaxArray) else oc)
# self._value *= (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self.value * _check_input_array(oc)
return self

def __itruediv__(self, oc):
# a /= b
self._value = self.value / (oc.value if isinstance(oc, JaxArray) else oc)
# self._value /= (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self.value / _check_input_array(oc)
return self

def __ifloordiv__(self, oc):
# a //= b
self._value = self.value // (oc.value if isinstance(oc, JaxArray) else oc)
# self._value //= (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self.value // _check_input_array(oc)
return self

def __imod__(self, oc):
# a %= b
self._value = self.value % (oc.value if isinstance(oc, JaxArray) else oc)
# self._value %= (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self.value % _check_input_array(oc)
return self

def __ipow__(self, oc):
# a **= b
self._value = self.value ** (oc.value if isinstance(oc, JaxArray) else oc)
# self._value **= (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self.value ** _check_input_array(oc)
return self

def __imatmul__(self, oc):
# a @= b
self._value = self.value @ (oc.value if isinstance(oc, JaxArray) else oc)
# self._value @= (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self.value @ _check_input_array(oc)
return self

def __iand__(self, oc):
# a &= b
self._value = self.value.__and__(oc.value if isinstance(oc, JaxArray) else oc)
# self._value &= (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self.value.__and__(_check_input_array(oc))
return self

def __ior__(self, oc):
# a |= b
self._value = self.value | (oc.value if isinstance(oc, JaxArray) else oc)
# self._value |= (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self.value | _check_input_array(oc)
return self

def __ixor__(self, oc):
# a ^= b
self._value = self.value ^ (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self.value ^ _check_input_array(oc)
return self

def __ilshift__(self, oc):
# a <<= b
self._value = self.value << (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self.value << _check_input_array(oc)
return self

def __irshift__(self, oc):
# a >>= b
self._value = self.value >> (oc.value if isinstance(oc, JaxArray) else oc)
self._value = self.value >> _check_input_array(oc)
return self

def fill(self, value):
@@ -1076,109 +1146,109 @@ class Variable(JaxArray):
return self.value.__invert__()

def __eq__(self, oc):
return self.value == (oc.value if isinstance(oc, JaxArray) else oc)
return self.value == _check_input_array(oc)

def __ne__(self, oc):
return self.value != (oc.value if isinstance(oc, JaxArray) else oc)
return self.value != _check_input_array(oc)

def __lt__(self, oc):
return self.value < (oc.value if isinstance(oc, JaxArray) else oc)
return self.value < _check_input_array(oc)

def __le__(self, oc):
return self.value <= (oc.value if isinstance(oc, JaxArray) else oc)
return self.value <= _check_input_array(oc)

def __gt__(self, oc):
return self.value > (oc.value if isinstance(oc, JaxArray) else oc)
return self.value > _check_input_array(oc)

def __ge__(self, oc):
return self.value >= (oc.value if isinstance(oc, JaxArray) else oc)
return self.value >= _check_input_array(oc)

def __add__(self, oc):
return self.value + (oc.value if isinstance(oc, JaxArray) else oc)
return self.value + _check_input_array(oc)

def __radd__(self, oc):
return self.value + (oc.value if isinstance(oc, JaxArray) else oc)
return self.value + _check_input_array(oc)

def __sub__(self, oc):
return self.value - (oc.value if isinstance(oc, JaxArray) else oc)
return self.value - _check_input_array(oc)

def __rsub__(self, oc):
return (oc.value if isinstance(oc, JaxArray) else oc) - self.value
return _check_input_array(oc) - self.value

def __mul__(self, oc):
return self.value * (oc.value if isinstance(oc, JaxArray) else oc)
return self.value * _check_input_array(oc)

def __rmul__(self, oc):
return (oc.value if isinstance(oc, JaxArray) else oc) * self.value
return _check_input_array(oc) * self.value

def __rdiv__(self, oc):
return (oc.value if isinstance(oc, JaxArray) else oc) / self.value
return _check_input_array(oc) / self.value

def __truediv__(self, oc):
return self.value / (oc.value if isinstance(oc, JaxArray) else oc)
return self.value / _check_input_array(oc)

def __rtruediv__(self, oc):
return (oc.value if isinstance(oc, JaxArray) else oc) / self.value
return _check_input_array(oc) / self.value

def __floordiv__(self, oc):
return self.value // (oc.value if isinstance(oc, JaxArray) else oc)
return self.value // _check_input_array(oc)

def __rfloordiv__(self, oc):
return (oc.value if isinstance(oc, JaxArray) else oc) // self.value
return _check_input_array(oc) // self.value

def __divmod__(self, oc):
return self.value.__divmod__(oc.value if isinstance(oc, JaxArray) else oc)
return self.value.__divmod__(_check_input_array(oc))

def __rdivmod__(self, oc):
return self.value.__rdivmod__(oc.value if isinstance(oc, JaxArray) else oc)
return self.value.__rdivmod__(_check_input_array(oc))

def __mod__(self, oc):
return self.value % (oc.value if isinstance(oc, JaxArray) else oc)
return self.value % _check_input_array(oc)

def __rmod__(self, oc):
return (oc.value if isinstance(oc, JaxArray) else oc) % self.value
return _check_input_array(oc) % self.value

def __pow__(self, oc):
return self.value ** (oc.value if isinstance(oc, JaxArray) else oc)
return self.value ** _check_input_array(oc)

def __rpow__(self, oc):
return (oc.value if isinstance(oc, JaxArray) else oc) ** self.value
return _check_input_array(oc) ** self.value

def __matmul__(self, oc):
return self.value @ (oc.value if isinstance(oc, JaxArray) else oc)
return self.value @ _check_input_array(oc)

def __rmatmul__(self, oc):
return (oc.value if isinstance(oc, JaxArray) else oc) @ self.value
return _check_input_array(oc) @ self.value

def __and__(self, oc):
return self.value & (oc.value if isinstance(oc, JaxArray) else oc)
return self.value & _check_input_array(oc)

def __rand__(self, oc):
return (oc.value if isinstance(oc, JaxArray) else oc) & self.value
return _check_input_array(oc) & self.value

def __or__(self, oc):
return self.value | (oc.value if isinstance(oc, JaxArray) else oc)
return self.value | _check_input_array(oc)

def __ror__(self, oc):
return (oc.value if isinstance(oc, JaxArray) else oc) | self.value
return _check_input_array(oc) | self.value

def __xor__(self, oc):
return self.value ^ (oc.value if isinstance(oc, JaxArray) else oc)
return self.value ^ _check_input_array(oc)

def __rxor__(self, oc):
return (oc.value if isinstance(oc, JaxArray) else oc) ^ self.value
return _check_input_array(oc) ^ self.value

def __lshift__(self, oc):
return self.value << (oc.value if isinstance(oc, JaxArray) else oc)
return self.value << _check_input_array(oc)

def __rlshift__(self, oc):
return (oc.value if isinstance(oc, JaxArray) else oc) << self.value
return _check_input_array(oc) << self.value

def __rshift__(self, oc):
return self.value >> (oc.value if isinstance(oc, JaxArray) else oc)
return self.value >> _check_input_array(oc)

def __rrshift__(self, oc):
return (oc.value if isinstance(oc, JaxArray) else oc) >> self.value
return _check_input_array(oc) >> self.value

def __round__(self, ndigits=None):
return self.value.__round__(ndigits)
@@ -1464,8 +1534,8 @@ class TrainVar(Variable):
"""
__slots__ = ('_value', '_batch_axis')

def __init__(self, value, dtype=None, batch_axis: int = None):
super(TrainVar, self).__init__(value, dtype=dtype, batch_axis=batch_axis)
def __init__(self, value_or_size, dtype=None, batch_axis: int = None):
super(TrainVar, self).__init__(value_or_size, dtype=dtype, batch_axis=batch_axis)


class Parameter(Variable):
@@ -1473,75 +1543,61 @@ class Parameter(Variable):
"""
__slots__ = ('_value', '_batch_axis')

def __init__(self, value, dtype=None, batch_axis: int = None):
super(Parameter, self).__init__(value, dtype=dtype, batch_axis=batch_axis)


register_pytree_node(JaxArray,
lambda t: ((t.value,), None),
lambda aux_data, flat_contents: JaxArray(*flat_contents))

register_pytree_node(Variable,
lambda t: ((t.value,), None),
lambda aux_data, flat_contents: Variable(*flat_contents))

register_pytree_node(TrainVar,
lambda t: ((t.value,), None),
lambda aux_data, flat_contents: TrainVar(*flat_contents))

register_pytree_node(Parameter,
lambda t: ((t.value,), None),
lambda aux_data, flat_contents: Parameter(*flat_contents))
def __init__(self, value_or_size, dtype=None, batch_axis: int = None):
super(Parameter, self).__init__(value_or_size, dtype=dtype, batch_axis=batch_axis)


class VariableView(Variable):
"""A view of a Variable instance.

This class is used to create a slice view of ``brainpy.math.Variable``.
This class is used to create a subset view of ``brainpy.math.Variable``.

>>> import brainpy.math as bm
>>> bm.random.seed(123)
>>> origin = bm.Variable(bm.random.random(5))
>>> view = bm.VariableView(origin, slice(None, 2, None)) # origin[:2]
VariableView([0.02920651, 0.19066381], dtype=float32)

``VariableView`` can be used to update the subset of the original
Variable instance, and make operations on this subset of the Variable.

>>> view[:] = 1.
>>> view
VariableView([1., 1.], dtype=float32)
>>> origin
Variable([1. , 1. , 0.5482849, 0.6564884, 0.8446237], dtype=float32)
>>> view + 10
DeviceArray([11., 11.], dtype=float32)
>>> view *= 10
VariableView([10., 10.], dtype=float32)

The above example demonstrates that the updating of an ``VariableView`` instance
is actually made in the original ``Variable`` instance.

Moreover, it's worthy to note that ``VariableView`` is not a PyTree.
"""

def __init__(self, value: Variable, index):
self.index = index
if not isinstance(value, Variable):
raise ValueError('Must be instance of Variable.')
temp_shape = tuple([1] * len(index))
super(VariableView, self).__init__(jnp.zeros(temp_shape), batch_axis=value.batch_axis)
super(VariableView, self).__init__(value.value, batch_axis=value.batch_axis)
self._value = value

@property
def value(self):
return self._value[self.index]

@value.setter
def value(self, value):
int_shape = self.shape
if self.batch_axis is None:
ext_shape = value.shape
else:
ext_shape = value.shape[:self.batch_axis] + value.shape[self.batch_axis + 1:]
int_shape = int_shape[:self.batch_axis] + int_shape[self.batch_axis + 1:]
if ext_shape != int_shape:
error = f"The shape of the original data is {int_shape}, while we got {value.shape}"
if self.batch_axis is None:
error += '. Do you forget to set "batch_axis" when initialize this variable?'
else:
error += f' with batch_axis={self.batch_axis}.'
raise MathError(error)
if value.dtype != self._value.dtype:
raise MathError(f"The dtype of the original data is {self._value.dtype}, "
f"while we got {value.dtype}.")
self._value[self.index] = value

def __setitem__(self, index, value):
# value is JaxArray
if isinstance(value, JaxArray):
value = value.value
elif isinstance(value, np.ndarray):
value = jnp.asarray(value)

# tuple index
if isinstance(index, tuple):
index = tuple(x.value if isinstance(x, JaxArray) else x for x in index)
index = tuple(_check_input_array(x) for x in index)

# JaxArray index
elif isinstance(index, JaxArray):
@@ -1552,67 +1608,67 @@ class VariableView(Variable):

def __iadd__(self, oc):
# a += b
self._value[self.index] = self.value + (oc.value if isinstance(oc, JaxArray) else oc)
self._value[self.index] = self.value + _check_input_array(oc)
return self

def __isub__(self, oc):
# a -= b
self._value[self.index] = self.value - (oc.value if isinstance(oc, JaxArray) else oc)
self._value[self.index] = self.value - _check_input_array(oc)
return self

def __imul__(self, oc):
# a *= b
self._value[self.index] = self.value * (oc.value if isinstance(oc, JaxArray) else oc)
self._value[self.index] = self.value * _check_input_array(oc)
return self

def __itruediv__(self, oc):
# a /= b
self._value[self.index] = self.value / (oc.value if isinstance(oc, JaxArray) else oc)
self._value[self.index] = self.value / _check_input_array(oc)
return self

def __ifloordiv__(self, oc):
# a //= b
self._value[self.index] = self.value // (oc.value if isinstance(oc, JaxArray) else oc)
self._value[self.index] = self.value // _check_input_array(oc)
return self

def __imod__(self, oc):
# a %= b
self._value[self.index] = self.value % (oc.value if isinstance(oc, JaxArray) else oc)
self._value[self.index] = self.value % _check_input_array(oc)
return self

def __ipow__(self, oc):
# a **= b
self._value[self.index] = self.value ** (oc.value if isinstance(oc, JaxArray) else oc)
self._value[self.index] = self.value ** _check_input_array(oc)
return self

def __imatmul__(self, oc):
# a @= b
self._value[self.index] = self.value @ (oc.value if isinstance(oc, JaxArray) else oc)
self._value[self.index] = self.value @ _check_input_array(oc)
return self

def __iand__(self, oc):
# a &= b
self._value[self.index] = self.value.__and__(oc.value if isinstance(oc, JaxArray) else oc)
self._value[self.index] = self.value.__and__(_check_input_array(oc))
return self

def __ior__(self, oc):
# a |= b
self._value[self.index] = self.value | (oc.value if isinstance(oc, JaxArray) else oc)
self._value[self.index] = self.value | _check_input_array(oc)
return self

def __ixor__(self, oc):
# a ^= b
self._value[self.index] = self.value ^ (oc.value if isinstance(oc, JaxArray) else oc)
self._value[self.index] = self.value ^ _check_input_array(oc)
return self

def __ilshift__(self, oc):
# a <<= b
self._value[self.index] = self.value << (oc.value if isinstance(oc, JaxArray) else oc)
self._value[self.index] = self.value << _check_input_array(oc)
return self

def __irshift__(self, oc):
# a >>= b
self._value[self.index] = self.value >> (oc.value if isinstance(oc, JaxArray) else oc)
self._value[self.index] = self.value >> _check_input_array(oc)
return self

def fill(self, value):
@@ -1622,3 +1678,65 @@ class VariableView(Variable):
def sort(self, axis=-1, kind=None, order=None):
"""Sort an array in-place."""
self._value[self.index] = self.value.sort(axis=axis, kind=kind, order=order)

def update(self, value):
if self.batch_axis is None:
ext_shape = value.shape
int_shape = self.shape
else:
ext_shape = value.shape[:self.batch_axis] + value.shape[self.batch_axis + 1:]
int_shape = self.shape[:self.batch_axis] + self.shape[self.batch_axis + 1:]
if ext_shape != int_shape:
error = f"The shape of the original data is {self.shape}, while we got {value.shape}"
if self.batch_axis is None:
error += '. Do you forget to set "batch_axis" when initialize this variable?'
else:
error += f' with batch_axis={self.batch_axis}.'
raise MathError(error)
if value.dtype != self._value.dtype:
raise MathError(f"The dtype of the original data is {self._value.dtype}, "
f"while we got {value.dtype}.")
self._value[self.index] = value.value if isinstance(value, JaxArray) else value

@value.setter
def value(self, value):
int_shape = self.shape
if self.batch_axis is None:
ext_shape = value.shape
else:
ext_shape = value.shape[:self.batch_axis] + value.shape[self.batch_axis + 1:]
int_shape = int_shape[:self.batch_axis] + int_shape[self.batch_axis + 1:]
if ext_shape != int_shape:
error = f"The shape of the original data is {int_shape}, while we got {value.shape}"
if self.batch_axis is None:
error += '. Do you forget to set "batch_axis" when initialize this variable?'
else:
error += f' with batch_axis={self.batch_axis}.'
raise MathError(error)
if value.dtype != self._value.dtype:
raise MathError(f"The dtype of the original data is {self._value.dtype}, "
f"while we got {value.dtype}.")
self._value[self.index] = value.value if isinstance(value, JaxArray) else value


def _jaxarray_unflatten(aux_data, flat_contents):
r = JaxArray(*flat_contents)
r._transform_context = aux_data[0]
return r


register_pytree_node(JaxArray,
lambda t: ((t.value,), (t._transform_context, )),
_jaxarray_unflatten)

register_pytree_node(Variable,
lambda t: ((t.value,), None),
lambda aux_data, flat_contents: Variable(*flat_contents))

register_pytree_node(TrainVar,
lambda t: ((t.value,), None),
lambda aux_data, flat_contents: TrainVar(*flat_contents))

register_pytree_node(Parameter,
lambda t: ((t.value,), None),
lambda aux_data, flat_contents: Parameter(*flat_contents))

+ 13
- 9
brainpy/math/jit.py View File

@@ -15,12 +15,13 @@ import jax
try:
from jax.errors import UnexpectedTracerError, ConcretizationTypeError
except ImportError:
from jax.core import UnexpectedTracerError
from jax.core import UnexpectedTracerError, ConcretizationTypeError

from brainpy import errors
from brainpy.base.base import Base
from brainpy.base.naming import get_unique_name
from brainpy.base.collector import TensorCollector
from brainpy.math.jaxarray import JaxArray, turn_on_global_jit, turn_off_global_jit
from brainpy.math.jaxarray import JaxArray, add_context, del_context
from brainpy.tools.codes import change_func_name

__all__ = [
@@ -38,22 +39,24 @@ def _make_jit_with_vars(func, vars, static_argnames=None, device=None, f_name=No
changes = vars.dict()
return out, changes

name = get_unique_name('_brainpy_object_oriented_jit_')

def call(*args, **kwargs):
variable_data = vars.dict()
try:
turn_on_global_jit()
add_context(name)
out, changes = jitted_func(variable_data, *args, **kwargs)
turn_off_global_jit()
del_context(name)
except UnexpectedTracerError as e:
turn_off_global_jit()
del_context(name)
for key, v in vars.items(): v._value = variable_data[key]
raise errors.JaxTracerError(variables=vars) from e
except ConcretizationTypeError as e:
turn_off_global_jit()
del_context(name)
for key, v in vars.items(): v._value = variable_data[key]
raise errors.ConcretizationTypeError() from e
except Exception as e:
turn_off_global_jit()
del_context(name)
for key, v in vars.items(): v._value = variable_data[key]
raise e
for key, v in vars.items(): v._value = changes[key]
@@ -64,11 +67,12 @@ def _make_jit_with_vars(func, vars, static_argnames=None, device=None, f_name=No

def _make_jit_without_vars(func, static_argnames=None, device=None, f_name=None):
jit_f = jax.jit(func, static_argnames=static_argnames, device=device)
name = get_unique_name('_jax_functional_jit_')

def call(*args, **kwargs):
turn_on_global_jit()
add_context(name)
r = jit_f(*args, **kwargs)
turn_off_global_jit()
del_context(name)
return r

return change_func_name(name=f_name, f=call) if f_name else call


+ 7
- 12
brainpy/math/operators/__init__.py View File

@@ -1,23 +1,18 @@
# -*- coding: utf-8 -*-


from . import multiplication
from . import sparse_matmul, event_matmul
from . import op_register
from . import pre2syn as pre2syn_module
from . import pre2post as pre2post_module
from . import syn2post as syn2post_module
from . import pre_syn_post as pre_syn_post_module
from . import wrap_jax
from . import spikegrad

__all__ = multiplication.__all__ + op_register.__all__
__all__ += pre2syn_module.__all__ + pre2post_module.__all__ + syn2post_module.__all__
__all__ += wrap_jax.__all__ + spikegrad.__all__
__all__ = (event_matmul.__all__ + sparse_matmul.__all__ + op_register.__all__
+ pre_syn_post_module.__all__ + wrap_jax.__all__ + spikegrad.__all__)

from .multiplication import *
from .event_matmul import *
from .sparse_matmul import *
from .op_register import *
from .pre2syn import *
from .pre2post import *
from .syn2post import *
from .pre_syn_post import *
from .wrap_jax import *
from .spikegrad import *

+ 52
- 0
brainpy/math/operators/event_matmul.py View File

@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-


from typing import Tuple

import brainpylib

from brainpy.math.numpy_ops import as_jax
from brainpy.types import Array

__all__ = [
'event_csr_matvec',
]


def event_csr_matvec(values: Array,
indices: Array,
indptr: Array,
events: Array,
shape: Tuple[int, ...],
transpose: bool = False):
"""The pre-to-post event-driven synaptic summation with `CSR` synapse structure.

Parameters
----------
values: Array, float
An array of shape ``(nse,)`` or a float.
indices: Array
An array of shape ``(nse,)``.
indptr: Array
An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``.
events: Array
An array of shape ``(shape[0] if transpose else shape[1],)``
and dtype ``data.dtype``.
shape: tuple of int
A length-2 tuple representing the sparse matrix shape.
transpose: bool
A boolean specifying whether to transpose the sparse matrix
before computing. Default is False.

Returns
-------
out: Array
A tensor with the shape of ``shape[1]`` if `transpose=True`,
or ``shape[0]`` if `transpose=False`.
"""
events = as_jax(events)
indices = as_jax(indices)
indptr = as_jax(indptr)
values = as_jax(values)
return brainpylib.event_csr_matvec(values, indices, indptr, events,
shape=shape, transpose=transpose)

+ 18
- 59
brainpy/math/operators/op_register.py View File

@@ -1,22 +1,15 @@
# -*- coding: utf-8 -*-

from typing import Union, Sequence, Callable
from typing import Callable

from jax.abstract_arrays import ShapedArray
import brainpylib
from jax.tree_util import tree_map

from brainpy.base import Base
from brainpy.math.jaxarray import JaxArray
from .utils import _check_brainpylib

try:
import brainpylib
except ModuleNotFoundError:
brainpylib = None

__all__ = [
'XLACustomOp',
'register_op',
]


@@ -57,8 +50,11 @@ class XLACustomOp(Base):
gpu_func: Callable = None,
apply_cpu_func_to_gpu: bool = False,
name: str = None,
batching_translation: Callable = None,
jvp_translation: Callable = None,
transpose_translation: Callable = None,
multiple_results: bool = False,
):
_check_brainpylib(register_op.__name__)
super(XLACustomOp, self).__init__(name=name)

# abstract evaluation function
@@ -77,11 +73,17 @@ class XLACustomOp(Base):
gpu_func = None

# register OP
self.op = brainpylib.register_op(self.name,
cpu_func=cpu_func,
gpu_func=gpu_func,
out_shapes=eval_shape,
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)
self.op = brainpylib.register_op_with_numba(
self.name,
cpu_func=cpu_func,
gpu_func_translation=gpu_func,
out_shapes=eval_shape,
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu,
batching_translation=batching_translation,
jvp_translation=jvp_translation,
transpose_translation=transpose_translation,
multiple_results=multiple_results,
)

def __call__(self, *args, **kwargs):
args = tree_map(lambda a: a.value if isinstance(a, JaxArray) else a,
@@ -89,48 +91,5 @@ class XLACustomOp(Base):
kwargs = tree_map(lambda a: a.value if isinstance(a, JaxArray) else a,
kwargs, is_leaf=lambda a: isinstance(a, JaxArray))
res = self.op.bind(*args, **kwargs)
return res[0] if len(res) == 1 else res


def register_op(
name: str,
eval_shape: Union[Callable, ShapedArray, Sequence[ShapedArray]],
cpu_func: Callable,
gpu_func: Callable = None,
apply_cpu_func_to_gpu: bool = False
):
"""
Converting the numba-jitted function in a Jax/XLA compatible primitive.
return res

Parameters
----------
name: str
Name of the operators.
cpu_func: Callble
A callable numba-jitted function or pure function (can be lambda function) running on CPU.
gpu_func: Callable, default = None
A callable cuda-jitted kernel running on GPU.
eval_shape: Callable, ShapedArray, Sequence[ShapedArray], default = None
Outputs shapes of target function. `out_shapes` can be a `ShapedArray` or
a sequence of `ShapedArray`. If it is a function, it takes as input the argument
shapes and dtypes and should return correct output shapes of `ShapedArray`.
apply_cpu_func_to_gpu: bool, default = False
True when gpu_func is implemented on CPU and other logics(data transfer) is implemented on GPU.

Returns
-------
A jitable JAX function.
"""
_check_brainpylib(register_op.__name__)
f = brainpylib.register_op(name,
cpu_func=cpu_func,
gpu_func=gpu_func,
out_shapes=eval_shape,
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)

def fixed_op(*inputs):
inputs = tuple([i.value if isinstance(i, JaxArray) else i for i in inputs])
res = f.bind(*inputs)
return res[0] if len(res) == 1 else res

return fixed_op

+ 0
- 489
brainpy/math/operators/pre2post.py View File

@@ -1,489 +0,0 @@
# -*- coding: utf-8 -*-

from functools import partial
from typing import Union, Tuple

import jax.numpy as jnp
from jax import vmap, jit
from jax.lax import cond

from brainpy.errors import MathError
from brainpy.math.jaxarray import JaxArray
from brainpy.math.numpy_ops import as_device_array
from brainpy.types import Array
from .pre2syn import pre2syn
from .syn2post import syn2post_mean
from .utils import _check_brainpylib

try:
import brainpylib
except ModuleNotFoundError:
brainpylib = None

__all__ = [
# pre-to-post
'pre2post_sum',
'pre2post_prod',
'pre2post_max',
'pre2post_min',
'pre2post_mean',

# pre-to-post event operator
'pre2post_event_sum',
'pre2post_event_prod',

]


def _raise_pre_ids_is_none(pre_ids):
if pre_ids is None:
raise MathError(f'pre2post synaptic computation needs "pre_ids" '
f'when providing heterogeneous "pre_values" '
f'(brainpy.math.ndim(pre_values) != 0).')


def pre2post_event_sum(events: Array,
pre2post: Tuple[Array, Array],
post_num: int,
values: Union[float, Array] = 1.):
"""The pre-to-post synaptic computation with event-driven summation.

When ``values`` is a scalar, this function is equivalent to

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
post_ids, idnptr = pre2post
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
post_val[post_ids[i]] += values

When ``values`` is a vector (with the length of ``len(post_ids)``),
this function is equivalent to

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)

post_ids, idnptr = pre2post
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
post_val[post_ids[i]] += values[j]


Parameters
----------
events: Array
The events, must be bool.
pre2post: tuple of Array, tuple of Array
A tuple contains the connection information of pre-to-post.
post_num: int
The number of post-synaptic group.
values: float, Array
The value to make summation.

Returns
-------
out: JaxArray, jax.numpy.ndarray
A tensor with the shape of ``post_num``.
"""
_check_brainpylib(pre2post_event_sum.__name__)
indices, idnptr = pre2post
events = as_device_array(events)
indices = as_device_array(indices)
idnptr = as_device_array(idnptr)
values = as_device_array(values)
return brainpylib.event_sum(events, (indices, idnptr), post_num, values)


def pre2post_event_sum2(events: Array,
pre2post: Tuple[Array, Array],
post_num: int,
values: Union[float, Array] = 1.):
"""The pre-to-post synaptic computation with event-driven summation.

When ``values`` is a scalar, this function is equivalent to

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
post_ids, idnptr = pre2post
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
post_val[post_ids[i]] += values

When ``values`` is a vector (with the length of ``len(post_ids)``),
this function is equivalent to

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)

post_ids, idnptr = pre2post
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
post_val[post_ids[i]] += values[j]


Parameters
----------
events: Array
The events, must be bool.
pre2post: tuple of Array, tuple of Array
A tuple contains the connection information of pre-to-post.
post_num: int
The number of post-synaptic group.
values: float, Array
The value to make summation.

Returns
-------
out: JaxArray, jax.numpy.ndarray
A tensor with the shape of ``post_num``.
"""
_check_brainpylib(pre2post_event_sum.__name__)
indices, idnptr = pre2post
events = as_device_array(events)
indices = as_device_array(indices)
idnptr = as_device_array(idnptr)
values = as_device_array(values)
return brainpylib.event_sum2(events, (indices, idnptr), post_num, values)


def pre2post_event_prod(events, pre2post, post_num, values=1.):
"""The pre-to-post synaptic computation with event-driven production.

When ``values`` is a scalar, this function is equivalent to

.. highlight:: python
.. code-block:: python

post_val = np.ones(post_num)
post_ids, idnptr = pre2post
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
post_val[post_ids[i]] *= values

When ``values`` is a vector (with the length of ``len(post_ids)``),
this function is equivalent to

.. highlight:: python
.. code-block:: python

post_val = np.ones(post_num)

post_ids, idnptr = pre2post
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
post_val[post_ids[i]] *= values[j]


Parameters
----------
events: JaxArray, jax.numpy.ndarray, Variable
The events, must be bool.
pre2post: tuple of JaxArray, tuple of jax.numpy.ndarray
A tuple contains the connection information of pre-to-post.
post_num: int
The number of post-synaptic group.
values: float, JaxArray, jax.numpy.ndarray
The value to make summation.

Returns
-------
out: JaxArray, jax.numpy.ndarray
A tensor with the shape of ``post_num``.
"""
_check_brainpylib(pre2post_event_prod.__name__)
indices, idnptr = pre2post
events = as_device_array(events)
indices = as_device_array(indices)
idnptr = as_device_array(idnptr)
values = as_device_array(values)
return brainpylib.event_prod(events, (indices, idnptr), post_num, values)


def pre2post_sum(pre_values, post_num, post_ids, pre_ids=None):
"""The pre-to-post synaptic summation.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
for i, j in zip(pre_ids, post_ids):
post_val[j] += pre_values[pre_ids[i]]

Parameters
----------
pre_values: float, jax.numpy.ndarray, JaxArray, Variable
The pre-synaptic values.
post_ids: jax.numpy.ndarray, JaxArray
The connected post-synaptic neuron ids.
post_num: int
Output dimension. The number of post-synaptic neurons.
pre_ids: optional, jax.numpy.ndarray, JaxArray
The connected pre-synaptic neuron ids.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The value with the size of post-synaptic neurons.
"""
out = jnp.zeros(post_num)
pre_values = as_device_array(pre_values)
post_ids = as_device_array(post_ids)
if jnp.ndim(pre_values) != 0:
_raise_pre_ids_is_none(pre_ids)
pre_ids = as_device_array(pre_ids)
pre_values = pre_values[pre_ids]
return out.at[post_ids].add(pre_values)


def pre2post_prod(pre_values, post_num, post_ids, pre_ids=None):
"""The pre-to-post synaptic production.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
for i, j in zip(pre_ids, post_ids):
post_val[j] *= pre_values[pre_ids[i]]

Parameters
----------
pre_values: float, jax.numpy.ndarray, JaxArray, Variable
The pre-synaptic values.
pre_ids: jax.numpy.ndarray, JaxArray
The connected pre-synaptic neuron ids.
post_ids: jax.numpy.ndarray, JaxArray
The connected post-synaptic neuron ids.
post_num: int
Output dimension. The number of post-synaptic neurons.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The value with the size of post-synaptic neurons.
"""
out = jnp.zeros(post_num)
pre_values = as_device_array(pre_values)
post_ids = as_device_array(post_ids)
if jnp.ndim(pre_values) != 0:
_raise_pre_ids_is_none(pre_ids)
pre_ids = as_device_array(pre_ids)
pre_values = pre_values[pre_ids]
return out.at[post_ids].multiply(pre_values)


def pre2post_min(pre_values, post_num, post_ids, pre_ids=None):
"""The pre-to-post synaptic minimization.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
for i, j in zip(pre_ids, post_ids):
post_val[j] = np.minimum(post_val[j], pre_values[pre_ids[i]])

Parameters
----------
pre_values: float, jax.numpy.ndarray, JaxArray
The pre-synaptic values.
pre_ids: jax.numpy.ndarray, JaxArray
The connected pre-synaptic neuron ids.
post_ids: jax.numpy.ndarray, JaxArray
The connected post-synaptic neuron ids.
post_num: int
Output dimension. The number of post-synaptic neurons.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The value with the size of post-synaptic neurons.
"""
out = jnp.zeros(post_num)
pre_values = as_device_array(pre_values)
post_ids = as_device_array(post_ids)
if jnp.ndim(pre_values) != 0:
_raise_pre_ids_is_none(pre_ids)
pre_ids = as_device_array(pre_ids)
pre_values = pre_values[pre_ids]
return out.at[post_ids].min(pre_values)


def pre2post_max(pre_values, post_num, post_ids, pre_ids=None):
"""The pre-to-post synaptic maximization.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
for i, j in zip(pre_ids, post_ids):
post_val[j] = np.maximum(post_val[j], pre_values[pre_ids[i]])

Parameters
----------
pre_values: float, jax.numpy.ndarray, JaxArray, Variable
The pre-synaptic values.
pre_ids: jax.numpy.ndarray, JaxArray
The connected pre-synaptic neuron ids.
post_ids: jax.numpy.ndarray, JaxArray
The connected post-synaptic neuron ids.
post_num: int
Output dimension. The number of post-synaptic neurons.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The value with the size of post-synaptic neurons.
"""
out = jnp.zeros(post_num)
pre_values = as_device_array(pre_values)
post_ids = as_device_array(post_ids)
if jnp.ndim(pre_values) != 0:
_raise_pre_ids_is_none(pre_ids)
pre_ids = as_device_array(pre_ids)
pre_values = pre_values[pre_ids]
return out.at[post_ids].max(pre_values)


def pre2post_mean(pre_values, post_num, post_ids, pre_ids=None):
"""The pre-to-post synaptic mean computation.

Parameters
----------
pre_values: float, jax.numpy.ndarray, JaxArray, Variable
The pre-synaptic values.
pre_ids: jax.numpy.ndarray, JaxArray
The connected pre-synaptic neuron ids.
post_ids: jax.numpy.ndarray, JaxArray
The connected post-synaptic neuron ids.
post_num: int
Output dimension. The number of post-synaptic neurons.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The value with the size of post-synaptic neurons.
"""
out = jnp.zeros(post_num)
pre_values = as_device_array(pre_values)
post_ids = as_device_array(post_ids)
if jnp.ndim(pre_values) == 0:
# return out.at[post_ids].set(pre_values)
return out.at[jnp.unique(post_ids)].set(pre_values)
else:
_raise_pre_ids_is_none(pre_ids)
pre_ids = as_device_array(pre_ids)
pre_values = pre2syn(pre_values, pre_ids)
return syn2post_mean(pre_values, post_ids, post_num)


def pre2post_matmul(event, conn):
event = event.value if isinstance(event, JaxArray) else event
Cl = conn[0].value if isinstance(conn[0], JaxArray) else conn[0]
Cr = conn[1].value if isinstance(conn[1], JaxArray) else conn[1]
if jnp.ndim(event) != 1:
raise ValueError(f'"event" must be a one-dimensional vector. But we got {jnp.shape(event)}')
if jnp.ndim(Cl) != 2:
raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cl)}')
if jnp.ndim(Cr) != 2:
raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cr)}')

f0 = vmap(lambda i, j: event[i] * (Cl[i] * Cr[:, j]).sum(), in_axes=(0, None))
ii = jnp.arange(Cl.shape[0])
f1 = vmap(lambda j: f0(ii, j).sum(), in_axes=(None, 0))
return f1(jnp.arange(Cr.shape[1]))


def pre2post_matmul2(event, conn):
event = event.value if isinstance(event, JaxArray) else event
Cl = conn[0].value if isinstance(conn[0], JaxArray) else conn[0]
Cr = conn[1].value if isinstance(conn[1], JaxArray) else conn[1]
if jnp.ndim(event) != 1:
raise ValueError(f'"event" must be a one-dimensional vector. But we got {jnp.shape(event)}')
if jnp.ndim(Cl) != 2:
raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cl)}')
if jnp.ndim(Cr) != 2:
raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cr)}')
f1 = vmap(lambda j: (event * (Cl * Cr[:, j]).sum(1)).sum())
return f1(jnp.arange(Cr.shape[1]))


def pre2post_matmul_mask(event, conn, mask):
event = event.value if isinstance(event, JaxArray) else event
Cl = conn[0].value if isinstance(conn[0], JaxArray) else conn[0]
Cr = conn[1].value if isinstance(conn[1], JaxArray) else conn[1]
Ml = mask[0].value if isinstance(mask[0], JaxArray) else mask[0]
Mr = mask[1].value if isinstance(mask[1], JaxArray) else mask[1]
if jnp.ndim(event) != 1:
raise ValueError(f'"event" must be a one-dimensional vector. But we got {jnp.shape(event)}')
if jnp.ndim(Cl) != 2:
raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cl)}')
if jnp.ndim(Cr) != 2:
raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cr)}')
if jnp.ndim(Mr) != 2:
raise ValueError(f'"mask" must be a two-dimensional matrix. But we got {jnp.shape(Mr)}')
if jnp.ndim(Ml) != 2:
raise ValueError(f'"mask" must be a two-dimensional matrix. But we got {jnp.shape(Ml)}')

f0 = vmap(lambda i, j: event[i] * (Cl[i] * Cr[:, j]).sum() * (Ml[i] * Mr[:, j]).sum(), in_axes=(0, None))
f1 = jit(vmap(lambda ii, j: f0(ii, j).sum(), in_axes=(None, 0)))
return f1(jnp.arange(Cl.shape[0]), jnp.arange(Cr.shape[1]))


def pre2post_matmul_mask2(event, conn, mask):
event = event.value if isinstance(event, JaxArray) else event
Cl = conn[0].value if isinstance(conn[0], JaxArray) else conn[0]
Cr = conn[1].value if isinstance(conn[1], JaxArray) else conn[1]
Ml = mask[0].value if isinstance(mask[0], JaxArray) else mask[0]
Mr = mask[1].value if isinstance(mask[1], JaxArray) else mask[1]
if jnp.ndim(event) != 1:
raise ValueError(f'"event" must be a one-dimensional vector. But we got {jnp.shape(event)}')
if jnp.ndim(Cl) != 2:
raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cl)}')
if jnp.ndim(Cr) != 2:
raise ValueError(f'"conn" must be a two-dimensional matrix. But we got {jnp.shape(Cr)}')
if jnp.ndim(Mr) != 2:
raise ValueError(f'"mask" must be a two-dimensional matrix. But we got {jnp.shape(Mr)}')
if jnp.ndim(Ml) != 2:
raise ValueError(f'"mask" must be a two-dimensional matrix. But we got {jnp.shape(Ml)}')

# f0 = vmap(lambda i, j: event[i] * (Cl[i] * Cr[:, j]).sum() * (Ml[i] * Mr[:, j]).sum(), in_axes=(0, None))
@partial(vmap, in_axes=(0, None))
def f0(i, j):
return cond(event[i],
lambda: cond(Ml[i] @ Mr[:, j],
lambda: (Cl[i] * Cr[:, j]).sum(),
lambda: 0.),
lambda: 0.)

ii = jnp.arange(Cl.shape[0])
jj = jnp.arange(Cr.shape[1])

# def body(_, j):
# r = f0(ii, j).sum()
# return 0, r
# _, out = scan(body, 0, jj)
# return out

f = jit(vmap(lambda j: f0(ii, j).sum()))
return f(jj)

+ 0
- 47
brainpy/math/operators/pre2syn.py View File

@@ -1,47 +0,0 @@
# -*- coding: utf-8 -*-

import jax.numpy as jnp
from jax import vmap

from brainpy.math.numpy_ops import as_device_array

__all__ = [
'pre2syn'
]


_pre2syn = vmap(lambda pre_id, pre_vs: pre_vs[pre_id], in_axes=(0, None))


def pre2syn(pre_values, pre_ids):
"""The pre-to-syn computation.

Change the pre-synaptic data to the data with the dimension of synapses.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

syn_val = np.zeros(len(pre_ids))
for syn_i, pre_i in enumerate(pre_ids):
syn_val[i] = pre_values[pre_i]

Parameters
----------
pre_values: float, jax.numpy.ndarray, JaxArray, Variable
The pre-synaptic value.
pre_ids: jax.numpy.ndarray, JaxArray
The pre-synaptic neuron index.

Returns
-------
syn_val: jax.numpy.ndarray, JaxArray
The synaptic value.
"""
pre_values = as_device_array(pre_values)
pre_ids = as_device_array(pre_ids)
if jnp.ndim(pre_values) == 0:
return jnp.ones(len(pre_ids), dtype=pre_values.dtype) * pre_values
else:
return _pre2syn(pre_ids, pre_values)

+ 625
- 0
brainpy/math/operators/pre_syn_post.py View File

@@ -0,0 +1,625 @@
# -*- coding: utf-8 -*-

from typing import Union, Tuple

import brainpylib
import jax.numpy as jnp
from jax import vmap, jit, ops as jops

from brainpy.errors import MathError
from brainpy.math.numpy_ops import as_jax
from brainpy.types import Array

__all__ = [
# pre-to-post
'pre2post_sum',
'pre2post_prod',
'pre2post_max',
'pre2post_min',
'pre2post_mean',

# pre-to-post event operator
'pre2post_event_sum',
'pre2post_coo_event_sum',
'pre2post_event_prod',

# pre-to-syn
'pre2syn',

# syn-to-post
'syn2post_sum', 'syn2post',
'syn2post_prod',
'syn2post_max',
'syn2post_min',
'syn2post_mean',
'syn2post_softmax',
]


def _raise_pre_ids_is_none(pre_ids):
if pre_ids is None:
raise MathError(f'pre2post synaptic computation needs "pre_ids" '
f'when providing heterogeneous "pre_values" '
f'(brainpy.math.ndim(pre_values) != 0).')


def pre2post_event_sum(events: Array,
pre2post: Tuple[Array, Array],
post_num: int,
values: Union[float, Array] = 1.):
"""The pre-to-post event-driven synaptic summation with `CSR` synapse structure.

When ``values`` is a scalar, this function is equivalent to

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
post_ids, idnptr = pre2post
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
post_val[post_ids[i]] += values

When ``values`` is a vector (with the length of ``len(post_ids)``),
this function is equivalent to

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)

post_ids, idnptr = pre2post
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
post_val[post_ids[i]] += values[j]


Parameters
----------
events: Array
The events, must be bool.
pre2post: tuple of Array, tuple of Array
A tuple contains the connection information of pre-to-post.
post_num: int
The number of post-synaptic group.
values: float, Array
The value to make summation.

Returns
-------
out: JaxArray, jax.numpy.ndarray
A tensor with the shape of ``post_num``.
"""
indices, idnptr = pre2post
events = as_jax(events)
indices = as_jax(indices)
idnptr = as_jax(idnptr)
values = as_jax(values)
return brainpylib.event_csr_matvec(values, indices, idnptr, events,
shape=(events.shape[0], post_num),
transpose=True)


def pre2post_coo_event_sum(events: Array,
pre_ids: Array,
post_ids: Array,
post_num: int,
values: Union[float, Array] = 1.):
"""The pre-to-post synaptic computation with event-driven summation.

Parameters
----------
events: Array
The events, must be bool.
pre_ids: Array
Pre-synaptic ids.
post_ids: Array
Post-synaptic ids.
post_num: int
The number of post-synaptic group.
values: float, Array
The value to make summation.

Returns
-------
out: JaxArray, jax.numpy.ndarray
A tensor with the shape of ``post_num``.
"""
events = as_jax(events)
post_ids = as_jax(post_ids)
pre_ids = as_jax(pre_ids)
values = as_jax(values)
return brainpylib.coo_event_sum(events, pre_ids, post_ids, post_num, values)


def pre2post_event_prod(events, pre2post, post_num, values=1.):
"""The pre-to-post synaptic computation with event-driven production.

When ``values`` is a scalar, this function is equivalent to

.. highlight:: python
.. code-block:: python

post_val = np.ones(post_num)
post_ids, idnptr = pre2post
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
post_val[post_ids[i]] *= values

When ``values`` is a vector (with the length of ``len(post_ids)``),
this function is equivalent to

.. highlight:: python
.. code-block:: python

post_val = np.ones(post_num)

post_ids, idnptr = pre2post
for i in range(pre_num):
if events[i]:
for j in range(idnptr[i], idnptr[i+1]):
post_val[post_ids[i]] *= values[j]


Parameters
----------
events: JaxArray, jax.numpy.ndarray, Variable
The events, must be bool.
pre2post: tuple of JaxArray, tuple of jax.numpy.ndarray
A tuple contains the connection information of pre-to-post.
post_num: int
The number of post-synaptic group.
values: float, JaxArray, jax.numpy.ndarray
The value to make summation.

Returns
-------
out: JaxArray, jax.numpy.ndarray
A tensor with the shape of ``post_num``.
"""
indices, idnptr = pre2post
events = as_jax(events)
indices = as_jax(indices)
idnptr = as_jax(idnptr)
values = as_jax(values)
return brainpylib.csr_event_prod(events, (indices, idnptr), post_num, values)


def pre2post_sum(pre_values, post_num, post_ids, pre_ids=None):
"""The pre-to-post synaptic summation.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
for i, j in zip(pre_ids, post_ids):
post_val[j] += pre_values[pre_ids[i]]

Parameters
----------
pre_values: float, jax.numpy.ndarray, JaxArray, Variable
The pre-synaptic values.
post_ids: jax.numpy.ndarray, JaxArray
The connected post-synaptic neuron ids.
post_num: int
Output dimension. The number of post-synaptic neurons.
pre_ids: optional, jax.numpy.ndarray, JaxArray
The connected pre-synaptic neuron ids.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The value with the size of post-synaptic neurons.
"""
out = jnp.zeros(post_num)
pre_values = as_jax(pre_values)
post_ids = as_jax(post_ids)
if jnp.ndim(pre_values) != 0:
_raise_pre_ids_is_none(pre_ids)
pre_ids = as_jax(pre_ids)
pre_values = pre_values[pre_ids]
return out.at[post_ids].add(pre_values)


def pre2post_prod(pre_values, post_num, post_ids, pre_ids=None):
"""The pre-to-post synaptic production.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
for i, j in zip(pre_ids, post_ids):
post_val[j] *= pre_values[pre_ids[i]]

Parameters
----------
pre_values: float, jax.numpy.ndarray, JaxArray, Variable
The pre-synaptic values.
pre_ids: jax.numpy.ndarray, JaxArray
The connected pre-synaptic neuron ids.
post_ids: jax.numpy.ndarray, JaxArray
The connected post-synaptic neuron ids.
post_num: int
Output dimension. The number of post-synaptic neurons.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The value with the size of post-synaptic neurons.
"""
out = jnp.zeros(post_num)
pre_values = as_jax(pre_values)
post_ids = as_jax(post_ids)
if jnp.ndim(pre_values) != 0:
_raise_pre_ids_is_none(pre_ids)
pre_ids = as_jax(pre_ids)
pre_values = pre_values[pre_ids]
return out.at[post_ids].multiply(pre_values)


def pre2post_min(pre_values, post_num, post_ids, pre_ids=None):
"""The pre-to-post synaptic minimization.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
for i, j in zip(pre_ids, post_ids):
post_val[j] = np.minimum(post_val[j], pre_values[pre_ids[i]])

Parameters
----------
pre_values: float, jax.numpy.ndarray, JaxArray
The pre-synaptic values.
pre_ids: jax.numpy.ndarray, JaxArray
The connected pre-synaptic neuron ids.
post_ids: jax.numpy.ndarray, JaxArray
The connected post-synaptic neuron ids.
post_num: int
Output dimension. The number of post-synaptic neurons.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The value with the size of post-synaptic neurons.
"""
out = jnp.zeros(post_num)
pre_values = as_jax(pre_values)
post_ids = as_jax(post_ids)
if jnp.ndim(pre_values) != 0:
_raise_pre_ids_is_none(pre_ids)
pre_ids = as_jax(pre_ids)
pre_values = pre_values[pre_ids]
return out.at[post_ids].min(pre_values)


def pre2post_max(pre_values, post_num, post_ids, pre_ids=None):
"""The pre-to-post synaptic maximization.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
for i, j in zip(pre_ids, post_ids):
post_val[j] = np.maximum(post_val[j], pre_values[pre_ids[i]])

Parameters
----------
pre_values: float, jax.numpy.ndarray, JaxArray, Variable
The pre-synaptic values.
pre_ids: jax.numpy.ndarray, JaxArray
The connected pre-synaptic neuron ids.
post_ids: jax.numpy.ndarray, JaxArray
The connected post-synaptic neuron ids.
post_num: int
Output dimension. The number of post-synaptic neurons.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The value with the size of post-synaptic neurons.
"""
out = jnp.zeros(post_num)
pre_values = as_jax(pre_values)
post_ids = as_jax(post_ids)
if jnp.ndim(pre_values) != 0:
_raise_pre_ids_is_none(pre_ids)
pre_ids = as_jax(pre_ids)
pre_values = pre_values[pre_ids]
return out.at[post_ids].max(pre_values)


def pre2post_mean(pre_values, post_num, post_ids, pre_ids=None):
"""The pre-to-post synaptic mean computation.

Parameters
----------
pre_values: float, jax.numpy.ndarray, JaxArray, Variable
The pre-synaptic values.
pre_ids: jax.numpy.ndarray, JaxArray
The connected pre-synaptic neuron ids.
post_ids: jax.numpy.ndarray, JaxArray
The connected post-synaptic neuron ids.
post_num: int
Output dimension. The number of post-synaptic neurons.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The value with the size of post-synaptic neurons.
"""
out = jnp.zeros(post_num)
pre_values = as_jax(pre_values)
post_ids = as_jax(post_ids)
if jnp.ndim(pre_values) == 0:
return out.at[post_ids].set(pre_values)
# return out.at[jnp.unique(post_ids)].set(pre_values)
else:
_raise_pre_ids_is_none(pre_ids)
pre_ids = as_jax(pre_ids)
pre_values = pre2syn(pre_values, pre_ids)
return syn2post_mean(pre_values, post_ids, post_num)


_pre2syn = vmap(lambda pre_id, pre_vs: pre_vs[pre_id], in_axes=(0, None))


def pre2syn(pre_values, pre_ids):
"""The pre-to-syn computation.

Change the pre-synaptic data to the data with the dimension of synapses.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

syn_val = np.zeros(len(pre_ids))
for syn_i, pre_i in enumerate(pre_ids):
syn_val[i] = pre_values[pre_i]

Parameters
----------
pre_values: float, jax.numpy.ndarray, JaxArray, Variable
The pre-synaptic value.
pre_ids: jax.numpy.ndarray, JaxArray
The pre-synaptic neuron index.

Returns
-------
syn_val: jax.numpy.ndarray, JaxArray
The synaptic value.
"""
pre_values = as_jax(pre_values)
pre_ids = as_jax(pre_ids)
if jnp.ndim(pre_values) == 0:
return jnp.ones(len(pre_ids), dtype=pre_values.dtype) * pre_values
else:
return _pre2syn(pre_ids, pre_values)


_jit_seg_sum = jit(jops.segment_sum, static_argnums=(2, 3))
_jit_seg_prod = jit(jops.segment_prod, static_argnums=(2, 3))
_jit_seg_max = jit(jops.segment_max, static_argnums=(2, 3))
_jit_seg_min = jit(jops.segment_min, static_argnums=(2, 3))


def syn2post_sum(syn_values, post_ids, post_num: int, indices_are_sorted=False):
"""The syn-to-post summation computation.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
for syn_i, post_i in enumerate(post_ids):
post_val[post_i] += syn_values[syn_i]

Parameters
----------
syn_values: jax.numpy.ndarray, JaxArray, Variable
The synaptic values.
post_ids: jax.numpy.ndarray, JaxArray
The post-synaptic neuron ids.
post_num: int
The number of the post-synaptic neurons.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The post-synaptic value.
"""
post_ids = as_jax(post_ids)
syn_values = as_jax(syn_values)
if syn_values.dtype == jnp.bool_:
syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
return _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted)


syn2post = syn2post_sum


def syn2post_prod(syn_values, post_ids, post_num: int, indices_are_sorted=False):
"""The syn-to-post product computation.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
for syn_i, post_i in enumerate(post_ids):
post_val[post_i] *= syn_values[syn_i]

Parameters
----------
syn_values: jax.numpy.ndarray, JaxArray, Variable
The synaptic values.
post_ids: jax.numpy.ndarray, JaxArray
The post-synaptic neuron ids. If ``post_ids`` is generated by
``brainpy.conn.TwoEndConnector``, then it has sorted indices.
Otherwise, this function cannot guarantee indices are sorted.
You's better set ``indices_are_sorted=False``.
post_num: int
The number of the post-synaptic neurons.
indices_are_sorted: whether ``post_ids`` is known to be sorted.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The post-synaptic value.
"""
post_ids = as_jax(post_ids)
syn_values = as_jax(syn_values)
if syn_values.dtype == jnp.bool_:
syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
return _jit_seg_prod(syn_values, post_ids, post_num, indices_are_sorted)


def syn2post_max(syn_values, post_ids, post_num: int, indices_are_sorted=False):
"""The syn-to-post maximum computation.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
for syn_i, post_i in enumerate(post_ids):
post_val[post_i] = np.maximum(post_val[post_i], syn_values[syn_i])

Parameters
----------
syn_values: jax.numpy.ndarray, JaxArray, Variable
The synaptic values.
post_ids: jax.numpy.ndarray, JaxArray
The post-synaptic neuron ids. If ``post_ids`` is generated by
``brainpy.conn.TwoEndConnector``, then it has sorted indices.
Otherwise, this function cannot guarantee indices are sorted.
You's better set ``indices_are_sorted=False``.
post_num: int
The number of the post-synaptic neurons.
indices_are_sorted: whether ``post_ids`` is known to be sorted.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The post-synaptic value.
"""
post_ids = as_jax(post_ids)
syn_values = as_jax(syn_values)
if syn_values.dtype == jnp.bool_:
syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
return _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted)


def syn2post_min(syn_values, post_ids, post_num: int, indices_are_sorted=False):
"""The syn-to-post minimization computation.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
for syn_i, post_i in enumerate(post_ids):
post_val[post_i] = np.minimum(post_val[post_i], syn_values[syn_i])

Parameters
----------
syn_values: jax.numpy.ndarray, JaxArray, Variable
The synaptic values.
post_ids: jax.numpy.ndarray, JaxArray
The post-synaptic neuron ids. If ``post_ids`` is generated by
``brainpy.conn.TwoEndConnector``, then it has sorted indices.
Otherwise, this function cannot guarantee indices are sorted.
You's better set ``indices_are_sorted=False``.
post_num: int
The number of the post-synaptic neurons.
indices_are_sorted: whether ``post_ids`` is known to be sorted.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The post-synaptic value.
"""
post_ids = as_jax(post_ids)
syn_values = as_jax(syn_values)
if syn_values.dtype == jnp.bool_:
syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
return _jit_seg_min(syn_values, post_ids, post_num, indices_are_sorted)


def syn2post_mean(syn_values, post_ids, post_num: int, indices_are_sorted=False):
"""The syn-to-post mean computation.

Parameters
----------
syn_values: jax.numpy.ndarray, JaxArray, Variable
The synaptic values.
post_ids: jax.numpy.ndarray, JaxArray
The post-synaptic neuron ids. If ``post_ids`` is generated by
``brainpy.conn.TwoEndConnector``, then it has sorted indices.
Otherwise, this function cannot guarantee indices are sorted.
You's better set ``indices_are_sorted=False``.
post_num: int
The number of the post-synaptic neurons.
indices_are_sorted: whether ``post_ids`` is known to be sorted.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The post-synaptic value.
"""
post_ids = as_jax(post_ids)
syn_values = as_jax(syn_values)
if syn_values.dtype == jnp.bool_:
syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
nominator = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted)
denominator = _jit_seg_sum(jnp.ones_like(syn_values), post_ids, post_num, indices_are_sorted)
return jnp.nan_to_num(nominator / denominator)


def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=False):
"""The syn-to-post softmax computation.

Parameters
----------
syn_values: jax.numpy.ndarray, JaxArray, Variable
The synaptic values.
post_ids: jax.numpy.ndarray, JaxArray
The post-synaptic neuron ids. If ``post_ids`` is generated by
``brainpy.conn.TwoEndConnector``, then it has sorted indices.
Otherwise, this function cannot guarantee indices are sorted.
You's better set ``indices_are_sorted=False``.
post_num: int
The number of the post-synaptic neurons.
indices_are_sorted: whether ``post_ids`` is known to be sorted.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The post-synaptic value.
"""
post_ids = as_jax(post_ids)
syn_values = as_jax(syn_values)
if syn_values.dtype == jnp.bool_:
syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
syn_maxs = _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted)
syn_values = syn_values - syn_maxs[post_ids]
syn_values = jnp.exp(syn_values)
normalizers = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted)
softmax = syn_values / normalizers[post_ids]
return jnp.nan_to_num(softmax)

brainpy/math/operators/multiplication.py → brainpy/math/operators/sparse_matmul.py View File

@@ -1,16 +1,18 @@
# -*- coding: utf-8 -*-

from typing import Union, Dict, Tuple

from typing import Union, Dict

import brainpylib
import jax.numpy as jnp
from jax import ops as jops
from jax import ops

from brainpy.math.jaxarray import JaxArray
from brainpy.math.numpy_ops import _remove_jaxarray
from brainpy.math.numpy_ops import as_jax
from brainpy.types import Array

__all__ = [
'sparse_matmul'
'sparse_matmul',
'csr_matvec',
]


@@ -42,16 +44,16 @@ def _matmul_with_left_sparse(
shape = sparse['shape']
if len(shape) != 2:
raise ValueError(f'Sparse matrix must be a two-dimensional matrix. But we got {shape}')
values = _remove_jaxarray(values)
rows = _remove_jaxarray(rows)
cols = _remove_jaxarray(cols)
dense = _remove_jaxarray(dense)
values = as_jax(values)
rows = as_jax(rows)
cols = as_jax(cols)
dense = as_jax(dense)
B = dense.take(cols, axis=0)
if B.ndim == 2:
prod = B * jnp.reshape(values, (-1, 1))
else:
prod = B * values
return jops.segment_sum(prod, rows, shape[0])
return ops.segment_sum(prod, rows, shape[0])


def _matmul_with_right_sparse(
@@ -82,17 +84,17 @@ def _matmul_with_right_sparse(
shape = sparse['shape']
if len(shape) != 2:
raise ValueError(f'Sparse matrix must be a two-dimensional matrix. But we got {shape}')
values = _remove_jaxarray(values)
rows = _remove_jaxarray(rows)
cols = _remove_jaxarray(cols)
dense = _remove_jaxarray(dense)
values = as_jax(values)
rows = as_jax(rows)
cols = as_jax(cols)
dense = as_jax(dense)
if dense.ndim == 2:
A = dense[:, rows]
prod = (A * values).T
res = jops.segment_sum(prod, cols, shape[1]).T
res = ops.segment_sum(prod, cols, shape[1]).T
else:
prod = dense[rows] * values
res = jops.segment_sum(prod, cols, shape[1])
res = ops.segment_sum(prod, cols, shape[1])
return res


@@ -164,3 +166,42 @@ def sparse_matmul(A, B):
f'A:\n{A}\n'
f'B:\n{B}')
return _matmul_with_right_sparse(A, B)


def csr_matvec(values: Array,
indices: Array,
indptr: Array,
vector: Array,
shape: Tuple[int, ...],
transpose: bool = False):
"""Product of CSR sparse matrix and a dense vector.

Parameters
----------
values: Array
An array of shape ``(nse,)``.
indices: ndarray
An array of shape ``(nse,)``.
indptr: Array
An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``.
vector: Array
An array of shape ``(shape[0] if transpose else shape[1],)``
and dtype ``data.dtype``.
shape: tuple of int
A length-2 tuple representing the matrix shape.
transpose: bool
A boolean specifying whether to transpose the sparse matrix
before computing.

Returns
-------
y : Array
The array of shape ``(shape[1] if transpose else shape[0],)`` representing
the matrix vector product.
"""
vector = as_jax(vector)
indices = as_jax(indices)
indptr = as_jax(indptr)
values = as_jax(values)
return brainpylib.csr_matvec(values, indices, indptr, vector,
shape=shape, transpose=transpose)

+ 2
- 2
brainpy/math/operators/spikegrad.py View File

@@ -158,9 +158,9 @@ def spike_with_gaussian_grad(x, sigma=None, scale=None):
dE_dx = dE_dz * _gaussian(x, 0., _sigma) * _scale
returns = (_consistent_type(dE_dx, x),)
if sigma is not None:
returns += (_consistent_type(bm.zeros_like(_sigma), sigma), )
returns += (_consistent_type(bm.zeros_like(_sigma), sigma),)
if scale is not None:
returns += (_consistent_type(bm.zeros_like(_scale), scale), )
returns += (_consistent_type(bm.zeros_like(_scale), scale),)
return returns

return z, grad


+ 0
- 235
brainpy/math/operators/syn2post.py View File

@@ -1,235 +0,0 @@
# -*- coding: utf-8 -*-

import jax.numpy as jnp
from jax import jit, vmap
from jax import ops as jops

from brainpy.math.numpy_ops import as_device_array


_jit_seg_sum = jit(jops.segment_sum, static_argnums=(2, 3))
_jit_seg_prod = jit(jops.segment_prod, static_argnums=(2, 3))
_jit_seg_max = jit(jops.segment_max, static_argnums=(2, 3))
_jit_seg_min = jit(jops.segment_min, static_argnums=(2, 3))


__all__ = [
'syn2post_sum', 'syn2post',
'syn2post_prod',
'syn2post_max',
'syn2post_min',
'syn2post_mean',
'syn2post_softmax',

]


def syn2post_sum(syn_values, post_ids, post_num: int, indices_are_sorted=True):
"""The syn-to-post summation computation.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
for syn_i, post_i in enumerate(post_ids):
post_val[post_i] += syn_values[syn_i]

Parameters
----------
syn_values: jax.numpy.ndarray, JaxArray, Variable
The synaptic values.
post_ids: jax.numpy.ndarray, JaxArray
The post-synaptic neuron ids.
post_num: int
The number of the post-synaptic neurons.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The post-synaptic value.
"""
post_ids = as_device_array(post_ids)
syn_values = as_device_array(syn_values)
if syn_values.dtype == jnp.bool_:
syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
return _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted)


syn2post = syn2post_sum


def syn2post_prod(syn_values, post_ids, post_num: int, indices_are_sorted=True):
"""The syn-to-post product computation.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
for syn_i, post_i in enumerate(post_ids):
post_val[post_i] *= syn_values[syn_i]

Parameters
----------
syn_values: jax.numpy.ndarray, JaxArray, Variable
The synaptic values.
post_ids: jax.numpy.ndarray, JaxArray
The post-synaptic neuron ids. If ``post_ids`` is generated by
``brainpy.conn.TwoEndConnector``, then it has sorted indices.
Otherwise, this function cannot guarantee indices are sorted.
You's better set ``indices_are_sorted=False``.
post_num: int
The number of the post-synaptic neurons.
indices_are_sorted: whether ``post_ids`` is known to be sorted.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The post-synaptic value.
"""
post_ids = as_device_array(post_ids)
syn_values = as_device_array(syn_values)
if syn_values.dtype == jnp.bool_:
syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
return _jit_seg_prod(syn_values, post_ids, post_num, indices_are_sorted)


def syn2post_max(syn_values, post_ids, post_num: int, indices_are_sorted=True):
"""The syn-to-post maximum computation.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
for syn_i, post_i in enumerate(post_ids):
post_val[post_i] = np.maximum(post_val[post_i], syn_values[syn_i])

Parameters
----------
syn_values: jax.numpy.ndarray, JaxArray, Variable
The synaptic values.
post_ids: jax.numpy.ndarray, JaxArray
The post-synaptic neuron ids. If ``post_ids`` is generated by
``brainpy.conn.TwoEndConnector``, then it has sorted indices.
Otherwise, this function cannot guarantee indices are sorted.
You's better set ``indices_are_sorted=False``.
post_num: int
The number of the post-synaptic neurons.
indices_are_sorted: whether ``post_ids`` is known to be sorted.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The post-synaptic value.
"""
post_ids = as_device_array(post_ids)
syn_values = as_device_array(syn_values)
if syn_values.dtype == jnp.bool_:
syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
return _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted)


def syn2post_min(syn_values, post_ids, post_num: int, indices_are_sorted=True):
"""The syn-to-post minimization computation.

This function is equivalent to:

.. highlight:: python
.. code-block:: python

post_val = np.zeros(post_num)
for syn_i, post_i in enumerate(post_ids):
post_val[post_i] = np.minimum(post_val[post_i], syn_values[syn_i])

Parameters
----------
syn_values: jax.numpy.ndarray, JaxArray, Variable
The synaptic values.
post_ids: jax.numpy.ndarray, JaxArray
The post-synaptic neuron ids. If ``post_ids`` is generated by
``brainpy.conn.TwoEndConnector``, then it has sorted indices.
Otherwise, this function cannot guarantee indices are sorted.
You's better set ``indices_are_sorted=False``.
post_num: int
The number of the post-synaptic neurons.
indices_are_sorted: whether ``post_ids`` is known to be sorted.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The post-synaptic value.
"""
post_ids = as_device_array(post_ids)
syn_values = as_device_array(syn_values)
if syn_values.dtype == jnp.bool_:
syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
return _jit_seg_min(syn_values, post_ids, post_num, indices_are_sorted)


def syn2post_mean(syn_values, post_ids, post_num: int, indices_are_sorted=True):
"""The syn-to-post mean computation.

Parameters
----------
syn_values: jax.numpy.ndarray, JaxArray, Variable
The synaptic values.
post_ids: jax.numpy.ndarray, JaxArray
The post-synaptic neuron ids. If ``post_ids`` is generated by
``brainpy.conn.TwoEndConnector``, then it has sorted indices.
Otherwise, this function cannot guarantee indices are sorted.
You's better set ``indices_are_sorted=False``.
post_num: int
The number of the post-synaptic neurons.
indices_are_sorted: whether ``post_ids`` is known to be sorted.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The post-synaptic value.
"""
post_ids = as_device_array(post_ids)
syn_values = as_device_array(syn_values)
if syn_values.dtype == jnp.bool_:
syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
nominator = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted)
denominator = _jit_seg_sum(jnp.ones_like(syn_values), post_ids, post_num, indices_are_sorted)
return jnp.nan_to_num(nominator / denominator)


def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=True):
"""The syn-to-post softmax computation.

Parameters
----------
syn_values: jax.numpy.ndarray, JaxArray, Variable
The synaptic values.
post_ids: jax.numpy.ndarray, JaxArray
The post-synaptic neuron ids. If ``post_ids`` is generated by
``brainpy.conn.TwoEndConnector``, then it has sorted indices.
Otherwise, this function cannot guarantee indices are sorted.
You's better set ``indices_are_sorted=False``.
post_num: int
The number of the post-synaptic neurons.
indices_are_sorted: whether ``post_ids`` is known to be sorted.

Returns
-------
post_val: jax.numpy.ndarray, JaxArray
The post-synaptic value.
"""
post_ids = as_device_array(post_ids)
syn_values = as_device_array(syn_values)
if syn_values.dtype == jnp.bool_:
syn_values = jnp.asarray(syn_values, dtype=jnp.int32)
syn_maxs = _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted)
syn_values = syn_values - syn_maxs[post_ids]
syn_values = jnp.exp(syn_values)
normalizers = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted)
softmax = syn_values / normalizers[post_ids]
return jnp.nan_to_num(softmax)


+ 3
- 45
brainpy/math/operators/tests/test_op_register.py View File

@@ -23,9 +23,7 @@ def event_sum_op(outs, ins):
outs[index] += v


event_sum = bm.register_op(name='event_sum', cpu_func=event_sum_op, eval_shape=abs_eval)
event_sum2 = bm.XLACustomOp(name='event_sum', cpu_func=event_sum_op, eval_shape=abs_eval)
event_sum = bm.jit(event_sum)
event_sum2 = bm.XLACustomOp(name='event_sum2', cpu_func=event_sum_op, eval_shape=abs_eval)


class ExponentialSyn(bp.dyn.TwoEndConn):
@@ -54,36 +52,6 @@ class ExponentialSyn(bp.dyn.TwoEndConn):
self.post.input += self.g * (self.E - self.post.V)


class ExponentialSyn2(bp.dyn.TwoEndConn):
def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.,
method='exp_auto'):
super(ExponentialSyn2, self).__init__(pre=pre, post=post, conn=conn)
self.check_pre_attrs('spike')
self.check_post_attrs('input', 'V')

# parameters
self.E = E
self.tau = tau
self.delay = delay
self.g_max = g_max
self.pre2post = self.conn.require('pre2post')

# variables
self.g = bm.Variable(bm.zeros(self.post.num))

# function
self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method)

def update(self, tdi):
self.g.value = self.integral(self.g, tdi['t'], tdi['dt'])
# Customized operator
# ------------------------------------------------------------------------------------------------------------
post_val = bm.zeros(self.post.num)
self.g += event_sum(self.pre.spike, self.pre2post[0], self.pre2post[1], post_val, self.g_max)
# ------------------------------------------------------------------------------------------------------------
self.post.input += self.g * (self.E - self.post.V)


class ExponentialSyn3(bp.dyn.TwoEndConn):
def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.,
method='exp_auto'):
@@ -138,7 +106,6 @@ class EINet(bp.dyn.Network):
self.I2I = syn_class(self.I, self.I, bp.conn.FixedProb(0.02), E=-80., g_max=wi, tau=10., method=method)



class TestOpRegister(unittest.TestCase):
def test_op(self):

@@ -155,17 +122,6 @@ class TestOpRegister(unittest.TestCase):
ax = fig.add_subplot(gs[0, 0])
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], ax=ax)

net2 = EINet(ExponentialSyn2, scale=1., method='euler')
runner2 = bp.dyn.DSRunner(
net2,
inputs=[(net2.E.input, 20.), (net2.I.input, 20.)],
monitors={'E.spike': net2.E.spike},
)
t, _ = runner2.run(100., eval_time=True)
print(t)
ax = fig.add_subplot(gs[0, 1])
bp.visualize.raster_plot(runner2.mon.ts, runner2.mon['E.spike'], ax=ax)

net3 = EINet(ExponentialSyn3, scale=1., method='euler')
runner3 = bp.dyn.DSRunner(
net3,
@@ -176,4 +132,6 @@ class TestOpRegister(unittest.TestCase):
print(t)
ax = fig.add_subplot(gs[0, 2])
bp.visualize.raster_plot(runner3.mon.ts, runner3.mon['E.spike'], ax=ax, show=True)

bm.clear_buffer_memory()
plt.close()

+ 0
- 28
brainpy/math/operators/utils.py View File

@@ -1,28 +0,0 @@
# -*- coding: utf-8 -*-

from brainpy.errors import PackageMissingError

try:
import brainpylib
except ModuleNotFoundError:
brainpylib = None


_BRAINPYLIB_MINIMAL_VERSION = '0.0.6'


def _check_brainpylib(ops_name):
if brainpylib is not None:
if brainpylib.__version__ < _BRAINPYLIB_MINIMAL_VERSION:
raise PackageMissingError(
f'"{ops_name}" operator need "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}". \n'
f'Please install it through:\n\n'
f'>>> pip install brainpylib>={_BRAINPYLIB_MINIMAL_VERSION} -U'
)
else:
raise PackageMissingError(
f'"brainpylib" must be installed when the user '
f'wants to use "{ops_name}" operator. \n'
f'Please install "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}" through:\n\n'
f'>>> pip install brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}'
)

+ 139
- 9
brainpy/math/operators/wrap_jax.py View File

@@ -8,6 +8,7 @@ from jax import lax
from jax import ops as jops

from brainpy.math.jaxarray import JaxArray
from brainpy.math.numpy_ops import as_jax

__all__ = [
'segment_sum',
@@ -24,8 +25,40 @@ def segment_sum(data: Union[JaxArray, jnp.ndarray],
unique_indices: bool = False,
bucket_size: Optional[int] = None,
mode: Optional[lax.GatherScatterMode] = None) -> JaxArray:
return JaxArray(jops.segment_sum(data.value if isinstance(data, JaxArray) else data,
segment_ids.value if isinstance(segment_ids, JaxArray) else segment_ids,
"""``segment_sum`` operator for brainpy `JaxArray` and `Variable`.

Parameters
----------
data: Array
An array with the values to be reduced.
segment_ids: Array
An array with integer dtype that indicates the segments of
`data` (along its leading axis) to be summed. Values can be repeated and
need not be sorted.
num_segments: Optional, int
An int with nonnegative value indicating the number
of segments. The default is set to be the minimum number of segments that
would support all indices in ``segment_ids``, calculated as
``max(segment_ids) + 1``.
Since `num_segments` determines the size of the output, a static value
must be provided to use ``segment_sum`` in a ``jit``-compiled function.
indices_are_sorted: bool
whether ``segment_ids`` is known to be sorted.
unique_indices: bool
whether `segment_ids` is known to be free of duplicates.
bucket_size: int
Size of bucket to group indices into. ``segment_sum`` is
performed on each bucket separately to improve numerical stability of
addition. Default ``None`` means no bucketing.

Returns
-------
output: Array
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
segment sums.
"""
return JaxArray(jops.segment_sum(as_jax(data),
as_jax(segment_ids),
num_segments,
indices_are_sorted,
unique_indices,
@@ -39,8 +72,40 @@ def segment_prod(data: Union[JaxArray, jnp.ndarray],
unique_indices: bool = False,
bucket_size: Optional[int] = None,
mode: Optional[lax.GatherScatterMode] = None) -> JaxArray:
return JaxArray(jops.segment_prod(data.value if isinstance(data, JaxArray) else data,
segment_ids.value if isinstance(segment_ids, JaxArray) else segment_ids,
"""``segment_prod`` operator for brainpy `JaxArray` and `Variable`.

Parameters
----------
data: Array
An array with the values to be reduced.
segment_ids: Array
An array with integer dtype that indicates the segments of
`data` (along its leading axis) to be summed. Values can be repeated and
need not be sorted.
num_segments: Optional, int
An int with nonnegative value indicating the number
of segments. The default is set to be the minimum number of segments that
would support all indices in ``segment_ids``, calculated as
``max(segment_ids) + 1``.
Since `num_segments` determines the size of the output, a static value
must be provided to use ``segment_sum`` in a ``jit``-compiled function.
indices_are_sorted: bool
whether ``segment_ids`` is known to be sorted.
unique_indices: bool
whether `segment_ids` is known to be free of duplicates.
bucket_size: int
Size of bucket to group indices into. ``segment_sum`` is
performed on each bucket separately to improve numerical stability of
addition. Default ``None`` means no bucketing.

Returns
-------
output: Array
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
segment sums.
"""
return JaxArray(jops.segment_prod(as_jax(data),
as_jax(segment_ids),
num_segments,
indices_are_sorted,
unique_indices,
@@ -54,8 +119,40 @@ def segment_max(data: Union[JaxArray, jnp.ndarray],
unique_indices: bool = False,
bucket_size: Optional[int] = None,
mode: Optional[lax.GatherScatterMode] = None) -> JaxArray:
return JaxArray(jops.segment_max(data.value if isinstance(data, JaxArray) else data,
segment_ids.value if isinstance(segment_ids, JaxArray) else segment_ids,
"""``segment_max`` operator for brainpy `JaxArray` and `Variable`.

Parameters
----------
data: Array
An array with the values to be reduced.
segment_ids: Array
An array with integer dtype that indicates the segments of
`data` (along its leading axis) to be summed. Values can be repeated and
need not be sorted.
num_segments: Optional, int
An int with nonnegative value indicating the number
of segments. The default is set to be the minimum number of segments that
would support all indices in ``segment_ids``, calculated as
``max(segment_ids) + 1``.
Since `num_segments` determines the size of the output, a static value
must be provided to use ``segment_sum`` in a ``jit``-compiled function.
indices_are_sorted: bool
whether ``segment_ids`` is known to be sorted.
unique_indices: bool
whether `segment_ids` is known to be free of duplicates.
bucket_size: int
Size of bucket to group indices into. ``segment_sum`` is
performed on each bucket separately to improve numerical stability of
addition. Default ``None`` means no bucketing.

Returns
-------
output: Array
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
segment sums.
"""
return JaxArray(jops.segment_max(as_jax(data),
as_jax(segment_ids),
num_segments,
indices_are_sorted,
unique_indices,
@@ -69,9 +166,42 @@ def segment_min(data: Union[JaxArray, jnp.ndarray],
unique_indices: bool = False,
bucket_size: Optional[int] = None,
mode: Optional[lax.GatherScatterMode] = None) -> JaxArray:
return JaxArray(jops.segment_min(data.value if isinstance(data, JaxArray) else data,
segment_ids.value if isinstance(segment_ids, JaxArray) else segment_ids,
"""``segment_min`` operator for brainpy `JaxArray` and `Variable`.

Parameters
----------
data: Array
An array with the values to be reduced.
segment_ids: Array
An array with integer dtype that indicates the segments of
`data` (along its leading axis) to be summed. Values can be repeated and
need not be sorted.
num_segments: Optional, int
An int with nonnegative value indicating the number
of segments. The default is set to be the minimum number of segments that
would support all indices in ``segment_ids``, calculated as
``max(segment_ids) + 1``.
Since `num_segments` determines the size of the output, a static value
must be provided to use ``segment_sum`` in a ``jit``-compiled function.
indices_are_sorted: bool
whether ``segment_ids`` is known to be sorted.
unique_indices: bool
whether `segment_ids` is known to be free of duplicates.
bucket_size: int
Size of bucket to group indices into. ``segment_sum`` is
performed on each bucket separately to improve numerical stability of
addition. Default ``None`` means no bucketing.

Returns
-------
output: Array
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
segment sums.
"""
return JaxArray(jops.segment_min(as_jax(data),
as_jax(segment_ids),
num_segments,
indices_are_sorted,
unique_indices,
bucket_size, mode))
bucket_size,
mode))

+ 68
- 28
brainpy/math/random.py View File

@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
import warnings
from collections import namedtuple
from functools import partial
from operator import index
@@ -86,9 +86,8 @@ def _get_tr_params(n, p):
m = jnp.floor((n + 1) * p).astype(n.dtype)
log_p = jnp.log(p)
log1_p = jnp.log1p(-p)
log_h = (m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) + (
_stirling_approx_tail(m) + _stirling_approx_tail(n - m)
)
log_h = ((m + 0.5) * (jnp.log((m + 1.0) / (n - m + 1.0)) + log1_p - log_p) +
_stirling_approx_tail(m) + _stirling_approx_tail(n - m))
return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)


@@ -105,10 +104,9 @@ def _stirling_approx_tail(k):
0.008330563433362871, ])
kp1 = k + 1
kp1sq = (k + 1) ** 2
return jnp.where(
k < 10, precomputed[k],
(1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1,
)
return jnp.where(k < 10,
precomputed[k],
(1.0 / 12 - (1.0 / 360 - (1.0 / 1260) / kp1sq) / kp1sq) / kp1)


def _binomial_btrs(key, p, n):
@@ -290,9 +288,9 @@ def _multinomial(key, p, n, n_max, shape=()):
excess = 0
# NB: we transpose to move batch shape to the front
indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
samples_2D = vmap(_scatter_add_one, (0, 0, 0))(jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
jnp.expand_dims(indices_2D, axis=-1),
jnp.ones(indices_2D.shape, dtype=indices.dtype))
samples_2D = vmap(_scatter_add_one)(jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
jnp.expand_dims(indices_2D, axis=-1),
jnp.ones(indices_2D.shape, dtype=indices.dtype))
return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess


@@ -387,37 +385,79 @@ class RandomState(Variable):
"""RandomState that track the random generator state. """
__slots__ = ()

def __init__(self, seed=None):
def __init__(self, seed_or_key=None, seed=None):
"""RandomState constructor.

Parameters
----------
seed : int, jax.DeviceArray, Optional
The initial seed of the random number generator.
seed_or_key: int, Array, optional
It can be an integer for initial seed of the random number generator,
or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.

.. versionadded:: 2.2.3.4

seed : int, Array, optional
Same as `seed_or_key`.

.. deprecated:: 2.2.3.4
Will be removed since version 2.4.
"""
if seed is None:
seed = np.random.randint(0, 100000, 2, dtype=np.uint32)
if isinstance(seed, int):
key = jr.PRNGKey(seed)
if seed is not None:
if seed_or_key is not None:
raise ValueError('Please set "seed_or_key" or "seed", not both.')
seed_or_key = seed
warnings.warn('Please use seed_or_key instead. '
'seed will be removed since 2.4.0', UserWarning)

if seed_or_key is None:
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
if isinstance(seed_or_key, int):
key = jr.PRNGKey(seed_or_key)
else:
assert len(seed) == 2
key = seed
if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
raise ValueError('key must be an array with dtype uint32. '
f'But we got {seed_or_key}')
key = seed_or_key
super(RandomState, self).__init__(key)

# ------------------- #
# seed and random key #
# ------------------- #

def seed(self, seed):
def seed(self, seed_or_key=None, seed=None):
"""Sets a new random seed.

Parameters
----------
seed : int
The new initial seed of the random number generator.
seed_or_key: int, Array, optional
It can be an integer for initial seed of the random number generator,
or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.

.. versionadded:: 2.2.3.4

seed : int, Array, optional
Same as `seed_or_key`.

.. deprecated:: 2.2.3.4
Will be removed since version 2.4.
"""
if seed is not None:
self.value = jr.PRNGKey(seed)
if seed_or_key is not None:
raise ValueError('Please set "seed_or_key" or "seed", not both.')
seed_or_key = seed
warnings.warn('Please use seed_or_key instead. '
'seed will be removed since 2.4.0', UserWarning)

if seed_or_key is None:
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
if isinstance(seed_or_key, int):
key = jr.PRNGKey(seed_or_key)
else:
if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
raise ValueError('key must be an array with dtype uint32. '
f'But we got {seed_or_key}')
key = seed_or_key
self.value = key

def split_key(self):
"""Create a new seed from the current seed.
@@ -509,11 +549,11 @@ class RandomState(Variable):
return JaxArray(jr.choice(key, a=a, shape=_size2shape(size),
replace=replace, p=p))

def permutation(self, x, key=None):
def permutation(self, x, axis: int = 0, independent: bool = False, key=None):
x = x.value if isinstance(x, JaxArray) else x
x = _check_py_seq(x)
key = self.split_key() if key is None else key
return JaxArray(jr.permutation(key, x))
return JaxArray(jr.permutation(key, x, axis=axis, independent=independent))

def shuffle(self, x, axis=0, key=None):
assert isinstance(x, JaxArray), f'Must be a JaxArray, but got {type(x)}'
@@ -1117,8 +1157,8 @@ def choice(a, size=None, replace=True, p=None, key=None):


@wraps(np.random.permutation)
def permutation(x, key=None):
return DEFAULT.permutation(x, key=key)
def permutation(x, axis: int = 0, independent: bool = False, key=None):
return DEFAULT.permutation(x, axis=axis, independent=independent, key=key)


@wraps(np.random.shuffle)


+ 14
- 2
brainpy/math/setting.py View File

@@ -3,13 +3,14 @@
import os
import re

from jax import dtypes, config, numpy as jnp
from jax import dtypes, config, numpy as jnp, devices
from jax.lib import xla_bridge

__all__ = [
'enable_x64',
'disable_x64',
'set_platform',
'get_platform',
'set_host_device_count',

# device memory
@@ -92,7 +93,7 @@ def disable_x64():
config.update("jax_enable_x64", False)


def set_platform(platform):
def set_platform(platform: str):
"""
Changes platform to CPU, GPU, or TPU. This utility only takes
effect at the beginning of your program.
@@ -101,6 +102,17 @@ def set_platform(platform):
config.update("jax_platform_name", platform)


def get_platform() -> str:
"""Get the computing platform.

Returns
-------
platform: str
Either 'cpu', 'gpu' or 'tpu'.
"""
return devices()[0].platform


def set_host_device_count(n):
"""
By default, XLA considers all CPU cores as one device. This utility tells XLA


+ 40
- 26
brainpy/math/tests/test_delay_vars.py View File

@@ -5,6 +5,7 @@ import unittest
import jax.numpy as jnp

import brainpy.math as bm
from brainpy.math.delayvars import ROTATION_UPDATING, CONCAT_UPDATING


class TestTimeDelay(unittest.TestCase):
@@ -80,36 +81,49 @@ class TestTimeDelay(unittest.TestCase):
class TestLengthDelay(unittest.TestCase):
def test1(self):
dim = 3
delay = bm.LengthDelay(jnp.zeros(dim), 10)
print(delay(1))
self.assertTrue(jnp.array_equal(delay(1), jnp.zeros(dim)))
for update_method in [ROTATION_UPDATING, CONCAT_UPDATING]:
delay = bm.LengthDelay(jnp.zeros(dim), 10, update_method=update_method)
print(delay(1))
self.assertTrue(jnp.array_equal(delay(1), jnp.zeros(dim)))

delay = bm.jit(delay)
print(delay(1))
self.assertTrue(jnp.array_equal(delay(1), jnp.zeros(dim)))
delay = bm.jit(delay)
print(delay(1))
self.assertTrue(jnp.array_equal(delay(1), jnp.zeros(dim)))

def test2(self):
dim = 3
delay = bm.LengthDelay(jnp.zeros(dim), 10, initial_delay_data=jnp.arange(1, 11).reshape((10, 1)))
print(delay(0))
self.assertTrue(jnp.array_equal(delay(0), jnp.zeros(dim)))
print(delay(1))
self.assertTrue(jnp.array_equal(delay(1), jnp.ones(dim) * 10))

delay = bm.jit(delay)
print(delay(0))
self.assertTrue(jnp.array_equal(delay(0), jnp.zeros(dim)))
print(delay(1))
self.assertTrue(jnp.array_equal(delay(1), jnp.ones(dim) * 10))
for update_method in [ROTATION_UPDATING, CONCAT_UPDATING]:
delay = bm.LengthDelay(jnp.zeros(dim), 10,
# initial_delay_data=jnp.arange(1, 11).reshape((10, 1)),
initial_delay_data=jnp.arange(10, 0, -1).reshape((10, 1)),
update_method=update_method)
print(delay(0))
self.assertTrue(jnp.array_equal(delay(0), jnp.zeros(dim)))
print(delay(1))
self.assertTrue(jnp.array_equal(delay(1), jnp.ones(dim) * 10))

delay = bm.jit(delay)
print(delay(0))
self.assertTrue(jnp.array_equal(delay(0), jnp.zeros(dim)))
print(delay(1))
self.assertTrue(jnp.array_equal(delay(1), jnp.ones(dim) * 10))

def test3(self):
dim = 3
delay = bm.LengthDelay(jnp.zeros(dim), 10, initial_delay_data=jnp.arange(1, 11).reshape((10, 1)))
print(delay(jnp.asarray([1, 2, 3]),
jnp.arange(3)))
# self.assertTrue(bm.array_equal(delay(0), bm.zeros(dim)))

delay = bm.jit(delay)
print(delay(jnp.asarray([1, 2, 3]),
jnp.arange(3)))
# self.assertTrue(bm.array_equal(delay(1), bm.ones(dim) * 10))
for update_method in [ROTATION_UPDATING, CONCAT_UPDATING]:
delay = bm.LengthDelay(jnp.zeros(dim), 10,
# initial_delay_data=jnp.arange(1, 11).reshape((10, 1)),
initial_delay_data=jnp.arange(10, 0, -1).reshape((10, 1)),
update_method=update_method)
print(delay(jnp.asarray([1, 2, 3]),
jnp.arange(3)))
self.assertTrue(bm.array_equal(delay(jnp.asarray([1, 2, 3]), jnp.arange(3)),
bm.asarray([10., 9., 8.])))

delay = bm.jit(delay)
print(delay(jnp.asarray([1, 2, 3]),
jnp.arange(3)))
self.assertTrue(bm.array_equal(delay(jnp.asarray([1, 2, 3]), jnp.arange(3)),
bm.asarray([10., 9., 8.])))



+ 51
- 0
brainpy/math/tests/test_jaxarray.py View File

@@ -4,6 +4,7 @@
import unittest

import jax.numpy as jnp
import numpy as np
from jax.tree_util import tree_flatten, tree_unflatten

import brainpy.math as bm
@@ -39,4 +40,54 @@ class TestJaxArray(unittest.TestCase):
with self.assertRaises(TypeError):
ee = a + e

def test_operation_with_numpy_array(self):
rng = bm.random.RandomState(123)
add = lambda: rng.rand(10) + np.zeros(1)
self.assertTrue(isinstance(add(), bm.JaxArray))
self.assertTrue(isinstance(bm.jit(add, dyn_vars=rng)(), bm.JaxArray))


class TestVariable(unittest.TestCase):
def test_variable_init(self):
self.assertTrue(
bm.array_equal(bm.Variable(bm.zeros(10)),
bm.Variable(10))
)
bm.random.seed(123)
self.assertTrue(
not bm.array_equal(bm.Variable(bm.random.rand(10)),
bm.Variable(10))
)


class TestVariableView(unittest.TestCase):
def test_update(self):
origin = bm.Variable(bm.zeros(10))
view = bm.VariableView(origin, slice(0, 5, None))

view.update(bm.ones(5))
self.assertTrue(
bm.array_equal(origin, bm.concatenate([bm.ones(5), bm.zeros(5)]))
)

view.value = bm.arange(5.)
self.assertTrue(
bm.array_equal(origin, bm.concatenate([bm.arange(5), bm.zeros(5)]))
)

view += 10
self.assertTrue(
bm.array_equal(origin, bm.concatenate([bm.arange(5) + 10, bm.zeros(5)]))
)

bm.random.shuffle(view)
print(view)
print(origin)

view.sort()
self.assertTrue(
bm.array_equal(origin, bm.concatenate([bm.arange(5) + 10, bm.zeros(5)]))
)

self.assertTrue(view.sum() == bm.sum(bm.arange(5) + 10))


+ 5
- 0
brainpy/math/tests/test_numpy_einsum.py View File

@@ -14,6 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest
pytest.skip("No need to test.", allow_module_level=True)


import itertools
from collections import defaultdict
from functools import partial


+ 5
- 1
brainpy/math/tests/test_numpy_indexing.py View File

@@ -15,6 +15,9 @@
# limitations under the License.


import pytest
pytest.skip("No need to test.", allow_module_level=True)

import enum
import itertools
import typing
@@ -403,7 +406,7 @@ MIXED_ADVANCED_INDEXING_TESTS = MIXED_ADVANCED_INDEXING_TESTS_NO_REPEATS + [

MODES = ["clip", "drop", "promise_in_bounds"]

@pytest.mark.skipif(True, reason="No longer need to test.")
class IndexingTest(jtu.JaxTestCase):
"""Tests for Numpy indexing translation rules."""

@@ -1013,6 +1016,7 @@ def _update_tol(op):
return tol


@pytest.mark.skipif(True, reason="No longer need to test.")
@jtu.with_config(jax_numpy_dtype_promotion='standard')
class IndexedUpdateTest(jtu.JaxTestCase):



+ 6
- 1
brainpy/math/tests/test_numpy_ops.py View File

@@ -13,6 +13,9 @@
# limitations under the License.


import pytest
pytest.skip("No need to test.", allow_module_level=True)

import collections
import functools
from functools import partial
@@ -545,7 +548,7 @@ def bm_func(fun):

return wrapper

@pytest.mark.skipif(True, reason="No longer need to test.")
@jtu.with_config(jax_numpy_dtype_promotion='standard')
class LaxBackedNumpyTests(jtu.JaxTestCase):
"""Tests for LAX-backed Numpy implementation."""
@@ -5991,6 +5994,7 @@ GRAD_SPECIAL_VALUE_TEST_RECORDS = [
GradSpecialValuesTestSpec(bm.sinc, [0.], 1),
]

@pytest.mark.skipif(True, reason="No longer need to test.")
@jtu.with_config(jax_numpy_dtype_promotion='standard')
class NumpyGradTests(jtu.JaxTestCase):
@parameterized.named_parameters(itertools.chain.from_iterable(
@@ -6095,6 +6099,7 @@ def _dtypes_for_ufunc(name: str) -> Iterator[Tuple[str, ...]]:
else:
yield arg_dtypes

@pytest.mark.skipif(True, reason="No longer need to test.")
@jtu.with_config(jax_numpy_dtype_promotion='standard')
class NumpyUfuncTests(jtu.JaxTestCase):
@parameterized.named_parameters(


+ 56
- 0
brainpy/math/tests/test_transformation_context.py View File

@@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-


import unittest

import brainpy as bp
import brainpy.math as bm


class TestJIT(unittest.TestCase):
def test1(self):
@bm.jit
def f1(a):
a[:] = 1.
return a

a = bm.zeros(10)
with self.assertRaises(bp.errors.MathError):
print(f1(a))

def test2(self):
@bm.jit
def f1(a):
b = a + 1

@bm.jit
def f2(x):
x.value = 1.
return x

return f2(b)

with self.assertRaises(bp.errors.MathError):
print(f1(bm.ones(2)))

def test3(self):
@bm.jit
def f1(a):
return a + 1

@bm.jit
def f2(b):
b[:] = 1.
return b

with self.assertRaises(bp.errors.MathError):
print(f2(f1(bm.ones(2))))

def test4(self):
@bm.jit
def f2(a):
b = bm.ones(1)
b += 10
return a + b

print(f2(bm.ones(1)))

+ 4
- 0
brainpy/measure/__init__.py View File

@@ -5,6 +5,10 @@ This module aims to provide commonly used analysis methods for simulated neurona
You can access them through ``brainpy.measure.XXX``.
"""

from . import correlation, firings, lfp

from .correlation import *
from .firings import *
from .lfp import *



+ 124
- 73
brainpy/measure/correlation.py View File

@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-

from functools import partial

import numpy as np
from jax import vmap, jit, lax, numpy as jnp
import numpy as onp
from jax import vmap, lax, numpy as jnp

from brainpy import math as bm
from brainpy.errors import UnsupportedError

__all__ = [
'cross_correlation',
@@ -17,17 +17,7 @@ __all__ = [
]


@jit
@partial(vmap, in_axes=(None, 0, 0))
def _cc(states, i, j):
sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j]))
return lax.cond(sqrt_ij == 0.,
lambda _: 0.,
lambda _: jnp.sum(states[i] * states[j]) / sqrt_ij,
None)


def cross_correlation(spikes, bin, dt=None):
def cross_correlation(spikes, bin, dt=None, numpy=True, method='loop'):
r"""Calculate cross correlation index between neurons.

The coherence [1]_ between two neurons i and j is measured by their
@@ -47,14 +37,27 @@ def cross_correlation(spikes, bin, dt=None):
average of :math:`\kappa_{i j}(\tau)` over many pairs of neurons in the
network.

.. note::
To JIT compile this function, users should make ``bin``, ``dt``, ``numpy`` static.
For example, ``partial(brainpy.measure.cross_correlation, bin=10, numpy=False)``.

Parameters
----------
spikes :
spikes : ndarray
The history of spike states of the neuron group.
bin : float, int
The time bin to normalize spike states.
dt : float, optional
The time precision.
numpy: bool
Whether we use numpy array as the functional output.
If ``False``, this function can be JIT compiled.
method: str
The method to calculate all pairs of cross correlation.
Supports two kinds of methods: `loop` and `vmap`.
`vmap` method needs much more memory.

.. versionadded:: 2.2.3.4

Returns
-------
@@ -67,27 +70,44 @@ def cross_correlation(spikes, bin, dt=None):
inhibition in a hippocampal interneuronal network model." Journal of
neuroscience 16.20 (1996): 6402-6413.
"""
spikes = bm.as_device_array(spikes)
spikes = bm.as_numpy(spikes) if numpy else bm.as_device_array(spikes)
np = onp if numpy else jnp
dt = bm.get_dt() if dt is None else dt
bin_size = int(bin / dt)
num_hist, num_neu = spikes.shape
num_bin = int(np.ceil(num_hist / bin_size))
num_bin = int(onp.ceil(num_hist / bin_size))
if num_bin * bin_size != num_hist:
spikes = jnp.append(spikes, jnp.zeros((num_bin * bin_size - num_hist, num_neu)), axis=0)
spikes = np.append(spikes, np.zeros((num_bin * bin_size - num_hist, num_neu)), axis=0)
states = spikes.T.reshape((num_neu, num_bin, bin_size))
states = jnp.asarray(jnp.sum(states, axis=2) > 0., dtype=jnp.float_)
states = jnp.asarray(np.sum(states, axis=2) > 0., dtype=jnp.float_)
indices = jnp.tril_indices(num_neu, k=-1)
return jnp.mean(_cc(states, *indices))

if method == 'loop':
def _f(i, j):
sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j]))
return lax.cond(sqrt_ij == 0.,
lambda _: 0.,
lambda _: jnp.sum(states[i] * states[j]) / sqrt_ij,
None)
res = bm.for_loop(_f, dyn_vars=[], operands=indices)

elif method == 'vmap':
@vmap
def _cc(i, j):
sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j]))
return lax.cond(sqrt_ij == 0.,
lambda _: 0.,
lambda _: jnp.sum(states[i] * states[j]) / sqrt_ij,
None)

@partial(vmap, in_axes=(None, 0))
def _var(neu_signal, i):
neu_signal = neu_signal[:, i]
return jnp.mean(neu_signal * neu_signal) - jnp.mean(neu_signal) ** 2
res = _cc(*indices)
else:
raise UnsupportedError(f'Do not support {method}. We only support "loop" or "vmap".')

return np.mean(np.asarray(res))

@jit
def voltage_fluctuation(potentials):
def voltage_fluctuation(potentials, numpy=True, method='loop'):
r"""Calculate neuronal synchronization via voltage variance.

The method comes from [1]_ [2]_ [3]_.
@@ -125,8 +145,18 @@ def voltage_fluctuation(potentials):

Parameters
----------
potentials :
The membrane potential matrix of the neuron group.
potentials : ndarray
The membrane potential matrix of the neuron group.
numpy: bool
Whether we use numpy array as the functional output.
If ``False``, this function can be JIT compiled.
method: str
The method to calculate all pairs of cross correlation.
Supports two kinds of methods: `loop` and `vmap`.
`vmap` method will consume much more memory.

.. versionadded:: 2.2.3.4


Returns
-------
@@ -136,44 +166,60 @@ def voltage_fluctuation(potentials):
References
----------
.. [1] Golomb, D. and Rinzel J. (1993) Dynamics of globally coupled
inhibitory neurons with heterogeneity. Phys. Rev. reversal_potential 48:4810-4814.
inhibitory neurons with heterogeneity. Phys. Rev. E 48:4810-4814.
.. [2] Golomb D. and Rinzel J. (1994) Clustering in globally coupled
inhibitory neurons. Physica D 72:259-282.
.. [3] David Golomb (2007) Neuronal synchrony measures. Scholarpedia, 2(1):1347.
"""

potentials = bm.as_device_array(potentials)
num_hist, num_neu = potentials.shape
var_mean = jnp.mean(_var(potentials, jnp.arange(num_neu)))
avg = jnp.mean(potentials, axis=1)
avg_var = jnp.mean(avg * avg) - jnp.mean(avg) ** 2
return lax.cond(var_mean != 0., lambda _: avg_var / var_mean, lambda _: 1., None)

if method == 'loop':
_var = lambda aa: bm.for_loop(lambda signal: jnp.mean(signal * signal) - jnp.mean(signal) ** 2,
dyn_vars=(),
operands=bm.moveaxis(aa, 0, 1).value)

elif method == 'vmap':
_var = vmap(lambda signal: jnp.mean(signal * signal) - jnp.mean(signal) ** 2, in_axes=1)
else:
raise UnsupportedError(f'Do not support {method}. We only support "loop" or "vmap".')

var_mean = jnp.mean(_var(potentials))
r = jnp.where(var_mean == 0., 1., avg_var / var_mean)
return bm.as_numpy(r) if numpy else r

def matrix_correlation(x, y):

def matrix_correlation(x, y, numpy=True):
"""Pearson correlation of the lower triagonal of two matrices.

The triangular matrix is offset by k = 1 in order to ignore the diagonal line

Parameters
----------
x: tensor
x: ndarray
First matrix.
y: tensor
y: ndarray
Second matrix
numpy: bool
Whether we use numpy array as the functional output.
If ``False``, this function can be JIT compiled.

Returns
-------
coef: tensor
coef: ndarray
Correlation coefficient
"""
x = bm.as_numpy(x)
y = bm.as_numpy(y)

x = bm.as_numpy(x) if numpy else bm.as_device_array(x)
y = bm.as_numpy(y) if numpy else bm.as_device_array(y)
np = onp if numpy else jnp
if x.ndim != 2:
raise ValueError(f'Only support 2d tensor, but we got a tensor '
raise ValueError(f'Only support 2d array, but we got a array '
f'with the shape of {x.shape}')
if y.ndim != 2:
raise ValueError(f'Only support 2d tensor, but we got a tensor '
raise ValueError(f'Only support 2d array, but we got a array '
f'with the shape of {y.shape}')
x = x[np.triu_indices_from(x, k=1)]
y = y[np.triu_indices_from(y, k=1)]
@@ -181,34 +227,37 @@ def matrix_correlation(x, y):
return cc


def functional_connectivity(activities):
def functional_connectivity(activities, numpy=True):
"""Functional connectivity matrix of timeseries activities.

Parameters
----------
activities: tensor
The multidimensional tensor with the shape of ``(num_time, num_sample)``.
activities: ndarray
The multidimensional array with the shape of ``(num_time, num_sample)``.
numpy: bool
Whether we use numpy array as the functional output.
If ``False``, this function can be JIT compiled.

Returns
-------
connectivity_matrix: tensor
connectivity_matrix: ndarray
``num_sample x num_sample`` functional connectivity matrix.
"""
activities = bm.as_numpy(activities)
activities = bm.as_numpy(activities) if numpy else bm.as_device_array(activities)
np = onp if numpy else jnp
if activities.ndim != 2:
raise ValueError('Only support 2d tensor with shape of "(num_time, num_sample)". '
f'But we got a tensor with the shape of {activities.shape}')
raise ValueError('Only support 2d array with shape of "(num_time, num_sample)". '
f'But we got a array with the shape of {activities.shape}')
fc = np.corrcoef(activities.T)
return np.nan_to_num(fc)


@jit
def functional_connectivity_dynamics(activities, window_size=30, step_size=5):
"""Computes functional connectivity dynamics (FCD) matrix.

Parameters
----------
activities: tensor
activities: ndarray
The time series with shape of ``(num_time, num_sample)``.
window_size: int
Size of each rolling window in time steps, defaults to 30.
@@ -217,50 +266,52 @@ def functional_connectivity_dynamics(activities, window_size=30, step_size=5):

Returns
-------
fcd_matrix: tensor
fcd_matrix: ndarray
FCD matrix.
"""
pass


def _weighted_mean(x, w):
"""Weighted Mean"""
return jnp.sum(x * w) / jnp.sum(w)


def _weighted_cov(x, y, w):
"""Weighted Covariance"""
return jnp.sum(w * (x - _weighted_mean(x, w)) * (y - _weighted_mean(y, w))) / jnp.sum(w)


@jit
def weighted_correlation(x, y, w):
def weighted_correlation(x, y, w, numpy=True):
"""Weighted Pearson correlation of two data series.

Parameters
----------
x: tensor
x: ndarray
The data series 1.
y: tensor
y: ndarray
The data series 2.
w: tensor
w: ndarray
Weight vector, must have same length as x and y.
numpy: bool
Whether we use numpy array as the functional output.
If ``False``, this function can be JIT compiled.

Returns
-------
corr: tensor
corr: ndarray
Weighted correlation coefficient.
"""
x = bm.as_device_array(x)
y = bm.as_device_array(y)
w = bm.as_device_array(w)
x = bm.as_numpy(x) if numpy else bm.as_device_array(x)
y = bm.as_numpy(y) if numpy else bm.as_device_array(y)
w = bm.as_numpy(w) if numpy else bm.as_device_array(w)
np = onp if numpy else jnp

def _weighted_mean(x, w):
"""Weighted Mean"""
return np.sum(x * w) / np.sum(w)

def _weighted_cov(x, y, w):
"""Weighted Covariance"""
return np.sum(w * (x - _weighted_mean(x, w)) * (y - _weighted_mean(y, w))) / np.sum(w)

if x.ndim != 1:
raise ValueError(f'Only support 1d tensor, but we got a tensor '
raise ValueError(f'Only support 1d array, but we got a array '
f'with the shape of {x.shape}')
if y.ndim != 1:
raise ValueError(f'Only support 1d tensor, but we got a tensor '
raise ValueError(f'Only support 1d array, but we got a array '
f'with the shape of {y.shape}')
if w.ndim != 1:
raise ValueError(f'Only support 1d tensor, but we got a tensor '
raise ValueError(f'Only support 1d array, but we got a array '
f'with the shape of {w.shape}')
return _weighted_cov(x, y, w) / jnp.sqrt(_weighted_cov(x, x, w) * _weighted_cov(y, y, w))
return _weighted_cov(x, y, w) / np.sqrt(_weighted_cov(x, x, w) * _weighted_cov(y, y, w))

+ 13
- 18
brainpy/measure/firings.py View File

@@ -1,8 +1,7 @@
# -*- coding: utf-8 -*-

import numpy as np
import numpy as onp
import jax.numpy as jnp
from jax import jit

from brainpy import math as bm

@@ -29,21 +28,14 @@ def raster_plot(sp_matrix, times):
Include (neuron index, spike time).
"""
sp_matrix = bm.as_numpy(sp_matrix)
times = np.asarray(times)
elements = np.where(sp_matrix > 0.)
times = onp.asarray(times)
elements = onp.where(sp_matrix > 0.)
index = elements[1]
time = times[elements[0]]
return index, time


@jit
def _firing_rate(sp_matrix, window):
sp_matrix = bm.as_device_array(sp_matrix)
rate = jnp.sum(sp_matrix, axis=1) / sp_matrix.shape[1]
return jnp.convolve(rate, window, mode='same')


def firing_rate(sp_matrix, width, dt=None, numpy=True):
def firing_rate(spikes, width, dt=None, numpy=True):
r"""Calculate the mean firing rate over in a neuron group.

This method is adopted from Brian2.
@@ -57,21 +49,24 @@ def firing_rate(sp_matrix, width, dt=None, numpy=True):

Parameters
----------
sp_matrix : math.JaxArray, np.ndarray
spikes : ndarray
The spike matrix which record spiking activities.
width : int, float
The width of the ``window`` in millisecond.
dt : float, optional
The sample rate.
numpy: bool
Whether we use numpy array as the functional output.
If ``False``, this function can be JIT compiled.

Returns
-------
rate : numpy.ndarray
rate : ndarray
The population rate in Hz, smoothed with the given window.
"""
spikes = bm.as_numpy(spikes) if numpy else bm.as_device_array(spikes)
np = onp if numpy else jnp
dt = bm.get_dt() if (dt is None) else dt
width1 = int(width / 2 / dt) * 2 + 1
window = jnp.ones(width1) * 1000 / width
fr = _firing_rate(sp_matrix, window)
return bm.as_numpy(fr) if numpy else fr

window = np.ones(width1) * 1000 / width
return np.convolve(np.mean(spikes, axis=1), window, mode='same')

+ 114
- 0
brainpy/measure/lfp.py View File

@@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-


from jax import numpy as jnp

import brainpy.math as bm

__all__ = [
'unitary_LFP',
]


def unitary_LFP(times, spikes, spike_type='exc',
xmax=0.2, ymax=0.2, va=200., lambda_=0.2,
sig_i=2.1, sig_e=2.1 * 1.5, location='soma layer', seed=None):
"""A kernel-based method to calculate unitary local field potentials (uLFP)
from a network of spiking neurons [1]_.

.. note::
This method calculates LFP only from the neuronal spikes. It does not consider
the subthreshold synaptic events, or the dendritic voltage-dependent ion channels.

Examples
--------

If you have spike data of excitatory and inhibtiory neurons, you can get the LFP
by the following methods:

>>> import brainpy as bp
>>> n_time = 1000
>>> n_exc = 100
>>> n_inh = 25
>>> times = bm.arange(n_time) * 0.1
>>> exc_sps = bp.math.random.random((n_time, n_exc)) < 0.3
>>> inh_sps = bp.math.random.random((n_time, n_inh)) < 0.4
>>> lfp = bp.measure.unitary_LFP(times, exc_sps, 'exc')
>>> lfp += bp.measure.unitary_LFP(times, inh_sps, 'inh')

Parameters
----------
times: ndarray
The times of the recording points.
spikes: ndarray
The spikes of excitatory neurons recorded by brainpy monitors.
spike_type: str
The neuron type of the spike trains. It can be "exc" or "inh".
location: str
The location of the spikes recorded. It can be "soma layer", "deep layer",
"superficial layer" and "surface".
xmax: float
Size of the array (in mm).
ymax: float
Size of the array (in mm).
va: int, float
The axon velocity (mm/sec).
lambda_: float
The space constant (mm).
sig_i: float
The std-dev of inhibition (in ms)
sig_e: float
The std-dev for excitation (in ms).
seed: int
The random seed.

References
----------
.. [1] Telenczuk, Bartosz, Maria Telenczuk, and Alain Destexhe. "A kernel-based
method to calculate local field potentials from networks of spiking
neurons." Journal of Neuroscience Methods 344 (2020): 108871.

"""
times = bm.as_device_array(times)
spikes = bm.as_device_array(spikes)
if spike_type not in ['exc', 'inh']:
raise ValueError('"spike_type" should be "exc or ""inh". ')
if spikes.ndim != 2:
raise ValueError('"E_spikes" should be a matrix with shape of (num_time, num_neuron). '
f'But we got {spikes.shape}')
if times.shape[0] != spikes.shape[0]:
raise ValueError('times and spikes should be consistent at the firs axis. '
f'Bug we got {times.shape[0]} != {spikes.shape}.')

# Distributing cells in a 2D grid
rng = bm.random.RandomState(seed)
num_neuron = spikes.shape[1]
pos_xs, pos_ys = rng.rand(2, num_neuron).value * jnp.array([[xmax], [ymax]])
pos_xs, pos_ys = jnp.asarray(pos_xs), jnp.asarray(pos_ys)

# distance/coordinates
xe, ye = xmax / 2, ymax / 2 # coordinates of electrode
dist = jnp.sqrt((pos_xs - xe) ** 2 + (pos_ys - ye) ** 2) # distance to electrode in mm

# amplitude
if location == 'soma layer':
amp_e, amp_i = 0.48, 3. # exc/inh uLFP amplitude (soma layer)
elif location == 'deep layer':
amp_e, amp_i = -0.16, -0.2 # exc/inh uLFP amplitude (deep layer)
elif location == 'superficial layer':
amp_e, amp_i = 0.24, -1.2 # exc/inh uLFP amplitude (superficial layer)
elif location == 'surface layer':
amp_e, amp_i = -0.08, 0.3 # exc/inh uLFP amplitude (surface)
else:
raise NotImplementedError
A = bm.exp(-dist / lambda_) * (amp_e if spike_type == 'exc' else amp_i)

# delay
delay = 10.4 + dist / va # delay to peak (in ms)

# LFP Calculation
iis, ids = jnp.where(spikes)
tts = times[iis] + delay[ids]
exc_amp = A[ids]
tau = (2 * sig_e * sig_e) if spike_type == 'exc' else (2 * sig_i * sig_i)
return bm.for_loop(lambda t: bm.sum(exc_amp * bm.exp(-(t - tts) ** 2 / tau)), [], times)

+ 33
- 7
brainpy/measure/tests/test_correlation.py View File

@@ -3,13 +3,19 @@

import unittest
import brainpy as bp
import brainpy.math as bm
from jax import jit
from functools import partial


class TestCrossCorrelation(unittest.TestCase):
def test_c(self):
spikes = bp.math.asarray([[1, 0, 1, 0, 1, 0, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0]]).T
cc1 = bp.measure.cross_correlation(spikes, 1., dt=1.)
print(cc1)
f_cc = jit(partial(bp.measure.cross_correlation, numpy=False, bin=1, dt=1.))
cc2 = f_cc(spikes)
print(cc1, cc2)
self.assertTrue(cc1 == cc2)

def test_cc(self):
spikes = bp.math.ones((1000, 10))
@@ -47,19 +53,33 @@ class TestCrossCorrelation(unittest.TestCase):

class TestVoltageFluctuation(unittest.TestCase):
def test_vf1(self):
bp.math.random.seed()
voltages = bp.math.random.normal(0, 10, size=(1000, 100))
rng = bp.math.random.RandomState(122)
voltages = rng.normal(0, 10, size=(1000, 100)).value
print(bp.measure.voltage_fluctuation(voltages))

voltages = bp.math.ones((1000, 100))
print(bp.measure.voltage_fluctuation(voltages))
bm.enable_x64()
voltages = bp.math.ones((1000, 100)).value
r1 = bp.measure.voltage_fluctuation(voltages)

jit_f = jit(partial(bp.measure.voltage_fluctuation, numpy=False))
jit_f = jit(lambda a: bp.measure.voltage_fluctuation(a, numpy=False))
r2 = jit_f(voltages)
print(r1, r2) # TODO: JIT results are different?
# self.assertTrue(r1 == r2)

bm.disable_x64()


class TestFunctionalConnectivity(unittest.TestCase):
def test_cf1(self):
bp.math.random.seed()
act = bp.math.random.random((10000, 3))
print(bp.measure.functional_connectivity(act))
r1 = bp.measure.functional_connectivity(act)

jit_f = jit(partial(bp.measure.functional_connectivity, numpy=False))
r2 = jit_f(act)

self.assertTrue(bm.allclose(r1, r2))


class TestMatrixCorrelation(unittest.TestCase):
@@ -67,5 +87,11 @@ class TestMatrixCorrelation(unittest.TestCase):
bp.math.random.seed()
A = bp.math.random.random((100, 100))
B = bp.math.random.random((100, 100))
print(bp.measure.matrix_correlation(A, B))
r1 = (bp.measure.matrix_correlation(A, B))

jit_f = jit(partial(bp.measure.matrix_correlation, numpy=False))
r2 = jit_f(A, B)

self.assertTrue(bm.allclose(r1, r2))



+ 3
- 3
brainpy/modes.py View File

@@ -13,7 +13,7 @@ __all__ = [
'batching',
'training',

'check',
'check_mode',
]


@@ -42,14 +42,14 @@ batching = BatchingMode()
training = TrainingMode()


def check(mode, supported_modes, name=''):
def check_mode(mode, supported_modes, name=''):
"""Check whether the used mode is in the list of the supported models.

Parameters
----------
mode: Mode
The mode used.
supported_modes: list of type, tuple of type
supported_modes: type, list of type, tuple of type
The list of all types to support.
name: Any
The name.


+ 2
- 2
brainpy/optimizers/optimizer.py View File

@@ -482,8 +482,8 @@ class LARS(Optimizer):
for k, p in self.vars_to_train.items():
g = grads[k]
m = self.implicit_vars[k + '_m']
p_norm = jnp.linalg.norm(bm.as_device_array(p))
g_norm = jnp.linalg.norm(bm.as_device_array(g))
p_norm = jnp.linalg.norm(bm.as_jax(p))
g_norm = jnp.linalg.norm(bm.as_jax(g))
trust_ratio = self.tc * p_norm / (g_norm + self.weight_decay * p_norm + self.eps)
local_lr = lr * jnp.maximum(jnp.logical_or(p_norm == 0, g_norm == 0), trust_ratio)
m.value = self.momentum * m.value + local_lr * (g + self.weight_decay * p.value)


+ 19
- 2
brainpy/running/__init__.py View File

@@ -2,9 +2,26 @@


"""
This module provides APIs for brain simulations.
This module provides APIs for parallel brain simulations.
"""

from .multiprocess import *
from . import jax_multiprocessing
from . import native_multiprocessing
from . import pathos_multiprocessing
from . import runner
from . import constants


__all__ = (native_multiprocessing.__all__ +
pathos_multiprocessing.__all__ +
jax_multiprocessing.__all__ +
runner.__all__ +
constants.__all__)


from .runner import *
from .jax_multiprocessing import *
from .native_multiprocessing import *
from .pathos_multiprocessing import *
from .constants import *


+ 141
- 0
brainpy/running/jax_multiprocessing.py View File

@@ -0,0 +1,141 @@
# -*- coding: utf-8 -*-

from typing import Sequence, Dict, Union

import numpy as np
from jax import vmap, pmap
from jax.tree_util import tree_unflatten, tree_flatten

import brainpy.math as bm
from brainpy.types import Array

__all__ = [
'jax_vectorize_map',
'jax_parallelize_map',
]


def jax_vectorize_map(
func: callable,
arguments: Union[Dict[str, Array], Sequence[Array]],
num_parallel: int,
clear_buffer: bool = False
):
"""Perform a vectorized map of a function by using ``jax.vmap``.

This function can be used in CPU or GPU backends. But it is highly
suitable to be used in GPU backends. This is because ``jax.vmap``
can parallelize the mapped axis on GPU devices.

Parameters
----------
func: callable, function
The function to be mapped.
arguments: sequence, dict
The function arguments, used to define tasks.
num_parallel: int
The number of batch size.
clear_buffer: bool
Clear the buffer memory after running each batch data.

Returns
-------
results: Any
The running results.
"""
if not isinstance(arguments, (dict, tuple, list)):
raise TypeError(f'"arguments" must be sequence or dict, but we got {type(arguments)}')
elements, tree = tree_flatten(arguments, is_leaf=lambda a: isinstance(a, bm.JaxArray))
if clear_buffer:
elements = [np.asarray(ele) for ele in elements]
num_pars = [len(ele) for ele in elements]
if len(np.unique(num_pars)) != 1:
raise ValueError(f'All elements in parameters should have the same length. '
f'But we got {tree_unflatten(tree, num_pars)}')

res_tree = None
results = None
vmap_func = vmap(func)
for i in range(0, num_pars[0], num_parallel):
run_f = vmap(func) if clear_buffer else vmap_func
if isinstance(arguments, dict):
r = run_f(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements]))
else:
r = run_f(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements]))
res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.JaxArray))
if results is None:
results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values)
else:
for j, val in enumerate(res_values):
results[j].append(np.asarray(val) if clear_buffer else val)
if clear_buffer:
bm.clear_buffer_memory()
if res_tree is None:
return None
results = ([np.concatenate(res, axis=0) for res in results]
if clear_buffer else
[bm.concatenate(res, axis=0) for res in results])
return tree_unflatten(res_tree, results)


def jax_parallelize_map(
func: callable,
arguments: Union[Dict[str, Array], Sequence[Array]],
num_parallel: int,
clear_buffer: bool = False
):
"""Perform a parallelized map of a function by using ``jax.pmap``.

This function can be used in multi- CPU or GPU backends.
If you are using it in a single CPU, please set host device count
by ``brainpy.math.set_host_device_count(n)`` before.

Parameters
----------
func: callable, function
The function to be mapped.
arguments: sequence, dict
The function arguments, used to define tasks.
num_parallel: int
The number of batch size.
clear_buffer: bool
Clear the buffer memory after running each batch data.

Returns
-------
results: Any
The running results.
"""
if not isinstance(arguments, (dict, tuple, list)):
raise TypeError(f'"arguments" must be sequence or dict, but we got {type(arguments)}')
elements, tree = tree_flatten(arguments, is_leaf=lambda a: isinstance(a, bm.JaxArray))
if clear_buffer:
elements = [np.asarray(ele) for ele in elements]
num_pars = [len(ele) for ele in elements]
if len(np.unique(num_pars)) != 1:
raise ValueError(f'All elements in parameters should have the same length. '
f'But we got {tree_unflatten(tree, num_pars)}')

res_tree = None
results = None
vmap_func = pmap(func)
for i in range(0, num_pars[0], num_parallel):
run_f = pmap(func) if clear_buffer else vmap_func
if isinstance(arguments, dict):
r = run_f(**tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements]))
else:
r = run_f(*tree_unflatten(tree, [ele[i: i + num_parallel] for ele in elements]))
res_values, res_tree = tree_flatten(r, is_leaf=lambda a: isinstance(a, bm.JaxArray))
if results is None:
results = tuple([np.asarray(val) if clear_buffer else val] for val in res_values)
else:
for j, val in enumerate(res_values):
results[j].append(np.asarray(val) if clear_buffer else val)
if clear_buffer:
bm.clear_buffer_memory()
if res_tree is None:
return None
results = ([np.concatenate(res, axis=0) for res in results]
if clear_buffer else
[bm.concatenate(res, axis=0) for res in results])
return tree_unflatten(res_tree, results)

brainpy/running/multiprocess.py → brainpy/running/native_multiprocessing.py View File

@@ -1,17 +1,17 @@
# -*- coding: utf-8 -*-

from typing import Union, Sequence, Dict
import multiprocessing


__all__ = [
'process_pool',
'process_pool_lock',
'vectorize_map',
'parallelize_map',
]


def process_pool(func, all_params, num_process):
def process_pool(func: callable,
all_params: Union[Sequence, Dict],
num_process: int):
"""Run multiple models in multi-processes.

.. Note::
@@ -21,7 +21,7 @@ def process_pool(func, all_params, num_process):
----------
func : callable
The function to run model.
all_params : a_list, tuple
all_params : list, tuple, dict
The parameters of the function arguments.
The parameters for each process can be a tuple, or a dictionary.
num_process : int
@@ -47,7 +47,9 @@ def process_pool(func, all_params, num_process):
return [r.get() for r in results]


def process_pool_lock(func, all_net_params, nb_process):
def process_pool_lock(func: callable,
all_params: Union[Sequence, Dict],
num_process: int):
"""Run multiple models in multi-processes with lock.

Sometimes, you want to synchronize the processes. For example,
@@ -71,11 +73,11 @@ def process_pool_lock(func, all_net_params, nb_process):

Parameters
----------
func : callable
func: callable
The function to run model.
all_net_params : a_list, tuple
all_params : list, tuple, dict
The parameters of the function arguments.
nb_process : int
num_process : int
The number of the processes.

Returns
@@ -83,12 +85,12 @@ def process_pool_lock(func, all_net_params, nb_process):
results : list
Process results.
"""
print('{} jobs total.'.format(len(all_net_params)))
pool = multiprocessing.Pool(processes=nb_process)
print('{} jobs total.'.format(len(all_params)))
pool = multiprocessing.Pool(processes=num_process)
m = multiprocessing.Manager()
lock = m.Lock()
results = []
for net_params in all_net_params:
for net_params in all_params:
if isinstance(net_params, (list, tuple)):
results.append(pool.apply_async(func, args=tuple(net_params) + (lock,)))
elif isinstance(net_params, dict):
@@ -99,14 +101,3 @@ def process_pool_lock(func, all_net_params, nb_process):
pool.close()
pool.join()
return [r.get() for r in results]


def vectorize_map(func, all_params, num_thread):
pass


def parallelize_map(func, all_params, num_process):
pass




+ 228
- 0
brainpy/running/pathos_multiprocessing.py View File

@@ -0,0 +1,228 @@
# -*- coding: utf-8 -*-


"""The parallel execution of a BrainPy func on multiple CPU cores.

Specifically, these batch running functions include:

- ``cpu_ordered_parallel``: Performs a parallel ordered map.
- ``cpu_unordered_parallel``: Performs a parallel unordered map.
"""

from collections.abc import Sized
from typing import (Any, Callable, Generator, Iterable, List,
Union, Optional, Sequence, Dict)

from tqdm.auto import tqdm

from brainpy.errors import PackageMissingError

try:
from pathos.helpers import cpu_count
from pathos.multiprocessing import ProcessPool
except ModuleNotFoundError:
cpu_count = None
ProcessPool = None

__all__ = [
'cpu_ordered_parallel',
'cpu_unordered_parallel',
]


def _parallel(
ordered: bool,
function: Callable,
arguments: Union[Sequence[Iterable], Dict[str, Iterable]],
num_process: Union[int, float] = None,
num_task: int = None,
**tqdm_kwargs: Any
) -> Generator:
"""Perform a parallel map with a progress bar.

Parameters
----------
ordered: bool
True for an ordered map, false for an unordered map.
function: callable, function
The function to apply to each element of the given Iterables.
arguments: sequence of Iterable, dict
One or more Iterables containing the data to be mapped.
num_process: int, float
Number of threads used for parallel running. If `int`, it is
the number of threads to be used; if `float`, it is the fraction
of total threads to be used for running.
num_task: int
The total number of tasks in this parallel running.
tqdm_kwargs: Any
The setting for the progress bar.

Returns
-------
results: Iterable
A generator which will apply the function to each element of the given Iterables
in parallel in order with a progress bar.
"""
if ProcessPool is None or cpu_count is None:
raise PackageMissingError(
'''
Please install "pathos" package first.
>>> pip install pathos
'''
)

# Determine num_process
if num_process is None:
num_process = cpu_count()
elif isinstance(num_process, int):
pass
elif isinstance(num_process, float):
num_process = int(round(num_process * cpu_count()))
else:
raise ValueError('"num_process" must be an int or a float.')

# arguments
if isinstance(arguments, dict):
keys = list(arguments.keys())
arguments = list(arguments.values())
run_f = lambda *args: function(**{key: arg for key, arg in zip(keys, args)})
else:
if not isinstance(arguments, (tuple, list)):
raise TypeError('"arguments" must be a sequence of Iterable or a dict of Iterable. '
f'But we got {type(arguments)}')
run_f = function

# Determine length of tqdm
lengths = [len(iterable) for iterable in arguments if isinstance(iterable, Sized)]
num_task = num_task or (min(lengths) if lengths else None)

# Create parallel generator
pool = ProcessPool(nodes=num_process)
if ordered:
map_func = pool.imap
else:
map_func = pool.uimap

# Choose tqdm variant
for item in tqdm(map_func(run_f, *arguments), total=num_task, **tqdm_kwargs):
yield item

pool.clear()


def cpu_ordered_parallel(
func: Callable,
arguments: Union[Sequence[Iterable], Dict[str, Iterable]],
num_process: Optional[Union[int, float]] = None,
num_task: Optional[int] = None,
**tqdm_kwargs: Any
) -> List[Any]:
"""Performs a parallel ordered map with a progress bar.

Examples
--------

>>> import brainpy as bp
>>> import brainpy.math as bm
>>> import numpy as np
>>>
>>> def simulate(inp):
>>> inp = bm.as_jax(inp)
>>> hh = bp.neurons.HH(1)
>>> runner = bp.DSRunner(hh, inputs=['input', inp],
>>> monitors=['V', 'spike'],
>>> progress_bar=False)
>>> runner.run(100)
>>> bm.clear_buffer_memory() # clear all cached data and functions
>>> return runner.mon.spike.sum()
>>>
>>> if __name__ == '__main__': # This is important!
>>> results = bp.running.cpu_unordered_parallel(simulate, [np.arange(1, 10, 100)], num_process=10)
>>> print(results)

Parameters
----------
func: callable, function
The function to apply to each element of the given Iterables.
arguments: sequence of Iterable, dict
One or more Iterables containing the data to be mapped.
num_process: int, float
Number of threads used for parallel running. If `int`, it is
the number of threads to be used; if `float`, it is the fraction
of total threads to be used for running.
num_task: int
The total number of tasks in this parallel running.
tqdm_kwargs: Any
The setting for the progress bar.

Returns
-------
results: list
A list which will apply the function to each element of the given tasks.
"""
generator = _parallel(True,
func,
arguments,
num_process=num_process,
num_task=num_task,
**tqdm_kwargs)
return list(generator)


def cpu_unordered_parallel(
func: Callable,
arguments: Union[Sequence[Iterable], Dict[str, Iterable]],
num_process: Optional[Union[int, float]] = None,
num_task: Optional[int] = None,
**tqdm_kwargs: Any
) -> List[Any]:
"""Performs a parallel unordered map with a progress bar.

Examples
--------
>>> import brainpy as bp
>>> import brainpy.math as bm
>>> import numpy as np
>>>
>>> def simulate(inp):
>>> inp = bm.as_jax(inp)
>>> hh = bp.neurons.HH(1)
>>> runner = bp.DSRunner(hh, inputs=['input', inp],
>>> monitors=['V', 'spike'],
>>> progress_bar=False)
>>> runner.run(100)
>>> bm.clear_buffer_memory() # clear all cached data and functions
>>> return runner.mon.spike.sum()
>>>
>>> if __name__ == '__main__': # This is important!
>>> results = bp.running.cpu_unordered_parallel(simulate, [np.arange(1, 10, 100)], num_process=10)
>>> print(results)

Parameters
----------
func: callable, function
The function to apply to each element of the given Iterables.
arguments: sequence of Iterable, dict
One or more Iterables containing the data to be mapped.
num_process: int, float
Number of threads used for parallel running. If `int`, it is
the number of threads to be used; if `float`, it is the fraction
of total threads to be used for running.
num_task: int
The total number of tasks in this parallel running.
tqdm_kwargs: Any
The setting for the progress bar.

Returns
-------
results: list
A list which will apply the function to each element of the given tasks.
"""
generator = _parallel(False,
func,
arguments,
num_process=num_process,
num_task=num_task,
**tqdm_kwargs)
return list(generator)

+ 3
- 2
brainpy/tools/others/numba_util.py View File

@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-


import numba
import numpy as np
try:
from numba import njit
@@ -11,6 +10,7 @@ except (ImportError, ModuleNotFoundError):
__all__ = [
'numba_jit',
'numba_seed',
'numba_range',
'SUPPORT_NUMBA',
]

@@ -38,3 +38,4 @@ def numba_seed(seed):
_seed(seed)


numba_range = numba.prange if SUPPORT_NUMBA else range

+ 22
- 1
brainpy/tools/others/others.py View File

@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-

import collections.abc
import _thread as thread
import threading
from typing import Optional, Tuple, Callable
from typing import Optional, Tuple, Callable, Union, Sequence, TypeVar

import numpy as np
from jax import lax
@@ -10,6 +11,7 @@ from jax.experimental import host_callback
from tqdm.auto import tqdm

__all__ = [
'replicate',
'not_customized',
'to_size',
'size2num',
@@ -18,6 +20,25 @@ __all__ = [
]


T = TypeVar('T')


def replicate(
element: Union[T, Sequence[T]],
num_replicate: int,
name: str,
) -> Tuple[T, ...]:
"""Replicates entry in `element` `num_replicate` if needed."""
if isinstance(element, (str, bytes)) or not isinstance(element, collections.abc.Sequence):
return (element,) * num_replicate
elif len(element) == 1:
return tuple(element * num_replicate)
elif len(element) == num_replicate:
return tuple(element)
else:
raise TypeError(f"{name} must be a scalar or sequence of length 1 or "
f"sequence of length {num_replicate}.")


def not_customized(fun: Callable) -> Callable:
"""Marks the given module method is not implemented.


+ 12
- 5
brainpy/train/back_propagation.py View File

@@ -64,7 +64,7 @@ class BPTrainer(DSTrainer):
**kwargs)

self.shuffle_data = shuffle_data
self.rng = bm.random.RandomState(seed=seed)
self.rng = bm.random.RandomState(seed)

# jit settings
self.jit[c.PREDICT_PHASE] = self.jit.get(c.PREDICT_PHASE, True)
@@ -300,6 +300,7 @@ class BPTT(BPTrainer):
if self.jit[c.LOSS_PHASE] and jit:
dyn_vars = self.target.vars()
dyn_vars.update(self.dyn_vars)
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
self._f_loss_compiled[shared_args_str] = bm.jit(self._f_loss_compiled[shared_args_str],
dyn_vars=dyn_vars)
return self._f_loss_compiled[shared_args_str]
@@ -311,6 +312,7 @@ class BPTT(BPTrainer):
_f_loss_internal = self.f_loss(shared_args, jit=False)
dyn_vars = self.target.vars()
dyn_vars.update(self.dyn_vars)
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
tran_vars = dyn_vars.subset(bm.TrainVar)
grad_f = bm.grad(_f_loss_internal,
dyn_vars=dyn_vars.unique(),
@@ -339,6 +341,7 @@ class BPTT(BPTrainer):
dyn_vars = self.target.vars()
dyn_vars.update(self.dyn_vars)
dyn_vars.update(self.optimizer.vars())
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
self._f_train_compiled[shared_args_str] = bm.jit(train_func, dyn_vars=dyn_vars.unique())
else:
self._f_train_compiled[shared_args_str] = train_func
@@ -453,6 +456,7 @@ class BPFF(BPTT):
if self.jit[c.LOSS_PHASE] and jit:
dyn_vars = self.target.vars()
dyn_vars.update(self.dyn_vars)
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
self._f_loss_compiled[shared_args_str] = bm.jit(self._f_loss_compiled[shared_args_str],
dyn_vars=dyn_vars)
else:
@@ -480,6 +484,7 @@ class BPFF(BPTT):
if self.jit[c.PREDICT_PHASE] and jit:
dyn_vars = self.target.vars()
dyn_vars.update(self.dyn_vars)
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
self._f_predict_compiled[shared_args_str] = bm.jit(run_func, dyn_vars=dyn_vars.unique())
else:
self._f_predict_compiled[shared_args_str] = run_func
@@ -505,6 +510,7 @@ class OnlineBPTT(BPTT):
if self.jit[c.LOSS_PHASE] and jit:
dyn_vars = self.target.vars()
dyn_vars.update(self.dyn_vars)
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
self._f_loss_compiled[shared_args_str] = bm.jit(self._f_loss_compiled[shared_args_str],
dyn_vars=dyn_vars)
else:
@@ -520,7 +526,7 @@ class OnlineBPTT(BPTT):
shared_args_str = serialize_kwargs(shared_args)
if shared_args_str not in self._f_train_compiled:

def train_step(x):
def train_step(*x):
# t, i, input_, target_ = x
res = self.f_grad(shared_args)(*x)
self.optimizer.update(res[0])
@@ -529,8 +535,8 @@ class OnlineBPTT(BPTT):
if self.jit[c.FIT_PHASE]:
dyn_vars = self.target.vars()
dyn_vars.update(self.dyn_vars)
f = bm.make_loop(train_step, dyn_vars=dyn_vars.unique(), has_return=True)
run_func = lambda all_inputs: f(all_inputs)[1]
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
run_func = lambda all_inputs: bm.for_loop(train_step, dyn_vars.unique(), all_inputs)

else:
def run_func(xs):
@@ -541,7 +547,7 @@ class OnlineBPTT(BPTT):
x = tree_map(lambda x: x[i], inputs, is_leaf=_is_jax_array)
y = tree_map(lambda x: x[i], targets, is_leaf=_is_jax_array)
# step at the i
loss = train_step((times[i], indices[i], x, y))
loss = train_step(times[i], indices[i], x, y)
# append output and monitor
losses.append(loss)
return bm.asarray(losses)
@@ -583,6 +589,7 @@ class OnlineBPTT(BPTT):
if self.jit[c.FIT_PHASE] and jit:
dyn_vars = self.target.vars()
dyn_vars.update(self.dyn_vars)
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
self._f_predict_compiled[shared_args_str] = bm.jit(run_func, dyn_vars=dyn_vars.unique())
else:
self._f_predict_compiled[shared_args_str] = run_func


+ 1
- 0
brainpy/train/offline.py View File

@@ -231,6 +231,7 @@ class OfflineTrainer(DSTrainer):
if self.jit['fit']:
dyn_vars = self.target.vars()
dyn_vars.update(self.dyn_vars)
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
train_func = bm.jit(train_func, dyn_vars=dyn_vars.unique())
return train_func



+ 4
- 5
brainpy/train/online.py View File

@@ -234,8 +234,7 @@ class OnlineTrainer(DSTrainer):

monitor_func = self.build_monitors(self._mon_info[0], self._mon_info[1], shared_args)

def _step_func(all_inputs):
t, i, x, ys = all_inputs
def _step_func(t, i, x, ys):
shared = DotDict(t=t, dt=self.dt, i=i)

# input step
@@ -262,8 +261,8 @@ class OnlineTrainer(DSTrainer):
if self.jit['fit']:
dyn_vars = self.target.vars()
dyn_vars.update(self.dyn_vars)
f = bm.make_loop(_step_func, dyn_vars=dyn_vars.unique(), has_return=True)
return lambda all_inputs: f(all_inputs)[1]
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
return lambda all_inputs: bm.for_loop(_step_func, dyn_vars.unique(), all_inputs)

else:
def run_func(all_inputs):
@@ -273,7 +272,7 @@ class OnlineTrainer(DSTrainer):
for i in range(times.shape[0]):
x = tree_map(lambda x: x[i], xs)
y = tree_map(lambda x: x[i], ys)
output, mon = _step_func((times[i], indices[i], x, y))
output, mon = _step_func(times[i], indices[i], x, y)
outputs.append(output)
for key, value in mon.items():
monitors[key].append(value)


+ 18
- 135
docs/auto_generater.py View File

@@ -491,6 +491,20 @@ def generate_math_docs(path='apis/auto/math/'):
with open(os.path.join(path, 'comparison_table.rst.inc'), 'w') as f:
f.write(codes)

module_and_name = [
('pre_syn_post', '``pre-syn-post`` Transformations',),
('sparse_matmul', 'Sparse Matrix Multiplication',),
('event_matmul', 'Event-based Matrix Multiplication',),
('spikegrad', 'Surrogate Gradients for Spike Operation',),
('op_register', 'Operator Registration',),
('wrap_jax', 'Other Operators',),
]
write_submodules(module_name='brainpy.math.operators',
filename=os.path.join(path, 'operators.rst'),
header='Sparse & Event-based Operators',
submodule_names=[k[0] for k in module_and_name],
section_names=[k[1] for k in module_and_name])

write_module(module_name='brainpy.math.activations',
filename=os.path.join(path, 'activations.rst'),
header='Activation Functions')
@@ -500,9 +514,7 @@ def generate_math_docs(path='apis/auto/math/'):
write_module(module_name='brainpy.math.controls',
filename=os.path.join(path, 'controls.rst'),
header='Control Flows')
write_module(module_name='brainpy.math.operators',
filename=os.path.join(path, 'operators.rst'),
header='Operators')

write_module(module_name='brainpy.math.parallels',
filename=os.path.join(path, 'parallels.rst'),
header='Parallel Compilation')
@@ -553,7 +565,9 @@ def generate_running_docs(path='apis/auto/'):
os.makedirs(path)

module_and_name = [
('multiprocess', 'Parallel Pool'),
('pathos_multiprocessing', 'Parallel Processing 1'),
('native_multiprocessing', 'Parallel Processing 2'),
('jax_multiprocessing', 'Parallel Processing 3'),
('runner', 'Runners')
]
write_submodules(module_name='brainpy.running',
@@ -580,134 +594,3 @@ def generate_tools_docs(path='apis/auto/tools/'):
filename=os.path.join(path, 'errors.rst'),
header='Error Tools')


# ---------- #
# Deprecated #
# ---------- #

def generate_nn_docs(path='apis/auto/nn/'):
if not os.path.exists(path):
os.makedirs(path)

write_module(module_name='brainpy.nn.base',
filename=os.path.join(path, 'base.rst'),
header='Base Classes')
write_module(module_name='brainpy.nn.operations',
filename=os.path.join(path, 'operations.rst'),
header='Node Operations')
write_module(module_name='brainpy.nn.graph_flow',
filename=os.path.join(path, 'graph_flow.rst'),
header='Node Graph Tools')
write_module(module_name='brainpy.nn.datatypes',
filename=os.path.join(path, 'data_types.rst'),
header='Data Types')

module_and_name = [
('rnn_runner', 'Base RNN Runner'),
('rnn_trainer', 'Base RNN Trainer'),
('online_trainer', 'Online RNN Trainer'),
('offline_trainer', 'Offline RNN Trainer'),
('back_propagation', 'Back-propagation Trainer'),
]
write_submodules(module_name='brainpy.nn.runners',
filename=os.path.join(path, 'runners.rst'),
header='Runners and Trainers',
submodule_names=[k[0] for k in module_and_name],
section_names=[k[1] for k in module_and_name])

module_and_name = [
('online', 'Online Training Algorithms'),
('offline', 'Offline Training Algorithms'),
]
write_submodules(module_name='brainpy.nn.algorithms',
filename=os.path.join(path, 'algorithms.rst'),
header='Training Algorithms',
submodule_names=[k[0] for k in module_and_name],
section_names=[k[1] for k in module_and_name])

write_module(module_name='brainpy.nn.nodes.base',
filename=os.path.join(path, 'nodes_base.rst'),
header='Nodes: basic')
write_module(module_name='brainpy.nn.nodes.ANN',
filename=os.path.join(path, 'nodes_ANN.rst'),
header='Nodes: artificial neural network ')
write_module(module_name='brainpy.nn.nodes.RC',
filename=os.path.join(path, 'nodes_RC.rst'),
header='Nodes: reservoir computing')

def generate_compact_docs(path='apis/auto/compat/'):
if not os.path.exists(path):
os.makedirs(path)

write_module(module_name='brainpy.compat.brainobjects',
filename=os.path.join(path, 'brainobjects.rst'),
header='Brain Objects')
write_module(module_name='brainpy.compat.integrators',
filename=os.path.join(path, 'integrators.rst'),
header='Integrators')
write_module(module_name='brainpy.compat.layers',
filename=os.path.join(path, 'layers.rst'),
header='Layers')
write_module(module_name='brainpy.compat.monitor',
filename=os.path.join(path, 'monitor.rst'),
header='Monitor')
write_module(module_name='brainpy.compat.runners',
filename=os.path.join(path, 'runners.rst'),
header='Runners')

write_module(module_name='brainpy.compat.nn.base',
filename=os.path.join(path, 'nn_base.rst'),
header='Base Classes')
write_module(module_name='brainpy.compat.nn.operations',
filename=os.path.join(path, 'nn_operations.rst'),
header='Node Operations')
write_module(module_name='brainpy.compat.nn.graph_flow',
filename=os.path.join(path, 'nn_graph_flow.rst'),
header='Node Graph Tools')
write_module(module_name='brainpy.compat.nn.datatypes',
filename=os.path.join(path, 'nn_data_types.rst'),
header='Data Types')
module_and_name = [
('rnn_runner', 'Base RNN Runner'),
('rnn_trainer', 'Base RNN Trainer'),
('online_trainer', 'Online RNN Trainer'),
('offline_trainer', 'Offline RNN Trainer'),
('back_propagation', 'Back-propagation Trainer'),
]
write_submodules(module_name='brainpy.compat.nn.runners',
filename=os.path.join(path, 'nn_runners.rst'),
header='Runners and Trainers',
submodule_names=[k[0] for k in module_and_name],
section_names=[k[1] for k in module_and_name])
module_and_name = [
('online', 'Online Training Algorithms'),
('offline', 'Offline Training Algorithms'),
]
write_submodules(module_name='brainpy.compat.nn.algorithms',
filename=os.path.join(path, 'nn_algorithms.rst'),
header='Training Algorithms',
submodule_names=[k[0] for k in module_and_name],
section_names=[k[1] for k in module_and_name])
write_module(module_name='brainpy.compat.nn.nodes.base',
filename=os.path.join(path, 'nn_nodes_base.rst'),
header='Nodes: basic')
write_module(module_name='brainpy.compat.nn.nodes.ANN',
filename=os.path.join(path, 'nn_nodes_ANN.rst'),
header='Nodes: artificial neural network ')
write_module(module_name='brainpy.compat.nn.nodes.RC',
filename=os.path.join(path, 'nn_nodes_RC.rst'),
header='Nodes: reservoir computing')


def generate_math_compact_docs(path='apis/auto/math/'):
if not os.path.exists(path):
os.makedirs(path)


write_module(module_name='brainpy.math.compat.optimizers',
filename=os.path.join(path, 'optimizers.rst'),
header='Optimizers')

write_module(module_name='brainpy.math.compat.losses',
filename=os.path.join(path, 'losses.rst'),
header='Losses')

+ 5
- 10
docs/conf.py View File

@@ -13,11 +13,11 @@

import os
import sys
import shutil

sys.path.insert(0, os.path.abspath('../'))

import brainpy

from docs import auto_generater

auto_generater.generate_base_docs()
@@ -36,19 +36,14 @@ auto_generater.generate_optimizers_docs()
auto_generater.generate_measure_docs()
auto_generater.generate_datasets_docs()
auto_generater.generate_tools_docs()
# auto_generater.generate_nn_docs()
# auto_generater.generate_compact_docs()
# auto_generater.generate_math_compact_docs()


import shutil

changelogs = [
('../changelog.rst', 'apis/auto/changelog-brainpy.rst'),
('../extensions/changelog.rst', 'apis/auto/changelog-brainpylib.rst'),
('../changelog.rst', 'apis/auto/changelog.rst'),
]
for source, dest in changelogs:
if os.path.exists(dest): os.remove(dest)
if os.path.exists(dest):
os.remove(dest)
shutil.copyfile(source, dest)

# -- Project information -----------------------------------------------------
@@ -70,7 +65,7 @@ extensions = [
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
'sphinx.ext.mathjax',
'sphinx-mathjax-offline',
# 'sphinx-mathjax-offline',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx_autodoc_typehints',


+ 9
- 2
docs/index.rst View File

@@ -88,6 +88,14 @@ The code of BrainPy is open-sourced at GitHub:
tutorial_advanced/interoperation.ipynb



.. toctree::
:maxdepth: 1
:caption: Frequently Asked Questions

tutorial_FAQs/citing_and_publication


.. toctree::
:maxdepth: 1
:caption: API Documentation
@@ -108,8 +116,7 @@ The code of BrainPy is open-sourced at GitHub:
apis/auto/measure.rst
apis/auto/running.rst
apis/tools.rst
apis/auto/changelog-brainpy.rst
apis/auto/changelog-brainpylib.rst
apis/auto/changelog.rst


Indices and tables


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save