|
- # The implementation of GDN is inherited from
- # https://github.com/jorge-pessoa/pytorch-gdn,
- # under the MIT License.
- #
- # This file is being made available under the BSD License.
- # Copyright (c) 2021 Yueyu Hu
- import argparse
- import glob
- import math
- import mindspore as ms
- import mindspore.nn as nn
- import mindspore.ops as ops
- import mindspore.nn.probability.distribution as msd
- import mindspore.dataset as ds
- from mindspore.common.initializer import initializer, XavierUniform
- import numpy as np
- import pickle
- from PIL import Image
- import time
- import os
- from gdn_v3 import GDN, IGDN
- import torch
-
- # Main analysis transform model with GDN
- class analysisTransformModel(nn.Cell):
- def __init__(self, in_dim, num_filters, conv_trainable=True):
- super(analysisTransformModel, self).__init__()
- self.t0 = nn.SequentialCell(
- )
- self.transform = nn.SequentialCell(
- # nn.ZeroPad2d((1, 2, 1, 2)), #左右上下
- nn.Pad(paddings=((0, 0), (0, 0), (1, 2), (1, 2)), mode="CONSTANT"), #上下左右
- nn.Conv2d(in_dim, num_filters[0], 5, 2, has_bias=True, pad_mode='valid', padding=0),
- GDN(num_filters[0]),
-
- nn.Pad(paddings=((0, 0), (0, 0), (1, 2), (1, 2)), mode="CONSTANT"),
- nn.Conv2d(num_filters[0], num_filters[1], 5, 2, has_bias=True, pad_mode='valid', padding=0),
- GDN(num_filters[1]),
-
- nn.Pad(paddings=((0, 0), (0, 0), (1, 2), (1, 2)), mode="CONSTANT"),
- nn.Conv2d(num_filters[1], num_filters[2], 5, 2, has_bias=True, pad_mode='valid', padding=0),
- GDN(num_filters[2]),
-
- nn.Pad(paddings=((0, 0), (0, 0), (1, 2), (1, 2)), mode="CONSTANT"),
- nn.Conv2d(num_filters[2], num_filters[3], 5, 2, has_bias=True, pad_mode='valid', padding=0),
- )
-
- def construct(self, inputs):
- x = self.transform(inputs)
- return x
-
- # Main synthesis transform model with IGDN
- class synthesisTransformModel(nn.Cell):
- def __init__(self, in_dim, num_filters, conv_trainable=True):
- super(synthesisTransformModel, self).__init__()
- self.transform = nn.SequentialCell(
- nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 0)), mode="CONSTANT"), #上下左右
- nn.Conv2dTranspose(in_dim, num_filters[0], 5, 2, padding=0, has_bias=True, pad_mode='same'),
- IGDN(num_filters[0]),
- nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 0)), mode="CONSTANT"),
- nn.Conv2dTranspose(num_filters[0], num_filters[1], 5, 2, padding=0, has_bias=True, pad_mode='same'),
- IGDN(num_filters[1]),
- nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 0)), mode="CONSTANT"),
- nn.Conv2dTranspose(num_filters[1], num_filters[2], 5, 2, padding=0, has_bias=True, pad_mode='same'),
- IGDN(num_filters[2]),
- )
-
- def construct(self, inputs):
- x = self.transform(inputs)
- # y = self.aux_conv(x)
- return x
-
- # # Space-to-depth & depth-to-space module
- # # same to TensorFlow implementations
- class Space2Depth(nn.Cell):
- def __init__(self, r):
- super(Space2Depth, self).__init__()
- self.r = r
-
- def construct(self, x):
- r = self.r
- b, c, h, w = x.shape
- out_c = c * (r**2)
- out_h = h//2
- out_w = w//2
- x_view = x.view(b, c, out_h, r, out_w, r)
- x_prime = x_view.permute(0, 3, 5, 1, 2, 4)
- output = x_prime
- output = output.view(b, out_c, out_h, out_w)
- return output
-
- class Depth2Space(nn.Cell):
- def __init__(self, r):
- super(Depth2Space, self).__init__()
- self.r = r
-
- def construct(self, x):
- r = self.r
- b, c, h, w = x.shape
- out_c = c // (r**2)
- out_h = h * 2
- out_w = w * 2
- x_view = x.view(b, r, r, out_c, h, w)
- x_prime = x_view.permute(0, 3, 4, 1, 5, 2)
- output = x_prime
- output = output.view(b, out_c, out_h, out_w)
- return output
-
- # Hyper analysis transform (w/o GDN)
- class h_analysisTransformModel(nn.Cell):
- def __init__(self, in_dim, num_filters, strides_list, conv_trainable=True):
- super(h_analysisTransformModel, self).__init__()
- self.transform = nn.SequentialCell(
- nn.Conv2d(in_dim, num_filters[0], 3, strides_list[0], has_bias=True, pad_mode='pad', padding=1),
- Space2Depth(2),
- nn.Conv2d(num_filters[0]*4, num_filters[1], 1, strides_list[1], has_bias=True, pad_mode='valid', padding=0),
- nn.ReLU(),
- nn.Conv2d(num_filters[1], num_filters[1], 1, 1, has_bias=True, pad_mode='valid', padding=0),
- nn.ReLU(),
- nn.Conv2d(num_filters[1], num_filters[2], 1, 1, has_bias=True, pad_mode='valid', padding=0),
- )
-
- def construct(self, inputs):
- x = self.transform(inputs)
- return x
-
- # Hyper synthesis transform (w/o GDN)
- class h_synthesisTransformModel(nn.Cell):
- def __init__(self, in_dim, num_filters, strides_list, conv_trainable=True):
- super(h_synthesisTransformModel, self).__init__()
- self.transform = nn.SequentialCell(
- nn.Conv2dTranspose(in_dim, num_filters[0], 1, strides_list[2], padding=0, has_bias=True, pad_mode='pad'),
- nn.Conv2dTranspose(num_filters[0], num_filters[1], 1, strides_list[1], padding=0, has_bias=True, pad_mode='pad'),
- nn.ReLU(),
- nn.Conv2dTranspose(num_filters[1], num_filters[1], 1, strides_list[1], padding=0, has_bias=True, pad_mode='pad'),
- nn.ReLU(),
- Depth2Space(2),
- nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 0)), mode="CONSTANT"), #上下左右
- nn.Conv2dTranspose(num_filters[1]//4, num_filters[2], 3, strides_list[0], padding=1, has_bias=True, pad_mode='pad'),
- )
-
- def construct(self, inputs):
- x = self.transform(inputs)
- return x
-
- class NeighborSample(nn.Cell):
- def __init__(self):
- super(NeighborSample, self).__init__()
- self.unfolder = torch.nn.Unfold(5, padding=2) #由于带padding的时候与mindspore算子结果不一致,只能用pytorch的代替
-
- def construct(self, inputs):
- b, c, h, w = inputs.shape
- inputs = torch.tensor(inputs.asnumpy())
- t = self.unfolder(inputs) # (b, c*5*5, h*w)
- t = ms.Tensor(t.numpy(), ms.float32)
- t = t.permute((0,2,1)).reshape(b*h*w, c, 5, 5)
- return t
-
- # class NeighborSample(nn.Cell):
- # def __init__(self):
- # super(NeighborSample, self).__init__()
-
- # self.unfolder = nn.Unfold(ksizes=[1, 5, 5, 1], strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding="valid")
-
- # def construct(self, inputs):
- # b, c, h, w = inputs.shape
- # pad = nn.Pad(paddings=((0, 0), (0, 0), (2, 2), (2, 2)), mode="CONSTANT")
- # inputs = pad(inputs)
- # t = self.unfolder(inputs) # (b, c*5*5, h*w)
-
- # out_batch, out_depth, out_row, out_col = t.shape
- # ms_reshape = ops.Reshape()
- # t = ms_reshape(t, (out_batch, out_depth, out_row * out_col))
-
- # ms_concat = ops.Concat()
- # output = None
- # for i in range(out_batch):
- # odd = None
- # even = None
- # for j in range(out_depth):
- # data = t[i,j,:]
- # data = ms_reshape(data, (1, data.shape[0]))
- # if j % 2 == 0:
- # if even is None:
- # even = data
- # else:
- # even = ms_concat((even, data))
- # else:
- # if odd is None:
- # odd = data
- # else:
- # odd = ms_concat((odd, data))
- # temp = ms_concat((even, odd))
- # temp = ms_reshape(temp, (1, temp.shape[0], temp.shape[1]))
- # if i == 0:
- # output = temp
- # else:
- # output = ms_concat((output, temp))
-
- # t = output.permute((0,2,1)).reshape(b*h*w, c, 5, 5)
-
- # return t
-
- # Gaussian likelihood calculation module
- class GaussianModel(nn.Cell):
- def __init__(self):
- super(GaussianModel, self).__init__()
- self.m_normal_dist = msd.Normal(0.0, 1.0, dtype=ms.float32)
-
- def _cumulative(self, inputs, stds, mu):
- half = 0.5
- eps = 1e-6
- upper = (inputs - mu + half) / (stds)
- lower = (inputs - mu - half) / (stds)
- cdf_upper = self.m_normal_dist.cdf(ms.Tensor(upper))
- cdf_lower = self.m_normal_dist.cdf(ms.Tensor(lower))
- res = cdf_upper - cdf_lower
- return res
-
- def construct(self, inputs, hyper_sigma, hyper_mu):
- likelihood = self._cumulative(inputs, hyper_sigma, hyper_mu)
- likelihood_bound = 1e-8
- likelihood = ops.clip_by_value(likelihood, likelihood_bound, likelihood + 100)
- return likelihood
-
- # Prediction module to generate mean and scale for entropy coding
- class PredictionModel(nn.Cell):
- def __init__(self, in_dim, dim=192, trainable=True, outdim=None):
- super(PredictionModel, self).__init__()
- if outdim is None:
- outdim = dim
- self.transform = nn.SequentialCell(
- nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)), mode="CONSTANT"),
- nn.Conv2d(in_dim, dim, 3, 1, has_bias=True, pad_mode='valid', padding=0),
- nn.LeakyReLU(0.2),
- nn.Pad(paddings=((0, 0), (0, 0), (1, 2), (1, 2)), mode="CONSTANT"),
- nn.Conv2d(dim, dim, 3, 2, has_bias=True, pad_mode='valid', padding=0),
- nn.LeakyReLU(0.2),
- nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)), mode="CONSTANT"),
- nn.Conv2d(dim, dim, 3, 1, has_bias=True, pad_mode='valid', padding=0),
- nn.LeakyReLU(0.2)
- )
- self.fc = nn.Dense(dim*3*3, outdim)
- # self.flatten = ops.Flatten()
-
- def construct(self, input_shape, h_tilde, h_sampler):
- b, c, h, w = input_shape
- h_sampled = h_sampler(h_tilde)
-
- h_sampled = self.transform(h_sampled)
- flatten = ops.Flatten()
- h_sampled = flatten(h_sampled)
-
- h_sampled = self.fc(h_sampled)
- hyper_mu = h_sampled[:, :c]
- reshape = ops.Reshape()
- hyper_mu = reshape(hyper_mu, (b, h, w, c)).permute(0, 3, 1, 2)
-
-
- hyper_sigma = h_sampled[:, c:]
- hyper_sigma = ops.exp(hyper_sigma)
- hyper_sigma = reshape(hyper_sigma, (b, h, w, c)).permute(0, 3, 1, 2)
- # hyper_sigma = hyper_sigma.contiguous().view(b, h, w, c).permute(0, 3, 1, 2)
-
- return hyper_mu, hyper_sigma
-
- ## differentiable rounding function
-
- def BypassRound(x):
- y = output_tensor(x.shape, x.dtype)
- for i0 in range(x.shape[0]):
- y[i0] = round(x[i0])
- return y
-
- def BypassRound_grad(x, dout):
- dx = output_tensor(x.shape, x.dtype)
- for i0 in range(x.shape[0]):
- dx[i0] = dout[i0]
- return dx
-
- def bprop():
- op = ops.Custom(BypassRound_grad, lambda x, _: x, lambda x, _: x, func_type="akg")
-
- def custom_bprop(x, out, dout):
- dx = op(x, dout)
- return (dx,)
-
- return custom_bprop
-
- # Information-Aggregation Reconstruction network
- class SideInfoReconModel(nn.Cell):
- def __init__(self, input_dim, num_filters=192):
- super(SideInfoReconModel, self).__init__()
- self.layer_1 = nn.SequentialCell(
- nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 0)), mode="CONSTANT"),
- nn.Conv2dTranspose(input_dim, num_filters, 5, 2, padding=0, has_bias=True, pad_mode='same'),
- )
- self.layer_1a = nn.SequentialCell(
- nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 0)), mode="CONSTANT"),
- nn.Conv2dTranspose(num_filters, num_filters, 5, 2, padding=0, has_bias=True, pad_mode='same'),
- nn.LeakyReLU(0.2)
- )
- self.layer_1b = nn.SequentialCell(
- nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 0)), mode="CONSTANT"),
- nn.Conv2dTranspose(num_filters, num_filters, 5, 2, padding=0, has_bias=True, pad_mode='same'),
- nn.LeakyReLU(0.2)
- )
- self.layer_3_1 = nn.SequentialCell(
- nn.Conv2d(num_filters*2, num_filters, 3, 1, has_bias=True, pad_mode='pad', padding=1),
- nn.LeakyReLU(0.2)
- )
- self.layer_3_2 = nn.SequentialCell(
- nn.Conv2d(num_filters, num_filters, 3, 1, has_bias=True, pad_mode='pad', padding=1),
- nn.LeakyReLU(0.2)
- )
- self.layer_3_3 = nn.SequentialCell(
- nn.Conv2d(num_filters, num_filters*2, 3, 1, has_bias=True, pad_mode='pad', padding=1),
- nn.LeakyReLU(0.2)
- )
- self.layer_4 = nn.SequentialCell(
- nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 0)), mode="CONSTANT"),
- nn.Conv2dTranspose(num_filters*2, num_filters//3, 5, 2, padding=0, has_bias=True, pad_mode='same'),
- )
- self.layer_5 = nn.SequentialCell(
- nn.Conv2d(num_filters//3, num_filters//12, 3, 1, has_bias=True, pad_mode='pad', padding=1),
- nn.LeakyReLU(0.2)
- )
- self.layer_6 = nn.Conv2d(num_filters//12, 3, 1, 1, has_bias=True, pad_mode='valid', padding=0)
- self.d2s = Depth2Space(2)
-
- def construct(self, pf, h2, h1):
- h1prime = self.d2s(h1)
- h = ops.concat([h2, h1prime], 1)
- h = self.layer_1(h)
- h = self.layer_1a(h)
- h = self.layer_1b(h)
-
- hfeat_0 = ops.concat([pf, h], 1)
- hfeat = self.layer_3_1(hfeat_0)
- hfeat = self.layer_3_2(hfeat)
- hfeat = self.layer_3_3(hfeat)
- hfeat = hfeat_0 + hfeat
-
- x = self.layer_4(hfeat)
- x = self.layer_5(x)
- x = self.layer_6(x)
- return x
-
- # bypass_round = BypassRound.apply
-
- # Projection head for constructing helping loss.
- # It is used to apply a constraint to the decoded hyperprior.
- # It helps hyperpriors preserve vital information during multi-layer training.
- class ProjHead(nn.Cell):
- def __init__(self, in_dim, out_dim):
- super(ProjHead, self).__init__()
- self.transform = nn.Conv2d(in_dim, out_dim, 3, 1, has_bias=True, pad_mode='pad', padding=1)
- def construct(self, inputs):
- return self.transform(inputs)
-
- class NetHigh(nn.Cell):
- def __init__(self):
- super(NetHigh, self).__init__()
- self.a_model = analysisTransformModel(3, [384, 384, 384, 384])
- self.s_model = synthesisTransformModel(384, [384, 384, 384, 3])
-
- self.ha_model_1 = h_analysisTransformModel(64*4, [64*4*2, 32*4*2, 32*4], [1, 1, 1])
- self.hs_model_1 = h_synthesisTransformModel(32*4, [64*4*2, 64*4*2, 64*4], [1, 1, 1])
-
- self.ha_model_2 = h_analysisTransformModel(384, [384*2, 192*4*2, 64*4], [1, 1, 1])
- self.hs_model_2 = h_synthesisTransformModel(64*4, [192*4*2, 192*4*2, 384], [1, 1, 1])
-
- self.entropy_bottleneck_z1 = GaussianModel()
- self.entropy_bottleneck_z2 = GaussianModel()
- self.entropy_bottleneck_z3 = GaussianModel()
- self.h1_sigma = ms.Parameter(ops.ones(
- (1, 32*4, 1, 1), ms.float32), requires_grad=False)
- self.prediction_model_2 = PredictionModel(in_dim=64*4, dim=64*4, outdim=64*4*2)
-
- self.prediction_model_3 = PredictionModel(in_dim=384, dim=384, outdim=384*2)
-
- self.sampler_2 = NeighborSample()
- self.sampler_3 = NeighborSample()
-
- self.side_recon_model = SideInfoReconModel(384+64, num_filters=384)
-
- def construct(self, inputs, mode='test'):
- pass
-
- def encode(self, inputs):
- b, c, h, w = inputs.shape
- tb, tc, th, tw = inputs.shape
-
- z3 = self.a_model(inputs)
- z3_rounded = ops.Rint()(z3)
-
- z2 = self.ha_model_2(z3_rounded)
- z2_rounded = ops.Rint()(z2)
-
- z1 = self.ha_model_1(z2_rounded)
- z1_rounded = ops.Rint()(z1)
-
- z1_sigma = ops.abs(self.h1_sigma)
- z1_mu = ops.ZerosLike()(z1_sigma)
-
- h1 = self.hs_model_1(z1_rounded)
- h2 = self.hs_model_2(z2_rounded)
- h3 = self.s_model(z3_rounded)
-
- z1_likelihoods = self.entropy_bottleneck_z1(z1_rounded, z1_sigma, z1_mu)
-
- z2_mu, z2_sigma = self.prediction_model_2((tb,64*4,th//2//16,tw//2//16), h1, self.sampler_2)
-
- z2_likelihoods = self.entropy_bottleneck_z2(z2_rounded, z2_sigma, z2_mu)
-
- z3_mu, z3_sigma = self.prediction_model_3((tb,384,th//16,tw//16), h2, self.sampler_3)
-
- z3_likelihoods = self.entropy_bottleneck_z3(z3_rounded, z3_sigma, z3_mu)
-
- pf = self.s_model(z3_rounded)
- x_tilde = self.side_recon_model(pf, h2, h1)
-
- num_pixels = inputs.shape[0] * h * w
- test_num_pixels = inputs.shape[0] * h * w
-
- eval_bpp = ops.ReduceSum()(ops.Log()(z3_likelihoods), [0,1,2,3]) / (-np.log(2) * test_num_pixels) + ops.ReduceSum()(ops.Log()(z2_likelihoods), [0,1,2,3]) / (-np.log(2) * test_num_pixels) + ops.ReduceSum()(ops.Log()(z1_likelihoods), [0,1,2,3]) / (-np.log(2) * test_num_pixels)
-
- gt = ops.round((inputs + 1) * 127.5)
- x_hat = ops.clip_by_value((x_tilde + 1) * 127.5, 0, 255)
- x_hat = ops.round(x_hat).float()
- v_mse = ops.ReduceMean()((x_hat - gt) ** 2, [1,2,3])
- v_psnr = ops.ReduceMean()(20 * ops.Log()(255 / ops.Sqrt()(v_mse)), 0)
-
- ret = {}
- ret['z1_mu'] = z1_mu.numpy()
- ret['z1_sigma'] = z1_sigma.numpy()
- ret['z2_mu'] = z2_mu.numpy()
- ret['z2_sigma'] = z2_sigma.numpy()
- ret['z3_mu'] = z3_mu.numpy()
- ret['z3_sigma'] = z3_sigma.numpy()
- ret['z1_rounded'] = z1_rounded.numpy()
- ret['z2_rounded'] = z2_rounded.numpy()
- ret['z3_rounded'] = z3_rounded.numpy()
- ret['v_psnr'] = v_psnr.numpy()
- ret['eval_bpp'] = eval_bpp.numpy()
- return ret
-
- def decode(self, inputs, stage):
- if stage == 0:
- z1_sigma = ops.abs(self.h1_sigma)
- z1_mu = ops.ZerosLike()(z1_sigma)
-
- ret = {}
- ret['z1_sigma'] = z1_sigma.numpy()
- ret['z1_mu'] = z1_mu.numpy()
- return ret
-
- elif stage == 1:
- z1_rounded = inputs['z1_rounded']
- h1 = self.hs_model_1(z1_rounded)
-
- self.h1 = h1
- z2_mu, z2_sigma = self.prediction_model_2((h1.shape[0],64*4,h1.shape[2],h1.shape[3]), h1, self.sampler_2)
- ret = {}
- ret['z2_sigma'] = z2_sigma.numpy()
- ret['z2_mu'] = z2_mu.numpy()
-
- return ret
-
- elif stage == 2:
- z2_rounded = inputs['z2_rounded']
- h2 = self.hs_model_2(z2_rounded)
- self.h2 = h2
- z3_mu, z3_sigma = self.prediction_model_3((h2.shape[0],384,h2.shape[2],h2.shape[3]), h2, self.sampler_3)
- ret = {}
- ret['z3_sigma'] = z3_sigma.numpy()
- ret['z3_mu'] = z3_mu.numpy()
- return ret
-
- elif stage == 3:
- z3_rounded = inputs['z3_rounded']
- pf = self.s_model(z3_rounded)
- x_tilde = self.side_recon_model(pf, self.h2, self.h1)
- x_tilde = round(ops.clip_by_value((x_tilde + 1) * 127.5, 0, 255))
- return x_tilde.numpy()
-
- # Main network
- # The current hyper parameters are for higher-bit-rate compression (2x)
- # Stage 1: train the main encoder & decoder, fine hyperprior
- # Stage 2: train the whole network w/o info-agg sub-network
- # Stage 3: disable the final layer of the synthesis transform and enable info-agg net
- # Stage 4: End-to-end train the whole network w/o the helping (auxillary) loss
- class NetLow(nn.Cell):
- def __init__(self, train_size=(1,256,256,3), test_size=(1,256,256,3)):
- super(NetLow, self).__init__()
- self.train_size = train_size
- self.test_size = test_size
- self.a_model = analysisTransformModel(
- 3, [192, 192, 192, 192])
- self.s_model = synthesisTransformModel(
- 192, [192, 192, 192, 3])
-
- self.ha_model_1 = h_analysisTransformModel(
- 64*4, [64*4, 32*4, 32*4], [1, 1, 1])
- self.hs_model_1 = h_synthesisTransformModel(
- 32*4, [64*4, 64*4, 64*4], [1, 1, 1])
-
- self.ha_model_2 = h_analysisTransformModel(
- 192, [384, 192*4, 64*4], [1, 1, 1])
- self.hs_model_2 = h_synthesisTransformModel(
- 64*4, [192*4, 192*4, 192], [1, 1, 1])
-
- self.entropy_bottleneck_z1 = GaussianModel()
- self.entropy_bottleneck_z2 = GaussianModel()
- self.entropy_bottleneck_z3 = GaussianModel()
- b, h, w, c = train_size
- tb, th, tw, tc = test_size
-
- self.h1_sigma = ms.Parameter(ops.ones(
- (1, 32*4, 1, 1), ms.float32), requires_grad=False)
-
- self.prediction_model_2 = PredictionModel(
- in_dim=64*4, dim=64*4, outdim=64*4*2)
-
- self.prediction_model_3 = PredictionModel(
- in_dim=192, dim=192, outdim=192*2)
-
- self.sampler_2 = NeighborSample()
- self.sampler_3 = NeighborSample()
-
- self.side_recon_model = SideInfoReconModel(192+64, num_filters=192)
-
- # self.proj_head_z3 = ProjHead(384, 384)
- # self.proj_head_z2 = ProjHead(64*4, 64*4)
-
- self.BypassRound = ops.Custom(BypassRound, lambda x: x, lambda x: x, bprop=bprop(), func_type="akg")
-
- def stage1_params(self):
- params = []
- for v in self.a_model.value():
- params.append(v)
- for v in self.s_model.value():
- params.append(v)
-
- for v in self.ha_model_2.value():
- params.append(v)
- for v in self.hs_model_2.value():
- params.append(v)
- for v in self.proj_head_z3.value():
- params.append(v)
- for v in self.prediction_model_3.value():
- params.append(v)
- params.append(self.z2_sigma)
-
- return params
-
- def stage2_params(self):
- params = []
-
- for v in self.a_model.value():
- params.append(v)
- for v in self.s_model.value():
- params.append(v)
-
- for v in self.ha_model_2.value():
- params.append(v)
- for v in self.hs_model_2.value():
- params.append(v)
- for v in self.proj_head_z3.value():
- params.append(v)
- for v in self.prediction_model_3.value():
- params.append(v)
-
- for v in self.ha_model_1.value():
- params.append(v)
- for v in self.hs_model_1.value():
- params.append(v)
- for v in self.proj_head_z2.value():
- params.append(v)
- for v in self.prediction_model_2.value():
- params.append(v)
- params.append(self.get_h1_sigma)
-
- return params
-
- # We adopt a multi-stage training procedure
- def construct(self, inputs, mode='test', stage=1):
- pass
- def encode(self, inputs):
- b, c, h, w = inputs.shape
- tb, tc, th, tw = inputs.shape
-
- z3 = self.a_model(inputs)
- z3_rounded = ops.Rint()(z3)
-
- z2 = self.ha_model_2(z3_rounded)
- z2_rounded = ops.Rint()(z2)
-
- z1 = self.ha_model_1(z2_rounded)
- z1_rounded = ops.Rint()(z1)
-
- z1_sigma = ops.abs(self.h1_sigma)
- z1_mu = ops.ZerosLike()(z1_sigma)
-
- h1 = self.hs_model_1(z1_rounded)
- h2 = self.hs_model_2(z2_rounded)
- h3 = self.s_model(z3_rounded)
-
- z1_likelihoods = self.entropy_bottleneck_z1(z1_rounded, z1_sigma, z1_mu)
- z2_mu, z2_sigma = self.prediction_model_2((tb, 64*4,th//2//16,tw//2//16), h1, self.sampler_2)
-
- z2_likelihoods = self.entropy_bottleneck_z2(z2_rounded, z2_sigma, z2_mu)
-
- z3_mu, z3_sigma = self.prediction_model_3((tb, 192,th//16,tw//16), h2, self.sampler_3)
-
- z3_likelihoods = self.entropy_bottleneck_z3(z3_rounded, z3_sigma, z3_mu)
-
- pf = self.s_model(z3_rounded)
- x_tilde = self.side_recon_model(pf, h2, h1)
-
- num_pixels = inputs.shape[0] * h * w
-
- test_num_pixels = inputs.shape[0] * h * w
-
- eval_bpp = ops.ReduceSum()(ops.Log()(z3_likelihoods), [0,1,2,3]) / (-np.log(2) * test_num_pixels) + ops.ReduceSum()(ops.Log()(z2_likelihoods), [0,1,2,3]) / (-np.log(2) * test_num_pixels) + ops.ReduceSum()(ops.Log()(z1_likelihoods), [0,1,2,3]) / (-np.log(2) * test_num_pixels)
-
- gt = ops.round((inputs + 1) * 127.5)
- x_hat = ops.clip_by_value((x_tilde + 1) * 127.5, 0, 255)
- x_hat = ops.round(x_hat).float()
- v_mse = ops.ReduceMean()((x_hat - gt) ** 2, [1,2,3])
- v_psnr = ops.ReduceMean()(20 * ops.Log()(255 / ops.Sqrt()(v_mse)), 0)
-
- ret = {}
- ret['z1_mu'] = z1_mu.numpy()
- ret['z1_sigma'] = z1_sigma.numpy()
- ret['z2_mu'] = z2_mu.numpy()
- ret['z2_sigma'] = z2_sigma.numpy()
- ret['z3_mu'] = z3_mu.numpy()
- ret['z3_sigma'] = z3_sigma.numpy()
- ret['z1_rounded'] = z1_rounded.numpy()
- ret['z2_rounded'] = z2_rounded.numpy()
- ret['z3_rounded'] = z3_rounded.numpy()
- ret['v_psnr'] = v_psnr.numpy()
- ret['eval_bpp'] = eval_bpp.numpy()
- return ret
- def decode(self, inputs, stage):
- if stage == 0:
- z1_sigma = ops.abs(self.h1_sigma)
- z1_mu = ops.ZerosLike()(z1_sigma)
-
- ret = {}
- ret['z1_sigma'] = z1_sigma.numpy()
- ret['z1_mu'] = z1_mu.numpy()
- return ret
-
- elif stage == 1:
- z1_rounded = inputs['z1_rounded']
- h1 = self.hs_model_1(z1_rounded)
-
- self.h1 = h1
- z2_mu, z2_sigma = self.prediction_model_2((h1.shape[0],64*4,h1.shape[2],h1.shape[3]), h1, self.sampler_2)
- ret = {}
- ret['z2_sigma'] = z2_sigma.numpy()
- ret['z2_mu'] = z2_mu.numpy()
-
- return ret
-
- elif stage == 2:
- z2_rounded = inputs['z2_rounded']
- h2 = self.hs_model_2(z2_rounded)
- self.h2 = h2
- z3_mu, z3_sigma = self.prediction_model_3((h2.shape[0],192,h2.shape[2],h2.shape[3]), h2, self.sampler_3)
- ret = {}
- ret['z3_sigma'] = z3_sigma.numpy()
- ret['z3_mu'] = z3_mu.numpy()
- return ret
-
- elif stage == 3:
- z3_rounded = inputs['z3_rounded']
- pf = self.s_model(z3_rounded)
- x_tilde = self.side_recon_model(pf, self.h2, self.h1)
- x_tilde = round(ops.clip_by_value((x_tilde + 1) * 127.5, 0, 255))
- return x_tilde.numpy()
-
- if __name__ == "__main__":
-
- #测试net代码
- x = np.random.random((1, 3, 256, 256)).astype(np.float32)
- net = NetLow()
- param_dict = ms.load_checkpoint('my_model.ckpt')
- param_not_load = ms.load_param_into_net(net, param_dict)
- print(param_not_load)
- y = net(ms.Tensor(x))
- print(y)
|