|
- # Copyright (c) Nanjing University, Vision Lab.
- # Last update:
- # 2021.9.9
-
- import os
- import argparse
- import numpy as np
- #import tensorflow as tf
- import time
- import importlib
- import subprocess
- import tensorlayer as tl
- import math
-
- ################### Compression Network (with conditional entropy model) ###################
- def parallel_run(function,x,batch_parallel): #已改
- for i in range(math.ceil(len(x)/batch_parallel)):
- end_idx = (i+1)*batch_parallel if len(x)-i*batch_parallel>batch_parallel else len(x)
- x0 = x[i*batch_parallel:end_idx]
- print("x0.shape:",x0.shape)
- start = time.time()
- y_a = function(x0)
- print(batch_parallel," batch: {}s".format(round(time.time()-start, 4)))
- if(i == 0):
- ys = y_a
- else:
- ys = tl.layers.Concat(concat_dim=0)([ys, y_a])
- return ys
-
- def compress_hyper(cubes, model, batch_parallel):
- """Compress cubes to bitstream.
- Input: cubes with shape [batch size, depth, width, height, channel(1)].
- Output: compressed bitstream.
- """
-
- print('===== Compress =====')
- # load model.
- #model = importlib.import_module(model)
- x = tl.convert_to_tensor(cubes,dtype='float32')
-
- start = time.time()
- #ys = tf.map_fn(loop_analysis, x, dtype=tf.float32, parallel_iterations=1, back_prop=False)
- ys = parallel_run(model.analysis_transform,x,batch_parallel)
- print("Analysis Transform: {}s".format(round(time.time()-start, 4)))
-
- start = time.time()
- #zs = tf.map_fn(loop_hyper_encoder, ys, dtype=tf.float32, parallel_iterations=1, back_prop=False)
- zs = model.hyper_encoder(ys) #这里网络小,不需要分batch
- print("Hyper Encoder: {}s".format(round(time.time()-start, 4)))
-
- z_hats, _ = model.entropy_bottleneck(zs, False) #实例名()就是调用call()方法
- print("Quantize hyperprior.")
-
- start = time.time()
- #locs, scales = tf.map_fn(loop_hyper_deocder, z_hats, dtype=(tf.float32, tf.float32),
- # parallel_iterations=1, back_prop=False)
- locs, scales = model.hyper_decoder(z_hats)
- lower_bound = 1e-9
- scales = tl.ops.Maximum()(scales, lower_bound)
- print("Hyper Decoder: {}s".format(round(time.time()-start, 4)))
- # locs = tl.ops.round(locs * 1e1) / 1e1
- # scales = tl.ops.round(scales * 1e1) / 1e1
-
- start = time.time()
- z_strings, z_min_v, z_max_v = model.entropy_bottleneck.compress(zs)
- z_shape = zs.shape
- print("Entropy Encode (Hyper): {}s".format(round(time.time()-start, 4)))
-
- start = time.time()
- # y_strings, y_min_v, y_max_v = conditional_entropy_model.compress(ys, locs, scales)
- y_shape = ys.shape
- y_strings, y_min_vs, y_max_vs = model.conditional_entropy_model.compress(ys, locs, scales)
- #y_shape = tf.convert_to_tensor(np.insert(tf.shape(ys)[1:].numpy(), 0, 1))
- print("Entropy Encode: {}s".format(round(time.time()-start, 4)))
-
- return y_strings, y_min_vs, y_max_vs, y_shape, z_strings, z_min_v, z_max_v, z_shape
-
- #已改
- def decompress_hyper(y_strings, y_min_vs, y_max_vs, y_shape, z_strings, z_min_v, z_max_v, z_shape, model, batch_parallel):
- """Decompress bitstream to cubes.
- Input: compressed bitstream. latent representations (y) and hyper prior (z).
- Output: cubes with shape [batch size, length, width, height, channel(1)]
- """
- print('===== Decompress =====')
-
- start = time.time()
- zs = model.entropy_bottleneck.decompress(z_strings, z_min_v, z_max_v, z_shape)
- print("Entropy Decoder (Hyper): {}s".format(round(time.time()-start, 4)))
-
- start = time.time()
- locs, scales = model.hyper_decoder(zs)
- lower_bound = 1e-9
- scales = tl.ops.Maximum()(scales, lower_bound)
- print("Hyper Decoder: {}s".format(round(time.time()-start, 4)))
- # locs = tl.ops.round(locs * 1e1) / 1e1 #必须是1e1,1e2不行
- # scales = tl.ops.round(scales * 1e1) / 1e1
-
- start = time.time()
- # ys = conditional_entropy_model.decompress(y_strings, locs, scales, y_min_v, y_max_v, y_shape)
- ys = model.conditional_entropy_model.decompress(y_strings, locs, scales, y_min_vs, y_max_vs, y_shape)
- print("Entropy Decoder: {}s".format(round(time.time()-start, 4)))
-
- start = time.time()
- xs = parallel_run(model.synthesis_transform,ys,batch_parallel)
- print("Synthesis Transform: {}s".format(round(time.time()-start, 4)))
-
- return xs
|