|
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as f
-
- class Space2Depth(nn.Module):
- def __init__(self, r):
- super(Space2Depth, self).__init__()
- self.r = r
-
- def forward(self, x):
- r = self.r
- b, c, h, w = x.size()
- out_c = c * (r**2)
- out_h = h//2
- out_w = w//2
- x_view = x.view(b, c, out_h, r, out_w, r)
- x_prime = x_view.permute(0,3,5,1,2,4).contiguous().view(b, out_c, out_h, out_w)
- return x_prime
-
- class Depth2Space(nn.Module):
- def __init__(self, r):
- super(Depth2Space, self).__init__()
- self.r = r
- def forward(self, x):
- r = self.r
- b, c, h, w = x.size()
- out_c = c // (r**2)
- out_h = h * 2
- out_w = w * 2
- x_view = x.view(b, r, r, out_c, h, w)
- x_prime = x_view.permute(0,3,4,1,5,2).contiguous().view(b, out_c, out_h, out_w)
- return x_prime
-
- class GroupConv(nn.Module):
- def __init__(self, in_dim, num_filters, num_split, k, s, p):
- super(GroupConv, self).__init__()
- sub_in_dim = in_dim // num_split
- sub_filter = num_filters // num_split
- self.num_split = num_split
- self.sub_in_dim = sub_in_dim
- self.convs = nn.ModuleList([nn.Conv2d(sub_in_dim, sub_filter, k, s, p) for i in range(num_split)])
-
- def forward(self, inputs):
- splits = torch.split(inputs, self.sub_in_dim, dim=1)
- y = torch.cat([self.convs[i](splits[i]) for i in range(len(self.convs))], 1)
- return y
-
- class h_analysisTransform(nn.Module):
- def __init__(self, in_dim, num_filters, strides_list, conv_trainable=True):
- super(h_analysisTransform, self).__init__()
- self.transform = nn.Sequential(
- nn.Conv2d(in_dim, num_filters[0], 3, 1, 1),
- Space2Depth(2),
- GroupConv(num_filters[0]*4, num_filters[1], 16, 1, 1, 0),
- nn.ReLU(),
- GroupConv(num_filters[1], num_filters[2], 4, 1, 1, 0),
- nn.ReLU(),
- nn.Conv2d(num_filters[2], num_filters[3], 1, 1, 0)
- )
-
- def forward(self, inputs):
- x = self.transform(inputs)
- return x
-
- class h_synthesisTransform(nn.Module):
- def __init__(self, in_dim, num_filters, strides_list, conv_trainable=True):
- super(h_synthesisTransform, self).__init__()
- self.transform = nn.Sequential(
- nn.ConvTranspose2d(in_dim, num_filters[0], 1, strides_list[2], 0),
- GroupConv(num_filters[0], num_filters[1], 16, 1, 1, 0),
- nn.ReLU(),
- GroupConv(num_filters[1], num_filters[2], 4, 1, 1, 0),
- nn.ReLU(),
- Depth2Space(2),
- nn.ZeroPad2d((0,0,0,0)),
- nn.ConvTranspose2d(num_filters[2]//4, num_filters[3], 3, strides_list[0], 1)
- )
-
- def forward(self, inputs):
- x = self.transform(inputs)
- return x
-
|