@@ -12,7 +12,7 @@ This module provide basic Node class for whole ``brainpy.nn`` system.
This means ``brainpy.nn.Network`` is only used to pack element nodes. It will be
This means ``brainpy.nn.Network`` is only used to pack element nodes. It will be
never be an element node.
never be an element node.
- ``brainpy.nn.FrozenNetwork``: The whole network which can be represented as a basic
- ``brainpy.nn.FrozenNetwork``: The whole network which can be represented as a basic
elementary node when composing a larger network. TODO
elementary node when composing a larger network ( TODO).
"""
"""
from copy import copy, deepcopy
from copy import copy, deepcopy
@@ -48,6 +48,16 @@ __all__ = [
NODE_STATES = ['inputs', 'feedbacks', 'state', 'output']
NODE_STATES = ['inputs', 'feedbacks', 'state', 'output']
SUPPORTED_LAYOUTS = ['shell_layout',
'multipartite_layout',
'spring_layout',
'spiral_layout',
'spectral_layout',
'random_layout',
'planar_layout',
'kamada_kawai_layout',
'circular_layout']
def not_implemented(fun: Callable) -> Callable:
def not_implemented(fun: Callable) -> Callable:
"""Marks the given module method is not implemented.
"""Marks the given module method is not implemented.
@@ -92,8 +102,10 @@ class Node(Base):
self._is_ff_initialized = False
self._is_ff_initialized = False
self._is_fb_initialized = False
self._is_fb_initialized = False
self._is_state_initialized = False
self._is_state_initialized = False
self._is_fb_state_initialized = False
self._trainable = trainable
self._trainable = trainable
self._state = None # the state of the current node
self._state = None # the state of the current node
self._fb_output = None # the feedback output of the current node
# data pass function
# data pass function
if self.data_pass_type not in DATA_PASS_FUNC:
if self.data_pass_type not in DATA_PASS_FUNC:
raise ValueError(f'Unsupported data pass type {self.data_pass_type}. '
raise ValueError(f'Unsupported data pass type {self.data_pass_type}. '
@@ -111,12 +123,9 @@ class Node(Base):
name = type(self).__name__
name = type(self).__name__
prefix = ' ' * (len(name) + 1)
prefix = ' ' * (len(name) + 1)
line1 = (f"{name}(name={self.name}, "
line1 = (f"{name}(name={self.name}, "
f"trainable={self.trainable}, "
f"forwards={self.feedforward_shapes}, "
f"forwards={self.feedforward_shapes}, "
f"feedbacks={self.feedback_shapes}, \n")
f"feedbacks={self.feedback_shapes}, \n")
line2 = (f"{prefix}output={self.output_shape}, "
f"support_feedback={self.support_feedback}, "
f"data_pass_type={self.data_pass_type})")
line2 = f"{prefix}output={self.output_shape}"
return line1 + line2
return line1 + line2
def __call__(self, *args, **kwargs) -> Tensor:
def __call__(self, *args, **kwargs) -> Tensor:
@@ -194,7 +203,7 @@ class Node(Base):
@property
@property
def state(self) -> Optional[Tensor]:
def state(self) -> Optional[Tensor]:
"""Node current internal state."""
"""Node current internal state."""
if self.is_ff_initialized:
if self._ is_ff_initialized:
return self._state
return self._state
return None
return None
@@ -209,9 +218,9 @@ class Node(Base):
This method allows the maximum flexibility to change the
This method allows the maximum flexibility to change the
node state. It can set a new data (same shape, same dtype)
node state. It can set a new data (same shape, same dtype)
to the state. It can also set the data with another batch size.
We highly recommend the user to use this function .
to the state. It can also set a new data with the different
shape. We highly recommend the user to use this function.
instead of using ``self.state.value`` .
"""
"""
if self.state is None:
if self.state is None:
if self.output_shape is not None:
if self.output_shape is not None:
@@ -225,31 +234,52 @@ class Node(Base):
self.state._value = bm.as_device_array(state)
self.state._value = bm.as_device_array(state)
@property
@property
def trainable(self) -> bool:
"""Returns if the Node can be trained."""
return self._trainable
def fb_output(self) -> Optional[Tensor]:
return self._fb_output
@property
def is_ff_initialized(self) -> bool:
return self._is_ff_initialized
@fb_output.setter
def fb_output(self, value: Tensor):
raise NotImplementedError('Please use "set_fb_output()" to reset the node feedback state, '
'or use "self.fb_output.value" to change the state content.')
@is_ff_initialized.setter
def is_ff_initialized(self, value: bool):
assert isinstance(value, bool)
self._is_ff_initialized = value
def set_fb_output(self, state: Tensor):
"""
Safely set the feedback state of the node.
@property
def is_fb_initialized(self) -> bool:
return self._is_fb_initialized
This method allows the maximum flexibility to change the
node state. It can set a new data (same shape, same dtype)
to the state. It can also set a new data with the different
shape. We highly recommend the user to use this function.
instead of using ``self.fb_output.value``.
"""
if self.fb_output is None:
if self.output_shape is not None:
check_batch_shape(self.output_shape, state.shape)
self._fb_output = bm.Variable(state) if not isinstance(state, bm.Variable) else state
else:
check_batch_shape(self.fb_output.shape, state.shape)
if self.fb_output.dtype != state.dtype:
raise MathError('Cannot set the feedback state, because the dtype is '
f'not consistent: {self.fb_output.dtype} != {state.dtype}')
self.fb_output._value = bm.as_device_array(state)
@is_fb_initialized.setter
def is_fb_initialized(self, value: bool):
assert isinstance(value, bool)
self._is_fb_initialized = value
@property
def trainable(self) -> bool :
"""Returns if the Node can be trained."""
return self._trainabl e
@property
@property
def is_state_initialized(self):
return self._is_state_initialized
def is_initialized(self) -> bool:
if self._is_ff_initialized and self._is_state_initialized:
if self.feedback_shapes is not None:
if self._is_fb_initialized and self._is_fb_state_initialized:
return True
else:
return False
else:
return True
else:
return False
@trainable.setter
@trainable.setter
def trainable(self, value: bool):
def trainable(self, value: bool):
@@ -268,7 +298,7 @@ class Node(Base):
self.set_feedforward_shapes(size)
self.set_feedforward_shapes(size)
def set_feedforward_shapes(self, feedforward_shapes: Dict):
def set_feedforward_shapes(self, feedforward_shapes: Dict):
if not self.is_ff_initialized:
if not self._ is_ff_initialized:
check_dict_data(feedforward_shapes,
check_dict_data(feedforward_shapes,
key_type=(Node, str),
key_type=(Node, str),
val_type=(list, tuple),
val_type=(list, tuple),
@@ -278,11 +308,11 @@ class Node(Base):
if self.feedforward_shapes is not None:
if self.feedforward_shapes is not None:
for key, size in self._feedforward_shapes.items():
for key, size in self._feedforward_shapes.items():
if key not in feedforward_shapes:
if key not in feedforward_shapes:
raise ValueError(f"Impossible to reset the input data of {self.name}. "
raise ValueError(f"Impossible to reset the input shape of {self.name}. "
f"Because this Node has the input dimension {size} from {key}. "
f"Because this Node has the input dimension {size} from {key}. "
f"While we do not find it in the given feedforward_shapes")
f"While we do not find it in the given feedforward_shapes")
if not check_batch_shape(size, feedforward_shapes[key], mode='bool'):
if not check_batch_shape(size, feedforward_shapes[key], mode='bool'):
raise ValueError(f"Impossible to reset the input data of {self.name}. "
raise ValueError(f"Impossible to reset the input shape of {self.name}. "
f"Because this Node has the input dimension {size} from {key}. "
f"Because this Node has the input dimension {size} from {key}. "
f"While the give shape is {feedforward_shapes[key]}")
f"While the give shape is {feedforward_shapes[key]}")
@@ -296,7 +326,7 @@ class Node(Base):
self.set_feedback_shapes(size)
self.set_feedback_shapes(size)
def set_feedback_shapes(self, fb_shapes: Dict):
def set_feedback_shapes(self, fb_shapes: Dict):
if not self.is_fb_initialized:
if not self._ is_fb_initialized:
check_dict_data(fb_shapes, key_type=(Node, str), val_type=(tuple, list), name='fb_shapes')
check_dict_data(fb_shapes, key_type=(Node, str), val_type=(tuple, list), name='fb_shapes')
self._feedback_shapes = fb_shapes
self._feedback_shapes = fb_shapes
else:
else:
@@ -321,14 +351,21 @@ class Node(Base):
self.set_output_shape(size)
self.set_output_shape(size)
@property
@property
def support _feedback(self):
if hasattr(self.init_fb, 'not_implemented'):
if self.init_fb.not_implemented:
def i s_feedback_input_supported (self):
if hasattr(self.init_fb_conn , 'not_implemented'):
if self.init_fb_conn .not_implemented:
return False
return False
return True
return True
@property
def is_feedback_supported(self):
if self.fb_output is None:
return False
else:
return True
def set_output_shape(self, shape: Sequence[int]):
def set_output_shape(self, shape: Sequence[int]):
if not self.is_ff_initialized:
if not self._ is_ff_initialized:
if not isinstance(shape, (tuple, list)):
if not isinstance(shape, (tuple, list)):
raise ValueError(f'Must be a sequence of int, but got {shape}')
raise ValueError(f'Must be a sequence of int, but got {shape}')
self._output_shape = tuple(shape)
self._output_shape = tuple(shape)
@@ -368,84 +405,88 @@ class Node(Base):
new_obj.name = self.unique_name(name or (self.name + '_copy'))
new_obj.name = self.unique_name(name or (self.name + '_copy'))
return new_obj
return new_obj
def _ff_init (self):
if not self.is_ff_initialized:
def _init_ff_conn (self):
if not self._ is_ff_initialized:
try:
try:
self.init_ff()
self.init_ff_conn ()
except Exception as e:
except Exception as e:
raise ModelBuildError(f'{self.name} initialization failed.') from e
raise ModelBuildError(f'{self.name} initialization failed.') from e
self._is_ff_initialized = True
self._is_ff_initialized = True
if self.output_shape is None:
raise ValueError(f'Please set the output shape when implementing '
f'"init_ff()" of the node {self.name}')
def _fb_init(self):
if not self.is_fb_initialized:
def _init_fb_conn (self):
if not self._ is_fb_initialized:
try:
try:
self.init_fb()
self.init_fb_conn ()
except Exception as e:
except Exception as e:
raise ModelBuildError(f"{self.name} initialization failed.") from e
raise ModelBuildError(f"{self.name} initialization failed.") from e
self._is_fb_initialized = True
self._is_fb_initialized = True
@not_implemented
@not_implemented
def init_fb(self):
def init_fb_conn(self):
"""Initialize the feedback connections.
This function will be called only once."""
raise ValueError(f'This node \n\n{self} \n\ndoes not support feedback connection.')
raise ValueError(f'This node \n\n{self} \n\ndoes not support feedback connection.')
def init_ff(self):
def init_ff_conn(self):
"""Initialize the feedforward connections.
This function will be called only once."""
raise NotImplementedError('Please implement the feedforward initialization.')
raise NotImplementedError('Please implement the feedforward initialization.')
def init_state(self, num_batch=1):
def _init_state(self, num_batch=1):
state = self.init_state(num_batch)
if state is not None:
self.set_state(state)
def _init_fb_output(self, num_batch=1):
output = self.init_fb_output(num_batch)
if output is not None:
self.set_fb_output(output)
def init_state(self, num_batch=1) -> Optional[Tensor]:
"""Set the initial node state.
This function can be called multiple times."""
pass
pass
def initialize(self,
ff: Optional[Union[Tensor, Dict[Any, Tensor]]] = None,
fb: Optional[Union[Tensor, Dict[Any, Tensor]]] = None,
num_batch: int = None):
def init_fb_output(self, num_batch=1) -> Optional[Tensor]:
"""Set the initial node feedback state.
This function can be called multiple times. However,
it is only triggered when the node has feedback connections.
"""
"""
Initialize the whole network. This function must be called before applying JIT.
return bm.zeros((num_batch,) + self.output_shape[1:], dtype=bm.float_)
This function is useful, because it is independent from the __call__ function.
We can use this function before we applying JIT to __call__ function.
def initialize(self, num_batch: int):
"""
"""
Initialize the node. This function must be called before applying JIT.
# feedforward initialization
if not self.is_ff_initialized:
# feedforward data
if ff is None:
if self._feedforward_shapes is None:
raise ValueError('Cannot initialize this node, because we detect '
'both "feedforward_shapes"and "ff" inputs are None. ')
in_sizes = self._feedforward_shapes
if num_batch is None:
raise ValueError('"num_batch" cannot be None when "ff" is not provided.')
check_integer(num_batch, 'num_batch', min_bound=0, allow_none=False)
else:
if isinstance(ff, (bm.ndarray, jnp.ndarray)):
ff = {self.name: ff}
assert isinstance(ff, dict), f'"ff" must be a dict or a tensor, got {type(ff)}: {ff}'
assert self.name in ff, f'Cannot find input for this node \n\n{self} \n\nwhen given "ff" {ff}'
batch_sizes = [v.shape[0] for v in ff.values()]
if set(batch_sizes) != 1:
raise ValueError('Batch sizes must be consistent, but we got multiple '
f'batch sizes {set(batch_sizes)} for the given input: \n'
f'{ff}')
in_sizes = {k: (None,) + v.shape[1:] for k, v in ff.items()}
if (num_batch is not None) and (num_batch != batch_sizes[0]):
raise ValueError(f'The provided "num_batch" {num_batch} is consistent with the '
f'batch size of the provided data {batch_sizes[0]}')
# initialize feedforward
self.set_feedforward_shapes(in_sizes)
self._ff_init()
self.init_state(num_batch)
self._is_state_initialized = True
This function is useful, because it is independent of the __call__ function.
We can use this function before we apply JIT to __call__ function.
"""
# feedback initialization
if fb is not None:
if not self.is_fb_initialized: # initialize feedback
assert isinstance(fb, dict), f'"fb" must be a dict, got {type(fb)}'
fb_sizes = {k: (None,) + v.shape[1:] for k, v in fb.items()}
self.set_feedback_shapes(fb_sizes)
self._fb_init()
else:
self._is_fb_initialized = True
# feedforward initialization
if self.feedforward_shapes is None:
raise ValueError('Cannot initialize this node, because we detect '
'both "feedforward_shapes" is None. '
'Two ways can solve this problem:\n\n'
'1. Connecting an instance of "brainpy.nn.Input()" to this node. \n'
'2. Providing the "input_shape" when initialize the node.')
check_integer(num_batch, 'num_batch', min_bound=0, allow_none=False)
self._init_ff_conn()
# initialize state
self._init_state(num_batch)
self._is_state_initialized = True
if self.feedback_shapes is not None:
# feedback initialization
self._init_fb_conn()
# initialize feedback state
self._init_fb_output(num_batch)
self._is_fb_state_initialized = True
def _check_inputs(self, ff, fb=None):
def _check_inputs(self, ff, fb=None):
# check feedforward inputs
# check feedforward inputs
@@ -477,9 +518,8 @@ class Node(Base):
forced_feedbacks: Dict[str, Tensor] = None,
forced_feedbacks: Dict[str, Tensor] = None,
monitors=None,
monitors=None,
**kwargs) -> Union[Tensor, Tuple[Tensor, Dict]]:
**kwargs) -> Union[Tensor, Tuple[Tensor, Dict]]:
# # initialization
# self.initialize(ff, fb)
if not (self.is_ff_initialized and self.is_fb_initialized and self.is_state_initialized):
# checking
if not self.is_initialized:
raise ValueError('Please initialize the Node first by calling "initialize()" function.')
raise ValueError('Please initialize the Node first by calling "initialize()" function.')
# initialize the forced data
# initialize the forced data
@@ -511,6 +551,7 @@ class Node(Base):
assert self.state is not None, (f'{self} \n\nhas no state, while '
assert self.state is not None, (f'{self} \n\nhas no state, while '
f'the user try to monitor its state.')
f'the user try to monitor its state.')
state_monitors[key] = None
state_monitors[key] = None
# calling
# calling
ff, fb = self._check_inputs(ff, fb=fb)
ff, fb = self._check_inputs(ff, fb=fb)
if 'inputs' in state_monitors:
if 'inputs' in state_monitors:
@@ -528,7 +569,7 @@ class Node(Base):
else:
else:
return output
return output
def forward(self, ff, fb=None, **kwargs):
def forward(self, ff, fb=None, **shared_ kwargs):
"""The feedforward computation function of a node.
"""The feedforward computation function of a node.
Parameters
Parameters
@@ -537,7 +578,7 @@ class Node(Base):
The feedforward inputs.
The feedforward inputs.
fb: optional, tensor, dict, sequence
fb: optional, tensor, dict, sequence
The feedback inputs.
The feedback inputs.
**kwargs
**shared_ kwargs
Other parameters.
Other parameters.
Returns
Returns
@@ -547,12 +588,12 @@ class Node(Base):
"""
"""
raise NotImplementedError
raise NotImplementedError
def feedback(self, **kwargs):
def feedback(self, ff_output, **shared_ kwargs):
"""The feedback computation function of a node.
"""The feedback computation function of a node.
Parameters
Parameters
----------
----------
**kwargs
**shared_ kwargs
Other global parameters.
Other global parameters.
Returns
Returns
@@ -560,12 +601,17 @@ class Node(Base):
Tensor
Tensor
A feedback output tensor value.
A feedback output tensor value.
"""
"""
return self.state
return ff_output
class RecurrentNode(Node):
class RecurrentNode(Node):
"""
"""
Basic class for recurrent node.
Basic class for recurrent node.
The supports for the recurrent node are:
- Self-connection when using ``plot_node_graph()`` function
- Set trainable state with ``state_trainable=True``.
"""
"""
def __init__(self,
def __init__(self,
@@ -617,19 +663,6 @@ class RecurrentNode(Node):
else:
else:
self.state._value = bm.as_device_array(state)
self.state._value = bm.as_device_array(state)
def __repr__(self):
name = type(self).__name__
prefix = ' ' * (len(name) + 1)
line1 = (f"{name}(name={self.name}, recurrent=True, "
f"trainable={self.trainable}, \n")
line2 = (f"{prefix}forwards={self.feedforward_shapes}, "
f"feedbacks={self.feedback_shapes}, \n")
line3 = (f"{prefix}output={self.output_shape}, "
f"support_feedback={self.support_feedback}, "
f"data_pass_type={self.data_pass_type})")
return line1 + line2 + line3
class Network(Node):
class Network(Node):
"""Basic Network class for neural network building in BrainPy."""
"""Basic Network class for neural network building in BrainPy."""
@@ -806,8 +839,8 @@ class Network(Node):
def replace_graph(self,
def replace_graph(self,
nodes: Sequence[Node],
nodes: Sequence[Node],
ff_edges: Sequence[Tuple[Node, Node ]],
fb_edges: Sequence[Tuple[Node, Node ]] = None) -> "Network":
ff_edges: Sequence[Tuple[Node, ... ]],
fb_edges: Sequence[Tuple[Node, ... ]] = None) -> "Network":
if fb_edges is None: fb_edges = tuple()
if fb_edges is None: fb_edges = tuple()
# assign nodes and edges
# assign nodes and edges
@@ -817,16 +850,45 @@ class Network(Node):
self._network_init()
self._network_init()
return self
return self
def init_ff(self):
def set_output_shape(self, shape: Dict[str, Sequence[int]]):
# check shape
if not isinstance(shape, dict):
raise ValueError(f'Must be a dict of <node name, shape>, but got {type(shape)}: {shape}')
for key, val in shape.items():
if not isinstance(val, (tuple, list)):
raise ValueError(f'Must be a sequence of int, but got {val} for key "{key}"')
# for s in val:
# if not (isinstance(s, int) or (s is None)):
# raise ValueError(f'Must be a sequence of int, but got {val}')
if not self._is_ff_initialized:
if len(self.exit_nodes) == 1:
self._output_shape = tuple(shape.values())[0]
else:
self._output_shape = shape
else:
for val in shape.values():
check_batch_shape(val, self.output_shape)
def init_ff_conn(self):
"""Initialize the feedforward connections of the network.
This function will be called only once."""
# input shapes of entry nodes
# input shapes of entry nodes
for node in self.entry_nodes:
for node in self.entry_nodes:
# set ff shapes
if node.feedforward_shapes is None:
if node.feedforward_shapes is None:
if self.feedforward_shapes is None:
if self.feedforward_shapes is None:
raise ValueError('Cannot find the input size. '
raise ValueError('Cannot find the input size. '
'Cannot initialize the network.')
'Cannot initialize the network.')
else:
else:
node.set_feedforward_shapes({node.name: self._feedforward_shapes[node.name]})
node.set_feedforward_shapes({node.name: self._feedforward_shapes[node.name]})
node._ff_init()
# set fb shapes
if node in self.fb_senders:
fb_shapes = {node: node.output_shape for node in self.fb_senders.get(node, [])}
if None not in fb_shapes.values():
node.set_feedback_shapes(fb_shapes)
# init ff conn
node._init_ff_conn()
# initialize the data
# initialize the data
children_queue = []
children_queue = []
@@ -840,49 +902,79 @@ class Network(Node):
children_queue.append(child)
children_queue.append(child)
while len(children_queue):
while len(children_queue):
node = children_queue.pop(0)
node = children_queue.pop(0)
# initialize input and output siz es
# set ff shap es
parent_sizes = {p: p.output_shape for p in self.ff_senders.get(node, [])}
parent_sizes = {p: p.output_shape for p in self.ff_senders.get(node, [])}
node.set_feedforward_shapes(parent_sizes)
node.set_feedforward_shapes(parent_sizes)
node._ff_init()
if node in self.fb_senders:
# set fb shapes
fb_shapes = {node: node.output_shape for node in self.fb_senders.get(node, [])}
if None not in fb_shapes.values():
node.set_feedback_shapes(fb_shapes)
# init ff conn
node._init_ff_conn()
# append children
# append children
for child in self.ff_receivers.get(node, []):
for child in self.ff_receivers.get(node, []):
ff_senders[child].remove(node)
ff_senders[child].remove(node)
if len(ff_senders.get(child, [])) == 0:
if len(ff_senders.get(child, [])) == 0:
children_queue.append(child)
children_queue.append(child)
def init_fb(self):
# set output shape
out_sizes = {node: node.output_shape for node in self.exit_nodes}
self.set_output_shape(out_sizes)
def init_fb_conn(self):
"""Initialize the feedback connections of the network.
This function will be called only once."""
for receiver, senders in self.fb_senders.items():
for receiver, senders in self.fb_senders.items():
fb_sizes = {node: node.output_shape for node in senders}
fb_sizes = {node: node.output_shape for node in senders}
if None in fb_sizes.values():
none_size_nodes = [repr(n) for n, v in fb_sizes.items() if v is None]
none_size_nodes = "\n".join(none_size_nodes)
raise ValueError(f'Output shapes of nodes \n\n'
f'{none_size_nodes}\n\n'
f'have not been initialized, '
f'leading us cannot initialize the '
f'feedback connection of node \n\n'
f'{receiver}')
receiver.set_feedback_shapes(fb_sizes)
receiver.set_feedback_shapes(fb_sizes)
receiver._fb_init()
receiver._init_fb_conn ()
def init_state(self, num_batch=1):
"""Initialize the states of all children nodes."""
def _init_state(self, num_batch=1):
"""Initialize the states of all children nodes.
This function can be called multiple times."""
for node in self.lnodes:
for node in self.lnodes:
node.init_state(num_batch)
node._init_state(num_batch)
def _init_fb_output(self, num_batch=1):
"""Initialize the node feedback state.
def initialize(self,
ff: Optional[Union[Tensor, Dict[Any, Tensor]]] = None,
fb: Optional[Union[Tensor, Dict[Any, Tensor]]] = None,
num_batch: int = None):
This function can be called multiple times. However,
it is only triggered when the node has feedback connections.
"""
for node in self.feedback_nodes:
node._init_fb_output(num_batch)
def initialize(self, num_batch: int):
"""
"""
Initialize the whole network. This function must be called before applying JIT.
Initialize the whole network. This function must be called before applying JIT.
This function is useful, because it is independent from the __call__ function.
We can use this function before we applying JIT to __call__ function.
This function is useful, because it is independent of the __call__ function.
We can use this function before we apply JIT to __call__ function.
"""
"""
# feedforward initialization
if not self.is_ff_initialized:
# set feedforward shapes
if not self._ is_ff_initialized:
# check input and output nodes
# check input and output nodes
assert len(self.entry_nodes) > 0, (f"We found this network \n\n"
f"{self} "
f"\n\nhas no input nodes.")
assert len(self.exit_nodes) > 0, (f"We found this network \n\n"
f"{self} "
f"\n\nhas no output nodes.")
# check whether has a feedforward path for each feedback pair
if len(self.entry_nodes) <= 0:
raise ValueError(f"We found this network \n\n"
f"{self} "
f"\n\nhas no input nodes.")
if len(self.exit_nodes) <= 0:
raise ValueError(f"We found this network \n\n"
f"{self} "
f"\n\nhas no output nodes.")
# check whether it has a feedforward path for each feedback pair
ff_edges = [(a.name, b.name) for a, b in self.ff_edges]
ff_edges = [(a.name, b.name) for a, b in self.ff_edges]
for node, receiver in self.fb_edges:
for node, receiver in self.fb_edges:
if not detect_path(receiver.name, node.name, ff_edges):
if not detect_path(receiver.name, node.name, ff_edges):
@@ -895,49 +987,42 @@ class Network(Node):
f'feedforward connection between them. ')
f'feedforward connection between them. ')
# feedforward checking
# feedforward checking
if ff is None:
in_sizes = dict()
for node in self.entry_nodes:
if node._feedforward_shapes is None:
raise ValueError('Cannot initialize this node, because we detect '
'both "feedforward_shapes" and "ff" inputs are None. '
'Maybe you need a brainpy.nn.Input instance '
'to instruct the input size.')
in_sizes.update(node._feedforward_shapes)
if num_batch is None:
raise ValueError('"num_batch" cannot be None when "ff" is not provided.')
check_integer(num_batch, 'num_batch', min_bound=0, allow_none=False)
else:
if isinstance(ff, (bm.ndarray, jnp.ndarray)):
ff = {self.entry_nodes[0].name: ff}
assert isinstance(ff, dict), f'ff must be a dict or a tensor, got {type(ff)}: {ff}'
for n in self.entry_nodes:
if n.name not in ff:
raise ValueError(f'Cannot find the input of the node {n}')
batch_sizes = [v.shape[0] for v in ff.values()]
if len(set(batch_sizes)) != 1:
raise ValueError('Batch sizes must be consistent, but we got multiple '
f'batch sizes {set(batch_sizes)} for the given input: \n'
f'{ff}')
in_sizes = {k: (None,) + v.shape[1:] for k, v in ff.items()}
if (num_batch is not None) and (num_batch != batch_sizes[0]):
raise ValueError(f'The provided "num_batch" {num_batch} is consistent with the '
f'batch size of the provided data {batch_sizes[0]}')
# initialize feedforward
in_sizes = dict()
for node in self.entry_nodes:
if node.feedforward_shapes is None:
raise ValueError('Cannot initialize this node, because we detect '
'"feedforward_shapes" is None. '
'Maybe you need a brainpy.nn.Input instance '
'to instruct the input size.')
in_sizes.update(node._feedforward_shapes)
self.set_feedforward_shapes(in_sizes)
self.set_feedforward_shapes(in_sizes)
self._ff_init()
self.init_state(num_batch)
self._is_state_initialized = True
# feedforward initialization
if self.feedforward_shapes is None:
raise ValueError('Cannot initialize this node, because we detect '
'both "feedforward_shapes" is None. ')
check_integer(num_batch, 'num_batch', min_bound=1, allow_none=False)
self._init_ff_conn()
# initialize state
self._init_state(num_batch)
self._is_state_initialized = True
# set feedback shapes
if not self._is_fb_initialized:
if len(self.fb_senders) > 0:
fb_sizes = dict()
for sender in self.fb_senders.keys():
fb_sizes[sender] = sender.output_shape
self.set_feedback_shapes(fb_sizes)
# feedback initialization
# feedback initialization
if len(self.fb_senders):
# initialize feedback
if not self.is_fb_initialized:
self._fb_init()
else:
self.is_fb_initialized = True
if self.feedback_shapes is not None:
self._init_fb_conn()
# initialize feedback state
self._init_fb_output(num_batch)
self._ is_fb_state _initialized = True
def _check_inputs(self, ff, fb=None):
def _check_inputs(self, ff, fb=None):
# feedforward inputs
# feedforward inputs
@@ -986,8 +1071,7 @@ class Network(Node):
monitors: Optional[Sequence[str]] = None,
monitors: Optional[Sequence[str]] = None,
**kwargs):
**kwargs):
# initialization
# initialization
# self.initialize(ff, fb)
if not (self.is_ff_initialized and self.is_fb_initialized and self.is_state_initialized):
if not self.is_initialized:
raise ValueError('Please initialize the Network first by calling "initialize()" function.')
raise ValueError('Please initialize the Network first by calling "initialize()" function.')
# initialize the forced data
# initialize the forced data
@@ -1038,7 +1122,7 @@ class Network(Node):
forced_states: Dict[str, Tensor] = None,
forced_states: Dict[str, Tensor] = None,
forced_feedbacks: Dict[str, Tensor] = None,
forced_feedbacks: Dict[str, Tensor] = None,
monitors: Dict = None,
monitors: Dict = None,
**kwargs):
**shared_ kwargs):
"""The main computation function of a network.
"""The main computation function of a network.
Parameters
Parameters
@@ -1053,7 +1137,7 @@ class Network(Node):
The fixed feedback for the nodes in the network.
The fixed feedback for the nodes in the network.
monitors: optional, sequence
monitors: optional, sequence
Can be used to monitor the state or the attribute of a node in the network.
Can be used to monitor the state or the attribute of a node in the network.
**kwargs
**shared_ kwargs
Other parameters which will be parsed into every node.
Other parameters which will be parsed into every node.
Returns
Returns
@@ -1077,10 +1161,11 @@ class Network(Node):
parent_outputs = {}
parent_outputs = {}
for i, node in enumerate(self._entry_nodes):
for i, node in enumerate(self._entry_nodes):
ff_ = {node.name: ff[i]}
ff_ = {node.name: ff[i]}
fb_ = {p: (forced_feedbacks[p.name] if (p.name in forced_feedbacks) else p.feedback() )
fb_ = {p: (forced_feedbacks[p.name] if (p.name in forced_feedbacks) else p.fb_output )
for p in self.fb_senders.get(node, [])}
for p in self.fb_senders.get(node, [])}
self._call_a_node(node, ff_, fb_, monitors, forced_states,
self._call_a_node(node, ff_, fb_, monitors, forced_states,
parent_outputs, children_queue, ff_senders, **kwargs)
parent_outputs, children_queue, ff_senders,
**shared_kwargs)
runned_nodes.add(node.name)
runned_nodes.add(node.name)
# run the model
# run the model
@@ -1088,23 +1173,23 @@ class Network(Node):
node = children_queue.pop(0)
node = children_queue.pop(0)
# get feedforward and feedback inputs
# get feedforward and feedback inputs
ff = {p: parent_outputs[p] for p in self.ff_senders.get(node, [])}
ff = {p: parent_outputs[p] for p in self.ff_senders.get(node, [])}
fb = {p: (forced_feedbacks[p.name] if (p.name in forced_feedbacks) else p.feedback() )
fb = {p: (forced_feedbacks[p.name] if (p.name in forced_feedbacks) else p.fb_output )
for p in self.fb_senders.get(node, [])}
for p in self.fb_senders.get(node, [])}
# call the node
# call the node
self._call_a_node(node, ff, fb, monitors, forced_states,
self._call_a_node(node, ff, fb, monitors, forced_states,
parent_outputs, children_queue, ff_senders,
parent_outputs, children_queue, ff_senders,
**kwargs)
# # - remove unnecessary parent outputs -#
# needed_parents = []
# runned_nodes.add(node.name)
# for child in (all_nodes - runned_nodes):
# for parent in self.ff_senders[self.implicit_nodes[child]]:
# needed_parents.append(parent.name)
# for parent in list(parent_outputs.keys()):
# _name = parent.name
# if _name not in needed_parents and _name not in output_nodes:
# parent_outputs.pop(parent)
**shared_ kwargs)
# - remove unnecessary parent outputs - #
needed_parents = []
runned_nodes.add(node.name)
for child in (all_nodes - runned_nodes):
for parent in self.ff_senders[self.implicit_nodes[child]]:
needed_parents.append(parent.name)
for parent in list(parent_outputs.keys()):
_name = parent.name
if _name not in needed_parents and _name not in output_nodes:
parent_outputs.pop(parent)
# returns
# returns
if len(self.exit_nodes) > 1:
if len(self.exit_nodes) > 1:
@@ -1114,7 +1199,8 @@ class Network(Node):
return state, monitors
return state, monitors
def _call_a_node(self, node, ff, fb, monitors, forced_states,
def _call_a_node(self, node, ff, fb, monitors, forced_states,
parent_outputs, children_queue, ff_senders, **kwargs):
parent_outputs, children_queue, ff_senders,
**shared_kwargs):
ff = node.data_pass_func(ff)
ff = node.data_pass_func(ff)
if f'{node.name}.inputs' in monitors:
if f'{node.name}.inputs' in monitors:
monitors[f'{node.name}.inputs'] = ff
monitors[f'{node.name}.inputs'] = ff
@@ -1123,12 +1209,17 @@ class Network(Node):
fb = node.data_pass_func(fb)
fb = node.data_pass_func(fb)
if f'{node.name}.feedbacks' in monitors:
if f'{node.name}.feedbacks' in monitors:
monitors[f'{node.name}.feedbacks'] = fb
monitors[f'{node.name}.feedbacks'] = fb
parent_outputs[node] = node.forward(ff, fb, **kwargs)
parent_outputs[node] = node.forward(ff, fb, **shared_ kwargs)
else:
else:
parent_outputs[node] = node.forward(ff, **kwargs)
if node.name in forced_states: # forced state
parent_outputs[node] = node.forward(ff, **shared_kwargs)
# get the feedback state
if node in self.fb_receivers:
node.set_fb_output(node.feedback(parent_outputs[node], **shared_kwargs))
# forced state
if node.name in forced_states:
node.state.value = forced_states[node.name]
node.state.value = forced_states[node.name]
parent_outputs[node] = forced_states[node.name]
# parent_outputs[node] = forced_states[node.name]
# monitor the values
if f'{node.name}.state' in monitors:
if f'{node.name}.state' in monitors:
monitors[f'{node.name}.state'] = node.state.value
monitors[f'{node.name}.state'] = node.state.value
if f'{node.name}.output' in monitors:
if f'{node.name}.output' in monitors:
@@ -1143,7 +1234,7 @@ class Network(Node):
fig_size: tuple = (10, 10),
fig_size: tuple = (10, 10),
node_size: int = 2000,
node_size: int = 2000,
arrow_size: int = 20,
arrow_size: int = 20,
layout='spectra l_layout'):
layout='shel l_layout'):
"""Plot the node graph based on NetworkX package
"""Plot the node graph based on NetworkX package
Parameters
Parameters
@@ -1155,7 +1246,17 @@ class Network(Node):
arrow_size:int, default to 20
arrow_size:int, default to 20
The size of the arrow
The size of the arrow
layout: str
layout: str
The graph layout. More please see networkx Graph Layout.
The graph layout. The supported layouts are:
- "shell_layout"
- "multipartite_layout"
- "spring_layout"
- "spiral_layout"
- "spectral_layout"
- "random_layout"
- "planar_layout"
- "kamada_kawai_layout"
- "circular_layout"
"""
"""
try:
try:
import networkx as nx
import networkx as nx
@@ -1204,15 +1305,8 @@ class Network(Node):
G.add_edges_from(fb_edges)
G.add_edges_from(fb_edges)
G.add_edges_from(rec_edges)
G.add_edges_from(rec_edges)
assert layout in ['shell_layout',
'multipartite_layout',
'spring_layout',
'spiral_layout',
'spectral_layout',
'random_layout',
'planar_layout',
'kamada_kawai_layout',
'circular_layout']
if layout not in SUPPORTED_LAYOUTS:
raise UnsupportedError(f'Only support layouts: {SUPPORTED_LAYOUTS}')
layout = getattr(nx, layout)(G)
layout = getattr(nx, layout)(G)
plt.figure(figsize=fig_size)
plt.figure(figsize=fig_size)
@@ -1252,10 +1346,12 @@ class Network(Node):
proxie = []
proxie = []
labels = []
labels = []
if len(nodes_trainable):
if len(nodes_trainable):
proxie.append(Line2D([], [], color='white', marker='o', markerfacecolor=trainable_color))
proxie.append(Line2D([], [], color='white', marker='o',
markerfacecolor=trainable_color))
labels.append('Trainable')
labels.append('Trainable')
if len(nodes_untrainable):
if len(nodes_untrainable):
proxie.append(Line2D([], [], color='white', marker='o', markerfacecolor=untrainable_color))
proxie.append(Line2D([], [], color='white', marker='o',
markerfacecolor=untrainable_color))
labels.append('Untrainable')
labels.append('Untrainable')
if len(ff_edges):
if len(ff_edges):
proxie.append(Line2D([], [], color=ff_color, linewidth=2))
proxie.append(Line2D([], [], color=ff_color, linewidth=2))
@@ -1267,8 +1363,7 @@ class Network(Node):
proxie.append(Line2D([], [], color=rec_color, linewidth=2))
proxie.append(Line2D([], [], color=rec_color, linewidth=2))
labels.append('Recurrent')
labels.append('Recurrent')
plt.legend(proxie, labels, scatterpoints=1, markerscale=2,
loc='best')
plt.legend(proxie, labels, scatterpoints=1, markerscale=2, loc='best')
plt.tight_layout()
plt.tight_layout()
plt.show()
plt.show()