#195 [众智活动任务1]MaxUnpool

Closed
pengtaox wants to merge 1 commits from pengtaox/MSAdapter:maxunpool into master
  1. +122
    -0
      ms_adapter/pytorch/nn/functional.py
  2. +150
    -0
      testing/ut/pytorch/nn/functional/test_maxunpool.py

+ 122
- 0
ms_adapter/pytorch/nn/functional.py View File

@@ -1433,6 +1433,128 @@ def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
out = _max_pool(input)
return cast_to_adapter_tensor(out)

def _unpool_output_size(input, kernel_size, stride, padding, output_size):
input_size = input.shape
default_size = []

if isinstance(kernel_size, int):
kernel_size = [kernel_size]
if isinstance(stride, int):
stride = [stride]
if isinstance(padding, int):
padding = [padding]

for d in range(len(kernel_size)):
default_size.append((input_size[-len(kernel_size) + d] - 1) * stride[d] + kernel_size[d] - 2 * padding[d])
if output_size is None:
ret = default_size
else:
if len(output_size) == len(kernel_size) + 2:
output_size = output_size[2:]
if len(output_size) != len(kernel_size):
raise ValueError(
"output_size should be a sequence containing "
"{} or {} elements, but it has a length of '{}'".format(
len(kernel_size), len(kernel_size) + 2, len(output_size)
)
)
for d in range(len(kernel_size)):
min_size = default_size[d] - stride[d]
max_size = default_size[d] + stride[d]
if not (min_size < output_size[d] < max_size):
raise ValueError(
'invalid output_size "{}" (dim {} must be between {} and {})'.format(
output_size, d, min_size, max_size
)
)

ret = output_size
return ret

def max_unpool1d(input, indices, kernel_size, stride, padding, output_size = None):
input = cast_to_ms_tensor(input)
indices = cast_to_ms_tensor(indices)

input_shape = list(input.shape)
size = _unpool_output_size(input, kernel_size, stride, padding, output_size)

output_size = tuple(input_shape[0:-1] + size)

out = ms.ops.Zeros()(output_size, input.dtype)

if len(input_shape) == 2:
row = ms.numpy.arange(indices.shape[0]).reshape(-1,1).broadcast_to((-1, indices.shape[1]))
out[row,indices] = input
elif len(input_shape) == 3:
batch = ms.numpy.arange(indices.shape[0]).reshape(-1,1,1).broadcast_to((-1, indices.shape[1], indices.shape[2]))
row = ms.numpy.arange(indices.shape[1]).reshape(-1,1).broadcast_to((indices.shape[0], -1, indices.shape[2]))
out[batch,row,indices] = input

return out


def max_unpool2d(input, indices, kernel_size, stride, padding, output_size = None):
input = cast_to_ms_tensor(input)
indices = cast_to_ms_tensor(indices)
kernel_size = (kernel_size, kernel_size)
if stride is not None:
_stride = (stride, stride)
else:
_stride = kernel_size
padding = (padding, padding)
input_shape = list(input.shape)

size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)

output_size = tuple(input_shape[0:-2] + size)

out = ms.ops.Zeros()(output_size, input.dtype)

if len(input_shape) == 3:
row = ms.numpy.arange(input_shape[0]).reshape(-1,1).broadcast_to((-1, input_shape[1] * input_shape[2]))
indices = indices.reshape(-1,input_shape[1] * input_shape[2])
out = out.reshape((-1, output_size[1] * output_size[2]))
out[row,indices] = input.reshape(-1,input_shape[1] * input_shape[2])
elif len(input_shape) == 4:
batch = ms.numpy.arange(input_shape[0]).reshape(-1,1,1).broadcast_to((-1, input_shape[1], input_shape[2]*input_shape[3]))
row = ms.numpy.arange(input_shape[1]).reshape(-1,1).broadcast_to((input_shape[0], -1, input_shape[2]*input_shape[3]))
indices = indices.reshape(input_shape[0],-1,input_shape[2] * input_shape[3])
out = out.view((output_size[0], -1, output_size[2] * output_size[3]))
out[batch,row,indices] = input.view((input_shape[0], -1, input_shape[2] * input_shape[3]))

return out.reshape(output_size)

def max_unpool3d(input, indices, kernel_size, stride, padding, output_size = None):
input = cast_to_ms_tensor(input)
indices = cast_to_ms_tensor(indices)
kernel_size = (kernel_size, kernel_size, kernel_size)
if stride is not None:
_stride = (stride, stride, stride)
else:
_stride = kernel_size
padding = (padding, padding, padding)
input_shape = list(input.shape)

size = _unpool_output_size(input, kernel_size, _stride, padding, output_size)

output_size = tuple(input_shape[0:-3] + size)

out = ms.ops.Zeros()(output_size, input.dtype)

if len(input_shape) == 4:
row = ms.numpy.arange(input_shape[0]).reshape(-1,1).broadcast_to((-1, input_shape[1] * input_shape[2] * input_shape[3]))
indices = indices.reshape(-1,input_shape[1] * input_shape[2] * input_shape[3])
out = out.reshape((-1, output_size[1] * output_size[2] * output_size[3]))
out[row,indices] = input.reshape(-1,input_shape[1] * input_shape[2] * input_shape[3])
elif len(input_shape) == 5:
dim = input_shape[2]*input_shape[3]*input_shape[4]
batch = ms.numpy.arange(input_shape[0]).reshape(-1,1,1).broadcast_to((-1, input_shape[1], dim))
row = ms.numpy.arange(input_shape[1]).reshape(-1,1).broadcast_to((input_shape[0], -1, dim))
indices = indices.reshape(input_shape[0],-1, dim)
out = out.view((output_size[0], -1, output_size[2] * output_size[3] * output_size[4]))
out[batch,row,indices] = input.reshape((input_shape[0], -1, dim))

return out.reshape(output_size)

def linear(input, weight, bias=None):
@constexpr


+ 150
- 0
testing/ut/pytorch/nn/functional/test_maxunpool.py View File

@@ -0,0 +1,150 @@
import numpy as np
import torch
import mindspore as ms
import ms_adapter.pytorch as ms_torch

def test_maxunpool1d_with_2dim():
ms.context.set_context(mode=ms.PYNATIVE_MODE)
N = np.random.randint(1, 65)
C = np.random.randint(1, 513)
tensor = np.random.randn(N, C).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, C + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool1d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool1d(torch_pooling, torch_indices, kernel_size, stride, padding)

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool1d(ms_pooling, ms_indices, kernel_size, stride, padding)

assert np.allclose(ms_output.numpy(), torch_output.numpy())


def test_maxunpool1d_with_3dim():
ms.context.set_context(mode=ms.PYNATIVE_MODE)
B = np.random.randint(1, 65)
N = np.random.randint(1, 65)
C = np.random.randint(1, 513)
tensor = np.random.randn(B, N, C).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, C + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool1d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool1d(torch_pooling, torch_indices, kernel_size, stride, padding)

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool1d(ms_pooling, ms_indices, kernel_size, stride, padding)

assert np.allclose(ms_output.numpy(), torch_output.numpy())


def test_maxunpool2d_with_3dim():
ms.context.set_context(mode=ms.PYNATIVE_MODE)
H = W = np.random.randint(1, 65)

C = np.random.randint(1, 513)
tensor = np.random.randn(C, H, W).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, H + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool2d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool2d(torch_pooling, torch_indices, kernel_size, stride, padding)

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool2d(ms_pooling, ms_indices, kernel_size, stride, padding)

assert np.allclose(ms_output.numpy(), torch_output.numpy())


def test_maxunpool2d_with_4dim():
ms.context.set_context(mode=ms.PYNATIVE_MODE)
H = W = np.random.randint(1, 65)
C = np.random.randint(1, 513)
N = np.random.randint(1, 32)
tensor = np.random.randn(N, C, H, W).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, H + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool2d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool2d(torch_pooling, torch_indices, kernel_size, stride, padding)

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool2d(ms_pooling, ms_indices, kernel_size, stride, padding)

assert np.allclose(ms_output.numpy(), torch_output.numpy())


def test_maxunpool3d_with_4dim():
ms.context.set_context(mode=ms.PYNATIVE_MODE)
H = W = np.random.randint(1, 65)
C = np.random.randint(1, 1025)
D = np.random.randint(1, H + 1)
tensor = np.random.randn(C, D, H, W).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, D + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool3d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool3d(torch_pooling, torch_indices, kernel_size, stride, padding)

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool3d(ms_pooling, ms_indices, kernel_size, stride, padding)

assert np.allclose(ms_output.numpy(), torch_output.numpy())


def test_maxunpool3d_with_5dim():
ms.context.set_context(mode=ms.PYNATIVE_MODE)
H = W = np.random.randint(1, 33)
C = np.random.randint(1, 513)
D = np.random.randint(1, H + 1)
N = np.random.randint(1, 17)
tensor = np.random.randn(N, C, D, H, W).astype(np.float32)

torch_tensor = torch.tensor(tensor)
kernel_size = np.random.randint(1, D + 1)
padding = np.random.randint(0, kernel_size/2 + 1)
stride = 1

torch_pooling, torch_indices = torch.nn.functional.max_pool3d(torch_tensor, kernel_size, stride, padding, return_indices=True)
torch_output = torch.nn.functional.max_unpool3d(torch_pooling, torch_indices, kernel_size, stride, padding)

ms_pooling = ms_torch.tensor(torch_pooling.numpy())
ms_indices = ms_torch.tensor(torch_indices.numpy())

ms_output = ms_torch.nn.functional.max_unpool3d(ms_pooling, ms_indices, kernel_size, stride, padding)

assert np.allclose(ms_output.numpy(), torch_output.numpy())

if __name__ == '__main__':
test_maxunpool1d_with_2dim()
test_maxunpool1d_with_3dim()
test_maxunpool2d_with_3dim()
test_maxunpool2d_with_4dim()
test_maxunpool3d_with_4dim()
test_maxunpool3d_with_5dim()


Loading…
Cancel
Save