#918 load/save fix

Merged
Erpim merged 22 commits from load_fix into master 1 week ago
hanjr commented 2 weeks ago
完成: 1、结合storage整改load的流程,fix之前未覆盖的场景,目前整改legacy save/load, zipfile save/load(仅限 torch weights state dict) 2、legacy save/ zipfile save相关流程,以及涉及到的对象__reduce_ex__ 和 __setstate__整改
hanjr changed title from [WIP]load fix to [WIP]load/save fix 2 weeks ago
Erpim reviewed 2 weeks ago
mindtorch/torch/serialization.py
@@ -606,0 +615,4 @@
for i in range(num_storages):
args = pickle_module.load(f, **pickle_load_args)
key, location, storage_type = args
dtype = storage_type._dtype
Erpim commented 2 weeks ago
storage_type没有_dtype对象,可调用dtype。
hanjr commented 1 week ago
done
Erpim reviewed 2 weeks ago
mindtorch/torch/serialization.py
@@ -606,0 +632,4 @@
element_size = _utils._element_size(root.dtype)
offset_bytes = offset * element_size
deserialized_objects[target_cdata] = storage.TypedStorage(
wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size],
Erpim commented 2 weeks ago
storage.TypedStorage 替换为_TypedStorage
hanjr commented 1 week ago
done
Erpim reviewed 2 weeks ago
mindtorch/torch/serialization.py
@@ -606,0 +621,4 @@
data = np.fromfile(f, dtype=np.uint8, count=nbytes, offset=0)
obj = _UntypedStorage.from_buffer(data)

deserialized_objects[key] = storage.TypedStorage(
Erpim commented 2 weeks ago
_TypedStorage
hanjr commented 1 week ago
done
Erpim reviewed 2 weeks ago
mindtorch/torch/serialization.py
@@ -606,0 +624,4 @@
deserialized_objects[key] = storage.TypedStorage(
wrap_storage=obj,
dtype=dtype,
_internal=True)
Erpim commented 2 weeks ago
_TypedStorage 需要新增一个_internal入参
hanjr marked this conversation as resolved
Erpim reviewed 2 weeks ago
mindtorch/torch/serialization.py
@@ -606,0 +634,4 @@
deserialized_objects[target_cdata] = storage.TypedStorage(
wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size],
dtype=root.dtype,
_internal=True)
Erpim commented 2 weeks ago
同上
hanjr marked this conversation as resolved
Erpim reviewed 2 weeks ago
mindtorch/torch/serialization.py
@@ -606,0 +649,4 @@
stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
storage_offset, = struct.unpack('<q', f.read(8))
tmp_tensor = tensor([], dtype=storage.dtype).set_(
storage._untyped_storage, storage_offset, numel, stride)
Erpim commented 2 weeks ago
storage没有_untyped_storage,改为storage._untyped()
hanjr marked this conversation as resolved
Erpim reviewed 2 weeks ago
mindtorch/torch/tensor.py
@@ -621,3 +650,3 @@

def __setstate__(self, state):
Tensor.__init__(self, state["input_data"], dtype=state["dtype"], inner=True)
if not isinstance(state, tuple):
Erpim commented 2 weeks ago
isinstance(state, tuple) 才进
hanjr marked this conversation as resolved
Erpim reviewed 2 weeks ago
mindtorch/torch/nn/parameter.py
@@ -201,3 +188,1 @@
source = cast_to_ms_tensor(source)
self.set_data(source, True)
return self
#
Erpim commented 2 weeks ago
直接删除
hanjr marked this conversation as resolved
Erpim reviewed 2 weeks ago
mindtorch/torch/nn/parameter.py
@@ -202,2 +188,2 @@
self.set_data(source, True)
return self
#
# def set_(self, source=None, storage_offset=0, size=None, stride=None):
Erpim commented 2 weeks ago
新增该函数 ``` def __setstate__(self, state): if isinstance(state, tuple): if len(state) == 4: self.set_(*state) return elif len(state) == 5: data = state[0] Parameter.__init__(self, data, requires_grad=state[3]) self.set_dtype(data.dtype) self.set_data(data=data, slice_shape=True) self._requires_grad = state[3] return ```
hanjr marked this conversation as resolved
Erpim commented 2 weeks ago
Collaborator
parameter new函数修改: ``` def __new__(cls, data=None, *args, **kwargs): if data is None: data = np.array(1) ``` init函数修改: ``` def __init__(self, data=None, requires_grad=True, name=None, layerwise_parallel=False, parallel_optimizer=True): if data is None: data = np.array(1) ```
Erpim reviewed 2 weeks ago
mindtorch/torch/serialization.py
@@ -593,3 +596,4 @@
pass
return super().find_class(mod_name, name)

def legacy_load(f):
Erpim commented 2 weeks ago
找个合适的位置调用torch_disable, torch_pop
hanjr marked this conversation as resolved
Erpim reviewed 2 weeks ago
mindtorch/torch/serialization.py
@@ -598,2 +602,4 @@
def persistent_load(saved_id):
if isinstance(saved_id, tuple):
if all(saved_id[1:]):
_check_container_source(*saved_id)
Erpim commented 2 weeks ago
补充_check_container_source实现
hanjr marked this conversation as resolved
Erpim reviewed 2 weeks ago
mindtorch/torch/_tensor.py
@@ -0,0 +4,4 @@
if type is Tensor:
return func(*args)

ret = func(*args).as_subclass(type)
Erpim commented 2 weeks ago
当前有应该不支持as_subclass?什么场景会进这个函数?
hanjr commented 1 week ago
def as_subclass(self, cls): return cls(self) 补充了简单实现,浩宇的那个样例会进这个函数
hanjr commented 1 week ago
已删除,当前不需要
hanjr marked this conversation as resolved
Erpim reviewed 2 weeks ago
mindtorch/torch/serialization.py
@@ -24,10 +24,14 @@ from typing import Any, BinaryIO, Union, IO, Optional, Type, Dict, Tuple
from typing_extensions import TypeAlias
from ml_dtypes import bfloat16
import numpy as np
from mindtorch.torch import storage
Erpim commented 2 weeks ago
和后面的变量重名了
hanjr marked this conversation as resolved
Erpim reviewed 2 weeks ago
mindtorch/torch/serialization.py
@@ -448,3 +440,1 @@
if storage_dataptr != 0:
if storage_dataptr in storage_dtypes:
if storage_dtype != storage_dtypes[storage_dataptr]:
if storage.data_ptr() != 0:
Erpim commented 2 weeks ago
这部分data_ptr的作用是否和内存访问相关,频繁调用性能较差
Erpim commented 2 weeks ago
Collaborator
需验证,加载torch的pth,保存mindtorch的pth,再加载回来,整个流程的正确性
Erpim reviewed 2 weeks ago
mindtorch/torch/serialization.py
@@ -550,3 +546,3 @@
name = f'archive/data/{key}'
storage = serialized_storages[key]
storage_data = storage.get_bytes()
storage_data = storage.inner_data
Erpim commented 2 weeks ago
为什么不调用get_bytes? 如果效果等价,是不是修改get_bytes接口实现?
hanjr commented 1 week ago
get_bytes是mindspore tensor的方法,之前是直接保存的tensor,所以用的get_bytes,现在我们保存的storage,直接保存storage的inner data数据就行
Erpim reviewed 2 weeks ago
mindtorch/torch/serialization.py
@@ -590,3 +590,1 @@
return eval(name)
if mod_name == 'torch':
return str(name)
if type(mod_name) is str and mod_name.startswith('torch.'):
Erpim commented 2 weeks ago
用torch_disable来解决该类问题
hanjr marked this conversation as resolved
Erpim reviewed 2 weeks ago
mindtorch/torch/serialization.py
@@ -791,3 +807,1 @@
return eval(name)
if mod_name == 'torch':
return str(name)
if type(mod_name) is str and mod_name.startswith('torch.'):
Erpim commented 2 weeks ago
同上
hanjr marked this conversation as resolved
frelam commented 2 weeks ago
Collaborator
import torch a = torch.tensor(2.) a.kkk = 3 pth = torch.save(a, './pt.pth') tensor = torch.load('./pt.pth') print(tensor) print(tensor.kkk) 这个用例可以跑过。 但是mindtorch无法恢复出tensor.kkk这个值
frelam commented 2 weeks ago
Collaborator
> import torch > a = torch.tensor(2.) > > a.kkk = 3 > > pth = torch.save(a, './pt.pth') > > tensor = torch.load('./pt.pth') > > print(tensor) > print(tensor.kkk) > > > 这个用例可以跑过。 但是mindtorch无法恢复出tensor.kkk这个值 > import torch a = torch.tensor(2.) a.kkk = 3 pth = torch.save(a, './pt.pth') import mindtorch.torch as ms_torch tensor = ms_torch.load('./pt.pth') print(tensor) print(tensor.kkk)
Erpim commented 1 week ago
Collaborator
> > import torch > > a = torch.tensor(2.) > > > > a.kkk = 3 > > > > pth = torch.save(a, './pt.pth') > > > > tensor = torch.load('./pt.pth') > > > > print(tensor) > > print(tensor.kkk) > > > > > > 这个用例可以跑过。 但是mindtorch无法恢复出tensor.kkk这个值 > > > > import torch > a = torch.tensor(2.) > > a.kkk = 3 > > pth = torch.save(a, './pt.pth') > > import mindtorch.torch as ms_torch > > tensor = ms_torch.load('./pt.pth') > > print(tensor) > print(tensor.kkk) 需新增_set_obj_state _rebuild_from_type_v2中调用
hanjr commented 1 week ago
Poster
> > import torch > > a = torch.tensor(2.) > > > > a.kkk = 3 > > > > pth = torch.save(a, './pt.pth') > > > > tensor = torch.load('./pt.pth') > > > > print(tensor) > > print(tensor.kkk) > > > > > > 这个用例可以跑过。 但是mindtorch无法恢复出tensor.kkk这个值 > > > > import torch > a = torch.tensor(2.) > > a.kkk = 3 > > pth = torch.save(a, './pt.pth') > > import mindtorch.torch as ms_torch > > tensor = ms_torch.load('./pt.pth') > > print(tensor) > print(tensor.kkk) 目前已经可以正常load torch具有额外属性的tensor,mindtorch也可以正确保存tensor的额外属性
zoulq reviewed 1 week ago
@@ -54,0 +56,4 @@
unsupported_attr(stride)
from mindtorch.torch.tensor import tensor # pylint: disable=R0401, C0415
t = tensor([], dtype=storage.dtype, device=storage._untyped().device)
return t.set_(storage._untyped(), storage_offset, size)
zoulq commented 1 week ago
直接创建一个tensor的方法快一点吧?
Erpim commented 1 week ago
不可以直接创建tensor,加载场景是通过修改storage的值,同步改变tensor的值,如果直接场景tensor,外部storage和tensor直接没有建立连接,不会同步更新
Erpim commented 1 week ago
不可以直接创建tensor,加载场景是通过修改storage的值,同步改变tensor的值,如果直接场景tensor,外部storage和tensor直接没有建立连接,不会同步更新
zoulq reviewed 1 week ago
mindtorch/torch/_utils.py
@@ -54,0 +67,4 @@

def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
from mindtorch.torch.functional import from_numpy # pylint: disable=R0401, C0415
tensor = from_numpy(data).to(dtype=dtype, device=device)
zoulq commented 1 week ago
同上,是不是可以直接创建tensor?
Erpim commented 1 week ago
不可以直接创建tensor,加载场景是通过修改storage的值,同步改变tensor的值,如果直接场景tensor,外部storage和tensor直接没有建立连接,不会同步更新
zoulq reviewed 1 week ago
mindtorch/torch/nn/parameter.py
@@ -45,0 +45,4 @@
def __new__(cls, data=None, requires_grad=True, name=None, layerwise_parallel=False, # pylint: disable = W0613
parallel_optimizer=True): # pylint: disable = W0613
if data is None:
data = np.array(1)
zoulq commented 1 week ago
data=1和data = np.array(1)对结果的区别是dtype不同,现在这么写是因为要求int64?
Erpim commented 1 week ago
可以修改,效果是一样的。
hanjr commented 1 week ago
已修改
zoulq reviewed 1 week ago
mindtorch/torch/nn/parameter.py
@@ -202,3 +194,1 @@
self.set_data(source, True)
return self

def __setstate__(self, state):
zoulq commented 1 week ago
为啥要把set_接口删掉呢?
hanjr commented 1 week ago
因为我们的parameter继承自Tensor,他直接可以使用Tensor的类内方法 set_。
zoulq reviewed 1 week ago
@@ -581,1 +568,3 @@

torch_pop()
return result
torch_disable()
zoulq commented 1 week ago
现在还有要依赖pytorch的场景吗?
hanjr commented 1 week ago
没有依赖pytorch的场景,在load torch的权重时,会读取到保存的torch函数指针,这个地方只是为了保证这个函数指针一定指向mindtorch实现的同名函数位置。
zoulq reviewed 1 week ago
mindtorch/torch/serialization.py
@@ -597,0 +587,4 @@
try:
current_source = ''.join(get_source_lines_and_file(container_type)[0])
except Exception: # saving the source is optional, so we can ignore any errors
warnings.warn("Couldn't retrieve source code for container of "
zoulq commented 1 week ago
warnings是否要用mindtorch封装后的?这样可以受环境变量配置控制
hanjr commented 1 week ago
已修改
zoulq reviewed 1 week ago
testing/ut/pytorch/torch/test_import.py
@@ -14,3 +5,2 @@
set_mode_by_env_config()
test_import()
#
zoulq commented 1 week ago
这些用例注释掉的原因是什么?
hanjr commented 1 week ago
已放开
zoulq reviewed 1 week ago
@@ -185,0 +358,4 @@
pytorch.save(state_dict, 'mindtorch_module_oldfile.pt', _use_new_zipfile_serialization=False)
os.remove('mindtorch_module_oldfile.pt')

@SKIP_ENV_CPU(reason="This function need torch version >= 2.1.0")
zoulq commented 1 week ago
这个用例在GPU和Ascend也是不能跑的,可以都先加上skip,如果后续需要跑再单独放开
hanjr commented 1 week ago
已修改
hanjr changed title from [WIP]load/save fix to load/save fix 1 week ago
frelam commented 1 week ago
Collaborator
遇到个问题: _rebuild_tensor_v2() takes 6 positional arguments but 7 were given
hanjr commented 1 week ago
Poster
> 遇到个问题: > _rebuild_tensor_v2() takes 6 positional arguments but 7 were given 在torch 2.1版本的rebuild_tensor_v2中第七个参数 metadata,主要作用是: Currently, this only returns a dict[string, bool] specifing whether `conj` or `neg` bit is set. 用于指示tensor是否有共轭和取负。 目前我们并不支持tensor的metadata属性,所以设置该函数入参为None,unsupported_attr。可以正常传参,但是不会set metadata
hanjr commented 1 week ago
Poster
修改了 parameter `__reduce_ex__` 函数主要是为了和torch保持一致,能够为parameter设置额外的属性,比如a.kkk = 2这种情况下能够正确save和load。
frelam commented 1 week ago
Collaborator
> > 遇到个问题: > > _rebuild_tensor_v2() takes 6 positional arguments but 7 were given > > 在torch 2.1版本的rebuild_tensor_v2中第七个参数 metadata,主要作用是: > Currently, this only returns a dict[string, bool] specifing whether `conj` or `neg` bit is set. > 用于指示tensor是否有共轭和取负。 > 目前我们并不支持tensor的metadata属性,所以设置该函数入参为None,unsupported_attr。可以正常传参,但是不会set metadata 已验证最新修改, 在我这边场景可行。
Erpim reviewed 1 week ago
mindtorch/torch/utils/data/dataloader.py
@@ -219,3 +219,3 @@
shuffle: Optional[bool] = None, sampler: Union[Sampler, Iterable, None] = None,
batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
num_workers: int = 1, collate_fn: Optional[_collate_fn_t] = None,
num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
Erpim commented 1 week ago
默认值改成None,如果是None的时候走后面的判断。如果用户设0走非多进程
Erpim marked this conversation as resolved
Erpim merged commit 2bf20afdfb into master 1 week ago
hanjr deleted branch load_fix 6 days ago
The pull request has been merged as 2bf20afdfb.
Sign in to join this conversation.
No reviewers
No Label
No Milestone
No Assignees
4 Participants
Notifications
Due Date

No due date set.

Dependencies

This pull request currently doesn't have any dependencies.

Loading…
There is no content yet.