|
|
@@ -0,0 +1,69 @@ |
|
|
|
#!/usr/bin/env python |
|
|
|
# coding: utf-8 |
|
|
|
# |
|
|
|
# Author: Kazuto Nakashima |
|
|
|
# URL: http://kazuto1011.github.io |
|
|
|
# Created: 2017-11-19 |
|
|
|
|
|
|
|
from __future__ import absolute_import, print_function |
|
|
|
|
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
from .resnet import _ConvBnReLU, _ResLayer, _Stem |
|
|
|
|
|
|
|
|
|
|
|
class _ASPP(nn.Module): |
|
|
|
""" |
|
|
|
Atrous spatial pyramid pooling (ASPP) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, in_ch, out_ch, rates): |
|
|
|
super(_ASPP, self).__init__() |
|
|
|
for i, rate in enumerate(rates): |
|
|
|
self.add_module( |
|
|
|
"c{}".format(i), |
|
|
|
nn.Conv2d(in_ch, out_ch, 3, 1, padding=rate, dilation=rate, bias=True), |
|
|
|
) |
|
|
|
|
|
|
|
for m in self.children(): |
|
|
|
nn.init.normal_(m.weight, mean=0, std=0.01) |
|
|
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
return sum([stage(x) for stage in self.children()]) |
|
|
|
|
|
|
|
|
|
|
|
class DeepLabV2(nn.Sequential): |
|
|
|
""" |
|
|
|
DeepLab v2: Dilated ResNet + ASPP |
|
|
|
Output stride is fixed at 8 |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, n_classes, n_blocks, atrous_rates): |
|
|
|
super(DeepLabV2, self).__init__() |
|
|
|
ch = [64 * 2 ** p for p in range(6)] |
|
|
|
self.add_module("layer1", _Stem(ch[0])) |
|
|
|
self.add_module("layer2", _ResLayer(n_blocks[0], ch[0], ch[2], 1, 1)) |
|
|
|
self.add_module("layer3", _ResLayer(n_blocks[1], ch[2], ch[3], 2, 1)) |
|
|
|
self.add_module("layer4", _ResLayer(n_blocks[2], ch[3], ch[4], 1, 2)) |
|
|
|
self.add_module("layer5", _ResLayer(n_blocks[3], ch[4], ch[5], 1, 4)) |
|
|
|
self.add_module("aspp", _ASPP(ch[5], n_classes, atrous_rates)) |
|
|
|
|
|
|
|
def freeze_bn(self): |
|
|
|
for m in self.modules(): |
|
|
|
if isinstance(m, _ConvBnReLU.BATCH_NORM): |
|
|
|
m.eval() |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
model = DeepLabV2( |
|
|
|
n_classes=21, n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24] |
|
|
|
) |
|
|
|
model.eval() |
|
|
|
image = torch.randn(1, 3, 513, 513) |
|
|
|
|
|
|
|
print(model) |
|
|
|
print("input:", image.shape) |
|
|
|
print("output:", model(image).shape) |