You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

58 lines
2.3 kB

  1. from tensorflow.python.ops.init_ops import Initializer,_compute_fans
  2. from numpy.random import RandomState
  3. import numpy as np
  4. import tensorflow as tf
  5. class ComplexInit(Initializer):
  6. # The standard complex initialization using
  7. # either the He or the Glorot criterion.
  8. def __init__(self, kernel_size, input_dim,
  9. weight_dim, nb_filters=None,
  10. criterion='glorot', seed=None):
  11. # `weight_dim` is used as a parameter for sanity check
  12. # as we should not pass an integer as kernel_size when
  13. # the weight dimension is >= 2.
  14. # nb_filters == 0 if weights are not convolutional (matrix instead of filters)
  15. # then in such a case, weight_dim = 2.
  16. # (in case of 2D input):
  17. # nb_filters == None and len(kernel_size) == 2 and_weight_dim == 2
  18. # conv1D: len(kernel_size) == 1 and weight_dim == 1
  19. # conv2D: len(kernel_size) == 2 and weight_dim == 2
  20. # conv3d: len(kernel_size) == 3 and weight_dim == 3
  21. assert len(kernel_size) == weight_dim and weight_dim in {0, 1, 2, 3}
  22. self.nb_filters = nb_filters
  23. self.kernel_size = kernel_size
  24. self.input_dim = input_dim
  25. self.weight_dim = weight_dim
  26. self.criterion = criterion
  27. self.seed = 1337 if seed is None else seed
  28. def __call__(self, shape, dtype=None, partition_info=None):
  29. if self.nb_filters is not None:
  30. kernel_shape = tuple(self.kernel_size) + (int(self.input_dim), self.nb_filters)
  31. else:
  32. kernel_shape = (int(self.input_dim), self.kernel_size[-1])
  33. fan_in, fan_out = _compute_fans(
  34. tuple(self.kernel_size) + (self.input_dim, self.nb_filters)
  35. )
  36. if self.criterion == 'glorot':
  37. s = 1. / (fan_in + fan_out)
  38. elif self.criterion == 'he':
  39. s = 1. / fan_in
  40. else:
  41. raise ValueError('Invalid criterion: ' + self.criterion)
  42. rng = RandomState(self.seed)
  43. modulus = rng.rayleigh(scale=s, size=kernel_shape)
  44. phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape)
  45. weight_real = modulus * np.cos(phase)
  46. weight_imag = modulus * np.sin(phase)
  47. weight = np.concatenate([weight_real, weight_imag], axis=-1)
  48. return weight

简介

No Description

Python

贡献者 (1)