|
- import math
- from functools import partial
-
- import mindspore
- from mindspore import nn
- #from torch.nn import functional as F
- import mindspore.ops as ops
- from mindspore.ops import Custom
-
-
- class SwishImplementation(Custom):
- @staticmethod
- def forward(ctx, i):
- result = i * nn.Sigmoid(i)
- ctx.save_for_backward(i)
- return result
-
- @staticmethod
- def backward(ctx, grad_output):
- i = ctx.saved_variables[0]
- sigmoid_i = nn.Sigmoid(i)
- return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
-
- class MemoryEfficientSwish(nn.Cell):
- def forward(self, x):
- return SwishImplementation.apply(x)
-
-
- def drop_connect(inputs, p, training):
- """ Drop connect. """
- if not training: return inputs
- batch_size = inputs.shape[0]
- keep_prob = 1 - p
- random_tensor = keep_prob
- random_tensor += ops.UniformReal([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
- binary_tensor = ops.Floor(random_tensor)
- output = inputs / keep_prob * binary_tensor
- return output
-
-
- def get_same_padding_conv2d(image_size=None):
- return partial(Conv2dStaticSamePadding, image_size=image_size)
-
- def get_width_and_height_from_size(x):
- """ Obtains width and height from a int or tuple """
- if isinstance(x, int): return x, x
- if isinstance(x, list) or isinstance(x, tuple): return x
- else: raise TypeError()
-
- def calculate_output_image_size(input_image_size, stride):
- """
- 计算出 Conv2dSamePadding with a stride.
- """
- if input_image_size is None: return None
- image_height, image_width = get_width_and_height_from_size(input_image_size)
- stride = stride if isinstance(stride, int) else stride[0]
- image_height = int(math.ceil(image_height / stride))
- image_width = int(math.ceil(image_width / stride))
- return [image_height, image_width]
-
-
-
- class Conv2dStaticSamePadding(nn.Conv2d):
- """ 2D Convolutions like TensorFlow, for a fixed image size"""
-
- def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
- super().__init__(in_channels, out_channels, kernel_size, **kwargs)
- self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
-
- # Calculate padding based on image size and save it
- assert image_size is not None
- ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
- kh, kw = self.weight.size()[-2:]
- sh, sw = self.stride
- oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
- pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
- pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
- if pad_h > 0 or pad_w > 0:
- self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
- else:
- self.static_padding = Identity()
-
- def forward(self, x):
- x = self.static_padding(x)
- x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
- return x
-
- class Identity(nn.Module):
- def __init__(self, ):
- super(Identity, self).__init__()
-
- def forward(self, input):
- return input
-
-
- # MBConvBlock
- class MBConvBlock(nn.Module):
- '''
- 层 ksize3*3 输入32 输出16 conv1 stride步长1
- '''
- def __init__(self, ksize, input_filters, output_filters, expand_ratio=1, stride=1, image_size=224):
- super().__init__()
- self._bn_mom = 0.1
- self._bn_eps = 0.01
- self._se_ratio = 0.25
- self._input_filters = input_filters
- self._output_filters = output_filters
- self._expand_ratio = expand_ratio
- self._kernel_size = ksize
- self._stride = stride
-
- inp = self._input_filters
- oup = self._input_filters * self._expand_ratio
- if self._expand_ratio != 1:
- Conv2d = get_same_padding_conv2d(image_size=image_size)
- self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
- self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
-
-
- # Depthwise convolution
- k = self._kernel_size
- s = self._stride
- Conv2d = get_same_padding_conv2d(image_size=image_size)
- self._depthwise_conv = Conv2d(
- in_channels=oup, out_channels=oup, groups=oup,
- kernel_size=k, stride=s, bias=False)
- self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
- image_size = calculate_output_image_size(image_size, s)
-
- # Squeeze and Excitation layer, if desired
- Conv2d = get_same_padding_conv2d(image_size=(1,1))
- num_squeezed_channels = max(1, int(self._input_filters * self._se_ratio))
- self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
- self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
-
- # Output phase
- final_oup = self._output_filters
- Conv2d = get_same_padding_conv2d(image_size=image_size)
- self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
- self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
- self._swish = MemoryEfficientSwish()
-
- def forward(self, inputs, drop_connect_rate=None):
- """
- :param inputs: input tensor
- :param drop_connect_rate: drop connect rate (float, between 0 and 1)
- :return: output of block
- """
-
- # Expansion and Depthwise Convolution
- x = inputs
- if self._expand_ratio != 1:
- expand = self._expand_conv(inputs)
- bn0 = self._bn0(expand)
- x = self._swish(bn0)
- depthwise = self._depthwise_conv(x)
- bn1 = self._bn1(depthwise)
- x = self._swish(bn1)
-
- # Squeeze and Excitation
- x_squeezed = F.adaptive_avg_pool2d(x, 1)
- x_squeezed = self._se_reduce(x_squeezed)
- x_squeezed = self._swish(x_squeezed)
- x_squeezed = self._se_expand(x_squeezed)
- x = nn.Sigmoid(x_squeezed) * x
-
- x = self._bn2(self._project_conv(x))
-
- # Skip connection and drop connect
- input_filters, output_filters = self._input_filters, self._output_filters
- if self._stride == 1 and input_filters == output_filters:
- if drop_connect_rate:
- x = drop_connect(x, p=drop_connect_rate, training=self.training)
- x = x + inputs # skip connection
- return x
-
- if __name__ == '__main__':
- input=ops.StandardNormal([1,3,112,112])
- mbconv=MBConvBlock(ksize=3,input_filters=3,output_filters=3,image_size=112)
- out=mbconv(input)
- print(out.shape)
|