|
- from tensorflow.python.ops.init_ops import Initializer,_compute_fans
- from numpy.random import RandomState
- import numpy as np
- import tensorflow as tf
-
- class ComplexInit(Initializer):
- # The standard complex initialization using
- # either the He or the Glorot criterion.
- def __init__(self, kernel_size, input_dim,
- weight_dim, nb_filters=None,
- criterion='glorot', seed=None):
-
- # `weight_dim` is used as a parameter for sanity check
- # as we should not pass an integer as kernel_size when
- # the weight dimension is >= 2.
- # nb_filters == 0 if weights are not convolutional (matrix instead of filters)
- # then in such a case, weight_dim = 2.
- # (in case of 2D input):
- # nb_filters == None and len(kernel_size) == 2 and_weight_dim == 2
- # conv1D: len(kernel_size) == 1 and weight_dim == 1
- # conv2D: len(kernel_size) == 2 and weight_dim == 2
- # conv3d: len(kernel_size) == 3 and weight_dim == 3
-
- assert len(kernel_size) == weight_dim and weight_dim in {0, 1, 2, 3}
- self.nb_filters = nb_filters
- self.kernel_size = kernel_size
- self.input_dim = input_dim
- self.weight_dim = weight_dim
- self.criterion = criterion
- self.seed = 1337 if seed is None else seed
-
- def __call__(self, shape, dtype=None, partition_info=None):
-
- if self.nb_filters is not None:
- kernel_shape = tuple(self.kernel_size) + (int(self.input_dim), self.nb_filters)
- else:
- kernel_shape = (int(self.input_dim), self.kernel_size[-1])
-
- fan_in, fan_out = _compute_fans(
- tuple(self.kernel_size) + (self.input_dim, self.nb_filters)
- )
-
- if self.criterion == 'glorot':
- s = 1. / (fan_in + fan_out)
- elif self.criterion == 'he':
- s = 1. / fan_in
- else:
- raise ValueError('Invalid criterion: ' + self.criterion)
- rng = RandomState(self.seed)
- modulus = rng.rayleigh(scale=s, size=kernel_shape)
- phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape)
- weight_real = modulus * np.cos(phase)
- weight_imag = modulus * np.sin(phase)
- weight = np.concatenate([weight_real, weight_imag], axis=-1)
-
- return weight
|