|
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.init as init
- import torch.nn.functional as f
- from torch.distributions.uniform import Uniform
- from Model.context_model import Weighted_Gaussian
- from Model.basic_module import Non_local_Block, ResBlock
- from Model.context_model import P_Model
- from Model.factorized_entropy_model import Entropy_bottleneck
- from Model.gaussian_entropy_model import Distribution_for_entropy
- from Model.sign_conv2d import SignConv2d
- from Model.gdn import GDN2d
-
-
- class lmd_map(nn.Module):
- def __init__(self,latent_channels):
- super(lmd_map, self).__init__()
-
- self.lc = int(latent_channels)
- self.fcn1 = nn.Linear(1, self.lc)
- self.fcn2 = nn.Linear(self.lc,self.lc)
- nn.init.constant(self.fcn2.weight,0)
- nn.init.constant(self.fcn2.bias,1)
-
- def forward(self, y):
-
- y1 = self.fcn1(y)
- y2 = self.fcn2(y1)
- mask = f.sigmoid(y2)
-
- return mask
-
- class SQL(nn.Module):
- def __init__(self, in_channels, latent_channels):
- super(SQL, self).__init__()
-
- self.m = int(in_channels)
- self.c = int(latent_channels)
-
- self.conv1 = nn.Conv2d(self.m, self.c, 1, 1, 0)
- self.conv2 = nn.Conv2d(self.c, self.c, 1, 1, 0)
- self.conv3 = nn.Conv2d(self.c, self.c, 1, 1, 0)
- self.conv4 = nn.Conv2d(self.c, self.c, 1, 1, 0)
- self.conv5 = nn.Conv2d(self.c, self.c, 1, 1, 0)
- self.conv6 = nn.Conv2d(self.c, self.c, 1, 1, 0)
- self.conv7 = nn.Conv2d(self.c, self.c, 1, 1, 0)
- self.conv8 = nn.Conv2d(self.c, self.m, 1, 1, 0)
-
- self.lmbd_map1 = lmd_map(self.c)
- self.lmbd_map2 = lmd_map(self.c)
- self.lmbd_map3 = lmd_map(self.c)
- self.lmbd_map4 = lmd_map(self.c)
- self.lmbd_map5 = lmd_map(self.c)
- self.lmbd_map6 = lmd_map(self.c)
- self.lmbd_map7 = lmd_map(self.c)
-
- def forward(self, x, lmd):
-
- b=x.size()[0]
- y=torch.ones((b,1)).cuda()
- y*=lmd
-
- mask1 = self.lmbd_map1(y)
- mask1 = torch.reshape(mask1, (b, self.c, 1, 1)).cuda()
- mask2 = self.lmbd_map2(y)
- mask2 = torch.reshape(mask2, (b, self.c, 1, 1)).cuda()
- mask3 = self.lmbd_map3(y)
- mask3 = torch.reshape(mask3, (b, self.c, 1, 1)).cuda()
- mask4 = self.lmbd_map4(y)
- mask4 = torch.reshape(mask4, (b, self.c, 1, 1)).cuda()
- mask5 = self.lmbd_map5(y)
- mask5 = torch.reshape(mask5, (b, self.c, 1, 1)).cuda()
- mask6 = self.lmbd_map6(y)
- mask6 = torch.reshape(mask6, (b, self.c, 1, 1)).cuda()
- mask7 = self.lmbd_map7(y)
- mask7 = torch.reshape(mask7, (b, self.c, 1, 1)).cuda()
-
- x1 = x
- x1 = self.conv1(x1)
- x2 = mask1 * x1
- x2 = self.conv2(x2)
- x3 = mask2 * x2
- x3 = self.conv3(x3)
- x4 = mask3 * x3
- x4 = self.conv4(x4)
- x5 = mask4 * x4
- x5 = self.conv5(x5)
- x6 = mask5 * x5
- x6 = self.conv6(x6)
- x7 = mask6 * x6
- x7 = self.conv7(x7)
- x8 = mask7 * x7
- x8 = self.conv8(x8)
-
- return x8
-
- class SQL_resi(nn.Module):
- def __init__(self, in_channels, latent_channels):
- super(SQL_resi, self).__init__()
-
- self.m = int(in_channels)
- self.c = int(latent_channels)
-
- self.conv1 = nn.Conv2d(self.m, self.c, 1, 1, 0)
- self.conv2 = nn.Conv2d(self.c, self.c, 1, 1, 0)
- self.conv3 = nn.Conv2d(self.c, self.c, 1, 1, 0)
- self.conv4 = nn.Conv2d(self.c, self.c, 1, 1, 0)
- self.conv5 = nn.Conv2d(self.c, self.c, 1, 1, 0)
- self.conv6 = nn.Conv2d(self.c, self.c, 1, 1, 0)
- self.conv7 = nn.Conv2d(self.c, self.c, 1, 1, 0)
- self.conv8 = nn.Conv2d(self.c, self.m, 1, 1, 0)
-
- self.lmbd_map1 = lmd_map(self.c)
- self.lmbd_map2 = lmd_map(self.c)
- self.lmbd_map3 = lmd_map(self.c)
- self.lmbd_map4 = lmd_map(self.c)
- self.lmbd_map5 = lmd_map(self.c)
- self.lmbd_map6 = lmd_map(self.c)
- self.lmbd_map7 = lmd_map(self.c)
-
- def forward(self, x, lmd):
-
- b=x.size()[0]
- y=torch.ones((b,1)).cuda()
- y*=lmd
-
- mask1 = self.lmbd_map1(y)
- mask1 = torch.reshape(mask1, (b, self.c, 1, 1)).cuda()
- mask2 = self.lmbd_map2(y)
- mask2 = torch.reshape(mask2, (b, self.c, 1, 1)).cuda()
- mask3 = self.lmbd_map3(y)
- mask3 = torch.reshape(mask3, (b, self.c, 1, 1)).cuda()
- mask4 = self.lmbd_map4(y)
- mask4 = torch.reshape(mask4, (b, self.c, 1, 1)).cuda()
- mask5 = self.lmbd_map5(y)
- mask5 = torch.reshape(mask5, (b, self.c, 1, 1)).cuda()
- mask6 = self.lmbd_map6(y)
- mask6 = torch.reshape(mask6, (b, self.c, 1, 1)).cuda()
- mask7 = self.lmbd_map7(y)
- mask7 = torch.reshape(mask7, (b, self.c, 1, 1)).cuda()
-
- x1 = x
- x1 = self.conv1(x1)
- x2 = mask1 * x1
- x2 = self.conv2(x2)
- x3 = mask2 * x2
- x3 = self.conv3(x3)
- x4 = mask3 * x3
- x4 = self.conv4(x4)
- x5 = mask4 * x4
- x5 = self.conv5(x5)
- x6 = mask5 * x5
- x6 = self.conv6(x6)
- x7 = mask6 * x6
- x7 = self.conv7(x7)
- x8 = mask7 * x7
- x8 = self.conv8(x8)
-
- return x*f.sigmoid(x8)
-
- class Enc(nn.Module):
- def __init__(self, num_features, N1, N2, M, M1):
-
- super(Enc, self).__init__()
- self.N1 = int(N1)
- self.N2 = int(N2)
- self.M = int(M)
- self.M1 = int(M1)
- self.n_features = int(num_features)
-
- self.conv1 = nn.Conv2d(self.n_features, self.M1, 5, 1, 2)
- self.trunk1 = nn.Sequential(ResBlock(self.M1, self.M1, 3, 1, 1), ResBlock(
- self.M1, self.M1, 3, 1, 1), nn.Conv2d(self.M1, 2*self.M1, 5, 2, 2))
-
- self.down1 = nn.Conv2d(2*self.M1, self.M, 5, 2, 2)
- self.trunk2 = nn.Sequential(ResBlock(2*self.M1, 2*self.M1, 3, 1, 1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- ResBlock(2*self.M1, 2*self.M1, 3, 1, 1))
- self.mask1 = nn.Sequential(Non_local_Block(2*self.M1, self.M1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- ResBlock(
- 2*self.M1, 2*self.M1, 3, 1, 1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- nn.Conv2d(2*self.M1, 2*self.M1, 1, 1, 0))
- self.trunk3 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 5, 2, 2))
-
- self.trunk4 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 5, 2, 2))
-
- self.trunk5 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1))
- self.mask2 = nn.Sequential(Non_local_Block(self.M, self.M // 2), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 1, 1, 0))
-
-
- self.trunk6 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- nn.Conv2d(self.M, self.M, 5, 2, 2))
- self.trunk7 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- nn.Conv2d(self.M, self.M, 5, 2, 2))
-
- self.trunk8 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1))
- self.mask3 = nn.Sequential(Non_local_Block(self.M, self.M // 2), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 1, 1, 0))
- self.conv2 = nn.Conv2d(self.M, self.N2, 3, 1, 1)
-
- def forward(self, x):
- x1 = self.conv1(x)
- x2 = self.trunk1(x1)
- x3 = self.trunk2(x2)+x2
- x3 = self.down1(x3)
- x4 = self.trunk3(x3)
- x5 = self.trunk4(x4)
- x6 = self.trunk5(x5)*f.sigmoid(self.mask2(x5)) + x5
-
- x7 = self.trunk6(x6)
- x8 = self.trunk7(x7)
- x9 = self.trunk8(x8)*f.sigmoid(self.mask3(x8)) + x8
- x10 = self.conv2(x9)
-
- return x6, x10
-
- class Enc_lfc(nn.Module):
- def __init__(self, num_features, N1, N2, M, M1):
-
- super(Enc_lfc, self).__init__()
- self.N1 = int(N1)
- self.N2 = int(N2)
- self.M = int(M)
- self.M1 = int(M1)
- self.n_features = int(num_features)
-
- self.local_fc_enc = nn.Conv2d(256, 256 * 16, 4, stride=4, groups=256,bias=False)
- init_lfc=np.zeros((256*16,16))
- for c in range(256*16):
- i = c%16
- init_lfc[c][i]=1
- init_lfc = torch.from_numpy(init_lfc)
- init_lfc = torch.reshape(init_lfc,(256*16,1,4,4)).float()
- self.local_fc_enc.weight = nn.Parameter(init_lfc)
-
- self.conv1 = nn.Conv2d(self.n_features, self.M1, 5, 1, 2)
- self.trunk1 = nn.Sequential(ResBlock(self.M1, self.M1, 3, 1, 1), ResBlock(
- self.M1, self.M1, 3, 1, 1), nn.Conv2d(self.M1, 2*self.M1, 5, 2, 2))
-
- self.down1 = nn.Conv2d(2*self.M1, self.M, 5, 2, 2)
- self.trunk2 = nn.Sequential(ResBlock(2*self.M1, 2*self.M1, 3, 1, 1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- ResBlock(2*self.M1, 2*self.M1, 3, 1, 1))
- self.mask1 = nn.Sequential(Non_local_Block(2*self.M1, self.M1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- ResBlock(
- 2*self.M1, 2*self.M1, 3, 1, 1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- nn.Conv2d(2*self.M1, 2*self.M1, 1, 1, 0))
- self.trunk3 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 5, 2, 2))
-
- self.trunk4 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 5, 2, 2))
-
- self.trunk5 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1))
- self.mask2 = nn.Sequential(Non_local_Block(self.M, self.M // 2), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 1, 1, 0))
-
-
- self.trunk6 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- nn.Conv2d(self.M, self.M, 5, 2, 2))
- self.trunk7 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- nn.Conv2d(self.M, self.M, 5, 2, 2))
-
- self.trunk8 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1))
- self.mask3 = nn.Sequential(Non_local_Block(self.M, self.M // 2), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 1, 1, 0))
- self.conv2 = nn.Conv2d(self.M, self.N2, 3, 1, 1)
-
- def forward(self, x):
- x1 = self.conv1(x)
- x2 = self.trunk1(x1)
- x3 = self.trunk2(x2)+x2
- x3 = self.down1(x3)
- x4 = self.trunk3(x3)
- x5 = self.trunk4(x4)
- x6 = self.trunk5(x5)*f.sigmoid(self.mask2(x5)) + x5
-
- b, c, h, w = x6.size()
- fold = nn.Fold(output_size=(h, w), kernel_size=4, stride=4)
-
- uncorrelated = self.local_fc_enc(x6)
- uncorrelated = torch.reshape(uncorrelated, (b, c * 16, h * w // 16))
- uncorrelated = fold(uncorrelated)
-
- x6 = f.relu(uncorrelated)
-
- x7 = self.trunk6(x6)
- x8 = self.trunk7(x7)
- x9 = self.trunk8(x8)*f.sigmoid(self.mask3(x8)) + x8
- x10 = self.conv2(x9)
-
- return x6, x10
-
- class Enc_SQL(nn.Module):
- def __init__(self, num_features, N1, N2, M, M1):
-
- super(Enc_SQL, self).__init__()
- self.N1 = int(N1)
- self.N2 = int(N2)
- self.M = int(M)
- self.M1 = int(M1)
- self.n_features = int(num_features)
-
- self.conv1 = nn.Conv2d(self.n_features, self.M1, 5, 1, 2)
- self.trunk1 = nn.Sequential(ResBlock(self.M1, self.M1, 3, 1, 1), ResBlock(
- self.M1, self.M1, 3, 1, 1), nn.Conv2d(self.M1, 2*self.M1, 5, 2, 2))
-
- self.down1 = nn.Conv2d(2*self.M1, self.M, 5, 2, 2)
- self.trunk2 = nn.Sequential(ResBlock(2*self.M1, 2*self.M1, 3, 1, 1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- ResBlock(2*self.M1, 2*self.M1, 3, 1, 1))
- self.mask1 = nn.Sequential(Non_local_Block(2*self.M1, self.M1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- ResBlock(
- 2*self.M1, 2*self.M1, 3, 1, 1), ResBlock(2*self.M1, 2*self.M1, 3, 1, 1),
- nn.Conv2d(2*self.M1, 2*self.M1, 1, 1, 0))
- self.trunk3 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 5, 2, 2))
-
- self.trunk4 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 5, 2, 2))
-
- self.trunk5 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1))
- self.mask2 = nn.Sequential(Non_local_Block(self.M, self.M // 2), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 1, 1, 0))
-
- self.SQL=SQL_resi(256,256)
-
- self.trunk6 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- nn.Conv2d(self.M, self.M, 5, 2, 2))
- self.trunk7 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- nn.Conv2d(self.M, self.M, 5, 2, 2))
-
- self.trunk8 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1))
- self.mask3 = nn.Sequential(Non_local_Block(self.M, self.M // 2), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 1, 1, 0))
- self.conv2 = nn.Conv2d(self.M, self.N2, 3, 1, 1)
-
- def forward(self, x,lmd):
- x1 = self.conv1(x)
- x2 = self.trunk1(x1)
- x3 = self.trunk2(x2)+x2
- x3 = self.down1(x3)
- x4 = self.trunk3(x3)
- x5 = self.trunk4(x4)
- x6 = self.trunk5(x5)*f.sigmoid(self.mask2(x5)) + x5
-
- x6 = self.SQL(x6,lmd)
-
- x7 = self.trunk6(x6)
- x8 = self.trunk7(x7)
- x9 = self.trunk8(x8)*f.sigmoid(self.mask3(x8)) + x8
- x10 = self.conv2(x9)
-
- return x6, x10
-
-
- class Hyper_Dec(nn.Module):
- def __init__(self, N2, M):
- super(Hyper_Dec, self).__init__()
-
- self.N2 = N2
- self.M = M
- self.conv1 = nn.Conv2d(self.N2, M, 3, 1, 1)
- self.trunk1 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1))
- self.mask1 = nn.Sequential(Non_local_Block(self.M, self.M // 2), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 1, 1, 0))
-
- self.trunk2 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- nn.ConvTranspose2d(M, M, 5, 2, 2, 1))
- self.trunk3 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- nn.ConvTranspose2d(M, M, 5, 2, 2, 1))
-
- def forward(self, xq2):
- x1 = self.conv1(xq2)
- x2 = self.trunk1(x1) * f.sigmoid(self.mask1(x1)) + x1
- x3 = self.trunk2(x2)
- x4 = self.trunk3(x3)
-
- return x4
-
-
- class Dec(nn.Module):
- def __init__(self, input_features, N1, M, M1):
- super(Dec, self).__init__()
-
- self.N1 = N1
- self.M = M
- self.M1 = M1
- self.input = input_features
-
- self.trunk1 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1))
- self.mask1 = nn.Sequential(Non_local_Block(self.M, self.M // 2), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.Conv2d(self.M, self.M, 1, 1, 0))
-
- self.up1 = nn.ConvTranspose2d(M, M, 5, 2, 2, 1)
- self.trunk2 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.ConvTranspose2d(M, M, 5, 2, 2, 1))
- self.trunk3 = nn.Sequential(ResBlock(self.M, self.M, 3, 1, 1), ResBlock(self.M, self.M, 3, 1, 1),
- ResBlock(self.M, self.M, 3, 1, 1), nn.ConvTranspose2d(M, 2 * self.M1, 5, 2, 2, 1))
-
- self.trunk4 = nn.Sequential(ResBlock(2 * self.M1, 2 * self.M1, 3, 1, 1),
- ResBlock(2 * self.M1, 2 * self.M1, 3, 1, 1),
- ResBlock(2 * self.M1, 2 * self.M1, 3, 1, 1))
- self.mask2 = nn.Sequential(Non_local_Block(2 * self.M1, self.M1), ResBlock(2 * self.M1, 2 * self.M1, 3, 1, 1),
- ResBlock(2 * self.M1, 2 * self.M1, 3, 1, 1),
- ResBlock(2 * self.M1, 2 * self.M1, 3, 1, 1),
- nn.Conv2d(2 * self.M1, 2 * self.M1, 1, 1, 0))
-
- self.trunk5 = nn.Sequential(nn.ConvTranspose2d(2 * M1, M1, 5, 2, 2, 1), ResBlock(self.M1, self.M1, 3, 1, 1),
- ResBlock(self.M1, self.M1, 3, 1, 1),
- ResBlock(self.M1, self.M1, 3, 1, 1))
-
- self.conv1 = nn.Conv2d(self.M1, self.input, 5, 1, 2)
-
- def forward(self, x):
- x1 = self.trunk1(x) * f.sigmoid(self.mask1(x)) + x
- x1 = self.up1(x1)
- x2 = self.trunk2(x1)
- x3 = self.trunk3(x2)
- x4 = self.trunk4(x3) + x3
- x5 = self.trunk5(x4)
- output = self.conv1(x5)
- return output
-
- class Image_coding(nn.Module):
- def __init__(self, input_features, N1, N2, M, M1):
-
- super(Image_coding, self).__init__()
-
- self.encoder = Enc(input_features, N1, N2, M, M1)
- self.factorized_entropy_func = Entropy_bottleneck(N2)
- self.hyper_dec = Hyper_Dec(N2, M)
- self.p = P_Model(M)
- self.decoder = Dec(input_features, N1, M, M1)
- self.context = Weighted_Gaussian(M)
-
- def add_noise(self, x):
- noise = np.random.uniform(-0.5, 0.5, x.size())
- noise = torch.Tensor(noise).cuda()
- return x + noise
-
- def forward(self, x, if_training):
- x1, x2 = self.encoder(x)
- xq2, xp2 = self.factorized_entropy_func(x2, if_training)
- x3 = self.hyper_dec(xq2)
- hyper_dec = self.p(x3)
- if if_training == 0:
- xq1 = self.add_noise(x1)
- elif if_training == 1:
- xq1 = UniverseQuant.apply(x1)
- else:
- xq1 = torch.round(x1)
- xp3, _ = self.context(xq1, hyper_dec)
-
- output = self.decoder(xq1)
-
- return [output, xp2, xp3]
-
-
- class Image_coding_lfc(nn.Module):
- def __init__(self, input_features, N1, N2, M, M1):
-
- super(Image_coding_lfc, self).__init__()
-
- self.encoder = Enc_lfc(input_features, N1, N2, M, M1)
- self.local_fc_dec = nn.Conv2d(256, 256 * 16, 4, stride=4, groups=256,bias=False)
- init_lfc = np.zeros((256 * 16, 16))
- for c in range(256 * 16):
- i = c % 16
- init_lfc[c][i] = 1
- init_lfc = torch.from_numpy(init_lfc)
- init_lfc = torch.reshape(init_lfc, (256 * 16, 1, 4, 4)).float()
- self.local_fc_dec.weight = nn.Parameter(init_lfc)
-
- self.factorized_entropy_func = Entropy_bottleneck(N2)
- self.hyper_dec = Hyper_Dec(N2, M)
- self.p = P_Model(M)
- self.decoder = Dec(input_features, N1, M, M1)
- self.context = Weighted_Gaussian(M)
-
- def add_noise(self, x):
- noise = np.random.uniform(-0.5, 0.5, x.size())
- noise = torch.Tensor(noise).cuda()
- return x + noise
-
- def forward(self, x, if_training):
- x1, x2 = self.encoder(x)
-
- xq2, xp2 = self.factorized_entropy_func(x2, if_training)
- x3 = self.hyper_dec(xq2)
- hyper_dec = self.p(x3)
- if if_training == 0:
- xq1 = self.add_noise(x1)
- elif if_training == 1:
- xq1 = UniverseQuant.apply(x1)
- else:
- xq1 = torch.round(x1)
- xp3, _ = self.context(xq1, hyper_dec)
-
- b, c, h, w = xq1.size()
- fold = nn.Fold(output_size=(h, w), kernel_size=4, stride=4)
-
- uncorrelated = self.local_fc_dec(xq1)
- uncorrelated = torch.reshape(uncorrelated, (b, c * 16, h * w // 16))
- uncorrelated = fold(uncorrelated)
-
- xq1 = f.relu(uncorrelated)
- output = self.decoder(xq1)
-
- return [output, xp2, xp3]
-
- class NIC_SQL(nn.Module):
- def __init__(self, input_features, N1, N2, M, M1):
- super(NIC_SQL, self).__init__()
-
- self.encoder = Enc_SQL(input_features, N1, N2, M, M1)
- self.decoder = Dec(input_features, N1, M, M1)
- self.factorized_entropy_func = Entropy_bottleneck(N2)
- self.hyper_dec = Hyper_Dec(N2, M)
- self.p = P_Model(M)
- self.context = Weighted_Gaussian(M)
-
- def forward(self, x):
-
- b = x.size()[0]
- rand_lambda = np.random.rand(b)
- rand_lambda = 256*rand_lambda + 1
- rand_lambda = np.array(rand_lambda, dtype=np.int)
-
- lmd_info = np.array(rand_lambda, dtype=np.float32)
- lmd_info = torch.from_numpy(lmd_info).cuda()
- lmd_info = torch.reshape(lmd_info,(b,1)).cuda()
-
- x1,x2 = self.encoder(x,lmd_info)
-
- xq2, xp2 = self.factorized_entropy_func(x2, 1)
- x3 = self.hyper_dec(xq2)
-
- hyper_dec_p = self.p(x3)
- xq1 = UniverseQuant.apply(x1)
- xp3, _ = self.context(xq1, hyper_dec_p)
-
- fake = self.decoder(xq1)
-
- lmd_info = torch.squeeze(lmd_info).cuda()
-
- return lmd_info, fake, xp2, xp3
-
-
- class UniverseQuant(torch.autograd.Function):
- @staticmethod
- def forward(ctx, x):
- b = np.random.uniform(-1, 1)
- uniform_distribution = Uniform(-0.5 * torch.ones(x.size())
- * (2 ** b), 0.5 * torch.ones(x.size()) * (2 ** b)).sample().cuda()
- return torch.round(x + uniform_distribution) - uniform_distribution
-
- @staticmethod
- def backward(ctx, g):
- return g
|