|
- # coding: utf-8
- from __future__ import print_function
- import os
- import time
- import random
- from PIL import Image
- import tensorflow as tf
- import numpy as np
- from utils import *
- from model import *
- from glob import glob
-
- sess = tf.Session()
-
- input_decom = tf.placeholder(tf.float32, [None, None, None, 3], name='input_decom')
- input_low_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low_r')
- input_low_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i')
- input_high_r = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high_r')
- input_high_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_high_i')
- input_low_i_ratio = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i_ratio')
-
- [R_decom, I_decom] = DecomNet_simple(input_decom)
- decom_output_R = R_decom
- decom_output_I = I_decom
- output_r = Restoration_net(input_low_r, input_low_i)
- output_i = Illumination_adjust_net(input_low_i, input_low_i_ratio)
-
- var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
- var_adjust = [var for var in tf.trainable_variables() if 'Illumination_adjust_net' in var.name]
- var_restoration = [var for var in tf.trainable_variables() if 'Restoration_net' in var.name]
-
- saver_Decom = tf.train.Saver(var_list = var_Decom)
- saver_adjust = tf.train.Saver(var_list=var_adjust)
- saver_restoration = tf.train.Saver(var_list=var_restoration)
-
- decom_checkpoint_dir ='./checkpoint/decom_net_train/'
- ckpt_pre=tf.train.get_checkpoint_state(decom_checkpoint_dir)
- if ckpt_pre:
- print('loaded '+ckpt_pre.model_checkpoint_path)
- saver_Decom.restore(sess,ckpt_pre.model_checkpoint_path)
- else:
- print('No decomnet checkpoint!')
-
- checkpoint_dir_adjust = './checkpoint/illumination_adjust_net_train/'
- ckpt_adjust=tf.train.get_checkpoint_state(checkpoint_dir_adjust)
- if ckpt_adjust:
- print('loaded '+ckpt_adjust.model_checkpoint_path)
- saver_adjust.restore(sess,ckpt_adjust.model_checkpoint_path)
- else:
- print("No adjust pre model!")
-
- checkpoint_dir_restoration = './checkpoint/Restoration_net_train/'
- ckpt=tf.train.get_checkpoint_state(checkpoint_dir_restoration)
- if ckpt:
- print('loaded '+ckpt.model_checkpoint_path)
- saver_restoration.restore(sess,ckpt.model_checkpoint_path)
- else:
- print("No restoration pre model!")
-
- ###load eval data
- eval_low_data = []
- eval_img_name =[]
- eval_low_data_name = glob('./LOLdataset/eval15/low/*.png')
- eval_low_data_name.sort()
- for idx in range(len(eval_low_data_name)):
- [_, name] = os.path.split(eval_low_data_name[idx])
- suffix = name[name.find('.') + 1:]
- name = name[:name.find('.')]
- eval_img_name.append(name)
- eval_low_im = load_images(eval_low_data_name[idx])
- eval_low_data.append(eval_low_im)
- print(eval_low_im.shape)
- # To get better results, the illumination adjustment ratio is computed based on the decom_i_high, so we also need the high data.
- eval_high_data = []
- eval_high_data_name = glob('./LOLdataset/eval15/high/*.png')
- eval_high_data_name.sort()
- for idx in range(len(eval_high_data_name)):
- eval_high_im = load_images(eval_high_data_name[idx])
- eval_high_data.append(eval_high_im)
-
- sample_dir = './results/LOLdataset_eval15/'
- if not os.path.isdir(sample_dir):
- os.makedirs(sample_dir)
-
- print("Start evalating!")
- start_time = time.time()
- for idx in range(len(eval_low_data)):
- print(idx)
- name = eval_img_name[idx]
- input_low = eval_low_data[idx]
- input_low_eval = np.expand_dims(input_low, axis=0)
- input_high = eval_high_data[idx]
- input_high_eval = np.expand_dims(input_high, axis=0)
- h, w, _ = input_low.shape
-
- decom_r_low, decom_i_low = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_low_eval})
- decom_r_high, decom_i_high = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_high_eval})
-
- restoration_r = sess.run(output_r, feed_dict={input_low_r: decom_r_low, input_low_i: decom_i_low})
-
- ratio = np.mean(((decom_i_high))/(decom_i_low+0.0001))
-
- i_low_data_ratio = np.ones([h, w])*(ratio)
- i_low_ratio_expand = np.expand_dims(i_low_data_ratio , axis =2)
- i_low_ratio_expand2 = np.expand_dims(i_low_ratio_expand, axis=0)
-
- adjust_i = sess.run(output_i, feed_dict={input_low_i: decom_i_low, input_low_i_ratio: i_low_ratio_expand2})
- fusion = restoration_r*adjust_i
- save_images(os.path.join(sample_dir, '%s_kindle.png' % (name)), fusion)
-
|