You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

#### 58 lines 2.3 kB Raw Permalink Blame History

 from tensorflow.python.ops.init_ops import Initializer,_compute_fans from numpy.random import RandomState import numpy as np import tensorflow as tf class ComplexInit(Initializer): # The standard complex initialization using # either the He or the Glorot criterion. def __init__(self, kernel_size, input_dim, weight_dim, nb_filters=None, criterion='glorot', seed=None): # weight_dim is used as a parameter for sanity check # as we should not pass an integer as kernel_size when # the weight dimension is >= 2. # nb_filters == 0 if weights are not convolutional (matrix instead of filters) # then in such a case, weight_dim = 2. # (in case of 2D input): # nb_filters == None and len(kernel_size) == 2 and_weight_dim == 2 # conv1D: len(kernel_size) == 1 and weight_dim == 1 # conv2D: len(kernel_size) == 2 and weight_dim == 2 # conv3d: len(kernel_size) == 3 and weight_dim == 3 assert len(kernel_size) == weight_dim and weight_dim in {0, 1, 2, 3} self.nb_filters = nb_filters self.kernel_size = kernel_size self.input_dim = input_dim self.weight_dim = weight_dim self.criterion = criterion self.seed = 1337 if seed is None else seed def __call__(self, shape, dtype=None, partition_info=None): if self.nb_filters is not None: kernel_shape = tuple(self.kernel_size) + (int(self.input_dim), self.nb_filters) else: kernel_shape = (int(self.input_dim), self.kernel_size[-1]) fan_in, fan_out = _compute_fans( tuple(self.kernel_size) + (self.input_dim, self.nb_filters) ) if self.criterion == 'glorot': s = 1. / (fan_in + fan_out) elif self.criterion == 'he': s = 1. / fan_in else: raise ValueError('Invalid criterion: ' + self.criterion) rng = RandomState(self.seed) modulus = rng.rayleigh(scale=s, size=kernel_shape) phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape) weight_real = modulus * np.cos(phase) weight_imag = modulus * np.sin(phase) weight = np.concatenate([weight_real, weight_imag], axis=-1) return weight 

No Description

Python