直接创建一个tensor的方法快一点吧?
@@ -0,0 +1,20 @@ | |||
from ._utils import _set_obj_state | |||
def _rebuild_from_type(func, type, args, dict): | |||
from mindtorch.torch.tensor import Tensor # pylint: disable=R0401, C0415 | |||
if type is Tensor: | |||
return func(*args) | |||
ret = func(*args).as_subclass(type) | |||
hanjr marked this conversation as resolved
|
|||
ret.__dict__ = dict | |||
return ret | |||
def _rebuild_from_type_v2(func, new_type, args, state): | |||
from mindtorch.torch.tensor import Tensor # pylint: disable=R0401, C0415 | |||
ret = func(*args) | |||
if not isinstance(ret, new_type): | |||
ret = ret.as_subclass(new_type) | |||
if getattr(ret.__class__, "__setstate__", Tensor.__setstate__) is not Tensor.__setstate__: | |||
ret.__setstate__(state) | |||
else: | |||
ret = _set_obj_state(ret, state) | |||
return ret |
@@ -1,5 +1,6 @@ | |||
import sys | |||
import traceback | |||
import copyreg | |||
import mindtorch.torch.common.dtype as _dtype | |||
from mindtorch.torch.common.dtype import finfo, iinfo | |||
from mindtorch.utils import unsupported_attr | |||
@@ -51,6 +52,53 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs): | |||
warning("'async' is deprecated; use 'non_blocking'") | |||
return kwargs['async'] | |||
def _rebuild_tensor(storage, storage_offset, size, stride): | |||
unsupported_attr(stride) | |||
from mindtorch.torch.tensor import tensor # pylint: disable=R0401, C0415 | |||
t = tensor([], dtype=storage.dtype, device=storage._untyped().device) | |||
return t.set_(storage._untyped(), storage_offset, size) | |||
zoulq commented 1 month ago
Review
直接创建一个tensor的方法快一点吧? 直接创建一个tensor的方法快一点吧?
Erpim commented 3 weeks ago
Review
不可以直接创建tensor,加载场景是通过修改storage的值,同步改变tensor的值,如果直接场景tensor,外部storage和tensor直接没有建立连接,不会同步更新 不可以直接创建tensor,加载场景是通过修改storage的值,同步改变tensor的值,如果直接场景tensor,外部storage和tensor直接没有建立连接,不会同步更新
Erpim commented 3 weeks ago
Review
不可以直接创建tensor,加载场景是通过修改storage的值,同步改变tensor的值,如果直接场景tensor,外部storage和tensor直接没有建立连接,不会同步更新 不可以直接创建tensor,加载场景是通过修改storage的值,同步改变tensor的值,如果直接场景tensor,外部storage和tensor直接没有建立连接,不会同步更新
|
|||
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None): | |||
unsupported_attr(backward_hooks) | |||
unsupported_attr(metadata) | |||
tensor = _rebuild_tensor(storage, storage_offset, size, stride) | |||
tensor.requires_grad = requires_grad | |||
return tensor | |||
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad): | |||
from mindtorch.torch.functional import from_numpy # pylint: disable=R0401, C0415 | |||
tensor = from_numpy(data).to(dtype=dtype, device=device) | |||
tensor.requires_grad = requires_grad | |||
return tensor | |||
def _rebuild_parameter(data, requires_grad, backward_hooks): | |||
unsupported_attr(backward_hooks) | |||
from mindtorch.torch.nn import Parameter # pylint: disable=R0401, C0415 | |||
param = Parameter(data, requires_grad) | |||
param.set_(data.storage()._untyped(), 0, data.size()) | |||
return param | |||
def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state): | |||
unsupported_attr(backward_hooks) | |||
from mindtorch.torch.nn import Parameter # pylint: disable=R0401, C0415 | |||
param = Parameter(data, requires_grad) | |||
param._backward_hooks = backward_hooks | |||
param = _set_obj_state(param, state) | |||
return param | |||
def _rebuild_mindtorch_parameter(data, requires_grad, name, layerwise_parallel): | |||
from mindtorch.torch.nn import Parameter # pylint: disable=R0401, C0415 | |||
param = Parameter(data, requires_grad, name, layerwise_parallel) | |||
return param | |||
def _rebuild_mindtorch_parameter_with_state(data, requires_grad, name, layerwise_parallel, state): | |||
from mindtorch.torch.nn import Parameter # pylint: disable=R0401, C0415 | |||
param = Parameter(data, requires_grad, name, layerwise_parallel) | |||
param = _set_obj_state(param, state) | |||
return param | |||
def _import_dotted_name(name): | |||
components = name.split('.') | |||
obj = __import__(components[0]) | |||
@@ -127,3 +175,42 @@ def _unflatten_dense_tensors(flat, tensors): | |||
unsupported_attr(flat) | |||
unsupported_attr(tensors) | |||
raise NotImplementedError("`_unflatten_dense_tensors` is not implemented now.") | |||
def _set_obj_state(obj, state): | |||
if isinstance(state, tuple): | |||
if not len(state) == 2: | |||
raise RuntimeError(f"Invalid serialized state: {state}") | |||
dict_state = state[0] | |||
slots_state = state[1] | |||
else: | |||
dict_state = state | |||
slots_state = None | |||
if dict_state: | |||
for k, v in dict_state.items(): | |||
setattr(obj, k, v) | |||
if slots_state: | |||
for k, v in slots_state.items(): | |||
setattr(obj, k, v) | |||
return obj | |||
def _get_obj_state(obj): | |||
getstate_fn = getattr(obj, "__getstate__", None) | |||
if getstate_fn is not None: | |||
state = getstate_fn() | |||
else: | |||
slots_to_save = copyreg._slotnames(obj.__class__) # type: ignore[attr-defined] | |||
if slots_to_save: | |||
state = ( | |||
obj.__dict__, | |||
{ | |||
name: getattr(obj, name) | |||
for name in slots_to_save | |||
if hasattr(obj, name) | |||
}, | |||
) | |||
else: | |||
state = obj.__dict__.copy() | |||
return state |
@@ -12,7 +12,11 @@ except ImportError: | |||
import mindspore as ms | |||
from mindspore import ops | |||
from mindspore.common import dtype as mstype | |||
from mindspore.scipy.ops import SolveTriangular | |||
try: | |||
from mindspore.scipy.ops import SolveTriangular # not support on win cpu | |||
except ImportError: | |||
# do nothings here. | |||
... | |||
from mindspore.ops.primitive import _primexpr | |||
from mindspore.ops._primitive_cache import _get_cache_prim | |||
from mindspore._c_expression import Tensor as ms_Tensor_ | |||
@@ -3,7 +3,10 @@ | |||
import mindspore as ms | |||
from mindspore.ops.primitive import _primexpr | |||
from mindspore.scipy.ops import SolveTriangular | |||
try: | |||
from mindspore.scipy.ops import SolveTriangular# not support on win cpu | |||
except ImportError: | |||
... | |||
from mindtorch.torch.common._inner import _out_inplace_assign | |||
from mindtorch.utils import unsupported_attr, pynative_mode_condition, \ | |||
is_under_gpu_context, is_under_ascend_context, set_multiple_name_tuple | |||
@@ -15,8 +15,7 @@ from mindspore.parallel._ps_context import _insert_accumu_init_info | |||
from mindtorch.torch.tensor import Tensor, cast_to_ms_tensor, cast_to_adapter_tensor | |||
from mindtorch.torch.common.dtype import _msdtype2typeDict | |||
from mindtorch.torch.functional import empty as torch_empty | |||
from mindtorch.utils import graph_mode_condition | |||
from mindtorch.torch import _utils | |||
__all__ = ['Parameter', 'ParameterTuple', 'UninitializedParameter', 'UninitializedBuffer'] | |||
def init_to_value(init): | |||
@@ -41,7 +40,11 @@ def init_to_value(init): | |||
class Parameter(ms.Parameter): | |||
_base_type = {} | |||
def __new__(cls, data, *args, **kwargs): | |||
def __new__(cls, data=None, requires_grad=True, name=None, layerwise_parallel=False, # pylint: disable = W0613 | |||
parallel_optimizer=True): # pylint: disable = W0613 | |||
if data is None: | |||
data = 1 | |||
init_data_flag = bool(isinstance(data, ms.Tensor) and data.has_init) | |||
rc = sys.getrefcount(data) | |||
input_class, *class_init_args = Parameter._get_parameter_new_args(data, rc) | |||
@@ -55,16 +58,22 @@ class Parameter(ms.Parameter): | |||
return obj | |||
def __reduce_ex__(self, _): | |||
data = self | |||
state = _utils._get_obj_state(self) | |||
if self.init_mode is not None: | |||
data = self.init_mode | |||
else: | |||
# cast to break deep infinite loop while deepcopy | |||
data = ms.Tensor(self) | |||
return ( | |||
Parameter, (data, self.requires_grad, self.name, self.layerwise_parallel)) | |||
if not state: | |||
return (_utils._rebuild_mindtorch_parameter, (data, self.requires_grad, self.name, | |||
self.layerwise_parallel)) | |||
return (_utils._rebuild_mindtorch_parameter_with_state, (data, self.requires_grad, self.name, | |||
self.layerwise_parallel, state)) | |||
def __init__(self, data, requires_grad=True, name=None, layerwise_parallel=False, parallel_optimizer=True): | |||
def __init__(self, data=None, requires_grad=True, name=None, layerwise_parallel=False, parallel_optimizer=True): | |||
if data is None: | |||
data = 1 | |||
self.adapter_flag = True | |||
super().__init__(default_input=data, name=name, requires_grad=requires_grad, | |||
layerwise_parallel=layerwise_parallel, parallel_optimizer=parallel_optimizer) | |||
@@ -185,23 +194,22 @@ class Parameter(ms.Parameter): | |||
def shape(self): | |||
return self._shape | |||
def set_(self, source=None, storage_offset=0, size=None, stride=None): | |||
if storage_offset or size or stride: | |||
raise ValueError("Currently, `Parameter.set_` specifying `storage_offset`, " | |||
"`size` or `stride` are not supported.") | |||
if source is None: | |||
raise ValueError("Currently, `Parameter.set_` only supported specify the `source`, " \ | |||
"please ensure that it is not None.") | |||
if graph_mode_condition(): | |||
raise RuntimeError('`Parameter.set_` is an in-place operation and "x.set_()" is not supported to use ' | |||
'in MindSpore static graph mode.') | |||
source = cast_to_ms_tensor(source) | |||
self.set_data(source, True) | |||
return self | |||
def __setstate__(self, state): | |||
if isinstance(state, tuple): | |||
if len(state) == 4: | |||
self.set_(*state) | |||
return | |||
elif len(state) == 5: | |||
data = state[0] | |||
Parameter.__init__(self, data, requires_grad=state[3]) | |||
self.set_dtype(data.dtype) | |||
self.set_data(data=data, slice_shape=True) | |||
self._requires_grad = state[3] | |||
return | |||
def __getstate__(self): | |||
state = {key: value for key, value in self.__dict__.items() if key not in Parameter().__dict__} | |||
return state | |||
def _init_parameter_api(): | |||
param_func = dir(Parameter) | |||
@@ -2,6 +2,7 @@ | |||
# pylint: disable=unused-argument | |||
# pylint: disable=eval-used | |||
# pylint: disable=broad-except | |||
import difflib | |||
import os | |||
import io | |||
import struct | |||
@@ -11,23 +12,21 @@ import pathlib | |||
import shutil | |||
import zipfile | |||
import tarfile | |||
import warnings | |||
import tempfile | |||
import operator | |||
import inspect | |||
from functools import reduce | |||
from dataclasses import dataclass | |||
from enum import Enum | |||
from contextlib import closing, contextmanager | |||
from collections.abc import Mapping, Sequence | |||
from typing import Any, BinaryIO, Union, IO, Optional, Type, Dict, Tuple | |||
from typing_extensions import TypeAlias | |||
from ml_dtypes import bfloat16 | |||
import numpy as np | |||
from mindtorch.module_hooker import torch_disable, torch_pop | |||
from mindtorch.torch import _utils | |||
from mindtorch.torch.storage import _UntypedStorage, _TypedStorage | |||
from mindtorch.torch.tensor import tensor, Tensor | |||
from mindtorch.torch.nn.modules.module import Module, Parameter | |||
from mindtorch.torch.nn.modules.module import Module | |||
from mindtorch.torch.logging import warning | |||
import mindtorch.torch.common.dtype as _dtype | |||
from mindtorch.torch.storage import _get_dtype_from_pickle_storage_type | |||
DEFAULT_PROTOCOL = 2 | |||
LONG_SIZE = struct.Struct('=l').size | |||
@@ -46,41 +45,29 @@ __all__ = [ | |||
'load', | |||
] | |||
dtype_map = { | |||
"HalfStorage": np.float16, | |||
"FloatStorage": np.float32, | |||
'BFloat16Storage': bfloat16, | |||
'LongStorage': np.int64, | |||
'ByteStorage': np.uint8, | |||
'BoolStorage': np.bool_, | |||
'IntStorage': np.int32, | |||
'ShortStorage': np.int16, | |||
'CharStorage': np.int8, | |||
'DoubleStorage': np.float64, | |||
} | |||
_storage_classes_dict = {_dtype.double: "DoubleStorage", | |||
_dtype.float: "FloatStorage", | |||
_dtype.half: "HalfStorage", | |||
_dtype.long: "LongStorage", | |||
_dtype.int: "IntStorage", | |||
_dtype.int16: "ShortStorage", | |||
_dtype.int8: "CharStorage", | |||
_dtype.uint8: "ByteStorage", | |||
_dtype.bool: "BoolStorage", | |||
_dtype.bfloat16: "BFloat16Storage", | |||
_dtype.cdouble: "ComplexDoubleStorage", | |||
_dtype.cfloat: "ComplexFloatStorage", | |||
} | |||
element_size_map = { | |||
"HalfStorage": 2, | |||
"FloatStorage": 3, | |||
'BFloat16Storage': 2, | |||
'LongStorage': 4, | |||
'ByteStorage': 1, | |||
'BoolStorage': 1 | |||
} | |||
def typename(o): | |||
if isinstance(o, Tensor): | |||
return o.type() | |||
module = '' | |||
class_name = '' | |||
if hasattr(o, '__module__') and o.__module__ != 'builtins' \ | |||
and o.__module__ != '__builtin__' and o.__module__ is not None: | |||
module = o.__module__ + '.' | |||
if hasattr(o, '__qualname__'): | |||
class_name = o.__qualname__ | |||
elif hasattr(o, '__name__'): | |||
class_name = o.__name__ | |||
else: | |||
class_name = o.__class__.__name__ | |||
return module + class_name | |||
class SourceChangeWarning(Warning): | |||
pass | |||
def get_source_lines_and_file(obj, error_msg = None) : | |||
try: | |||
@@ -106,6 +93,12 @@ def mkdtemp(): | |||
finally: | |||
shutil.rmtree(path) | |||
class _HasStorage: | |||
def __init__(self, storage): | |||
self._storage = storage | |||
def storage(self): | |||
return self._storage | |||
class PyTorchFileReader: | |||
@@ -141,6 +134,11 @@ class PyTorchFileReader: | |||
return self.file.getinfo(filename).header_offset | |||
return None | |||
def get_storage_from_record(self, name, numel, dtype): | |||
filename = f"{self.directory}/{name}" | |||
storage = _UntypedStorage | |||
return _HasStorage(storage.from_buffer(self.read_record(name))) | |||
class PyTorchFileWriter: | |||
def __init__(self, file): | |||
@@ -338,44 +336,9 @@ def _should_read_directly(f): | |||
except AttributeError: | |||
return False | |||
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None): | |||
if size == (): | |||
size = () | |||
stride = (1,) | |||
num_elemets = 1 | |||
else: | |||
num_elemets = reduce(operator.mul, size) | |||
array = storage[storage_offset: storage_offset + num_elemets] | |||
origin_dtype = None | |||
if array.dtype == bfloat16: | |||
origin_dtype = 'bfloat16' | |||
array = array.astype(np.float32) | |||
if stride is not None and len(stride) > 1 and stride[0] == 1 and stride[1] > 1: | |||
stride = tuple((s * 4 for s in stride)) | |||
array = np.lib.stride_tricks.as_strided(array, size, stride) | |||
else: | |||
order = "C" | |||
array = array.reshape(size, order=order) | |||
if origin_dtype == 'bfloat16': | |||
return tensor(array, dtype=_dtype.bfloat16) | |||
param = tensor(array) | |||
return param | |||
def _rebuild_parameter(data, requires_grad, backward_hooks): | |||
param = Parameter(data, requires_grad) | |||
param._backward_hooks = backward_hooks | |||
return param | |||
@dataclass | |||
class FakeParameter: | |||
storage: np.ndarray = None | |||
storage_offset: int = None | |||
size: tuple = None | |||
requires_grad: bool = None | |||
def _rebuild_tensor_legacy(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None): | |||
return FakeParameter(storage, storage_offset, size, requires_grad) | |||
def normalize_storage_type(storage_type): | |||
import mindtorch.torch as ms_torch # pylint: disable=R0401, C0415 | |||
return getattr(ms_torch, storage_type.__name__) | |||
def _maybe_decode_ascii(bytes_str: Union[bytes, str]): | |||
if isinstance(bytes_str, bytes): | |||
@@ -430,21 +393,34 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol): | |||
# source_lines, _, source_file = get_source_lines_and_file(obj) | |||
# source = ''.join(source_lines) | |||
# except Exception: | |||
# warnings.warn("Couldn't retrieve source code for container of " | |||
# warning("Couldn't retrieve source code for container of " | |||
# "type " + obj.__name__ + ". It won't be checked " | |||
# "for correctness upon loading.") | |||
# return ('module', obj, source_file, source) | |||
raise NotImplementedError("Do not support save module now. Please use torch.save to save model parameters." | |||
"If you want to save model parameters, " | |||
"please use 'torch.save(net.state_dict(), filename)'") | |||
from mindtorch.torch import is_storage # pylint: disable=R0401, C0415 | |||
if isinstance(obj, _TypedStorage) or is_storage(obj): | |||
storage = None | |||
if isinstance(obj, _TypedStorage): | |||
import mindtorch.torch as ms_torch # pylint: disable=R0401, C0415 | |||
storage = obj._storage | |||
storage_dtype = obj.dtype | |||
storage_type_str = obj.pickle_storage_type() | |||
storage_type = getattr(ms_torch, storage_type_str) | |||
dtype = obj.dtype | |||
storage_numel = obj.size() | |||
elif isinstance(obj, _UntypedStorage): | |||
storage = obj | |||
storage_dtype = _dtype.uint8 | |||
storage_type = normalize_storage_type(type(obj)) | |||
dtype = _dtype.uint8 | |||
storage_numel = storage.nbytes() | |||
else: | |||
raise TypeError(f'type not recognized: {type(obj)}') | |||
if isinstance(obj, (Parameter, Tensor)): | |||
storage = obj | |||
storage_dtype = obj.dtype | |||
storage_type = _storage_classes_dict[obj.dtype] | |||
storage_numel = obj.numel() | |||
storage_dataptr = id(storage) | |||
storage_dataptr = storage.data_ptr() | |||
if storage_dataptr != 0: | |||
if storage_dataptr in storage_dtypes: | |||
if storage_dtype != storage_dtypes[storage_dataptr]: | |||
@@ -455,20 +431,19 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol): | |||
storage_dtypes[storage_dataptr] = storage_dtype | |||
view_metadata: Optional[Tuple[str, int, int]] | |||
storage_key = id_map.setdefault(storage_dataptr, str(len(id_map))) | |||
offset = 0 | |||
storage_key = str(id(storage)) | |||
location = 'cpu' | |||
if storage_key not in serialized_storages: | |||
serialized_storages[storage_key] = (storage, obj.dtype) | |||
serialized_storages[storage_key] = (storage, dtype) | |||
view_metadata = None | |||
mindtorch_info = storage.shape | |||
res = ('storage', | |||
storage_type, | |||
storage_key, | |||
location, | |||
storage_numel, | |||
view_metadata, | |||
mindtorch_info) | |||
view_metadata) | |||
return res | |||
return None | |||
@@ -494,8 +469,7 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol): | |||
f.flush() | |||
for key in serialized_storage_keys: | |||
storage, dtype = serialized_storages[key] | |||
f.write(np.array(storage.numel(), dtype=np.uint64).tobytes()) | |||
f.write(storage.get_bytes()) | |||
storage._write_file(f, _should_read_directly(f), True, _utils._element_size(dtype)) | |||
def _save(obj, zip_file, pickle_module, pickle_protocol): | |||
serialized_storages = {} | |||
@@ -507,14 +481,24 @@ def _save(obj, zip_file, pickle_module, pickle_protocol): | |||
raise NotImplementedError("Do not support save module now. Please use torch.save to save model parameters." | |||
"If you want to save model parameters, " | |||
"please use 'torch.save(net.state_dict(), filename)'") | |||
if isinstance(obj, (Parameter, Tensor)): | |||
storage = obj | |||
storage_dtype = obj.dtype | |||
storage_type = _storage_classes_dict[obj.dtype] | |||
storage_numel = obj.numel() | |||
storage_shape = storage.shape | |||
storage_dataptr = id(storage) | |||
from mindtorch.torch import is_storage # pylint: disable=R0401, C0415 | |||
if isinstance(obj, _TypedStorage) or is_storage(obj): | |||
if isinstance(obj, _TypedStorage): | |||
import mindtorch.torch as ms_torch # pylint: disable=R0401, C0415 | |||
storage = obj._storage | |||
storage_dtype = obj.dtype | |||
storage_type_str = obj.pickle_storage_type() | |||
storage_type = getattr(ms_torch, storage_type_str) | |||
storage_numel = obj.size() | |||
else: | |||
storage = obj | |||
storage_dtype = _dtype.uint8 | |||
storage_type = normalize_storage_type(type(obj)) | |||
storage_numel = storage.nbytes() | |||
storage_dataptr = storage.data_ptr() | |||
if storage_dataptr != 0: | |||
if storage_dataptr in storage_dtypes: | |||
if storage_dtype != storage_dtypes[storage_dataptr]: | |||
@@ -524,19 +508,16 @@ def _save(obj, zip_file, pickle_module, pickle_protocol): | |||
else: | |||
storage_dtypes[storage_dataptr] = storage_dtype | |||
view_metadata: Optional[Tuple[str, int, int]] | |||
storage_key = id_map.setdefault(storage_dataptr, str(len(id_map))) | |||
storage_key = id_map.setdefault(id(storage), str(len(id_map))) | |||
location = 'cpu' | |||
if storage_key not in serialized_storages: | |||
serialized_storages[storage_key] = storage | |||
serialized_storages[storage_key] = storage | |||
return ('storage', | |||
storage_type, | |||
storage_key, | |||
location, | |||
storage_numel) | |||
res = ('storage', | |||
storage_type, | |||
storage_key, | |||
location, | |||
storage_numel, | |||
storage_shape) | |||
return res | |||
return None | |||
data_buf = io.BytesIO() | |||
@@ -549,9 +530,16 @@ def _save(obj, zip_file, pickle_module, pickle_protocol): | |||
for key in sorted(serialized_storages.keys()): | |||
name = f'archive/data/{key}' | |||
storage = serialized_storages[key] | |||
storage_data = storage.get_bytes() | |||
storage_data = storage.inner_data | |||
zip_file.write_record(name, storage_data) | |||
class StorageType(): | |||
def __init__(self, name): | |||
self.dtype = _get_dtype_from_pickle_storage_type(name) | |||
def __str__(self): | |||
return f'StorageType(dtype={self.dtype})' | |||
def load(f: FILE_LIKE, | |||
map_location=None, | |||
@@ -571,38 +559,129 @@ def load(f: FILE_LIKE, | |||
with _open_zipfile_reader(opened_file, ) as opened_zipfile: | |||
if _is_torchscript_zip(opened_zipfile): | |||
raise ValueError('do not support torchscript now') | |||
return _load(opened_zipfile, | |||
torch_disable() | |||
result = _load(opened_zipfile, | |||
pickle_module, | |||
overall_storage=overall_storage, | |||
**pickle_load_args) | |||
return _legacy_load(opened_file, pickle_module, **pickle_load_args) | |||
torch_pop() | |||
return result | |||
torch_disable() | |||
result = _legacy_load(opened_file, pickle_module, **pickle_load_args) | |||
zoulq commented 1 month ago
Review
现在还有要依赖pytorch的场景吗? 现在还有要依赖pytorch的场景吗?
hanjr commented 3 weeks ago
Review
没有依赖pytorch的场景,在load torch的权重时,会读取到保存的torch函数指针,这个地方只是为了保证这个函数指针一定指向mindtorch实现的同名函数位置。 没有依赖pytorch的场景,在load torch的权重时,会读取到保存的torch函数指针,这个地方只是为了保证这个函数指针一定指向mindtorch实现的同名函数位置。
|
|||
torch_pop() | |||
return result | |||
def _legacy_load(f, pickle_module, **pickle_load_args): | |||
deserialized_objects: Dict[int, Any] = {} | |||
class UnpicklerWrapper(pickle_module.Unpickler): | |||
def find_class(self, mod_name, name): | |||
if name == '_rebuild_tensor_v2': | |||
name = '_rebuild_tensor_legacy' | |||
if mod_name == 'torch._utils': | |||
return eval(name) | |||
if mod_name == 'torch': | |||
return str(name) | |||
if isinstance(name, str) and 'Storage' in name: | |||
try: | |||
return StorageType(name) | |||
except KeyError: | |||
pass | |||
return super().find_class(mod_name, name) | |||
def legacy_load(f): | |||
deserialized_objects: Dict[int, Any] = {} | |||
def _check_container_source(container_type, source_file, original_source): | |||
try: | |||
current_source = ''.join(get_source_lines_and_file(container_type)[0]) | |||
except Exception: | |||
warning("Couldn't retrieve source code for container of " | |||
"type " + container_type.__name__ + ". It won't be checked " | |||
"for correctness upon loading.") | |||
return | |||
if original_source != current_source: | |||
if container_type.dump_patches: | |||
file_name = container_type.__name__ + '.patch' | |||
diff = difflib.unified_diff(current_source.split('\n'), | |||
original_source.split('\n'), | |||
source_file, | |||
source_file, lineterm="") | |||
lines = '\n'.join(diff) | |||
try: | |||
with open(file_name, 'a+') as f: | |||
file_size = f.seek(0, 2) | |||
f.seek(0) | |||
if file_size == 0: | |||
f.write(lines) | |||
elif file_size != len(lines) or f.read() != lines: | |||
raise OSError | |||
msg = ("Saved a reverse patch to " + file_name + ". " | |||
"Run `patch -p0 < " + file_name + "` to revert your changes.") | |||
except OSError: | |||
msg = ("Tried to save a patch, but couldn't create a " | |||
"writable file " + file_name + ". Make sure it " | |||
"doesn't exist and your working directory is " | |||
"writable.") | |||
else: | |||
msg = ("you can retrieve the original source code by " | |||
"accessing the object's source attribute or set " | |||
"`torch.nn.Module.dump_patches = True` and use the " | |||
"patch tool to revert the changes.") | |||
msg = f"source code of class '{typename(container_type)}' has changed. {msg}" | |||
warning(msg, SourceChangeWarning) | |||
def legacy_load(file): | |||
deserialized_objects: Dict[int, Any] = {} | |||
def persistent_load(saved_id): | |||
if isinstance(saved_id, tuple): | |||
if all(saved_id[1:]): | |||
_check_container_source(*saved_id) #TODO | |||
return saved_id[0] | |||
return deserialized_objects[int(saved_id)] | |||
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \ | |||
with closing(tarfile.open(fileobj=file, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \ | |||
mkdtemp() as tmpdir: | |||
raise ValueError('do not support legacy load for Pytorch.') | |||
tar.extract('storages', path=tmpdir) | |||
with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as _file: | |||
num_storages = pickle_module.load(_file, **pickle_load_args) | |||
for i in range(num_storages): | |||
args = pickle_module.load(_file, **pickle_load_args) | |||
key, location, storage_type = args | |||
dtype = storage_type.dtype | |||
element_size = _utils._element_size(dtype) | |||
nbytes = np.frombuffer(_file.read(8), np.int64).item() * element_size | |||
data = np.fromfile(_file, dtype=np.uint8, count=nbytes, offset=0) | |||
obj = _UntypedStorage.from_buffer(data) | |||
deserialized_objects[key] = _TypedStorage( | |||
wrap_storage=obj, | |||
dtype=dtype, | |||
_internal=True) | |||
storage_views = pickle_module.load(_file, **pickle_load_args) | |||
for target_cdata, root_cdata, offset, numel in storage_views: | |||
root = deserialized_objects[root_cdata] | |||
element_size = _utils._element_size(root.dtype) | |||
offset_bytes = offset * element_size | |||
deserialized_objects[target_cdata] = _TypedStorage( | |||
wrap_storage=root._untyped()[offset_bytes:offset_bytes + numel * element_size], | |||
dtype=root.dtype, | |||
_internal=True) | |||
tar.extract('tensors', path=tmpdir) | |||
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as _file: | |||
num_tensors = pickle_module.load(_file, **pickle_load_args) | |||
for _ in range(num_tensors): | |||
args = pickle_module.load(_file, **pickle_load_args) | |||
key, storage_id, original_tensor_type = args | |||
storage = deserialized_objects[storage_id] | |||
ndim, = struct.unpack('<i', _file.read(4)) | |||
_file.read(4) | |||
numel = struct.unpack(f'<{ndim}q', _file.read(8 * ndim)) | |||
stride = struct.unpack(f'<{ndim}q', _file.read(8 * ndim)) | |||
storage_offset, = struct.unpack('<q', _file.read(8)) | |||
tmp_tensor = tensor([], dtype=storage.dtype).set_( | |||
storage._untyped(), storage_offset, numel, stride) | |||
deserialized_objects[key] = tmp_tensor | |||
pickle_file = tar.extractfile('pickle') | |||
unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args) | |||
unpickler.persistent_load = persistent_load | |||
result = unpickler.load() | |||
return result | |||
deserialized_objects = {} | |||
@@ -617,25 +696,29 @@ def _legacy_load(f, pickle_module, **pickle_load_args): | |||
"Do not support load module now. Please use 'torch.load' to load model parameters." | |||
"Model parameters should be saved in 'PyTorch' by 'torch.save(net.state_dict(), filename)'.") | |||
if typename == 'storage': | |||
if len(data) == 6: | |||
storage_type, root_key, location, numel, view_metadata, mindtorch_info = data | |||
else: | |||
storage_type, root_key, location, numel, view_metadata = data | |||
storage_type, root_key, location, numel, view_metadata = data | |||
location = _maybe_decode_ascii(location) | |||
dtype = storage_type.dtype | |||
nbytes = numel * _utils._element_size(dtype) | |||
if root_key not in deserialized_objects: | |||
typed_storage = np.empty(numel, dtype_map[storage_type]) | |||
deserialized_objects[root_key] = typed_storage | |||
else: | |||
typed_storage = deserialized_objects[root_key] | |||
obj = _UntypedStorage(nbytes) | |||
deserialized_objects[root_key] = _TypedStorage( | |||
wrap_storage=obj, dtype=dtype) | |||
typed_storage = deserialized_objects[root_key] | |||
if view_metadata is not None: | |||
view_key, offset, view_size = view_metadata | |||
offset_bytes = offset * _utils._element_size(dtype) | |||
view_size_bytes = view_size * _utils._element_size(dtype) | |||
if view_key not in deserialized_objects: | |||
deserialized_objects[view_key] = typed_storage[offset: offset + view_size] | |||
deserialized_objects[view_key] = _TypedStorage( | |||
wrap_storage=typed_storage._storage[offset_bytes:offset_bytes + view_size_bytes], | |||
dtype=dtype) | |||
res = deserialized_objects[view_key] | |||
else: | |||
res = typed_storage | |||
if mindtorch_info is not None: | |||
res = _rebuild_tensor_legacy(res, 0, mindtorch_info, None, False, None) | |||
return res | |||
raise RuntimeError(f"Unknown saved id type: {saved_id[0]}") | |||
@@ -668,55 +751,19 @@ def _legacy_load(f, pickle_module, **pickle_load_args): | |||
unpickler = UnpicklerWrapper(f, **pickle_load_args) | |||
unpickler.persistent_load = persistent_load | |||
result = unpickler.load() | |||
deserialized_storage_keys = pickle_module.load(f, **pickle_load_args) | |||
offset = f.tell() if f_should_read_directly else None | |||
for key in deserialized_storage_keys: | |||
assert key in deserialized_objects | |||
typed_storage = deserialized_objects[key] | |||
f.read(8) | |||
array = np.frombuffer(f.read(typed_storage.nbytes), typed_storage.dtype) | |||
typed_storage[:] = array | |||
if typed_storage.dtype == bfloat16: | |||
assert np.allclose(typed_storage.astype(np.float32), array.astype(np.float32)) | |||
else: | |||
assert np.allclose(typed_storage, array) | |||
typed_storage._storage._set_from_file( | |||
f, 0, f_should_read_directly, | |||
_utils._element_size(typed_storage.dtype)) | |||
if offset is not None: | |||
offset = f.tell() | |||
def result_convert(result): | |||
elem_type = type(result) | |||
if isinstance(result, FakeParameter): | |||
if result.size == (): | |||
num_elemets = 1 | |||
else: | |||
num_elemets = reduce(operator.mul, result.size) | |||
array = result.storage[result.storage_offset: result.storage_offset + num_elemets] | |||
array = array.reshape(result.size) | |||
if array.dtype == bfloat16: | |||
array = array.astype(np.float32) | |||
return tensor(array, dtype=_dtype.bfloat16) | |||
return tensor(array) | |||
elif isinstance(result, Mapping): | |||
try: | |||
return elem_type({key: result_convert(result[key]) for key in result}) | |||
except TypeError: | |||
return {key: result_convert(result[key]) for key in result} | |||
elif isinstance(result, tuple) and hasattr(result, '_fields'): | |||
return elem_type(*(result_convert(d) for d in result)) | |||
elif isinstance(result, (tuple, list)): | |||
return [result_convert(d) for d in result] | |||
elif isinstance(result, Sequence) and not isinstance(result, string_classes): | |||
try: | |||
return elem_type([result_convert(d) for d in result]) | |||
except TypeError: | |||
return [result_convert(d) for d in result] | |||
else: | |||
return result | |||
new_result = result_convert(result) | |||
return new_result | |||
return result | |||
def _load(zip_file, pickle_module, overall_storage=None, pickle_file='data.pkl', **pickle_load_args): | |||
@@ -740,13 +787,20 @@ def _load(zip_file, pickle_module, overall_storage=None, pickle_file='data.pkl', | |||
if not zip_file.has_record(byteordername) and \ | |||
get_default_load_endianness() is None and \ | |||
sys.byteorder == 'big': | |||
warnings.warn("The default load endianness for checkpoints without a byteorder mark " | |||
warning("The default load endianness for checkpoints without a byteorder mark " | |||
"on big endian machines was changed from 'native' to 'little' endian, " | |||
"to avoid this behavior please use " | |||
"torch.serialization.set_default_load_endianness to set " | |||
"the desired default load endianness", | |||
UserWarning) | |||
def load_tensor(dtype, numel, key, location): | |||
name = f'data/{key}' | |||
tmp_storage = zip_file.get_storage_from_record(name, numel, _UntypedStorage).storage()._untyped() | |||
loaded_storages[key] = _TypedStorage( | |||
wrap_storage=tmp_storage,dtype=dtype) | |||
def persistent_load(saved_id): | |||
assert isinstance(saved_id, tuple) | |||
typename = _maybe_decode_ascii(saved_id[0]) | |||
@@ -754,29 +808,17 @@ def _load(zip_file, pickle_module, overall_storage=None, pickle_file='data.pkl', | |||
assert typename == 'storage', \ | |||
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" | |||
shape = None | |||
if len(data) == 5: | |||
storage_type, key, location, numel, shape = data | |||
storage_type, key, location, numel = data | |||
if storage_type is _UntypedStorage: | |||
dtype = _dtype.uint8 | |||
else: | |||
storage_type, key, location, numel = data | |||
name = f'data/{key}' | |||
if name in loaded_storages: | |||
return loaded_storages[name] | |||
dtype = storage_type.dtype | |||
if overall_storage is not None: | |||
array = np.memmap(overall_storage, dtype=dtype_map[storage_type], | |||
offset=zip_file.open_record(name)._fileobj.tell(), shape=(numel,)) | |||
else: | |||
array = np.frombuffer(zip_file.read_record(name), dtype_map[storage_type]) | |||
if shape is not None: | |||
array = np.reshape(array, shape) | |||
if dtype_map[storage_type] == bfloat16: | |||
array = array.astype(np.float32) | |||
array = tensor(array, dtype=_dtype.bfloat16) | |||
else: | |||
array = tensor(array) | |||
loaded_storages[name] = array | |||
return array | |||
if key not in loaded_storages: | |||
nbytes = numel * _utils._element_size(dtype) | |||
load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) | |||
return loaded_storages[key] | |||
load_module_mapping: Dict[str, str] = { | |||
'torch.tensor': 'torch._tensor' | |||
@@ -787,10 +829,11 @@ def _load(zip_file, pickle_module, overall_storage=None, pickle_file='data.pkl', | |||
raise NotImplementedError( | |||
"Do not support load module now. Please use 'torch.load' to load model parameters." | |||
"Model parameters should be saved in 'PyTorch' by 'torch.save(net.state_dict(), filename)'.") | |||
if mod_name == 'torch._utils': | |||
return eval(name) | |||
if mod_name == 'torch': | |||
return str(name) | |||
if isinstance(name, str) and 'Storage' in name: | |||
try: | |||
return StorageType(name) | |||
except KeyError: | |||
pass | |||
mod_name = load_module_mapping.get(mod_name, mod_name) | |||
return super().find_class(mod_name, name) | |||
@@ -4,6 +4,7 @@ import collections | |||
from functools import lru_cache | |||
from ast import literal_eval | |||
from typing import Any | |||
from ml_dtypes import bfloat16 as np_bfloat16 | |||
import mindspore as ms | |||
import mindtorch.torch.common.dtype as _dtype | |||
from mindtorch.torch.common.dtype import _TypeDict | |||
@@ -71,13 +72,20 @@ class _StorageBase(): | |||
self._update_referenced_tensor() | |||
return self | |||
def _update_referenced_tensor(self, strict=True): | |||
def _update_referenced_tensor(self, strict=True, size=None): | |||
if self.referenced_tensor is not None: | |||
np_data = np.frombuffer(self.inner_data, | |||
_TypeDict.get(self.referenced_tensor.dtype)) | |||
if size is not None: | |||
np_data = np_data.reshape(size) | |||
if strict: | |||
np_data = np_data.reshape(self.referenced_tensor.shape) | |||
value = ms.Tensor.from_numpy(np_data) | |||
if np_data.dtype == np_bfloat16: | |||
np_data = np_data.astype(np.float32) | |||
value = ms.Tensor.from_numpy(np_data) | |||
value = value.astype(_dtype.bfloat16) | |||
else: | |||
value = ms.Tensor.from_numpy(np_data) | |||
self.referenced_tensor.assign_value(value) | |||
def nbytes(self): | |||
@@ -154,7 +162,7 @@ class _StorageBase(): | |||
def resize_(self, size): | |||
if size <= self.size(): | |||
self.inner_data = np.frombuffer(self.inner_data, dtype=np.uint8, count=size) | |||
self.inner_data = self.inner_data[:size] | |||
else: | |||
append_data = np.random.randint(0, 255, size=size - self.size(), dtype=np.uint8) | |||
self.inner_data = np.concatenate((self.inner_data, append_data), axis=0) | |||
@@ -173,7 +181,7 @@ class _StorageBase(): | |||
raise RuntimeError("Currently, in `storage._set_from_file` only is_real_file==True supported.") | |||
nbytes = np.frombuffer(f.read(8), np.int64).item() * element_size | |||
array = np.fromfile(f, dtype=np.uint8, count=nbytes, offset=offset) | |||
self.inner_data = array | |||
self.inner_data[:] = array | |||
self._update_referenced_tensor() | |||
return self | |||
@@ -370,7 +378,8 @@ class _TypedStorage: | |||
self[0:len(self)] = value | |||
return self | |||
def __new__(cls, *args, wrap_storage=None, dtype=None, device=None): | |||
def __new__(cls, *args, wrap_storage=None, dtype=None, device=None, _internal=True): | |||
unsupported_attr(_internal) | |||
if cls == _LegacyStorage: | |||
raise RuntimeError("Only child classes of _LegacyStorage can be instantiated") | |||
@@ -436,7 +445,8 @@ class _TypedStorage: | |||
wrap_storage=wrap_storage, | |||
dtype=cls.dtype) | |||
def __init__(self, *args, device=None, dtype=None, wrap_storage=None): | |||
def __init__(self, *args, device=None, dtype=None, wrap_storage=None, _internal=True): | |||
unsupported_attr(_internal) | |||
arg_error_msg = ( | |||
'_TypedStorage.__init__ received an invalid combination ' | |||
'of arguments. Expected one of:\n' | |||
@@ -908,3 +918,11 @@ _storage_classes_dict = {_dtype.double: DoubleStorage, | |||
_dtype.cdouble: ComplexDoubleStorage, | |||
_dtype.cfloat: ComplexFloatStorage, | |||
} | |||
def _get_dtype_from_pickle_storage_type(pickle_storage_type: str): | |||
try: | |||
return _storage_type_to_dtype_map()[pickle_storage_type] | |||
except KeyError as e: | |||
raise KeyError( | |||
f'pickle storage type "{pickle_storage_type}" is not recognized') from e |
@@ -4,13 +4,17 @@ import os | |||
import abc | |||
import numbers | |||
import operator | |||
from collections import OrderedDict | |||
# from functools import reduce, lru_cache | |||
from copy import deepcopy | |||
from functools import reduce | |||
import numpy as np | |||
import mindspore as ms | |||
from mindspore import Tensor as ms_Tensor | |||
from mindspore.scipy.ops import SolveTriangular | |||
try: | |||
from mindspore.scipy.ops import SolveTriangular # not support on win cpu | |||
except ImportError: | |||
... | |||
from mindspore.common import dtype as mstype | |||
import mindspore.ops as P | |||
from mindspore.ops.primitive import _primexpr | |||
@@ -39,6 +43,8 @@ from mindtorch.torch.logging import warning, info | |||
import mindtorch.torch._register_numpy_primitive as numpy_cell | |||
from mindtorch.torch._default_dtype import _not_default_fp32_dtype, get_default_dtype | |||
from mindtorch.torch._C.Size import Size | |||
from mindtorch.torch._tensor import _rebuild_from_type_v2 | |||
from mindtorch.torch import _utils | |||
_dtypeDict = { | |||
'float16': mstype.float16, | |||
@@ -255,6 +261,8 @@ class _TensorMeta(type(ms_Tensor), abc.ABCMeta): | |||
""" | |||
class Tensor(StubTensor, metaclass=_TensorMeta): | |||
layout = property(lambda self: object(), lambda self, v: None, lambda self: None) | |||
def __init__(self, *data, requires_grad=False, dtype=None, inner=False, cast_tensor=False): | |||
if cast_tensor: | |||
if len(data) != 1: | |||
@@ -616,11 +624,38 @@ class Tensor(StubTensor, metaclass=_TensorMeta): | |||
return out | |||
def __getstate__(self): | |||
pickled = {"input_data": self.asnumpy(), "dtype": self.dtype} | |||
return pickled | |||
state = {key: value for key, value in self.__dict__.items() if key not in Tensor().__dict__} | |||
return state | |||
def __reduce_ex__(self, protocol): | |||
state = _utils._get_obj_state(self) | |||
if isinstance(self, Tensor) and not state: | |||
return self._reduce_ex_internal() | |||
func, args = self._reduce_ex_internal() | |||
return (_rebuild_from_type_v2, (func, type(self), args, state)) | |||
def _reduce_ex_internal(self): | |||
backward_hooks = OrderedDict() | |||
args = ( | |||
_TypedStorage( | |||
wrap_storage=self.storage()._untyped(), | |||
dtype=self.dtype), | |||
0, | |||
tuple(self.size()), | |||
self.stride(), | |||
self.requires_grad, | |||
backward_hooks) | |||
return (_utils._rebuild_tensor_v2, args) | |||
def __setstate__(self, state): | |||
Tensor.__init__(self, state["input_data"], dtype=state["dtype"], inner=True) | |||
if isinstance(state, tuple): | |||
if len(state) == 4: | |||
self.set_(*state) | |||
return | |||
elif len(state) == 5: | |||
data = state[0] | |||
Tensor.__init__(self, data, dtype=data.dtype, inner=True, requires_grad=state[3]) | |||
return | |||
@property | |||
def grad_fn(self): | |||
@@ -687,7 +722,7 @@ class Tensor(StubTensor, metaclass=_TensorMeta): | |||
def storage(self): | |||
if graph_mode_condition(): | |||
raise NotImplementedError('Currently, `tensor.storage()` is not supported in graph mode. ' | |||
warning('Currently, `tensor.storage()` is not supported in graph mode. ' | |||
'Please replace `Storage` related interfaces with the equivalent interface.') | |||
if self.dtype == mindtorch_dtype.bfloat16: | |||
@@ -1591,22 +1626,18 @@ class Tensor(StubTensor, metaclass=_TensorMeta): | |||
return cast_to_adapter_tensor(input_ms.copy()) | |||
def set_(self, source=None, storage_offset=0, size=None, stride=None): | |||
if storage_offset or size or stride: | |||
raise ValueError("Currently, `Tensor.set_` specifying `storage_offset`, " | |||
"`size` or `stride` are not supported.") | |||
if source is None: | |||
raise ValueError("Currently, `Tensor.set_` only supported specify the `source`, " \ | |||
"please ensure that it is not None.") | |||
unsupported_attr(storage_offset) | |||
unsupported_attr(stride) | |||
if graph_mode_condition(): | |||
raise RuntimeError('`Tensor.set_` is an in-place operation and "x.set_()" is not supported to use ' | |||
warning('`Tensor.set_` is an in-place operation and "x.set_()" is not supported to use ' | |||
'in MindSpore static graph mode.') | |||
if isinstance(source, Tensor): | |||
if source.dtype != self.dtype: | |||
raise RuntimeError("In `tensor.set_`, sourse.dtype must equal to self.dtype.") | |||
source = cast_to_ms_tensor(source) | |||
if size: | |||
source = source.reshape(size) | |||
self.assign_value(source) | |||
return self | |||
@@ -1615,12 +1646,12 @@ class Tensor(StubTensor, metaclass=_TensorMeta): | |||
if source.dtype != self.dtype: | |||
raise RuntimeError("In `tensor.set_`, _TypedStorage.dtype must equal to self.dtype.") | |||
source._storage.referenced_tensor = self | |||
source._storage._update_referenced_tensor(strict=False) | |||
source._storage._update_referenced_tensor(strict=False, size=size) | |||
return self | |||
# handle source is a _UntypedStorage | |||
source.referenced_tensor = self | |||
source._update_referenced_tensor(strict=False) | |||
source._update_referenced_tensor(strict=False, size=size) | |||
return self | |||
def to(self, *args, **kwargs): | |||
@@ -6,7 +6,7 @@ in `./_utils/worker.py`. | |||
""" | |||
from __future__ import absolute_import | |||
import functools | |||
import sys | |||
import itertools | |||
import logging | |||
import os | |||
@@ -218,7 +218,7 @@ class DataLoader(Generic[T_co]): | |||
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, | |||
shuffle: Optional[bool] = None, sampler: Union[Sampler, Iterable, None] = None, | |||
batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None, | |||
num_workers: int = 1, collate_fn: Optional[_collate_fn_t] = None, | |||
num_workers: int = None, collate_fn: Optional[_collate_fn_t] = None, | |||
pin_memory: bool = False, drop_last: bool = False, | |||
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, | |||
multiprocessing_context=None, generator=None, | |||
@@ -226,6 +226,11 @@ class DataLoader(Generic[T_co]): | |||
persistent_workers: bool = False, | |||
pin_memory_device: str = ""): | |||
# torch._C._log_api_usage_once("python.data_loader") | |||
if num_workers is None: | |||
if sys.platform == "win32": | |||
num_workers = 0 | |||
else: | |||
num_workers = 1 | |||
if num_workers < 0: | |||
raise ValueError('num_workers option should be non-negative; ' | |||
@@ -9,9 +9,9 @@ import numpy as np | |||
from mindspore import context | |||
import pytest | |||
from ...utils import set_mode_by_env_config, SKIP_ENV_GRAPH_MODE | |||
from ...utils import set_mode_by_env_config, SKIP_ENV_GRAPH_MODE, SKIP_ENV_CPU | |||
set_mode_by_env_config() | |||
@SKIP_ENV_CPU(reason="need stable network") | |||
def test_get_dir(): | |||
ms_hub_dir = ms_torch.hub.get_dir() | |||
torch_hub_dir = torch.hub.get_dir() | |||
@@ -3,6 +3,7 @@ from mindtorch.module_hooker import torch_enable, torch_pop | |||
from ...utils import set_mode_by_env_config | |||
set_mode_by_env_config() | |||
@pytest.fixture(scope='function') | |||
def test_import(): | |||
torch_enable() | |||
import torch | |||
@@ -1,11 +1,13 @@ | |||
import os | |||
import pytest | |||
import numpy as np | |||
import torch | |||
import mindtorch.torch as pytorch | |||
from ...utils import set_mode_by_env_config, param_compare | |||
from ...utils import set_mode_by_env_config, param_compare, SKIP_ENV_CPU, SKIP_ENV_GPU, SKIP_ENV_ASCEND | |||
set_mode_by_env_config() | |||
@pytest.fixture(scope='function') | |||
def test_save_load_1(): | |||
state_dict_torch ={} | |||
state_dict_mindtorch = {} | |||
@@ -31,6 +33,7 @@ def test_save_load_1(): | |||
param_compare(state_dict_torch["b"], state_dict_mindtorch["b"]) | |||
assert state_dict_torch["c"] == state_dict_mindtorch["c"] | |||
@pytest.fixture(scope='function') | |||
def test_save_load_2(): | |||
state_dict_torch = {} | |||
state_dict_mindtorch = {} | |||
@@ -56,6 +59,7 @@ def test_save_load_2(): | |||
param_compare(state_dict_torch["b"], state_dict_mindtorch["b"]) | |||
assert state_dict_torch["c"] == state_dict_mindtorch["c"] | |||
@pytest.fixture(scope='function') | |||
def test_save_load_3(): | |||
state_dict_torch = {} | |||
state_dict_mindtorch = {} | |||
@@ -80,7 +84,7 @@ def test_save_load_3(): | |||
param_compare(state_dict_torch["b"], state_dict_mindtorch["b"]) | |||
assert state_dict_torch["c"] == state_dict_mindtorch["c"] | |||
@pytest.fixture(scope='function') | |||
def test_save_load_4(): | |||
state_dict_torch = {} | |||
state_dict_mindtorch = {} | |||
@@ -105,6 +109,7 @@ def test_save_load_4(): | |||
param_compare(state_dict_torch["b"], state_dict_mindtorch["b"]) | |||
assert state_dict_torch["c"] == state_dict_mindtorch["c"] | |||
@pytest.fixture(scope='function') | |||
def test_save_load_bf16_1(): | |||
state_dict_mindtorch = {} | |||
a = pytorch.tensor(800000, dtype=pytorch.bfloat16) | |||
@@ -125,6 +130,7 @@ def test_save_load_bf16_1(): | |||
param_compare(b.to(pytorch.float32), state_dict_mindtorch["b"].to(pytorch.float32)) | |||
assert c == state_dict_mindtorch["c"] | |||
@pytest.fixture(scope='function') | |||
def test_save_load_bf16_2(): | |||
state_dict_mindtorch = {} | |||
a = pytorch.tensor(800000, dtype=pytorch.bfloat16) | |||
@@ -145,7 +151,7 @@ def test_save_load_bf16_2(): | |||
param_compare(b.to(pytorch.float32), state_dict_mindtorch["b"].to(pytorch.float32)) | |||
assert c == state_dict_mindtorch["c"] | |||
@pytest.fixture(scope='function') | |||
def test_save_load_bf16_3(): | |||
state_dict_mindtorch = {} | |||
a = torch.tensor(800000, dtype=torch.bfloat16) | |||
@@ -164,6 +170,7 @@ def test_save_load_bf16_3(): | |||
param_compare(b.to(torch.float32), state_dict_mindtorch["b"].to(pytorch.float32)) | |||
assert c == state_dict_mindtorch["c"] | |||
@pytest.fixture(scope='function') | |||
def test_save_load_bf16_4(): | |||
state_dict_mindtorch = {} | |||
a = torch.tensor(800000, dtype=torch.bfloat16) | |||
@@ -182,6 +189,229 @@ def test_save_load_bf16_4(): | |||
param_compare(b.to(torch.float32), state_dict_mindtorch["b"].to(pytorch.float32)) | |||
assert c == state_dict_mindtorch["c"] | |||
@pytest.fixture(scope='function') | |||
def test_save_load_parameter_1(): | |||
state_dict_torch ={} | |||
state_dict_mindtorch = {} | |||
a = np.random.rand(3, 3).astype(np.float32) | |||
b = np.random.rand(1, 64,64, 3).astype(np.float32) | |||
c = 1 | |||
state_dict_torch["a"] = torch.nn.Parameter(torch.tensor(a)) | |||
state_dict_torch["b"] = torch.nn.Parameter(torch.tensor(b)) | |||
state_dict_torch["c"] = c | |||
state_dict_mindtorch["a"] = pytorch.nn.Parameter(pytorch.tensor(a)) | |||
state_dict_mindtorch["b"] = pytorch.nn.Parameter(pytorch.tensor(b)) | |||
state_dict_mindtorch["c"] = c | |||
torch.save(state_dict_torch, "test_save_load_parameter_1_torch.pth") | |||
pytorch.save(state_dict_mindtorch, "test_save_load_parameter_1_mindtorch.pth") | |||
state_dict_torch = torch.load("test_save_load_parameter_1_torch.pth") | |||
state_dict_mindtorch = pytorch.load("test_save_load_parameter_1_mindtorch.pth") | |||
os.remove("test_save_load_parameter_1_torch.pth") | |||
os.remove("test_save_load_parameter_1_mindtorch.pth") | |||
param_compare(state_dict_torch["a"].detach(), state_dict_mindtorch["a"]) | |||
param_compare(state_dict_torch["b"].detach(), state_dict_mindtorch["b"]) | |||
assert state_dict_torch["c"] == state_dict_mindtorch["c"] | |||
@pytest.fixture(scope='function') | |||
def test_save_load_parameter_2(): | |||
state_dict_torch ={} | |||
state_dict_mindtorch = {} | |||
a = np.random.rand(3, 3).astype(np.float32) | |||
b = np.random.rand(1, 64,64, 3).astype(np.float32) | |||
c = 1 | |||
state_dict_torch["a"] = torch.nn.Parameter(torch.tensor(a)) | |||
state_dict_torch["b"] = torch.nn.Parameter(torch.tensor(b)) | |||
state_dict_torch["c"] = c | |||
state_dict_mindtorch["a"] = pytorch.nn.Parameter(pytorch.tensor(a)) | |||
state_dict_mindtorch["b"] = pytorch.nn.Parameter(pytorch.tensor(b)) | |||
state_dict_mindtorch["c"] = c | |||
torch.save(state_dict_torch, "test_save_load_parameter_2_torch.pth", _use_new_zipfile_serialization=False) | |||
pytorch.save(state_dict_mindtorch, "test_save_load_parameter_2_mindtorch.pth", _use_new_zipfile_serialization=False) | |||
state_dict_torch = torch.load("test_save_load_parameter_2_torch.pth") | |||
state_dict_mindtorch = pytorch.load("test_save_load_parameter_2_mindtorch.pth") | |||
os.remove("test_save_load_parameter_2_torch.pth") | |||
os.remove("test_save_load_parameter_2_mindtorch.pth") | |||
param_compare(state_dict_torch["a"].detach(), state_dict_mindtorch["a"]) | |||
param_compare(state_dict_torch["b"].detach(), state_dict_mindtorch["b"]) | |||
assert state_dict_torch["c"] == state_dict_mindtorch["c"] | |||
@pytest.fixture(scope='function') | |||
def test_save_load_parameter_3(): | |||
state_dict_torch = {} | |||
state_dict_mindtorch = {} | |||
a = np.random.rand(3, 3).astype(np.float32) | |||
b = np.random.rand(1, 64, 64, 3).astype(np.float32) | |||
c = 1 | |||
state_dict_torch["a"] = torch.nn.Parameter(torch.tensor(a)) | |||
state_dict_torch["b"] = torch.nn.Parameter(torch.tensor(b)) | |||
state_dict_torch["c"] = c | |||
state_dict_mindtorch["a"] = pytorch.nn.Parameter(pytorch.tensor(a)) | |||
state_dict_mindtorch["b"] = pytorch.nn.Parameter(pytorch.tensor(b)) | |||
state_dict_mindtorch["c"] = c | |||
torch.save(state_dict_torch, "test_save_load_parameter_3_torch.pth") | |||
state_dict_torch = torch.load("test_save_load_parameter_3_torch.pth") | |||
state_dict_mindtorch = pytorch.load("test_save_load_parameter_3_torch.pth") | |||
os.remove("test_save_load_parameter_3_torch.pth") | |||
param_compare(state_dict_torch["a"].detach(), state_dict_mindtorch["a"]) | |||
param_compare(state_dict_torch["b"].detach(), state_dict_mindtorch["b"]) | |||
assert state_dict_torch["c"] == state_dict_mindtorch["c"] | |||
@pytest.fixture(scope='function') | |||
def test_save_load_parameter_4(): | |||
state_dict_torch = {} | |||
state_dict_mindtorch = {} | |||
a = np.random.rand(3, 3).astype(np.float32) | |||
b = np.random.rand(1, 64, 64, 3).astype(np.float32) | |||
c = 1 | |||
state_dict_torch["a"] = torch.nn.Parameter(torch.tensor(a)) | |||
state_dict_torch["b"] = torch.nn.Parameter(torch.tensor(b)) | |||
state_dict_torch["c"] = c | |||
state_dict_mindtorch["a"] = pytorch.nn.Parameter(pytorch.tensor(a)) | |||
state_dict_mindtorch["b"] = pytorch.nn.Parameter(pytorch.tensor(b)) | |||
state_dict_mindtorch["c"] = c | |||
torch.save(state_dict_torch, "test_save_load_parameter_4_torch.pth", _use_new_zipfile_serialization=False) | |||
state_dict_torch = torch.load("test_save_load_parameter_4_torch.pth") | |||
state_dict_mindtorch = pytorch.load("test_save_load_parameter_4_torch.pth") | |||
os.remove("test_save_load_parameter_4_torch.pth") | |||
param_compare(state_dict_torch["a"].detach(), state_dict_mindtorch["a"]) | |||
param_compare(state_dict_torch["b"].detach(), state_dict_mindtorch["b"]) | |||
assert state_dict_torch["c"] == state_dict_mindtorch["c"] | |||
@pytest.fixture(scope='function') | |||
def test_save_load_net(): | |||
import torch | |||
import torch.nn as nn | |||
class Net(nn.Module): | |||
def __init__(self, num_classes: int = 10) -> None: | |||
super(Net, self).__init__() | |||
self.features = nn.Sequential( | |||
nn.Conv2d(3, 64, (11, 11), (4, 4), (2, 2), bias=False), | |||
nn.BatchNorm2d(64), | |||
nn.ReLU(), | |||
nn.MaxPool2d((3, 3), (2, 2)), | |||
) | |||
self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) | |||
self.classifier = nn.Sequential( | |||
nn.Dropout(0.5), | |||
nn.Linear(256 * 6 * 6, 4096), | |||
) | |||
net = Net() | |||
state_dict = { | |||
'net': net.state_dict(), | |||
} | |||
torch.save(state_dict, 'torch_module.pt', _use_new_zipfile_serialization=True) | |||
torch.save(state_dict, 'torch_module_oldfile.pt', _use_new_zipfile_serialization=False) | |||
import mindtorch.torch as pytorch | |||
import mindtorch.torch.nn as nn | |||
class Net(nn.Module): | |||
def __init__(self, num_classes: int = 10) -> None: | |||
super(Net, self).__init__() | |||
self.features = nn.Sequential( | |||
nn.Conv2d(3, 64, (11, 11), (4, 4), (2, 2), bias=False), | |||
nn.BatchNorm2d(64), | |||
nn.ReLU(), | |||
nn.MaxPool2d((3, 3), (2, 2)), | |||
) | |||
self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) | |||
self.classifier = nn.Sequential( | |||
nn.Dropout(0.5), | |||
nn.Linear(256 * 6 * 6, 4096), | |||
) | |||
net = Net() | |||
state = pytorch.load("torch_module.pt") | |||
net.load_state_dict(state['net']) | |||
os.remove("torch_module.pt") | |||
state_dict = { | |||
'net': net.state_dict(), | |||
} | |||
pytorch.save(state_dict, 'mindtorch_module.pt', _use_new_zipfile_serialization=True) | |||
os.remove("mindtorch_module.pt") | |||
state = pytorch.load("torch_module_oldfile.pt") | |||
net.load_state_dict(state['net']) | |||
os.remove("torch_module_oldfile.pt") | |||
state_dict = { | |||
'net': net.state_dict(), | |||
} | |||
pytorch.save(state_dict, 'mindtorch_module_oldfile.pt', _use_new_zipfile_serialization=False) | |||
os.remove('mindtorch_module_oldfile.pt') | |||
@pytest.fixture(scope='function') | |||
@SKIP_ENV_ASCEND(reason="This function need torch version >= 2.1.0") | |||
@SKIP_ENV_GPU(reason="This function need torch version >= 2.1.0") | |||
@SKIP_ENV_CPU(reason="This function need torch version >= 2.1.0") | |||
def test_save_load_5(): | |||
a = torch.tensor(2.) | |||
a.kkk = 3 | |||
torch.save(a, 'a.pth') | |||
tensor = pytorch.load('a.pth') | |||
os.remove('a.pth') | |||
assert tensor.kkk == a.kkk | |||
param_compare(a, tensor) | |||
@pytest.fixture(scope='function') | |||
def test_save_load_6(): | |||
a = pytorch.tensor(2.) | |||
a.kkk = 3 | |||
pytorch.save(a, 'a.pth') | |||
tensor = pytorch.load('a.pth') | |||
os.remove('a.pth') | |||
assert tensor.kkk == a.kkk | |||
param_compare(a, tensor) | |||
@pytest.fixture(scope='function') | |||
@SKIP_ENV_ASCEND(reason="This function need torch version >= 2.1.0") | |||
@SKIP_ENV_GPU(reason="This function need torch version >= 2.1.0") | |||
@SKIP_ENV_CPU(reason="This function need torch version >= 2.1.0") | |||
def test_save_load_7(): | |||
a = torch.nn.Parameter(torch.tensor(2.)) | |||
a.kkk = 3 | |||
torch.save(a, 'a.pth') | |||
tensor = pytorch.load('a.pth') | |||
os.remove('a.pth') | |||
assert tensor.kkk == a.kkk | |||
param_compare(a.detach(), tensor) | |||
@pytest.fixture(scope='function') | |||
def test_save_load_8(): | |||
a = pytorch.nn.Parameter(pytorch.tensor(2.)) | |||
a.kkk = 3 | |||
pytorch.save(a, 'a.pth') | |||
tensor = pytorch.load('a.pth') | |||
os.remove('a.pth') | |||
assert tensor.kkk == a.kkk | |||
param_compare(a, tensor) | |||
if __name__ == '__main__': | |||
test_save_load_1() | |||
test_save_load_2() | |||
@@ -191,3 +421,10 @@ if __name__ == '__main__': | |||
test_save_load_bf16_2() | |||
test_save_load_bf16_3() | |||
test_save_load_bf16_4() | |||
test_save_load_parameter_1() | |||
test_save_load_parameter_2() | |||
test_save_load_parameter_3() | |||
test_save_load_parameter_4() | |||
test_save_load_net() | |||
test_save_load_5() | |||
test_save_load_6() |
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》
当前有应该不支持as_subclass?什么场景会进这个函数?
补充了简单实现,浩宇的那个样例会进这个函数
已删除,当前不需要