|
- # coding: utf-8
- from __future__ import print_function
- import os
- import time
- import random
- #from skimage import color
- from PIL import Image
- import tensorflow as tf
- import numpy as np
- from utils import *
- from model import *
- from glob import glob
-
- from npu_bridge.npu_init import *
-
-
- batch_size = 10
- patch_size = 48
-
- config = tf.ConfigProto()
- custom_op = config.graph_options.rewrite_options.custom_optimizers.add()
- custom_op.name = "NpuOptimizer"
- custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision")
- config.graph_options.rewrite_options.remapping = RewriterConfig.OFF
- custom_op.parameter_map["dynamic_input"].b = True
- custom_op.parameter_map["dynamic_graph_execute_mode"].s = tf.compat.as_bytes("lazy_recompile")
- sess = tf.Session(config=config)
- #the input of decomposition net
- input_decom = tf.placeholder(tf.float32, [None, None, None, 3], name='input_decom')
- #the input of illumination adjustment net
- input_low_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i')
- input_low_i_ratio = tf.placeholder(tf.float32, [None, None, None, 1], name='input_low_i_ratio')
- input_high_i = tf.placeholder(tf.float32, [None, None, None, 1], name='input_high_i')
-
- [R_decom, I_decom] = DecomNet_simple(input_decom)
- #the output of decomposition network
- decom_output_R = R_decom
- decom_output_I = I_decom
- #the output of illumination adjustment net
- output_i = Illumination_adjust_net(input_low_i, input_low_i_ratio)
-
- #define loss
-
- def grad_loss(input_i_low, input_i_high):
- x_loss = tf.square(gradient(input_i_low, 'x') - gradient(input_i_high, 'x'))
- y_loss = tf.square(gradient(input_i_low, 'y') - gradient(input_i_high, 'y'))
- grad_loss_all = tf.reduce_mean(x_loss + y_loss)
- return grad_loss_all
-
- loss_grad = grad_loss(output_i, input_high_i)
- loss_square = tf.reduce_mean(tf.square(output_i - input_high_i))# * ( 1 - input_low_r ))#* (1- input_low_i)))
-
- loss_adjust = loss_square + loss_grad
-
- lr = tf.placeholder(tf.float32, name='learning_rate')
-
- optimizer = tf.train.AdamOptimizer(learning_rate=lr, name='AdamOptimizer')
-
- var_Decom = [var for var in tf.trainable_variables() if 'DecomNet' in var.name]
- var_adjust = [var for var in tf.trainable_variables() if 'Illumination_adjust_net' in var.name]
-
- saver_adjust = tf.train.Saver(var_list=var_adjust)
- saver_Decom = tf.train.Saver(var_list = var_Decom)
- train_op_adjust = optimizer.minimize(loss_adjust, var_list = var_adjust)
- sess.run(tf.global_variables_initializer())
- print("[*] Initialize model successfully...")
-
- ### load data
- ### Based on the decomposition net, we first get the decomposed reflectance maps
- ### and illumination maps, then train the adjust net.
- ###train_data
- train_low_data = []
- train_high_data = []
- train_low_data_names = glob('./LOLdataset/our485/low/*.png')
- train_low_data_names.sort()
- train_high_data_names = glob('./LOLdataset/our485/high/*.png')
- train_high_data_names.sort()
- assert len(train_low_data_names) == len(train_high_data_names)
- print('[*] Number of training data: %d' % len(train_low_data_names))
- for idx in range(len(train_low_data_names)):
- low_im = load_images(train_low_data_names[idx])
- train_low_data.append(low_im)
- high_im = load_images(train_high_data_names[idx])
- train_high_data.append(high_im)
-
- pre_decom_checkpoint_dir = './checkpoint/decom_net_train/'
- ckpt_pre=tf.train.get_checkpoint_state(pre_decom_checkpoint_dir)
- if ckpt_pre:
- print('loaded '+ckpt_pre.model_checkpoint_path)
- saver_Decom.restore(sess,ckpt_pre.model_checkpoint_path)
- else:
- print('No pre_decom_net checkpoint!')
-
- #decomposed_low_r_data_480 = []
- decomposed_low_i_data_480 = []
- #decomposed_high_r_data_480 = []
- decomposed_high_i_data_480 = []
- for idx in range(len(train_low_data)):
- input_low = np.expand_dims(train_low_data[idx], axis=0)
- RR, II = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_low})
- RR0 = np.squeeze(RR)
- II0 = np.squeeze(II)
- print(RR0.shape, II0.shape)
- #decomposed_high_r_data_480.append(result_1_sq)
- decomposed_low_i_data_480.append(II0)
- for idx in range(len(train_high_data)):
- input_high = np.expand_dims(train_high_data[idx], axis=0)
- RR2, II2 = sess.run([decom_output_R, decom_output_I], feed_dict={input_decom: input_high})
- RR02 = np.squeeze(RR2)
- II02 = np.squeeze(II2)
- print(RR02.shape, II02.shape)
- #decomposed_high_r_data_480.append(result_1_sq)
- decomposed_high_i_data_480.append(II02)
-
- eval_adjust_low_i_data = decomposed_low_i_data_480[451:480]
- eval_adjust_high_i_data = decomposed_high_i_data_480[451:480]
-
- train_adjust_low_i_data = decomposed_low_i_data_480[0:450]
- train_adjust_high_i_data = decomposed_high_i_data_480[0:450]
-
- print('[*] Number of training data: %d' % len(train_adjust_high_i_data))
-
- learning_rate = 0.0001
- epoch = 2000
- eval_every_epoch = 200
- train_phase = 'adjustment'
- numBatch = len(train_adjust_low_i_data) // int(batch_size)
- train_op = train_op_adjust
- train_loss = loss_adjust
- saver = saver_adjust
-
- checkpoint_dir = './checkpoint/illumination_adjust_net_train/'
- if not os.path.isdir(checkpoint_dir):
- os.makedirs(checkpoint_dir)
- ckpt=tf.train.get_checkpoint_state(checkpoint_dir)
- if ckpt:
- print('loaded '+ckpt.model_checkpoint_path)
- saver.restore(sess,ckpt.model_checkpoint_path)
- else:
- print("No adjustment net pre model!")
-
- start_step = 0
- start_epoch = 0
- iter_num = 0
- print("[*] Start training for phase %s, with start epoch %d start iter %d : " % (train_phase, start_epoch, iter_num))
-
- sample_dir = './illumination_adjust_net_train/'
- if not os.path.isdir(sample_dir):
- os.makedirs(sample_dir)
-
- start_time = time.time()
- image_id = 0
-
- for epoch in range(start_epoch, epoch):
- for batch_id in range(start_step, numBatch):
- batch_input_low_i_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
- batch_input_high_i_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
- batch_input_low_i = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
- batch_input_high_i = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
- input_low_i_rand = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
- input_high_i_rand = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
- input_low_i_rand_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
- input_high_i_rand_ratio = np.zeros((batch_size, patch_size, patch_size, 1), dtype="float32")
-
- for patch_id in range(batch_size):
- i_low_data = train_adjust_low_i_data[image_id]
- i_low_expand = np.expand_dims(i_low_data, axis = 2)
- i_high_data = train_adjust_high_i_data[image_id]
- i_high_expand = np.expand_dims(i_high_data, axis = 2)
-
- h, w = train_adjust_low_i_data[image_id].shape
- x = random.randint(0, h - patch_size)
- y = random.randint(0, w - patch_size)
- i_low_data_crop = i_low_expand[x : x+patch_size, y : y+patch_size, :]
- i_high_data_crop = i_high_expand[x : x+patch_size, y : y+patch_size, :]
-
- rand_mode = np.random.randint(0, 7)
- batch_input_low_i[patch_id, :, :, :] = data_augmentation(i_low_data_crop , rand_mode)
- batch_input_high_i[patch_id, :, :, :] = data_augmentation(i_high_data_crop, rand_mode)
-
- ratio = np.mean(i_low_data_crop/(i_high_data_crop+0.0001))
- #print(ratio)
- i_low_data_ratio = np.ones([patch_size,patch_size])*(1/ratio+0.0001)
- i_low_ratio_expand = np.expand_dims(i_low_data_ratio , axis =2)
- i_high_data_ratio = np.ones([patch_size,patch_size])*(ratio)
- i_high_ratio_expand = np.expand_dims(i_high_data_ratio , axis =2)
- batch_input_low_i_ratio[patch_id, :, :, :] = i_low_ratio_expand
- batch_input_high_i_ratio[patch_id, :, :, :] = i_high_ratio_expand
-
- rand_mode = np.random.randint(0, 2)
- if rand_mode == 1:
- input_low_i_rand[patch_id, :, :, :] = batch_input_low_i[patch_id, :, :, :]
- input_high_i_rand[patch_id, :, :, :] = batch_input_high_i[patch_id, :, :, :]
- input_low_i_rand_ratio[patch_id, :, :, :] = batch_input_low_i_ratio[patch_id, :, :, :]
- input_high_i_rand_ratio[patch_id, :, :, :] = batch_input_high_i_ratio[patch_id, :, :, :]
- else:
- input_low_i_rand[patch_id, :, :, :] = batch_input_high_i[patch_id, :, :, :]
- input_high_i_rand[patch_id, :, :, :] = batch_input_low_i[patch_id, :, :, :]
- input_low_i_rand_ratio[patch_id, :, :, :] = batch_input_high_i_ratio[patch_id, :, :, :]
- input_high_i_rand_ratio[patch_id, :, :, :] = batch_input_low_i_ratio[patch_id, :, :, :]
-
- image_id = (image_id + 1) % len(train_adjust_low_i_data)
-
- _, loss = sess.run([train_op, train_loss], feed_dict={input_low_i: input_low_i_rand,input_low_i_ratio: input_low_i_rand_ratio,\
- input_high_i: input_high_i_rand, \
- lr: learning_rate})
- print("%s Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.6f" \
- % (train_phase, epoch + 1, batch_id + 1, numBatch, time.time() - start_time, loss))
- iter_num += 1
- if (epoch + 1) % eval_every_epoch == 0:
- print("[*] Evaluating for phase %s / epoch %d..." % (train_phase, epoch + 1))
-
- for idx in range(10):
- rand_idx = idx#np.random.randint(26)
- input_uu_i = eval_adjust_low_i_data[rand_idx]
- input_low_eval_i = np.expand_dims(input_uu_i, axis=0)
- input_low_eval_ii = np.expand_dims(input_low_eval_i, axis=3)
- h, w = eval_adjust_low_i_data[idx].shape
- rand_ratio = np.random.random(1)*2
- input_uu_i_ratio = np.ones([h,w]) * rand_ratio
- input_low_eval_i_ratio = np.expand_dims(input_uu_i_ratio, axis=0)
- input_low_eval_ii_ratio = np.expand_dims(input_low_eval_i_ratio, axis=3)
-
- result_1 = sess.run(output_i, feed_dict={input_low_i: input_low_eval_ii, input_low_i_ratio: input_low_eval_ii_ratio})
- save_images(os.path.join(sample_dir, 'h_eval_%d_%d_%5f.png' % ( epoch + 1 , rand_idx + 1, rand_ratio)), input_uu_i, result_1)
-
-
- saver.save(sess, checkpoint_dir + 'model.ckpt')
-
- print("[*] Finish training for phase %s." % train_phase)
-
-
|