|
- from collections import OrderedDict
- import math
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.utils import model_zoo
- from torchvision.models.densenet import densenet121, densenet161
- from torchvision.models.squeezenet import squeezenet1_1
-
-
- def load_weights_sequential(target, source_state):
- new_dict = OrderedDict()
- for (k1, v1), (k2, v2) in zip(target.state_dict().items(), source_state.items()):
- new_dict[k1] = v2
- target.load_state_dict(new_dict)
-
- '''
- Implementation of dilated ResNet-101 with deep supervision. Downsampling is changed to 8x
- '''
- model_urls = {
- 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
- 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
- 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
- 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
- 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
- }
-
-
- def conv3x3(in_planes, out_planes, stride=1, dilation=1):
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
- padding=dilation, dilation=dilation, bias=False)
-
-
- class BasicBlock(nn.Module):
- expansion = 1
-
- def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
- super(BasicBlock, self).__init__()
- self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
- self.bn1 = nn.BatchNorm2d(planes)
- self.relu = nn.ReLU(inplace=True)
- self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
- self.bn2 = nn.BatchNorm2d(planes)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x):
- residual = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
-
- if self.downsample is not None:
- residual = self.downsample(x)
-
- out += residual
- out = self.relu(out)
-
- return out
-
-
- class Bottleneck(nn.Module):
- expansion = 4
-
- def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
- super(Bottleneck, self).__init__()
- self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
- self.bn1 = nn.BatchNorm2d(planes)
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation,
- padding=dilation, bias=False)
- self.bn2 = nn.BatchNorm2d(planes)
- self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
- self.bn3 = nn.BatchNorm2d(planes * 4)
- self.relu = nn.ReLU(inplace=True)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x):
- residual = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
- out = self.relu(out)
-
- out = self.conv3(out)
- out = self.bn3(out)
-
- if self.downsample is not None:
- residual = self.downsample(x)
-
- out += residual
- out = self.relu(out)
-
- return out
-
-
- class ResNet(nn.Module):
- def __init__(self, block, layers=(3, 4, 23, 3)):
- self.inplanes = 64
- super(ResNet, self).__init__()
- self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
- bias=False)
- self.bn1 = nn.BatchNorm2d(64)
- self.relu = nn.ReLU(inplace=True)
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
- self.layer1 = self._make_layer(block, 64, layers[0])
- self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
- self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
- self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
-
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
- m.weight.data.normal_(0, math.sqrt(2. / n))
- elif isinstance(m, nn.BatchNorm2d):
- m.weight.data.fill_(1)
- m.bias.data.zero_()
-
- def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
- downsample = None
- if stride != 1 or self.inplanes != planes * block.expansion:
- downsample = nn.Sequential(
- nn.Conv2d(self.inplanes, planes * block.expansion,
- kernel_size=1, stride=stride, bias=False),
- nn.BatchNorm2d(planes * block.expansion),
- )
-
- layers = [block(self.inplanes, planes, stride, downsample)]
- self.inplanes = planes * block.expansion
- for i in range(1, blocks):
- layers.append(block(self.inplanes, planes, dilation=dilation))
-
- return nn.Sequential(*layers)
-
- def forward(self, x):
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.maxpool(x)
-
- x = self.layer1(x)
- x = self.layer2(x)
- x_3 = self.layer3(x)
- x = self.layer4(x_3)
-
- return x, x_3
-
-
- '''
- Implementation of DenseNet with deep supervision. Downsampling is changed to 8x
- '''
-
-
- class _DenseLayer(nn.Sequential):
- def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, index):
- super(_DenseLayer, self).__init__()
- self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
- self.add_module('relu1', nn.ReLU(inplace=True)),
- if index == 3:
- self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
- growth_rate, kernel_size=1, stride=1, bias=False)),
- self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
- self.add_module('relu2', nn.ReLU(inplace=True)),
- self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
- kernel_size=3, stride=1, dilation=2, padding=2, bias=False)),
- else:
- self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
- growth_rate, kernel_size=1, stride=1, bias=False)),
- self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
- self.add_module('relu2', nn.ReLU(inplace=True)),
- self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
- kernel_size=3, stride=1, padding=1, bias=False)),
- self.drop_rate = drop_rate
-
- def forward(self, x):
- new_features = super(_DenseLayer, self).forward(x)
- if self.drop_rate > 0:
- new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
- return torch.cat([x, new_features], 1)
-
-
- class _DenseBlock(nn.Sequential):
- def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, index):
- super(_DenseBlock, self).__init__()
- for i in range(num_layers):
- layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate, index)
- self.add_module('denselayer%d' % (i + 1), layer)
-
-
- class _Transition(nn.Sequential):
- def __init__(self, num_input_features, num_output_features, downsample=True):
- super(_Transition, self).__init__()
- self.add_module('norm', nn.BatchNorm2d(num_input_features))
- self.add_module('relu', nn.ReLU(inplace=True))
- self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
- kernel_size=1, stride=1, bias=False))
- if downsample:
- self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
- else:
- self.add_module('pool', nn.AvgPool2d(kernel_size=1, stride=1)) # compatibility hack
-
-
- class DenseNet(nn.Module):
- def __init__(self, growth_rate=8, block_config=(6, 12, 24, 16),
- num_init_features=16, bn_size=4, drop_rate=0, pretrained=False):
-
- super(DenseNet, self).__init__()
-
- # First convolution
- self.start_features = nn.Sequential(OrderedDict([
- ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
- ('norm0', nn.BatchNorm2d(num_init_features)),
- ('relu0', nn.ReLU(inplace=True)),
- ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
- ]))
-
- # Each denseblock
- num_features = num_init_features
-
- init_weights = list(densenet121(pretrained=True).features.children())
- start = 0
- for i, c in enumerate(self.start_features.children()):
- #if pretrained:
- #c.load_state_dict(init_weights[i].state_dict())
- start += 1
- self.blocks = nn.ModuleList()
- for i, num_layers in enumerate(block_config):
- block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
- bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, index = i)
- if pretrained:
- block.load_state_dict(init_weights[start].state_dict())
- start += 1
- self.blocks.append(block)
- setattr(self, 'denseblock%d' % (i + 1), block)
-
- num_features = num_features + num_layers * growth_rate
- if i != len(block_config) - 1:
- downsample = i < 1
- trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2,
- downsample=downsample)
- if pretrained:
- trans.load_state_dict(init_weights[start].state_dict())
- start += 1
- self.blocks.append(trans)
- setattr(self, 'transition%d' % (i + 1), trans)
- num_features = num_features // 2
-
- def forward(self, x):
- out = self.start_features(x)
- deep_features = None
- for i, block in enumerate(self.blocks):
- out = block(out)
- if i == 5:
- deep_features = out
-
- return out, deep_features
-
-
- class Fire(nn.Module):
-
- def __init__(self, inplanes, squeeze_planes,
- expand1x1_planes, expand3x3_planes, dilation=1):
- super(Fire, self).__init__()
- self.inplanes = inplanes
- self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
- self.squeeze_activation = nn.ReLU(inplace=True)
- self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
- kernel_size=1)
- self.expand1x1_activation = nn.ReLU(inplace=True)
- self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
- kernel_size=3, padding=dilation, dilation=dilation)
- self.expand3x3_activation = nn.ReLU(inplace=True)
-
- def forward(self, x):
- x = self.squeeze_activation(self.squeeze(x))
- return torch.cat([
- self.expand1x1_activation(self.expand1x1(x)),
- self.expand3x3_activation(self.expand3x3(x))
- ], 1)
-
-
- class SqueezeNet(nn.Module):
-
- def __init__(self, pretrained=False):
- super(SqueezeNet, self).__init__()
-
- self.feat_1 = nn.Sequential(
- nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
- nn.ReLU(inplace=True)
- )
- self.feat_2 = nn.Sequential(
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
- Fire(64, 16, 64, 64),
- Fire(128, 16, 64, 64)
- )
- self.feat_3 = nn.Sequential(
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
- Fire(128, 32, 128, 128, 2),
- Fire(256, 32, 128, 128, 2)
- )
- self.feat_4 = nn.Sequential(
- Fire(256, 48, 192, 192, 4),
- Fire(384, 48, 192, 192, 4),
- Fire(384, 64, 256, 256, 4),
- Fire(512, 64, 256, 256, 4)
- )
- if pretrained:
- weights = squeezenet1_1(pretrained=True).features.state_dict()
- load_weights_sequential(self, weights)
-
- def forward(self, x):
- f1 = self.feat_1(x)
- f2 = self.feat_2(f1)
- f3 = self.feat_3(f2)
- f4 = self.feat_4(f3)
- return f4, f3
-
-
- '''
- Handy methods for construction
- '''
-
-
- def squeezenet(pretrained=True):
- return SqueezeNet(pretrained)
-
-
- def densenet(pretrained=True):
- return DenseNet(pretrained=pretrained)
-
-
- def resnet18(pretrained=True):
- model = ResNet(BasicBlock, [2, 2, 2, 2])
- if pretrained:
- load_weights_sequential(model, model_zoo.load_url(model_urls['resnet18']))
- return model
-
-
- def resnet34(pretrained=True):
- model = ResNet(BasicBlock, [3, 4, 6, 3])
- if pretrained:
- load_weights_sequential(model, model_zoo.load_url(model_urls['resnet34']))
- return model
-
-
- def resnet50(pretrained=True):
- model = ResNet(Bottleneck, [3, 4, 6, 3])
- if pretrained:
- load_weights_sequential(model, model_zoo.load_url(model_urls['resnet50']))
- return model
-
-
- def resnet101(pretrained=True):
- model = ResNet(Bottleneck, [3, 4, 23, 3])
- if pretrained:
- load_weights_sequential(model, model_zoo.load_url(model_urls['resnet101']))
- return model
-
-
- def resnet152(pretrained=True):
- model = ResNet(Bottleneck, [3, 8, 36, 3])
- if pretrained:
- load_weights_sequential(model, model_zoo.load_url(model_urls['resnet152']))
- return model
|