|
|
@@ -6,18 +6,71 @@ import traceback |
|
|
|
import mindspore as ms |
|
|
|
|
|
|
|
from mindtorch.utils import get_backend, unsupported_attr |
|
|
|
from mindtorch.torch.tensor import BoolTensor, ByteTensor, CharTensor, ShortTensor, IntTensor, HalfTensor, \ |
|
|
|
FloatTensor, DoubleTensor, LongTensor, BFloat16Tensor, tensor |
|
|
|
from mindtorch.torch.tensor import BoolTensor as _BoolTensor, \ |
|
|
|
ByteTensor as _ByteTensor, \ |
|
|
|
CharTensor as _CharTensor, \ |
|
|
|
ShortTensor as _ShortTensor, \ |
|
|
|
IntTensor as _IntTensor, \ |
|
|
|
HalfTensor as _HalfTensor, \ |
|
|
|
FloatTensor as _FloatTensor, \ |
|
|
|
DoubleTensor as _DoubleTensor , \ |
|
|
|
LongTensor as _LongTensor, \ |
|
|
|
BFloat16Tensor as _BFloat16Tensor, \ |
|
|
|
tensor, cast_to_ms_tensor, cast_to_adapter_tensor |
|
|
|
import mindtorch.torch.cuda.amp as amp |
|
|
|
from mindtorch.torch.cuda.random import manual_seed_all, manual_seed |
|
|
|
from mindtorch.torch.logging import warning |
|
|
|
from mindtorch.torch.cuda.streams import * |
|
|
|
from mindtorch.torch.common.dtype import _get_type_from_dtype, _get_dtype_from_type |
|
|
|
from ._utils import _get_device_index |
|
|
|
|
|
|
|
_tls = threading.local() |
|
|
|
_initialization_lock = threading.Lock() |
|
|
|
_queued_calls = [] |
|
|
|
|
|
|
|
class _CudaTensor: |
|
|
|
def type(self, dtype=None, non_blocking=False, **kwargs): |
|
|
|
unsupported_attr(non_blocking) |
|
|
|
unsupported_attr(kwargs) |
|
|
|
if dtype is None: |
|
|
|
return 'torch.cuda.' + _get_type_from_dtype(self.dtype) |
|
|
|
|
|
|
|
_dtype = _get_dtype_from_type(dtype) |
|
|
|
if _dtype == self.dtype: |
|
|
|
return self |
|
|
|
x = cast_to_ms_tensor(self) |
|
|
|
output = x.astype(_dtype) |
|
|
|
return cast_to_adapter_tensor(output) |
|
|
|
|
|
|
|
class BoolTensor(_CudaTensor, _BoolTensor): |
|
|
|
... |
|
|
|
|
|
|
|
class ByteTensor(_CudaTensor, _ByteTensor): |
|
|
|
... |
|
|
|
|
|
|
|
class CharTensor(_CudaTensor, _CharTensor): |
|
|
|
... |
|
|
|
|
|
|
|
class ShortTensor(_CudaTensor, _ShortTensor): |
|
|
|
... |
|
|
|
|
|
|
|
class IntTensor(_CudaTensor, _IntTensor): |
|
|
|
... |
|
|
|
|
|
|
|
class HalfTensor(_CudaTensor, _HalfTensor): |
|
|
|
... |
|
|
|
|
|
|
|
class FloatTensor(_CudaTensor, _FloatTensor): |
|
|
|
... |
|
|
|
|
|
|
|
class DoubleTensor(_CudaTensor, _DoubleTensor): |
|
|
|
... |
|
|
|
|
|
|
|
class LongTensor(_CudaTensor, _LongTensor): |
|
|
|
... |
|
|
|
|
|
|
|
class BFloat16Tensor(_CudaTensor, _BFloat16Tensor): |
|
|
|
... |
|
|
|
|
|
|
|
class device: |
|
|
|
def __init__(self, device): |
|
|
|