#55 Fix format

Merged
laich merged 1 commits from add_layer_26 into master 1 year ago
  1. +10
    -10
      ms_adapter/pytorch/nn/modules/container.py

+ 10
- 10
ms_adapter/pytorch/nn/modules/container.py View File

@@ -96,14 +96,14 @@ class Sequential(Module):
index = _valid_index(len(self), index, self.__class__.__name__)
return list(self._cells.values())[index]

def __setitem__(self, index, cell):
def __setitem__(self, index, module):
cls_name = self.__class__.__name__
if _valid_cell(cell, cls_name):
if _valid_cell(module, cls_name):
prefix, _ = _get_prefix_and_index(self._cells)
index = _valid_index(len(self), index, cls_name)
key = list(self._cells.keys())[index]
self._cells[key] = cell
cell.update_parameters_name(prefix + key + ".")
self._cells[key] = module
module.update_parameters_name(prefix + key + ".")
self.cell_list = list(self._cells.values())

def __delitem__(self, index):
@@ -142,12 +142,12 @@ class Sequential(Module):
for cell in self._cells.values():
cell.set_grad(flag)

def append(self, cell):
def append(self, module):
"""
Appends a given Module to the end of the list.

Args:
cell(Module): The Module to be appended.
module(Module): The Module to be appended.

Examples:
>>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid', weight_init="ones")
@@ -163,11 +163,11 @@ class Sequential(Module):
[[26.999863 26.999863]
[26.999863 26.999863]]]]
"""
if _valid_cell(cell, self.__class__.__name__):
if _valid_cell(module, self.__class__.__name__):
prefix, _ = _get_prefix_and_index(self._cells)
cell.update_parameters_name(prefix + str(len(self)) + ".")
module.update_parameters_name(prefix + str(len(self)) + ".")
self._is_dynamic_name.append(True)
self._cells[str(len(self))] = cell
self._cells[str(len(self))] = module
self.cell_list = list(self._cells.values())

def add_module(self, name, module):
@@ -350,7 +350,7 @@ class ModuleList(_ModuleListBase, Module):
Appends a given Module to the end of the list.

Args:
cell(Module): The subcell to be appended.
module(Module): The subcell to be appended.
"""
if _valid_cell(module, self.__class__.__name__):
if self._auto_prefix:


Loading…
Cancel
Save