|
- import tensorflow as tf
- import tensorflow.contrib.slim as slim
- from tensorflow.contrib.layers.python.layers import initializers
-
- def lrelu(x, trainbable=None):
- return tf.maximum(x*0.2,x)
-
- def upsample_and_concat(x1, x2, output_channels, in_channels, scope_name, trainable=True):
- with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope:
- pool_size = 2
- deconv_filter = tf.get_variable('weights', [pool_size, pool_size, output_channels, in_channels], trainable= True)
- deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2) , strides=[1, pool_size, pool_size, 1], name=scope_name)
-
- deconv_output = tf.concat([deconv, x2],3)
- deconv_output.set_shape([None, None, None, output_channels*2])
-
- return deconv_output
-
- def DecomNet_simple(input):
- with tf.variable_scope('DecomNet', reuse=tf.AUTO_REUSE):
- conv1=slim.conv2d(input,32,[3,3], rate=1, activation_fn=lrelu,scope='g_conv1_1')
- pool1=slim.max_pool2d(conv1, [2, 2], stride = 2, padding='SAME' )
- conv2=slim.conv2d(pool1,64,[3,3], rate=1, activation_fn=lrelu,scope='g_conv2_1')
- pool2=slim.max_pool2d(conv2, [2, 2], stride = 2, padding='SAME' )
- conv3=slim.conv2d(pool2,128,[3,3], rate=1, activation_fn=lrelu,scope='g_conv3_1')
- up8 = upsample_and_concat( conv3, conv2, 64, 128 , 'g_up_1')
- conv8=slim.conv2d(up8, 64,[3,3], rate=1, activation_fn=lrelu,scope='g_conv8_1')
- up9 = upsample_and_concat( conv8, conv1, 32, 64 , 'g_up_2')
- conv9=slim.conv2d(up9, 32,[3,3], rate=1, activation_fn=lrelu,scope='g_conv9_1')
- # Here, we use 1*1 kernel to replace the 3*3 ones in the paper to get better results.
- conv10=slim.conv2d(conv9,3,[1,1], rate=1, activation_fn=None, scope='g_conv10')
- R_out = tf.sigmoid(conv10)
-
- l_conv2=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='l_conv1_2')
- l_conv3=tf.concat([l_conv2, conv9],3)
- # Here, we use 1*1 kernel to replace the 3*3 ones in the paper to get better results.
- l_conv4=slim.conv2d(l_conv3,1,[1,1], rate=1, activation_fn=None,scope='l_conv1_4')
- L_out = tf.sigmoid(l_conv4)
-
- return R_out, L_out
-
- def Restoration_net(input_r, input_i):
- with tf.variable_scope('Restoration_net', reuse=tf.AUTO_REUSE):
- input_all = tf.concat([input_r,input_i], 3)
-
- conv1=slim.conv2d(input_all,32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv1_1')
- conv1=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv1_2')
- pool1=slim.max_pool2d(conv1, [2, 2], padding='SAME' )
-
- conv2=slim.conv2d(pool1,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv2_1')
- conv2=slim.conv2d(conv2,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv2_2')
- pool2=slim.max_pool2d(conv2, [2, 2], padding='SAME' )
-
- conv3=slim.conv2d(pool2,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv3_1')
- conv3=slim.conv2d(conv3,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv3_2')
- pool3=slim.max_pool2d(conv3, [2, 2], padding='SAME' )
-
- conv4=slim.conv2d(pool3,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv4_1')
- conv4=slim.conv2d(conv4,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv4_2')
- pool4=slim.max_pool2d(conv4, [2, 2], padding='SAME' )
-
- conv5=slim.conv2d(pool4,512,[3,3], rate=1, activation_fn=lrelu,scope='de_conv5_1')
- conv5=slim.conv2d(conv5,512,[3,3], rate=1, activation_fn=lrelu,scope='de_conv5_2')
-
- up6 = upsample_and_concat( conv5, conv4, 256, 512, 'up_6')
-
- conv6=slim.conv2d(up6, 256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv6_1')
- conv6=slim.conv2d(conv6,256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv6_2')
-
- up7 = upsample_and_concat( conv6, conv3, 128, 256, 'up_7' )
- conv7=slim.conv2d(up7, 128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv7_1')
- conv7=slim.conv2d(conv7,128,[3,3], rate=1, activation_fn=lrelu,scope='de_conv7_2')
-
- up8 = upsample_and_concat( conv7, conv2, 64, 128, 'up_8' )
- conv8=slim.conv2d(up8, 64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv8_1')
- conv8=slim.conv2d(conv8,64,[3,3], rate=1, activation_fn=lrelu,scope='de_conv8_2')
-
- up9 = upsample_and_concat( conv8, conv1, 32, 64, 'up_9' )
- conv9=slim.conv2d(up9, 32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv9_1')
- conv9=slim.conv2d(conv9,32,[3,3], rate=1, activation_fn=lrelu,scope='de_conv9_2')
-
- conv10=slim.conv2d(conv9,3,[3,3], rate=1, activation_fn=None, scope='de_conv10')
-
- out = tf.sigmoid(conv10)
- return out
-
- def Illumination_adjust_net(input_i, input_ratio):
- with tf.variable_scope('Illumination_adjust_net', reuse=tf.AUTO_REUSE):
- input_all = tf.concat([input_i, input_ratio], 3)
-
- conv1=slim.conv2d(input_all,32,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_1')
- conv2=slim.conv2d(conv1,32,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_2')
- conv3=slim.conv2d(conv2,32,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_3')
- conv4=slim.conv2d(conv3,1,[3,3], rate=1, activation_fn=lrelu,scope='en_conv_4')
-
- L_enhance = tf.sigmoid(conv4)
- return L_enhance
|