|
- import torch
- import torch.nn.functional as F
- import torch.nn as nn
-
- # 定义残差块
- # class ResBlk(nn.Module):
- # def __init__(self, ch_in, ch_out, stride):
- # super(ResBlk, self).__init__()
- # self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
- # self.bn1 = nn.BatchNorm2d(ch_out)
- # self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
- # self.bn2 = nn.BatchNorm2d(ch_out)
- #
- # self.extra = nn.Sequential(
- # nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
- # nn.BatchNorm2d(ch_out)
- # )
- #
- # def forward(self, x):
- # out = F.relu(self.bn1(self.conv1(x)))
- # out = self.bn2(self.conv2(out))
- # out = self.extra(x) + out
- # out = F.relu(out)
- # return out
- #
- # # 定义ResNet18网络结构
- # class ResNet18(nn.Module):
- # def __init__(self):
- # super(ResNet18, self).__init__()
- # self.conv1 = nn.Sequential(
- # nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
- # nn.BatchNorm2d(64)
- # )
- # self.blk1 = ResBlk(64, 64, stride=2)
- # self.blk2 = ResBlk(64, 128, stride=2)
- # self.blk3 = ResBlk(128, 256, stride=2)
- # self.blk4 = ResBlk(256, 512, stride=2)
- #
- # self.outlayer = nn.Linear(512*1*1, 10)
- #
- # def forward(self, x):
- #
- # x = F.relu(self.conv1(x))
- #
- # x = self.blk1(x)
- # x = self.blk2(x)
- # x = self.blk3(x)
- # x = self.blk4(x)
- #
- # x = F.adaptive_avg_pool2d(x, [1, 1])
- # x = x.view(x.size(0), -1)
- # x = self.outlayer(x)
- #
- # return x
-
- class ResBlk(nn.Module):
- expansion = 1
-
- def __init__(self, in_channels, out_channels, stride=1, downsample=False):
- super().__init__()
-
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
- stride=stride, padding=1, bias=False)
- self.bn1 = nn.BatchNorm2d(out_channels)
-
- self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
- stride=1, padding=1, bias=False)
- self.bn2 = nn.BatchNorm2d(out_channels)
-
- self.relu = nn.ReLU(inplace=True)
-
- if downsample:
- conv = nn.Conv2d(in_channels, out_channels, kernel_size=1,
- stride=stride, bias=False)
- bn = nn.BatchNorm2d(out_channels)
- downsample = nn.Sequential(conv, bn)
- else:
- downsample = None
-
- self.downsample = downsample
-
- def forward(self, x):
-
- i = x
-
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
-
- x = self.conv2(x)
- x = self.bn2(x)
-
- if self.downsample is not None:
- i = self.downsample(i)
-
- x += i
- x = self.relu(x)
-
- return x
-
- # 定义ResNet18网络结构
- class ResNet(nn.Module):
- def __init__(self, config, output_dim):
- super().__init__()
-
- block, n_blocks, channels = config
- self.in_channels = channels[0]
-
- assert len(n_blocks) == len(channels) == 4
-
- self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
- self.bn1 = nn.BatchNorm2d(self.in_channels)
- self.relu = nn.ReLU(inplace=True)
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
-
- self.layer1 = self.get_resnet_layer(block, n_blocks[0], channels[0])
- self.layer2 = self.get_resnet_layer(block, n_blocks[1], channels[1], stride=2)
- self.layer3 = self.get_resnet_layer(block, n_blocks[2], channels[2], stride=2)
- self.layer4 = self.get_resnet_layer(block, n_blocks[3], channels[3], stride=2)
-
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
- self.fc = nn.Linear(self.in_channels, output_dim)
-
- def get_resnet_layer(self, block, n_blocks, channels, stride=1):
-
- layers = []
-
- if self.in_channels != block.expansion * channels:
- downsample = True
- else:
- downsample = False
-
- layers.append(block(self.in_channels, channels, stride, downsample))
-
- for i in range(1, n_blocks):
- layers.append(block(block.expansion * channels, channels))
-
- self.in_channels = block.expansion * channels
-
- return nn.Sequential(*layers)
-
- def forward(self, x):
-
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
- x = self.maxpool(x)
-
- x = self.layer1(x)
- x = self.layer2(x)
- x = self.layer3(x)
- x = self.layer4(x)
-
- x = self.avgpool(x)
- h = x.view(x.shape[0], -1)
- x = self.fc(h)
-
- return x
|