@@ -62,9 +62,9 @@ To install the stable release of BrainPy, please use | |||||
**Other dependencies**: you want to get the full supports by BrainPy, please install the following packages: | **Other dependencies**: you want to get the full supports by BrainPy, please install the following packages: | ||||
- `JAX >= 0.2.10`, needed for "jax" backend and [many other supports](https://brainpy.readthedocs.io/en/latest/apis/math/special_jax.html) | |||||
- `Numba >= 0.52`, needed for JIT compilation on "numpy" backend | |||||
- `SymPy >= 1.4`, needed for dynamics "analysis" module and Exponential Euler method | |||||
- `JAX >= 0.2.10`, needed for "jax" backend and many other supports ([how to install jax?](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#jax)) | |||||
- `Numba >= 0.52`, needed for JIT compilation on "numpy" backend ([how to install numba?](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#numba)) | |||||
- `SymPy >= 1.4`, needed for dynamics "analysis" module and Exponential Euler method ([how to install sympy?](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html#sympy)) | |||||
@@ -126,8 +126,6 @@ See [brainmodels.synapses](https://brainmodels.readthedocs.io/en/latest/apis/syn | |||||
- **[Working Memory]** [*(Mi, et. al., 2017)* STP for Working Memory Capacity](https://brainpy-examples.readthedocs.io/en/latest/working_memory/Mi_2017_working_memory_capacity.html) | - **[Working Memory]** [*(Mi, et. al., 2017)* STP for Working Memory Capacity](https://brainpy-examples.readthedocs.io/en/latest/working_memory/Mi_2017_working_memory_capacity.html) | ||||
- **[Working Memory]** [*(Bouchacourt & Buschman, 2019)* Flexible Working Memory Model](https://brainpy-examples.readthedocs.io/en/latest/working_memory/Bouchacourt_2019_Flexible_working_memory.html) | - **[Working Memory]** [*(Bouchacourt & Buschman, 2019)* Flexible Working Memory Model](https://brainpy-examples.readthedocs.io/en/latest/working_memory/Bouchacourt_2019_Flexible_working_memory.html) | ||||
- **[Decision Making]** [*(Wang, 2002)* Decision making spiking model](https://brainpy-examples.readthedocs.io/en/latest/decision_making/Wang_2002_decision_making_spiking.html) | - **[Decision Making]** [*(Wang, 2002)* Decision making spiking model](https://brainpy-examples.readthedocs.io/en/latest/decision_making/Wang_2002_decision_making_spiking.html) | ||||
- **[Recurrent Network]** [*(Laje & Buonomano, 2013)* Robust Timing in RNN](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/Laje_Buonomano_2013_robust_timing_rnn.html) | |||||
- **[Recurrent Network]** [*(Sussillo & Abbott, 2009)* FORCE Learning](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/Sussillo_Abbott_2009_FORCE_Learning.html) | |||||
@@ -135,16 +133,20 @@ See [brainmodels.synapses](https://brainmodels.readthedocs.io/en/latest/apis/syn | |||||
- [Train Integrator RNN with BP](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/integrator_rnn.html) | - [Train Integrator RNN with BP](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/integrator_rnn.html) | ||||
- [FORCE Learning](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/Sussillo_Abbott_2009_FORCE_Learning.html) | |||||
- [*(Sussillo & Abbott, 2009)* FORCE Learning](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/Sussillo_Abbott_2009_FORCE_Learning.html) | |||||
- [*(Laje & Buonomano, 2013)* Robust Timing in RNN](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/Laje_Buonomano_2013_robust_timing_rnn.html) | |||||
- [*(Song, et al., 2016)*: Training excitatory-inhibitory recurrent network](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/Song_2016_EI_RNN.html) | |||||
- **[Working Memory]** [*(Masse, et al., 2019)*: RNN with STP for Working Memory](https://brainpy-examples.readthedocs.io/en/latest/recurrent_networks/Masse_2019_STP_RNN.html) | |||||
### Low-dimension dynamics analysis | ### Low-dimension dynamics analysis | ||||
- [Phase plane analysis of the I<sub>Na,p</sub>-I<sub>K</sub> model](https://brainmodels.readthedocs.io/en/latest/tutorials/dynamics_analysis/NaK_model_analysis.html) | |||||
- [Codimension 1 bifurcation analysis of FitzHugh Nagumo model](https://brainmodels.readthedocs.io/en/latest/tutorials/dynamics_analysis/FitzHugh_Nagumo_analysis.html) | |||||
- [Codimension 2 bifurcation analysis of FitzHugh Nagumo model](https://brainmodels.readthedocs.io/en/latest/tutorials/dynamics_analysis/FitzHugh_Nagumo_analysis.html#Codimension-2-bifurcation-analysis) | |||||
- [1D system bifurcation](https://brainmodels.readthedocs.io/en/latest/low_dim_analysis/1D_system_bifurcation.html) | |||||
- [Codimension 1 bifurcation analysis of FitzHugh Nagumo model](https://brainpy-examples.readthedocs.io/en/latest/low_dim_analysis/FitzHugh_Nagumo_analysis.html) | |||||
- [Codimension 2 bifurcation analysis of FitzHugh Nagumo model](https://brainpy-examples.readthedocs.io/en/latest/low_dim_analysis/FitzHugh_Nagumo_analysis.html#Codimension-2-bifurcation-analysis) | |||||
- **[Decision Making Model]** [*(Wong & Wang, 2006)* Decision making rate model](https://brainpy-examples.readthedocs.io/en/latest/decision_making/Wang_2006_decision_making_rate.html) | - **[Decision Making Model]** [*(Wong & Wang, 2006)* Decision making rate model](https://brainpy-examples.readthedocs.io/en/latest/decision_making/Wang_2006_decision_making_rate.html) | ||||
@@ -1,6 +1,6 @@ | |||||
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
__version__ = "1.1.4" | |||||
__version__ = "1.1.5" | |||||
# "base" module | # "base" module | ||||
@@ -35,6 +35,7 @@ from .simulation import connect | |||||
from .simulation import initialize | from .simulation import initialize | ||||
from .simulation import inputs | from .simulation import inputs | ||||
from .simulation import measure | from .simulation import measure | ||||
init = initialize | |||||
# "analysis" module | # "analysis" module | ||||
@@ -165,31 +165,34 @@ def find_root_of_1d(f, f_points, args=(), tol=1e-8): | |||||
""" | """ | ||||
vals = f(f_points, *args) | vals = f(f_points, *args) | ||||
fs_len = len(f_points) | fs_len = len(f_points) | ||||
fs_sign = np.sign(vals) | |||||
signs = np.sign(vals) | |||||
roots = [] | roots = [] | ||||
fl_sign = fs_sign[0] | |||||
f_i = 1 | |||||
while f_i < fs_len and fl_sign == 0.: | |||||
roots.append(f_points[f_i - 1]) | |||||
fl_sign = fs_sign[f_i] | |||||
f_i += 1 | |||||
while f_i < fs_len: | |||||
fr_sign = fs_sign[f_i] | |||||
if fr_sign == 0.: | |||||
roots.append(f_points[f_i]) | |||||
if f_i + 1 < fs_len: | |||||
fl_sign = fs_sign[f_i + 1] | |||||
sign_l = signs[0] | |||||
point_l = f_points[0] | |||||
idx = 1 | |||||
while idx < fs_len and sign_l == 0.: | |||||
roots.append(f_points[idx - 1]) | |||||
sign_l = signs[idx] | |||||
idx += 1 | |||||
while idx < fs_len: | |||||
sign_r = signs[idx] | |||||
point_r = f_points[idx] | |||||
if sign_r == 0.: | |||||
roots.append(point_r) | |||||
if idx + 1 < fs_len: | |||||
sign_l = sign_r | |||||
point_l = point_r | |||||
else: | else: | ||||
break | break | ||||
f_i += 2 | |||||
idx += 1 | |||||
else: | else: | ||||
if not np.isnan(fr_sign) and fl_sign != fr_sign: | |||||
root, funcalls, itr = brentq(f, f_points[f_i - 1], f_points[f_i], args) | |||||
if abs(f(root, *args)) < tol: | |||||
roots.append(root) | |||||
fl_sign = fr_sign | |||||
f_i += 1 | |||||
if not np.isnan(sign_r) and sign_l != sign_r: | |||||
root, funcalls, itr = brentq(f, point_l, point_r, args) | |||||
if abs(f(root, *args)) < tol: roots.append(root) | |||||
sign_l = sign_r | |||||
point_l = point_r | |||||
idx += 1 | |||||
return roots | return roots | ||||
@@ -244,6 +247,7 @@ def find_root_of_2d(f, x_bound, y_bound, args=(), shgo_args=None, | |||||
res : tuple | res : tuple | ||||
The roots. | The roots. | ||||
""" | """ | ||||
print('Using scipy.optimize.shgo to solve fixed points.') | |||||
if shgo is None: | if shgo is None: | ||||
raise errors.PackageMissingError('Package "scipy" must be installed when the users ' | raise errors.PackageMissingError('Package "scipy" must be installed when the users ' | ||||
@@ -287,8 +287,8 @@ class Base1DSymAnalyzer(BaseSymAnalyzer): | |||||
sympy_failed = True | sympy_failed = True | ||||
if not self.options.escape_sympy_solver and not x_eq.contain_unknown_func: | if not self.options.escape_sympy_solver and not x_eq.contain_unknown_func: | ||||
try: | try: | ||||
logger.info(f'SymPy solve derivative of "{self.x_eq_group.func_name}' | |||||
f'({argument})" by "{x_var}", ') | |||||
logger.warning(f'SymPy solve derivative of "{self.x_eq_group.func_name}' | |||||
f'({argument})" by "{x_var}", ') | |||||
x_eq = x_eq.expr | x_eq = x_eq.expr | ||||
f = utils.timeout(time_out)(lambda: sympy.diff(x_eq, x_symbol)) | f = utils.timeout(time_out)(lambda: sympy.diff(x_eq, x_symbol)) | ||||
dfxdx_expr = f() | dfxdx_expr = f() | ||||
@@ -297,10 +297,10 @@ class Base1DSymAnalyzer(BaseSymAnalyzer): | |||||
all_vars = set(eq_x_scope.keys()) | all_vars = set(eq_x_scope.keys()) | ||||
all_vars.update(self.dvar_names + self.dpar_names) | all_vars.update(self.dvar_names + self.dpar_names) | ||||
if utils.contain_unknown_symbol(analysis_by_sympy.sympy2str(dfxdx_expr), all_vars): | if utils.contain_unknown_symbol(analysis_by_sympy.sympy2str(dfxdx_expr), all_vars): | ||||
logger.info('\tfailed because contain unknown symbols.') | |||||
logger.warning('\tfailed because contain unknown symbols.') | |||||
sympy_failed = True | sympy_failed = True | ||||
else: | else: | ||||
logger.info('\tsuccess.') | |||||
logger.warning('\tsuccess.') | |||||
func_codes = [f'def dfdx({argument}):'] | func_codes = [f'def dfdx({argument}):'] | ||||
for expr in self.x_eq_group.sub_exprs[:-1]: | for expr in self.x_eq_group.sub_exprs[:-1]: | ||||
func_codes.append(f'{expr.var_name} = {expr.code}') | func_codes.append(f'{expr.var_name} = {expr.code}') | ||||
@@ -309,9 +309,9 @@ class Base1DSymAnalyzer(BaseSymAnalyzer): | |||||
dfdx = eq_x_scope['dfdx'] | dfdx = eq_x_scope['dfdx'] | ||||
sympy_failed = False | sympy_failed = False | ||||
except KeyboardInterrupt: | except KeyboardInterrupt: | ||||
logger.info(f'\tfailed because {time_out} s timeout.') | |||||
logger.warning(f'\tfailed because {time_out} s timeout.') | |||||
except NotImplementedError: | except NotImplementedError: | ||||
logger.info('\tfailed because the equation is too complex.') | |||||
logger.warning('\tfailed because the equation is too complex.') | |||||
if sympy_failed: | if sympy_failed: | ||||
scope = dict(_fx=self.get_f_dx(), perturb=self.options.perturbation, math=math) | scope = dict(_fx=self.get_f_dx(), perturb=self.options.perturbation, math=math) | ||||
@@ -349,8 +349,8 @@ class Base1DSymAnalyzer(BaseSymAnalyzer): | |||||
sympy_failed = True | sympy_failed = True | ||||
if not self.options.escape_sympy_solver and not x_eq.contain_unknown_func: | if not self.options.escape_sympy_solver and not x_eq.contain_unknown_func: | ||||
try: | try: | ||||
logger.info(f'SymPy solve "{self.x_eq_group.func_name}({argument1}) = 0" ' | |||||
f'to "{self.x_var} = f({argument2})", ') | |||||
logger.warning(f'SymPy solve "{self.x_eq_group.func_name}({argument1}) = 0" ' | |||||
f'to "{self.x_var} = f({argument2})", ') | |||||
# solver | # solver | ||||
f = utils.timeout(timeout_len)( | f = utils.timeout(timeout_len)( | ||||
@@ -360,11 +360,11 @@ class Base1DSymAnalyzer(BaseSymAnalyzer): | |||||
all_vars = set(scope.keys()) | all_vars = set(scope.keys()) | ||||
all_vars.update(self.dvar_names + self.dpar_names) | all_vars.update(self.dvar_names + self.dpar_names) | ||||
if utils.contain_unknown_symbol(analysis_by_sympy.sympy2str(res), all_vars): | if utils.contain_unknown_symbol(analysis_by_sympy.sympy2str(res), all_vars): | ||||
logger.info('\tfailed because contain unknown symbols.') | |||||
logger.warning('\tfailed because contain unknown symbols.') | |||||
sympy_failed = True | sympy_failed = True | ||||
break | break | ||||
else: | else: | ||||
logger.info('\tsuccess.') | |||||
logger.warning('\tsuccess.') | |||||
# function codes | # function codes | ||||
func_codes = [f'def solve_x({argument2}):'] | func_codes = [f'def solve_x({argument2}):'] | ||||
for expr in self.x_eq_group.sub_exprs[:-1]: | for expr in self.x_eq_group.sub_exprs[:-1]: | ||||
@@ -379,10 +379,10 @@ class Base1DSymAnalyzer(BaseSymAnalyzer): | |||||
self.analyzed_results['fixed_point'] = scope['solve_x'] | self.analyzed_results['fixed_point'] = scope['solve_x'] | ||||
sympy_failed = False | sympy_failed = False | ||||
except NotImplementedError: | except NotImplementedError: | ||||
logger.info('\tfailed because the equation is too complex.') | |||||
logger.warning('\tfailed because the equation is too complex.') | |||||
sympy_failed = True | sympy_failed = True | ||||
except KeyboardInterrupt: | except KeyboardInterrupt: | ||||
logger.info(f'\tfailed because {timeout_len} s timeout.') | |||||
logger.warning(f'\tfailed because {timeout_len} s timeout.') | |||||
sympy_failed = True | sympy_failed = True | ||||
if sympy_failed: | if sympy_failed: | ||||
@@ -531,8 +531,8 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
sympy_failed = True | sympy_failed = True | ||||
if not self.options.escape_sympy_solver and not x_eq.contain_unknown_func: | if not self.options.escape_sympy_solver and not x_eq.contain_unknown_func: | ||||
try: | try: | ||||
logger.info(f'SymPy solve derivative of "{self.x_eq_group.func_name}' | |||||
f'({argument})" by "{y_var}", ') | |||||
logger.warning(f'SymPy solve derivative of "{self.x_eq_group.func_name}' | |||||
f'({argument})" by "{y_var}", ') | |||||
x_eq = x_eq.expr | x_eq = x_eq.expr | ||||
f = utils.timeout(time_out)(lambda: sympy.diff(x_eq, y_symbol)) | f = utils.timeout(time_out)(lambda: sympy.diff(x_eq, y_symbol)) | ||||
dfxdy_expr = f() | dfxdy_expr = f() | ||||
@@ -541,10 +541,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
all_vars = set(eq_x_scope.keys()) | all_vars = set(eq_x_scope.keys()) | ||||
all_vars.update(self.dvar_names + self.dpar_names) | all_vars.update(self.dvar_names + self.dpar_names) | ||||
if utils.contain_unknown_symbol(analysis_by_sympy.sympy2str(dfxdy_expr), all_vars): | if utils.contain_unknown_symbol(analysis_by_sympy.sympy2str(dfxdy_expr), all_vars): | ||||
logger.info('\tfailed because contain unknown symbols.') | |||||
logger.warning('\tfailed because contain unknown symbols.') | |||||
sympy_failed = True | sympy_failed = True | ||||
else: | else: | ||||
logger.info('\tsuccess.') | |||||
logger.warning('\tsuccess.') | |||||
func_codes = [f'def dfdy({argument}):'] | func_codes = [f'def dfdy({argument}):'] | ||||
for expr in self.x_eq_group.sub_exprs[:-1]: | for expr in self.x_eq_group.sub_exprs[:-1]: | ||||
func_codes.append(f'{expr.var_name} = {expr.code}') | func_codes.append(f'{expr.var_name} = {expr.code}') | ||||
@@ -553,9 +553,9 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
dfdy = eq_x_scope['dfdy'] | dfdy = eq_x_scope['dfdy'] | ||||
sympy_failed = False | sympy_failed = False | ||||
except KeyboardInterrupt: | except KeyboardInterrupt: | ||||
logger.info(f'\tfailed because {time_out} s timeout.') | |||||
logger.warning(f'\tfailed because {time_out} s timeout.') | |||||
except NotImplementedError: | except NotImplementedError: | ||||
logger.info('\tfailed because the equation is too complex.') | |||||
logger.warning('\tfailed because the equation is too complex.') | |||||
if sympy_failed: | if sympy_failed: | ||||
scope = dict(_fx=self.get_f_dx(), perturb=self.options.perturbation, math=math) | scope = dict(_fx=self.get_f_dx(), perturb=self.options.perturbation, math=math) | ||||
@@ -595,8 +595,8 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
sympy_failed = True | sympy_failed = True | ||||
if not self.options.escape_sympy_solver and not y_eq.contain_unknown_func: | if not self.options.escape_sympy_solver and not y_eq.contain_unknown_func: | ||||
try: | try: | ||||
logger.info(f'SymPy solve derivative of "{self.y_eq_group.func_name}' | |||||
f'({argument})" by "{x_var}", ') | |||||
logger.warning(f'SymPy solve derivative of "{self.y_eq_group.func_name}' | |||||
f'({argument})" by "{x_var}", ') | |||||
y_eq = y_eq.expr | y_eq = y_eq.expr | ||||
f = utils.timeout(time_out)(lambda: sympy.diff(y_eq, x_symbol)) | f = utils.timeout(time_out)(lambda: sympy.diff(y_eq, x_symbol)) | ||||
dfydx_expr = f() | dfydx_expr = f() | ||||
@@ -605,10 +605,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
all_vars = set(eq_y_scope.keys()) | all_vars = set(eq_y_scope.keys()) | ||||
all_vars.update(self.dvar_names + self.dpar_names) | all_vars.update(self.dvar_names + self.dpar_names) | ||||
if utils.contain_unknown_symbol(analysis_by_sympy.sympy2str(dfydx_expr), all_vars): | if utils.contain_unknown_symbol(analysis_by_sympy.sympy2str(dfydx_expr), all_vars): | ||||
logger.info('\tfailed because contain unknown symbols.') | |||||
logger.warning('\tfailed because contain unknown symbols.') | |||||
sympy_failed = True | sympy_failed = True | ||||
else: | else: | ||||
logger.info('\tsuccess.') | |||||
logger.warning('\tsuccess.') | |||||
func_codes = [f'def dgdx({argument}):'] | func_codes = [f'def dgdx({argument}):'] | ||||
for expr in self.y_eq_group.sub_exprs[:-1]: | for expr in self.y_eq_group.sub_exprs[:-1]: | ||||
func_codes.append(f'{expr.var_name} = {expr.code}') | func_codes.append(f'{expr.var_name} = {expr.code}') | ||||
@@ -617,9 +617,9 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
dgdx = eq_y_scope['dgdx'] | dgdx = eq_y_scope['dgdx'] | ||||
sympy_failed = False | sympy_failed = False | ||||
except KeyboardInterrupt: | except KeyboardInterrupt: | ||||
logger.info(f'\tfailed because {time_out} s timeout.') | |||||
logger.warning(f'\tfailed because {time_out} s timeout.') | |||||
except NotImplementedError: | except NotImplementedError: | ||||
logger.info('\tfailed because the equation is too complex.') | |||||
logger.warning('\tfailed because the equation is too complex.') | |||||
if sympy_failed: | if sympy_failed: | ||||
scope = dict(_fy=self.get_f_dy(), perturb=self.options.perturbation, math=math) | scope = dict(_fy=self.get_f_dy(), perturb=self.options.perturbation, math=math) | ||||
@@ -660,8 +660,8 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
sympy_failed = True | sympy_failed = True | ||||
if not self.options.escape_sympy_solver and not y_eq.contain_unknown_func: | if not self.options.escape_sympy_solver and not y_eq.contain_unknown_func: | ||||
try: | try: | ||||
logger.info(f'\tSymPy solve derivative of "{self.y_eq_group.func_name}' | |||||
f'({argument})" by "{y_var}", ') | |||||
logger.warning(f'\tSymPy solve derivative of "{self.y_eq_group.func_name}' | |||||
f'({argument})" by "{y_var}", ') | |||||
y_eq = y_eq.expr | y_eq = y_eq.expr | ||||
f = utils.timeout(time_out)(lambda: sympy.diff(y_eq, y_symbol)) | f = utils.timeout(time_out)(lambda: sympy.diff(y_eq, y_symbol)) | ||||
dfydx_expr = f() | dfydx_expr = f() | ||||
@@ -670,10 +670,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
all_vars = set(eq_y_scope.keys()) | all_vars = set(eq_y_scope.keys()) | ||||
all_vars.update(self.dvar_names + self.dpar_names) | all_vars.update(self.dvar_names + self.dpar_names) | ||||
if utils.contain_unknown_symbol(analysis_by_sympy.sympy2str(dfydx_expr), all_vars): | if utils.contain_unknown_symbol(analysis_by_sympy.sympy2str(dfydx_expr), all_vars): | ||||
logger.info('\tfailed because contain unknown symbols.') | |||||
logger.warning('\tfailed because contain unknown symbols.') | |||||
sympy_failed = True | sympy_failed = True | ||||
else: | else: | ||||
logger.info('\tsuccess.') | |||||
logger.warning('\tsuccess.') | |||||
func_codes = [f'def dgdy({argument}):'] | func_codes = [f'def dgdy({argument}):'] | ||||
for expr in self.y_eq_group.sub_exprs[:-1]: | for expr in self.y_eq_group.sub_exprs[:-1]: | ||||
func_codes.append(f'{expr.var_name} = {expr.code}') | func_codes.append(f'{expr.var_name} = {expr.code}') | ||||
@@ -682,9 +682,9 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
dgdy = eq_y_scope['dgdy'] | dgdy = eq_y_scope['dgdy'] | ||||
sympy_failed = False | sympy_failed = False | ||||
except KeyboardInterrupt: | except KeyboardInterrupt: | ||||
logger.info(f'\tfailed because {time_out} s timeout.') | |||||
logger.warning(f'\tfailed because {time_out} s timeout.') | |||||
except NotImplementedError: | except NotImplementedError: | ||||
logger.info('\tfailed because the equation is too complex.') | |||||
logger.warning('\tfailed because the equation is too complex.') | |||||
if sympy_failed: | if sympy_failed: | ||||
scope = dict(_fy=self.get_f_dy(), perturb=self.options.perturbation, math=math) | scope = dict(_fy=self.get_f_dy(), perturb=self.options.perturbation, math=math) | ||||
@@ -1065,9 +1065,9 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
timeout_len = self.options.sympy_solver_timeout | timeout_len = self.options.sympy_solver_timeout | ||||
try: | try: | ||||
logger.info(f'SymPy solve "{self.y_eq_group.func_name}({argument}) = 0" to ' | |||||
f'"{self.y_var} = f({self.x_var}, ' | |||||
f'{",".join(self.dvar_names[2:] + self.dpar_names)})", ') | |||||
logger.warning(f'SymPy solve "{self.y_eq_group.func_name}({argument}) = 0" to ' | |||||
f'"{self.y_var} = f({self.x_var}, ' | |||||
f'{",".join(self.dvar_names[2:] + self.dpar_names)})", ') | |||||
# solve the expression | # solve the expression | ||||
f = utils.timeout(timeout_len)(lambda: sympy.solve(y_eq, y_symbol)) | f = utils.timeout(timeout_len)(lambda: sympy.solve(y_eq, y_symbol)) | ||||
y_by_x_in_y_eq = f() | y_by_x_in_y_eq = f() | ||||
@@ -1079,10 +1079,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
all_vars = set(eq_y_scope.keys()) | all_vars = set(eq_y_scope.keys()) | ||||
all_vars.update(self.dvar_names + self.dpar_names) | all_vars.update(self.dvar_names + self.dpar_names) | ||||
if utils.contain_unknown_symbol(y_by_x_in_y_eq, all_vars): | if utils.contain_unknown_symbol(y_by_x_in_y_eq, all_vars): | ||||
logger.info('\tfailed because contain unknown symbols.') | |||||
logger.warning('\tfailed because contain unknown symbols.') | |||||
results['status'] = 'sympy_failed' | results['status'] = 'sympy_failed' | ||||
else: | else: | ||||
logger.info('\tsuccess.') | |||||
logger.warning('\tsuccess.') | |||||
# substituted codes | # substituted codes | ||||
subs_codes = [f'{expr.var_name} = {expr.code}' | subs_codes = [f'{expr.var_name} = {expr.code}' | ||||
for expr in self.y_eq_group.sub_exprs[:-1]] | for expr in self.y_eq_group.sub_exprs[:-1]] | ||||
@@ -1101,10 +1101,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
results['f'] = eq_y_scope['func'] | results['f'] = eq_y_scope['func'] | ||||
except NotImplementedError: | except NotImplementedError: | ||||
logger.info('\tfailed because the equation is too complex.') | |||||
logger.warning('\tfailed because the equation is too complex.') | |||||
results['status'] = 'sympy_failed' | results['status'] = 'sympy_failed' | ||||
except KeyboardInterrupt: | except KeyboardInterrupt: | ||||
logger.info(f'\tfailed because {timeout_len} s timeout.') | |||||
logger.warning(f'\tfailed because {timeout_len} s timeout.') | |||||
results['status'] = 'sympy_failed' | results['status'] = 'sympy_failed' | ||||
else: | else: | ||||
results['status'] = 'escape' | results['status'] = 'escape' | ||||
@@ -1139,9 +1139,9 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
timeout_len = self.options.sympy_solver_timeout | timeout_len = self.options.sympy_solver_timeout | ||||
try: | try: | ||||
logger.info(f'SymPy solve "{self.x_eq_group.func_name}({argument}) = 0" to ' | |||||
f'"{self.y_var} = f({self.x_var}, ' | |||||
f'{",".join(self.dvar_names[2:] + self.dpar_names)})", ') | |||||
logger.warning(f'SymPy solve "{self.x_eq_group.func_name}({argument}) = 0" to ' | |||||
f'"{self.y_var} = f({self.x_var}, ' | |||||
f'{",".join(self.dvar_names[2:] + self.dpar_names)})", ') | |||||
# solve the expression | # solve the expression | ||||
f = utils.timeout(timeout_len)(lambda: sympy.solve(x_eq, y_symbol)) | f = utils.timeout(timeout_len)(lambda: sympy.solve(x_eq, y_symbol)) | ||||
@@ -1153,10 +1153,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
all_vars = set(eq_x_scope.keys()) | all_vars = set(eq_x_scope.keys()) | ||||
all_vars.update(self.dvar_names + self.dpar_names) | all_vars.update(self.dvar_names + self.dpar_names) | ||||
if utils.contain_unknown_symbol(y_by_x_in_x_eq, all_vars): | if utils.contain_unknown_symbol(y_by_x_in_x_eq, all_vars): | ||||
logger.info('\tfailed because contain unknown symbols.') | |||||
logger.warning('\tfailed because contain unknown symbols.') | |||||
results['status'] = 'sympy_failed' | results['status'] = 'sympy_failed' | ||||
else: | else: | ||||
logger.info('\tsuccess.') | |||||
logger.warning('\tsuccess.') | |||||
# substituted codes | # substituted codes | ||||
subs_codes = [f'{expr.var_name} = {expr.code}' | subs_codes = [f'{expr.var_name} = {expr.code}' | ||||
@@ -1175,10 +1175,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
results['subs'] = subs_codes | results['subs'] = subs_codes | ||||
results['f'] = eq_x_scope['func'] | results['f'] = eq_x_scope['func'] | ||||
except NotImplementedError: | except NotImplementedError: | ||||
logger.info('\tfailed because the equation is too complex.') | |||||
logger.warning('\tfailed because the equation is too complex.') | |||||
results['status'] = 'sympy_failed' | results['status'] = 'sympy_failed' | ||||
except KeyboardInterrupt: | except KeyboardInterrupt: | ||||
logger.info(f'\tfailed because {timeout_len} s timeout.') | |||||
logger.warning(f'\tfailed because {timeout_len} s timeout.') | |||||
results['status'] = 'sympy_failed' | results['status'] = 'sympy_failed' | ||||
else: | else: | ||||
results['status'] = 'escape' | results['status'] = 'escape' | ||||
@@ -1213,8 +1213,8 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
timeout_len = self.options.sympy_solver_timeout | timeout_len = self.options.sympy_solver_timeout | ||||
try: | try: | ||||
logger.info(f'SymPy solve "{self.y_eq_group.func_name}({argument}) = 0" to ' | |||||
f'"{self.x_var} = f({",".join(self.dvar_names[1:] + self.dpar_names)})", ') | |||||
logger.warning(f'SymPy solve "{self.y_eq_group.func_name}({argument}) = 0" to ' | |||||
f'"{self.x_var} = f({",".join(self.dvar_names[1:] + self.dpar_names)})", ') | |||||
# solve the expression | # solve the expression | ||||
f = utils.timeout(timeout_len)(lambda: sympy.solve(y_eq, x_symbol)) | f = utils.timeout(timeout_len)(lambda: sympy.solve(y_eq, x_symbol)) | ||||
x_by_y_in_y_eq = f() | x_by_y_in_y_eq = f() | ||||
@@ -1226,10 +1226,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
all_vars = set(eq_y_scope.keys()) | all_vars = set(eq_y_scope.keys()) | ||||
all_vars.update(self.dvar_names + self.dpar_names) | all_vars.update(self.dvar_names + self.dpar_names) | ||||
if utils.contain_unknown_symbol(x_by_y_in_y_eq, all_vars): | if utils.contain_unknown_symbol(x_by_y_in_y_eq, all_vars): | ||||
logger.info('\tfailed because contain unknown symbols.') | |||||
logger.warning('\tfailed because contain unknown symbols.') | |||||
results['status'] = 'sympy_failed' | results['status'] = 'sympy_failed' | ||||
else: | else: | ||||
logger.info('\tsuccess.') | |||||
logger.warning('\tsuccess.') | |||||
# substituted codes | # substituted codes | ||||
subs_codes = [f'{expr.var_name} = {expr.code}' | subs_codes = [f'{expr.var_name} = {expr.code}' | ||||
@@ -1248,10 +1248,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
results['subs'] = subs_codes | results['subs'] = subs_codes | ||||
results['f'] = eq_y_scope['func'] | results['f'] = eq_y_scope['func'] | ||||
except NotImplementedError: | except NotImplementedError: | ||||
logger.info('\tfailed because the equation is too complex.') | |||||
logger.warning('\tfailed because the equation is too complex.') | |||||
results['status'] = 'sympy_failed' | results['status'] = 'sympy_failed' | ||||
except KeyboardInterrupt: | except KeyboardInterrupt: | ||||
logger.info(f'\tfailed because {timeout_len} s timeout.') | |||||
logger.warning(f'\tfailed because {timeout_len} s timeout.') | |||||
results['status'] = 'sympy_failed' | results['status'] = 'sympy_failed' | ||||
else: | else: | ||||
results['status'] = 'escape' | results['status'] = 'escape' | ||||
@@ -1286,8 +1286,8 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
timeout_len = self.options.sympy_solver_timeout | timeout_len = self.options.sympy_solver_timeout | ||||
try: | try: | ||||
logger.info(f'SymPy solve "{self.x_eq_group.func_name}({argument}) = 0" to ' | |||||
f'"{self.x_var} = f({",".join(self.dvar_names[1:] + self.dpar_names)})", ') | |||||
logger.warning(f'SymPy solve "{self.x_eq_group.func_name}({argument}) = 0" to ' | |||||
f'"{self.x_var} = f({",".join(self.dvar_names[1:] + self.dpar_names)})", ') | |||||
# solve the expression | # solve the expression | ||||
f = utils.timeout(timeout_len)(lambda: sympy.solve(x_eq, x_symbol)) | f = utils.timeout(timeout_len)(lambda: sympy.solve(x_eq, x_symbol)) | ||||
x_by_y_in_x_eq = f() | x_by_y_in_x_eq = f() | ||||
@@ -1299,10 +1299,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
all_vars = set(eq_x_scope.keys()) | all_vars = set(eq_x_scope.keys()) | ||||
all_vars.update(self.dvar_names + self.dpar_names) | all_vars.update(self.dvar_names + self.dpar_names) | ||||
if utils.contain_unknown_symbol(x_by_y_in_x_eq, all_vars): | if utils.contain_unknown_symbol(x_by_y_in_x_eq, all_vars): | ||||
logger.info('\tfailed because contain unknown symbols.') | |||||
logger.warning('\tfailed because contain unknown symbols.') | |||||
results['status'] = 'sympy_failed' | results['status'] = 'sympy_failed' | ||||
else: | else: | ||||
logger.info('\tsuccess.') | |||||
logger.warning('\tsuccess.') | |||||
# substituted codes | # substituted codes | ||||
subs_codes = [f'{expr.var_name} = {expr.code}' | subs_codes = [f'{expr.var_name} = {expr.code}' | ||||
@@ -1321,10 +1321,10 @@ class Base2DSymAnalyzer(Base1DSymAnalyzer): | |||||
results['subs'] = subs_codes | results['subs'] = subs_codes | ||||
results['f'] = eq_x_scope['func'] | results['f'] = eq_x_scope['func'] | ||||
except NotImplementedError: | except NotImplementedError: | ||||
logger.info('\tfailed because the equation is too complex.') | |||||
logger.warning('\tfailed because the equation is too complex.') | |||||
results['status'] = 'sympy_failed' | results['status'] = 'sympy_failed' | ||||
except KeyboardInterrupt: | except KeyboardInterrupt: | ||||
logger.info(f'\tfailed because {timeout_len} s timeout.') | |||||
logger.warning(f'\tfailed because {timeout_len} s timeout.') | |||||
results['status'] = 'sympy_failed' | results['status'] = 'sympy_failed' | ||||
else: | else: | ||||
results['status'] = 'escape' | results['status'] = 'escape' | ||||
@@ -214,7 +214,7 @@ class _Bifurcation1D(base.Base1DSymAnalyzer): | |||||
options=options) | options=options) | ||||
def plot_bifurcation(self, show=False): | def plot_bifurcation(self, show=False): | ||||
logger.info('plot bifurcation ...') | |||||
logger.warning('plot bifurcation ...') | |||||
f_fixed_point = self.get_f_fixed_point() | f_fixed_point = self.get_f_fixed_point() | ||||
f_dfdx = self.get_f_dfdx() | f_dfdx = self.get_f_dfdx() | ||||
@@ -316,7 +316,7 @@ class _Bifurcation2D(base.Base2DSymAnalyzer): | |||||
self.fixed_points = None | self.fixed_points = None | ||||
def plot_bifurcation(self, show=False): | def plot_bifurcation(self, show=False): | ||||
logger.info('plot bifurcation ...') | |||||
logger.warning('plot bifurcation ...') | |||||
# functions | # functions | ||||
f_fixed_point = self.get_f_fixed_point() | f_fixed_point = self.get_f_fixed_point() | ||||
@@ -405,7 +405,7 @@ class _Bifurcation2D(base.Base2DSymAnalyzer): | |||||
return container | return container | ||||
def plot_limit_cycle_by_sim(self, var, duration=100, inputs=(), plot_style=None, tol=0.001, show=False): | def plot_limit_cycle_by_sim(self, var, duration=100, inputs=(), plot_style=None, tol=0.001, show=False): | ||||
logger.info('plot limit cycle ...') | |||||
logger.warning('plot limit cycle ...') | |||||
if self.fixed_points is None: | if self.fixed_points is None: | ||||
raise errors.AnalyzerError('Please call "plot_bifurcation()" before "plot_limit_cycle_by_sim()".') | raise errors.AnalyzerError('Please call "plot_bifurcation()" before "plot_limit_cycle_by_sim()".') | ||||
@@ -773,7 +773,7 @@ class _FastSlowTrajectory(object): | |||||
show : bool | show : bool | ||||
Whether show or not. | Whether show or not. | ||||
""" | """ | ||||
logger.info('plot trajectory ...') | |||||
logger.warning('plot trajectory ...') | |||||
# 1. format the initial values | # 1. format the initial values | ||||
all_vars = self.fast_var_names + self.slow_var_names | all_vars = self.fast_var_names + self.slow_var_names | ||||
@@ -228,7 +228,7 @@ class _PhasePlane1D(base.Base1DSymAnalyzer): | |||||
results : np.ndarray | results : np.ndarray | ||||
The dx values. | The dx values. | ||||
""" | """ | ||||
logger.info('plot vector field ...') | |||||
logger.warning('plot vector field ...') | |||||
# 1. Nullcline of the x variable | # 1. Nullcline of the x variable | ||||
try: | try: | ||||
@@ -265,7 +265,7 @@ class _PhasePlane1D(base.Base1DSymAnalyzer): | |||||
points : np.ndarray | points : np.ndarray | ||||
The fixed points. | The fixed points. | ||||
""" | """ | ||||
logger.info('plot fixed point ...') | |||||
logger.warning('plot fixed point ...') | |||||
# 1. functions | # 1. functions | ||||
f_fixed_point = self.get_f_fixed_point() | f_fixed_point = self.get_f_fixed_point() | ||||
@@ -278,7 +278,7 @@ class _PhasePlane1D(base.Base1DSymAnalyzer): | |||||
x = x_values[i] | x = x_values[i] | ||||
dfdx = f_dfdx(x) | dfdx = f_dfdx(x) | ||||
fp_type = stability.stability_analysis(dfdx) | fp_type = stability.stability_analysis(dfdx) | ||||
logger.info(f"Fixed point #{i + 1} at {self.x_var}={x} is a {fp_type}.") | |||||
logger.warning(f"Fixed point #{i + 1} at {self.x_var}={x} is a {fp_type}.") | |||||
container[fp_type].append(x) | container[fp_type].append(x) | ||||
# 3. visualization | # 3. visualization | ||||
@@ -330,7 +330,7 @@ class _PhasePlane2D(base.Base2DSymAnalyzer): | |||||
result : tuple | result : tuple | ||||
The ``dx``, ``dy`` values. | The ``dx``, ``dy`` values. | ||||
""" | """ | ||||
logger.info('plot vector field ...') | |||||
logger.warning('plot vector field ...') | |||||
if plot_style is None: | if plot_style is None: | ||||
plot_style = dict() | plot_style = dict() | ||||
@@ -398,7 +398,7 @@ class _PhasePlane2D(base.Base2DSymAnalyzer): | |||||
results : tuple | results : tuple | ||||
The value points. | The value points. | ||||
""" | """ | ||||
logger.info('plot fixed point ...') | |||||
logger.warning('plot fixed point ...') | |||||
# function for fixed point solving | # function for fixed point solving | ||||
f_fixed_point = self.get_f_fixed_point() | f_fixed_point = self.get_f_fixed_point() | ||||
@@ -414,7 +414,7 @@ class _PhasePlane2D(base.Base2DSymAnalyzer): | |||||
x = x_values[i] | x = x_values[i] | ||||
y = y_values[i] | y = y_values[i] | ||||
fp_type = stability.stability_analysis(f_jacobian(x, y)) | fp_type = stability.stability_analysis(f_jacobian(x, y)) | ||||
logger.info(f"Fixed point #{i + 1} at {self.x_var}={x}, {self.y_var}={y} is a {fp_type}.") | |||||
logger.warning(f"Fixed point #{i + 1} at {self.x_var}={x}, {self.y_var}={y} is a {fp_type}.") | |||||
container[fp_type]['x'].append(x) | container[fp_type]['x'].append(x) | ||||
container[fp_type]['y'].append(y) | container[fp_type]['y'].append(y) | ||||
@@ -453,7 +453,7 @@ class _PhasePlane2D(base.Base2DSymAnalyzer): | |||||
values : dict | values : dict | ||||
A dict with the format of ``{func1: (x_val, y_val), func2: (x_val, y_val)}``. | A dict with the format of ``{func1: (x_val, y_val), func2: (x_val, y_val)}``. | ||||
""" | """ | ||||
logger.info('plot nullcline ...') | |||||
logger.warning('plot nullcline ...') | |||||
if numerical_setting is None: | if numerical_setting is None: | ||||
numerical_setting = dict() | numerical_setting = dict() | ||||
@@ -579,7 +579,7 @@ class _PhasePlane2D(base.Base2DSymAnalyzer): | |||||
Whether show or not. | Whether show or not. | ||||
""" | """ | ||||
logger.info('plot trajectory ...') | |||||
logger.warning('plot trajectory ...') | |||||
if axes not in ['v-v', 't-v']: | if axes not in ['v-v', 't-v']: | ||||
raise errors.BrainPyError(f'Unknown axes "{axes}", only support "v-v" and "t-v".') | raise errors.BrainPyError(f'Unknown axes "{axes}", only support "v-v" and "t-v".') | ||||
@@ -687,7 +687,7 @@ class _PhasePlane2D(base.Base2DSymAnalyzer): | |||||
show : bool | show : bool | ||||
Whether show or not. | Whether show or not. | ||||
""" | """ | ||||
logger.info('plot limit cycle ...') | |||||
logger.warning('plot limit cycle ...') | |||||
# 1. format the initial values | # 1. format the initial values | ||||
if isinstance(initials, dict): | if isinstance(initials, dict): | ||||
@@ -732,7 +732,7 @@ class _PhasePlane2D(base.Base2DSymAnalyzer): | |||||
lines = plt.plot(x_cycle, y_cycle, label='limit cycle') | lines = plt.plot(x_cycle, y_cycle, label='limit cycle') | ||||
utils.add_arrow(lines[0]) | utils.add_arrow(lines[0]) | ||||
else: | else: | ||||
logger.info(f'No limit cycle found for initial value {initial}') | |||||
logger.warning(f'No limit cycle found for initial value {initial}') | |||||
# 6. visualization | # 6. visualization | ||||
plt.xlabel(self.x_var) | plt.xlabel(self.x_var) | ||||
@@ -50,6 +50,8 @@ class Collector(dict): | |||||
>>> import brainpy as bp | >>> import brainpy as bp | ||||
>>> | >>> | ||||
>>> some_collector = Collector() | |||||
>>> | |||||
>>> # get all trainable variables | >>> # get all trainable variables | ||||
>>> some_collector.subset(bp.math.TrainVar) | >>> some_collector.subset(bp.math.TrainVar) | ||||
>>> | >>> | ||||
@@ -59,7 +61,7 @@ class Collector(dict): | |||||
or, it can be used to get a subset of integrators: | or, it can be used to get a subset of integrators: | ||||
>>> # get all ODE integrators | >>> # get all ODE integrators | ||||
>>> some_collector.subset(bp.integrators.ODE_INT) | |||||
>>> some_collector.subset(bp.ode.ODEIntegrator) | |||||
Parameters | Parameters | ||||
---------- | ---------- | ||||
@@ -4,7 +4,7 @@ from brainpy import errors | |||||
from brainpy.base.base import Base | from brainpy.base.base import Base | ||||
from brainpy.base import collector | from brainpy.base import collector | ||||
ndarray = None | |||||
math = None | |||||
__all__ = [ | __all__ = [ | ||||
'Function', | 'Function', | ||||
@@ -18,11 +18,11 @@ def _check_node(node): | |||||
def _check_var(var): | def _check_var(var): | ||||
global ndarray | |||||
if ndarray is None: from brainpy.math import ndarray | |||||
if not isinstance(var, ndarray): | |||||
global math | |||||
if math is None: from brainpy import math | |||||
if not isinstance(var, math.ndarray): | |||||
raise errors.BrainPyError(f'Element in "dyn_vars" must be an instance of ' | raise errors.BrainPyError(f'Element in "dyn_vars" must be an instance of ' | ||||
f'{ndarray.__name__}, but we got {type(var)}.') | |||||
f'{math.ndarray.__name__}, but we got {type(var)}.') | |||||
class Function(Base): | class Function(Base): | ||||
@@ -70,9 +70,9 @@ class Function(Base): | |||||
# --- | # --- | ||||
if dyn_vars is not None: | if dyn_vars is not None: | ||||
self.implicit_vars = collector.TensorCollector() | self.implicit_vars = collector.TensorCollector() | ||||
global ndarray | |||||
if ndarray is None: from brainpy.math import ndarray | |||||
if isinstance(dyn_vars, ndarray): | |||||
global math | |||||
if math is None: from brainpy import math | |||||
if isinstance(dyn_vars, math.ndarray): | |||||
dyn_vars = (dyn_vars,) | dyn_vars = (dyn_vars,) | ||||
if isinstance(dyn_vars, (tuple, list)): | if isinstance(dyn_vars, (tuple, list)): | ||||
for i, v in enumerate(dyn_vars): | for i, v in enumerate(dyn_vars): | ||||
@@ -83,7 +83,7 @@ class Function(Base): | |||||
_check_var(v) | _check_var(v) | ||||
self.implicit_vars.update(dyn_vars) | self.implicit_vars.update(dyn_vars) | ||||
else: | else: | ||||
raise ValueError(f'"dyn_vars" only support list/tuple/dict of {ndarray.__name__}, ' | |||||
raise ValueError(f'"dyn_vars" only support list/tuple/dict of {math.ndarray.__name__}, ' | |||||
f'but we got {type(dyn_vars)}: {dyn_vars}') | f'but we got {type(dyn_vars)}: {dyn_vars}') | ||||
def __call__(self, *args, **kwargs): | def __call__(self, *args, **kwargs): | ||||
@@ -61,8 +61,8 @@ def bernoulli(p, size=None): | |||||
return numpy.random.binomial(1, p=p, size=size) | return numpy.random.binomial(1, p=p, size=size) | ||||
def truncated_normal(): | |||||
raise NotImplementedError | |||||
def truncated_normal(lower, upper, size, scale=1.): | |||||
raise NotImplementedError('Please use `brainpy.math.jax.random.truncated_normal()`') | |||||
@tools.numba_jit | @tools.numba_jit | ||||
@@ -8,15 +8,25 @@ from .base import Initializer | |||||
__all__ = [ | __all__ = [ | ||||
'Normal', | 'Normal', | ||||
'Uniform', | 'Uniform', | ||||
'Orthogonal', | |||||
'VarianceScaling', | |||||
'KaimingUniform', | |||||
'KaimingNormal', | 'KaimingNormal', | ||||
'KaimingNormalTruncated', | |||||
'XavierUniform', | |||||
'XavierNormal', | 'XavierNormal', | ||||
'XavierNormalTruncated', | |||||
'TruncatedNormal', | |||||
'LecunUniform', | |||||
'LecunNormal', | |||||
'Orthogonal', | |||||
'DeltaOrthogonal', | |||||
] | ] | ||||
def _compute_fans(shape, in_axis=-2, out_axis=-1): | |||||
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] | |||||
fan_in = shape[in_axis] * receptive_field_size | |||||
fan_out = shape[out_axis] * receptive_field_size | |||||
return fan_in, fan_out | |||||
class Normal(Initializer): | class Normal(Initializer): | ||||
"""Initialize weights with normal distribution. | """Initialize weights with normal distribution. | ||||
@@ -26,12 +36,13 @@ class Normal(Initializer): | |||||
The gain of the derivation of the normal distribution. | The gain of the derivation of the normal distribution. | ||||
""" | """ | ||||
def __init__(self, gain=1.): | |||||
def __init__(self, scale=1.): | |||||
super(Normal, self).__init__() | super(Normal, self).__init__() | ||||
self.gain = gain | |||||
self.scale = scale | |||||
def __call__(self, shape, dtype=None): | def __call__(self, shape, dtype=None): | ||||
weights = math.random.normal(size=shape, scale=self.gain * np.sqrt(1 / np.prod(shape))) | |||||
weights = math.random.normal(size=shape, scale=self.scale) | |||||
return math.asarray(weights, dtype=dtype) | return math.asarray(weights, dtype=dtype) | ||||
@@ -46,41 +57,117 @@ class Uniform(Initializer): | |||||
The upper limit of the uniform distribution. | The upper limit of the uniform distribution. | ||||
""" | """ | ||||
def __init__(self, min_val=0., max_val=1.): | |||||
def __init__(self, min_val=0., max_val=1., scale=1e-2): | |||||
super(Uniform, self).__init__() | super(Uniform, self).__init__() | ||||
self.min_val = min_val | self.min_val = min_val | ||||
self.max_val = max_val | self.max_val = max_val | ||||
self.scale = scale | |||||
def __call__(self, shape, dtype=None): | def __call__(self, shape, dtype=None): | ||||
r = math.random.uniform(low=self.min_val, high=self.max_val, size=shape) | r = math.random.uniform(low=self.min_val, high=self.max_val, size=shape) | ||||
return math.asarray(r, dtype=dtype) | |||||
return math.asarray(r * self.scale, dtype=dtype) | |||||
class VarianceScaling(Initializer): | |||||
def __init__(self, scale, mode, distribution, in_axis=-2, out_axis=-1): | |||||
self.scale = scale | |||||
self.mode = mode | |||||
self.in_axis = in_axis | |||||
self.out_axis = out_axis | |||||
self.distribution = distribution | |||||
def __call__(self, shape, dtype=None): | |||||
fan_in, fan_out = _compute_fans(shape, in_axis=self.in_axis, out_axis=self.out_axis) | |||||
if self.mode == "fan_in": | |||||
denominator = fan_in | |||||
elif self.mode == "fan_out": | |||||
denominator = fan_out | |||||
elif self.mode == "fan_avg": | |||||
denominator = (fan_in + fan_out) / 2 | |||||
else: | |||||
raise ValueError("invalid mode for variance scaling initializer: {}".format(self.mode)) | |||||
variance = math.array(self.scale / denominator, dtype=dtype) | |||||
if self.distribution == "truncated_normal": | |||||
# constant is stddev of standard normal truncated to (-2, 2) | |||||
stddev = math.sqrt(variance) / math.array(.87962566103423978, dtype) | |||||
res = math.random.truncated_normal(-2, 2, shape) * stddev | |||||
return math.asarray(res, dtype=dtype) | |||||
elif self.distribution == "normal": | |||||
res = math.random.normal(size=shape) * math.sqrt(variance) | |||||
return math.asarray(res, dtype=dtype) | |||||
elif self.distribution == "uniform": | |||||
res = math.random.uniform(low=-1, high=1, size=shape) * math.sqrt(3 * variance) | |||||
return math.asarray(res, dtype=dtype) | |||||
else: | |||||
raise ValueError("invalid distribution for variance scaling initializer") | |||||
class KaimingUniform(VarianceScaling): | |||||
def __init__(self, scale=2.0, mode="fan_in", | |||||
distribution="uniform", | |||||
in_axis=-2, out_axis=-1): | |||||
super(KaimingUniform, self).__init__(scale, mode, distribution, | |||||
in_axis=in_axis, | |||||
out_axis=out_axis) | |||||
class KaimingNormal(VarianceScaling): | |||||
def __init__(self, scale=2.0, mode="fan_in", | |||||
distribution="truncated_normal", | |||||
in_axis=-2, out_axis=-1): | |||||
super(KaimingNormal, self).__init__(scale, mode, distribution, | |||||
in_axis=in_axis, | |||||
out_axis=out_axis) | |||||
class XavierUniform(VarianceScaling): | |||||
def __init__(self, scale=1.0, mode="fan_avg", | |||||
distribution="uniform", | |||||
in_axis=-2, out_axis=-1): | |||||
super(XavierUniform, self).__init__(scale, mode, distribution, | |||||
in_axis=in_axis, | |||||
out_axis=out_axis) | |||||
class XavierNormal(VarianceScaling): | |||||
def __init__(self, scale=1.0, mode="fan_avg", | |||||
distribution="truncated_normal", | |||||
in_axis=-2, out_axis=-1): | |||||
super(XavierNormal, self).__init__(scale, mode, distribution, | |||||
in_axis=in_axis, | |||||
out_axis=out_axis) | |||||
class LecunUniform(VarianceScaling): | |||||
def __init__(self, scale=1.0, mode="fan_in", | |||||
distribution="uniform", | |||||
in_axis=-2, out_axis=-1): | |||||
super(LecunUniform, self).__init__(scale, mode, distribution, | |||||
in_axis=in_axis, | |||||
out_axis=out_axis) | |||||
class LecunNormal(VarianceScaling): | |||||
def __init__(self, scale=1.0, mode="fan_in", | |||||
distribution="truncated_normal", | |||||
in_axis=-2, out_axis=-1): | |||||
super(LecunNormal, self).__init__(scale, mode, distribution, | |||||
in_axis=in_axis, | |||||
out_axis=out_axis) | |||||
class Orthogonal(Initializer): | class Orthogonal(Initializer): | ||||
"""Returns a uniformly distributed orthogonal tensor from | |||||
`Exact solutions to the nonlinear dynamics of learning in deep linear neural networks | |||||
<https://openreview.net/forum?id=_wzZwKpTDF_9C>`_. | |||||
Args: | |||||
shape: shape of the output tensor. | |||||
gain: optional scaling factor. | |||||
axis: the orthogonalizarion axis | |||||
Returns: | |||||
An orthogonally initialized tensor. | |||||
These tensors will be row-orthonormal along the access specified by | |||||
``axis``. If the rank of the weight is greater than 2, the shape will be | |||||
flattened in all other dimensions and then will be row-orthonormal along the | |||||
final dimension. Note that this only works if the ``axis`` dimension is | |||||
larger, otherwise the tensor will be transposed (equivalently, it will be | |||||
column orthonormal instead of row orthonormal). | |||||
If the shape is not square, the matrices will have orthonormal rows or | |||||
columns depending on which side is smaller. | |||||
""" | |||||
def __init__(self, gain=1., axis=-1): | |||||
""" | |||||
Construct an initializer for uniformly distributed orthogonal matrices. | |||||
If the shape is not square, the matrices will have orthonormal rows or columns | |||||
depending on which side is smaller. | |||||
""" | |||||
def __init__(self, scale=1., axis=-1): | |||||
super(Orthogonal, self).__init__() | super(Orthogonal, self).__init__() | ||||
self.gain = gain | |||||
self.scale = scale | |||||
self.axis = axis | self.axis = axis | ||||
def __call__(self, shape, dtype=None): | def __call__(self, shape, dtype=None): | ||||
@@ -94,149 +181,36 @@ class Orthogonal(Initializer): | |||||
if n_rows < n_cols: q_mat = q_mat.T | if n_rows < n_cols: q_mat = q_mat.T | ||||
q_mat = np.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis))) | q_mat = np.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis))) | ||||
q_mat = np.moveaxis(q_mat, 0, self.axis) | q_mat = np.moveaxis(q_mat, 0, self.axis) | ||||
return self.gain * math.asarray(q_mat, dtype=dtype) | |||||
return self.scale * math.asarray(q_mat, dtype=dtype) | |||||
class KaimingNormal(Initializer): | |||||
"""Returns a tensor with values assigned using Kaiming He normal initializer from | |||||
`Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification | |||||
<https://arxiv.org/abs/1502.01852>`_. | |||||
Args: | |||||
shape: shape of the output tensor. | |||||
gain: optional scaling factor. | |||||
Returns: | |||||
Tensor initialized with normal random variables with standard deviation (gain * kaiming_normal_gain). | |||||
""" | |||||
def __init__(self, gain=1.): | |||||
self.gain = gain | |||||
super(KaimingNormal, self).__init__() | |||||
def __call__(self, shape, dtype=None): | |||||
gain = np.sqrt(1 / np.prod(shape[:-1])) | |||||
res = math.random.normal(size=shape, scale=self.gain * gain) | |||||
return math.asarray(res, dtype=dtype) | |||||
class KaimingNormalTruncated(Initializer): | |||||
"""Returns a tensor with values assigned using Kaiming He truncated normal initializer from | |||||
`Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification | |||||
<https://arxiv.org/abs/1502.01852>`_. | |||||
Args: | |||||
shape: shape of the output tensor. | |||||
lower: lower truncation of the normal. | |||||
upper: upper truncation of the normal. | |||||
gain: optional scaling factor. | |||||
Returns: | |||||
Tensor initialized with truncated normal random variables with standard | |||||
deviation (gain * kaiming_normal_gain) and support [lower, upper]. | |||||
""" | |||||
def __init__(self, lower=-2., upper=2., gain=1.): | |||||
self.lower = lower | |||||
self.upper = upper | |||||
self.gain = gain | |||||
super(KaimingNormalTruncated, self).__init__() | |||||
def __call__(self, shape, dtype=None): | |||||
truncated_std = scipy.stats.truncnorm.std(a=self.lower, | |||||
b=self.upper, | |||||
loc=0., | |||||
scale=1.) | |||||
stddev = self.gain * np.sqrt(1 / np.prod(shape[:-1])) / truncated_std | |||||
res = math.random.truncated_normal(size=shape, | |||||
scale=stddev, | |||||
lower=self.lower, | |||||
upper=self.upper) | |||||
return math.asarray(res, dtype=dtype) | |||||
class XavierNormal(Initializer): | |||||
"""Returns a tensor with values assigned using Xavier Glorot normal initializer from | |||||
`Understanding the difficulty of training deep feedforward neural networks | |||||
<http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_. | |||||
Args: | |||||
shape: shape of the output tensor. | |||||
gain: optional scaling factor. | |||||
Returns: | |||||
Tensor initialized with normal random variables with standard deviation (gain * xavier_normal_gain). | |||||
""" | |||||
def __init__(self, gain=1.): | |||||
super(XavierNormal, self).__init__() | |||||
self.gain = gain | |||||
class DeltaOrthogonal(Initializer): | |||||
""" | |||||
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393. | |||||
def __call__(self, shape, dtype=None): | |||||
fan_in, fan_out = np.prod(shape[:-1]), shape[-1] | |||||
gain = np.sqrt(2 / (fan_in + fan_out)) | |||||
res = math.random.normal(size=shape, scale=self.gain * gain) | |||||
return math.asarray(res, dtype=dtype) | |||||
class XavierNormalTruncated(Initializer): | |||||
"""Returns a tensor with values assigned using Xavier Glorot truncated normal initializer from | |||||
`Understanding the difficulty of training deep feedforward neural networks | |||||
<http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_. | |||||
Args: | |||||
shape: shape of the output tensor. | |||||
lower: lower truncation of the normal. | |||||
upper: upper truncation of the normal. | |||||
gain: optional scaling factor. | |||||
Returns: | |||||
Tensor initialized with truncated normal random variables with standard | |||||
deviation (gain * xavier_normal_gain) and support [lower, upper]. | |||||
""" | |||||
def __init__(self, lower=-2., upper=2., gain=1.): | |||||
self.lower = lower | |||||
self.upper = upper | |||||
self.gain = gain | |||||
super(XavierNormalTruncated, self).__init__() | |||||
The shape must be 3D, 4D or 5D. | |||||
""" | |||||
def __call__(self, shape, dtype=None): | |||||
truncated_std = scipy.stats.truncnorm.std(a=self.lower, b=self.upper, loc=0., scale=1) | |||||
fan_in, fan_out = np.prod(shape[:-1]), shape[-1] | |||||
gain = np.sqrt(2 / (fan_in + fan_out)) | |||||
stddev = self.gain * gain / truncated_std | |||||
res = math.random.truncated_normal(size=shape, | |||||
scale=stddev, | |||||
lower=self.lower, | |||||
upper=self.upper) | |||||
return math.asarray(res, dtype=dtype) | |||||
class TruncatedNormal(Initializer): | |||||
"""Returns a tensor with values assigned using truncated normal initialization. | |||||
Args: | |||||
shape: shape of the output tensor. | |||||
lower: lower truncation of the normal. | |||||
upper: upper truncation of the normal. | |||||
stddev: expected standard deviation. | |||||
Returns: | |||||
Tensor initialized with truncated normal random variables with standard | |||||
deviation stddev and support [lower, upper]. | |||||
""" | |||||
def __init__(self, lower=-2., upper=2., scale=1.): | |||||
self.lower = lower | |||||
self.upper = upper | |||||
def __init__(self, scale=1.0, axis=-1, ): | |||||
super(DeltaOrthogonal, self).__init__() | |||||
self.scale = scale | self.scale = scale | ||||
super(TruncatedNormal, self).__init__() | |||||
self.axis = axis | |||||
def __call__(self, shape, dtype=None): | def __call__(self, shape, dtype=None): | ||||
truncated_std = scipy.stats.truncnorm.std(a=self.lower, b=self.upper, loc=0., scale=1) | |||||
res = math.random.truncated_normal(size=shape, | |||||
scale=self.scale / truncated_std, | |||||
lower=self.lower, | |||||
upper=self.upper) | |||||
return math.asarray(res, dtype=dtype) | |||||
if len(shape) not in [3, 4, 5]: | |||||
raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D shape.") | |||||
if shape[-1] < shape[-2]: | |||||
raise ValueError("`fan_in` must be less or equal than `fan_out`. ") | |||||
ortho_init = Orthogonal(scale=self.scale, axis=self.axis) | |||||
ortho_matrix = ortho_init(shape[-2:], dtype=dtype) | |||||
W = math.zeros(shape, dtype=dtype) | |||||
if len(shape) == 3: | |||||
k = shape[0] | |||||
W[(k - 1) // 2, ...] = ortho_matrix | |||||
elif len(shape) == 4: | |||||
k1, k2 = shape[:2] | |||||
W[(k1 - 1) // 2, (k2 - 1) // 2, ...] = ortho_matrix | |||||
else: | |||||
k1, k2, k3 = shape[:3] | |||||
W[(k1 - 1) // 2, (k2 - 1) // 2, (k3 - 1) // 2, ...] = ortho_matrix | |||||
return W |
@@ -1,7 +1,10 @@ | |||||
# -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
import inspect | import inspect | ||||
import numpy as onp | |||||
from typing import Union | from typing import Union | ||||
import jax.numpy as jnp | |||||
import brainpy.math.jax as bm | |||||
from brainpy import errors | from brainpy import errors | ||||
from brainpy.base.collector import Collector | from brainpy.base.collector import Collector | ||||
@@ -26,6 +29,19 @@ class Module(DynamicalSystem): | |||||
"""Basic module class for DNN networks.""" | """Basic module class for DNN networks.""" | ||||
target_backend = 'jax' | target_backend = 'jax' | ||||
@staticmethod | |||||
def get_param(param, size): | |||||
if param is None: | |||||
return None | |||||
if callable(param): | |||||
return bm.TrainVar(param(size)) | |||||
if isinstance(param, onp.ndarray): | |||||
assert param.shape == size | |||||
return bm.TrainVar(bm.asarray(param)) | |||||
if isinstance(param, (bm.JaxArray, jnp.ndarray)): | |||||
return bm.TrainVar(param) | |||||
raise ValueError | |||||
class Sequential(Module): | class Sequential(Module): | ||||
"""Basic sequential object to control data flow. | """Basic sequential object to control data flow. | ||||
@@ -82,18 +82,8 @@ class Conv2D(Module): | |||||
self.has_bias = True | self.has_bias = True | ||||
# weight initialization | # weight initialization | ||||
if callable(w): | |||||
self.w = bm.TrainVar(w((*_check_tuple(kernel_size), num_input // groups, num_output))) # HWIO | |||||
else: | |||||
assert w.shape == (*_check_tuple(kernel_size), num_input // groups, num_output) | |||||
self.w = bm.TrainVar(w) | |||||
if callable(b): | |||||
self.b = bm.TrainVar(b((num_output, 1, 1))) | |||||
elif b is None: | |||||
self.has_bias = False | |||||
else: | |||||
assert b.shape == (num_output, 1, 1) | |||||
self.b = bm.TrainVar(b) | |||||
self.w = self.get_param(w, (*_check_tuple(kernel_size), num_input // groups, num_output)) | |||||
self.b = self.get_param(b, (num_output, 1, 1)) | |||||
def update(self, x): | def update(self, x): | ||||
nin = self.w.value.shape[2] * self.groups | nin = self.w.value.shape[2] * self.groups | ||||
@@ -108,5 +98,5 @@ class Conv2D(Module): | |||||
rhs_dilation=self.dilations, | rhs_dilation=self.dilations, | ||||
feature_group_count=self.groups, | feature_group_count=self.groups, | ||||
dimension_numbers=('NCHW', 'HWIO', 'NCHW')) | dimension_numbers=('NCHW', 'HWIO', 'NCHW')) | ||||
if self.has_bias: y += self.b.value | |||||
return y | |||||
if self.b is None: return y | |||||
return y + self.b.value |
@@ -40,22 +40,12 @@ class Dense(Module): | |||||
self.num_hidden = num_hidden | self.num_hidden = num_hidden | ||||
# variables | # variables | ||||
if callable(w): | |||||
self.w = bm.TrainVar(w((num_input, num_hidden))) | |||||
else: | |||||
assert w.shape == (num_input, num_hidden) | |||||
self.w = bm.TrainVar(w) | |||||
if b is None: | |||||
self.has_bias = False | |||||
elif callable(b): | |||||
self.b = bm.TrainVar(b((num_hidden,))) | |||||
else: | |||||
assert b.shape == (num_hidden, ) | |||||
self.b = bm.TrainVar(b) | |||||
self.w = self.get_param(w, (num_input, num_hidden)) | |||||
self.b = self.get_param(b, (num_hidden,)) | |||||
def update(self, x): | def update(self, x): | ||||
"""Returns the results of applying the linear transformation to input x.""" | """Returns the results of applying the linear transformation to input x.""" | ||||
if self.has_bias: | |||||
return x @ self.w + self.b | |||||
else: | |||||
if self.b is None: | |||||
return x @ self.w | return x @ self.w | ||||
else: | |||||
return x @ self.w + self.b |
@@ -43,35 +43,17 @@ class VanillaRNN(RNNCore): | |||||
def __init__(self, num_hidden, num_input, num_batch, h=Uniform(), w=XavierNormal(), b=ZeroInit(), **kwargs): | def __init__(self, num_hidden, num_input, num_batch, h=Uniform(), w=XavierNormal(), b=ZeroInit(), **kwargs): | ||||
super(VanillaRNN, self).__init__(num_hidden, num_input, **kwargs) | super(VanillaRNN, self).__init__(num_hidden, num_input, **kwargs) | ||||
self.has_bias = True | |||||
# variables | # variables | ||||
if callable(h): | |||||
self.h = bm.Variable(h((num_batch, self.num_hidden))) | |||||
else: | |||||
self.h = bm.Variable(h) | |||||
self.h = bm.Variable(self.get_param(h, (num_batch, self.num_hidden))) | |||||
# weights | # weights | ||||
if callable(w): | |||||
self.w_ir = bm.TrainVar(w((num_input, num_hidden))) | |||||
self.w_rr = bm.TrainVar(w((num_hidden, num_hidden))) | |||||
else: | |||||
w_ir, w_rr = w | |||||
assert w_ir.shape == (num_input, num_hidden) | |||||
assert w_rr.shape == (num_hidden, num_hidden) | |||||
self.w_ir = bm.TrainVar(w_ir) | |||||
self.w_rr = bm.TrainVar(w_rr) | |||||
if b is None: | |||||
self.has_bias = False | |||||
elif callable(b): | |||||
self.b = bm.TrainVar(b((num_hidden,))) | |||||
else: | |||||
assert b.shape == (num_hidden,) | |||||
self.b = bm.TrainVar(b) | |||||
ws = self.get_param(w, (num_input + num_hidden, num_hidden)) | |||||
self.w_ir = bm.TrainVar(ws[:num_input]) | |||||
self.w_rr = bm.TrainVar(ws[num_input:]) | |||||
self.b = self.get_param(b, (num_hidden,)) | |||||
def update(self, x): | def update(self, x): | ||||
h = x @ self.w_ir + self.h @ self.w_rr | |||||
if self.has_bias: h += self.b | |||||
h = x @ self.w_ir + self.h @ self.w_rr + self.b | |||||
self.h.value = bm.relu(h) | self.h.value = bm.relu(h) | ||||
return self.h | return self.h | ||||
@@ -112,60 +94,35 @@ class GRU(RNNCore): | |||||
self.has_bias = True | self.has_bias = True | ||||
# variables | # variables | ||||
if callable(h): | |||||
self.h = bm.Variable(h((num_batch, self.num_hidden))) | |||||
else: | |||||
self.h = bm.Variable(h) | |||||
self.h = bm.Variable(self.get_param(h, (num_batch, self.num_hidden))) | |||||
# weights | # weights | ||||
if callable(wx): | |||||
self.w_iz = bm.TrainVar(wx((num_input, num_hidden))) | |||||
self.w_ir = bm.TrainVar(wx((num_input, num_hidden))) | |||||
self.w_ia = bm.TrainVar(wx((num_input, num_hidden))) | |||||
else: | |||||
w_iz, w_ir, w_ia = wx | |||||
assert w_iz.shape == (num_input, num_hidden) | |||||
assert w_ir.shape == (num_input, num_hidden) | |||||
assert w_ia.shape == (num_input, num_hidden) | |||||
self.w_iz = bm.TrainVar(w_iz) | |||||
self.w_ir = bm.TrainVar(w_ir) | |||||
self.w_ia = bm.TrainVar(w_ia) | |||||
if callable(wh): | |||||
self.w_hz = bm.TrainVar(wh((num_hidden, num_hidden))) | |||||
self.w_hr = bm.TrainVar(wh((num_hidden, num_hidden))) | |||||
self.w_ha = bm.TrainVar(wh((num_hidden, num_hidden))) | |||||
else: | |||||
w_hz, w_hr, w_ha = wh | |||||
assert w_hz.shape == (num_hidden, num_hidden) | |||||
assert w_hr.shape == (num_hidden, num_hidden) | |||||
assert w_ha.shape == (num_hidden, num_hidden) | |||||
self.w_hz = bm.TrainVar(w_hz) | |||||
self.w_hr = bm.TrainVar(w_hr) | |||||
self.w_ha = bm.TrainVar(w_ha) | |||||
if b is None: | |||||
self.has_bias = False | |||||
self.bz = 0. | |||||
self.br = 0. | |||||
self.ba = 0. | |||||
elif callable(b): | |||||
self.bz = bm.TrainVar(b((num_hidden,))) | |||||
self.br = bm.TrainVar(b((num_hidden,))) | |||||
self.ba = bm.TrainVar(b((num_hidden,))) | |||||
else: | |||||
bz, br, ba = b | |||||
assert bz.shape == (num_hidden, ) | |||||
assert br.shape == (num_hidden, ) | |||||
assert ba.shape == (num_hidden, ) | |||||
self.bz = bm.TrainVar(bz) | |||||
self.br = bm.TrainVar(br) | |||||
self.ba = bm.TrainVar(ba) | |||||
wxs = self.get_param(wx, (num_input * 3, num_hidden)) | |||||
self.w_iz = bm.TrainVar(wxs[:num_input]) | |||||
self.w_ir = bm.TrainVar(wxs[num_input: num_input * 2]) | |||||
self.w_ia = bm.TrainVar(wxs[num_input * 2:]) | |||||
whs = self.get_param(wh, (num_hidden * 3, num_hidden)) | |||||
self.w_hz = bm.TrainVar(whs[:num_hidden]) | |||||
self.w_hr = bm.TrainVar(whs[num_hidden: num_hidden * 2]) | |||||
self.w_ha = bm.TrainVar(whs[num_hidden * 2:]) | |||||
bs = self.get_param(b, (num_hidden * 3,)) | |||||
self.bz = bm.TrainVar(bs[:num_hidden]) | |||||
self.br = bm.TrainVar(bs[num_hidden: num_hidden * 2]) | |||||
self.ba = bm.TrainVar(bs[num_hidden * 2:]) | |||||
def update(self, x): | def update(self, x): | ||||
z = bm.sigmoid(x @ self.w_iz + self.h @ self.w_hz + self.bz) | |||||
r = bm.sigmoid(x @ self.w_ir + self.h @ self.w_hr + self.br) | |||||
a = bm.tanh(x @ self.w_ia + (r * self.h) @ self.w_ha + self.ba) | |||||
self.h.value = (1 - z) * self.h + z * a | |||||
return self.h.value | |||||
if self.bz is None: | |||||
z = bm.sigmoid(x @ self.w_iz + self.h @ self.w_hz) | |||||
r = bm.sigmoid(x @ self.w_ir + self.h @ self.w_hr) | |||||
a = bm.tanh(x @ self.w_ia + (r * self.h) @ self.w_ha) | |||||
self.h.value = (1 - z) * self.h + z * a | |||||
return self.h.value | |||||
else: | |||||
z = bm.sigmoid(x @ self.w_iz + self.h @ self.w_hz + self.bz) | |||||
r = bm.sigmoid(x @ self.w_ir + self.h @ self.w_hr + self.br) | |||||
a = bm.tanh(x @ self.w_ia + (r * self.h) @ self.w_ha + self.ba) | |||||
self.h.value = (1 - z) * self.h + z * a | |||||
return self.h.value | |||||
class LSTM(RNNCore): | class LSTM(RNNCore): | ||||
@@ -214,38 +171,23 @@ class LSTM(RNNCore): | |||||
self.has_bias = True | self.has_bias = True | ||||
# variables | # variables | ||||
if callable(hc): | |||||
self.h = bm.Variable(hc((num_batch, self.num_hidden))) | |||||
self.c = bm.Variable(hc((num_batch, self.num_hidden))) | |||||
else: | |||||
h, c = hc | |||||
assert h.shape == (num_batch, self.num_hidden) | |||||
assert c.shape == (num_batch, self.num_hidden) | |||||
self.h = bm.Variable(h) | |||||
self.c = bm.Variable(c) | |||||
hc = bm.Variable(self.get_param(hc, (num_batch * 2, self.num_hidden))) | |||||
self.h = bm.Variable(hc[:num_batch]) | |||||
self.c = bm.Variable(hc[num_batch:]) | |||||
# weights | # weights | ||||
if callable(w): | |||||
self.w = bm.TrainVar(w((num_input + num_hidden, num_hidden * 4))) | |||||
else: | |||||
assert w.shape == (num_input + num_hidden, num_hidden * 4) | |||||
self.w = bm.TrainVar(w) | |||||
if b is None: | |||||
self.b = 0. | |||||
self.has_bias = False | |||||
elif callable(b): | |||||
self.b = bm.TrainVar(b((num_hidden * 4,))) | |||||
else: | |||||
assert b.shape == (num_hidden * 4, ) | |||||
self.b = bm.TrainVar(b) | |||||
self.w = self.get_param(w, (num_input + num_hidden, num_hidden * 4)) | |||||
self.b = self.get_param(b, (num_hidden * 4,)) | |||||
def update(self, x): | def update(self, x): | ||||
xh = bm.concatenate([x, self.h], axis=-1) | xh = bm.concatenate([x, self.h], axis=-1) | ||||
gated = xh @ self.w + self.b | |||||
if self.b is None: | |||||
gated = xh @ self.w | |||||
else: | |||||
gated = xh @ self.w + self.b | |||||
i, g, f, o = bm.split(gated, indices_or_sections=4, axis=-1) | i, g, f, o = bm.split(gated, indices_or_sections=4, axis=-1) | ||||
c = bm.sigmoid(f + 1.) * self.c + bm.sigmoid(i) * bm.tanh(g) | c = bm.sigmoid(f + 1.) * self.c + bm.sigmoid(i) * bm.tanh(g) | ||||
h = bm.sigmoid(o) * bm.tanh(c) | h = bm.sigmoid(o) * bm.tanh(c) | ||||
self.h.value = h | self.h.value = h | ||||
self.c.value = c | self.c.value = c | ||||
return self.h.value | return self.h.value | ||||
@@ -2,6 +2,20 @@ Release notes | |||||
============= | ============= | ||||
Version 1.1.5 | |||||
------------- | |||||
**API changes:** | |||||
- fix bugs on ndarray import in `brainpy.base.function.py` | |||||
- convenient 'get_param' interface `brainpy.simulation.layers` | |||||
- add more weight initialization methods | |||||
**Doc changes:** | |||||
- add more examples in README | |||||
Version 1.1.4 | Version 1.1.4 | ||||
------------- | ------------- | ||||
@@ -84,7 +84,21 @@ is based on JAX. | |||||
Currently, JAX supports **Linux** (Ubuntu 16.04 or later) and **macOS** (10.12 or | Currently, JAX supports **Linux** (Ubuntu 16.04 or later) and **macOS** (10.12 or | ||||
later) platforms. The provided binary releases of JAX for Linux and macOS | later) platforms. The provided binary releases of JAX for Linux and macOS | ||||
systems are available at https://storage.googleapis.com/jax-releases/jax_releases.html . | systems are available at https://storage.googleapis.com/jax-releases/jax_releases.html . | ||||
Users can download the preferred release ".whl" file, and install it via ``pip``: | |||||
To install a CPU-only version of JAX, you can run | |||||
.. code-block:: bash | |||||
pip install --upgrade "jax[cpu]" | |||||
If you want to install JAX with both CPU and NVidia GPU support, you must first install | |||||
`CUDA`_ and `CuDNN`_, if they have not already been installed. Next, run | |||||
.. code-block:: bash | |||||
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html | |||||
Alternatively, you can download the preferred release ".whl" file, and install it via ``pip``: | |||||
.. code-block:: bash | .. code-block:: bash | ||||
@@ -92,13 +106,23 @@ Users can download the preferred release ".whl" file, and install it via ``pip`` | |||||
For **Windows** users, JAX can be installed by the following methods: | For **Windows** users, JAX can be installed by the following methods: | ||||
- For Windows 10+ system, you can `Windows Subsystem for Linux (WSL)`_. | |||||
The installation guide can be found in `WSL Installation Guide for Windows 10`_. | |||||
Then, you can install JAX in WSL just like the installation step in Linux. | |||||
- There are several precompiled Windows wheels, like `jaxlib_0.1.68_Windows_wheels`_ and `jaxlib_0.1.61_Windows_wheels`_. | |||||
- Finally, you can also `build JAX from source`_. | |||||
Method 1: For Windows 10+ system, you can `Windows Subsystem for Linux (WSL)`_. | |||||
The installation guide can be found in `WSL Installation Guide for Windows 10`_. | |||||
Then, you can install JAX in WSL just like the installation step in Linux. | |||||
Method 2: There are several community supported Windows build for jax, please refer | |||||
to the github link for more details: https://github.com/cloudhan/jax-windows-builder . | |||||
Simply speaking, you can run: | |||||
.. code-block:: bash | |||||
# for only CPU | |||||
pip install jaxlib -f https://whls.blob.core.windows.net/unstable/index.html | |||||
# for GPU support | |||||
pip install <downloaded jaxlib> | |||||
More details of JAX installation can be found in https://github.com/google/jax#installation . | |||||
Method 3: You can also `build JAX from source`_. | |||||
Numba | Numba | ||||
@@ -148,3 +172,5 @@ Therefore, we highly recommend you to install sympy, just typing | |||||
.. _SymPy: https://github.com/sympy/sympy | .. _SymPy: https://github.com/sympy/sympy | ||||
.. _Exponential Euler numerical solver: https://brainpy.readthedocs.io/en/latest/tutorials_advanced/ode_numerical_solvers.html#Exponential-Euler-methods | .. _Exponential Euler numerical solver: https://brainpy.readthedocs.io/en/latest/tutorials_advanced/ode_numerical_solvers.html#Exponential-Euler-methods | ||||
.. _dynamics analysis module: https://brainpy.readthedocs.io/en/latest/apis/analysis.html | .. _dynamics analysis module: https://brainpy.readthedocs.io/en/latest/apis/analysis.html | ||||
.. _CUDA: https://developer.nvidia.com/cuda-downloads | |||||
.. _CuDNN: https://developer.nvidia.com/CUDNN |
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》