From ce6e8a64c559587dbaf83002e9f8ffc93e3a15a1 Mon Sep 17 00:00:00 2001 From: pengtao <756625088@qq.com> Date: Wed, 30 Nov 2022 20:20:55 +0800 Subject: [PATCH] =?UTF-8?q?[=E4=BC=97=E6=99=BA=E6=B4=BB=E5=8A=A8=E4=BB=BB?= =?UTF-8?q?=E5=8A=A11]MaxUnpool?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ms_adapter/pytorch/nn/functional.py | 122 ++++++++++++++ .../pytorch/nn/functional/test_maxunpool.py | 150 ++++++++++++++++++ 2 files changed, 272 insertions(+) create mode 100644 testing/ut/pytorch/nn/functional/test_maxunpool.py diff --git a/ms_adapter/pytorch/nn/functional.py b/ms_adapter/pytorch/nn/functional.py index f9dd70d4..793d7641 100644 --- a/ms_adapter/pytorch/nn/functional.py +++ b/ms_adapter/pytorch/nn/functional.py @@ -1433,6 +1433,128 @@ def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, out = _max_pool(input) return cast_to_adapter_tensor(out) +def _unpool_output_size(input, kernel_size, stride, padding, output_size): + input_size = input.shape + default_size = [] + + if isinstance(kernel_size, int): + kernel_size = [kernel_size] + if isinstance(stride, int): + stride = [stride] + if isinstance(padding, int): + padding = [padding] + + for d in range(len(kernel_size)): + default_size.append((input_size[-len(kernel_size) + d] - 1) * stride[d] + kernel_size[d] - 2 * padding[d]) + if output_size is None: + ret = default_size + else: + if len(output_size) == len(kernel_size) + 2: + output_size = output_size[2:] + if len(output_size) != len(kernel_size): + raise ValueError( + "output_size should be a sequence containing " + "{} or {} elements, but it has a length of '{}'".format( + len(kernel_size), len(kernel_size) + 2, len(output_size) + ) + ) + for d in range(len(kernel_size)): + min_size = default_size[d] - stride[d] + max_size = default_size[d] + stride[d] + if not (min_size < output_size[d] < max_size): + raise ValueError( + 'invalid output_size "{}" (dim {} must be between {} and {})'.format( + output_size, d, min_size, max_size + ) + ) + + ret = output_size + return ret + +def max_unpool1d(input, indices, kernel_size, stride, padding, output_size = None): + input = cast_to_ms_tensor(input) + indices = cast_to_ms_tensor(indices) + + input_shape = list(input.shape) + size = _unpool_output_size(input, kernel_size, stride, padding, output_size) + + output_size = tuple(input_shape[0:-1] + size) + + out = ms.ops.Zeros()(output_size, input.dtype) + + if len(input_shape) == 2: + row = ms.numpy.arange(indices.shape[0]).reshape(-1,1).broadcast_to((-1, indices.shape[1])) + out[row,indices] = input + elif len(input_shape) == 3: + batch = ms.numpy.arange(indices.shape[0]).reshape(-1,1,1).broadcast_to((-1, indices.shape[1], indices.shape[2])) + row = ms.numpy.arange(indices.shape[1]).reshape(-1,1).broadcast_to((indices.shape[0], -1, indices.shape[2])) + out[batch,row,indices] = input + + return out + + +def max_unpool2d(input, indices, kernel_size, stride, padding, output_size = None): + input = cast_to_ms_tensor(input) + indices = cast_to_ms_tensor(indices) + kernel_size = (kernel_size, kernel_size) + if stride is not None: + _stride = (stride, stride) + else: + _stride = kernel_size + padding = (padding, padding) + input_shape = list(input.shape) + + size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + + output_size = tuple(input_shape[0:-2] + size) + + out = ms.ops.Zeros()(output_size, input.dtype) + + if len(input_shape) == 3: + row = ms.numpy.arange(input_shape[0]).reshape(-1,1).broadcast_to((-1, input_shape[1] * input_shape[2])) + indices = indices.reshape(-1,input_shape[1] * input_shape[2]) + out = out.reshape((-1, output_size[1] * output_size[2])) + out[row,indices] = input.reshape(-1,input_shape[1] * input_shape[2]) + elif len(input_shape) == 4: + batch = ms.numpy.arange(input_shape[0]).reshape(-1,1,1).broadcast_to((-1, input_shape[1], input_shape[2]*input_shape[3])) + row = ms.numpy.arange(input_shape[1]).reshape(-1,1).broadcast_to((input_shape[0], -1, input_shape[2]*input_shape[3])) + indices = indices.reshape(input_shape[0],-1,input_shape[2] * input_shape[3]) + out = out.view((output_size[0], -1, output_size[2] * output_size[3])) + out[batch,row,indices] = input.view((input_shape[0], -1, input_shape[2] * input_shape[3])) + + return out.reshape(output_size) + +def max_unpool3d(input, indices, kernel_size, stride, padding, output_size = None): + input = cast_to_ms_tensor(input) + indices = cast_to_ms_tensor(indices) + kernel_size = (kernel_size, kernel_size, kernel_size) + if stride is not None: + _stride = (stride, stride, stride) + else: + _stride = kernel_size + padding = (padding, padding, padding) + input_shape = list(input.shape) + + size = _unpool_output_size(input, kernel_size, _stride, padding, output_size) + + output_size = tuple(input_shape[0:-3] + size) + + out = ms.ops.Zeros()(output_size, input.dtype) + + if len(input_shape) == 4: + row = ms.numpy.arange(input_shape[0]).reshape(-1,1).broadcast_to((-1, input_shape[1] * input_shape[2] * input_shape[3])) + indices = indices.reshape(-1,input_shape[1] * input_shape[2] * input_shape[3]) + out = out.reshape((-1, output_size[1] * output_size[2] * output_size[3])) + out[row,indices] = input.reshape(-1,input_shape[1] * input_shape[2] * input_shape[3]) + elif len(input_shape) == 5: + dim = input_shape[2]*input_shape[3]*input_shape[4] + batch = ms.numpy.arange(input_shape[0]).reshape(-1,1,1).broadcast_to((-1, input_shape[1], dim)) + row = ms.numpy.arange(input_shape[1]).reshape(-1,1).broadcast_to((input_shape[0], -1, dim)) + indices = indices.reshape(input_shape[0],-1, dim) + out = out.view((output_size[0], -1, output_size[2] * output_size[3] * output_size[4])) + out[batch,row,indices] = input.reshape((input_shape[0], -1, dim)) + + return out.reshape(output_size) def linear(input, weight, bias=None): @constexpr diff --git a/testing/ut/pytorch/nn/functional/test_maxunpool.py b/testing/ut/pytorch/nn/functional/test_maxunpool.py new file mode 100644 index 00000000..9d561ccc --- /dev/null +++ b/testing/ut/pytorch/nn/functional/test_maxunpool.py @@ -0,0 +1,150 @@ +import numpy as np +import torch +import mindspore as ms +import ms_adapter.pytorch as ms_torch + +def test_maxunpool1d_with_2dim(): + ms.context.set_context(mode=ms.PYNATIVE_MODE) + N = np.random.randint(1, 65) + C = np.random.randint(1, 513) + tensor = np.random.randn(N, C).astype(np.float32) + + torch_tensor = torch.tensor(tensor) + kernel_size = np.random.randint(1, C + 1) + padding = np.random.randint(0, kernel_size/2 + 1) + stride = 1 + + torch_pooling, torch_indices = torch.nn.functional.max_pool1d(torch_tensor, kernel_size, stride, padding, return_indices=True) + torch_output = torch.nn.functional.max_unpool1d(torch_pooling, torch_indices, kernel_size, stride, padding) + + ms_pooling = ms_torch.tensor(torch_pooling.numpy()) + ms_indices = ms_torch.tensor(torch_indices.numpy()) + + ms_output = ms_torch.nn.functional.max_unpool1d(ms_pooling, ms_indices, kernel_size, stride, padding) + + assert np.allclose(ms_output.numpy(), torch_output.numpy()) + + +def test_maxunpool1d_with_3dim(): + ms.context.set_context(mode=ms.PYNATIVE_MODE) + B = np.random.randint(1, 65) + N = np.random.randint(1, 65) + C = np.random.randint(1, 513) + tensor = np.random.randn(B, N, C).astype(np.float32) + + torch_tensor = torch.tensor(tensor) + kernel_size = np.random.randint(1, C + 1) + padding = np.random.randint(0, kernel_size/2 + 1) + stride = 1 + + torch_pooling, torch_indices = torch.nn.functional.max_pool1d(torch_tensor, kernel_size, stride, padding, return_indices=True) + torch_output = torch.nn.functional.max_unpool1d(torch_pooling, torch_indices, kernel_size, stride, padding) + + ms_pooling = ms_torch.tensor(torch_pooling.numpy()) + ms_indices = ms_torch.tensor(torch_indices.numpy()) + + ms_output = ms_torch.nn.functional.max_unpool1d(ms_pooling, ms_indices, kernel_size, stride, padding) + + assert np.allclose(ms_output.numpy(), torch_output.numpy()) + + +def test_maxunpool2d_with_3dim(): + ms.context.set_context(mode=ms.PYNATIVE_MODE) + H = W = np.random.randint(1, 65) + + C = np.random.randint(1, 513) + tensor = np.random.randn(C, H, W).astype(np.float32) + + torch_tensor = torch.tensor(tensor) + kernel_size = np.random.randint(1, H + 1) + padding = np.random.randint(0, kernel_size/2 + 1) + stride = 1 + + torch_pooling, torch_indices = torch.nn.functional.max_pool2d(torch_tensor, kernel_size, stride, padding, return_indices=True) + torch_output = torch.nn.functional.max_unpool2d(torch_pooling, torch_indices, kernel_size, stride, padding) + + ms_pooling = ms_torch.tensor(torch_pooling.numpy()) + ms_indices = ms_torch.tensor(torch_indices.numpy()) + + ms_output = ms_torch.nn.functional.max_unpool2d(ms_pooling, ms_indices, kernel_size, stride, padding) + + assert np.allclose(ms_output.numpy(), torch_output.numpy()) + + +def test_maxunpool2d_with_4dim(): + ms.context.set_context(mode=ms.PYNATIVE_MODE) + H = W = np.random.randint(1, 65) + C = np.random.randint(1, 513) + N = np.random.randint(1, 32) + tensor = np.random.randn(N, C, H, W).astype(np.float32) + + torch_tensor = torch.tensor(tensor) + kernel_size = np.random.randint(1, H + 1) + padding = np.random.randint(0, kernel_size/2 + 1) + stride = 1 + + torch_pooling, torch_indices = torch.nn.functional.max_pool2d(torch_tensor, kernel_size, stride, padding, return_indices=True) + torch_output = torch.nn.functional.max_unpool2d(torch_pooling, torch_indices, kernel_size, stride, padding) + + ms_pooling = ms_torch.tensor(torch_pooling.numpy()) + ms_indices = ms_torch.tensor(torch_indices.numpy()) + + ms_output = ms_torch.nn.functional.max_unpool2d(ms_pooling, ms_indices, kernel_size, stride, padding) + + assert np.allclose(ms_output.numpy(), torch_output.numpy()) + + +def test_maxunpool3d_with_4dim(): + ms.context.set_context(mode=ms.PYNATIVE_MODE) + H = W = np.random.randint(1, 65) + C = np.random.randint(1, 1025) + D = np.random.randint(1, H + 1) + tensor = np.random.randn(C, D, H, W).astype(np.float32) + + torch_tensor = torch.tensor(tensor) + kernel_size = np.random.randint(1, D + 1) + padding = np.random.randint(0, kernel_size/2 + 1) + stride = 1 + + torch_pooling, torch_indices = torch.nn.functional.max_pool3d(torch_tensor, kernel_size, stride, padding, return_indices=True) + torch_output = torch.nn.functional.max_unpool3d(torch_pooling, torch_indices, kernel_size, stride, padding) + + ms_pooling = ms_torch.tensor(torch_pooling.numpy()) + ms_indices = ms_torch.tensor(torch_indices.numpy()) + + ms_output = ms_torch.nn.functional.max_unpool3d(ms_pooling, ms_indices, kernel_size, stride, padding) + + assert np.allclose(ms_output.numpy(), torch_output.numpy()) + + +def test_maxunpool3d_with_5dim(): + ms.context.set_context(mode=ms.PYNATIVE_MODE) + H = W = np.random.randint(1, 33) + C = np.random.randint(1, 513) + D = np.random.randint(1, H + 1) + N = np.random.randint(1, 17) + tensor = np.random.randn(N, C, D, H, W).astype(np.float32) + + torch_tensor = torch.tensor(tensor) + kernel_size = np.random.randint(1, D + 1) + padding = np.random.randint(0, kernel_size/2 + 1) + stride = 1 + + torch_pooling, torch_indices = torch.nn.functional.max_pool3d(torch_tensor, kernel_size, stride, padding, return_indices=True) + torch_output = torch.nn.functional.max_unpool3d(torch_pooling, torch_indices, kernel_size, stride, padding) + + ms_pooling = ms_torch.tensor(torch_pooling.numpy()) + ms_indices = ms_torch.tensor(torch_indices.numpy()) + + ms_output = ms_torch.nn.functional.max_unpool3d(ms_pooling, ms_indices, kernel_size, stride, padding) + + assert np.allclose(ms_output.numpy(), torch_output.numpy()) + +if __name__ == '__main__': + test_maxunpool1d_with_2dim() + test_maxunpool1d_with_3dim() + test_maxunpool2d_with_3dim() + test_maxunpool2d_with_4dim() + test_maxunpool3d_with_4dim() + test_maxunpool3d_with_5dim() + -- 2.34.1