|
- import tensorflow as tf
- import numpy as np
- import os
- import matplotlib.pyplot as plt
-
-
- class mine_layer(tf.keras.Model):
- def __init__(self):
- super(mine_layer,self).__init__()
- self.fl_1 = tf.keras.layers.Flatten()
- initializer1 = tf.keras.initializers.TruncatedNormal(mean=0., stddev=0.02)
- initializer2 = tf.keras.initializers.TruncatedNormal(mean=0., stddev=0.02)
- self.mul1=tf.keras.layers.Dense(32, kernel_initializer=initializer1, bias_initializer='zeros')
- self.mul2 = tf.keras.layers.Dense(32,kernel_initializer=initializer2,bias_initializer='zeros')
- self.batch=tf.keras.layers.BatchNormalization()
- self.A1=tf.keras.layers.ELU(1)
- self.A2 = tf.keras.layers.ELU(1)
- self.D3= tf.keras.layers.Dense(1)
-
- def call(self,x,y):
-
- output1=self.fl_1(x)
- output2=self.fl_1(y)
- output = tf.concat([output1, output2],1)
- output=self.fl_1(output)
- output = self.mul1(output)
- output= self.A1(output)
- output=self.mul2(output)
- output=self.A2(output)
-
-
- return self.D3(output)
-
-
-
- class MINE(tf.keras.Model):
- def __init__(self,user_num,M):
- super(MINE, self).__init__()
- self.MINE_seq=[]
- self.M=M
- self.param=[]
- for i in range(user_num):
- mine_dex = 'mine_' + str(i)
- ma = 'ma_et' + str(i)
- locals()[mine_dex]=mine_layer()
- locals()[ma]=None
- self.MINE_seq.append(locals()[mine_dex])
- self.param.append(locals()[ma])
-
- def call(self,x,y):
- out = []
- for i, layer in enumerate(self.MINE_seq):
- x2=x[:,(self.M+1)*i:(self.M+1)*i+(self.M+1)]
- y2=y[:, i, :, :]
- y2=np.array(y2,dtype=np.float32)
- y1 = layer(x2,y2)
- out.append(y1)
-
- return out
|