#24 V1.1.5

Merged
BrainPy merged 7 commits from V1.1.5 into master 2 years ago
  1. +12
    -10
      README.md
  2. +2
    -1
      brainpy/__init__.py
  3. +24
    -20
      brainpy/analysis/solver.py
  4. +56
    -56
      brainpy/analysis/symbolic/base.py
  5. +4
    -4
      brainpy/analysis/symbolic/bifurcation.py
  6. +10
    -10
      brainpy/analysis/symbolic/phase_plane.py
  7. +3
    -1
      brainpy/base/collector.py
  8. +9
    -9
      brainpy/base/function.py
  9. +2
    -2
      brainpy/math/numpy/random.py
  10. +145
    -171
      brainpy/simulation/initialize/random_inits.py
  11. +16
    -0
      brainpy/simulation/layers/base.py
  12. +4
    -14
      brainpy/simulation/layers/conv.py
  13. +5
    -15
      brainpy/simulation/layers/dense.py
  14. +40
    -98
      brainpy/simulation/layers/recurrent.py
  15. +14
    -0
      changelog.rst
  16. +33
    -7
      docs/quickstart/installation.rst

+ 12
- 10
README.md View File

@@ -62,9 +62,9 @@ To install the stable release of BrainPy, please use

**Other dependencies**: you want to get the full supports by BrainPy, please install the following packages:

- `JAX >= 0.2.10`, needed for "jax" backend and [many other supports](https://brainpy.readthedocs.io/en/latest/apis/math/special_jax.html)
- `Numba >= 0.52`, needed for JIT compilation on "numpy" backend
- `SymPy >= 1.4`, needed for dynamics "analysis" module and Exponential Euler method
- `JAX >= 0.2.10`, needed for "jax" backend and many other supports ([how to install jax?](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#jax))
- `Numba >= 0.52`, needed for JIT compilation on "numpy" backend ([how to install numba?](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#numba))
- `SymPy >= 1.4`, needed for dynamics "analysis" module and Exponential Euler method ([how to install sympy?](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#sympy))



@@ -126,8 +126,6 @@ See [brainmodels.synapses](https://brainmodels.readthedocs.io/en/latest/apis/syn
- **[Working Memory]** [*(Mi, et. al., 2017)* STP for Working Memory Capacity](https://brainpy-examples.readthedocs.io/en/latest/working_memory/Mi_2017_working_memory_capacity.html)
- **[Working Memory]** [*(Bouchacourt & Buschman, 2019)* Flexible Working Memory Model](https://brainpy-examples.readthedocs.io/en/latest/working_memory/Bouchacourt_2019_Flexible_working_memory.html)
- **[Decision Making]** [*(Wang, 2002)* Decision making spiking model](https://brainpy-examples.readthedocs.io/en/latest/decision_making/Wang_2002_decision_making_spiking.html)
- **[Recurrent Network]** [*(Laje & Buonomano, 2013)* Robust Timing in RNN](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/Laje_Buonomano_2013_robust_timing_rnn.html)
- **[Recurrent Network]** [*(Sussillo & Abbott, 2009)* FORCE Learning](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/Sussillo_Abbott_2009_FORCE_Learning.html)



@@ -135,16 +133,20 @@ See [brainmodels.synapses](https://brainmodels.readthedocs.io/en/latest/apis/syn

- [Train Integrator RNN with BP](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/integrator_rnn.html)

- [FORCE Learning](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/Sussillo_Abbott_2009_FORCE_Learning.html)
- [*(Sussillo & Abbott, 2009)* FORCE Learning](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/Sussillo_Abbott_2009_FORCE_Learning.html)

- [*(Laje & Buonomano, 2013)* Robust Timing in RNN](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/Laje_Buonomano_2013_robust_timing_rnn.html)
- [*(Song, et al., 2016)*: Training excitatory-inhibitory recurrent network](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/Song_2016_EI_RNN.html)
- **[Working Memory]** [*(Masse, et al., 2019)*: RNN with STP for Working Memory](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/Masse_2019_STP_RNN.html)




### Low-dimension dynamics analysis

- [Phase plane analysis of the I<sub>Na,p</sub>-I<sub>K</sub> model](https://brainmodels.readthedocs.io/en/latest/tutorials/dynamics_analysis/NaK_model_analysis.html)
- [Codimension 1 bifurcation analysis of FitzHugh Nagumo model](https://brainmodels.readthedocs.io/en/latest/tutorials/dynamics_analysis/FitzHugh_Nagumo_analysis.html)
- [Codimension 2 bifurcation analysis of FitzHugh Nagumo model](https://brainmodels.readthedocs.io/en/latest/tutorials/dynamics_analysis/FitzHugh_Nagumo_analysis.html#Codimension-2-bifurcation-analysis)
- [1D system bifurcation](https://brainmodels.readthedocs.io/en/latest/low_dim_analysis/1D_system_bifurcation.html)
- [Codimension 1 bifurcation analysis of FitzHugh Nagumo model](https://brainpy-examples.readthedocs.io/en/latest/low_dim_analysis/FitzHugh_Nagumo_analysis.html)
- [Codimension 2 bifurcation analysis of FitzHugh Nagumo model](https://brainpy-examples.readthedocs.io/en/latest/low_dim_analysis/FitzHugh_Nagumo_analysis.html#Codimension-2-bifurcation-analysis)
- **[Decision Making Model]** [*(Wong & Wang, 2006)* Decision making rate model](https://brainpy-examples.readthedocs.io/en/latest/decision_making/Wang_2006_decision_making_rate.html)




+ 2
- 1
brainpy/__init__.py View File

@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "1.1.4"
__version__ = "1.1.5"


# "base" module
@@ -35,6 +35,7 @@ from .simulation import connect
from .simulation import initialize
from .simulation import inputs
from .simulation import measure
init = initialize


# "analysis" module


+ 24
- 20
brainpy/analysis/solver.py View File

@@ -165,31 +165,34 @@ def find_root_of_1d(f, f_points, args=(), tol=1e-8):
"""
vals = f(f_points, *args)
fs_len = len(f_points)
fs_sign = np.sign(vals)
signs = np.sign(vals)

roots = []
fl_sign = fs_sign[0]
f_i = 1
while f_i < fs_len and fl_sign == 0.:
roots.append(f_points[f_i - 1])
fl_sign = fs_sign[f_i]
f_i += 1
while f_i < fs_len:
fr_sign = fs_sign[f_i]
if fr_sign == 0.:
roots.append(f_points[f_i])
if f_i + 1 < fs_len:
fl_sign = fs_sign[f_i + 1]
sign_l = signs[0]
point_l = f_points[0]
idx = 1
while idx < fs_len and sign_l == 0.:
roots.append(f_points[idx - 1])
sign_l = signs[idx]
idx += 1
while idx < fs_len:
sign_r = signs[idx]
point_r = f_points[idx]
if sign_r == 0.:
roots.append(point_r)
if idx + 1 < fs_len:
sign_l = sign_r
point_l = point_r
else:
break
f_i += 2
idx += 1
else:
if not np.isnan(fr_sign) and fl_sign != fr_sign:
root, funcalls, itr = brentq(f, f_points[f_i - 1], f_points[f_i], args)
if abs(f(root, *args)) < tol:
roots.append(root)
fl_sign = fr_sign
f_i += 1
if not np.isnan(sign_r) and sign_l != sign_r:
root, funcalls, itr = brentq(f, point_l, point_r, args)
if abs(f(root, *args)) < tol: roots.append(root)
sign_l = sign_r
point_l = point_r
idx += 1

return roots

@@ -244,6 +247,7 @@ def find_root_of_2d(f, x_bound, y_bound, args=(), shgo_args=None,
res : tuple
The roots.
"""
print('Using scipy.optimize.shgo to solve fixed points.')

if shgo is None:
raise errors.PackageMissingError('Package "scipy" must be installed when the users '


+ 56
- 56
brainpy/analysis/symbolic/base.py View File

@@ -287,8 +287,8 @@ class Base1DSymAnalyzer(BaseSymAnalyzer):
sympy_failed = True
if not self.options.escape_sympy_solver and not x_eq.contain_unknown_func:
try:
logger.info(f'SymPy solve derivative of "{self.x_eq_group.func_name}'
f'({argument})" by "{x_var}", ')
logger.warning(f'SymPy solve derivative of "{self.x_eq_group.func_name}'
f'({argument})" by "{x_var}", ')
x_eq = x_eq.expr
f = utils.timeout(time_out)(lambda: sympy.diff(x_eq, x_symbol))
dfxdx_expr = f()
@@ -297,10 +297,10 @@ class Base1DSymAnalyzer(BaseSymAnalyzer):
all_vars = set(eq_x_scope.keys())
all_vars.update(self.dvar_names + self.dpar_names)
if utils.contain_unknown_symbol(analysis_by_sympy.sympy2str(dfxdx_expr), all_vars):
logger.info('\tfailed because contain unknown symbols.')
logger.warning('\tfailed because contain unknown symbols.')
sympy_failed = True
else:
logger.info('\tsuccess.')
logger.warning('\tsuccess.')
func_codes = [f'def dfdx({argument}):']
for expr in self.x_eq_group.sub_exprs[:-1]:
func_codes.append(f'{expr.var_name} = {expr.code}')
@@ -309,9 +309,9 @@ class Base1DSymAnalyzer(BaseSymAnalyzer):
dfdx = eq_x_scope['dfdx']
sympy_failed = False
except KeyboardInterrupt:
logger.info(f'\tfailed because {time_out} s timeout.')
logger.warning(f'\tfailed because {time_out} s timeout.')
except NotImplementedError:
logger.info('\tfailed because the equation is too complex.')
logger.warning('\tfailed because the equation is too complex.')

if sympy_failed:
scope = dict(_fx=self.get_f_dx(), perturb=self.options.perturbation, math=math)
@@ -349,8 +349,8 @@ class Base1DSymAnalyzer(BaseSymAnalyzer):
sympy_failed = True
if not self.options.escape_sympy_solver and not x_eq.contain_unknown_func:
try:
logger.info(f'SymPy solve "{self.x_eq_group.func_name}({argument1}) = 0" '
f'to "{self.x_var} = f({argument2})", ')
logger.warning(f'SymPy solve "{self.x_eq_group.func_name}({argument1}) = 0" '
f'to "{self.x_var} = f({argument2})", ')

# solver
f = utils.timeout(timeout_len)(
@@ -360,11 +360,11 @@ class Base1DSymAnalyzer(BaseSymAnalyzer):
all_vars = set(scope.keys())
all_vars.update(self.dvar_names + self.dpar_names)
if utils.contain_unknown_symbol(analysis_by_sympy.sympy2str(res), all_vars):
logger.info('\tfailed because contain unknown symbols.')
logger.warning('\tfailed because contain unknown symbols.')
sympy_failed = True
break
else:
logger.info('\tsuccess.')
logger.warning('\tsuccess.')
# function codes
func_codes = [f'def solve_x({argument2}):']
for expr in self.x_eq_group.sub_exprs[:-1]:
@@ -379,10 +379,10 @@ class Base1DSymAnalyzer(BaseSymAnalyzer):
self.analyzed_results['fixed_point'] = scope['solve_x']
sympy_failed = False
except NotImplementedError:
logger.info('\tfailed because the equation is too complex.')
logger.warning('\tfailed because the equation is too complex.')
sympy_failed = True
except KeyboardInterrupt:
logger.info(f'\tfailed because {timeout_len} s timeout.')
logger.warning(f'\tfailed because {timeout_len} s timeout.')
sympy_failed = True

if sympy_failed:
@@ -531,8 +531,8 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
sympy_failed = True
if not self.options.escape_sympy_solver and not x_eq.contain_unknown_func:
try:
logger.info(f'SymPy solve derivative of "{self.x_eq_group.func_name}'
f'({argument})" by "{y_var}", ')
logger.warning(f'SymPy solve derivative of "{self.x_eq_group.func_name}'
f'({argument})" by "{y_var}", ')
x_eq = x_eq.expr
f = utils.timeout(time_out)(lambda: sympy.diff(x_eq, y_symbol))
dfxdy_expr = f()
@@ -541,10 +541,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
all_vars = set(eq_x_scope.keys())
all_vars.update(self.dvar_names + self.dpar_names)
if utils.contain_unknown_symbol(analysis_by_sympy.sympy2str(dfxdy_expr), all_vars):
logger.info('\tfailed because contain unknown symbols.')
logger.warning('\tfailed because contain unknown symbols.')
sympy_failed = True
else:
logger.info('\tsuccess.')
logger.warning('\tsuccess.')
func_codes = [f'def dfdy({argument}):']
for expr in self.x_eq_group.sub_exprs[:-1]:
func_codes.append(f'{expr.var_name} = {expr.code}')
@@ -553,9 +553,9 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
dfdy = eq_x_scope['dfdy']
sympy_failed = False
except KeyboardInterrupt:
logger.info(f'\tfailed because {time_out} s timeout.')
logger.warning(f'\tfailed because {time_out} s timeout.')
except NotImplementedError:
logger.info('\tfailed because the equation is too complex.')
logger.warning('\tfailed because the equation is too complex.')

if sympy_failed:
scope = dict(_fx=self.get_f_dx(), perturb=self.options.perturbation, math=math)
@@ -595,8 +595,8 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
sympy_failed = True
if not self.options.escape_sympy_solver and not y_eq.contain_unknown_func:
try:
logger.info(f'SymPy solve derivative of "{self.y_eq_group.func_name}'
f'({argument})" by "{x_var}", ')
logger.warning(f'SymPy solve derivative of "{self.y_eq_group.func_name}'
f'({argument})" by "{x_var}", ')
y_eq = y_eq.expr
f = utils.timeout(time_out)(lambda: sympy.diff(y_eq, x_symbol))
dfydx_expr = f()
@@ -605,10 +605,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
all_vars = set(eq_y_scope.keys())
all_vars.update(self.dvar_names + self.dpar_names)
if utils.contain_unknown_symbol(analysis_by_sympy.sympy2str(dfydx_expr), all_vars):
logger.info('\tfailed because contain unknown symbols.')
logger.warning('\tfailed because contain unknown symbols.')
sympy_failed = True
else:
logger.info('\tsuccess.')
logger.warning('\tsuccess.')
func_codes = [f'def dgdx({argument}):']
for expr in self.y_eq_group.sub_exprs[:-1]:
func_codes.append(f'{expr.var_name} = {expr.code}')
@@ -617,9 +617,9 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
dgdx = eq_y_scope['dgdx']
sympy_failed = False
except KeyboardInterrupt:
logger.info(f'\tfailed because {time_out} s timeout.')
logger.warning(f'\tfailed because {time_out} s timeout.')
except NotImplementedError:
logger.info('\tfailed because the equation is too complex.')
logger.warning('\tfailed because the equation is too complex.')

if sympy_failed:
scope = dict(_fy=self.get_f_dy(), perturb=self.options.perturbation, math=math)
@@ -660,8 +660,8 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
sympy_failed = True
if not self.options.escape_sympy_solver and not y_eq.contain_unknown_func:
try:
logger.info(f'\tSymPy solve derivative of "{self.y_eq_group.func_name}'
f'({argument})" by "{y_var}", ')
logger.warning(f'\tSymPy solve derivative of "{self.y_eq_group.func_name}'
f'({argument})" by "{y_var}", ')
y_eq = y_eq.expr
f = utils.timeout(time_out)(lambda: sympy.diff(y_eq, y_symbol))
dfydx_expr = f()
@@ -670,10 +670,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
all_vars = set(eq_y_scope.keys())
all_vars.update(self.dvar_names + self.dpar_names)
if utils.contain_unknown_symbol(analysis_by_sympy.sympy2str(dfydx_expr), all_vars):
logger.info('\tfailed because contain unknown symbols.')
logger.warning('\tfailed because contain unknown symbols.')
sympy_failed = True
else:
logger.info('\tsuccess.')
logger.warning('\tsuccess.')
func_codes = [f'def dgdy({argument}):']
for expr in self.y_eq_group.sub_exprs[:-1]:
func_codes.append(f'{expr.var_name} = {expr.code}')
@@ -682,9 +682,9 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
dgdy = eq_y_scope['dgdy']
sympy_failed = False
except KeyboardInterrupt:
logger.info(f'\tfailed because {time_out} s timeout.')
logger.warning(f'\tfailed because {time_out} s timeout.')
except NotImplementedError:
logger.info('\tfailed because the equation is too complex.')
logger.warning('\tfailed because the equation is too complex.')

if sympy_failed:
scope = dict(_fy=self.get_f_dy(), perturb=self.options.perturbation, math=math)
@@ -1065,9 +1065,9 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
timeout_len = self.options.sympy_solver_timeout

try:
logger.info(f'SymPy solve "{self.y_eq_group.func_name}({argument}) = 0" to '
f'"{self.y_var} = f({self.x_var}, '
f'{",".join(self.dvar_names[2:] + self.dpar_names)})", ')
logger.warning(f'SymPy solve "{self.y_eq_group.func_name}({argument}) = 0" to '
f'"{self.y_var} = f({self.x_var}, '
f'{",".join(self.dvar_names[2:] + self.dpar_names)})", ')
# solve the expression
f = utils.timeout(timeout_len)(lambda: sympy.solve(y_eq, y_symbol))
y_by_x_in_y_eq = f()
@@ -1079,10 +1079,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
all_vars = set(eq_y_scope.keys())
all_vars.update(self.dvar_names + self.dpar_names)
if utils.contain_unknown_symbol(y_by_x_in_y_eq, all_vars):
logger.info('\tfailed because contain unknown symbols.')
logger.warning('\tfailed because contain unknown symbols.')
results['status'] = 'sympy_failed'
else:
logger.info('\tsuccess.')
logger.warning('\tsuccess.')
# substituted codes
subs_codes = [f'{expr.var_name} = {expr.code}'
for expr in self.y_eq_group.sub_exprs[:-1]]
@@ -1101,10 +1101,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
results['f'] = eq_y_scope['func']

except NotImplementedError:
logger.info('\tfailed because the equation is too complex.')
logger.warning('\tfailed because the equation is too complex.')
results['status'] = 'sympy_failed'
except KeyboardInterrupt:
logger.info(f'\tfailed because {timeout_len} s timeout.')
logger.warning(f'\tfailed because {timeout_len} s timeout.')
results['status'] = 'sympy_failed'
else:
results['status'] = 'escape'
@@ -1139,9 +1139,9 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
timeout_len = self.options.sympy_solver_timeout

try:
logger.info(f'SymPy solve "{self.x_eq_group.func_name}({argument}) = 0" to '
f'"{self.y_var} = f({self.x_var}, '
f'{",".join(self.dvar_names[2:] + self.dpar_names)})", ')
logger.warning(f'SymPy solve "{self.x_eq_group.func_name}({argument}) = 0" to '
f'"{self.y_var} = f({self.x_var}, '
f'{",".join(self.dvar_names[2:] + self.dpar_names)})", ')

# solve the expression
f = utils.timeout(timeout_len)(lambda: sympy.solve(x_eq, y_symbol))
@@ -1153,10 +1153,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
all_vars = set(eq_x_scope.keys())
all_vars.update(self.dvar_names + self.dpar_names)
if utils.contain_unknown_symbol(y_by_x_in_x_eq, all_vars):
logger.info('\tfailed because contain unknown symbols.')
logger.warning('\tfailed because contain unknown symbols.')
results['status'] = 'sympy_failed'
else:
logger.info('\tsuccess.')
logger.warning('\tsuccess.')

# substituted codes
subs_codes = [f'{expr.var_name} = {expr.code}'
@@ -1175,10 +1175,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
results['subs'] = subs_codes
results['f'] = eq_x_scope['func']
except NotImplementedError:
logger.info('\tfailed because the equation is too complex.')
logger.warning('\tfailed because the equation is too complex.')
results['status'] = 'sympy_failed'
except KeyboardInterrupt:
logger.info(f'\tfailed because {timeout_len} s timeout.')
logger.warning(f'\tfailed because {timeout_len} s timeout.')
results['status'] = 'sympy_failed'
else:
results['status'] = 'escape'
@@ -1213,8 +1213,8 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
timeout_len = self.options.sympy_solver_timeout

try:
logger.info(f'SymPy solve "{self.y_eq_group.func_name}({argument}) = 0" to '
f'"{self.x_var} = f({",".join(self.dvar_names[1:] + self.dpar_names)})", ')
logger.warning(f'SymPy solve "{self.y_eq_group.func_name}({argument}) = 0" to '
f'"{self.x_var} = f({",".join(self.dvar_names[1:] + self.dpar_names)})", ')
# solve the expression
f = utils.timeout(timeout_len)(lambda: sympy.solve(y_eq, x_symbol))
x_by_y_in_y_eq = f()
@@ -1226,10 +1226,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
all_vars = set(eq_y_scope.keys())
all_vars.update(self.dvar_names + self.dpar_names)
if utils.contain_unknown_symbol(x_by_y_in_y_eq, all_vars):
logger.info('\tfailed because contain unknown symbols.')
logger.warning('\tfailed because contain unknown symbols.')
results['status'] = 'sympy_failed'
else:
logger.info('\tsuccess.')
logger.warning('\tsuccess.')

# substituted codes
subs_codes = [f'{expr.var_name} = {expr.code}'
@@ -1248,10 +1248,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
results['subs'] = subs_codes
results['f'] = eq_y_scope['func']
except NotImplementedError:
logger.info('\tfailed because the equation is too complex.')
logger.warning('\tfailed because the equation is too complex.')
results['status'] = 'sympy_failed'
except KeyboardInterrupt:
logger.info(f'\tfailed because {timeout_len} s timeout.')
logger.warning(f'\tfailed because {timeout_len} s timeout.')
results['status'] = 'sympy_failed'
else:
results['status'] = 'escape'
@@ -1286,8 +1286,8 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
timeout_len = self.options.sympy_solver_timeout

try:
logger.info(f'SymPy solve "{self.x_eq_group.func_name}({argument}) = 0" to '
f'"{self.x_var} = f({",".join(self.dvar_names[1:] + self.dpar_names)})", ')
logger.warning(f'SymPy solve "{self.x_eq_group.func_name}({argument}) = 0" to '
f'"{self.x_var} = f({",".join(self.dvar_names[1:] + self.dpar_names)})", ')
# solve the expression
f = utils.timeout(timeout_len)(lambda: sympy.solve(x_eq, x_symbol))
x_by_y_in_x_eq = f()
@@ -1299,10 +1299,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
all_vars = set(eq_x_scope.keys())
all_vars.update(self.dvar_names + self.dpar_names)
if utils.contain_unknown_symbol(x_by_y_in_x_eq, all_vars):
logger.info('\tfailed because contain unknown symbols.')
logger.warning('\tfailed because contain unknown symbols.')
results['status'] = 'sympy_failed'
else:
logger.info('\tsuccess.')
logger.warning('\tsuccess.')

# substituted codes
subs_codes = [f'{expr.var_name} = {expr.code}'
@@ -1321,10 +1321,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer):
results['subs'] = subs_codes
results['f'] = eq_x_scope['func']
except NotImplementedError:
logger.info('\tfailed because the equation is too complex.')
logger.warning('\tfailed because the equation is too complex.')
results['status'] = 'sympy_failed'
except KeyboardInterrupt:
logger.info(f'\tfailed because {timeout_len} s timeout.')
logger.warning(f'\tfailed because {timeout_len} s timeout.')
results['status'] = 'sympy_failed'
else:
results['status'] = 'escape'


+ 4
- 4
brainpy/analysis/symbolic/bifurcation.py View File

@@ -214,7 +214,7 @@ class _Bifurcation1D(base.Base1DSymAnalyzer):
options=options)

def plot_bifurcation(self, show=False):
logger.info('plot bifurcation ...')
logger.warning('plot bifurcation ...')

f_fixed_point = self.get_f_fixed_point()
f_dfdx = self.get_f_dfdx()
@@ -316,7 +316,7 @@ class _Bifurcation2D(base.Base2DSymAnalyzer):
self.fixed_points = None

def plot_bifurcation(self, show=False):
logger.info('plot bifurcation ...')
logger.warning('plot bifurcation ...')

# functions
f_fixed_point = self.get_f_fixed_point()
@@ -405,7 +405,7 @@ class _Bifurcation2D(base.Base2DSymAnalyzer):
return container

def plot_limit_cycle_by_sim(self, var, duration=100, inputs=(), plot_style=None, tol=0.001, show=False):
logger.info('plot limit cycle ...')
logger.warning('plot limit cycle ...')

if self.fixed_points is None:
raise errors.AnalyzerError('Please call "plot_bifurcation()" before "plot_limit_cycle_by_sim()".')
@@ -773,7 +773,7 @@ class _FastSlowTrajectory(object):
show : bool
Whether show or not.
"""
logger.info('plot trajectory ...')
logger.warning('plot trajectory ...')

# 1. format the initial values
all_vars = self.fast_var_names + self.slow_var_names


+ 10
- 10
brainpy/analysis/symbolic/phase_plane.py View File

@@ -228,7 +228,7 @@ class _PhasePlane1D(base.Base1DSymAnalyzer):
results : np.ndarray
The dx values.
"""
logger.info('plot vector field ...')
logger.warning('plot vector field ...')

# 1. Nullcline of the x variable
try:
@@ -265,7 +265,7 @@ class _PhasePlane1D(base.Base1DSymAnalyzer):
points : np.ndarray
The fixed points.
"""
logger.info('plot fixed point ...')
logger.warning('plot fixed point ...')

# 1. functions
f_fixed_point = self.get_f_fixed_point()
@@ -278,7 +278,7 @@ class _PhasePlane1D(base.Base1DSymAnalyzer):
x = x_values[i]
dfdx = f_dfdx(x)
fp_type = stability.stability_analysis(dfdx)
logger.info(f"Fixed point #{i + 1} at {self.x_var}={x} is a {fp_type}.")
logger.warning(f"Fixed point #{i + 1} at {self.x_var}={x} is a {fp_type}.")
container[fp_type].append(x)

# 3. visualization
@@ -330,7 +330,7 @@ class _PhasePlane2D(base.Base2DSymAnalyzer):
result : tuple
The ``dx``, ``dy`` values.
"""
logger.info('plot vector field ...')
logger.warning('plot vector field ...')

if plot_style is None:
plot_style = dict()
@@ -398,7 +398,7 @@ class _PhasePlane2D(base.Base2DSymAnalyzer):
results : tuple
The value points.
"""
logger.info('plot fixed point ...')
logger.warning('plot fixed point ...')

# function for fixed point solving
f_fixed_point = self.get_f_fixed_point()
@@ -414,7 +414,7 @@ class _PhasePlane2D(base.Base2DSymAnalyzer):
x = x_values[i]
y = y_values[i]
fp_type = stability.stability_analysis(f_jacobian(x, y))
logger.info(f"Fixed point #{i + 1} at {self.x_var}={x}, {self.y_var}={y} is a {fp_type}.")
logger.warning(f"Fixed point #{i + 1} at {self.x_var}={x}, {self.y_var}={y} is a {fp_type}.")
container[fp_type]['x'].append(x)
container[fp_type]['y'].append(y)

@@ -453,7 +453,7 @@ class _PhasePlane2D(base.Base2DSymAnalyzer):
values : dict
A dict with the format of ``{func1: (x_val, y_val), func2: (x_val, y_val)}``.
"""
logger.info('plot nullcline ...')
logger.warning('plot nullcline ...')

if numerical_setting is None:
numerical_setting = dict()
@@ -579,7 +579,7 @@ class _PhasePlane2D(base.Base2DSymAnalyzer):
Whether show or not.
"""

logger.info('plot trajectory ...')
logger.warning('plot trajectory ...')

if axes not in ['v-v', 't-v']:
raise errors.BrainPyError(f'Unknown axes "{axes}", only support "v-v" and "t-v".')
@@ -687,7 +687,7 @@ class _PhasePlane2D(base.Base2DSymAnalyzer):
show : bool
Whether show or not.
"""
logger.info('plot limit cycle ...')
logger.warning('plot limit cycle ...')

# 1. format the initial values
if isinstance(initials, dict):
@@ -732,7 +732,7 @@ class _PhasePlane2D(base.Base2DSymAnalyzer):
lines = plt.plot(x_cycle, y_cycle, label='limit cycle')
utils.add_arrow(lines[0])
else:
logger.info(f'No limit cycle found for initial value {initial}')
logger.warning(f'No limit cycle found for initial value {initial}')

# 6. visualization
plt.xlabel(self.x_var)


+ 3
- 1
brainpy/base/collector.py View File

@@ -50,6 +50,8 @@ class Collector(dict):

>>> import brainpy as bp
>>>
>>> some_collector = Collector()
>>>
>>> # get all trainable variables
>>> some_collector.subset(bp.math.TrainVar)
>>>
@@ -59,7 +61,7 @@ class Collector(dict):
or, it can be used to get a subset of integrators:

>>> # get all ODE integrators
>>> some_collector.subset(bp.integrators.ODE_INT)
>>> some_collector.subset(bp.ode.ODEIntegrator)

Parameters
----------


+ 9
- 9
brainpy/base/function.py View File

@@ -4,7 +4,7 @@ from brainpy import errors
from brainpy.base.base import Base
from brainpy.base import collector

ndarray = None
math = None

__all__ = [
'Function',
@@ -18,11 +18,11 @@ def _check_node(node):


def _check_var(var):
global ndarray
if ndarray is None: from brainpy.math import ndarray
if not isinstance(var, ndarray):
global math
if math is None: from brainpy import math
if not isinstance(var, math.ndarray):
raise errors.BrainPyError(f'Element in "dyn_vars" must be an instance of '
f'{ndarray.__name__}, but we got {type(var)}.')
f'{math.ndarray.__name__}, but we got {type(var)}.')


class Function(Base):
@@ -70,9 +70,9 @@ class Function(Base):
# ---
if dyn_vars is not None:
self.implicit_vars = collector.TensorCollector()
global ndarray
if ndarray is None: from brainpy.math import ndarray
if isinstance(dyn_vars, ndarray):
global math
if math is None: from brainpy import math
if isinstance(dyn_vars, math.ndarray):
dyn_vars = (dyn_vars,)
if isinstance(dyn_vars, (tuple, list)):
for i, v in enumerate(dyn_vars):
@@ -83,7 +83,7 @@ class Function(Base):
_check_var(v)
self.implicit_vars.update(dyn_vars)
else:
raise ValueError(f'"dyn_vars" only support list/tuple/dict of {ndarray.__name__}, '
raise ValueError(f'"dyn_vars" only support list/tuple/dict of {math.ndarray.__name__}, '
f'but we got {type(dyn_vars)}: {dyn_vars}')

def __call__(self, *args, **kwargs):


+ 2
- 2
brainpy/math/numpy/random.py View File

@@ -61,8 +61,8 @@ def bernoulli(p, size=None):
return numpy.random.binomial(1, p=p, size=size)


def truncated_normal():
raise NotImplementedError
def truncated_normal(lower, upper, size, scale=1.):
raise NotImplementedError('Please use `brainpy.math.jax.random.truncated_normal()`')


@tools.numba_jit


+ 145
- 171
brainpy/simulation/initialize/random_inits.py View File

@@ -8,15 +8,25 @@ from .base import Initializer
__all__ = [
'Normal',
'Uniform',
'Orthogonal',
'VarianceScaling',
'KaimingUniform',
'KaimingNormal',
'KaimingNormalTruncated',
'XavierUniform',
'XavierNormal',
'XavierNormalTruncated',
'TruncatedNormal',
'LecunUniform',
'LecunNormal',
'Orthogonal',
'DeltaOrthogonal',
]


def _compute_fans(shape, in_axis=-2, out_axis=-1):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out


class Normal(Initializer):
"""Initialize weights with normal distribution.

@@ -26,12 +36,13 @@ class Normal(Initializer):
The gain of the derivation of the normal distribution.

"""
def __init__(self, gain=1.):

def __init__(self, scale=1.):
super(Normal, self).__init__()
self.gain = gain
self.scale = scale

def __call__(self, shape, dtype=None):
weights = math.random.normal(size=shape, scale=self.gain * np.sqrt(1 / np.prod(shape)))
weights = math.random.normal(size=shape, scale=self.scale)
return math.asarray(weights, dtype=dtype)


@@ -46,41 +57,117 @@ class Uniform(Initializer):
The upper limit of the uniform distribution.

"""
def __init__(self, min_val=0., max_val=1.):

def __init__(self, min_val=0., max_val=1., scale=1e-2):
super(Uniform, self).__init__()
self.min_val = min_val
self.max_val = max_val
self.scale = scale

def __call__(self, shape, dtype=None):
r = math.random.uniform(low=self.min_val, high=self.max_val, size=shape)
return math.asarray(r, dtype=dtype)
return math.asarray(r * self.scale, dtype=dtype)


class VarianceScaling(Initializer):
def __init__(self, scale, mode, distribution, in_axis=-2, out_axis=-1):
self.scale = scale
self.mode = mode
self.in_axis = in_axis
self.out_axis = out_axis
self.distribution = distribution

def __call__(self, shape, dtype=None):
fan_in, fan_out = _compute_fans(shape, in_axis=self.in_axis, out_axis=self.out_axis)
if self.mode == "fan_in":
denominator = fan_in
elif self.mode == "fan_out":
denominator = fan_out
elif self.mode == "fan_avg":
denominator = (fan_in + fan_out) / 2
else:
raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode))
variance = math.array(self.scale / denominator, dtype=dtype)
if self.distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
stddev = math.sqrt(variance) / math.array(.87962566103423978, dtype)
res = math.random.truncated_normal(-2, 2, shape) * stddev
return math.asarray(res, dtype=dtype)
elif self.distribution == "normal":
res = math.random.normal(size=shape) * math.sqrt(variance)
return math.asarray(res, dtype=dtype)
elif self.distribution == "uniform":
res = math.random.uniform(low=-1, high=1, size=shape) * math.sqrt(3 * variance)
return math.asarray(res, dtype=dtype)
else:
raise ValueError("invalid distribution for variance scaling initializer")


class KaimingUniform(VarianceScaling):
def __init__(self, scale=2.0, mode="fan_in",
distribution="uniform",
in_axis=-2, out_axis=-1):
super(KaimingUniform, self).__init__(scale, mode, distribution,
in_axis=in_axis,
out_axis=out_axis)


class KaimingNormal(VarianceScaling):
def __init__(self, scale=2.0, mode="fan_in",
distribution="truncated_normal",
in_axis=-2, out_axis=-1):
super(KaimingNormal, self).__init__(scale, mode, distribution,
in_axis=in_axis,
out_axis=out_axis)


class XavierUniform(VarianceScaling):
def __init__(self, scale=1.0, mode="fan_avg",
distribution="uniform",
in_axis=-2, out_axis=-1):
super(XavierUniform, self).__init__(scale, mode, distribution,
in_axis=in_axis,
out_axis=out_axis)


class XavierNormal(VarianceScaling):
def __init__(self, scale=1.0, mode="fan_avg",
distribution="truncated_normal",
in_axis=-2, out_axis=-1):
super(XavierNormal, self).__init__(scale, mode, distribution,
in_axis=in_axis,
out_axis=out_axis)


class LecunUniform(VarianceScaling):
def __init__(self, scale=1.0, mode="fan_in",
distribution="uniform",
in_axis=-2, out_axis=-1):
super(LecunUniform, self).__init__(scale, mode, distribution,
in_axis=in_axis,
out_axis=out_axis)


class LecunNormal(VarianceScaling):
def __init__(self, scale=1.0, mode="fan_in",
distribution="truncated_normal",
in_axis=-2, out_axis=-1):
super(LecunNormal, self).__init__(scale, mode, distribution,
in_axis=in_axis,
out_axis=out_axis)


class Orthogonal(Initializer):
"""Returns a uniformly distributed orthogonal tensor from
`Exact solutions to the nonlinear dynamics of learning in deep linear neural networks
<https://openreview.net/forum?id=_wzZwKpTDF_9C>`_.

Args:
shape: shape of the output tensor.
gain: optional scaling factor.
axis: the orthogonalizarion axis

Returns:
An orthogonally initialized tensor.
These tensors will be row-orthonormal along the access specified by
``axis``. If the rank of the weight is greater than 2, the shape will be
flattened in all other dimensions and then will be row-orthonormal along the
final dimension. Note that this only works if the ``axis`` dimension is
larger, otherwise the tensor will be transposed (equivalently, it will be
column orthonormal instead of row orthonormal).
If the shape is not square, the matrices will have orthonormal rows or
columns depending on which side is smaller.
"""

def __init__(self, gain=1., axis=-1):
"""
Construct an initializer for uniformly distributed orthogonal matrices.

If the shape is not square, the matrices will have orthonormal rows or columns
depending on which side is smaller.
"""

def __init__(self, scale=1., axis=-1):
super(Orthogonal, self).__init__()
self.gain = gain
self.scale = scale
self.axis = axis

def __call__(self, shape, dtype=None):
@@ -94,149 +181,36 @@ class Orthogonal(Initializer):
if n_rows < n_cols: q_mat = q_mat.T
q_mat = np.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
q_mat = np.moveaxis(q_mat, 0, self.axis)
return self.gain * math.asarray(q_mat, dtype=dtype)
return self.scale * math.asarray(q_mat, dtype=dtype)


class KaimingNormal(Initializer):
"""Returns a tensor with values assigned using Kaiming He normal initializer from
`Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification
<https://arxiv.org/abs/1502.01852>`_.

Args:
shape: shape of the output tensor.
gain: optional scaling factor.

Returns:
Tensor initialized with normal random variables with standard deviation (gain * kaiming_normal_gain).
"""

def __init__(self, gain=1.):
self.gain = gain
super(KaimingNormal, self).__init__()

def __call__(self, shape, dtype=None):
gain = np.sqrt(1 / np.prod(shape[:-1]))
res = math.random.normal(size=shape, scale=self.gain * gain)
return math.asarray(res, dtype=dtype)


class KaimingNormalTruncated(Initializer):
"""Returns a tensor with values assigned using Kaiming He truncated normal initializer from
`Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification
<https://arxiv.org/abs/1502.01852>`_.

Args:
shape: shape of the output tensor.
lower: lower truncation of the normal.
upper: upper truncation of the normal.
gain: optional scaling factor.

Returns:
Tensor initialized with truncated normal random variables with standard
deviation (gain * kaiming_normal_gain) and support [lower, upper].
"""

def __init__(self, lower=-2., upper=2., gain=1.):
self.lower = lower
self.upper = upper
self.gain = gain
super(KaimingNormalTruncated, self).__init__()

def __call__(self, shape, dtype=None):
truncated_std = scipy.stats.truncnorm.std(a=self.lower,
b=self.upper,
loc=0.,
scale=1.)
stddev = self.gain * np.sqrt(1 / np.prod(shape[:-1])) / truncated_std
res = math.random.truncated_normal(size=shape,
scale=stddev,
lower=self.lower,
upper=self.upper)
return math.asarray(res, dtype=dtype)


class XavierNormal(Initializer):
"""Returns a tensor with values assigned using Xavier Glorot normal initializer from
`Understanding the difficulty of training deep feedforward neural networks
<http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_.

Args:
shape: shape of the output tensor.
gain: optional scaling factor.

Returns:
Tensor initialized with normal random variables with standard deviation (gain * xavier_normal_gain).
"""

def __init__(self, gain=1.):
super(XavierNormal, self).__init__()
self.gain = gain
class DeltaOrthogonal(Initializer):
"""
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.

def __call__(self, shape, dtype=None):
fan_in, fan_out = np.prod(shape[:-1]), shape[-1]
gain = np.sqrt(2 / (fan_in + fan_out))
res = math.random.normal(size=shape, scale=self.gain * gain)
return math.asarray(res, dtype=dtype)


class XavierNormalTruncated(Initializer):
"""Returns a tensor with values assigned using Xavier Glorot truncated normal initializer from
`Understanding the difficulty of training deep feedforward neural networks
<http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_.

Args:
shape: shape of the output tensor.
lower: lower truncation of the normal.
upper: upper truncation of the normal.
gain: optional scaling factor.

Returns:
Tensor initialized with truncated normal random variables with standard
deviation (gain * xavier_normal_gain) and support [lower, upper].
"""

def __init__(self, lower=-2., upper=2., gain=1.):
self.lower = lower
self.upper = upper
self.gain = gain
super(XavierNormalTruncated, self).__init__()
The shape must be 3D, 4D or 5D.
"""

def __call__(self, shape, dtype=None):
truncated_std = scipy.stats.truncnorm.std(a=self.lower, b=self.upper, loc=0., scale=1)
fan_in, fan_out = np.prod(shape[:-1]), shape[-1]
gain = np.sqrt(2 / (fan_in + fan_out))
stddev = self.gain * gain / truncated_std
res = math.random.truncated_normal(size=shape,
scale=stddev,
lower=self.lower,
upper=self.upper)
return math.asarray(res, dtype=dtype)


class TruncatedNormal(Initializer):
"""Returns a tensor with values assigned using truncated normal initialization.

Args:
shape: shape of the output tensor.
lower: lower truncation of the normal.
upper: upper truncation of the normal.
stddev: expected standard deviation.

Returns:
Tensor initialized with truncated normal random variables with standard
deviation stddev and support [lower, upper].
"""

def __init__(self, lower=-2., upper=2., scale=1.):
self.lower = lower
self.upper = upper
def __init__(self, scale=1.0, axis=-1, ):
super(DeltaOrthogonal, self).__init__()
self.scale = scale
super(TruncatedNormal, self).__init__()
self.axis = axis

def __call__(self, shape, dtype=None):
truncated_std = scipy.stats.truncnorm.std(a=self.lower, b=self.upper, loc=0., scale=1)
res = math.random.truncated_normal(size=shape,
scale=self.scale / truncated_std,
lower=self.lower,
upper=self.upper)
return math.asarray(res, dtype=dtype)
if len(shape) not in [3, 4, 5]:
raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.")
if shape[-1] < shape[-2]:
raise ValueError("`fan_in` must be less or equal than `fan_out`. ")
ortho_init = Orthogonal(scale=self.scale, axis=self.axis)
ortho_matrix = ortho_init(shape[-2:], dtype=dtype)
W = math.zeros(shape, dtype=dtype)
if len(shape) == 3:
k = shape[0]
W[(k - 1) // 2, ...] = ortho_matrix
elif len(shape) == 4:
k1, k2 = shape[:2]
W[(k1 - 1) // 2, (k2 - 1) // 2, ...] = ortho_matrix
else:
k1, k2, k3 = shape[:3]
W[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2, ...] = ortho_matrix
return W

+ 16
- 0
brainpy/simulation/layers/base.py View File

@@ -1,7 +1,10 @@
# -*- coding: utf-8 -*-

import inspect
import numpy as onp
from typing import Union
import jax.numpy as jnp
import brainpy.math.jax as bm

from brainpy import errors
from brainpy.base.collector import Collector
@@ -26,6 +29,19 @@ class Module(DynamicalSystem):
"""Basic module class for DNN networks."""
target_backend = 'jax'

@staticmethod
def get_param(param, size):
if param is None:
return None
if callable(param):
return bm.TrainVar(param(size))
if isinstance(param, onp.ndarray):
assert param.shape == size
return bm.TrainVar(bm.asarray(param))
if isinstance(param, (bm.JaxArray, jnp.ndarray)):
return bm.TrainVar(param)
raise ValueError


class Sequential(Module):
"""Basic sequential object to control data flow.


+ 4
- 14
brainpy/simulation/layers/conv.py View File

@@ -82,18 +82,8 @@ class Conv2D(Module):
self.has_bias = True

# weight initialization
if callable(w):
self.w = bm.TrainVar(w((*_check_tuple(kernel_size), num_input // groups, num_output))) # HWIO
else:
assert w.shape == (*_check_tuple(kernel_size), num_input // groups, num_output)
self.w = bm.TrainVar(w)
if callable(b):
self.b = bm.TrainVar(b((num_output, 1, 1)))
elif b is None:
self.has_bias = False
else:
assert b.shape == (num_output, 1, 1)
self.b = bm.TrainVar(b)
self.w = self.get_param(w, (*_check_tuple(kernel_size), num_input // groups, num_output))
self.b = self.get_param(b, (num_output, 1, 1))

def update(self, x):
nin = self.w.value.shape[2] * self.groups
@@ -108,5 +98,5 @@ class Conv2D(Module):
rhs_dilation=self.dilations,
feature_group_count=self.groups,
dimension_numbers=('NCHW', 'HWIO', 'NCHW'))
if self.has_bias: y += self.b.value
return y
if self.b is None: return y
return y + self.b.value

+ 5
- 15
brainpy/simulation/layers/dense.py View File

@@ -40,22 +40,12 @@ class Dense(Module):
self.num_hidden = num_hidden

# variables
if callable(w):
self.w = bm.TrainVar(w((num_input, num_hidden)))
else:
assert w.shape == (num_input, num_hidden)
self.w = bm.TrainVar(w)
if b is None:
self.has_bias = False
elif callable(b):
self.b = bm.TrainVar(b((num_hidden,)))
else:
assert b.shape == (num_hidden, )
self.b = bm.TrainVar(b)
self.w = self.get_param(w, (num_input, num_hidden))
self.b = self.get_param(b, (num_hidden,))

def update(self, x):
"""Returns the results of applying the linear transformation to input x."""
if self.has_bias:
return x @ self.w + self.b
else:
if self.b is None:
return x @ self.w
else:
return x @ self.w + self.b

+ 40
- 98
brainpy/simulation/layers/recurrent.py View File

@@ -43,35 +43,17 @@ class VanillaRNN(RNNCore):
def __init__(self, num_hidden, num_input, num_batch, h=Uniform(), w=XavierNormal(), b=ZeroInit(), **kwargs):
super(VanillaRNN, self).__init__(num_hidden, num_input, **kwargs)

self.has_bias = True

# variables
if callable(h):
self.h = bm.Variable(h((num_batch, self.num_hidden)))
else:
self.h = bm.Variable(h)
self.h = bm.Variable(self.get_param(h, (num_batch, self.num_hidden)))

# weights
if callable(w):
self.w_ir = bm.TrainVar(w((num_input, num_hidden)))
self.w_rr = bm.TrainVar(w((num_hidden, num_hidden)))
else:
w_ir, w_rr = w
assert w_ir.shape == (num_input, num_hidden)
assert w_rr.shape == (num_hidden, num_hidden)
self.w_ir = bm.TrainVar(w_ir)
self.w_rr = bm.TrainVar(w_rr)
if b is None:
self.has_bias = False
elif callable(b):
self.b = bm.TrainVar(b((num_hidden,)))
else:
assert b.shape == (num_hidden,)
self.b = bm.TrainVar(b)
ws = self.get_param(w, (num_input + num_hidden, num_hidden))
self.w_ir = bm.TrainVar(ws[:num_input])
self.w_rr = bm.TrainVar(ws[num_input:])
self.b = self.get_param(b, (num_hidden,))

def update(self, x):
h = x @ self.w_ir + self.h @ self.w_rr
if self.has_bias: h += self.b
h = x @ self.w_ir + self.h @ self.w_rr + self.b
self.h.value = bm.relu(h)
return self.h

@@ -112,60 +94,35 @@ class GRU(RNNCore):
self.has_bias = True

# variables
if callable(h):
self.h = bm.Variable(h((num_batch, self.num_hidden)))
else:
self.h = bm.Variable(h)
self.h = bm.Variable(self.get_param(h, (num_batch, self.num_hidden)))

# weights
if callable(wx):
self.w_iz = bm.TrainVar(wx((num_input, num_hidden)))
self.w_ir = bm.TrainVar(wx((num_input, num_hidden)))
self.w_ia = bm.TrainVar(wx((num_input, num_hidden)))
else:
w_iz, w_ir, w_ia = wx
assert w_iz.shape == (num_input, num_hidden)
assert w_ir.shape == (num_input, num_hidden)
assert w_ia.shape == (num_input, num_hidden)
self.w_iz = bm.TrainVar(w_iz)
self.w_ir = bm.TrainVar(w_ir)
self.w_ia = bm.TrainVar(w_ia)
if callable(wh):
self.w_hz = bm.TrainVar(wh((num_hidden, num_hidden)))
self.w_hr = bm.TrainVar(wh((num_hidden, num_hidden)))
self.w_ha = bm.TrainVar(wh((num_hidden, num_hidden)))
else:
w_hz, w_hr, w_ha = wh
assert w_hz.shape == (num_hidden, num_hidden)
assert w_hr.shape == (num_hidden, num_hidden)
assert w_ha.shape == (num_hidden, num_hidden)
self.w_hz = bm.TrainVar(w_hz)
self.w_hr = bm.TrainVar(w_hr)
self.w_ha = bm.TrainVar(w_ha)
if b is None:
self.has_bias = False
self.bz = 0.
self.br = 0.
self.ba = 0.
elif callable(b):
self.bz = bm.TrainVar(b((num_hidden,)))
self.br = bm.TrainVar(b((num_hidden,)))
self.ba = bm.TrainVar(b((num_hidden,)))
else:
bz, br, ba = b
assert bz.shape == (num_hidden, )
assert br.shape == (num_hidden, )
assert ba.shape == (num_hidden, )
self.bz = bm.TrainVar(bz)
self.br = bm.TrainVar(br)
self.ba = bm.TrainVar(ba)
wxs = self.get_param(wx, (num_input * 3, num_hidden))
self.w_iz = bm.TrainVar(wxs[:num_input])
self.w_ir = bm.TrainVar(wxs[num_input: num_input * 2])
self.w_ia = bm.TrainVar(wxs[num_input * 2:])
whs = self.get_param(wh, (num_hidden * 3, num_hidden))
self.w_hz = bm.TrainVar(whs[:num_hidden])
self.w_hr = bm.TrainVar(whs[num_hidden: num_hidden * 2])
self.w_ha = bm.TrainVar(whs[num_hidden * 2:])
bs = self.get_param(b, (num_hidden * 3,))
self.bz = bm.TrainVar(bs[:num_hidden])
self.br = bm.TrainVar(bs[num_hidden: num_hidden * 2])
self.ba = bm.TrainVar(bs[num_hidden * 2:])

def update(self, x):
z = bm.sigmoid(x @ self.w_iz + self.h @ self.w_hz + self.bz)
r = bm.sigmoid(x @ self.w_ir + self.h @ self.w_hr + self.br)
a = bm.tanh(x @ self.w_ia + (r * self.h) @ self.w_ha + self.ba)
self.h.value = (1 - z) * self.h + z * a
return self.h.value
if self.bz is None:
z = bm.sigmoid(x @ self.w_iz + self.h @ self.w_hz)
r = bm.sigmoid(x @ self.w_ir + self.h @ self.w_hr)
a = bm.tanh(x @ self.w_ia + (r * self.h) @ self.w_ha)
self.h.value = (1 - z) * self.h + z * a
return self.h.value
else:
z = bm.sigmoid(x @ self.w_iz + self.h @ self.w_hz + self.bz)
r = bm.sigmoid(x @ self.w_ir + self.h @ self.w_hr + self.br)
a = bm.tanh(x @ self.w_ia + (r * self.h) @ self.w_ha + self.ba)
self.h.value = (1 - z) * self.h + z * a
return self.h.value


class LSTM(RNNCore):
@@ -214,38 +171,23 @@ class LSTM(RNNCore):
self.has_bias = True

# variables
if callable(hc):
self.h = bm.Variable(hc((num_batch, self.num_hidden)))
self.c = bm.Variable(hc((num_batch, self.num_hidden)))
else:
h, c = hc
assert h.shape == (num_batch, self.num_hidden)
assert c.shape == (num_batch, self.num_hidden)
self.h = bm.Variable(h)
self.c = bm.Variable(c)
hc = bm.Variable(self.get_param(hc, (num_batch * 2, self.num_hidden)))
self.h = bm.Variable(hc[:num_batch])
self.c = bm.Variable(hc[num_batch:])

# weights
if callable(w):
self.w = bm.TrainVar(w((num_input + num_hidden, num_hidden * 4)))
else:
assert w.shape == (num_input + num_hidden, num_hidden * 4)
self.w = bm.TrainVar(w)
if b is None:
self.b = 0.
self.has_bias = False
elif callable(b):
self.b = bm.TrainVar(b((num_hidden * 4,)))
else:
assert b.shape == (num_hidden * 4, )
self.b = bm.TrainVar(b)
self.w = self.get_param(w, (num_input + num_hidden, num_hidden * 4))
self.b = self.get_param(b, (num_hidden * 4,))

def update(self, x):
xh = bm.concatenate([x, self.h], axis=-1)
gated = xh @ self.w + self.b
if self.b is None:
gated = xh @ self.w
else:
gated = xh @ self.w + self.b
i, g, f, o = bm.split(gated, indices_or_sections=4, axis=-1)
c = bm.sigmoid(f + 1.) * self.c + bm.sigmoid(i) * bm.tanh(g)
h = bm.sigmoid(o) * bm.tanh(c)
self.h.value = h
self.c.value = c
return self.h.value


+ 14
- 0
changelog.rst View File

@@ -2,6 +2,20 @@ Release notes
=============


Version 1.1.5
-------------

**API changes:**

- fix bugs on ndarray import in `brainpy.base.function.py`
- convenient 'get_param' interface `brainpy.simulation.layers`
- add more weight initialization methods

**Doc changes:**

- add more examples in README


Version 1.1.4
-------------



+ 33
- 7
docs/quickstart/installation.rst View File

@@ -84,7 +84,21 @@ is based on JAX.
Currently, JAX supports **Linux** (Ubuntu 16.04 or later) and **macOS** (10.12 or
later) platforms. The provided binary releases of JAX for Linux and macOS
systems are available at https://storage.googleapis.com/jax-releases/jax_releases.html .
Users can download the preferred release ".whl" file, and install it via ``pip``:

To install a CPU-only version of JAX, you can run

.. code-block:: bash

pip install --upgrade "jax[cpu]"

If you want to install JAX with both CPU and NVidia GPU support, you must first install
`CUDA`_ and `CuDNN`_, if they have not already been installed. Next, run

.. code-block:: bash

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Alternatively, you can download the preferred release ".whl" file, and install it via ``pip``:

.. code-block:: bash

@@ -92,13 +106,23 @@ Users can download the preferred release ".whl" file, and install it via ``pip``

For **Windows** users, JAX can be installed by the following methods:

- For Windows 10+ system, you can `Windows Subsystem for Linux (WSL)`_.
The installation guide can be found in `WSL Installation Guide for Windows 10`_.
Then, you can install JAX in WSL just like the installation step in Linux.
- There are several precompiled Windows wheels, like `jaxlib_0.1.68_Windows_wheels`_ and `jaxlib_0.1.61_Windows_wheels`_.
- Finally, you can also `build JAX from source`_.
Method 1: For Windows 10+ system, you can `Windows Subsystem for Linux (WSL)`_.
The installation guide can be found in `WSL Installation Guide for Windows 10`_.
Then, you can install JAX in WSL just like the installation step in Linux.

Method 2: There are several community supported Windows build for jax, please refer
to the github link for more details: https://github.com/cloudhan/jax-windows-builder .
Simply speaking, you can run:

.. code-block:: bash

# for only CPU
pip install jaxlib -f https://whls.blob.core.windows.net/unstable/index.html

# for GPU support
pip install <downloaded jaxlib>

More details of JAX installation can be found in https://github.com/google/jax#installation .
Method 3: You can also `build JAX from source`_.


Numba
@@ -148,3 +172,5 @@ Therefore, we highly recommend you to install sympy, just typing
.. _SymPy: https://github.com/sympy/sympy
.. _Exponential Euler numerical solver: https://brainpy.readthedocs.io/en/latest/tutorials_advanced/ode_numerical_solvers.html#Exponential-Euler-methods
.. _dynamics analysis module: https://brainpy.readthedocs.io/en/latest/apis/analysis.html
.. _CUDA: https://developer.nvidia.com/cuda-downloads
.. _CuDNN: https://developer.nvidia.com/CUDNN

Loading…
Cancel
Save