#39 add adaptivemaxpool

Merged
hanjr merged 1 commits from hanjr-patch into master 1 year ago
  1. +3
    -0
      ms_adapter/pytorch/nn/modules/__init__.py
  2. +108
    -5
      ms_adapter/pytorch/nn/modules/pooling.py
  3. +192
    -2
      testing/layers/test_adaptivepool.py

+ 3
- 0
ms_adapter/pytorch/nn/modules/__init__.py View File

@@ -22,6 +22,9 @@ __all__ = [
'AdaptiveAvgPool1d',
'AdaptiveAvgPool2d',
'AdaptiveAvgPool3d',
'AdaptiveMaxPool1d',
'AdaptiveMaxPool2d',
'AdaptiveMaxPool3d',
'MaxPool1d',
'MaxPool2d',
'MaxPool3d',


+ 108
- 5
ms_adapter/pytorch/nn/modules/pooling.py View File

@@ -11,7 +11,8 @@ from .module import Module

__all__ = ['MaxPool1d', 'MaxPool2d', 'MaxPool3d',
'AvgPool1d', 'AvgPool2d', 'AvgPool3d',
'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d']
'AdaptiveAvgPool1d', 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d',
'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d']

class _MaxPoolNd(Module):
def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False):
@@ -184,16 +185,16 @@ class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
self.output_size = output_size
self.shape = F.shape

def construct(self, x):
_, _, width = self.shape(x)
def construct(self, input):
_, _, width = self.shape(input)
stride = width // self.output_size
kernel_size = width - (self.output_size - 1) * stride
stride = (1, width // self.output_size)
kernel_size = (1, kernel_size)

max_pool = P.AvgPool(kernel_size=kernel_size, strides=stride, pad_mode="valid", data_format="NCHW")
x = self.expand(x, 2)
x = max_pool(x)
input = self.expand(input, 2)
x = max_pool(input)
x = self.squeeze(x)

return x
@@ -268,3 +269,105 @@ class AdaptiveAvgPool3d(_AdaptiveAvgPoolNd):
)
outputs = avg_pool(input)
return outputs


class _AdaptiveMaxPoolNd(Module):
__constants__ = ['output_size', 'return_indices']
return_indices: bool

def __init__(self, output_size, return_indices = False):
super(_AdaptiveMaxPoolNd, self).__init__()
self.output_size = output_size
self.return_indices = return_indices

def extra_repr(self) -> str:
return 'output_size={}'.format(self.output_size)



class AdaptiveMaxPool1d(_AdaptiveMaxPoolNd):

def __init__(self, output_size, return_indices = False):
"""Initialize AdaptiveMaxPool1d."""
super(AdaptiveMaxPool1d, self).__init__(output_size, return_indices)
self.expand = P.ExpandDims()
self.squeeze = P.Squeeze(2)
self.output_size = output_size
self.shape = F.shape
self.return_indices = return_indices

def construct(self, input):
_, _, width = self.shape(input)
stride = width // self.output_size
kernel_size = width - (self.output_size - 1) * stride
stride = (1, width // self.output_size)
kernel_size = (1, kernel_size)
if self.return_indices:
max_pool = P.MaxPoolWithArgmax(kernel_size=kernel_size, strides=stride, pad_mode='valid', data_format="NCHW")
x = self.expand(input, 2)
x, idx = max_pool(x)
x = self.squeeze(x)
idx = self.squeeze(idx)
return (x, idx)
else:
max_pool = P.AvgPool(kernel_size=kernel_size, strides=stride, pad_mode="valid", data_format="NCHW")
x = self.expand(input, 2)
x = max_pool(x)
x = self.squeeze(x)

return x


class AdaptiveMaxPool2d(_AdaptiveMaxPoolNd):
def __init__(self, output_size, return_indices = False):
super(AdaptiveMaxPool2d, self).__init__(output_size, return_indices)

self.adaptive_max_pool2d = P.nn_ops.AdaptiveMaxPool2D(output_size, return_indices)

def forward(self, input):

return self.adaptive_max_pool2d(input)


class AdaptiveMaxPool3d(_AdaptiveMaxPoolNd):
def __init__(self, output_size, return_indices = False):
super(AdaptiveMaxPool3d, self).__init__(output_size, return_indices)
self.output_size = output_size
self.shape = P.Shape()
if not isinstance(self.output_size, Iterable):
self.output_size = [self.output_size, ] * 3
self.condition = [0,] * 3
if None in self.output_size:
self.output_size = list(self.output_size)
if self.output_size[0] == None:
self.condition [0] = 1
self.output_size[0] = 0
if self.output_size[1] == None:
self.condition [1] = 1
self.output_size[1] = 0
if self.output_size[2] == None:
self.condition[2] = 1
self.output_size[2] = 0
if return_indices:
raise NotImplementedError('AdaptiveMaxPool3d doesn\'t support return_indices now.')


def forward(self, input):
n, c, d, h, w = self.shape(input)
out_d = self.output_size[0] + self.condition[0] * d
out_h = self.output_size[1] + self.condition[1] * h
out_w = self.output_size[2] + self.condition[2] * w
stride_d = d // out_d
kernel_d = d - (out_d - 1) * stride_d
stride_h = h // out_h
kernel_h = h - (out_h - 1) * stride_h
stride_w = w // out_w
kernel_w = w - (out_w - 1) * stride_w
avg_pool = P.MaxPool3D(
kernel_size=(kernel_d, kernel_h, kernel_w), strides=(stride_d, stride_h, stride_w), pad_mode="valid",
data_format="NCDHW"
)
outputs = avg_pool(input)
return outputs



+ 192
- 2
testing/layers/test_adaptivepool.py View File

@@ -3,7 +3,7 @@ import torch

from mindspore import Tensor
from ms_adapter.pytorch.nn import AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d
from ms_adapter.pytorch.nn import AdaptiveMaxPool1d, AdaptiveMaxPool2d, AdaptiveMaxPool3d
from mindspore import context
import mindspore as ms
context.set_context(mode=ms.GRAPH_MODE)
@@ -157,6 +157,182 @@ def test_adaptiveavgpool3d_compare5():
assert np.allclose(ms_output.shape, torch_output.shape)



def test_adaptivemaxpool2d_compare1():
ms_net = AdaptiveMaxPool2d((3, 7))
torch_net = torch.nn.AdaptiveMaxPool2d((3, 7))

data = np.random.random((1, 64, 10, 9))
ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output = ms_net(ms_input)
torch_output = torch_net(torch_input)
print(ms_output.shape, torch_output.shape)
assert np.allclose(ms_output.shape, torch_output.shape)

def test_adaptivemaxpool2d_compare2():
ms_net = AdaptiveMaxPool2d(4)
torch_net = torch.nn.AdaptiveMaxPool2d(4)

data = np.random.random((1, 3, 10, 9))
ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output = ms_net(ms_input)
torch_output = torch_net(torch_input)
print(ms_output.shape, torch_output.shape)
assert np.allclose(ms_output.shape, torch_output.shape)

def test_adaptivemaxpool2d_compare3():
ms_net = AdaptiveMaxPool2d((4, None))
torch_net = torch.nn.AdaptiveMaxPool2d((4, None))

data = np.random.random((1, 3, 10, 9))
ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output = ms_net(ms_input)
torch_output = torch_net(torch_input)
print(ms_output.shape, torch_output.shape)
assert np.allclose(ms_output.shape, torch_output.shape)

def test_adaptivemaxpool2d_compare4():
ms_net = AdaptiveMaxPool2d((None, 4))
torch_net = torch.nn.AdaptiveMaxPool2d((None, 4))

data = np.random.random((1, 3, 10, 9))
ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output = ms_net(ms_input)
torch_output = torch_net(torch_input)
print(ms_output.shape, torch_output.shape)
assert np.allclose(ms_output.shape, torch_output.shape)

def test_adaptivemaxpool2d_compare5():
ms_net = AdaptiveMaxPool2d((None, None))
torch_net = torch.nn.AdaptiveMaxPool2d((None, None))

data = np.random.random((1, 3, 10, 9))
ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output = ms_net(ms_input)
torch_output = torch_net(torch_input)
print(ms_output.shape, torch_output.shape)
assert np.allclose(ms_output.shape, torch_output.shape)

def test_adaptivemaxpool2d_compare6():
ms_net = AdaptiveMaxPool2d((3, 7), return_indices=True)
torch_net = torch.nn.AdaptiveMaxPool2d((3, 7), return_indices=True)

data = np.random.random((1, 64, 10, 9))
ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output = ms_net(ms_input)
torch_output = torch_net(torch_input)
print(ms_output[0].shape, torch_output[0].shape)
print(ms_output[1].shape, torch_output[1].shape)
assert np.allclose(ms_output[0].shape, torch_output[0].shape)
assert np.allclose(ms_output[1].shape, torch_output[1].shape)

def test_adaptivemaxpool1d_compare1():
ms_net = AdaptiveMaxPool1d(3)
torch_net = torch.nn.AdaptiveMaxPool1d(3)

data = np.random.random((1, 10, 9))
ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output = ms_net(ms_input)
torch_output = torch_net(torch_input)
print(ms_output.shape, torch_output.shape)
assert np.allclose(ms_output.shape, torch_output.shape)

def test_adaptivemaxpool1d_compare2():
ms_net = AdaptiveMaxPool1d(3, True)
torch_net = torch.nn.AdaptiveMaxPool1d(3, True)

data = np.random.random((1, 10, 9))
ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output = ms_net(ms_input)
torch_output = torch_net(torch_input)
print(ms_output[0].shape, torch_output[0].shape)
print(ms_output[1].shape, torch_output[1].shape)
assert np.allclose(ms_output[0].shape, torch_output[0].shape)
assert np.allclose(ms_output[1].shape, torch_output[1].shape)

def test_adaptivemaxpool3d_compare1():
ms_net = AdaptiveMaxPool3d((3, 4, 5))
torch_net = torch.nn.AdaptiveMaxPool3d((3, 4, 5))

data = np.random.random((1, 3, 10, 9, 12))
ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output = ms_net(ms_input)
torch_output = torch_net(torch_input)
print(ms_output.shape, torch_output.shape)
assert np.allclose(ms_output.shape, torch_output.shape)

def test_adaptivemaxpool3d_compare2():
ms_net = AdaptiveMaxPool3d(3)
torch_net = torch.nn.AdaptiveMaxPool3d(3)

data = np.random.random((1, 3, 10, 9, 12))
ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output = ms_net(ms_input)
torch_output = torch_net(torch_input)
print(ms_output.shape, torch_output.shape)
assert np.allclose(ms_output.shape, torch_output.shape)

def test_adaptivemaxpool3d_compare3():
ms_net = AdaptiveMaxPool3d(3)
torch_net = torch.nn.AdaptiveMaxPool3d(3)

data = np.random.random((1, 3, 10, 9, 12))
ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output = ms_net(ms_input)
torch_output = torch_net(torch_input)
print(ms_output.shape, torch_output.shape)
assert np.allclose(ms_output.shape, torch_output.shape)

def test_adaptivemaxpool3d_compare4():
ms_net = AdaptiveMaxPool3d((3, None, 5))
torch_net = torch.nn.AdaptiveMaxPool3d((3, None, 5))

data = np.random.random((1, 3, 10, 9, 12))
ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output = ms_net(ms_input)
torch_output = torch_net(torch_input)
print(ms_output.shape, torch_output.shape)
assert np.allclose(ms_output.shape, torch_output.shape)

def test_adaptivemaxpool3d_compare5():
ms_net = AdaptiveMaxPool3d((None, None, 5))
torch_net = torch.nn.AdaptiveMaxPool3d((None, None, 5))

data = np.random.random((1, 3, 10, 9, 12))
ms_input = Tensor(data.astype(np.float32))
torch_input = torch.Tensor(data)

ms_output = ms_net(ms_input)
torch_output = torch_net(torch_input)
print(ms_output.shape, torch_output.shape)
assert np.allclose(ms_output.shape, torch_output.shape)



test_adaptiveavgpool2d_compare1()
test_adaptiveavgpool2d_compare2()
test_adaptiveavgpool2d_compare3()
@@ -167,4 +343,18 @@ test_adaptiveavgpool3d_compare1()
test_adaptiveavgpool3d_compare2()
test_adaptiveavgpool3d_compare3()
test_adaptiveavgpool3d_compare4()
test_adaptiveavgpool3d_compare5()
test_adaptiveavgpool3d_compare5()

test_adaptivemaxpool2d_compare1()
test_adaptivemaxpool2d_compare2()
test_adaptivemaxpool2d_compare3()
test_adaptivemaxpool2d_compare4()
test_adaptivemaxpool2d_compare5()
test_adaptivemaxpool2d_compare6()
test_adaptivemaxpool1d_compare1()
test_adaptivemaxpool1d_compare2()
test_adaptivemaxpool3d_compare1()
test_adaptivemaxpool3d_compare2()
test_adaptivemaxpool3d_compare3()
test_adaptivemaxpool3d_compare4()
test_adaptivemaxpool3d_compare5()

Loading…
Cancel
Save