|
- import torch
- import torch.nn as nn
- import torch.nn.functional as f
-
- from Model.GDN_transform import GDN
-
-
- class ResGDN(nn.Module):
- def __init__(self, in_channel, out_channel, kernel_size, stride, padding, inv=False):
- super(ResGDN, self).__init__()
- self.in_ch = int(in_channel)
- self.out_ch = int(out_channel)
- self.k = int(kernel_size)
- self.stride = int(stride)
- self.padding = int(padding)
- self.inv = bool(inv)
- self.conv1 = nn.Conv2d(self.in_ch, self.out_ch,
- self.k, self.stride, self.padding)
- self.conv2 = nn.Conv2d(self.in_ch, self.out_ch,
- self.k, self.stride, self.padding)
- self.ac1 = GDN(self.in_ch, self.inv)
- self.ac2 = GDN(self.in_ch, self.inv)
-
- def forward(self, x):
- x1 = self.ac1(self.conv1(x))
- x2 = self.conv2(x1)
- out = self.ac2(x + x2)
- return out
-
-
- class ResBlock(nn.Module):
- def __init__(self, in_channel, out_channel, kernel_size, stride, padding):
- super(ResBlock, self).__init__()
- self.in_ch = int(in_channel)
- self.out_ch = int(out_channel)
- self.k = int(kernel_size)
- self.stride = int(stride)
- self.padding = int(padding)
-
- self.conv1 = nn.Conv2d(self.in_ch, self.out_ch,
- self.k, self.stride, self.padding)
- self.conv2 = nn.Conv2d(self.in_ch, self.out_ch,
- self.k, self.stride, self.padding)
-
- def forward(self, x):
- x1 = self.conv2(f.relu(self.conv1(x)))
- out = x+x1
- return out
-
- # here use embedded gaussian
-
-
- class Non_local_Block(nn.Module):
- def __init__(self, in_channel, out_channel):
- super(Non_local_Block, self).__init__()
- self.in_channel = in_channel
- self.out_channel = out_channel
- self.g = nn.Conv2d(self.in_channel, self.out_channel, 1, 1, 0)
- self.theta = nn.Conv2d(self.in_channel, self.out_channel, 1, 1, 0)
- self.phi = nn.Conv2d(self.in_channel, self.out_channel, 1, 1, 0)
- self.W = nn.Conv2d(self.out_channel, self.in_channel, 1, 1, 0)
- nn.init.constant(self.W.weight, 0)
- nn.init.constant(self.W.bias, 0)
-
- def forward(self, x):
- # x_size: (b c h w)
-
- batch_size = x.size(0)
- g_x = self.g(x).view(batch_size, self.out_channel, -1)
- g_x = g_x.permute(0, 2, 1)
- theta_x = self.theta(x).view(batch_size, self.out_channel, -1)
- theta_x = theta_x.permute(0, 2, 1)
- phi_x = self.phi(x).view(batch_size, self.out_channel, -1)
-
- f1 = torch.matmul(theta_x, phi_x)
- f_div_C = f.softmax(f1, dim=-1)
- y = torch.matmul(f_div_C, g_x)
- y = y.permute(0, 2, 1).contiguous()
- y = y.view(batch_size, self.out_channel, *x.size()[2:])
- W_y = self.W(y)
- z = W_y+x
-
- return z
-
- class ScalingNet(nn.Module):
- def __init__(self, channel):
- super(ScalingNet, self).__init__()
- self.channel = int(channel)
-
- self.fc1 = nn.Linear(1, channel // 2, bias=True)
- self.fc2 = nn.Linear(channel // 2, channel, bias=True)
- nn.init.constant(self.fc2.weight, 0)
- nn.init.constant(self.fc2.bias, 0)
-
- def forward(self, x, lambda_rd):
- b, c, _, _ = x.size()
- scaling_vector = torch.exp(self.fc2(f.relu(self.fc1(lambda_rd))))
- scaling_vector = scaling_vector.view(b, c, 1, 1)
- x_scaled = x * scaling_vector.expand_as(x)
- return x_scaled
|