|
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as f
- from torch.distributions.uniform import Uniform
-
- from Model.basic_module import Non_local_Block, ResBlock, ScalingNet
- from Model.context_model import P_Model
- from Model.factorized_entropy_model import Entropy_bottleneck
- from Model.gaussian_entropy_model import Distribution_for_entropy
-
- # Multi-hyper Model
- from Model.hyper_module import h_analysisTransform, h_synthesisTransform
-
- from Util.config import dict
- USE_VR_MODEL = dict['USE_VR_MODEL']
- USE_MULTI_HYPER = dict['USE_MULTI_HYPER']
-
- class Enc(nn.Module):
- def __init__(self, num_features, N1, N2, M, M1):
- #input_features = 3, N1 = 192, N2 = 128, M = 192, M1 = 96
- super(Enc, self).__init__()
- self.N1 = int(N1)
- self.N2 = int(N2)
- self.M = int(M)
- self.M1 = int(M1)
- self.n_features = int(num_features)
-
- self.conv1 = nn.Conv2d(self.n_features, self.M1, 5, 1, 2)
- self.trunk1 = nn.Sequential(ResBlock(self.M1, self.M1, 3, 1, 1), ResBlock(
- self.M1, self.M1, 3, 1, 1), nn.Conv2d(self.M1, 2*self.M1, 5, 2, 2))
- if USE_VR_MODEL:
- self.scaling1 = ScalingNet(2 * self.M1)
-
- self.down1 = nn.Conv2d(2*self.M1, self.M, 5, 2, 2)
- self.trunk2 = nn.Sequential(ResBlock(2*self.M1, 2*self.M1, 3, 1, 1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- ResBlock(2*self.M1, 2*self.M1, 3, 1, 1))
- self.mask1 = nn.Sequential(Non_local_Block(2*self.M1, self.M1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- ResBlock(
- 2*self.M1, 2*self.M1, 3, 1, 1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- nn.Conv2d(2*self.M1, 2*self.M1, 1, 1, 0))
- if USE_VR_MODEL:
- self.scaling2 = ScalingNet(self.M)
- self.trunk3 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 5, 2, 2))
- if USE_VR_MODEL:
- self.scaling3 = ScalingNet(self.M)
-
- self.trunk4 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 5, 2, 2))
- if USE_VR_MODEL:
- self.scaling4 = ScalingNet(self.M)
-
- self.trunk5 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1))
- self.mask2 = nn.Sequential(Non_local_Block(self.M, self.M // 2), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 1, 1, 0))
- if USE_VR_MODEL:
- self.scaling5 = ScalingNet(self.M)
-
- # VR_MODEL --> single hyper model
-
- if USE_MULTI_HYPER:
- if M == 192: # low and middle bitrate range
- self.hyper_1_enc = h_analysisTransform(M, [384, 1536, 512, 256], [1, 1, 1])
- self.hyper_2_enc = h_analysisTransform(256, [64*4, 64*4*4, 32*4, 128], [1, 1, 1])
- elif M == 256: # high bitrate range
- self.hyper_1_enc = h_analysisTransform(M, [384*2, 1536*2, 512*2, 256], [1, 1, 1])
- self.hyper_2_enc = h_analysisTransform(64*4, [64*4*2, 64*4*4*2, 32*4*2, 128], [1, 1, 1])
- else:
- # UNDEFINED MODEL HYPER PARAMETER
- raise Exception("UNDEFINED MODEL HYPER PARAMETER, M should be 192 or 256, but current M is "+str(M))
- else:
- self.trunk6 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- nn.Conv2d(self.M, self.M, 5, 2, 2))
- self.trunk7 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- nn.Conv2d(self.M, self.M, 5, 2, 2))
-
- self.trunk8 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1))
- self.mask3 = nn.Sequential(Non_local_Block(self.M, self.M // 2), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 1, 1, 0))
- self.conv2 = nn.Conv2d(self.M, self.N2, 3, 1, 1)
-
-
- def forward(self, x, lambda_rd=None):
- x1 = self.conv1(x)
- x2 = self.trunk1(x1)
- if USE_VR_MODEL and (lambda_rd is not None):
- x2 = self.scaling1(x2, lambda_rd)
- x3 = self.trunk2(x2)+x2
- x3 = self.down1(x3)
- if USE_VR_MODEL and (lambda_rd is not None):
- x3 = self.scaling2(x3, lambda_rd)
- x4 = self.trunk3(x3)
- if USE_VR_MODEL and (lambda_rd is not None):
- x4 = self.scaling3(x4, lambda_rd)
- x5 = self.trunk4(x4)
- if USE_VR_MODEL and (lambda_rd is not None):
- x5 = self.scaling4(x5, lambda_rd)
- x6 = self.trunk5(x5)*f.sigmoid(self.mask2(x5)) + x5
- if USE_VR_MODEL and (lambda_rd is not None):
- x6 = self.scaling5(x6, lambda_rd)
-
- if USE_MULTI_HYPER:
- x7 = self.hyper_1_enc(x6) # hyper_1
- x8 = self.hyper_2_enc(x7) # hyper_2
- return x6, x7, x8
- else:
- # hyper
- x7 = self.trunk6(x6)
- x8 = self.trunk7(x7)
- x9 = self.trunk8(x8)*f.sigmoid(self.mask3(x8)) + x8
- x10 = self.conv2(x9)
- return x6, x10
-
-
- class Hyper_Dec(nn.Module):
- def __init__(self, N2, M):
- super(Hyper_Dec, self).__init__()
-
- self.N2 = N2
- self.M = M
- self.conv1 = nn.Conv2d(self.N2, M, 3, 1, 1)
- self.trunk1 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1))
- self.mask1 = nn.Sequential(Non_local_Block(self.M, self.M // 2), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 1, 1, 0))
-
- self.trunk2 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- nn.ConvTranspose2d(M, M, 5, 2, 2, 1))
- self.trunk3 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- nn.ConvTranspose2d(M, M, 5, 2, 2, 1))
-
- def forward(self, xq2):
-
- x1 = self.conv1(xq2)
- x2 = self.trunk1(x1) * f.sigmoid(self.mask1(x1)) + x1
- x3 = self.trunk2(x2)
- x4 = self.trunk3(x3)
-
- return x4
-
-
- class Dec(nn.Module):
- def __init__(self, input_features, N1, M, M1):
- super(Dec, self).__init__()
-
- self.N1 = N1
- self.M = M
- self.M1 = M1
- self.input = input_features
-
- if USE_VR_MODEL:
- self.scaling1 = ScalingNet(self.M)
- self.trunk1 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1))
- self.mask1 = nn.Sequential(Non_local_Block(self.M, self.M // 2), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 1, 1, 0))
-
- self.up1 = nn.ConvTranspose2d(M, M, 5, 2, 2, 1)
- if USE_VR_MODEL:
- self.scaling2 = ScalingNet(self.M)
- self.trunk2 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.ConvTranspose2d(M, M, 5, 2, 2, 1))
- if USE_VR_MODEL:
- self.scaling3 = ScalingNet(self.M)
- self.trunk3 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.ConvTranspose2d(M, 2*self.M1, 5, 2, 2, 1))
- if USE_VR_MODEL:
- self.scaling4 = ScalingNet(2 * self.M1)
- self.trunk4 = nn.Sequential(ResBlock(2*self.M1, 2*self.M1, 3, 1, 1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- ResBlock(2*self.M1, 2*self.M1, 3, 1, 1))
- self.mask2 = nn.Sequential(Non_local_Block(2*self.M1, self.M1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- ResBlock(2*self.M1, 2*self.M1, 3, 1, 1), nn.Conv2d(2*self.M1, 2*self.M1, 1, 1, 0))
- if USE_VR_MODEL:
- self.scaling5 = ScalingNet(2 * self.M1)
- self.trunk5 = nn.Sequential(nn.ConvTranspose2d(2*M1, M1, 5, 2, 2, 1), ResBlock(self.M1, self.M1, 3, 1, 1), ResBlock(self.M1, self.M1, 3, 1, 1),
- ResBlock(self.M1, self.M1, 3, 1, 1))
-
- self.conv1 = nn.Conv2d(self.M1, self.input, 5, 1, 2)
-
- def forward(self, x, lambda_rd=None):
- if USE_VR_MODEL and (lambda_rd is not None):
- x = self.scaling1(x, lambda_rd)
- x1 = self.trunk1(x)*f.sigmoid(self.mask1(x))+x
- x1 = self.up1(x1)
- if USE_VR_MODEL and (lambda_rd is not None):
- x1 = self.scaling2(x1, lambda_rd)
- x2 = self.trunk2(x1)
- if USE_VR_MODEL and (lambda_rd is not None):
- x2 = self.scaling3(x2, lambda_rd)
- x3 = self.trunk3(x2)
- if USE_VR_MODEL and (lambda_rd is not None):
- x3 = self.scaling4(x3, lambda_rd)
- x4 = self.trunk4(x3)+x3
- if USE_VR_MODEL and (lambda_rd is not None):
- x4 = self.scaling5(x4, lambda_rd)
- #print (x4.size())
- x5 = self.trunk5(x4)
- output = self.conv1(x5)
- return output
-
-
- class Image_coding(nn.Module):
- def __init__(self, input_features, N1, N2, M, M1):
- #input_features = 3, N1 = 192, N2 = 128, M = 192, M1 = 96
- super(Image_coding, self).__init__()
- self.N1 = N1
- self.encoder = Enc(input_features, N1, N2, M, M1)
- self.factorized_entropy_func = Entropy_bottleneck(N2)
- self.hyper_dec = Hyper_Dec(N2, M)
- self.p = P_Model(M)
- self.gaussin_entropy_func = Distribution_for_entropy()
- self.decoder = Dec(input_features, N1, M, M1)
-
- def add_noise(self, x):
- noise = np.random.uniform(-0.5, 0.5, x.size())
- noise = torch.Tensor(noise).cuda()
- return x + noise
-
- def forward(self, x, if_training, lambda_rd=None):
- x1, x2 = self.encoder(x, lambda_rd)
- xq2, xp2 = self.factorized_entropy_func(x2, if_training)
- x3 = self.hyper_dec(xq2)
- hyper_dec = self.p(x3)
- if if_training == 0:
- xq1 = self.add_noise(x1)
- elif if_training == 1:
- xq1 = UniverseQuant.apply(x1)
- else:
- xq1 = torch.round(x1)
- xp1 = self.gaussin_entropy_func(xq1, hyper_dec)
-
- output = self.decoder(xq1, lambda_rd)
-
- return [output, xp1, xp2, xq1, hyper_dec]
-
-
-
- class Image_coding_multi_hyper(nn.Module):
- def __init__(self, input_features, N1, N2, M, M1):
- #input_features = 3, N1 = 192, N2 = 128, M = 192, M1 = 96
- super(Image_coding_multi_hyper, self).__init__()
- self.N1 = N1
- self.encoder = Enc(input_features, N1, N2, M, M1)
- self.factorized_entropy_func = Entropy_bottleneck(128)
-
- if M == 192:
- self.hyper_1_dec = h_synthesisTransform(256, [768, 768, 768, M], [1, 1, 1])
- self.hyper_2_dec = h_synthesisTransform(128, [64*4, 64*4, 64*4, 64*4], [1, 1, 1])
- elif M == 256:
- self.hyper_1_dec = h_synthesisTransform(256, [768*2, 768*2, 768*2, M], [1, 1, 1])
- self.hyper_2_dec = h_synthesisTransform(128, [64*4*2, 64*4*2, 64*4*2, 64*4], [1, 1, 1])
-
- self.p = P_Model(M)
- self.p_2 = P_Model(256)
- self.factorized_entropy_func_for_hyper = Entropy_bottleneck(256) # STAGE 1 ONLY
- self.gaussin_entropy_func_for_hyper = Distribution_for_entropy()
- self.gaussin_entropy_func = Distribution_for_entropy()
- self.decoder = Dec(input_features, N1, M, M1)
-
- def add_noise(self, x):
- noise = np.random.uniform(-0.5, 0.5, x.size())
- noise = torch.Tensor(noise).cuda()
- return x + noise
-
- def forward(self, x, if_training, stage=-1):
- x1, x2, x3 = self.encoder(x)
- if stage == 1:
- xq2, xp2 = self.factorized_entropy_func_for_hyper(x2, if_training)
- else:
- xq3, xp3 = self.factorized_entropy_func(x3, if_training)
-
- x4 = self.hyper_2_dec(xq3)
-
- hyper_2_dec = self.p_2(x4)
-
- if if_training == 0:
- xq2 = self.add_noise(x2)
- elif if_training == 1:
- xq2 = UniverseQuant.apply(x2)
- else:
- xq2 = torch.round(x2)
-
- xp2 = self.gaussin_entropy_func_for_hyper(xq2, hyper_2_dec)
-
- x5 = self.hyper_1_dec(xq2)
- hyper_dec = self.p(x5)
-
- if if_training == 0:
- xq1 = self.add_noise(x1)
- elif if_training == 1:
- xq1 = UniverseQuant.apply(x1)
- else:
- xq1 = torch.round(x1)
- xp1 = self.gaussin_entropy_func(xq1, hyper_dec)
-
- output = self.decoder(xq1)
-
-
- if stage == 1:
- return output, xp1, xp2, xq1, hyper_dec
- else:
- return output, xp1, xp2, xq1, hyper_dec, xp3, xq3
-
-
- class UniverseQuant(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x):
- b = np.random.uniform(-1, 1)
- #b = 0
- uniform_distribution = Uniform(-0.5*torch.ones(x.size())
- * (2**b), 0.5*torch.ones(x.size())*(2**b)).sample().cuda()
- return torch.round(x+uniform_distribution)-uniform_distribution
-
- @staticmethod
- def backward(ctx, g):
-
- return g
|