@@ -1,168 +1,160 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from abc import abstractmethod
import operator
from itertools import chain
from typing import Dict
import warnings
from collections import OrderedDict, abc as container_abcs
from mindspore.nn.layer.container import _get_prefix_and_index, _valid_index, _valid_cell
from itertools import chain, islice
import operator
from mindtorch.torch.tensor import Tensor, cast_to_adapter_tensor
from mindtorch.torch.nn.parameter import Parameter
from mindtorch.torch._ref import typename
from mindtorch import torch
from .module import Module
from ..parameter import Parameter
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
from typing_extensions import Self
__all__ = ['Container', 'Sequential', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict']
T = TypeVar('T', bound=Module)
# Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList
def _addindent(s_, numSpaces):
s = s_.split('\n')
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s
class Container(Module):
def __init__(self, **kwargs: Any) -> None:
super().__init__()
# DeprecationWarning is ignored by default <sigh>
warnings.warn("nn.Container is deprecated. All of it's functionality "
"is now implemented in nn.Module. Subclass that instead.")
for key, value in kwargs.items():
self.add_module(key, value)
class Sequential(Module):
r"""A sequential container.
Modules will be added to it in the order they are passed in the
constructor. Alternatively, an ``OrderedDict`` of modules can be
passed in. The ``forward()`` method of ``Sequential`` accepts any
input and forwards it to the first module it contains. It then
"chains" outputs to inputs sequentially for each subsequent module,
finally returning the output of the last module.
The value a ``Sequential`` provides over manually calling a sequence
of modules is that it allows treating the whole container as a
single module, such that performing a transformation on the
``Sequential`` applies to each of the modules it stores (which are
each a registered submodule of the ``Sequential``).
What's the difference between a ``Sequential`` and a
:class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it
sounds like--a list for storing ``Module`` s! On the other hand,
the layers in a ``Sequential`` are connected in a cascading way.
Example::
# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
"""
Sequential Module container. For more details about Module, please refer to
A list of Cells will be added to it in the order they are passed in the constructor.
Alternatively, an ordered dict of cells can also be passed in.
_modules: Dict[str, Module] # type: ignore[assignment]
Note:
Sequential and nn.ModuleList are different, ModuleList is a list for storing modules. However,
the layers in a Sequential are connected in a cascading way.
@overload
def __init__(self, *args: Module) -> None:
...
@overload
def __init__(self, arg: 'OrderedDict[str, Module]') -> None:
...
Args:
args (list, OrderedDict): List or OrderedDict of subclass of Module.
Inputs:
- **x** (Tensor) - Tensor with shape according to the first Module in the sequence.
Outputs:
Tensor, the output Tensor with shape depending on the input `x` and defined sequence of Cells.
Raises:
TypeError: If the type of the `args` is not list or OrderedDict.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid', weight_init="ones")
>>> relu = nn.ReLU()
>>> seq = nn.Sequential([conv, relu])
>>> x = Tensor(np.ones([1, 3, 4, 4]), dtype=mindspore.float32)
>>> output = seq(x)
>>> print(output)
[[[[27. 27.]
[27. 27.]]
[[27. 27.]
[27. 27.]]]]
>>> from collections import OrderedDict
>>> d = OrderedDict()
>>> d["conv"] = conv
>>> d["relu"] = relu
>>> seq = nn.Sequential(d)
>>> x = Tensor(np.ones([1, 3, 4, 4]), dtype=mindspore.float32)
>>> output = seq(x)
>>> print(output)
[[[[27. 27.]
[27. 27.]]
[[27. 27.]
[27. 27.]]]]
"""
def __init__(self, *args):
"""Initialize Sequential."""
super(Sequential, self).__init__()
self._is_dynamic_name = []
if len(args) == 1:
cells = args[0]
if isinstance(cells, list):
for index, cell in enumerate(cells):
self.insert_child_to_cell(str(index), cell)
cell.update_parameters_name(str(index) + ".")
self._is_dynamic_name.append(True)
elif isinstance(cells, OrderedDict):
for name, cell in cells.items():
self.insert_child_to_cell(name, cell)
cell.update_parameters_name(name + ".")
self._is_dynamic_name.append(False)
elif isinstance(cells, Module):
for index, cell in enumerate(args):
self.insert_child_to_cell(str(index), cell)
cell.update_parameters_name(str(index) + ".")
self._is_dynamic_name.append(True)
else:
raise TypeError(f"For '{self.__class__.__name__}', the 'args[0]' must be list or orderedDict, "
f"but got {type(cells).__name__}")
super().__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for index, cell in enumerate(args):
self.insert_child_to_cell(str(index), cell)
cell.update_parameters_name(str(index) + ".")
self._is_dynamic_name.append(True)
self.cell_list = list(self._cells.values())
def __getitem__(self, index):
if isinstance(index, slice):
return self.__class__(
OrderedDict(list(self._cells.items())[index]))
if isinstance(index, Tensor):
index = int(index)
index = _valid_index(len(self), index, self.__class__.__name__)
return list(self._cells.values())[index]
def __setitem__(self, index, module):
if isinstance(index, Tensor):
index = int(index)
cls_name = self.__class__.__name__
if _valid_cell(module, cls_name):
prefix, _ = _get_prefix_and_index(self._cells)
index = _valid_index(len(self), index, cls_name)
key = list(self._cells.keys())[index]
self._cells[key] = module
module.update_parameters_name(prefix + key + ".")
self.cell_list = list(self._cells.values())
def __delitem__(self, index):
cls_name = self.__class__.__name__
if isinstance(index, int):
index = _valid_index(len(self), index, cls_name)
key = list(self._cells.keys())[index]
del self._cells[key]
del self._is_dynamic_name[index]
elif isinstance(index, slice):
keys = list(self._cells.keys())[index]
for key in keys:
del self._cells[key]
del self._is_dynamic_name[index]
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var]
"""Get the idx-th item of the iterator."""
size = len(self)
idx = operator.index(idx)
if not -size <= idx < size:
raise IndexError(f'index {idx} is out of range')
idx %= size
return next(islice(iterator, idx, None))
def __getitem__(self, idx: Union[slice, int]) -> Union['Sequential', T]:
if isinstance(idx, slice):
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
else:
raise TypeError(f"For '{cls_name}', the type of index must be int type or slice type, "
f"but got {type(index).__name__}")
prefix, key_index = _get_prefix_and_index(self._cells)
temp_dict = OrderedDict()
for idx, key in enumerate(self._cells.keys()):
cell = self._cells[key]
if self._is_dynamic_name[idx]:
for _, param in cell.parameters_and_names():
param.name = prefix + str(idx) + "." + ".".join(param.name.split(".")[key_index+1:])
temp_dict[str(idx)] = cell
else:
temp_dict[key] = cell
self._cells = temp_dict
self.cell_list = list(self._cells.values())
return self._get_item_by_idx(self._modules.values(), idx)
def __setitem__(self, idx: int, module: Module) -> None:
key: str = self._get_item_by_idx(self._modules.keys(), idx)
return setattr(self, key, module)
def __len__(self):
return len(self._cells)
def __delitem__(self, idx: Union[slice, int]) -> None:
if isinstance(idx, slice):
for key in list(self._modules.keys())[idx]:
delattr(self, key)
else:
key = self._get_item_by_idx(self._modules.keys(), idx)
delattr(self, key)
# To preserve numbering
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
def __bool__(self):
return len(self._cells) != 0
def __len__(self) -> int:
return len(self._modules)
def __add__(self, other):
def __add__(self, other) -> 'Sequential':
if isinstance(other, Sequential):
ret = Sequential()
for layer in self:
self.append(ret, layer)
ret.append( layer)
for layer in other:
self.append(ret, layer)
ret.append( layer)
return ret
else:
raise ValueError('add operator supports only objects '
'of Sequential class, but {} is given.'.format(
str(type(other))))
f'of Sequential class, but {str(type(other))} is given.')
def pop(self, key: Union[int, slice]) -> Module:
v = self[key]
del self[key]
return v
def __iadd__(self, other):
def __iadd__(self, other) -> Self :
if isinstance(other, Sequential):
offset = len(self)
for i, module in enumerate(other):
@@ -170,13 +162,12 @@ class Sequential(Module):
return self
else:
raise ValueError('add operator supports only objects '
'of Sequential class, but {} is given.'.format(
str(type(other))))
f'of Sequential class, but {str(type(other))} is given.')
def __mul__(self, other):
def __mul__(self, other: int ) -> 'Sequential' :
if not isinstance(other, int):
raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
elif other <= 0:
elif ( other <= 0) :
raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
else:
combined = Sequential()
@@ -187,164 +178,85 @@ class Sequential(Module):
offset += 1
return combined
def __rmul__(self, other):
def __rmul__(self, other: int ) -> 'Sequential' :
return self.__mul__(other)
def __imul__(self, other):
def __imul__(self, other: int ) -> Self :
if not isinstance(other, int):
raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
elif other <= 0:
elif ( other <= 0) :
raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
else:
len_original = len(self)
offset = len(self)
for _ in range(other - 1):
for i in range(len_original):
self.add_module(str(i + offset), self._cell s[str(i)])
self.add_module(str(i + offset), self._module s[str(i)])
offset += len_original
return self
def __dir__(self):
keys = Module.__dir__(self )
keys = super().__dir__( )
keys = [key for key in keys if not key.isdigit()]
return keys
def __iter__(self):
return iter(self._cells.values())
@property
def _modules(self):
return self._cells
def __iter__(self) -> Iterator[Module]:
return iter(self._modules.values())
def set_grad(self, flag=True):
self.requires_grad = flag
for cell in self._cells.values():
cell.set_grad(flag)
# NB: We can't really type check this function as the type of input
# may change dynamically (as is tested in
# TestScript.test_sequential_intermediary_types). Cannot annotate
# with Any as TorchScript expects a more precise type
def forward(self, input):
for module in self:
input = module(input)
return input
def append(self, module):
"""
Appends a given Module to the end of the list.
def append(self, module: Module) -> 'Sequential':
r"""Append a given module to the end.
Args:
module(Module): The Module to be appended.
Examples:
>>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid', weight_init="ones")
>>> bn = nn.BatchNorm2d(2)
>>> relu = nn.ReLU()
>>> seq = nn.Sequential([conv, bn])
>>> seq.append(relu)
>>> x = Tensor(np.ones([1, 3, 4, 4]), dtype=mindspore.float32)
>>> output = seq(x)
>>> print(output)
[[[[26.999863 26.999863]
[26.999863 26.999863]]
[[26.999863 26.999863]
[26.999863 26.999863]]]]
module (nn.Module): module to append
"""
if _valid_cell(module, self.__class__.__name__):
prefix, _ = _get_prefix_and_index(self._cells)
module.update_parameters_name(prefix + str(len(self)) + ".")
self._is_dynamic_name.append(True)
self._cells[str(len(self))] = module
self.cell_list = list(self._cells.values())
self.add_module(str(len(self)), module)
return self
def add_module(self, name, module):
if not isinstance(module, Module) and module is not None:
raise TypeError("{} is not a Module subclass".format(
module.__name__))
elif hasattr(self, name) and name not in self._cells:
raise KeyError("attribute '{}' already exists".format(name))
elif '.' in name:
raise KeyError("module name can't contain \".\", got: {}".format(name))
elif name == '':
raise KeyError("module name can't be empty string \"\"")
if _valid_cell(module, self.__class__.__name__):
module.update_parameters_name(name + ".")
self._is_dynamic_name.append(False)
self._cells[name] = module
self.cell_list = list(self._cells.values())
def forward(self, input):
for cell in self.cell_list:
input = cell(input)
return cast_to_adapter_tensor(input)
def pop(self, key):
v = self[key]
del self[key]
return v
def insert(self, index: int, module: Module) -> 'Sequential':
if not isinstance(module, Module):
raise AssertionError(
f'module should be of type: {Module}')
n = len(self._modules)
if not (-n <= index <= n):
raise IndexError(
f'Index out of range: {index}')
if index < 0:
index += n
for i in range(n, index, -1):
self._modules[str(i)] = self._modules[str(i - 1)]
self._modules[str(index)] = module
return self
def extend(self, sequential):
def extend(self, sequential) -> 'Sequential':
for layer in sequential:
self.append(layer)
return self
def insert(self, index, module):
"""
Inserts a given Cell before a given index in the list.
Args:
index(int): The Insert index in the CellList.
cell(Cell): The Cell to be inserted.
"""
cls_name = self.__class__.__name__
idx = _valid_index(len(self), index, cls_name)
_valid_cell(module, cls_name)
length = len(self)
prefix, key_index = _get_prefix_and_index(self._cells)
while length > idx:
if self._auto_prefix:
tmp_cell = self._cells[str(length-1)]
for _, param in tmp_cell.parameters_and_names():
param.name = f'{prefix}{str(length)}{"."}{".".join(param.name.split(".")[key_index+1:])}'
self._cells[str(length)] = self._cells[str(length - 1)]
length -= 1
self._cells[str(idx)] = module
if self._auto_prefix:
module.update_parameters_name(prefix + str(idx) + ".")
self.cell_list = list(self._cells.values())
self._is_dynamic_name.insert(index, True)
#_ModuleListBase is similar to ms.nn._CellListBase
class _ModuleListBase:
"""
An interface for base the Module as list.
The sequential Module may be iterated using the construct method using for-in statement.
But there are some scenarios that the construct method built-in does not fit.
For convenience, we provide an interface that indicates the sequential
Module may be interpreted as list of Cells, so it can be accessed using
iterator or subscript when a sequential Module instantiate is accessed
by iterator or subscript, it will be interpreted as a list of Cells.
"""
def __init__(self):
"""Initialize _ModuleListBase."""
self.__cell_as_list__ = True #for ms jit parse
@abstractmethod
def __len__(self):
pass
@abstractmethod
def __getitem__(self, index):
pass
class ModuleList(Module):
r"""Holds submodules in a list.
class ModuleList(_ModuleListBase, Module):
"""
Holds Cells in a list.
ModuleList can be used like a regular Python list, the Cells it contains have been initialized.
:class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
modules it contains are properly registered, and will be visible by all
:class:`~torch.nn.Module` methods.
Args:
modules (iterable, optional): an iterable of modules to add
modules (iterable, optional): an iterable of modules to add
Example::
Examples:
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self ).__init__()
super().__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
@@ -353,172 +265,154 @@ class ModuleList(_ModuleListBase, Module):
x = self.linears[i // 2](x) + l(x)
return x
"""
def __init__(self, modules=None):
"""Initialize ModuleList."""
_ModuleListBase.__init__(self)
Module.__init__(self)
_modules: Dict[str, Module] # type: ignore[assignment]
def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
super().__init__()
if modules is not None:
self.extend(modules)
self += modules
def __getitem__(self, idx):
if isinstance(idx, Tensor):
idx = int(idx)
cls_name = self.__class__.__name__
def _get_abs_string_index(self, idx):
"""Get the absolute index for the list of modules."""
idx = operator.index(idx)
if not (-len(self) <= idx < len(self)):
raise IndexError(f'index {idx} is out of range')
if idx < 0:
idx += len(self)
return str(idx)
def __getitem__(self, idx: Union[int, slice]) -> Union[Module, 'ModuleList']:
if isinstance(idx, slice):
return self.__class__(list(self._modules.values())[idx])
else:
return self._modules[self._get_abs_string_index(idx)]
def __setitem__(self, idx: int, module: Module) -> None:
idx = self._get_abs_string_index(idx)
return setattr(self, str(idx), module)
def __delitem__(self, idx: Union[int, slice]) -> None:
if isinstance(idx, slice):
return self.__class__(list(self._cells.values())[idx])
if isinstance(idx, int):
idx = _valid_index(len(self), idx, cls_name)
return self._cells[str(idx)]
raise TypeError(f"For '{cls_name}', the type of 'idx' must be int or slice, "
f"but got {type(idx).__name__}.")
def __setitem__(self, idx, module):
if isinstance(idx, Tensor):
idx = int(idx)
cls_name = self.__class__.__name__
if not isinstance(idx, int) and _valid_cell(module, cls_name):
raise TypeError(f"For '{cls_name}', the type of 'idx' must be int, "
f"but got {type(idx).__name__}.")
idx = _valid_index(len(self), idx, cls_name)
if self._auto_prefix:
prefix, _ = _get_prefix_and_index(self._cells)
module.update_parameters_name(prefix + str(idx) + ".")
self._cells[str(idx)] = module
def __delitem__(self, idx):
if isinstance(idx, Tensor):
idx = int(idx)
cls_name = self.__class__.__name__
if isinstance(idx, int):
idx = _valid_index(len(self), idx, cls_name)
del self._cells[str(idx)]
elif isinstance(idx, slice):
keys = list(self._cells.keys())[idx]
for key in keys:
del self._cells[key]
for k in range(len(self._modules))[idx]:
delattr(self, str(k))
else:
raise TypeError(f"For '{cls_name}', the type of 'index' must be int or slice, "
f"but got {type(idx).__name__}.")
# adjust orderedDict
prefix, key_index = _get_prefix_and_index(self._cells)
temp_dict = OrderedDict()
for id, cell in enumerate(self._cells.values()):
if self._auto_prefix:
for _, param in cell.parameters_and_names():
param.name = prefix + str(id) + "." + ".".join(param.name.split(".")[key_index+1:])
temp_dict[str(id)] = cell
self._cells = temp_dict
def __len__(self):
return len(self._cells)
def __iter__(self):
return iter(self._cells.values())
def __iadd__(self, modules):
delattr(self, self._get_abs_string_index(idx))
# To preserve numbering, self._modules is being reconstructed with modules after deletion
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
def __len__(self) -> int:
return len(self._modules)
def __iter__(self) -> Iterator[Module]:
return iter(self._modules.values())
def __iadd__(self, modules: Iterable[Module]) -> Self:
return self.extend(modules)
def __add__(self, other):
def __add__(self, other: Iterable[Module]) -> 'ModuleList':
combined = ModuleList()
for _ , module in enumerate(chain(self, other)):
combined.append( module)
for i, module in enumerate(chain(self, other)):
combined.add_module(str(i), module)
return combined
def __repr__(self):
"""Return a custom repr for ModuleList that compresses repeated module representations."""
list_of_reprs = [repr(item) for item in self]
if len(list_of_reprs) == 0:
return self._get_name() + '()'
start_end_indices = [[0, 0]]
repeated_blocks = [list_of_reprs[0]]
for i, r in enumerate(list_of_reprs[1:], 1):
if r == repeated_blocks[-1]:
start_end_indices[-1][1] += 1
continue
start_end_indices.append([i, i])
repeated_blocks.append(r)
lines = []
main_str = self._get_name() + '('
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
local_repr = f"({start_id}): {b}" # default repr
if start_id != end_id:
n = end_id - start_id + 1
local_repr = f"({start_id}-{end_id}): {n} x {b}"
local_repr = _addindent(local_repr, 2)
lines.append(local_repr)
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str
def __dir__(self):
keys = super(ModuleList, self).__dir__()
keys = super().__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
def pop(self, key):
v = self[key]
del self[key]
return v
def insert(self, index: int, module: Module) -> None:
r"""Insert a given module before a given index in the list.
def insert(self, index, module):
Args:
index (int): index to insert.
module (nn.Module): module to insert
"""
Inserts a given Module before a given index in the list.
for i in range(len(self._modules), index, -1):
self._modules[str(i)] = self._modules[str(i - 1)]
self._modules[str(index)] = module
def append(self, module: Module) -> 'ModuleList':
r"""Append a given module to the end of the list.
Args:
index(int): The Insert index in the ModuleList.
module(Module): The Module to be inserted.
module (nn.Module): module to append
"""
cls_name = self.__class__.__name__
#TODO: after _valid_index fixed, below code can be remove
if len(self) == 0 and index == 0:
idx = index
else:
idx = _valid_index(len(self), index, cls_name)
_valid_cell(module, cls_name)
length = len(self)
prefix, key_index = _get_prefix_and_index(self._cells)
while length > idx:
if self._auto_prefix:
tmp_cell = self._cells[str(length-1)]
for _, param in tmp_cell.parameters_and_names():
param.name = prefix + str(length) + "." + ".".join(param.name.split(".")[key_index+1:])
self._cells[str(length)] = self._cells[str(length - 1)]
length -= 1
self._cells[str(idx)] = module
if self._auto_prefix:
module.update_parameters_name(prefix + str(idx) + ".")
def extend(self, modules):
"""
Appends Cells from a Python iterable to the end of the list.
self.add_module(str(len(self)), module)
return self
Args:
cells(list): The Cells to be extended.
def pop(self, key: Union[int, slice]) -> Module:
v = self[key]
del self[key]
return v
def extend(self, modules: Iterable[Module]) -> Self:
r"""Append modules from a Python iterable to the end of the list.
Raises:
TypeError: If the argument cells are not a list of Cells.
Args:
modules (iterable): iterable of modules to append
"""
cls_name = self.__class__.__name__
if not isinstance(modules, container_abcs.Iterable):
raise TypeError("ModuleList.extend should be called with an "
"iterable, but got " + type(modules).__name__)
prefix, _ = _get_prefix_and_index(self._cells)
for module in modules:
if _valid_cell(module, cls_name):
if self._auto_prefix:
module.update_parameters_name(prefix + str(len(self)) + ".")
self._cells[str(len(self))] = module
offset = len(self)
for i, module in enumerate(modules):
self.add_module(str(offset + i), module)
return self
def append(self, module):
"""
Appends a given Module to the end of the list.
Args:
module(Module): The subcell to be appended.
"""
if _valid_cell(module, self.__class__.__name__):
if self._auto_prefix:
prefix, _ = _get_prefix_and_index(self._cells)
module.update_parameters_name(prefix + str(len(self)) + ".")
self._cells[str(len(self))] = module
# remove forward alltogether to fallback on Module's _forward_unimplemented
def set_grad(self, flag=True):
self.requires_grad = flag
for cell in self._cells.values():
cell.set_grad(flag)
class ModuleDict(Module):
r"""Holds submodules in a dictionary.
:class:`nn.ModuleDict` can be indexed like a regular Python dictionary,
:class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary,
but modules it contains are properly registered, and will be visible by all
:class:`nn.Module` methods.
:class:`~torch.nn.Module` methods.
:class:`nn.ModuleDict` is an **ordered** dictionary that respects
:class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects
* the order of insertion, and
* in :meth:`nn.ModuleDict.update`, the order of the merged
* in :meth:`~torch. nn.ModuleDict.update`, the order of the merged
``OrderedDict``, ``dict`` (started from Python 3.6) or another
:class:`nn.ModuleDict` (the argument to
:meth:`nn.ModuleDict.update`).
:class:`~torch. nn.ModuleDict` (the argument to
:meth:`~torch. nn.ModuleDict.update`).
Note that :meth:`nn.ModuleDict.update` with other unordered mapping
Note that :meth:`~torch. nn.ModuleDict.update` with other unordered mapping
types (e.g., Python's plain ``dict`` before Python version 3.6) does not
preserve the order of the merged mapping.
@@ -530,7 +424,7 @@ class ModuleDict(Module):
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self ).__init__()
super().__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
@@ -546,42 +440,36 @@ class ModuleDict(Module):
return x
"""
def __init__(self, modules=None):
super(ModuleDict, self).__init__()
self.__cell_as_dict__ = True
_modules: Dict[str, Module] # type: ignore[assignment]
def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
super().__init__()
if modules is not None:
self.update(modules)
def __getitem__(self, key):
return self._cell s[key]
def __getitem__(self, key: str ) -> Module :
return self._module s[key]
def __setitem__(self, key, module):
self._update_cell_para_name(key, module)
def __setitem__(self, key: str, module: Module) -> None:
self.add_module(key, module)
def __delitem__(self, key):
del self._cell s[key]
def __delitem__(self, key: str ) -> None :
del self._module s[key]
def __len__(self):
return len(self._cell s)
def __len__(self) -> int :
return len(self._module s)
def __iter__(self):
return iter(self._cell s)
def __iter__(self) -> Iterator[str] :
return iter(self._module s)
def __contains__(self, key):
return key in self._cell s
def __contains__(self, key: str ) -> bool :
return key in self._module s
def _update_cell_para_name(self, key, cell):
if self._auto_prefix:
prefix, _ = _get_prefix_and_index(self._cells)
cell.update_parameters_name(prefix + key + ".")
def clear(self) -> None:
"""Remove all items from the ModuleDict."""
self._modules.clear()
def clear(self):
"""Remove all items from the ModuleDict.
"""
self._cells.clear()
def pop(self, key):
def pop(self, key: str) -> Module:
r"""Remove key from the ModuleDict and return its module.
Args:
@@ -591,32 +479,28 @@ class ModuleDict(Module):
del self[key]
return v
def keys(self):
r"""Return an iterable of the ModuleDict keys.
"""
return self._cells.keys()
def keys(self) -> Iterable[str]:
r"""Return an iterable of the ModuleDict keys."""
return self._modules.keys()
def items(self):
r"""Return an iterable of the ModuleDict key/value pairs.
"""
return self._cells.items()
def items(self) -> Iterable[Tuple[str, Module]]:
r"""Return an iterable of the ModuleDict key/value pairs."""
return self._modules.items()
def values(self):
r"""Return an iterable of the ModuleDict values.
"""
return self._cells.values()
def values(self) -> Iterable[Module]:
r"""Return an iterable of the ModuleDict values."""
return self._modules.values()
def update(self, modules):
r"""Update the :class:`nn.ModuleDict` with the key-value pairs from a
mapping or an iterable, overwriting existing keys.
def update(self, modules: Mapping[str, Module]) -> None:
r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.
.. note::
If :attr:`modules` is an ``OrderedDict``, a :class:`nn.ModuleDict`, or
If :attr:`modules` is an ``OrderedDict``, a :class:`~torch. nn.ModuleDict`, or
an iterable of key-value pairs, the order of new elements in it is preserved.
Args:
modules (iterable): a mapping (dictionary) from string to :class:`nn.Module`,
or an iterable of key-value pairs of type (string, :class:`nn.Module`)
modules (iterable): a mapping (dictionary) from string to :class:`~torch. nn.Module`,
or an iterable of key-value pairs of type (string, :class:`~torch. nn.Module`)
"""
if not isinstance(modules, container_abcs.Iterable):
raise TypeError("ModuleDict.update should be called with an "
@@ -645,15 +529,15 @@ class ModuleDict(Module):
class ParameterList(Module):
"""Holds parameters in a list.
r """Holds parameters in a list.
:class:`nn.ParameterList` can be used like a regular Python
list, but Tensors that are :class:`nn.Parameter` are properly registered,
and will be visible by all :class:`nn.Module` methods.
:class:`~torch. nn.ParameterList` can be used like a regular Python
list, but Tensors that are :class:`~torch. nn.Parameter` are properly registered,
and will be visible by all :class:`~torch. nn.Module` methods.
Note that the constructor, assigning an element of the list, the
:meth:`nn.ParameterDict.append` method and the :meth:`nn.ParameterDict.extend`
method will convert any :class:`Tensor` into :class:`nn.Parameter`.
:meth:`~torch. nn.ParameterDict.append` method and the :meth:`~torch. nn.ParameterDict.extend`
method will convert any :class:`~torch. Tensor` into :class:`~torch. nn.Parameter`.
Args:
parameters (iterable, optional): an iterable of elements to add to the list.
@@ -662,8 +546,8 @@ class ParameterList(Module):
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self ).__init__()
self.params = nn.ParameterList([nn.Parameter(ms_ torch.randn(10, 10)) for i in range(10)])
super().__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
def forward(self, x):
# ParameterList can act as an iterable, or be indexed using ints
@@ -672,21 +556,29 @@ class ParameterList(Module):
return x
"""
def __init__(self, values=None):
super(ParameterList, self ).__init__()
def __init__(self, values: Optional[Iterable[Any]] = None) -> None :
super().__init__()
self._size = 0
if values is not None:
self += values
def _get_abs_string_index(self, idx):
"""Get the absolute index for the list of modules"""
"""Get the absolute index for the list of modules. """
idx = operator.index(idx)
if not -len(self) <= idx < len(self):
raise IndexError('index {} is out of range'.format(idx) )
if not ( -len(self) <= idx < len(self) ):
raise IndexError(f 'index {idx } is out of range')
if idx < 0:
idx += len(self)
return str(idx)
@overload
def __getitem__(self, idx: int) -> Any:
...
@overload
def __getitem__(self: T, idx: slice) -> T:
...
def __getitem__(self, idx):
if isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
@@ -698,33 +590,33 @@ class ParameterList(Module):
idx = self._get_abs_string_index(idx)
return getattr(self, str(idx))
def __setitem__(self, idx, param):
def __setitem__(self, idx: int , param: Any ) -> None :
# Note that all other function that add an entry to the list part of
# the ParameterList end up here. So this is the only place where we need
# to wrap things into Parameter if needed.
# Objects added via setattr() are not in the list part and thus won't
# call into this function.
idx = self._get_abs_string_index(idx)
if isinstance(param, Tensor) and not isinstance(param, Parameter):
if isinstance(param, torch. Tensor) and not isinstance(param, Parameter):
param = Parameter(param)
return setattr(self, str(idx), param)
def __len__(self):
def __len__(self) -> int :
return self._size
def __iter__(self):
def __iter__(self) -> Iterator[Any] :
return iter(self[i] for i in range(len(self)))
def __iadd__(self, parameters):
def __iadd__(self, parameters: Iterable[Any] ) -> Self :
return self.extend(parameters)
def __dir__(self):
keys = super(ParameterList, self ).__dir__()
keys = super().__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
def append(self, value):
"""Appends a given value at the end of the list.
def append(self, value: Any ) -> 'ParameterList' :
"""Append a given value at the end of the list.
Args:
value (Any): value to append
@@ -734,26 +626,26 @@ class ParameterList(Module):
self[new_idx] = value
return self
def extend(self, values):
"""Appends values from a Python iterable to the end of the list.
def extend(self, values: Iterable[Any] ) -> Self :
"""Append values from a Python iterable to the end of the list.
Args:
values (iterable): iterable of values to append
"""
# Tensor is an iterable but we never want to unpack it here
if not isinstance(values, container_abcs.Iterable) or isinstance(values, Tensor):
if not isinstance(values, container_abcs.Iterable) or isinstance(values, torch. Tensor):
raise TypeError("ParameterList.extend should be called with an "
"iterable, but got " + type(values).__name__)
for value in values:
self.append(value)
return self
def extra_repr(self):
def extra_repr(self) -> str :
child_lines = []
for k, p in enumerate(self):
if isinstance(p, Tensor):
if isinstance(p, torch. Tensor):
size_str = 'x'.join(str(size) for size in p.size())
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
device_str = ''
parastr = '{} containing: [{} of size {}{}]'.format(
"Parameter" if isinstance(p, Parameter) else "Tensor",
p.dtype, size_str, device_str)
@@ -767,31 +659,23 @@ class ParameterList(Module):
def __call__(self, *args, **kwargs):
raise RuntimeError('ParameterList should not be called.')
# adpater api, to convert ParameterList to list[Parameter]
def to_list(self):
list_params = []
for i, p in enumerate(self):
p.name = str(i) + "." + p.name
list_params.append(p)
return list_params
class ParameterDict(Module):
"""Holds parameters in a dictionary.
r"""Holds parameters in a dictionary.
ParameterDict can be indexed like a regular Python dictionary, but Parameters it
contains are properly registered, and will be visible by all Module methods.
Other objects are treated as would be done by a regular Python dictionary
:class:`nn.ParameterDict` is an **ordered** dictionary.
:meth:`nn.ParameterDict.update` with other unordered mapping
:class:`~torch. nn.ParameterDict` is an **ordered** dictionary.
:meth:`~torch. nn.ParameterDict.update` with other unordered mapping
types (e.g., Python's plain ``dict``) does not preserve the order of the
merged mapping. On the other hand, ``OrderedDict`` or another :class:`nn.ParameterDict`
merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch. nn.ParameterDict`
will preserve their ordering.
Note that the constructor, assigning an element of the dictionary and the
:meth:`nn.ParameterDict.update` method will convert any :class:`Tensor` into
:class:`nn.Parameter`.
:meth:`~torch. nn.ParameterDict.update` method will convert any :class:`~torch. Tensor` into
:class:`~torch. nn.Parameter`.
Args:
values (iterable, optional): a mapping (dictionary) of
@@ -802,10 +686,10 @@ class ParameterDict(Module):
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self ).__init__()
super().__init__()
self.params = nn.ParameterDict({
'left': nn.Parameter(ms_ torch.randn(5, 10)),
'right': nn.Parameter(ms_ torch.randn(5, 10))
'left': nn.Parameter(torch.randn(5, 10)),
'right': nn.Parameter(torch.randn(5, 10))
})
def forward(self, x, choice):
@@ -813,13 +697,13 @@ class ParameterDict(Module):
return x
"""
def __init__(self, parameters = None):
super(ParameterDict, self ).__init__()
def __init__(self, parameters: Any = None) -> None :
super().__init__()
self._keys: Dict[str, None] = {}
if parameters is not None:
self.update(parameters)
def _key_to_attr(self, key):
def _key_to_attr(self, key: str ) -> str :
if not isinstance(key, str):
raise TypeError("Index given to ParameterDict cannot be used as a key as it is "
f"not a string (type is '{type(key).__name__}'). Open an issue on "
@@ -828,11 +712,11 @@ class ParameterDict(Module):
# Use the key as-is so that `.named_parameters()` returns the right thing
return key
def __getitem__(self, key):
def __getitem__(self, key: str ) -> Any :
attr = self._key_to_attr(key)
return getattr(self, attr)
def __setitem__(self, key, value):
def __setitem__(self, key: str , value: Any ) -> None :
# Note that all other function that add an entry to the dictionary part of
# the ParameterDict end up here. So this is the only place where we need
# to wrap things into Parameter if needed.
@@ -840,36 +724,37 @@ class ParameterDict(Module):
# call into this function.
self._keys[key] = None
attr = self._key_to_attr(key)
if isinstance(value, Tensor) and not isinstance(value, Parameter):
if isinstance(value, torch. Tensor) and not isinstance(value, Parameter):
value = Parameter(value)
setattr(self, attr, value)
def __delitem__(self, key):
def __delitem__(self, key: str ) -> None :
del self._keys[key]
attr = self._key_to_attr(key)
delattr(self, attr)
def __len__(self):
def __len__(self) -> int :
return len(self._keys)
def __iter__(self):
def __iter__(self) -> Iterator[str] :
return iter(self._keys)
def __reversed__(self):
def __reversed__(self) -> Iterator[str] :
return reversed(list(self._keys))
def copy(self):
"""Returns a copy of this :class:`nn.ParameterDict` instance.
"""
def copy(self) -> 'ParameterDict':
"""Return a copy of this :class:`~torch.nn.ParameterDict` instance."""
# We have to use an OrderedDict because the ParameterDict constructor
# behaves differently on plain dict vs OrderedDict
return ParameterDict(OrderedDict((k, self[k]) for k in self._keys))
def __contains__(self, key):
def __contains__(self, key: str ) -> bool :
return key in self._keys
def setdefault(self, key, default = None):
"""If key is in the ParameterDict, return its value.
def setdefault(self, key: str, default: Optional[Any] = None) -> Any:
"""Set the default for a key in the Parameterdict.
If key is in the ParameterDict, return its value.
If not, insert `key` with a parameter `default` and return `default`.
`default` defaults to `None`.
@@ -877,18 +762,16 @@ class ParameterDict(Module):
key (str): key to set default for
default (Any): the parameter set to the key
"""
if key not in self:
self[key] = default
return self[key]
def clear(self):
"""Remove all items from the ParameterDict.
"""
def clear(self) -> None:
"""Remove all items from the ParameterDict."""
for k in self._keys.copy():
del self[k]
def pop(self, key):
def pop(self, key: str ) -> Any :
r"""Remove key from the ParameterDict and return its parameter.
Args:
@@ -898,10 +781,8 @@ class ParameterDict(Module):
del self[key]
return v
def popitem(self):
"""Remove and return the last inserted `(key, parameter)` pair
from the ParameterDict
"""
def popitem(self) -> Tuple[str, Any]:
"""Remove and return the last inserted `(key, parameter)` pair from the ParameterDict."""
k, _ = self._keys.popitem()
# We need the key in the _keys to be able to access/del
self._keys[k] = None
@@ -909,9 +790,8 @@ class ParameterDict(Module):
del self[k]
return k, val
def get(self, key, default = None):
r"""Return the parameter associated with key if present.
Otherwise return default if provided, None if not.
def get(self, key: str, default: Optional[Any] = None) -> Any:
r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not.
Args:
key (str): key to get from the ParameterDict
@@ -919,42 +799,38 @@ class ParameterDict(Module):
"""
return self[key] if key in self else default
def fromkeys(self, keys, default = None):
r"""Return a new ParameterDict with the keys provided
def fromkeys(self, keys: Iterable[str] , default: Optional[Any] = None) -> 'ParameterDict' :
r"""Return a new ParameterDict with the keys provided.
Args:
keys (iterable, string): keys to make the new ParameterDict from
default (Parameter, optional): value to set for all keys
"""
return ParameterDict((( k, default) for k in keys))
return ParameterDict((k, default) for k in keys)
def keys(self):
r"""Return an iterable of the ParameterDict keys.
"""
def keys(self) -> Iterable[str]:
r"""Return an iterable of the ParameterDict keys."""
return self._keys.keys()
def items(self):
r"""Return an iterable of the ParameterDict key/value pairs.
"""
def items(self) -> Iterable[Tuple[str, Any]]:
r"""Return an iterable of the ParameterDict key/value pairs."""
return ((k, self[k]) for k in self._keys)
def values(self):
r"""Return an iterable of the ParameterDict values.
"""
def values(self) -> Iterable[Any]:
r"""Return an iterable of the ParameterDict values."""
return (self[k] for k in self._keys)
def update(self, parameters):
r"""Update the :class:`~nn.ParameterDict` with the key-value pairs from a
mapping or an iterable, overwriting existing keys.
def update(self, parameters: Union[Mapping[str, Any], 'ParameterDict']) -> None:
r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys.
.. note::
If :attr:`parameters` is an ``OrderedDict``, a :class:`~nn.ParameterDict`, or
If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch. nn.ParameterDict`, or
an iterable of key-value pairs, the order of new elements in it is preserved.
Args:
parameters (iterable): a mapping (dictionary) from string to
:class:`~nn.Parameter`, or an iterable of
key-value pairs of type (string, :class:`~nn.Parameter`)
:class:`~torch. nn.Parameter`, or an iterable of
key-value pairs of type (string, :class:`~torch. nn.Parameter`)
"""
if not isinstance(parameters, container_abcs.Iterable):
raise TypeError("ParametersDict.update should be called with an "
@@ -980,15 +856,15 @@ class ParameterDict(Module):
# parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
self[p[0]] = p[1] # type: ignore[assignment]
def extra_repr(self):
def extra_repr(self) -> str :
child_lines = []
for k, p in self.items():
if isinstance(p, Tensor):
if isinstance(p, torch. Tensor):
size_str = 'x'.join(str(size) for size in p.size())
device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
device_str = ''
parastr = '{} containing: [{} of size {}{}]'.format(
"Parameter" if isinstance(p, Parameter) else "Tensor",
typename(p), size_str, device_str)
torch.t ypename(p), size_str, device_str)
child_lines.append(' (' + str(k) + '): ' + parastr)
else:
child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__)
@@ -998,22 +874,16 @@ class ParameterDict(Module):
def __call__(self, input):
raise RuntimeError('ParameterDict should not be called.')
def __or__(self, other):
def __or__(self, other: 'ParameterDict' ) -> 'ParameterDict' :
copy = self.copy()
copy.update(other)
return copy
def __ror__(self, other):
def __ror__(self, other: 'ParameterDict' ) -> 'ParameterDict' :
copy = other.copy()
copy.update(self)
return copy
def __ior__(self, other):
def __ior__(self, other : 'ParameterDict' ) -> Self :
self.update(other)
return self
def to_dict(self):
new_dict = {}
for key in self._keys:
new_dict[key] = self[key]
return new_dict
没有使用到的导入包请删除