#911 [WIP]fix tensor.type

Open
frelam wants to merge 5 commits from frelam/MSAdapter:master0401 into master
  1. +55
    -2
      mindtorch/torch/cuda/__init__.py
  2. +1
    -1
      mindtorch/torch/tensor.py
  3. +16
    -0
      testing/ut/pytorch/tensor/test_tensor2.py

+ 55
- 2
mindtorch/torch/cuda/__init__.py View File

@@ -6,18 +6,71 @@ import traceback
import mindspore as ms

from mindtorch.utils import get_backend, unsupported_attr
from mindtorch.torch.tensor import BoolTensor, ByteTensor, CharTensor, ShortTensor, IntTensor, HalfTensor, \
FloatTensor, DoubleTensor, LongTensor, BFloat16Tensor, tensor
from mindtorch.torch.tensor import BoolTensor as _BoolTensor, \
ByteTensor as _ByteTensor, \
CharTensor as _CharTensor, \
ShortTensor as _ShortTensor, \
IntTensor as _IntTensor, \
HalfTensor as _HalfTensor, \
FloatTensor as _FloatTensor, \
DoubleTensor as _DoubleTensor , \
LongTensor as _LongTensor, \
BFloat16Tensor as _BFloat16Tensor, \
tensor, cast_to_ms_tensor, cast_to_adapter_tensor
import mindtorch.torch.cuda.amp as amp
from mindtorch.torch.cuda.random import manual_seed_all, manual_seed
from mindtorch.torch.logging import warning
from mindtorch.torch.cuda.streams import *
from mindtorch.torch.common.dtype import _get_type_from_dtype, _get_dtype_from_type
from ._utils import _get_device_index

_tls = threading.local()
_initialization_lock = threading.Lock()
_queued_calls = []

class _CudaTensor:
def type(self, dtype=None, non_blocking=False, **kwargs):
unsupported_attr(non_blocking)
unsupported_attr(kwargs)
if dtype is None:
return 'torch.cuda.' + _get_type_from_dtype(self.dtype)

_dtype = _get_dtype_from_type(dtype)
if _dtype == self.dtype:
return self
x = cast_to_ms_tensor(self)
output = x.astype(_dtype)
return cast_to_adapter_tensor(output)

class BoolTensor(_CudaTensor, _BoolTensor):
...

class ByteTensor(_CudaTensor, _ByteTensor):
...

class CharTensor(_CudaTensor, _CharTensor):
...

class ShortTensor(_CudaTensor, _ShortTensor):
...

class IntTensor(_CudaTensor, _IntTensor):
...

class HalfTensor(_CudaTensor, _HalfTensor):
...

class FloatTensor(_CudaTensor, _FloatTensor):
...

class DoubleTensor(_CudaTensor, _DoubleTensor):
...

class LongTensor(_CudaTensor, _LongTensor):
...

class BFloat16Tensor(_CudaTensor, _BFloat16Tensor):
...

class device:
def __init__(self, device):


+ 1
- 1
mindtorch/torch/tensor.py View File

@@ -2182,7 +2182,7 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
unsupported_attr(non_blocking)
unsupported_attr(kwargs)
if dtype is None:
return _get_type_from_dtype(self.dtype)
return 'torch.' + _get_type_from_dtype(self.dtype)

_dtype = _get_dtype_from_type(dtype)
if _dtype == self.dtype:


+ 16
- 0
testing/ut/pytorch/tensor/test_tensor2.py View File

@@ -800,6 +800,21 @@ def test_zero_dimention():
# ms_tensor3 = ms_tensor2.add_one()
# param_compare(torch_tensor3, ms_tensor3)

def test_type_cuda_tensor():
a = ms_torch.cuda.FloatTensor(3)
b = ms_torch.cuda.ByteTensor(3)
assert a.type() == 'torch.cuda.FloatTensor'
assert b.type() == 'torch.cuda.ByteTensor'

a = ms_torch.FloatTensor(3)
b = ms_torch.ByteTensor(3)
assert a.type() == 'torch.FloatTensor'
assert b.type() == 'torch.ByteTensor'

a = a.cuda()
b = b.cuda()
assert a.type() == 'torch.FloatTensor'
assert b.type() == 'torch.ByteTensor'

def test_mH():
data_np = np.random.random((2, 3))
@@ -893,3 +908,4 @@ if __name__ == '__main__':
test_mH()
test_mT()
test_storage_offset()
test_type_cuda_tensor()

Loading…
Cancel
Save