|
- # Copyright (c) Open-MMLab. All rights reserved.
- import logging
-
- import torch.nn as nn
- import torch.utils.checkpoint as cp
-
- from ..runner import load_checkpoint
- from .utils import constant_init, kaiming_init
-
-
- def conv3x3(in_planes, out_planes, stride=1, dilation=1):
- """3x3 convolution with padding."""
- return nn.Conv2d(
- in_planes,
- out_planes,
- kernel_size=3,
- stride=stride,
- padding=dilation,
- dilation=dilation,
- bias=False)
-
-
- class BasicBlock(nn.Module):
- expansion = 1
-
- def __init__(self,
- inplanes,
- planes,
- stride=1,
- dilation=1,
- downsample=None,
- style='pytorch',
- with_cp=False):
- super(BasicBlock, self).__init__()
- assert style in ['pytorch', 'caffe']
- self.conv1 = conv3x3(inplanes, planes, stride, dilation)
- self.bn1 = nn.BatchNorm2d(planes)
- self.relu = nn.ReLU(inplace=True)
- self.conv2 = conv3x3(planes, planes)
- self.bn2 = nn.BatchNorm2d(planes)
- self.downsample = downsample
- self.stride = stride
- self.dilation = dilation
- assert not with_cp
-
- def forward(self, x):
- residual = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
-
- if self.downsample is not None:
- residual = self.downsample(x)
-
- out += residual
- out = self.relu(out)
-
- return out
-
-
- class Bottleneck(nn.Module):
- expansion = 4
-
- def __init__(self,
- inplanes,
- planes,
- stride=1,
- dilation=1,
- downsample=None,
- style='pytorch',
- with_cp=False):
- """Bottleneck block.
-
- If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
- it is "caffe", the stride-two layer is the first 1x1 conv layer.
- """
- super(Bottleneck, self).__init__()
- assert style in ['pytorch', 'caffe']
- if style == 'pytorch':
- conv1_stride = 1
- conv2_stride = stride
- else:
- conv1_stride = stride
- conv2_stride = 1
- self.conv1 = nn.Conv2d(
- inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False)
- self.conv2 = nn.Conv2d(
- planes,
- planes,
- kernel_size=3,
- stride=conv2_stride,
- padding=dilation,
- dilation=dilation,
- bias=False)
-
- self.bn1 = nn.BatchNorm2d(planes)
- self.bn2 = nn.BatchNorm2d(planes)
- self.conv3 = nn.Conv2d(
- planes, planes * self.expansion, kernel_size=1, bias=False)
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
- self.relu = nn.ReLU(inplace=True)
- self.downsample = downsample
- self.stride = stride
- self.dilation = dilation
- self.with_cp = with_cp
-
- def forward(self, x):
-
- def _inner_forward(x):
- residual = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
- out = self.relu(out)
-
- out = self.conv3(out)
- out = self.bn3(out)
-
- if self.downsample is not None:
- residual = self.downsample(x)
-
- out += residual
-
- return out
-
- if self.with_cp and x.requires_grad:
- out = cp.checkpoint(_inner_forward, x)
- else:
- out = _inner_forward(x)
-
- out = self.relu(out)
-
- return out
-
-
- def make_res_layer(block,
- inplanes,
- planes,
- blocks,
- stride=1,
- dilation=1,
- style='pytorch',
- with_cp=False):
- downsample = None
- if stride != 1 or inplanes != planes * block.expansion:
- downsample = nn.Sequential(
- nn.Conv2d(
- inplanes,
- planes * block.expansion,
- kernel_size=1,
- stride=stride,
- bias=False),
- nn.BatchNorm2d(planes * block.expansion),
- )
-
- layers = []
- layers.append(
- block(
- inplanes,
- planes,
- stride,
- dilation,
- downsample,
- style=style,
- with_cp=with_cp))
- inplanes = planes * block.expansion
- for _ in range(1, blocks):
- layers.append(
- block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp))
-
- return nn.Sequential(*layers)
-
-
- class ResNet(nn.Module):
- """ResNet backbone.
-
- Args:
- depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
- num_stages (int): Resnet stages, normally 4.
- strides (Sequence[int]): Strides of the first block of each stage.
- dilations (Sequence[int]): Dilation of each stage.
- out_indices (Sequence[int]): Output from which stages.
- style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
- layer is the 3x3 conv layer, otherwise the stride-two layer is
- the first 1x1 conv layer.
- frozen_stages (int): Stages to be frozen (all param fixed). -1 means
- not freezing any parameters.
- bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
- running stats (mean and var).
- bn_frozen (bool): Whether to freeze weight and bias of BN layers.
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
- memory while slowing down the training speed.
- """
-
- arch_settings = {
- 18: (BasicBlock, (2, 2, 2, 2)),
- 34: (BasicBlock, (3, 4, 6, 3)),
- 50: (Bottleneck, (3, 4, 6, 3)),
- 101: (Bottleneck, (3, 4, 23, 3)),
- 152: (Bottleneck, (3, 8, 36, 3))
- }
-
- def __init__(self,
- depth,
- num_stages=4,
- strides=(1, 2, 2, 2),
- dilations=(1, 1, 1, 1),
- out_indices=(0, 1, 2, 3),
- style='pytorch',
- frozen_stages=-1,
- bn_eval=True,
- bn_frozen=False,
- with_cp=False):
- super(ResNet, self).__init__()
- if depth not in self.arch_settings:
- raise KeyError(f'invalid depth {depth} for resnet')
- assert num_stages >= 1 and num_stages <= 4
- block, stage_blocks = self.arch_settings[depth]
- stage_blocks = stage_blocks[:num_stages]
- assert len(strides) == len(dilations) == num_stages
- assert max(out_indices) < num_stages
-
- self.out_indices = out_indices
- self.style = style
- self.frozen_stages = frozen_stages
- self.bn_eval = bn_eval
- self.bn_frozen = bn_frozen
- self.with_cp = with_cp
-
- self.inplanes = 64
- self.conv1 = nn.Conv2d(
- 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
- self.bn1 = nn.BatchNorm2d(64)
- self.relu = nn.ReLU(inplace=True)
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
-
- self.res_layers = []
- for i, num_blocks in enumerate(stage_blocks):
- stride = strides[i]
- dilation = dilations[i]
- planes = 64 * 2**i
- res_layer = make_res_layer(
- block,
- self.inplanes,
- planes,
- num_blocks,
- stride=stride,
- dilation=dilation,
- style=self.style,
- with_cp=with_cp)
- self.inplanes = planes * block.expansion
- layer_name = f'layer{i + 1}'
- self.add_module(layer_name, res_layer)
- self.res_layers.append(layer_name)
-
- self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
-
- def init_weights(self, pretrained=None):
- if isinstance(pretrained, str):
- logger = logging.getLogger()
- load_checkpoint(self, pretrained, strict=False, logger=logger)
- elif pretrained is None:
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- kaiming_init(m)
- elif isinstance(m, nn.BatchNorm2d):
- constant_init(m, 1)
- else:
- raise TypeError('pretrained must be a str or None')
-
- def forward(self, x):
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.maxpool(x)
- outs = []
- for i, layer_name in enumerate(self.res_layers):
- res_layer = getattr(self, layer_name)
- x = res_layer(x)
- if i in self.out_indices:
- outs.append(x)
- if len(outs) == 1:
- return outs[0]
- else:
- return tuple(outs)
-
- def train(self, mode=True):
- super(ResNet, self).train(mode)
- if self.bn_eval:
- for m in self.modules():
- if isinstance(m, nn.BatchNorm2d):
- m.eval()
- if self.bn_frozen:
- for params in m.parameters():
- params.requires_grad = False
- if mode and self.frozen_stages >= 0:
- for param in self.conv1.parameters():
- param.requires_grad = False
- for param in self.bn1.parameters():
- param.requires_grad = False
- self.bn1.eval()
- self.bn1.weight.requires_grad = False
- self.bn1.bias.requires_grad = False
- for i in range(1, self.frozen_stages + 1):
- mod = getattr(self, f'layer{i}')
- mod.eval()
- for param in mod.parameters():
- param.requires_grad = False
|