#875 [WIP]add distributions

Open
Erpim wants to merge 1 commits from erpim_0207_ into master
  1. +38
    -0
      mindtorch/torch/_C/__init__.py
  2. +5
    -0
      mindtorch/torch/_six.py
  3. +96
    -0
      mindtorch/torch/distributions/__init__.py
  4. +104
    -0
      mindtorch/torch/distributions/bernoulli.py
  5. +81
    -0
      mindtorch/torch/distributions/beta.py
  6. +116
    -0
      mindtorch/torch/distributions/binomial.py
  7. +107
    -0
      mindtorch/torch/distributions/categorical.py
  8. +65
    -0
      mindtorch/torch/distributions/cauchy.py
  9. +17
    -0
      mindtorch/torch/distributions/chi2.py
  10. +168
    -0
      mindtorch/torch/distributions/constraint_registry.py
  11. +436
    -0
      mindtorch/torch/distributions/constraints.py
  12. +175
    -0
      mindtorch/torch/distributions/continuous_bernoulli.py
  13. +95
    -0
      mindtorch/torch/distributions/dirichlet.py
  14. +166
    -0
      mindtorch/torch/distributions/distribution.py
  15. +32
    -0
      mindtorch/torch/distributions/exp_family.py
  16. +74
    -0
      mindtorch/torch/distributions/exponential.py
  17. +77
    -0
      mindtorch/torch/distributions/fishersnedecor.py
  18. +76
    -0
      mindtorch/torch/distributions/gamma.py
  19. +97
    -0
      mindtorch/torch/distributions/geometric.py
  20. +57
    -0
      mindtorch/torch/distributions/gumbel.py
  21. +58
    -0
      mindtorch/torch/distributions/half_cauchy.py
  22. +57
    -0
      mindtorch/torch/distributions/half_normal.py
  23. +83
    -0
      mindtorch/torch/distributions/independent.py
  24. +784
    -0
      mindtorch/torch/distributions/kl.py
  25. +58
    -0
      mindtorch/torch/distributions/kumaraswamy.py
  26. +77
    -0
      mindtorch/torch/distributions/laplace.py
  27. +71
    -0
      mindtorch/torch/distributions/lkj_cholesky.py
  28. +41
    -0
      mindtorch/torch/distributions/log_normal.py
  29. +30
    -0
      mindtorch/torch/distributions/logistic_normal.py
  30. +156
    -0
      mindtorch/torch/distributions/lowrank_multivariate_normal.py
  31. +153
    -0
      mindtorch/torch/distributions/mixture_same_family.py
  32. +95
    -0
      mindtorch/torch/distributions/multinomial.py
  33. +171
    -0
      mindtorch/torch/distributions/multivariate_normal.py
  34. +95
    -0
      mindtorch/torch/distributions/negative_binomial.py
  35. +86
    -0
      mindtorch/torch/distributions/normal.py
  36. +92
    -0
      mindtorch/torch/distributions/one_hot_categorical.py
  37. +44
    -0
      mindtorch/torch/distributions/pareto.py
  38. +59
    -0
      mindtorch/torch/distributions/poisson.py
  39. +103
    -0
      mindtorch/torch/distributions/relaxed_bernoulli.py
  40. +93
    -0
      mindtorch/torch/distributions/relaxed_categorical.py
  41. +74
    -0
      mindtorch/torch/distributions/studentT.py
  42. +128
    -0
      mindtorch/torch/distributions/transformed_distribution.py
  43. +956
    -0
      mindtorch/torch/distributions/transforms.py
  44. +80
    -0
      mindtorch/torch/distributions/uniform.py
  45. +108
    -0
      mindtorch/torch/distributions/utils.py
  46. +115
    -0
      mindtorch/torch/distributions/von_mises.py
  47. +53
    -0
      mindtorch/torch/distributions/weibull.py
  48. +264
    -0
      mindtorch/torch/distributions/wishart.py
  49. +2
    -1
      mindtorch/torchvision/transforms/autoaugment.py

+ 38
- 0
mindtorch/torch/_C/__init__.py View File

@@ -1,4 +1,5 @@
import mindspore as ms
from mindspore.ops.primitive import _primexpr
from mindtorch.torch._C.Generator import *
from mindtorch.torch._C.Size import *
from mindtorch.utils import unsupported_attr
@@ -26,3 +27,40 @@ contiguous_format = memory_format("contiguous_format")
channels_last = memory_format("channels_last")
channels_last_3d = memory_format("channels_last_3d")
preserve_format = memory_format("preserve_format")

def _get_tracing_state():
# Currently, jit.trace is not supported, so the return value is always False.
return False


@_primexpr
def _infer_size(a_size, b_size):
# If a_size and b_size can broadcast to each other, return the shape after broadcasting,
# otherwise raise error.
exchange_flag = False
if len(a_size) >= len(b_size):
x_shape = a_size
y_shape = b_size
else:
exchange_flag = True
x_shape = b_size
y_shape = a_size

prefix_shape = x_shape[:len(x_shape) - len(y_shape)]
res_shape = ()

for i in range(-1, -(len(y_shape) + 1), -1):
if x_shape[i] == y_shape[i]:
res_shape = (x_shape[i],) + res_shape
elif x_shape[i] == 1:
res_shape = (y_shape[i],) + res_shape
elif y_shape[i] == 1:
res_shape = (x_shape[i],) + res_shape
elif exchange_flag:
raise RuntimeError(f"The size of tehsor a ({y_shape[i]}) must match the size of tensor b " \
f"({x_shape[i]}) at non-singleton dimension {len(x_shape) + i}")
else:
raise RuntimeError(f"The size of tehsor a ({x_shape[i]}) must match the size of tensor b " \
f"({y_shape[i]}) at non-singleton dimension {len(x_shape) + i}")

return prefix_shape + res_shape

+ 5
- 0
mindtorch/torch/_six.py View File

@@ -1 +1,6 @@
import math

inf = math.inf
nan = math.nan

string_classes = (str, bytes)

+ 96
- 0
mindtorch/torch/distributions/__init__.py View File

@@ -0,0 +1,96 @@
from .bernoulli import Bernoulli
from .beta import Beta
from .binomial import Binomial
from .categorical import Categorical
from .cauchy import Cauchy
from .chi2 import Chi2
from .constraint_registry import biject_to, transform_to
from .continuous_bernoulli import ContinuousBernoulli
from .dirichlet import Dirichlet
from .distribution import Distribution
from .exp_family import ExponentialFamily
from .exponential import Exponential
from .fishersnedecor import FisherSnedecor
from .gamma import Gamma
from .geometric import Geometric
from .gumbel import Gumbel
from .half_cauchy import HalfCauchy
from .half_normal import HalfNormal
from .independent import Independent
from .kl import kl_divergence, register_kl, _add_kl_info
from .kumaraswamy import Kumaraswamy
from .laplace import Laplace
from .lkj_cholesky import LKJCholesky
from .log_normal import LogNormal
from .logistic_normal import LogisticNormal
from .lowrank_multivariate_normal import LowRankMultivariateNormal
from .mixture_same_family import MixtureSameFamily
from .multinomial import Multinomial
from .multivariate_normal import MultivariateNormal
from .negative_binomial import NegativeBinomial
from .normal import Normal
from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough
from .pareto import Pareto
from .poisson import Poisson
from .relaxed_bernoulli import RelaxedBernoulli
from .relaxed_categorical import RelaxedOneHotCategorical
from .studentT import StudentT
from .transformed_distribution import TransformedDistribution
from .transforms import * # noqa: F403
from .uniform import Uniform
from .von_mises import VonMises
from .weibull import Weibull
from .wishart import Wishart
from . import transforms

_add_kl_info()
del _add_kl_info

__all__ = [
'Bernoulli',
'Beta',
'Binomial',
'Categorical',
'Cauchy',
'Chi2',
'ContinuousBernoulli',
'Dirichlet',
'Distribution',
'Exponential',
'ExponentialFamily',
'FisherSnedecor',
'Gamma',
'Geometric',
'Gumbel',
'HalfCauchy',
'HalfNormal',
'Independent',
'Kumaraswamy',
'LKJCholesky',
'Laplace',
'LogNormal',
'LogisticNormal',
'LowRankMultivariateNormal',
'MixtureSameFamily',
'Multinomial',
'MultivariateNormal',
'NegativeBinomial',
'Normal',
'OneHotCategorical',
'OneHotCategoricalStraightThrough',
'Pareto',
'RelaxedBernoulli',
'RelaxedOneHotCategorical',
'StudentT',
'Poisson',
'Uniform',
'VonMises',
'Weibull',
'Wishart',
'TransformedDistribution',
'biject_to',
'kl_divergence',
'register_kl',
'transform_to',
]
__all__.extend(transforms.__all__)

+ 104
- 0
mindtorch/torch/distributions/bernoulli.py View File

@@ -0,0 +1,104 @@
from numbers import Number
from mindspore import _no_grad as torch_no_grad
from mindtorch.torch._six import nan
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.exp_family import ExponentialFamily
from mindtorch.torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property
from mindtorch.torch.nn.functional import binary_cross_entropy_with_logits
from mindtorch.torch._C.Size import Size
from mindtorch.torch.conflict_functional import arange as torch_arange
from mindtorch.torch.functional import bernoulli, log as torch_log, exp as torch_exp


class Bernoulli(ExponentialFamily):
arg_constraints = {'probs': constraints.unit_interval,
'logits': constraints.real}
support = constraints.boolean
has_enumerate_support = True
_mean_carrier_measure = 0

def __init__(self, probs=None, logits=None, validate_args=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
is_scalar = isinstance(probs, Number)
self.probs, = broadcast_all(probs)
else:
is_scalar = isinstance(logits, Number)
self.logits, = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
batch_shape = Size()
else:
batch_shape = self._param.size()
super(Bernoulli, self).__init__(batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Bernoulli, _instance)
batch_shape = Size(batch_shape)
if 'probs' in self.__dict__:
new.probs = self.probs.expand(batch_shape)
new._param = new.probs
if 'logits' in self.__dict__:
new.logits = self.logits.expand(batch_shape)
new._param = new.logits
super(Bernoulli, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

@property
def mean(self):
return self.probs

@property
def mode(self):
mode = (self.probs >= 0.5).to(self.probs)
mode[self.probs == 0.5] = nan
return mode

@property
def variance(self):
return self.probs * (1 - self.probs)

@lazy_property
def logits(self): # pylint: disable=E0202
return probs_to_logits(self.probs, is_binary=True)

@lazy_property
def probs(self): # pylint: disable=E0202
return logits_to_probs(self.logits, is_binary=True)

@property
def param_shape(self):
return self._param.size()

def sample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
with torch_no_grad():
return bernoulli(self.probs.expand(shape))

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits, value)
return -binary_cross_entropy_with_logits(logits, value, reduction='none') # pylint: disable=E1130

def entropy(self):
return binary_cross_entropy_with_logits(self.logits, self.probs, reduction='none')

def enumerate_support(self, expand=True):
values = torch_arange(2, dtype=self._param.dtype, device=self._param.device)
values = values.view((-1,) + (1,) * len(self._batch_shape))
if expand:
values = values.expand((-1,) + self._batch_shape)
return values

@property
def _natural_params(self):
return (torch_log(self.probs / (1 - self.probs)), )

def _log_normalizer(self, x):
return torch_log(1 + torch_exp(x))

+ 81
- 0
mindtorch/torch/distributions/beta.py View File

@@ -0,0 +1,81 @@
from numbers import Real, Number

from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.dirichlet import Dirichlet
from mindtorch.torch.distributions.exp_family import ExponentialFamily
from mindtorch.torch.distributions.utils import broadcast_all
from mindtorch.torch.tensor import tensor as torch_tensor
from mindtorch.torch.functional import stack, lgamma
from mindtorch.torch._C.Size import Size


class Beta(ExponentialFamily):
arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive}
support = constraints.unit_interval
has_rsample = True

def __init__(self, concentration1, concentration0, validate_args=None):
if isinstance(concentration1, Real) and isinstance(concentration0, Real):
concentration1_concentration0 = torch_tensor([float(concentration1), float(concentration0)])
else:
concentration1, concentration0 = broadcast_all(concentration1, concentration0)
concentration1_concentration0 = stack([concentration1, concentration0], -1)
self._dirichlet = Dirichlet(concentration1_concentration0, validate_args=validate_args)
super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Beta, _instance)
batch_shape = Size(batch_shape)
new._dirichlet = self._dirichlet.expand(batch_shape)
super(Beta, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

@property
def mean(self):
return self.concentration1 / (self.concentration1 + self.concentration0)

@property
def mode(self):
return self._dirichlet.mode[..., 0]

@property
def variance(self):
total = self.concentration1 + self.concentration0
return (self.concentration1 * self.concentration0 /
(total.pow(2) * (total + 1)))

def rsample(self, sample_shape=()):
return self._dirichlet.rsample(sample_shape).select(-1, 0)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
heads_tails = stack([value, 1.0 - value], -1)
return self._dirichlet.log_prob(heads_tails)

def entropy(self):
return self._dirichlet.entropy()

@property
def concentration1(self):
result = self._dirichlet.concentration[..., 0]
if isinstance(result, Number):
return torch_tensor([result])
else:
return result

@property
def concentration0(self):
result = self._dirichlet.concentration[..., 1]
if isinstance(result, Number):
return torch_tensor([result])
else:
return result

@property
def _natural_params(self):
return (self.concentration1, self.concentration0)

def _log_normalizer(self, x, y):
return lgamma(x) + lgamma(y) - lgamma(x + y)

+ 116
- 0
mindtorch/torch/distributions/binomial.py View File

@@ -0,0 +1,116 @@
from mindspore import _no_grad as torch_no_grad
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.utils import broadcast_all, probs_to_logits, lazy_property, logits_to_probs
from mindtorch.torch._C.Size import Size
from mindtorch.torch.conflict_functional import arange as torch_arange
from mindtorch.torch.functional import binomial, lgamma, exp as torch_exp, abs as torch_abs, log1p


def _clamp_by_zero(x):
return (x.clamp(min=0) + x - x.clamp(max=0)) / 2


class Binomial(Distribution):
arg_constraints = {'total_count': constraints.nonnegative_integer,
'probs': constraints.unit_interval,
'logits': constraints.real}
has_enumerate_support = True

def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
self.total_count, self.probs, = broadcast_all(total_count, probs)
self.total_count = self.total_count.type_as(self.probs)
else:
self.total_count, self.logits, = broadcast_all(total_count, logits)
self.total_count = self.total_count.type_as(self.logits)

self._param = self.probs if probs is not None else self.logits
batch_shape = self._param.size()
super(Binomial, self).__init__(batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Binomial, _instance)
batch_shape = Size(batch_shape)
new.total_count = self.total_count.expand(batch_shape)
if 'probs' in self.__dict__:
new.probs = self.probs.expand(batch_shape)
new._param = new.probs
if 'logits' in self.__dict__:
new.logits = self.logits.expand(batch_shape)
new._param = new.logits
super(Binomial, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

@constraints.dependent_property(is_discrete=True, event_dim=0)
def support(self):
return constraints.integer_interval(0, self.total_count)

@property
def mean(self):
return self.total_count * self.probs

@property
def mode(self):
return ((self.total_count + 1) * self.probs).floor().clamp(max=self.total_count)

@property
def variance(self):
return self.total_count * self.probs * (1 - self.probs)

@lazy_property
def logits(self): # pylint: disable=E0202
return probs_to_logits(self.probs, is_binary=True)

@lazy_property
def probs(self): # pylint: disable=E0202
return logits_to_probs(self.logits, is_binary=True)

@property
def param_shape(self):
return self._param.size()

def sample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
with torch_no_grad():
return binomial(self.total_count.expand(shape), self.probs.expand(shape))

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
log_factorial_n = lgamma(self.total_count + 1)
log_factorial_k = lgamma(value + 1)
log_factorial_nmk = lgamma(self.total_count - value + 1)
# k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)
# (case logit < 0) = k * logit - n * log1p(e^logit)
# (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p)
# = k * logit - n * logit - n * log1p(e^-logit)
# (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)
normalize_term = (self.total_count * _clamp_by_zero(self.logits)
+ self.total_count * log1p(torch_exp(-torch_abs(self.logits))) # pylint: disable=E1130
- log_factorial_n)
return value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term

def entropy(self):
total_count = int(self.total_count.max())
if not self.total_count.min() == total_count:
raise NotImplementedError("Inhomogeneous total count not supported by `entropy`.")

log_prob = self.log_prob(self.enumerate_support(False))
return -(torch_exp(log_prob) * log_prob).sum(0)

def enumerate_support(self, expand=True):
total_count = int(self.total_count.max())
if not self.total_count.min() == total_count:
raise NotImplementedError("Inhomogeneous total count not supported by `enumerate_support`.")
values = torch_arange(1 + total_count, dtype=self._param.dtype, device=self._param.device)
values = values.view((-1,) + (1,) * len(self._batch_shape))
if expand:
values = values.expand((-1,) + self._batch_shape)
return values

+ 107
- 0
mindtorch/torch/distributions/categorical.py View File

@@ -0,0 +1,107 @@
import mindtorch.torch.common.dtype as mindtorch_dtype
from mindtorch.torch._six import nan
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.utils import probs_to_logits, logits_to_probs, lazy_property
from mindtorch.torch._C.Size import Size
from mindtorch.torch.conflict_functional import arange as torch_arange
from mindtorch.torch.common.dtype import finfo
from mindtorch.torch.functional import full, multinomial, broadcast_tensors, clamp


class Categorical(Distribution):
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real_vector}
has_enumerate_support = True

def __init__(self, probs=None, logits=None, validate_args=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
if probs.dim() < 1:
raise ValueError("`probs` parameter must be at least one-dimensional.")
self.probs = probs / probs.sum(-1, keepdim=True)
else:
if logits.dim() < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
# Normalize
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
self._param = self.probs if probs is not None else self.logits
self._num_events = self._param.size()[-1]
batch_shape = self._param.size()[:-1] if self._param.ndimension() > 1 else Size()
super(Categorical, self).__init__(batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Categorical, _instance)
batch_shape = Size(batch_shape)
param_shape = batch_shape + Size((self._num_events,))
if 'probs' in self.__dict__:
new.probs = self.probs.expand(param_shape)
new._param = new.probs
if 'logits' in self.__dict__:
new.logits = self.logits.expand(param_shape)
new._param = new.logits
new._num_events = self._num_events
super(Categorical, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

@constraints.dependent_property(is_discrete=True, event_dim=0)
def support(self):
return constraints.integer_interval(0, self._num_events - 1)

@lazy_property
def logits(self): # pylint: disable=E0202
return probs_to_logits(self.probs)

@lazy_property
def probs(self): # pylint: disable=E0202
return logits_to_probs(self.logits)

@property
def param_shape(self):
return self._param.size()

@property
def mean(self):
return full(self._extended_shape(), nan, dtype=self.probs.dtype, device=self.probs.device)

@property
def mode(self):
return self.probs.argmax(axis=-1)

@property
def variance(self):
return full(self._extended_shape(), nan, dtype=self.probs.dtype, device=self.probs.device)

def sample(self, sample_shape=Size()):
if not isinstance(sample_shape, Size):
sample_shape = Size(sample_shape)
probs_2d = self.probs.reshape(-1, self._num_events)
samples_2d = multinomial(probs_2d, sample_shape.numel(), True).T
return samples_2d.reshape(self._extended_shape(sample_shape))

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
value = value.long().unsqueeze(-1)
value, log_pmf = broadcast_tensors(value, self.logits)
value = value[..., :1]
return log_pmf.gather(-1, value).squeeze(-1)

def entropy(self):
min_real = finfo(self.logits.dtype).min
logits = clamp(self.logits, min=min_real)
p_log_p = logits * self.probs
return -p_log_p.sum(-1)

def enumerate_support(self, expand=True):
num_events = self._num_events
values = torch_arange(num_events, dtype=mindtorch_dtype.long, device=self._param.device)
values = values.view((-1,) + (1,) * len(self._batch_shape))
if expand:
values = values.expand((-1,) + self._batch_shape)
return values

+ 65
- 0
mindtorch/torch/distributions/cauchy.py View File

@@ -0,0 +1,65 @@
import math
from numbers import Number
from mindtorch.torch._six import inf, nan

from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.utils import broadcast_all
from mindtorch.torch._C.Size import Size
from mindtorch.torch.functional import full, atan, tan


class Cauchy(Distribution):
arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
support = constraints.real
has_rsample = True

def __init__(self, loc, scale, validate_args=None):
self.loc, self.scale = broadcast_all(loc, scale)
if isinstance(loc, Number) and isinstance(scale, Number):
batch_shape = Size()
else:
batch_shape = self.loc.size()
super(Cauchy, self).__init__(batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Cauchy, _instance)
batch_shape = Size(batch_shape)
new.loc = self.loc.expand(batch_shape)
new.scale = self.scale.expand(batch_shape)
super(Cauchy, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

@property
def mean(self):
return full(self._extended_shape(), nan, dtype=self.loc.dtype, device=self.loc.device)

@property
def mode(self):
return self.loc

@property
def variance(self):
return full(self._extended_shape(), inf, dtype=self.loc.dtype, device=self.loc.device)

def rsample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
eps = self.loc.new(shape).cauchy_()
return self.loc + eps * self.scale

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return -math.log(math.pi) - self.scale.log() - (1 + ((value - self.loc) / self.scale)**2).log()

def cdf(self, value):
if self._validate_args:
self._validate_sample(value)
return atan((value - self.loc) / self.scale) / math.pi + 0.5

def icdf(self, value):
return tan(math.pi * (value - 0.5)) * self.scale + self.loc

def entropy(self):
return math.log(4 * math.pi) + self.scale.log()

+ 17
- 0
mindtorch/torch/distributions/chi2.py View File

@@ -0,0 +1,17 @@
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.gamma import Gamma


class Chi2(Gamma):
arg_constraints = {'df': constraints.positive}

def __init__(self, df, validate_args=None):
super(Chi2, self).__init__(0.5 * df, 0.5, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Chi2, _instance)
return super(Chi2, self).expand(batch_shape, new)

@property
def df(self):
return self.concentration * 2

+ 168
- 0
mindtorch/torch/distributions/constraint_registry.py View File

@@ -0,0 +1,168 @@
import numbers

from mindtorch.torch.distributions import constraints, transforms
from mindtorch.utils import unsupported_attr

__all__ = [
'ConstraintRegistry',
'biject_to',
'transform_to',
]


class ConstraintRegistry():
def __init__(self):
self._registry = {}
super(ConstraintRegistry, self).__init__()

def register(self, constraint, factory=None):
# Support use as decorator.
if factory is None:
return lambda factory: self.register(constraint, factory)

# Support calling on singleton instances.
if isinstance(constraint, constraints.Constraint):
constraint = type(constraint)

if not isinstance(constraint, type) or not issubclass(constraint, constraints.Constraint):
raise TypeError('Expected constraint to be either a Constraint subclass or instance, '
'but got {}'.format(constraint))

self._registry[constraint] = factory
return factory

def __call__(self, constraint):
# Look up by Constraint subclass.
try:
factory = self._registry[type(constraint)]
except KeyError:
raise NotImplementedError(
f'Cannot transform {type(constraint).__name__} constraints') from None
return factory(constraint)


biject_to = ConstraintRegistry()
transform_to = ConstraintRegistry()


################################################################################
# Registration Table
################################################################################

@biject_to.register(constraints.real)
@transform_to.register(constraints.real)
def _transform_to_real(constraint):
unsupported_attr(constraint)
return transforms.identity_transform


@biject_to.register(constraints.independent)
def _biject_to_independent(constraint):
base_transform = biject_to(constraint.base_constraint)
return transforms.IndependentTransform(
base_transform, constraint.reinterpreted_batch_ndims)


@transform_to.register(constraints.independent)
def _transform_to_independent(constraint):
base_transform = transform_to(constraint.base_constraint)
return transforms.IndependentTransform(
base_transform, constraint.reinterpreted_batch_ndims)


@biject_to.register(constraints.positive)
@biject_to.register(constraints.nonnegative)
@transform_to.register(constraints.positive)
@transform_to.register(constraints.nonnegative)
def _transform_to_positive(constraint):
unsupported_attr(constraint)
return transforms.ExpTransform()


@biject_to.register(constraints.greater_than)
@biject_to.register(constraints.greater_than_eq)
@transform_to.register(constraints.greater_than)
@transform_to.register(constraints.greater_than_eq)
def _transform_to_greater_than(constraint):
return transforms.ComposeTransform([transforms.ExpTransform(),
transforms.AffineTransform(constraint.lower_bound, 1)])


@biject_to.register(constraints.less_than)
@transform_to.register(constraints.less_than)
def _transform_to_less_than(constraint):
return transforms.ComposeTransform([transforms.ExpTransform(),
transforms.AffineTransform(constraint.upper_bound, -1)])


@biject_to.register(constraints.interval)
@biject_to.register(constraints.half_open_interval)
@transform_to.register(constraints.interval)
@transform_to.register(constraints.half_open_interval)
def _transform_to_interval(constraint):
# Handle the special case of the unit interval.
lower_is_0 = isinstance(constraint.lower_bound, numbers.Number) and constraint.lower_bound == 0
upper_is_1 = isinstance(constraint.upper_bound, numbers.Number) and constraint.upper_bound == 1
if lower_is_0 and upper_is_1:
return transforms.SigmoidTransform()

loc = constraint.lower_bound
scale = constraint.upper_bound - constraint.lower_bound
return transforms.ComposeTransform([transforms.SigmoidTransform(),
transforms.AffineTransform(loc, scale)])


@biject_to.register(constraints.simplex)
def _biject_to_simplex(constraint):
unsupported_attr(constraint)
return transforms.StickBreakingTransform()


@transform_to.register(constraints.simplex)
def _transform_to_simplex(constraint):
unsupported_attr(constraint)
return transforms.SoftmaxTransform()


# TODO define a bijection for LowerCholeskyTransform
@transform_to.register(constraints.lower_cholesky)
def _transform_to_lower_cholesky(constraint):
unsupported_attr(constraint)
return transforms.LowerCholeskyTransform()


@biject_to.register(constraints.corr_cholesky)
@transform_to.register(constraints.corr_cholesky)
def _transform_to_corr_cholesky(constraint):
unsupported_attr(constraint)
return transforms.CorrCholeskyTransform()


@biject_to.register(constraints.cat)
def _biject_to_cat(constraint):
return transforms.CatTransform([biject_to(c)
for c in constraint.cseq],
constraint.dim,
constraint.lengths)


@transform_to.register(constraints.cat)
def _transform_to_cat(constraint):
return transforms.CatTransform([transform_to(c)
for c in constraint.cseq],
constraint.dim,
constraint.lengths)


@biject_to.register(constraints.stack)
def _biject_to_stack(constraint):
return transforms.StackTransform(
[biject_to(c)
for c in constraint.cseq], constraint.dim)


@transform_to.register(constraints.stack)
def _transform_to_stack(constraint):
return transforms.StackTransform(
[transform_to(c)
for c in constraint.cseq], constraint.dim)

+ 436
- 0
mindtorch/torch/distributions/constraints.py View File

@@ -0,0 +1,436 @@
import mindtorch.torch.common.dtype as mindtorch_dtype
from mindtorch.torch.functional import all as torch_all, full as torch_full, stack as torch_stack, \
cat as torch_cat, isclose as torch_isclose
from mindtorch.torch.common.dtype import finfo
from mindtorch.torch.linalg import cholesky_ex, eigvalsh, norm

__all__ = [
'Constraint',
'boolean',
'cat',
'corr_cholesky',
'dependent',
'dependent_property',
'greater_than',
'greater_than_eq',
'independent',
'integer_interval',
'interval',
'half_open_interval',
'is_dependent',
'less_than',
'lower_cholesky',
'lower_triangular',
'multinomial',
'nonnegative_integer',
'positive',
'positive_semidefinite',
'positive_definite',
'positive_integer',
'real',
'real_vector',
'simplex',
'square',
'stack',
'symmetric',
'unit_interval',
]


class Constraint():
is_discrete = False # Default to continuous.
event_dim = 0 # Default to univariate.

def check(self, value):
raise NotImplementedError

def __repr__(self):
return self.__class__.__name__[1:] + '()'


class _Dependent(Constraint):
def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
self._is_discrete = is_discrete
self._event_dim = event_dim
super().__init__()

@property
def is_discrete(self):
if self._is_discrete is NotImplemented:
raise NotImplementedError(".is_discrete cannot be determined statically")
return self._is_discrete

@property
def event_dim(self):
if self._event_dim is NotImplemented:
raise NotImplementedError(".event_dim cannot be determined statically")
return self._event_dim

def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
if is_discrete is NotImplemented:
is_discrete = self._is_discrete
if event_dim is NotImplemented:
event_dim = self._event_dim
return _Dependent(is_discrete=is_discrete, event_dim=event_dim)

def check(self, x):
raise ValueError('Cannot determine validity of dependent constraint')


def is_dependent(constraint):
return isinstance(constraint, _Dependent)


class _DependentProperty(property, _Dependent):
def __init__(self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented):
super().__init__(fn)
self._is_discrete = is_discrete
self._event_dim = event_dim

def __call__(self, fn):
return _DependentProperty(fn, is_discrete=self._is_discrete, event_dim=self._event_dim)


class _IndependentConstraint(Constraint):
def __init__(self, base_constraint, reinterpreted_batch_ndims):
assert isinstance(base_constraint, Constraint)
assert isinstance(reinterpreted_batch_ndims, int)
assert reinterpreted_batch_ndims >= 0
self.base_constraint = base_constraint
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
super().__init__()

@property
def is_discrete(self):
return self.base_constraint.is_discrete

@property
def event_dim(self):
return self.base_constraint.event_dim + self.reinterpreted_batch_ndims

def check(self, value):
result = self.base_constraint.check(value)
if result.dim() < self.reinterpreted_batch_ndims:
expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims
raise ValueError(f"Expected value.dim() >= {expected} but got {value.dim()}")
result = result.reshape(result.shape[:result.dim() - self.reinterpreted_batch_ndims] + (-1,))
result = result.all(-1)
return result

def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__[1:], repr(self.base_constraint),
self.reinterpreted_batch_ndims)


class _Boolean(Constraint):
is_discrete = True

def check(self, value):
return (value == 0) | (value == 1)


class _OneHot(Constraint):
is_discrete = True
event_dim = 1

def check(self, value):
is_boolean = (value == 0) | (value == 1)
is_normalized = value.sum(-1).eq(1)
return is_boolean.all(-1) & is_normalized


class _IntegerInterval(Constraint):
is_discrete = True

def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
super().__init__()

def check(self, value):
return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)

def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
return fmt_string


class _IntegerLessThan(Constraint):
is_discrete = True

def __init__(self, upper_bound):
self.upper_bound = upper_bound
super().__init__()

def check(self, value):
return (value % 1 == 0) & (value <= self.upper_bound)

def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(upper_bound={})'.format(self.upper_bound)
return fmt_string


class _IntegerGreaterThan(Constraint):
is_discrete = True

def __init__(self, lower_bound):
self.lower_bound = lower_bound
super().__init__()

def check(self, value):
return (value % 1 == 0) & (value >= self.lower_bound)

def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(lower_bound={})'.format(self.lower_bound)
return fmt_string


class _Real(Constraint):
def check(self, value):
# False for NANs.
return value == value # pylint: disable=R0124


class _GreaterThan(Constraint):
def __init__(self, lower_bound):
self.lower_bound = lower_bound
super().__init__()

def check(self, value):
return self.lower_bound < value

def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(lower_bound={})'.format(self.lower_bound)
return fmt_string


class _GreaterThanEq(Constraint):
def __init__(self, lower_bound):
self.lower_bound = lower_bound
super().__init__()

def check(self, value):
return self.lower_bound <= value

def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(lower_bound={})'.format(self.lower_bound)
return fmt_string


class _LessThan(Constraint):
def __init__(self, upper_bound):
self.upper_bound = upper_bound
super().__init__()

def check(self, value):
return value < self.upper_bound

def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(upper_bound={})'.format(self.upper_bound)
return fmt_string


class _Interval(Constraint):
def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
super().__init__()

def check(self, value):
return (self.lower_bound <= value) & (value <= self.upper_bound)

def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
return fmt_string


class _HalfOpenInterval(Constraint):
def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
super().__init__()

def check(self, value):
return (self.lower_bound <= value) & (value < self.upper_bound)

def __repr__(self):
fmt_string = self.__class__.__name__[1:]
fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
return fmt_string


class _Simplex(Constraint):
event_dim = 1

def check(self, value):
return torch_all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)


class _Multinomial(Constraint):
is_discrete = True
event_dim = 1

def __init__(self, upper_bound):
self.upper_bound = upper_bound

def check(self, x):
return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound)


class _LowerTriangular(Constraint):
event_dim = 2

def check(self, value):
value_tril = value.tril()
return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]


class _LowerCholesky(Constraint):
event_dim = 2

def check(self, value):
value_tril = value.tril()
lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]

positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0]
return lower_triangular & positive_diagonal


class _CorrCholesky(Constraint):
event_dim = 2

def check(self, value):
tol = finfo(value.dtype).eps * value.size(-1) * 10 # 10 is an adjustable fudge factor
row_norm = norm(value.detach(), dim=-1)
unit_row_norm = (row_norm - 1.).abs().le(tol).all(dim=-1)
return _LowerCholesky().check(value) & unit_row_norm


class _Square(Constraint):
event_dim = 2

def check(self, value):
return torch_full(
size=value.shape[:-2],
fill_value=(value.shape[-2] == value.shape[-1]),
dtype=mindtorch_dtype.bool,
device=value.device
)


class _Symmetric(_Square):
def check(self, value):
square_check = super().check(value)
if not square_check.all():
return square_check
return torch_isclose(value, value.mT, atol=1e-6).all(-2).all(-1)


class _PositiveSemidefinite(_Symmetric):
def check(self, value):
sym_check = super().check(value)
if not sym_check.all():
return sym_check
return eigvalsh(value).ge(0).all(-1)


class _PositiveDefinite(_Symmetric):
def check(self, value):
sym_check = super().check(value)
if not sym_check.all():
return sym_check
return cholesky_ex(value).info.eq(0)


class _Cat(Constraint):
def __init__(self, cseq, dim=0, lengths=None):
assert all(isinstance(c, Constraint) for c in cseq)
self.cseq = list(cseq)
if lengths is None:
lengths = [1] * len(self.cseq)
self.lengths = list(lengths)
assert len(self.lengths) == len(self.cseq)
self.dim = dim
super().__init__()

@property
def is_discrete(self):
return any(c.is_discrete for c in self.cseq)

@property
def event_dim(self):
return max(c.event_dim for c in self.cseq)

def check(self, value):
assert -value.dim() <= self.dim < value.dim()
checks = []
start = 0
for constr, length in zip(self.cseq, self.lengths):
v = value.narrow(self.dim, start, length)
checks.append(constr.check(v))
start = start + length # avoid += for jit compat
return torch_cat(checks, self.dim)


class _Stack(Constraint):
def __init__(self, cseq, dim=0):
assert all(isinstance(c, Constraint) for c in cseq)
self.cseq = list(cseq)
self.dim = dim
super().__init__()

@property
def is_discrete(self):
return any(c.is_discrete for c in self.cseq)

@property
def event_dim(self):
dim = max(c.event_dim for c in self.cseq)
if self.dim + dim < 0:
dim += 1
return dim

def check(self, value):
assert -value.dim() <= self.dim < value.dim()
vs = [value.select(self.dim, i) for i in range(value.size(self.dim))]
return torch_stack([constr.check(v)
for v, constr in zip(vs, self.cseq)], self.dim)


# Public interface.
dependent = _Dependent()
dependent_property = _DependentProperty
independent = _IndependentConstraint
boolean = _Boolean()
one_hot = _OneHot()
nonnegative_integer = _IntegerGreaterThan(0)
positive_integer = _IntegerGreaterThan(1)
integer_interval = _IntegerInterval
real = _Real()
real_vector = independent(real, 1)
positive = _GreaterThan(0.)
nonnegative = _GreaterThanEq(0.)
greater_than = _GreaterThan
greater_than_eq = _GreaterThanEq
less_than = _LessThan
multinomial = _Multinomial
unit_interval = _Interval(0., 1.)
interval = _Interval
half_open_interval = _HalfOpenInterval
simplex = _Simplex()
lower_triangular = _LowerTriangular()
lower_cholesky = _LowerCholesky()
corr_cholesky = _CorrCholesky()
square = _Square()
symmetric = _Symmetric()
positive_semidefinite = _PositiveSemidefinite()
positive_definite = _PositiveDefinite()
cat = _Cat
stack = _Stack

+ 175
- 0
mindtorch/torch/distributions/continuous_bernoulli.py View File

@@ -0,0 +1,175 @@
from numbers import Number
import math

from mindspore import _no_grad as torch_no_grad
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.exp_family import ExponentialFamily
from mindtorch.torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property, \
clamp_probs
from mindtorch.torch.nn.functional import binary_cross_entropy_with_logits
from mindtorch.torch._C.Size import Size
from mindtorch.torch.functional import max as torch_max, le as torch_le, gt as torch_gt, where, ones_like, \
ge as torch_ge, zeros_like, log as torch_log, log1p, abs as torch_abs, pow as torch_pow, sqrt as torch_sqrt, \
rand, exp as torch_exp


class ContinuousBernoulli(ExponentialFamily):
arg_constraints = {'probs': constraints.unit_interval,
'logits': constraints.real}
support = constraints.unit_interval
_mean_carrier_measure = 0
has_rsample = True

def __init__(self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
is_scalar = isinstance(probs, Number)
self.probs, = broadcast_all(probs)
if validate_args is not None:
if not self.arg_constraints['probs'].check(getattr(self, 'probs')).all():
raise ValueError("The parameter {} has invalid values".format('probs'))
self.probs = clamp_probs(self.probs)
else:
is_scalar = isinstance(logits, Number)
self.logits, = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
batch_shape = Size()
else:
batch_shape = self._param.size()
self._lims = lims
super(ContinuousBernoulli, self).__init__(batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(ContinuousBernoulli, _instance)
new._lims = self._lims
batch_shape = Size(batch_shape)
if 'probs' in self.__dict__:
new.probs = self.probs.expand(batch_shape)
new._param = new.probs
if 'logits' in self.__dict__:
new.logits = self.logits.expand(batch_shape)
new._param = new.logits
super(ContinuousBernoulli, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

def _outside_unstable_region(self):
return torch_max(torch_le(self.probs, self._lims[0]),
torch_gt(self.probs, self._lims[1]))

def _cut_probs(self):
return where(self._outside_unstable_region(),
self.probs,
self._lims[0] * ones_like(self.probs))

def _cont_bern_log_norm(self):
cut_probs = self._cut_probs()
cut_probs_below_half = where(torch_le(cut_probs, 0.5),
cut_probs,
zeros_like(cut_probs))
cut_probs_above_half = where(torch_ge(cut_probs, 0.5),
cut_probs,
ones_like(cut_probs))
log_norm = torch_log(torch_abs(log1p(-cut_probs) - torch_log(cut_probs))) - where( # pylint: disable=E1130
torch_le(cut_probs, 0.5),
log1p(-2.0 * cut_probs_below_half),
torch_log(2.0 * cut_probs_above_half - 1.0))
x = torch_pow(self.probs - 0.5, 2)
taylor = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x
return where(self._outside_unstable_region(), log_norm, taylor)

@property
def mean(self):
cut_probs = self._cut_probs()
mus = cut_probs / (2.0 * cut_probs - 1.0) + 1.0 / (
log1p(-cut_probs) - torch_log(cut_probs)) # pylint: disable=E1130
x = self.probs - 0.5
taylor = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * torch_pow(x, 2)) * x
return where(self._outside_unstable_region(), mus, taylor)

@property
def stddev(self):
return torch_sqrt(self.variance)

@property
def variance(self):
cut_probs = self._cut_probs()
vars = cut_probs * (cut_probs - 1.0) / torch_pow(1.0 - 2.0 * cut_probs, 2) + 1.0 / torch_pow(
log1p(-cut_probs) - torch_log(cut_probs), 2) # pylint: disable=E1130
x = torch_pow(self.probs - 0.5, 2)
taylor = 1.0 / 12.0 - (1.0 / 15.0 - 128. / 945.0 * x) * x
return where(self._outside_unstable_region(), vars, taylor)

@lazy_property
def logits(self): # pylint: disable=E0202
return probs_to_logits(self.probs, is_binary=True)

@lazy_property
def probs(self): # pylint: disable=E0202
return clamp_probs(logits_to_probs(self.logits, is_binary=True))

@property
def param_shape(self):
return self._param.size()

def sample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
u = rand(shape, dtype=self.probs.dtype, device=self.probs.device)
with torch_no_grad():
return self.icdf(u)

def rsample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
u = rand(shape, dtype=self.probs.dtype, device=self.probs.device)
return self.icdf(u)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits, value)
return -binary_cross_entropy_with_logits(logits, value, # pylint: disable=E1130
reduction='none') + self._cont_bern_log_norm()

def cdf(self, value):
if self._validate_args:
self._validate_sample(value)
cut_probs = self._cut_probs()
cdfs = (torch_pow(cut_probs, value) * torch_pow(1.0 - cut_probs, 1.0 - value)
+ cut_probs - 1.0) / (2.0 * cut_probs - 1.0)
unbounded_cdfs = where(self._outside_unstable_region(), cdfs, value)
return where(
torch_le(value, 0.0),
zeros_like(value),
where(torch_ge(value, 1.0), ones_like(value), unbounded_cdfs))

def icdf(self, value):
cut_probs = self._cut_probs()
return where(
self._outside_unstable_region(),
(log1p(-cut_probs + value * (2.0 * cut_probs - 1.0)) # pylint: disable=E1130
- log1p(-cut_probs)) / (torch_log(cut_probs) - log1p(-cut_probs)), # pylint: disable=E1130
value)

def entropy(self):
log_probs0 = log1p(-self.probs)
log_probs1 = torch_log(self.probs)
return self.mean * (log_probs0 - log_probs1) - self._cont_bern_log_norm() - log_probs0

@property
def _natural_params(self):
return (self.logits, )

def _log_normalizer(self, x):
out_unst_reg = torch_max(torch_le(x, self._lims[0] - 0.5),
torch_gt(x, self._lims[1] - 0.5))
cut_nat_params = where(out_unst_reg,
x,
(self._lims[0] - 0.5) * ones_like(x))
log_norm = torch_log(torch_abs(torch_exp(cut_nat_params) - 1.0)) - torch_log(torch_abs(cut_nat_params))
taylor = 0.5 * x + torch_pow(x, 2) / 24.0 - torch_pow(x, 4) / 2880.0
return where(out_unst_reg, log_norm, taylor)

+ 95
- 0
mindtorch/torch/distributions/dirichlet.py View File

@@ -0,0 +1,95 @@
# from mindtorch.torch.autograd import Function
# from mindtorch.torch.autograd.function import once_differentiable
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.exp_family import ExponentialFamily
# from mindtorch.torch.functional import _dirichlet_grad, _sample_dirichlet
from mindtorch.torch.functional import log as torch_log, lgamma, digamma
from mindtorch.torch.nn.functional import one_hot
from mindtorch.torch._C.Size import Size


# TODO:
# def _Dirichlet_backward(x, concentration, grad_output):
# total = concentration.sum(-1, True).expand_as(concentration)
# grad = _dirichlet_grad(x, concentration, total)
# return grad * (grad_output - (x * grad_output).sum(-1, True))
#
#
# class _Dirichlet(Function):
# @staticmethod
# def forward(ctx, concentration):
# x = _sample_dirichlet(concentration)
# ctx.save_for_backward(x, concentration)
# return x
#
# @staticmethod
# @once_differentiable
# def backward(ctx, grad_output):
# x, concentration = ctx.saved_tensors
# return _Dirichlet_backward(x, concentration, grad_output)


class Dirichlet(ExponentialFamily):
arg_constraints = {'concentration': constraints.independent(constraints.positive, 1)}
support = constraints.simplex
has_rsample = True

def __init__(self, concentration, validate_args=None):
if concentration.dim() < 1:
raise ValueError("`concentration` parameter must be at least one-dimensional.")
self.concentration = concentration
batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
super(Dirichlet, self).__init__(batch_shape, event_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Dirichlet, _instance)
batch_shape = Size(batch_shape)
new.concentration = self.concentration.expand(batch_shape + self.event_shape)
super(Dirichlet, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def rsample(self, sample_shape=()):
# TODO:
# shape = self._extended_shape(sample_shape)
# concentration = self.concentration.expand(shape)
# return _Dirichlet.apply(concentration)
raise NotImplementedError("Currently, `Dirichlet.rsample` is not implemented.")

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return ((torch_log(value) * (self.concentration - 1.0)).sum(-1) +
lgamma(self.concentration.sum(-1)) -
lgamma(self.concentration).sum(-1))

@property
def mean(self):
return self.concentration / self.concentration.sum(-1, True)

@property
def mode(self):
concentrationm1 = (self.concentration - 1).clamp(min=0.)
mode = concentrationm1 / concentrationm1.sum(-1, True)
mask = (self.concentration < 1).all(axis=-1)
mode[mask] = one_hot(mode[mask].argmax(axis=-1), concentrationm1.shape[-1]).to(mode)
return mode

@property
def variance(self):
con0 = self.concentration.sum(-1, True)
return self.concentration * (con0 - self.concentration) / (con0.pow(2) * (con0 + 1))

def entropy(self):
k = self.concentration.size(-1)
a0 = self.concentration.sum(-1)
return (lgamma(self.concentration).sum(-1) - lgamma(a0) -
(k - a0) * digamma(a0) -
((self.concentration - 1.0) * digamma(self.concentration)).sum(-1))

@property
def _natural_params(self):
return (self.concentration, )

def _log_normalizer(self, x):
return x.lgamma().sum(-1) - lgamma(x.sum(-1))

+ 166
- 0
mindtorch/torch/distributions/distribution.py View File

@@ -0,0 +1,166 @@
import warnings
from typing import Dict, Optional, Any
from mindspore import _no_grad as torch_no_grad
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.utils import lazy_property
from mindtorch.torch._C.Size import Size
from mindtorch.torch.tensor import Tensor as torch_Tensor
from mindtorch.torch.functional import exp as torch_exp


class Distribution():
has_rsample = False
has_enumerate_support = False
_validate_args = __debug__

@staticmethod
def set_default_validate_args(value):
if value not in [True, False]:
raise ValueError
Distribution._validate_args = value

def __init__(self, batch_shape=Size(), event_shape=Size(), validate_args=None):
self._batch_shape = batch_shape
self._event_shape = event_shape
if validate_args is not None:
self._validate_args = validate_args
if self._validate_args:
try:
arg_constraints = self.arg_constraints
except NotImplementedError:
arg_constraints = {}
warnings.warn(f'{self.__class__} does not define `arg_constraints`. ' +
'Please set `arg_constraints = {}` or initialize the distribution ' +
'with `validate_args=False` to turn off validation.')
for param, constraint in arg_constraints.items():
if constraints.is_dependent(constraint):
continue # skip constraints that cannot be checked
if param not in self.__dict__ and isinstance(getattr(type(self), param), lazy_property):
continue # skip checking lazily-constructed args
value = getattr(self, param)
valid = constraint.check(value)
if not valid.all():
raise ValueError(
f"Expected parameter {param} "
f"({type(value).__name__} of shape {tuple(value.shape)}) "
f"of distribution {repr(self)} "
f"to satisfy the constraint {repr(constraint)}, "
f"but found invalid values:\n{value}"
)
super(Distribution, self).__init__()

def expand(self, batch_shape, _instance=None):
raise NotImplementedError

@property
def batch_shape(self):
return self._batch_shape

@property
def event_shape(self):
return self._event_shape

@property
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
raise NotImplementedError

@property
def support(self) -> Optional[Any]:
raise NotImplementedError

@property
def mean(self):
raise NotImplementedError

@property
def mode(self):
raise NotImplementedError(f"{self.__class__} does not implement mode")

@property
def variance(self):
raise NotImplementedError

@property
def stddev(self):
return self.variance.sqrt()

def sample(self, sample_shape=Size()):
with torch_no_grad():
return self.rsample(sample_shape)

def rsample(self, sample_shape=Size()):
raise NotImplementedError

def sample_n(self, n):
warnings.warn('sample_n will be deprecated. Use .sample((n,)) instead', UserWarning)
return self.sample(Size((n,)))

def log_prob(self, value):
raise NotImplementedError

def cdf(self, value):
raise NotImplementedError

def icdf(self, value):
raise NotImplementedError

def enumerate_support(self, expand=True):
raise NotImplementedError

def entropy(self):
raise NotImplementedError

def perplexity(self):
return torch_exp(self.entropy())

def _extended_shape(self, sample_shape=Size()):
if not isinstance(sample_shape, Size):
sample_shape = Size(sample_shape)
return sample_shape + self._batch_shape + self._event_shape

def _validate_sample(self, value):
if not isinstance(value, torch_Tensor):
raise ValueError('The value argument to log_prob must be a Tensor')

event_dim_start = len(value.size()) - len(self._event_shape)
if value.size()[event_dim_start:] != self._event_shape:
raise ValueError('The right-most size of value must match event_shape: {} vs {}.'.
format(value.size(), self._event_shape))

actual_shape = value.size()
expected_shape = self._batch_shape + self._event_shape
for i, j in zip(reversed(actual_shape), reversed(expected_shape)):
if i != 1 and j != 1 and i != j:
raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
format(actual_shape, expected_shape))
try:
support = self.support
except NotImplementedError:
warnings.warn(f'{self.__class__} does not define `support` to enable ' +
'sample validation. Please initialize the distribution with ' +
'`validate_args=False` to turn off validation.')
return
assert support is not None
valid = support.check(value)
if not valid.all():
raise ValueError(
"Expected value argument "
f"({type(value).__name__} of shape {tuple(value.shape)}) "
f"to be within the support ({repr(support)}) "
f"of the distribution {repr(self)}, "
f"but found invalid values:\n{value}"
)

def _get_checked_instance(self, cls, _instance=None):
if _instance is None and type(self).__init__ != cls.__init__:
raise NotImplementedError("Subclass {} of {} that defines a custom __init__ method "
"must also define a custom .expand() method.".
format(self.__class__.__name__, cls.__name__))
return self.__new__(type(self)) if _instance is None else _instance

def __repr__(self):
param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p]
if self.__dict__[p].numel() == 1
else self.__dict__[p].size()) for p in param_names])
return self.__class__.__name__ + '(' + args_string + ')'

+ 32
- 0
mindtorch/torch/distributions/exp_family.py View File

@@ -0,0 +1,32 @@
from mindspore import grad
from mindtorch.torch.functional import cast_to_adapter_tensor
from mindtorch.torch.distributions.distribution import Distribution


class ExponentialFamily(Distribution):
@property
def _natural_params(self):
raise NotImplementedError

def _log_normalizer(self, *natural_params):
raise NotImplementedError

@property
def _mean_carrier_measure(self):
raise NotImplementedError

def entropy(self):
result = -self._mean_carrier_measure
nparams = [p.detach().requires_grad_() for p in self._natural_params]
lg_normal = self._log_normalizer(*nparams)
# TODO:gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)
def fn(*nparams):
lg_normal = self._log_normalizer(*nparams)
return lg_normal.sum()
gradients = grad(fn, tuple(range(len(nparams))))(*nparams)
gradients = cast_to_adapter_tensor(gradients)

result += lg_normal
for np, g in zip(nparams, gradients):
result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1)
return result

+ 74
- 0
mindtorch/torch/distributions/exponential.py View File

@@ -0,0 +1,74 @@
from numbers import Number

from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.exp_family import ExponentialFamily
from mindtorch.torch.distributions.utils import broadcast_all
from mindtorch.torch.functional import zeros_like, log as torch_log, exp as torch_exp, rand as torch_rand
from mindtorch.torch._C.Size import Size
from mindtorch.torch._C import _get_tracing_state

class Exponential(ExponentialFamily):
arg_constraints = {'rate': constraints.positive}
support = constraints.nonnegative
has_rsample = True
_mean_carrier_measure = 0

@property
def mean(self):
return self.rate.reciprocal()

@property
def mode(self):
return zeros_like(self.rate)

@property
def stddev(self):
return self.rate.reciprocal()

@property
def variance(self):
return self.rate.pow(-2)

def __init__(self, rate, validate_args=None):
self.rate, = broadcast_all(rate)
batch_shape = Size() if isinstance(rate, Number) else self.rate.size()
super(Exponential, self).__init__(batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Exponential, _instance)
batch_shape = Size(batch_shape)
new.rate = self.rate.expand(batch_shape)
super(Exponential, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def rsample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
if _get_tracing_state():
# [JIT WORKAROUND] lack of support for ._exponential()
u = torch_rand(shape, dtype=self.rate.dtype, device=self.rate.device)
return -(-u).log1p() / self.rate
return self.rate.new(shape).exponential_() / self.rate

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return self.rate.log() - self.rate * value

def cdf(self, value):
if self._validate_args:
self._validate_sample(value)
return 1 - torch_exp(-self.rate * value)

def icdf(self, value):
return -torch_log(1 - value) / self.rate # pylint: disable=E1130

def entropy(self):
return 1.0 - torch_log(self.rate)

@property
def _natural_params(self):
return (-self.rate, )

def _log_normalizer(self, x):
return -torch_log(-x) # pylint: disable=E1130

+ 77
- 0
mindtorch/torch/distributions/fishersnedecor.py View File

@@ -0,0 +1,77 @@
from numbers import Number

from mindtorch.torch._six import nan
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.gamma import Gamma
from mindtorch.torch.distributions.utils import broadcast_all
from mindtorch.torch._C.Size import Size
from mindtorch.torch.common.dtype import finfo
from mindtorch.torch.functional import log as torch_log, log1p


class FisherSnedecor(Distribution):
arg_constraints = {'df1': constraints.positive, 'df2': constraints.positive}
support = constraints.positive
has_rsample = True

def __init__(self, df1, df2, validate_args=None):
self.df1, self.df2 = broadcast_all(df1, df2)
self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
self._gamma2 = Gamma(self.df2 * 0.5, self.df2)

if isinstance(df1, Number) and isinstance(df2, Number):
batch_shape = Size()
else:
batch_shape = self.df1.size()
super(FisherSnedecor, self).__init__(batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(FisherSnedecor, _instance)
batch_shape = Size(batch_shape)
new.df1 = self.df1.expand(batch_shape)
new.df2 = self.df2.expand(batch_shape)
new._gamma1 = self._gamma1.expand(batch_shape)
new._gamma2 = self._gamma2.expand(batch_shape)
super(FisherSnedecor, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

@property
def mean(self):
df2 = self.df2.clone()
df2[df2 <= 2] = nan
return df2 / (df2 - 2)

@property
def mode(self):
mode = (self.df1 - 2) / self.df1 * self.df2 / (self.df2 + 2)
mode[self.df1 <= 2] = nan
return mode

@property
def variance(self):
df2 = self.df2.clone()
df2[df2 <= 4] = nan
return 2 * df2.pow(2) * (self.df1 + df2 - 2) / (self.df1 * (df2 - 2).pow(2) * (df2 - 4))

def rsample(self, sample_shape=Size(())):
shape = self._extended_shape(sample_shape)
X1 = self._gamma1.rsample(sample_shape).view(shape)
X2 = self._gamma2.rsample(sample_shape).view(shape)
tiny = finfo(X2.dtype).tiny
X2.clamp_(min=tiny)
Y = X1 / X2
Y.clamp_(min=tiny)
return Y

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
ct1 = self.df1 * 0.5
ct2 = self.df2 * 0.5
ct3 = self.df1 / self.df2
t1 = (ct1 + ct2).lgamma() - ct1.lgamma() - ct2.lgamma()
t2 = ct1 * ct3.log() + (ct1 - 1) * torch_log(value)
t3 = (ct1 + ct2) * log1p(ct3 * value)
return t1 + t2 - t3

+ 76
- 0
mindtorch/torch/distributions/gamma.py View File

@@ -0,0 +1,76 @@
from numbers import Number

from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.exp_family import ExponentialFamily
from mindtorch.torch.distributions.utils import broadcast_all
from mindtorch.torch._C.Size import Size
# from mindtorch.torch.common.dtype import finfo
from mindtorch.torch.functional import as_tensor, xlogy, lgamma, digamma, log as torch_log
# from mindtorch.torch.functional import _standard_gamma


# def _standard_gamma(concentration):
# return _standard_gamma(concentration)


class Gamma(ExponentialFamily):
arg_constraints = {'concentration': constraints.positive, 'rate': constraints.positive}
support = constraints.nonnegative
has_rsample = True
_mean_carrier_measure = 0

@property
def mean(self):
return self.concentration / self.rate

@property
def mode(self):
return ((self.concentration - 1) / self.rate).clamp(min=0)

@property
def variance(self):
return self.concentration / self.rate.pow(2)

def __init__(self, concentration, rate, validate_args=None):
self.concentration, self.rate = broadcast_all(concentration, rate)
if isinstance(concentration, Number) and isinstance(rate, Number):
batch_shape = Size()
else:
batch_shape = self.concentration.size()
super(Gamma, self).__init__(batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Gamma, _instance)
batch_shape = Size(batch_shape)
new.concentration = self.concentration.expand(batch_shape)
new.rate = self.rate.expand(batch_shape)
super(Gamma, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def rsample(self, sample_shape=Size()):
# TODO:
# shape = self._extended_shape(sample_shape)
# value = _standard_gamma(self.concentration.expand(shape)) / self.rate.expand(shape)
# value.detach().clamp_(min=finfo(value.dtype).tiny) # do not record in autograd graph
# return value
raise NotImplementedError(f"Currently, `{self.__class__}.rsample` is not implemented.")

def log_prob(self, value):
value = as_tensor(value, dtype=self.rate.dtype, device=self.rate.device)
if self._validate_args:
self._validate_sample(value)
return (xlogy(self.concentration, self.rate) +
xlogy(self.concentration - 1, value) -
self.rate * value - lgamma(self.concentration))

def entropy(self):
return (self.concentration - torch_log(self.rate) + lgamma(self.concentration) +
(1.0 - self.concentration) * digamma(self.concentration))

@property
def _natural_params(self):
return (self.concentration - 1, -self.rate)

def _log_normalizer(self, x, y):
return lgamma(x + 1) + (x + 1) * torch_log(-y.reciprocal())

+ 97
- 0
mindtorch/torch/distributions/geometric.py View File

@@ -0,0 +1,97 @@
from numbers import Number

from mindspore import _no_grad as torch_no_grad
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property
from mindtorch.torch.nn.functional import binary_cross_entropy_with_logits
from mindtorch.torch._C.Size import Size
from mindtorch.torch.common.dtype import finfo
from mindtorch.torch.functional import zeros_like, rand
from mindtorch.torch._C import _get_tracing_state


class Geometric(Distribution):
arg_constraints = {'probs': constraints.unit_interval,
'logits': constraints.real}
support = constraints.nonnegative_integer

def __init__(self, probs=None, logits=None, validate_args=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
self.probs, = broadcast_all(probs)
else:
self.logits, = broadcast_all(logits)
probs_or_logits = probs if probs is not None else logits
if isinstance(probs_or_logits, Number):
batch_shape = Size()
else:
batch_shape = probs_or_logits.size()
super(Geometric, self).__init__(batch_shape, validate_args=validate_args)
if self._validate_args and probs is not None:
# Add an extra check beyond unit_interval
value = self.probs
valid = value > 0
if not valid.all():
invalid_value = value.data[~valid]
raise ValueError(
"Expected parameter probs "
f"({type(value).__name__} of shape {tuple(value.shape)}) "
f"of distribution {repr(self)} "
f"to be positive but found invalid values:\n{invalid_value}"
)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Geometric, _instance)
batch_shape = Size(batch_shape)
if 'probs' in self.__dict__:
new.probs = self.probs.expand(batch_shape)
if 'logits' in self.__dict__:
new.logits = self.logits.expand(batch_shape)
super(Geometric, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

@property
def mean(self):
return 1. / self.probs - 1.

@property
def mode(self):
return zeros_like(self.probs)

@property
def variance(self):
return (1. / self.probs - 1.) / self.probs

@lazy_property
def logits(self): # pylint: disable=E0202
return probs_to_logits(self.probs, is_binary=True)

@lazy_property
def probs(self): # pylint: disable=E0202
return logits_to_probs(self.logits, is_binary=True)

def sample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
tiny = finfo(self.probs.dtype).tiny
with torch_no_grad():
if _get_tracing_state():
# [JIT WORKAROUND] lack of support for .uniform_()
u = rand(shape, dtype=self.probs.dtype, device=self.probs.device)
u = u.clamp(min=tiny)
else:
u = self.probs.new(shape).uniform_(tiny, 1)
return (u.log() / (-self.probs).log1p()).floor()

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
value, probs = broadcast_all(value, self.probs)
probs = probs.clone()
probs[(probs == 1) & (value == 0)] = 0
return value * (-probs).log1p() + self.probs.log()

def entropy(self):
return binary_cross_entropy_with_logits(self.logits, self.probs, reduction='none') / self.probs

+ 57
- 0
mindtorch/torch/distributions/gumbel.py View File

@@ -0,0 +1,57 @@
from numbers import Number
import math
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.uniform import Uniform
from mindtorch.torch.distributions.transformed_distribution import TransformedDistribution
from mindtorch.torch.distributions.transforms import AffineTransform, ExpTransform
from mindtorch.torch.distributions.utils import broadcast_all, euler_constant
from mindtorch.torch.common.dtype import finfo as torch_finfo
from mindtorch.torch.functional import full_like, ones_like


class Gumbel(TransformedDistribution):
arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
support = constraints.real

def __init__(self, loc, scale, validate_args=None):
self.loc, self.scale = broadcast_all(loc, scale)
finfo = torch_finfo(self.loc.dtype)
if isinstance(loc, Number) and isinstance(scale, Number):
base_dist = Uniform(finfo.tiny, 1 - finfo.eps)
else:
base_dist = Uniform(full_like(self.loc, finfo.tiny),
full_like(self.loc, 1 - finfo.eps))
transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-ones_like(self.scale)),
ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)]
super(Gumbel, self).__init__(base_dist, transforms, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Gumbel, _instance)
new.loc = self.loc.expand(batch_shape)
new.scale = self.scale.expand(batch_shape)
return super(Gumbel, self).expand(batch_shape, _instance=new)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
y = (self.loc - value) / self.scale
return (y - y.exp()) - self.scale.log()

@property
def mean(self):
return self.loc + self.scale * euler_constant

@property
def mode(self):
return self.loc

@property
def stddev(self):
return (math.pi / math.sqrt(6)) * self.scale

@property
def variance(self):
return self.stddev.pow(2)

def entropy(self):
return self.scale.log() + (1 + euler_constant)

+ 58
- 0
mindtorch/torch/distributions/half_cauchy.py View File

@@ -0,0 +1,58 @@
import math

from mindtorch.torch._six import inf
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.transforms import AbsTransform
from mindtorch.torch.distributions.cauchy import Cauchy
from mindtorch.torch.distributions.transformed_distribution import TransformedDistribution
from mindtorch.torch.functional import full, zeros_like, as_tensor


class HalfCauchy(TransformedDistribution):
arg_constraints = {'scale': constraints.positive}
support = constraints.nonnegative
has_rsample = True

def __init__(self, scale, validate_args=None):
base_dist = Cauchy(0, scale, validate_args=False)
super(HalfCauchy, self).__init__(base_dist, AbsTransform(),
validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(HalfCauchy, _instance)
return super(HalfCauchy, self).expand(batch_shape, _instance=new)

@property
def scale(self):
return self.base_dist.scale

@property
def mean(self):
return full(self._extended_shape(), math.inf, dtype=self.scale.dtype, device=self.scale.device)

@property
def mode(self):
return zeros_like(self.scale)

@property
def variance(self):
return self.base_dist.variance

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
value = as_tensor(value, dtype=self.base_dist.scale.dtype, device=self.base_dist.scale.device)
log_prob = self.base_dist.log_prob(value) + math.log(2)
log_prob[value.expand(log_prob.shape) < 0] = -inf
return log_prob

def cdf(self, value):
if self._validate_args:
self._validate_sample(value)
return 2 * self.base_dist.cdf(value) - 1

def icdf(self, prob):
return self.base_dist.icdf((prob + 1) / 2)

def entropy(self):
return self.base_dist.entropy() - math.log(2)

+ 57
- 0
mindtorch/torch/distributions/half_normal.py View File

@@ -0,0 +1,57 @@
import math

from mindtorch.torch._six import inf
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.transforms import AbsTransform
from mindtorch.torch.distributions.normal import Normal
from mindtorch.torch.distributions.transformed_distribution import TransformedDistribution
from mindtorch.torch.functional import zeros_like


class HalfNormal(TransformedDistribution):
arg_constraints = {'scale': constraints.positive}
support = constraints.nonnegative
has_rsample = True

def __init__(self, scale, validate_args=None):
base_dist = Normal(0, scale, validate_args=False)
super(HalfNormal, self).__init__(base_dist, AbsTransform(),
validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(HalfNormal, _instance)
return super(HalfNormal, self).expand(batch_shape, _instance=new)

@property
def scale(self):
return self.base_dist.scale

@property
def mean(self):
return self.scale * math.sqrt(2 / math.pi)

@property
def mode(self):
return zeros_like(self.scale)

@property
def variance(self):
return self.scale.pow(2) * (1 - 2 / math.pi)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
log_prob = self.base_dist.log_prob(value) + math.log(2)
log_prob[value.expand(log_prob.shape) < 0] = -inf
return log_prob

def cdf(self, value):
if self._validate_args:
self._validate_sample(value)
return 2 * self.base_dist.cdf(value) - 1

def icdf(self, prob):
return self.base_dist.icdf((prob + 1) / 2)

def entropy(self):
return self.base_dist.entropy() - math.log(2)

+ 83
- 0
mindtorch/torch/distributions/independent.py View File

@@ -0,0 +1,83 @@
from typing import Dict
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.utils import _sum_rightmost
from mindtorch.torch._C.Size import Size


class Independent(Distribution):
arg_constraints: Dict[str, constraints.Constraint] = {}

def __init__(self, base_distribution, reinterpreted_batch_ndims, validate_args=None):
if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
raise ValueError("Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
"actual {} vs {}".format(reinterpreted_batch_ndims,
len(base_distribution.batch_shape)))
shape = base_distribution.batch_shape + base_distribution.event_shape
event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape)
batch_shape = shape[:len(shape) - event_dim]
event_shape = shape[len(shape) - event_dim:]
self.base_dist = base_distribution
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
super(Independent, self).__init__(batch_shape, event_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Independent, _instance)
batch_shape = Size(batch_shape)
new.base_dist = self.base_dist.expand(batch_shape +
self.event_shape[:self.reinterpreted_batch_ndims])
new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims
super(Independent, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new

@property
def has_rsample(self):
return self.base_dist.has_rsample

@property
def has_enumerate_support(self):
if self.reinterpreted_batch_ndims > 0:
return False
return self.base_dist.has_enumerate_support

@constraints.dependent_property
def support(self):
result = self.base_dist.support
if self.reinterpreted_batch_ndims:
result = constraints.independent(result, self.reinterpreted_batch_ndims)
return result

@property
def mean(self):
return self.base_dist.mean

@property
def mode(self):
return self.base_dist.mode

@property
def variance(self):
return self.base_dist.variance

def sample(self, sample_shape=Size()):
return self.base_dist.sample(sample_shape)

def rsample(self, sample_shape=Size()):
return self.base_dist.rsample(sample_shape)

def log_prob(self, value):
log_prob = self.base_dist.log_prob(value)
return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)

def entropy(self):
entropy = self.base_dist.entropy()
return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)

def enumerate_support(self, expand=True):
if self.reinterpreted_batch_ndims > 0:
raise NotImplementedError("Enumeration over cartesian product is not implemented")
return self.base_dist.enumerate_support(expand=expand)

def __repr__(self):
return self.__class__.__name__ + '({}, {})'.format(self.base_dist, self.reinterpreted_batch_ndims)

+ 784
- 0
mindtorch/torch/distributions/kl.py View File

@@ -0,0 +1,784 @@
import math
import warnings
from functools import total_ordering
from typing import Type, Dict, Callable, Tuple

from mindspore import grad
from mindtorch.torch._six import inf
from mindtorch.torch.functional import full_like, digamma, log1p, lgamma, exp as torch_exp, diag_embed, \
log as torch_log, square as torch_square, ge as torch_ge, le as torch_le, ones_like, max as torch_max, \
where, erf, cast_to_adapter_tensor
from mindtorch.torch.linalg import solve_triangular
from mindtorch.torch._C import _infer_size
from mindtorch.utils import unsupported_attr

from .bernoulli import Bernoulli
from .beta import Beta
from .binomial import Binomial
from .categorical import Categorical
from .cauchy import Cauchy
from .continuous_bernoulli import ContinuousBernoulli
from .dirichlet import Dirichlet
from .distribution import Distribution
from .exponential import Exponential
from .exp_family import ExponentialFamily
from .gamma import Gamma
from .geometric import Geometric
from .gumbel import Gumbel
from .half_normal import HalfNormal
from .independent import Independent
from .laplace import Laplace
from .lowrank_multivariate_normal import (LowRankMultivariateNormal, _batch_lowrank_logdet,
_batch_lowrank_mahalanobis)
from .multivariate_normal import (MultivariateNormal, _batch_mahalanobis)
from .normal import Normal
from .one_hot_categorical import OneHotCategorical
from .pareto import Pareto
from .poisson import Poisson
from .transformed_distribution import TransformedDistribution
from .uniform import Uniform
from .utils import _sum_rightmost, euler_constant as _euler_gamma

# Source of truth mapping a few general (type, type) pairs to functions.
_KL_REGISTRY = {}
# Memoized version mapping many specific (type, type) pairs to functions.
_KL_MEMOIZE: Dict[Tuple[Type, Type], Callable] = {}


def register_kl(type_p, type_q):
if not isinstance(type_p, type) and issubclass(type_p, Distribution):
raise TypeError('Expected type_p to be a Distribution subclass but got {}'.format(type_p))
if not isinstance(type_q, type) and issubclass(type_q, Distribution):
raise TypeError('Expected type_q to be a Distribution subclass but got {}'.format(type_q))

def decorator(fun):
_KL_REGISTRY[type_p, type_q] = fun
_KL_MEMOIZE.clear() # reset since lookup order may have changed
return fun

return decorator


@total_ordering
class _Match():
__slots__ = ['types']

def __init__(self, *types):
self.types = types

def __eq__(self, other):
return self.types == other.types

def __le__(self, other):
for x, y in zip(self.types, other.types):
if not issubclass(x, y):
return False
if x is not y:
break
return True


def _dispatch_kl(type_p, type_q):
matches = [(super_p, super_q) for super_p, super_q in _KL_REGISTRY
if issubclass(type_p, super_p) and issubclass(type_q, super_q)]
if not matches:
return NotImplemented
left_p, left_q = min(_Match(*m) for m in matches).types # type: ignore[type-var]
right_q, right_p = min(_Match(*reversed(m)) for m in matches).types # type: ignore[type-var]
left_fun = _KL_REGISTRY[left_p, left_q]
right_fun = _KL_REGISTRY[right_p, right_q]
if left_fun is not right_fun:
warnings.warn('Ambiguous kl_divergence({}, {}). Please register_kl({}, {})'.format(
type_p.__name__, type_q.__name__, left_p.__name__, right_q.__name__),
RuntimeWarning)
return left_fun


def _infinite_like(tensor):
return full_like(tensor, inf)


def _x_log_x(tensor):
return tensor * tensor.log()


def _batch_trace_XXT(bmat):
n = bmat.size(-1)
m = bmat.size(-2)
flat_trace = bmat.reshape(-1, m * n).pow(2).sum(-1)
return flat_trace.reshape(bmat.shape[:-2])


def kl_divergence(p, q):
try:
fun = _KL_MEMOIZE[type(p), type(q)]
except KeyError:
fun = _dispatch_kl(type(p), type(q))
_KL_MEMOIZE[type(p), type(q)] = fun
if fun is NotImplemented:
raise NotImplementedError("No KL(p || q) is implemented for p type {} and q type {}"
.format(p.__class__.__name__, q.__class__.__name__))
return fun(p, q)


################################################################################
# KL Divergence Implementations
################################################################################

# Same distributions


@register_kl(Bernoulli, Bernoulli)
def _kl_bernoulli_bernoulli(p, q):
t1 = p.probs * (p.probs / q.probs).log()
t1[q.probs == 0] = inf
t1[p.probs == 0] = 0
t2 = (1 - p.probs) * ((1 - p.probs) / (1 - q.probs)).log()
t2[q.probs == 1] = inf
t2[p.probs == 1] = 0
return t1 + t2


@register_kl(Beta, Beta)
def _kl_beta_beta(p, q):
sum_params_p = p.concentration1 + p.concentration0
sum_params_q = q.concentration1 + q.concentration0
t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma()
t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma()
t3 = (p.concentration1 - q.concentration1) * digamma(p.concentration1)
t4 = (p.concentration0 - q.concentration0) * digamma(p.concentration0)
t5 = (sum_params_q - sum_params_p) * digamma(sum_params_p)
return t1 - t2 + t3 + t4 + t5


@register_kl(Binomial, Binomial)
def _kl_binomial_binomial(p, q):
if (p.total_count < q.total_count).any():
raise NotImplementedError('KL between Binomials where q.total_count > p.total_count is not implemented')
kl = p.total_count * (p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p())
inf_idxs = p.total_count > q.total_count
kl[inf_idxs] = _infinite_like(kl[inf_idxs])
return kl


@register_kl(Categorical, Categorical)
def _kl_categorical_categorical(p, q):
t = p.probs * (p.logits - q.logits)
t[(q.probs == 0).expand_as(t)] = inf
t[(p.probs == 0).expand_as(t)] = 0
return t.sum(-1)


@register_kl(ContinuousBernoulli, ContinuousBernoulli)
def _kl_continuous_bernoulli_continuous_bernoulli(p, q):
t1 = p.mean * (p.logits - q.logits)
t2 = p._cont_bern_log_norm() + log1p(-p.probs)
t3 = - q._cont_bern_log_norm() - log1p(-q.probs)
return t1 + t2 + t3


@register_kl(Dirichlet, Dirichlet)
def _kl_dirichlet_dirichlet(p, q):
sum_p_concentration = p.concentration.sum(-1)
sum_q_concentration = q.concentration.sum(-1)
t1 = sum_p_concentration.lgamma() - sum_q_concentration.lgamma()
t2 = (p.concentration.lgamma() - q.concentration.lgamma()).sum(-1)
t3 = p.concentration - q.concentration
t4 = p.concentration.digamma() - sum_p_concentration.digamma().unsqueeze(-1)
return t1 - t2 + (t3 * t4).sum(-1)


@register_kl(Exponential, Exponential)
def _kl_exponential_exponential(p, q):
rate_ratio = q.rate / p.rate
t1 = -rate_ratio.log()
return t1 + rate_ratio - 1


@register_kl(ExponentialFamily, ExponentialFamily)
def _kl_expfamily_expfamily(p, q):
if not isinstance(p, type(q)):
raise NotImplementedError("The cross KL-divergence between different exponential families cannot \
be computed using Bregman divergences")
p_nparams = [np.detach().requires_grad_() for np in p._natural_params]
q_nparams = q._natural_params
lg_normal = p._log_normalizer(*p_nparams)

# TODO: gradients = torch.autograd.grad(lg_normal.sum(), p_nparams, create_graph=True)
def fn(*p_nparams):
lg_normal = p._log_normalizer(*p_nparams)
return lg_normal.sum()
gradients = grad(fn, tuple(range(len(p_nparams))))(*p_nparams)
gradients = cast_to_adapter_tensor(gradients)

result = q._log_normalizer(*q_nparams) - lg_normal
for pnp, qnp, g in zip(p_nparams, q_nparams, gradients):
term = (qnp - pnp) * g
result -= _sum_rightmost(term, len(q.event_shape))
return result


@register_kl(Gamma, Gamma)
def _kl_gamma_gamma(p, q):
t1 = q.concentration * (p.rate / q.rate).log()
t2 = lgamma(q.concentration) - lgamma(p.concentration)
t3 = (p.concentration - q.concentration) * digamma(p.concentration)
t4 = (q.rate - p.rate) * (p.concentration / p.rate)
return t1 + t2 + t3 + t4


@register_kl(Gumbel, Gumbel)
def _kl_gumbel_gumbel(p, q):
ct1 = p.scale / q.scale
ct2 = q.loc / q.scale
ct3 = p.loc / q.scale
t1 = -ct1.log() - ct2 + ct3
t2 = ct1 * _euler_gamma
t3 = torch_exp(ct2 + (1 + ct1).lgamma() - ct3)
return t1 + t2 + t3 - (1 + _euler_gamma)


@register_kl(Geometric, Geometric)
def _kl_geometric_geometric(p, q):
return -p.entropy() - log1p(-q.probs) / p.probs - q.logits


@register_kl(HalfNormal, HalfNormal)
def _kl_halfnormal_halfnormal(p, q):
return _kl_normal_normal(p.base_dist, q.base_dist)


@register_kl(Laplace, Laplace)
def _kl_laplace_laplace(p, q):
scale_ratio = p.scale / q.scale
loc_abs_diff = (p.loc - q.loc).abs()
t1 = -scale_ratio.log()
t2 = loc_abs_diff / q.scale
t3 = scale_ratio * torch_exp(-loc_abs_diff / p.scale)
return t1 + t2 + t3 - 1


@register_kl(LowRankMultivariateNormal, LowRankMultivariateNormal)
def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q):
if p.event_shape != q.event_shape:
raise ValueError("KL-divergence between two Low Rank Multivariate Normals with\
different event shapes cannot be computed")

term1 = (_batch_lowrank_logdet(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
q._capacitance_tril) -
_batch_lowrank_logdet(p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag,
p._capacitance_tril))
term3 = _batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
q.loc - p.loc,
q._capacitance_tril)
# Expands term2 according to
# inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD)
# = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T)
qWt_qDinv = (q._unbroadcasted_cov_factor.mT /
q._unbroadcasted_cov_diag.unsqueeze(-2))
A = solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False)
term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1)
term22 = _batch_trace_XXT(p._unbroadcasted_cov_factor *
q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2))
term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor))
term2 = term21 + term22 - term23 - term24
return 0.5 * (term1 + term2 + term3 - p.event_shape[0])


@register_kl(MultivariateNormal, LowRankMultivariateNormal)
def _kl_multivariatenormal_lowrankmultivariatenormal(p, q):
if p.event_shape != q.event_shape:
raise ValueError("KL-divergence between two (Low Rank) Multivariate Normals with\
different event shapes cannot be computed")

term1 = (_batch_lowrank_logdet(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
q._capacitance_tril) -
2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
term3 = _batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
q.loc - p.loc,
q._capacitance_tril)
# Expands term2 according to
# inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T
# = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T
qWt_qDinv = (q._unbroadcasted_cov_factor.mT /
q._unbroadcasted_cov_diag.unsqueeze(-2))
A = solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False)
term21 = _batch_trace_XXT(p._unbroadcasted_scale_tril *
q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril))
term2 = term21 - term22
return 0.5 * (term1 + term2 + term3 - p.event_shape[0])


@register_kl(LowRankMultivariateNormal, MultivariateNormal)
def _kl_lowrankmultivariatenormal_multivariatenormal(p, q):
if p.event_shape != q.event_shape:
raise ValueError("KL-divergence between two (Low Rank) Multivariate Normals with\
different event shapes cannot be computed")

term1 = (2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) -
_batch_lowrank_logdet(p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag,
p._capacitance_tril))
term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
# Expands term2 according to
# inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD)
combined_batch_shape = _infer_size(q._unbroadcasted_scale_tril.shape[:-2],
p._unbroadcasted_cov_factor.shape[:-2])
n = p.event_shape[0]
q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
p_cov_factor = p._unbroadcasted_cov_factor.expand(combined_batch_shape +
(n, p.cov_factor.size(-1)))
p_cov_diag = (diag_embed(p._unbroadcasted_cov_diag.sqrt())
.expand(combined_batch_shape + (n, n)))
term21 = _batch_trace_XXT(solve_triangular(q_scale_tril, p_cov_factor, upper=False))
term22 = _batch_trace_XXT(solve_triangular(q_scale_tril, p_cov_diag, upper=False))
term2 = term21 + term22
return 0.5 * (term1 + term2 + term3 - p.event_shape[0])


@register_kl(MultivariateNormal, MultivariateNormal)
def _kl_multivariatenormal_multivariatenormal(p, q):
if p.event_shape != q.event_shape:
raise ValueError("KL-divergence between two Multivariate Normals with\
different event shapes cannot be computed")

half_term1 = (q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) -
p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
combined_batch_shape = _infer_size(q._unbroadcasted_scale_tril.shape[:-2],
p._unbroadcasted_scale_tril.shape[:-2])
n = p.event_shape[0]
q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
term2 = _batch_trace_XXT(solve_triangular(q_scale_tril, p_scale_tril, upper=False))
term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
return half_term1 + 0.5 * (term2 + term3 - n)


@register_kl(Normal, Normal)
def _kl_normal_normal(p, q):
var_ratio = (p.scale / q.scale).pow(2)
t1 = ((p.loc - q.loc) / q.scale).pow(2)
return 0.5 * (var_ratio + t1 - 1 - var_ratio.log())


@register_kl(OneHotCategorical, OneHotCategorical)
def _kl_onehotcategorical_onehotcategorical(p, q):
return _kl_categorical_categorical(p._categorical, q._categorical)


@register_kl(Pareto, Pareto)
def _kl_pareto_pareto(p, q):
scale_ratio = p.scale / q.scale
alpha_ratio = q.alpha / p.alpha
t1 = q.alpha * scale_ratio.log()
t2 = -alpha_ratio.log()
result = t1 + t2 + alpha_ratio - 1
result[p.support.lower_bound < q.support.lower_bound] = inf
return result


@register_kl(Poisson, Poisson)
def _kl_poisson_poisson(p, q):
return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate)


@register_kl(TransformedDistribution, TransformedDistribution)
def _kl_transformed_transformed(p, q):
if p.transforms != q.transforms:
raise NotImplementedError
if p.event_shape != q.event_shape:
raise NotImplementedError
return kl_divergence(p.base_dist, q.base_dist)


@register_kl(Uniform, Uniform)
def _kl_uniform_uniform(p, q):
result = ((q.high - q.low) / (p.high - p.low)).log()
result[(q.low > p.low) | (q.high < p.high)] = inf
return result


# Different distributions
@register_kl(Bernoulli, Poisson)
def _kl_bernoulli_poisson(p, q):
return -p.entropy() - (p.probs * q.rate.log() - q.rate)


@register_kl(Beta, ContinuousBernoulli)
def _kl_beta_continuous_bernoulli(p, q):
return -p.entropy() - p.mean * q.logits - log1p(-q.probs) - q._cont_bern_log_norm()


@register_kl(Beta, Pareto)
def _kl_beta_infinity(p, q):
unsupported_attr(q)
return _infinite_like(p.concentration1)


@register_kl(Beta, Exponential)
def _kl_beta_exponential(p, q):
return -p.entropy() - q.rate.log() + q.rate * (p.concentration1 / (p.concentration1 + p.concentration0))


@register_kl(Beta, Gamma)
def _kl_beta_gamma(p, q):
t1 = -p.entropy()
t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
t3 = (q.concentration - 1) * (p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma())
t4 = q.rate * p.concentration1 / (p.concentration1 + p.concentration0)
return t1 + t2 - t3 + t4

# TODO: Add Beta-Laplace KL Divergence


@register_kl(Beta, Normal)
def _kl_beta_normal(p, q):
E_beta = p.concentration1 / (p.concentration1 + p.concentration0)
var_normal = q.scale.pow(2)
t1 = -p.entropy()
t2 = 0.5 * (var_normal * 2 * math.pi).log()
t3 = (E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1) + E_beta.pow(2)) * 0.5
t4 = q.loc * E_beta
t5 = q.loc.pow(2) * 0.5
return t1 + t2 + (t3 - t4 + t5) / var_normal


@register_kl(Beta, Uniform)
def _kl_beta_uniform(p, q):
result = -p.entropy() + (q.high - q.low).log()
result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf
return result

# Note that the KL between a ContinuousBernoulli and Beta has no closed form


@register_kl(ContinuousBernoulli, Pareto)
def _kl_continuous_bernoulli_infinity(p, q):
unsupported_attr(q)
return _infinite_like(p.probs)


@register_kl(ContinuousBernoulli, Exponential)
def _kl_continuous_bernoulli_exponential(p, q):
return -p.entropy() - torch_log(q.rate) + q.rate * p.mean

# Note that the KL between a ContinuousBernoulli and Gamma has no closed form
# TODO: Add ContinuousBernoulli-Laplace KL Divergence


@register_kl(ContinuousBernoulli, Normal)
def _kl_continuous_bernoulli_normal(p, q):
t1 = -p.entropy()
t2 = 0.5 * (math.log(2. * math.pi) + torch_square(q.loc / q.scale)) + torch_log(q.scale)
t3 = (p.variance + torch_square(p.mean) - 2. * q.loc * p.mean) / (2.0 * torch_square(q.scale))
return t1 + t2 + t3


@register_kl(ContinuousBernoulli, Uniform)
def _kl_continuous_bernoulli_uniform(p, q):
result = -p.entropy() + (q.high - q.low).log()
return where(torch_max(torch_ge(q.low, p.support.lower_bound),
torch_le(q.high, p.support.upper_bound)),
ones_like(result) * inf, result)


@register_kl(Exponential, Beta)
@register_kl(Exponential, ContinuousBernoulli)
@register_kl(Exponential, Pareto)
@register_kl(Exponential, Uniform)
def _kl_exponential_infinity(p, q):
unsupported_attr(q)
return _infinite_like(p.rate)


@register_kl(Exponential, Gamma)
def _kl_exponential_gamma(p, q):
ratio = q.rate / p.rate
t1 = -q.concentration * torch_log(ratio)
return t1 + ratio + q.concentration.lgamma() + q.concentration * _euler_gamma - (1 + _euler_gamma)


@register_kl(Exponential, Gumbel)
def _kl_exponential_gumbel(p, q):
scale_rate_prod = p.rate * q.scale
loc_scale_ratio = q.loc / q.scale
t1 = scale_rate_prod.log() - 1
t2 = torch_exp(loc_scale_ratio) * scale_rate_prod / (scale_rate_prod + 1)
t3 = scale_rate_prod.reciprocal()
return t1 - loc_scale_ratio + t2 + t3

# TODO: Add Exponential-Laplace KL Divergence


@register_kl(Exponential, Normal)
def _kl_exponential_normal(p, q):
var_normal = q.scale.pow(2)
rate_sqr = p.rate.pow(2)
t1 = 0.5 * torch_log(rate_sqr * var_normal * 2 * math.pi)
t2 = rate_sqr.reciprocal()
t3 = q.loc / p.rate
t4 = q.loc.pow(2) * 0.5
return t1 - 1 + (t2 - t3 + t4) / var_normal


@register_kl(Gamma, Beta)
@register_kl(Gamma, ContinuousBernoulli)
@register_kl(Gamma, Pareto)
@register_kl(Gamma, Uniform)
def _kl_gamma_infinity(p, q):
unsupported_attr(q)
return _infinite_like(p.concentration)


@register_kl(Gamma, Exponential)
def _kl_gamma_exponential(p, q):
return -p.entropy() - q.rate.log() + q.rate * p.concentration / p.rate


@register_kl(Gamma, Gumbel)
def _kl_gamma_gumbel(p, q):
beta_scale_prod = p.rate * q.scale
loc_scale_ratio = q.loc / q.scale
t1 = (p.concentration - 1) * p.concentration.digamma() - p.concentration.lgamma() - p.concentration
t2 = beta_scale_prod.log() + p.concentration / beta_scale_prod
t3 = torch_exp(loc_scale_ratio) * (1 + beta_scale_prod.reciprocal()).pow(-p.concentration) - loc_scale_ratio
return t1 + t2 + t3

# TODO: Add Gamma-Laplace KL Divergence


@register_kl(Gamma, Normal)
def _kl_gamma_normal(p, q):
var_normal = q.scale.pow(2)
beta_sqr = p.rate.pow(2)
t1 = 0.5 * torch_log(beta_sqr * var_normal * 2 * math.pi) - p.concentration - p.concentration.lgamma()
t2 = 0.5 * (p.concentration.pow(2) + p.concentration) / beta_sqr
t3 = q.loc * p.concentration / p.rate
t4 = 0.5 * q.loc.pow(2)
return t1 + (p.concentration - 1) * p.concentration.digamma() + (t2 - t3 + t4) / var_normal


@register_kl(Gumbel, Beta)
@register_kl(Gumbel, ContinuousBernoulli)
@register_kl(Gumbel, Exponential)
@register_kl(Gumbel, Gamma)
@register_kl(Gumbel, Pareto)
@register_kl(Gumbel, Uniform)
def _kl_gumbel_infinity(p, q):
unsupported_attr(q)
return _infinite_like(p.loc)

# TODO: Add Gumbel-Laplace KL Divergence


@register_kl(Gumbel, Normal)
def _kl_gumbel_normal(p, q):
param_ratio = p.scale / q.scale
t1 = (param_ratio / math.sqrt(2 * math.pi)).log()
t2 = (math.pi * param_ratio * 0.5).pow(2) / 3
t3 = ((p.loc + p.scale * _euler_gamma - q.loc) / q.scale).pow(2) * 0.5
return -t1 + t2 + t3 - (_euler_gamma + 1)


@register_kl(Laplace, Beta)
@register_kl(Laplace, ContinuousBernoulli)
@register_kl(Laplace, Exponential)
@register_kl(Laplace, Gamma)
@register_kl(Laplace, Pareto)
@register_kl(Laplace, Uniform)
def _kl_laplace_infinity(p, q):
unsupported_attr(q)
return _infinite_like(p.loc)


@register_kl(Laplace, Normal)
def _kl_laplace_normal(p, q):
var_normal = q.scale.pow(2)
scale_sqr_var_ratio = p.scale.pow(2) / var_normal
t1 = 0.5 * torch_log(2 * scale_sqr_var_ratio / math.pi)
t2 = 0.5 * p.loc.pow(2)
t3 = p.loc * q.loc
t4 = 0.5 * q.loc.pow(2)
return -t1 + scale_sqr_var_ratio + (t2 - t3 + t4) / var_normal - 1


@register_kl(Normal, Beta)
@register_kl(Normal, ContinuousBernoulli)
@register_kl(Normal, Exponential)
@register_kl(Normal, Gamma)
@register_kl(Normal, Pareto)
@register_kl(Normal, Uniform)
def _kl_normal_infinity(p, q):
unsupported_attr(q)
return _infinite_like(p.loc)


@register_kl(Normal, Gumbel)
def _kl_normal_gumbel(p, q):
mean_scale_ratio = p.loc / q.scale
var_scale_sqr_ratio = (p.scale / q.scale).pow(2)
loc_scale_ratio = q.loc / q.scale
t1 = var_scale_sqr_ratio.log() * 0.5
t2 = mean_scale_ratio - loc_scale_ratio
t3 = torch_exp(-mean_scale_ratio + 0.5 * var_scale_sqr_ratio + loc_scale_ratio)
return -t1 + t2 + t3 - (0.5 * (1 + math.log(2 * math.pi)))


@register_kl(Normal, Laplace)
def _kl_normal_laplace(p, q):
loc_diff = p.loc - q.loc
scale_ratio = p.scale / q.scale
loc_diff_scale_ratio = loc_diff / p.scale
t1 = torch_log(scale_ratio)
t2 = math.sqrt(2 / math.pi) * p.scale * torch_exp(-0.5 * loc_diff_scale_ratio.pow(2))
t3 = loc_diff * erf(math.sqrt(0.5) * loc_diff_scale_ratio)
return -t1 + (t2 + t3) / q.scale - (0.5 * (1 + math.log(0.5 * math.pi))) # pylint: disable=E1130


@register_kl(Pareto, Beta)
@register_kl(Pareto, ContinuousBernoulli)
@register_kl(Pareto, Uniform)
def _kl_pareto_infinity(p, q):
unsupported_attr(q)
return _infinite_like(p.scale)


@register_kl(Pareto, Exponential)
def _kl_pareto_exponential(p, q):
scale_rate_prod = p.scale * q.rate
t1 = (p.alpha / scale_rate_prod).log()
t2 = p.alpha.reciprocal()
t3 = p.alpha * scale_rate_prod / (p.alpha - 1)
result = t1 - t2 + t3 - 1
result[p.alpha <= 1] = inf
return result


@register_kl(Pareto, Gamma)
def _kl_pareto_gamma(p, q):
common_term = p.scale.log() + p.alpha.reciprocal()
t1 = p.alpha.log() - common_term
t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
t3 = (1 - q.concentration) * common_term
t4 = q.rate * p.alpha * p.scale / (p.alpha - 1)
result = t1 + t2 + t3 + t4 - 1
result[p.alpha <= 1] = inf
return result

# TODO: Add Pareto-Laplace KL Divergence


@register_kl(Pareto, Normal)
def _kl_pareto_normal(p, q):
var_normal = 2 * q.scale.pow(2)
common_term = p.scale / (p.alpha - 1)
t1 = (math.sqrt(2 * math.pi) * q.scale * p.alpha / p.scale).log()
t2 = p.alpha.reciprocal()
t3 = p.alpha * common_term.pow(2) / (p.alpha - 2)
t4 = (p.alpha * common_term - q.loc).pow(2)
result = t1 - t2 + (t3 + t4) / var_normal - 1
result[p.alpha <= 2] = inf
return result


@register_kl(Poisson, Bernoulli)
@register_kl(Poisson, Binomial)
def _kl_poisson_infinity(p, q):
unsupported_attr(q)
return _infinite_like(p.rate)


@register_kl(Uniform, Beta)
def _kl_uniform_beta(p, q):
common_term = p.high - p.low
t1 = torch_log(common_term)
t2 = (q.concentration1 - 1) * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) / common_term
t3 = (q.concentration0 - 1) * (_x_log_x((1 - p.high)) - _x_log_x((1 - p.low)) + common_term) / common_term
t4 = q.concentration1.lgamma() + q.concentration0.lgamma() - (q.concentration1 + q.concentration0).lgamma()
result = t3 + t4 - t1 - t2
result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf
return result


@register_kl(Uniform, ContinuousBernoulli)
def _kl_uniform_continuous_bernoulli(p, q):
result = -p.entropy() - p.mean * q.logits - log1p(-q.probs) - q._cont_bern_log_norm()
return where(torch_max(torch_ge(p.high, q.support.upper_bound),
torch_le(p.low, q.support.lower_bound)),
ones_like(result) * inf, result)


@register_kl(Uniform, Exponential)
def _kl_uniform_exponetial(p, q):
result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log()
result[p.low < q.support.lower_bound] = inf
return result


@register_kl(Uniform, Gamma)
def _kl_uniform_gamma(p, q):
common_term = p.high - p.low
t1 = common_term.log()
t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
t3 = (1 - q.concentration) * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) / common_term
t4 = q.rate * (p.high + p.low) / 2
result = -t1 + t2 + t3 + t4
result[p.low < q.support.lower_bound] = inf
return result


@register_kl(Uniform, Gumbel)
def _kl_uniform_gumbel(p, q):
common_term = q.scale / (p.high - p.low)
high_loc_diff = (p.high - q.loc) / q.scale
low_loc_diff = (p.low - q.loc) / q.scale
t1 = common_term.log() + 0.5 * (high_loc_diff + low_loc_diff)
t2 = common_term * (torch_exp(-high_loc_diff) - torch_exp(-low_loc_diff))
return t1 - t2

# TODO: Uniform-Laplace KL Divergence


@register_kl(Uniform, Normal)
def _kl_uniform_normal(p, q):
common_term = p.high - p.low
t1 = (math.sqrt(math.pi * 2) * q.scale / common_term).log()
t2 = (common_term).pow(2) / 12
t3 = ((p.high + p.low - 2 * q.loc) / 2).pow(2)
return t1 + 0.5 * (t2 + t3) / q.scale.pow(2)


@register_kl(Uniform, Pareto)
def _kl_uniform_pareto(p, q):
support_uniform = p.high - p.low
t1 = (q.alpha * q.scale.pow(q.alpha) * (support_uniform)).log()
t2 = (_x_log_x(p.high) - _x_log_x(p.low) - support_uniform) / support_uniform
result = t2 * (q.alpha + 1) - t1
result[p.low < q.support.lower_bound] = inf
return result


@register_kl(Independent, Independent)
def _kl_independent_independent(p, q):
if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
raise NotImplementedError
result = kl_divergence(p.base_dist, q.base_dist)
return _sum_rightmost(result, p.reinterpreted_batch_ndims)


@register_kl(Cauchy, Cauchy)
def _kl_cauchy_cauchy(p, q):
t1 = ((p.scale + q.scale).pow(2) + (p.loc - q.loc).pow(2)).log()
t2 = (4 * p.scale * q.scale).log()
return t1 - t2

def _add_kl_info():
rows = ["KL divergence is currently implemented for the following distribution pairs:"]
for p, q in sorted(_KL_REGISTRY,
key=lambda p_q: (p_q[0].__name__, p_q[1].__name__)):
rows.append("* :class:`~torch.distributions.{}` and :class:`~torch.distributions.{}`"
.format(p.__name__, q.__name__))
kl_info = '\n\t'.join(rows)
if kl_divergence.__doc__:
kl_divergence.__doc__ += kl_info # type: ignore[operator]

+ 58
- 0
mindtorch/torch/distributions/kumaraswamy.py View File

@@ -0,0 +1,58 @@
from mindtorch.torch._six import nan
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.uniform import Uniform
from mindtorch.torch.distributions.transformed_distribution import TransformedDistribution
from mindtorch.torch.distributions.transforms import AffineTransform, PowerTransform
from mindtorch.torch.distributions.utils import broadcast_all, euler_constant
from mindtorch.torch.functional import lgamma, exp as torch_exp, full_like, pow as torch_pow, log as torch_log, \
digamma


def _moments(a, b, n):
arg1 = 1 + n / a
log_value = lgamma(arg1) + lgamma(b) - lgamma(arg1 + b)
return b * torch_exp(log_value)


class Kumaraswamy(TransformedDistribution):
arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive}
support = constraints.unit_interval
has_rsample = True

def __init__(self, concentration1, concentration0, validate_args=None):
self.concentration1, self.concentration0 = broadcast_all(concentration1, concentration0)
base_dist = Uniform(full_like(self.concentration0, 0),
full_like(self.concentration0, 1),
validate_args=validate_args)
transforms = [PowerTransform(exponent=self.concentration0.reciprocal()),
AffineTransform(loc=1., scale=-1.),
PowerTransform(exponent=self.concentration1.reciprocal())]
super(Kumaraswamy, self).__init__(base_dist, transforms, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Kumaraswamy, _instance)
new.concentration1 = self.concentration1.expand(batch_shape)
new.concentration0 = self.concentration0.expand(batch_shape)
return super(Kumaraswamy, self).expand(batch_shape, _instance=new)

@property
def mean(self):
return _moments(self.concentration1, self.concentration0, 1)

@property
def mode(self):
# Evaluate in log-space for numerical stability.
log_mode = self.concentration0.reciprocal() * \
(-self.concentration0).log1p() - (-self.concentration0 * self.concentration1).log1p()
log_mode[(self.concentration0 < 1) | (self.concentration1 < 1)] = nan
return log_mode.exp()

@property
def variance(self):
return _moments(self.concentration1, self.concentration0, 2) - torch_pow(self.mean, 2)

def entropy(self):
t1 = (1 - self.concentration1.reciprocal())
t0 = (1 - self.concentration0.reciprocal())
H0 = digamma(self.concentration0 + 1) + euler_constant
return t0 + t1 * H0 - torch_log(self.concentration1) - torch_log(self.concentration0)

+ 77
- 0
mindtorch/torch/distributions/laplace.py View File

@@ -0,0 +1,77 @@
from numbers import Number

from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.utils import broadcast_all
from mindtorch.torch._C.Size import Size
from mindtorch.torch.common.dtype import finfo as torch_finfo
from mindtorch.torch._C import _get_tracing_state
from mindtorch.torch.functional import rand as torch_rand, log1p, log as torch_log, expm1, abs as torch_abs


class Laplace(Distribution):
arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
support = constraints.real
has_rsample = True

@property
def mean(self):
return self.loc

@property
def mode(self):
return self.loc

@property
def variance(self):
return 2 * self.scale.pow(2)

@property
def stddev(self):
return (2 ** 0.5) * self.scale

def __init__(self, loc, scale, validate_args=None):
self.loc, self.scale = broadcast_all(loc, scale)
if isinstance(loc, Number) and isinstance(scale, Number):
batch_shape = Size()
else:
batch_shape = self.loc.size()
super(Laplace, self).__init__(batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Laplace, _instance)
batch_shape = Size(batch_shape)
new.loc = self.loc.expand(batch_shape)
new.scale = self.scale.expand(batch_shape)
super(Laplace, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def rsample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
finfo = torch_finfo(self.loc.dtype)
if _get_tracing_state():
# [JIT WORKAROUND] lack of support for .uniform_()
u = torch_rand(shape, dtype=self.loc.dtype, device=self.loc.device) * 2 - 1
return self.loc - self.scale * u.sign() * log1p(-u.abs().clamp(min=finfo.tiny))
u = self.loc.new(shape).uniform_(finfo.eps - 1, 1)
# TODO: If we ever implement tensor.nextafter, below is what we want ideally.
# u = self.loc.new(shape).uniform_(self.loc.nextafter(-.5, 0), .5)
return self.loc - self.scale * u.sign() * log1p(-u.abs())

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
return -torch_log(2 * self.scale) - torch_abs(value - self.loc) / self.scale # pylint: disable=E1130

def cdf(self, value):
if self._validate_args:
self._validate_sample(value)
return 0.5 - 0.5 * (value - self.loc).sign() * expm1(-(value - self.loc).abs() / self.scale)

def icdf(self, value):
term = value - 0.5
return self.loc - self.scale * (term).sign() * log1p(-2 * term.abs())

def entropy(self):
return 1 + torch_log(2 * self.scale)

+ 71
- 0
mindtorch/torch/distributions/lkj_cholesky.py View File

@@ -0,0 +1,71 @@
import math

from mindtorch.torch.distributions import constraints, Beta
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.utils import broadcast_all
from mindtorch.torch._C.Size import Size
from mindtorch.torch.conflict_functional import arange as torch_arange
from mindtorch.torch.common.dtype import finfo as torch_finfo
from mindtorch.torch.functional import sum as torch_sum, cat as torch_cat, sqrt as torch_sqrt, clamp, randn, \
diag_embed, lgamma, mvlgamma


class LKJCholesky(Distribution):
arg_constraints = {'concentration': constraints.positive}
support = constraints.corr_cholesky

def __init__(self, dim, concentration=1., validate_args=None):
if dim < 2:
raise ValueError(f'Expected dim to be an integer greater than or equal to 2. Found dim={dim}.')
self.dim = dim
self.concentration, = broadcast_all(concentration)
batch_shape = self.concentration.size()
event_shape = Size((dim, dim))
# This is used to draw vectorized samples from the beta distribution in Sec. 3.2 of [1].
marginal_conc = self.concentration + 0.5 * (self.dim - 2)
offset = torch_arange(self.dim - 1, dtype=self.concentration.dtype, device=self.concentration.device)
offset = torch_cat([offset.new_zeros((1,)), offset])
beta_conc1 = offset + 0.5
beta_conc0 = marginal_conc.unsqueeze(-1) - 0.5 * offset
self._beta = Beta(beta_conc1, beta_conc0)
super(LKJCholesky, self).__init__(batch_shape, event_shape, validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LKJCholesky, _instance)
batch_shape = Size(batch_shape)
new.dim = self.dim
new.concentration = self.concentration.expand(batch_shape)
new._beta = self._beta.expand(batch_shape + (self.dim,))
super(LKJCholesky, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def sample(self, sample_shape=Size()):
y = self._beta.sample(sample_shape).unsqueeze(-1)
u_normal = randn(self._extended_shape(sample_shape),
dtype=y.dtype,
device=y.device).tril(-1)
u_hypersphere = u_normal / u_normal.norm(dim=-1, keepdim=True)
# Replace NaNs in first row
u_hypersphere[..., 0, :].fill_(0.)
w = torch_sqrt(y) * u_hypersphere
# Fill diagonal elements; clamp for numerical stability
eps = torch_finfo(w.dtype).tiny
diag_elems = clamp(1 - torch_sum(w**2, dim=-1), min=eps).sqrt()
w += diag_embed(diag_elems)
return w

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:]
order = torch_arange(2, self.dim + 1, device=self.concentration.device)
order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
unnormalized_log_pdf = torch_sum(order * diag_elems.log(), dim=-1)
dm1 = self.dim - 1
alpha = self.concentration + 0.5 * dm1
denominator = lgamma(alpha) * dm1
numerator = mvlgamma(alpha - 0.5, dm1)
pi_constant = 0.5 * dm1 * math.log(math.pi)
normalize_term = pi_constant + numerator - denominator
return unnormalized_log_pdf - normalize_term

+ 41
- 0
mindtorch/torch/distributions/log_normal.py View File

@@ -0,0 +1,41 @@
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.transforms import ExpTransform
from mindtorch.torch.distributions.normal import Normal
from mindtorch.torch.distributions.transformed_distribution import TransformedDistribution


class LogNormal(TransformedDistribution):
arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
support = constraints.positive
has_rsample = True

def __init__(self, loc, scale, validate_args=None):
base_dist = Normal(loc, scale, validate_args=validate_args)
super(LogNormal, self).__init__(base_dist, ExpTransform(), validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LogNormal, _instance)
return super(LogNormal, self).expand(batch_shape, _instance=new)

@property
def loc(self):
return self.base_dist.loc

@property
def scale(self):
return self.base_dist.scale

@property
def mean(self):
return (self.loc + self.scale.pow(2) / 2).exp()

@property
def mode(self):
return (self.loc - self.scale.square()).exp()

@property
def variance(self):
return (self.scale.pow(2).exp() - 1) * (2 * self.loc + self.scale.pow(2)).exp()

def entropy(self):
return self.base_dist.entropy() + self.loc

+ 30
- 0
mindtorch/torch/distributions/logistic_normal.py View File

@@ -0,0 +1,30 @@
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.normal import Normal
from mindtorch.torch.distributions.transformed_distribution import TransformedDistribution
from mindtorch.torch.distributions.transforms import StickBreakingTransform


class LogisticNormal(TransformedDistribution):
arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
support = constraints.simplex
has_rsample = True

def __init__(self, loc, scale, validate_args=None):
base_dist = Normal(loc, scale, validate_args=validate_args)
if not base_dist.batch_shape:
base_dist = base_dist.expand([1])
super(LogisticNormal, self).__init__(base_dist,
StickBreakingTransform(),
validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LogisticNormal, _instance)
return super(LogisticNormal, self).expand(batch_shape, _instance=new)

@property
def loc(self):
return self.base_dist.base_dist.loc

@property
def scale(self):
return self.base_dist.base_dist.scale

+ 156
- 0
mindtorch/torch/distributions/lowrank_multivariate_normal.py View File

@@ -0,0 +1,156 @@
import math

from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv
from mindtorch.torch.distributions.utils import _standard_normal, lazy_property
from mindtorch.torch.functional import matmul, diag_embed, broadcast_tensors
from mindtorch.torch.linalg import cholesky, solve_triangular
from mindtorch.torch._C.Size import Size
from mindtorch.utils import unsupported_attr


def _batch_capacitance_tril(W, D):
m = W.size(-1)
Wt_Dinv = W.mT / D.unsqueeze(-2)
K = matmul(Wt_Dinv, W).contiguous()
K.view(-1, m * m)[:, ::m + 1] += 1 # add identity matrix to K
return cholesky(K)


def _batch_lowrank_logdet(W, D, capacitance_tril):
unsupported_attr(W)
return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(-1)


def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
Wt_Dinv = W.mT / D.unsqueeze(-2)
Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
mahalanobis_term1 = (x.pow(2) / D).sum(-1)
mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
return mahalanobis_term1 - mahalanobis_term2


class LowRankMultivariateNormal(Distribution):
arg_constraints = {"loc": constraints.real_vector,
"cov_factor": constraints.independent(constraints.real, 2),
"cov_diag": constraints.independent(constraints.positive, 1)}
support = constraints.real_vector
has_rsample = True

def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
if loc.dim() < 1:
raise ValueError("loc must be at least one-dimensional.")
event_shape = loc.shape[-1:]
if cov_factor.dim() < 2:
raise ValueError("cov_factor must be at least two-dimensional, "
"with optional leading batch dimensions")
if cov_factor.shape[-2:-1] != event_shape:
raise ValueError("cov_factor must be a batch of matrices with shape {} x m"
.format(event_shape[0]))
if cov_diag.shape[-1:] != event_shape:
raise ValueError("cov_diag must be a batch of vectors with shape {}".format(event_shape))

loc_ = loc.unsqueeze(-1)
cov_diag_ = cov_diag.unsqueeze(-1)
try:
loc_, self.cov_factor, cov_diag_ = broadcast_tensors(loc_, cov_factor, cov_diag_)
except RuntimeError as e:
raise ValueError("Incompatible batch shapes: loc {}, cov_factor {}, cov_diag {}"
.format(loc.shape, cov_factor.shape, cov_diag.shape)) from e
self.loc = loc_[..., 0]
self.cov_diag = cov_diag_[..., 0]
batch_shape = self.loc.shape[:-1]

self._unbroadcasted_cov_factor = cov_factor
self._unbroadcasted_cov_diag = cov_diag
self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
super(LowRankMultivariateNormal, self).__init__(batch_shape, event_shape,
validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LowRankMultivariateNormal, _instance)
batch_shape = Size(batch_shape)
loc_shape = batch_shape + self.event_shape
new.loc = self.loc.expand(loc_shape)
new.cov_diag = self.cov_diag.expand(loc_shape)
new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:])
new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor
new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag
new._capacitance_tril = self._capacitance_tril
super(LowRankMultivariateNormal, new).__init__(batch_shape,
self.event_shape,
validate_args=False)
new._validate_args = self._validate_args
return new

@property
def mean(self):
return self.loc

@property
def mode(self):
return self.loc

@lazy_property
def variance(self):
return (self._unbroadcasted_cov_factor.pow(2).sum(-1)
+ self._unbroadcasted_cov_diag).expand(self._batch_shape + self._event_shape)

@lazy_property
def scale_tril(self):
n = self._event_shape[0]
cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)
Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
K = matmul(Dinvsqrt_W, Dinvsqrt_W.mT).contiguous()
K.view(-1, n * n)[:, ::n + 1] += 1 # add identity matrix to K
scale_tril = cov_diag_sqrt_unsqueeze * cholesky(K)
return scale_tril.expand(self._batch_shape + self._event_shape + self._event_shape)

@lazy_property
def covariance_matrix(self):
covariance_matrix = (matmul(self._unbroadcasted_cov_factor,
self._unbroadcasted_cov_factor.mT)
+ diag_embed(self._unbroadcasted_cov_diag))
return covariance_matrix.expand(self._batch_shape + self._event_shape +
self._event_shape)

@lazy_property
def precision_matrix(self):
Wt_Dinv = (self._unbroadcasted_cov_factor.mT
/ self._unbroadcasted_cov_diag.unsqueeze(-2))
A = solve_triangular(self._capacitance_tril, Wt_Dinv, upper=False)
precision_matrix = diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A
return precision_matrix.expand(self._batch_shape + self._event_shape +
self._event_shape)

def rsample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
W_shape = shape[:-1] + self.cov_factor.shape[-1:]
eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
return (self.loc + _batch_mv(self._unbroadcasted_cov_factor, eps_W)
+ self._unbroadcasted_cov_diag.sqrt() * eps_D)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
diff = value - self.loc
M = _batch_lowrank_mahalanobis(self._unbroadcasted_cov_factor,
self._unbroadcasted_cov_diag,
diff,
self._capacitance_tril)
log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
self._unbroadcasted_cov_diag,
self._capacitance_tril)
return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M)

def entropy(self):
log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
self._unbroadcasted_cov_diag,
self._capacitance_tril)
H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det)
if len(self._batch_shape) == 0:
return H
else:
return H.expand(self._batch_shape)

+ 153
- 0
mindtorch/torch/distributions/mixture_same_family.py View File

@@ -0,0 +1,153 @@
from typing import Dict
from mindspore import _no_grad as torch_no_grad
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions import Categorical
from mindtorch.torch.distributions import constraints
from mindtorch.torch._C.Size import Size
from mindtorch.torch.functional import sum as torch_sum, logsumexp, gather
from mindtorch.torch.nn.functional import log_softmax


class MixtureSameFamily(Distribution):
arg_constraints: Dict[str, constraints.Constraint] = {}
has_rsample = False

def __init__(self,
mixture_distribution,
component_distribution,
validate_args=None):
self._mixture_distribution = mixture_distribution
self._component_distribution = component_distribution

if not isinstance(self._mixture_distribution, Categorical):
raise ValueError(" The Mixture distribution needs to be an "
" instance of torch.distribtutions.Categorical")

if not isinstance(self._component_distribution, Distribution):
raise ValueError("The Component distribution need to be an "
"instance of torch.distributions.Distribution")

# Check that batch size matches
mdbs = self._mixture_distribution.batch_shape
cdbs = self._component_distribution.batch_shape[:-1]
for size1, size2 in zip(reversed(mdbs), reversed(cdbs)):
if size1 != 1 and size2 != 1 and size1 != size2:
raise ValueError("`mixture_distribution.batch_shape` ({0}) is not "
"compatible with `component_distribution."
"batch_shape`({1})".format(mdbs, cdbs))

# Check that the number of mixture component matches
km = self._mixture_distribution.logits.shape[-1]
kc = self._component_distribution.batch_shape[-1]
if km is not None and kc is not None and km != kc:
raise ValueError("`mixture_distribution component` ({0}) does not"
" equal `component_distribution.batch_shape[-1]`"
" ({1})".format(km, kc))
self._num_component = km

event_shape = self._component_distribution.event_shape
self._event_ndims = len(event_shape)
super(MixtureSameFamily, self).__init__(batch_shape=cdbs,
event_shape=event_shape,
validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
batch_shape = Size(batch_shape)
batch_shape_comp = batch_shape + (self._num_component,)
new = self._get_checked_instance(MixtureSameFamily, _instance)
new._component_distribution = \
self._component_distribution.expand(batch_shape_comp)
new._mixture_distribution = \
self._mixture_distribution.expand(batch_shape)
new._num_component = self._num_component
new._event_ndims = self._event_ndims
event_shape = new._component_distribution.event_shape
super(MixtureSameFamily, new).__init__(batch_shape=batch_shape,
event_shape=event_shape,
validate_args=False)
new._validate_args = self._validate_args
return new

@constraints.dependent_property
def support(self):
return self._component_distribution.support

@property
def mixture_distribution(self):
return self._mixture_distribution

@property
def component_distribution(self):
return self._component_distribution

@property
def mean(self):
probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
return torch_sum(probs * self.component_distribution.mean,
dim=-1 - self._event_ndims) # [B, E]

@property
def variance(self):
probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
mean_cond_var = torch_sum(probs * self.component_distribution.variance,
dim=-1 - self._event_ndims)
var_cond_mean = torch_sum(probs * (self.component_distribution.mean -
self._pad(self.mean)).pow(2.0),
dim=-1 - self._event_ndims)
return mean_cond_var + var_cond_mean

def cdf(self, x):
x = self._pad(x)
cdf_x = self.component_distribution.cdf(x)
mix_prob = self.mixture_distribution.probs

return torch_sum(cdf_x * mix_prob, dim=-1)

def log_prob(self, x):
if self._validate_args:
self._validate_sample(x)
x = self._pad(x)
log_prob_x = self.component_distribution.log_prob(x) # [S, B, k]
log_mix_prob = log_softmax(self.mixture_distribution.logits, dim=-1) # [B, k]
return logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B]

def sample(self, sample_shape=Size()):
with torch_no_grad():
sample_len = len(sample_shape)
batch_len = len(self.batch_shape)
gather_dim = sample_len + batch_len
es = self.event_shape

# mixture samples [n, B]
mix_sample = self.mixture_distribution.sample(sample_shape)
mix_shape = mix_sample.shape

# component samples [n, B, k, E]
comp_samples = self.component_distribution.sample(sample_shape)

# Gather along the k dimension
mix_sample_r = mix_sample.reshape(
mix_shape + Size([1] * (len(es) + 1)))
mix_sample_r = mix_sample_r.repeat(
Size([1] * len(mix_shape)) + Size([1]) + es)

samples = gather(comp_samples, gather_dim, mix_sample_r)
return samples.squeeze(gather_dim)

def _pad(self, x):
return x.unsqueeze(-1 - self._event_ndims)

def _pad_mixture_dimensions(self, x):
dist_batch_ndims = self.batch_shape.numel()
cat_batch_ndims = self.mixture_distribution.batch_shape.numel()
pad_ndims = 0 if cat_batch_ndims == 1 else \
dist_batch_ndims - cat_batch_ndims
xs = x.shape
x = x.reshape(xs[:-1] + Size(pad_ndims * [1]) +
xs[-1:] + Size(self._event_ndims * [1]))
return x

def __repr__(self):
args_string = '\n {},\n {}'.format(self.mixture_distribution,
self.component_distribution)
return 'MixtureSameFamily' + '(' + args_string + ')'

+ 95
- 0
mindtorch/torch/distributions/multinomial.py View File

@@ -0,0 +1,95 @@
from mindtorch.torch._six import inf
from mindtorch.torch.distributions.binomial import Binomial
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions import Categorical
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.utils import broadcast_all
from mindtorch.torch._C.Size import Size
from mindtorch.torch.functional import ones_like, lgamma, exp as torch_exp
from ..tensor import tensor as torch_tensor


class Multinomial(Distribution):
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real_vector}
total_count: int

@property
def mean(self):
return self.probs * self.total_count

@property
def variance(self):
return self.total_count * self.probs * (1 - self.probs)

def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
if not isinstance(total_count, int):
raise NotImplementedError('inhomogeneous total_count is not supported')
self.total_count = total_count
self._categorical = Categorical(probs=probs, logits=logits)
self._binomial = Binomial(total_count=total_count, probs=self.probs)
batch_shape = self._categorical.batch_shape
event_shape = self._categorical.param_shape[-1:]
super(Multinomial, self).__init__(batch_shape, event_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Multinomial, _instance)
batch_shape = Size(batch_shape)
new.total_count = self.total_count
new._categorical = self._categorical.expand(batch_shape)
super(Multinomial, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def _new(self, *args, **kwargs):
return self._categorical._new(*args, **kwargs)

@constraints.dependent_property(is_discrete=True, event_dim=1)
def support(self):
return constraints.multinomial(self.total_count)

@property
def logits(self):
return self._categorical.logits

@property
def probs(self):
return self._categorical.probs

@property
def param_shape(self):
return self._categorical.param_shape

def sample(self, sample_shape=Size()):
sample_shape = Size(sample_shape)
samples = self._categorical.sample(Size((self.total_count,)) + sample_shape)
shifted_idx = list(range(samples.dim()))
shifted_idx.append(shifted_idx.pop(0))
samples = samples.permute(*shifted_idx)
counts = samples.new(self._extended_shape(sample_shape)).zero_()
counts.scatter_add_(-1, samples, ones_like(samples))
return counts.type_as(self.probs)

def entropy(self):
n = torch_tensor(self.total_count)

cat_entropy = self._categorical.entropy()
term1 = n * cat_entropy - lgamma(n + 1)

support = self._binomial.enumerate_support(expand=False)[1:]
binomial_probs = torch_exp(self._binomial.log_prob(support))
weights = lgamma(support + 1)
term2 = (binomial_probs * weights).sum([0, -1])

return term1 + term2

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits, value)
logits = logits.clone()
log_factorial_n = lgamma(value.sum(-1) + 1)
log_factorial_xs = lgamma(value + 1).sum(-1)
logits[(value == 0) & (logits == -inf)] = 0
log_powers = (logits * value).sum(-1)
return log_factorial_n - log_factorial_xs + log_powers

+ 171
- 0
mindtorch/torch/distributions/multivariate_normal.py View File

@@ -0,0 +1,171 @@
import math

from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.utils import _standard_normal, lazy_property
from mindtorch.torch.functional import matmul, flip, transpose, eye as torch_eye, broadcast_shapes, cholesky_inverse
from mindtorch.torch.linalg import solve_triangular, cholesky
from mindtorch.torch._C.Size import Size


def _batch_mv(bmat, bvec):
return matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)


def _batch_mahalanobis(bL, bx):
n = bx.size(-1)
bx_batch_shape = bx.shape[:-1]

bx_batch_dims = len(bx_batch_shape)
bL_batch_dims = bL.dim() - 2
outer_batch_dims = bx_batch_dims - bL_batch_dims
old_batch_dims = outer_batch_dims + bL_batch_dims
new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
# Reshape bx with the shape (..., 1, i, j, 1, n)
bx_new_shape = bx.shape[:outer_batch_dims]
for (sL, sx) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
bx_new_shape += (sx // sL, sL)
bx_new_shape += (n,)
bx = bx.reshape(bx_new_shape)
# Permute bx to make it have shape (..., 1, j, i, 1, n)
permute_dims = (list(range(outer_batch_dims)) +
list(range(outer_batch_dims, new_batch_dims, 2)) +
list(range(outer_batch_dims + 1, new_batch_dims, 2)) +
[new_batch_dims])
bx = bx.permute(permute_dims)

flat_L = bL.reshape(-1, n, n) # shape = b x n x n
flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n
flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c
M_swap = solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2) # shape = b x c
M = M_swap.t() # shape = c x b

# Now we revert the above reshape and permute operators.
permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1)
permute_inv_dims = list(range(outer_batch_dims))
for i in range(bL_batch_dims):
permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
reshaped_M = permuted_M.permute(permute_inv_dims) # shape = (..., 1, i, j, 1)
return reshaped_M.reshape(bx_batch_shape)


def _precision_to_scale_tril(P):
Lf = cholesky(flip(P, (-2, -1)))
L_inv = transpose(flip(Lf, (-2, -1)), -2, -1)
Id = torch_eye(P.shape[-1], dtype=P.dtype, device=P.device)
L = solve_triangular(L_inv, Id, upper=False)
return L


class MultivariateNormal(Distribution):
arg_constraints = {'loc': constraints.real_vector,
'covariance_matrix': constraints.positive_definite,
'precision_matrix': constraints.positive_definite,
'scale_tril': constraints.lower_cholesky}
support = constraints.real_vector
has_rsample = True

def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None):
if loc.dim() < 1:
raise ValueError("loc must be at least one-dimensional.")
if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1:
raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.")

if scale_tril is not None:
if scale_tril.dim() < 2:
raise ValueError("scale_tril matrix must be at least two-dimensional, "
"with optional leading batch dimensions")
batch_shape = broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
elif covariance_matrix is not None:
if covariance_matrix.dim() < 2:
raise ValueError("covariance_matrix must be at least two-dimensional, "
"with optional leading batch dimensions")
batch_shape = broadcast_shapes(covariance_matrix.shape[:-2], loc.shape[:-1])
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
else:
if precision_matrix.dim() < 2:
raise ValueError("precision_matrix must be at least two-dimensional, "
"with optional leading batch dimensions")
batch_shape = broadcast_shapes(precision_matrix.shape[:-2], loc.shape[:-1])
self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
self.loc = loc.expand(batch_shape + (-1,))

event_shape = self.loc.shape[-1:]
super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args)

if scale_tril is not None:
self._unbroadcasted_scale_tril = scale_tril
elif covariance_matrix is not None:
self._unbroadcasted_scale_tril = cholesky(covariance_matrix)
else: # precision_matrix is not None
self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(MultivariateNormal, _instance)
batch_shape = Size(batch_shape)
loc_shape = batch_shape + self.event_shape
cov_shape = batch_shape + self.event_shape + self.event_shape
new.loc = self.loc.expand(loc_shape)
new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
if 'covariance_matrix' in self.__dict__:
new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
if 'scale_tril' in self.__dict__:
new.scale_tril = self.scale_tril.expand(cov_shape)
if 'precision_matrix' in self.__dict__:
new.precision_matrix = self.precision_matrix.expand(cov_shape)
super(MultivariateNormal, new).__init__(batch_shape,
self.event_shape,
validate_args=False)
new._validate_args = self._validate_args
return new

@lazy_property
def scale_tril(self): # pylint: disable=E0202
return self._unbroadcasted_scale_tril.expand(
self._batch_shape + self._event_shape + self._event_shape)

@lazy_property
def covariance_matrix(self): # pylint: disable=E0202
return (matmul(self._unbroadcasted_scale_tril,
self._unbroadcasted_scale_tril.mT)
.expand(self._batch_shape + self._event_shape + self._event_shape))

@lazy_property
def precision_matrix(self): # pylint: disable=E0202
return cholesky_inverse(self._unbroadcasted_scale_tril).expand(
self._batch_shape + self._event_shape + self._event_shape)

@property
def mean(self):
return self.loc

@property
def mode(self):
return self.loc

@property
def variance(self):
return self._unbroadcasted_scale_tril.pow(2).sum(-1).expand(
self._batch_shape + self._event_shape)

def rsample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
diff = value - self.loc
M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det

def entropy(self):
half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det
if len(self._batch_shape) == 0:
return H
else:
return H.expand(self._batch_shape)

+ 95
- 0
mindtorch/torch/distributions/negative_binomial.py View File

@@ -0,0 +1,95 @@
from mindspore import _no_grad as torch_no_grad
import mindtorch.torch.nn.functional as F
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.gamma import Gamma
from mindtorch.torch.distributions.utils import broadcast_all, probs_to_logits, lazy_property, logits_to_probs
from mindtorch.torch.functional import exp as torch_exp, sigmoid, poisson, lgamma
from mindtorch.torch._C.Size import Size


class NegativeBinomial(Distribution):
arg_constraints = {'total_count': constraints.greater_than_eq(0),
'probs': constraints.half_open_interval(0., 1.),
'logits': constraints.real}
support = constraints.nonnegative_integer

def __init__(self, total_count, probs=None, logits=None, validate_args=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
self.total_count, self.probs, = broadcast_all(total_count, probs)
self.total_count = self.total_count.type_as(self.probs)
else:
self.total_count, self.logits, = broadcast_all(total_count, logits)
self.total_count = self.total_count.type_as(self.logits)

self._param = self.probs if probs is not None else self.logits
batch_shape = self._param.size()
super(NegativeBinomial, self).__init__(batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(NegativeBinomial, _instance)
batch_shape = Size(batch_shape)
new.total_count = self.total_count.expand(batch_shape)
if 'probs' in self.__dict__:
new.probs = self.probs.expand(batch_shape)
new._param = new.probs
if 'logits' in self.__dict__:
new.logits = self.logits.expand(batch_shape)
new._param = new.logits
super(NegativeBinomial, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

@property
def mean(self):
return self.total_count * torch_exp(self.logits)

@property
def mode(self):
return ((self.total_count - 1) * self.logits.exp()).floor().clamp(min=0.)

@property
def variance(self):
return self.mean / sigmoid(-self.logits)

@lazy_property
def logits(self): # pylint: disable=E0202
return probs_to_logits(self.probs, is_binary=True)

@lazy_property
def probs(self): # pylint: disable=E0202
return logits_to_probs(self.logits, is_binary=True)

@property
def param_shape(self):
return self._param.size()

@lazy_property
def _gamma(self):
# Note we avoid validating because self.total_count can be zero.
return Gamma(concentration=self.total_count,
rate=torch_exp(-self.logits),
validate_args=False)

def sample(self, sample_shape=Size()):
with torch_no_grad():
rate = self._gamma.sample(sample_shape=sample_shape)
return poisson(rate)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)

log_unnormalized_prob = (self.total_count * F.logsigmoid(-self.logits) +
value * F.logsigmoid(self.logits))

log_normalization = (-lgamma(self.total_count + value) + lgamma(1. + value) + # pylint: disable=E1130
lgamma(self.total_count))
log_normalization[self.total_count + value == 0.] = 0.

return log_unnormalized_prob - log_normalization

+ 86
- 0
mindtorch/torch/distributions/normal.py View File

@@ -0,0 +1,86 @@
import math
from numbers import Real
from numbers import Number

from mindspore import _no_grad as torch_no_grad
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.exp_family import ExponentialFamily
from mindtorch.torch.distributions.utils import _standard_normal, broadcast_all
from mindtorch.torch.functional import normal, erf, erfinv, log as torch_log
from mindtorch.torch._C.Size import Size


class Normal(ExponentialFamily):
arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
support = constraints.real
has_rsample = True
_mean_carrier_measure = 0

@property
def mean(self):
return self.loc

@property
def mode(self):
return self.loc

@property
def stddev(self):
return self.scale

@property
def variance(self):
return self.stddev.pow(2)

def __init__(self, loc, scale, validate_args=None):
self.loc, self.scale = broadcast_all(loc, scale)
if isinstance(loc, Number) and isinstance(scale, Number):
batch_shape = Size()
else:
batch_shape = self.loc.size()
super(Normal, self).__init__(batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Normal, _instance)
batch_shape = Size(batch_shape)
new.loc = self.loc.expand(batch_shape)
new.scale = self.scale.expand(batch_shape)
super(Normal, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def sample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
with torch_no_grad():
return normal(self.loc.expand(shape), self.scale.expand(shape))

def rsample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
return self.loc + eps * self.scale

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
# compute the variance
var = (self.scale ** 2)
log_scale = math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log()
return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))

def cdf(self, value):
if self._validate_args:
self._validate_sample(value)
return 0.5 * (1 + erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2)))

def icdf(self, value):
return self.loc + self.scale * erfinv(2 * value - 1) * math.sqrt(2)

def entropy(self):
return 0.5 + 0.5 * math.log(2 * math.pi) + torch_log(self.scale)

@property
def _natural_params(self):
return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal())

def _log_normalizer(self, x, y):
return -0.25 * x.pow(2) / y + 0.5 * torch_log(-math.pi / y)

+ 92
- 0
mindtorch/torch/distributions/one_hot_categorical.py View File

@@ -0,0 +1,92 @@
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.categorical import Categorical
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.nn.functional import one_hot
from mindtorch.torch.functional import eye
from mindtorch.torch._C.Size import Size


class OneHotCategorical(Distribution):
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real_vector}
support = constraints.one_hot
has_enumerate_support = True

def __init__(self, probs=None, logits=None, validate_args=None):
self._categorical = Categorical(probs, logits)
batch_shape = self._categorical.batch_shape
event_shape = self._categorical.param_shape[-1:]
super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(OneHotCategorical, _instance)
batch_shape = Size(batch_shape)
new._categorical = self._categorical.expand(batch_shape)
super(OneHotCategorical, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def _new(self, *args, **kwargs):
return self._categorical._new(*args, **kwargs)

@property
def _param(self):
return self._categorical._param

@property
def probs(self):
return self._categorical.probs

@property
def logits(self):
return self._categorical.logits

@property
def mean(self):
return self._categorical.probs

@property
def mode(self):
probs = self._categorical.probs
mode = probs.argmax(axis=-1)
return one_hot(mode, num_classes=probs.shape[-1]).to(probs)

@property
def variance(self):
return self._categorical.probs * (1 - self._categorical.probs)

@property
def param_shape(self):
return self._categorical.param_shape

def sample(self, sample_shape=Size()):
sample_shape = Size(sample_shape)
probs = self._categorical.probs
num_events = self._categorical._num_events
indices = self._categorical.sample(sample_shape)
return one_hot(indices, num_events).to(probs)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
indices = value.max(-1)[1]
return self._categorical.log_prob(indices)

def entropy(self):
return self._categorical.entropy()

def enumerate_support(self, expand=True):
n = self.event_shape[0]
values = eye(n, dtype=self._param.dtype, device=self._param.device)
values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
if expand:
values = values.expand((n,) + self.batch_shape + (n,))
return values

class OneHotCategoricalStraightThrough(OneHotCategorical):
has_rsample = True

def rsample(self, sample_shape=Size()):
samples = self.sample(sample_shape)
probs = self._categorical.probs # cached via @lazy_property
return samples + (probs - probs.detach())

+ 44
- 0
mindtorch/torch/distributions/pareto.py View File

@@ -0,0 +1,44 @@
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.exponential import Exponential
from mindtorch.torch.distributions.transformed_distribution import TransformedDistribution
from mindtorch.torch.distributions.transforms import AffineTransform, ExpTransform
from mindtorch.torch.distributions.utils import broadcast_all


class Pareto(TransformedDistribution):
arg_constraints = {'alpha': constraints.positive, 'scale': constraints.positive}

def __init__(self, scale, alpha, validate_args=None):
self.scale, self.alpha = broadcast_all(scale, alpha)
base_dist = Exponential(self.alpha, validate_args=validate_args)
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
super(Pareto, self).__init__(base_dist, transforms, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Pareto, _instance)
new.scale = self.scale.expand(batch_shape)
new.alpha = self.alpha.expand(batch_shape)
return super(Pareto, self).expand(batch_shape, _instance=new)

@property
def mean(self):
# mean is inf for alpha <= 1
a = self.alpha.clamp(min=1)
return a * self.scale / (a - 1)

@property
def mode(self):
return self.scale

@property
def variance(self):
# var is inf for alpha <= 2
a = self.alpha.clamp(min=2)
return self.scale.pow(2) * a / ((a - 1).pow(2) * (a - 2))

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return constraints.greater_than_eq(self.scale)

def entropy(self):
return (self.scale / self.alpha).log() + (1 + self.alpha.reciprocal())

+ 59
- 0
mindtorch/torch/distributions/poisson.py View File

@@ -0,0 +1,59 @@
from numbers import Number

from mindspore import _no_grad as torch_no_grad
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.exp_family import ExponentialFamily
from mindtorch.torch.distributions.utils import broadcast_all
from mindtorch.torch.functional import exp as torch_exp, poisson, log as torch_log
from mindtorch.torch._C.Size import Size


class Poisson(ExponentialFamily):
arg_constraints = {'rate': constraints.nonnegative}
support = constraints.nonnegative_integer

@property
def mean(self):
return self.rate

@property
def mode(self):
return self.rate.floor()

@property
def variance(self):
return self.rate

def __init__(self, rate, validate_args=None):
self.rate, = broadcast_all(rate)
if isinstance(rate, Number):
batch_shape = Size()
else:
batch_shape = self.rate.size()
super(Poisson, self).__init__(batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Poisson, _instance)
batch_shape = Size(batch_shape)
new.rate = self.rate.expand(batch_shape)
super(Poisson, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def sample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
with torch_no_grad():
return poisson(self.rate.expand(shape))

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
rate, value = broadcast_all(self.rate, value)
return value.xlogy(rate) - rate - (value + 1).lgamma()

@property
def _natural_params(self):
return (torch_log(self.rate), )

def _log_normalizer(self, x):
return torch_exp(x)

+ 103
- 0
mindtorch/torch/distributions/relaxed_bernoulli.py View File

@@ -0,0 +1,103 @@
from numbers import Number
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.transformed_distribution import TransformedDistribution
from mindtorch.torch.distributions.transforms import SigmoidTransform
from mindtorch.torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, \
lazy_property, clamp_probs
from mindtorch.torch.functional import rand
from mindtorch.torch._C.Size import Size


class LogitRelaxedBernoulli(Distribution):
arg_constraints = {'probs': constraints.unit_interval,
'logits': constraints.real}
support = constraints.real

def __init__(self, temperature, probs=None, logits=None, validate_args=None):
self.temperature = temperature
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
is_scalar = isinstance(probs, Number)
self.probs, = broadcast_all(probs)
else:
is_scalar = isinstance(logits, Number)
self.logits, = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
batch_shape = Size()
else:
batch_shape = self._param.size()
super(LogitRelaxedBernoulli, self).__init__(batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(LogitRelaxedBernoulli, _instance)
batch_shape = Size(batch_shape)
new.temperature = self.temperature
if 'probs' in self.__dict__:
new.probs = self.probs.expand(batch_shape)
new._param = new.probs
if 'logits' in self.__dict__:
new.logits = self.logits.expand(batch_shape)
new._param = new.logits
super(LogitRelaxedBernoulli, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)

@lazy_property
def logits(self): # pylint: disable=E0202
return probs_to_logits(self.probs, is_binary=True)

@lazy_property
def probs(self): # pylint: disable=E0202
return logits_to_probs(self.logits, is_binary=True)

@property
def param_shape(self):
return self._param.size()

def rsample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
probs = clamp_probs(self.probs.expand(shape))
uniforms = clamp_probs(rand(shape, dtype=probs.dtype, device=probs.device))
return (uniforms.log() - (-uniforms).log1p() + probs.log() - (-probs).log1p()) / self.temperature

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits, value)
diff = logits - value.mul(self.temperature)
return self.temperature.log() + diff - 2 * diff.exp().log1p()


class RelaxedBernoulli(TransformedDistribution):
arg_constraints = {'probs': constraints.unit_interval,
'logits': constraints.real}
support = constraints.unit_interval
has_rsample = True

def __init__(self, temperature, probs=None, logits=None, validate_args=None):
base_dist = LogitRelaxedBernoulli(temperature, probs, logits)
super(RelaxedBernoulli, self).__init__(base_dist,
SigmoidTransform(),
validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(RelaxedBernoulli, _instance)
return super(RelaxedBernoulli, self).expand(batch_shape, _instance=new)

@property
def temperature(self):
return self.base_dist.temperature

@property
def logits(self):
return self.base_dist.logits

@property
def probs(self):
return self.base_dist.probs

+ 93
- 0
mindtorch/torch/distributions/relaxed_categorical.py View File

@@ -0,0 +1,93 @@
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.categorical import Categorical
from mindtorch.torch.distributions.utils import clamp_probs, broadcast_all
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.transformed_distribution import TransformedDistribution
from mindtorch.torch.distributions.transforms import ExpTransform
from mindtorch.torch.functional import rand, full_like
from mindtorch.torch._C.Size import Size


class ExpRelaxedCategorical(Distribution):
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real_vector}
support = constraints.real_vector # The true support is actually a submanifold of this.
has_rsample = True

def __init__(self, temperature, probs=None, logits=None, validate_args=None):
self._categorical = Categorical(probs, logits)
self.temperature = temperature
batch_shape = self._categorical.batch_shape
event_shape = self._categorical.param_shape[-1:]
super(ExpRelaxedCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(ExpRelaxedCategorical, _instance)
batch_shape = Size(batch_shape)
new.temperature = self.temperature
new._categorical = self._categorical.expand(batch_shape)
super(ExpRelaxedCategorical, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def _new(self, *args, **kwargs):
return self._categorical._new(*args, **kwargs)

@property
def param_shape(self):
return self._categorical.param_shape

@property
def logits(self):
return self._categorical.logits

@property
def probs(self):
return self._categorical.probs

def rsample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
uniforms = clamp_probs(rand(shape, dtype=self.logits.dtype, device=self.logits.device))
gumbels = -((-(uniforms.log())).log())
scores = (self.logits + gumbels) / self.temperature
return scores - scores.logsumexp(dim=-1, keepdim=True)

def log_prob(self, value):
K = self._categorical._num_events
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits, value)
log_scale = (full_like(self.temperature, float(K)).lgamma() -
self.temperature.log().mul(-(K - 1)))
score = logits - value.mul(self.temperature)
score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
return score + log_scale


class RelaxedOneHotCategorical(TransformedDistribution):
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real_vector}
support = constraints.simplex
has_rsample = True

def __init__(self, temperature, probs=None, logits=None, validate_args=None):
base_dist = ExpRelaxedCategorical(temperature, probs, logits, validate_args=validate_args)
super(RelaxedOneHotCategorical, self).__init__(base_dist,
ExpTransform(),
validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(RelaxedOneHotCategorical, _instance)
return super(RelaxedOneHotCategorical, self).expand(batch_shape, _instance=new)

@property
def temperature(self):
return self.base_dist.temperature

@property
def logits(self):
return self.base_dist.logits

@property
def probs(self):
return self.base_dist.probs

+ 74
- 0
mindtorch/torch/distributions/studentT.py View File

@@ -0,0 +1,74 @@
import math

from mindtorch.torch._six import inf, nan
from mindtorch.torch.distributions import Chi2, constraints
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.utils import _standard_normal, broadcast_all
from mindtorch.torch.functional import rsqrt, lgamma, log1p, digamma
from mindtorch.torch._C.Size import Size


class StudentT(Distribution):
arg_constraints = {'df': constraints.positive, 'loc': constraints.real, 'scale': constraints.positive}
support = constraints.real
has_rsample = True

@property
def mean(self):
m = self.loc.clone()
m[self.df <= 1] = nan
return m

@property
def mode(self):
return self.loc

@property
def variance(self):
m = self.df.clone()
m[self.df > 2] = self.scale[self.df > 2].pow(2) * self.df[self.df > 2] / (self.df[self.df > 2] - 2)
m[(self.df <= 2) & (self.df > 1)] = inf
m[self.df <= 1] = nan
return m

def __init__(self, df, loc=0., scale=1., validate_args=None):
self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
self._chi2 = Chi2(self.df)
batch_shape = self.df.size()
super(StudentT, self).__init__(batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(StudentT, _instance)
batch_shape = Size(batch_shape)
new.df = self.df.expand(batch_shape)
new.loc = self.loc.expand(batch_shape)
new.scale = self.scale.expand(batch_shape)
new._chi2 = self._chi2.expand(batch_shape)
super(StudentT, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def rsample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
X = _standard_normal(shape, dtype=self.df.dtype, device=self.df.device)
Z = self._chi2.rsample(sample_shape)
Y = X * rsqrt(Z / self.df)
return self.loc + self.scale * Y

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
y = (value - self.loc) / self.scale
Z = (self.scale.log() +
0.5 * self.df.log() +
0.5 * math.log(math.pi) +
lgamma(0.5 * self.df) -
lgamma(0.5 * (self.df + 1.)))
return -0.5 * (self.df + 1.) * log1p(y**2. / self.df) - Z

def entropy(self):
lbeta = lgamma(0.5 * self.df) + math.lgamma(0.5) - lgamma(0.5 * (self.df + 1))
return (self.scale.log() +
0.5 * (self.df + 1) *
(digamma(0.5 * (self.df + 1)) - digamma(0.5 * self.df)) +
0.5 * self.df.log() + lbeta)

+ 128
- 0
mindtorch/torch/distributions/transformed_distribution.py View File

@@ -0,0 +1,128 @@
from typing import Dict
from mindspore import _no_grad as torch_no_grad
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.independent import Independent
from mindtorch.torch.distributions.transforms import ComposeTransform, Transform
from mindtorch.torch.distributions.utils import _sum_rightmost
from mindtorch.torch._C.Size import Size


class TransformedDistribution(Distribution):
arg_constraints: Dict[str, constraints.Constraint] = {}

def __init__(self, base_distribution, transforms, validate_args=None):
if isinstance(transforms, Transform):
self.transforms = [transforms, ]
elif isinstance(transforms, list):
if not all(isinstance(t, Transform) for t in transforms):
raise ValueError("transforms must be a Transform or a list of Transforms")
self.transforms = transforms
else:
raise ValueError("transforms must be a Transform or list, but was {}".format(transforms))

# Reshape base_distribution according to transforms.
base_shape = base_distribution.batch_shape + base_distribution.event_shape
base_event_dim = len(base_distribution.event_shape)
transform = ComposeTransform(self.transforms)
domain_event_dim = transform.domain.event_dim
if len(base_shape) < domain_event_dim:
raise ValueError("base_distribution needs to have shape with size at least {}, but got {}."
.format(domain_event_dim, base_shape))
shape = transform.forward_shape(base_shape)
expanded_base_shape = transform.inverse_shape(shape)
if base_shape != expanded_base_shape:
base_batch_shape = expanded_base_shape[:len(expanded_base_shape) - base_event_dim]
base_distribution = base_distribution.expand(base_batch_shape)
reinterpreted_batch_ndims = domain_event_dim - base_event_dim
if reinterpreted_batch_ndims > 0:
base_distribution = Independent(base_distribution, reinterpreted_batch_ndims)
self.base_dist = base_distribution

# Compute shapes.
event_dim = transform.codomain.event_dim + max(base_event_dim - domain_event_dim, 0)
assert len(shape) >= event_dim
cut = len(shape) - event_dim
batch_shape = shape[:cut]
event_shape = shape[cut:]
super(TransformedDistribution, self).__init__(batch_shape, event_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(TransformedDistribution, _instance)
batch_shape = Size(batch_shape)
shape = batch_shape + self.event_shape
for t in reversed(self.transforms):
shape = t.inverse_shape(shape)
base_batch_shape = shape[:len(shape) - len(self.base_dist.event_shape)]
new.base_dist = self.base_dist.expand(base_batch_shape)
new.transforms = self.transforms
super(TransformedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new

@constraints.dependent_property(is_discrete=False)
def support(self):
if not self.transforms:
return self.base_dist.support
support = self.transforms[-1].codomain
if len(self.event_shape) > support.event_dim:
support = constraints.independent(support, len(self.event_shape) - support.event_dim)
return support

@property
def has_rsample(self):
return self.base_dist.has_rsample

def sample(self, sample_shape=Size()):
with torch_no_grad():
x = self.base_dist.sample(sample_shape)
for transform in self.transforms:
x = transform(x)
return x

def rsample(self, sample_shape=Size()):
x = self.base_dist.rsample(sample_shape)
for transform in self.transforms:
x = transform(x)
return x

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
event_dim = len(self.event_shape)
log_prob = 0.0
y = value
for transform in reversed(self.transforms):
x = transform.inv(y)
event_dim += transform.domain.event_dim - transform.codomain.event_dim
log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
event_dim - transform.domain.event_dim)
y = x

log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y),
event_dim - len(self.base_dist.event_shape))
return log_prob

def _monotonize_cdf(self, value):
sign = 1
for transform in self.transforms:
sign = sign * transform.sign
if isinstance(sign, int) and sign == 1:
return value
return sign * (value - 0.5) + 0.5

def cdf(self, value):
for transform in self.transforms[::-1]:
value = transform.inv(value)
if self._validate_args:
self.base_dist._validate_sample(value)
value = self.base_dist.cdf(value)
value = self._monotonize_cdf(value)
return value

def icdf(self, value):
value = self._monotonize_cdf(value)
value = self.base_dist.icdf(value)
for transform in self.transforms:
value = transform(value)
return value

+ 956
- 0
mindtorch/torch/distributions/transforms.py View File

@@ -0,0 +1,956 @@
import functools
import math
import numbers
import operator
import weakref
from typing import List

import mindtorch.torch.nn.functional as F
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.utils import (_sum_rightmost, broadcast_all,
lazy_property, tril_matrix_to_vec,
vec_to_tril_matrix)
from mindtorch.torch.nn.functional import pad
from mindtorch.torch.nn.functional import softplus
from mindtorch.torch._C.Size import Size
from mindtorch.torch.functional import zeros_like, broadcast_shapes, sigmoid, clamp, atanh, full_like, \
abs as torch_abs, tanh, eye, cumsum, cat, stack
from mindtorch.torch.common.dtype import finfo as torch_finfo
from mindtorch.utils import unsupported_attr


__all__ = [
'AbsTransform',
'AffineTransform',
'CatTransform',
'ComposeTransform',
'CorrCholeskyTransform',
'CumulativeDistributionTransform',
'ExpTransform',
'IndependentTransform',
'LowerCholeskyTransform',
'PowerTransform',
'ReshapeTransform',
'SigmoidTransform',
'SoftplusTransform',
'TanhTransform',
'SoftmaxTransform',
'StackTransform',
'StickBreakingTransform',
'Transform',
'identity_transform',
]


class Transform():
bijective = False
domain: constraints.Constraint
codomain: constraints.Constraint

def __init__(self, cache_size=0):
self._cache_size = cache_size
self._inv = None
if cache_size == 0:
pass # default behavior
elif cache_size == 1:
self._cached_x_y = None, None
else:
raise ValueError('cache_size must be 0 or 1')
super(Transform, self).__init__()

@property
def event_dim(self):
if self.domain.event_dim == self.codomain.event_dim:
return self.domain.event_dim
raise ValueError("Please use either .domain.event_dim or .codomain.event_dim")

@property
def inv(self):
inv = None
if self._inv is not None:
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
self._inv = weakref.ref(inv)
return inv

@property
def sign(self):
raise NotImplementedError

def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
if type(self).__init__ is Transform.__init__:
return type(self)(cache_size=cache_size)
raise NotImplementedError("{}.with_cache is not implemented".format(type(self)))

def __eq__(self, other):
return self is other

def __ne__(self, other):
return not self.__eq__(other)

def __call__(self, x):
if self._cache_size == 0:
return self._call(x)
x_old, y_old = self._cached_x_y
if x is x_old:
return y_old
y = self._call(x)
self._cached_x_y = x, y
return y

def _inv_call(self, y):
if self._cache_size == 0:
return self._inverse(y)
x_old, y_old = self._cached_x_y
if y is y_old:
return x_old
x = self._inverse(y)
self._cached_x_y = x, y
return x

def _call(self, x):
raise NotImplementedError

def _inverse(self, y):
raise NotImplementedError

def log_abs_det_jacobian(self, x, y):
raise NotImplementedError

def __repr__(self):
return self.__class__.__name__ + '()'

def forward_shape(self, shape):
return shape

def inverse_shape(self, shape):
return shape


class _InverseTransform(Transform):
def __init__(self, transform: Transform):
super(_InverseTransform, self).__init__(cache_size=transform._cache_size)
self._inv: Transform = transform

@constraints.dependent_property(is_discrete=False)
def domain(self):
assert self._inv is not None
return self._inv.codomain

@constraints.dependent_property(is_discrete=False)
def codomain(self):
assert self._inv is not None
return self._inv.domain

@property
def bijective(self):
assert self._inv is not None
return self._inv.bijective

@property
def sign(self):
assert self._inv is not None
return self._inv.sign

@property
def inv(self):
return self._inv

def with_cache(self, cache_size=1):
assert self._inv is not None
return self.inv.with_cache(cache_size).inv

def __eq__(self, other):
if not isinstance(other, _InverseTransform):
return False
assert self._inv is not None
return self._inv == other._inv

def __repr__(self):
return f"{self.__class__.__name__}({repr(self._inv)})"

def __call__(self, x):
assert self._inv is not None
return self._inv._inv_call(x)

def log_abs_det_jacobian(self, x, y):
assert self._inv is not None
return -self._inv.log_abs_det_jacobian(y, x)

def forward_shape(self, shape):
return self._inv.inverse_shape(shape)

def inverse_shape(self, shape):
return self._inv.forward_shape(shape)


class ComposeTransform(Transform):
def __init__(self, parts: List[Transform], cache_size=0):
if cache_size:
parts = [part.with_cache(cache_size) for part in parts]
super(ComposeTransform, self).__init__(cache_size=cache_size)
self.parts = parts

def __eq__(self, other):
if not isinstance(other, ComposeTransform):
return False
return self.parts == other.parts

@constraints.dependent_property(is_discrete=False)
def domain(self):
if not self.parts:
return constraints.real
domain = self.parts[0].domain
# Adjust event_dim to be maximum among all parts.
event_dim = self.parts[-1].codomain.event_dim
for part in reversed(self.parts):
event_dim += part.domain.event_dim - part.codomain.event_dim
event_dim = max(event_dim, part.domain.event_dim)
assert event_dim >= domain.event_dim
if event_dim > domain.event_dim:
domain = constraints.independent(domain, event_dim - domain.event_dim)
return domain

@constraints.dependent_property(is_discrete=False)
def codomain(self):
if not self.parts:
return constraints.real
codomain = self.parts[-1].codomain
# Adjust event_dim to be maximum among all parts.
event_dim = self.parts[0].domain.event_dim
for part in self.parts:
event_dim += part.codomain.event_dim - part.domain.event_dim
event_dim = max(event_dim, part.codomain.event_dim)
assert event_dim >= codomain.event_dim
if event_dim > codomain.event_dim:
codomain = constraints.independent(codomain, event_dim - codomain.event_dim)
return codomain

@lazy_property
def bijective(self):
return all(p.bijective for p in self.parts)

@lazy_property
def sign(self):
sign = 1
for p in self.parts:
sign = sign * p.sign
return sign

@property
def inv(self):
inv = None
if self._inv is not None:
inv = self._inv()
if inv is None:
inv = ComposeTransform([p.inv for p in reversed(self.parts)])
self._inv = weakref.ref(inv)
inv._inv = weakref.ref(self)
return inv

def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return ComposeTransform(self.parts, cache_size=cache_size)

def __call__(self, x):
for part in self.parts:
x = part(x)
return x

def log_abs_det_jacobian(self, x, y):
if not self.parts:
return zeros_like(x)

# Compute intermediates. This will be free if parts[:-1] are all cached.
xs = [x]
for part in self.parts[:-1]:
xs.append(part(xs[-1]))
xs.append(y)

terms = []
event_dim = self.domain.event_dim
for part, x_, y_ in zip(self.parts, xs[:-1], xs[1:]):
terms.append(_sum_rightmost(part.log_abs_det_jacobian(x_, y_),
event_dim - part.domain.event_dim))
event_dim += part.codomain.event_dim - part.domain.event_dim
return functools.reduce(operator.add, terms)

def forward_shape(self, shape):
for part in self.parts:
shape = part.forward_shape(shape)
return shape

def inverse_shape(self, shape):
for part in reversed(self.parts):
shape = part.inverse_shape(shape)
return shape

def __repr__(self):
fmt_string = self.__class__.__name__ + '(\n '
fmt_string += ',\n '.join([p.__repr__() for p in self.parts])
fmt_string += '\n)'
return fmt_string


identity_transform = ComposeTransform([])


class IndependentTransform(Transform):
def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0):
super().__init__(cache_size=cache_size)
self.base_transform = base_transform.with_cache(cache_size)
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims

def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return IndependentTransform(self.base_transform,
self.reinterpreted_batch_ndims,
cache_size=cache_size)

@constraints.dependent_property(is_discrete=False)
def domain(self):
return constraints.independent(self.base_transform.domain,
self.reinterpreted_batch_ndims)

@constraints.dependent_property(is_discrete=False)
def codomain(self):
return constraints.independent(self.base_transform.codomain,
self.reinterpreted_batch_ndims)

@property
def bijective(self):
return self.base_transform.bijective

@property
def sign(self):
return self.base_transform.sign

def _call(self, x):
if x.dim() < self.domain.event_dim:
raise ValueError("Too few dimensions on input")
return self.base_transform(x)

def _inverse(self, y):
if y.dim() < self.codomain.event_dim:
raise ValueError("Too few dimensions on input")
return self.base_transform.inv(y)

def log_abs_det_jacobian(self, x, y):
result = self.base_transform.log_abs_det_jacobian(x, y)
result = _sum_rightmost(result, self.reinterpreted_batch_ndims)
return result

def __repr__(self):
return f"{self.__class__.__name__}({repr(self.base_transform)}, {self.reinterpreted_batch_ndims})"

def forward_shape(self, shape):
return self.base_transform.forward_shape(shape)

def inverse_shape(self, shape):
return self.base_transform.inverse_shape(shape)


class ReshapeTransform(Transform):
bijective = True

def __init__(self, in_shape, out_shape, cache_size=0):
self.in_shape = Size(in_shape)
self.out_shape = Size(out_shape)
if self.in_shape.numel() != self.out_shape.numel():
raise ValueError("in_shape, out_shape have different numbers of elements")
super().__init__(cache_size=cache_size)

@constraints.dependent_property
def domain(self):
return constraints.independent(constraints.real, len(self.in_shape))

@constraints.dependent_property
def codomain(self):
return constraints.independent(constraints.real, len(self.out_shape))

def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return ReshapeTransform(self.in_shape, self.out_shape, cache_size=cache_size)

def _call(self, x):
batch_shape = x.shape[:x.dim() - len(self.in_shape)]
return x.reshape(batch_shape + self.out_shape)

def _inverse(self, y):
batch_shape = y.shape[:y.dim() - len(self.out_shape)]
return y.reshape(batch_shape + self.in_shape)

def log_abs_det_jacobian(self, x, y):
batch_shape = x.shape[:x.dim() - len(self.in_shape)]
return x.new_zeros(batch_shape)

def forward_shape(self, shape):
if len(shape) < len(self.in_shape):
raise ValueError("Too few dimensions on input")
cut = len(shape) - len(self.in_shape)
if shape[cut:] != self.in_shape:
raise ValueError("Shape mismatch: expected {} but got {}".format(shape[cut:], self.in_shape))
return shape[:cut] + self.out_shape

def inverse_shape(self, shape):
if len(shape) < len(self.out_shape):
raise ValueError("Too few dimensions on input")
cut = len(shape) - len(self.out_shape)
if shape[cut:] != self.out_shape:
raise ValueError("Shape mismatch: expected {} but got {}".format(shape[cut:], self.out_shape))
return shape[:cut] + self.in_shape


class ExpTransform(Transform):
domain = constraints.real
codomain = constraints.positive
bijective = True
sign = +1

def __eq__(self, other):
return isinstance(other, ExpTransform)

def _call(self, x):
return x.exp()

def _inverse(self, y):
return y.log()

def log_abs_det_jacobian(self, x, y):
return x


class PowerTransform(Transform):
domain = constraints.positive
codomain = constraints.positive
bijective = True
sign = +1

def __init__(self, exponent, cache_size=0):
super(PowerTransform, self).__init__(cache_size=cache_size)
self.exponent, = broadcast_all(exponent)

def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return PowerTransform(self.exponent, cache_size=cache_size)

def __eq__(self, other):
if not isinstance(other, PowerTransform):
return False
return self.exponent.eq(other.exponent).all().item()

def _call(self, x):
return x.pow(self.exponent)

def _inverse(self, y):
return y.pow(1 / self.exponent)

def log_abs_det_jacobian(self, x, y):
return (self.exponent * y / x).abs().log()

def forward_shape(self, shape):
return broadcast_shapes(shape, getattr(self.exponent, "shape", ()))

def inverse_shape(self, shape):
return broadcast_shapes(shape, getattr(self.exponent, "shape", ()))


def _clipped_sigmoid(x):
finfo = torch_finfo(x.dtype)
return clamp(sigmoid(x), min=finfo.tiny, max=1. - finfo.eps)


class SigmoidTransform(Transform):
domain = constraints.real
codomain = constraints.unit_interval
bijective = True
sign = +1

def __eq__(self, other):
return isinstance(other, SigmoidTransform)

def _call(self, x):
return _clipped_sigmoid(x)

def _inverse(self, y):
finfo = torch_finfo(y.dtype)
y = y.clamp(min=finfo.tiny, max=1. - finfo.eps)
return y.log() - (-y).log1p()

def log_abs_det_jacobian(self, x, y):
return -F.softplus(-x) - F.softplus(x) # pylint: disable=E1130


class SoftplusTransform(Transform):
domain = constraints.real
codomain = constraints.positive
bijective = True
sign = +1

def __eq__(self, other):
return isinstance(other, SoftplusTransform)

def _call(self, x):
return softplus(x)

def _inverse(self, y):
return (-y).expm1().neg().log() + y

def log_abs_det_jacobian(self, x, y):
return -softplus(-x) # pylint: disable=E1130


class TanhTransform(Transform):
domain = constraints.real
codomain = constraints.interval(-1.0, 1.0)
bijective = True
sign = +1

def __eq__(self, other):
return isinstance(other, TanhTransform)

def _call(self, x):
return x.tanh()

def _inverse(self, y):
return atanh(y)

def log_abs_det_jacobian(self, x, y):
return 2. * (math.log(2.) - x - softplus(-2. * x))


class AbsTransform(Transform):
domain = constraints.real
codomain = constraints.positive

def __eq__(self, other):
return isinstance(other, AbsTransform)

def _call(self, x):
return x.abs()

def _inverse(self, y):
return y


class AffineTransform(Transform):
bijective = True

def __init__(self, loc, scale, event_dim=0, cache_size=0):
super(AffineTransform, self).__init__(cache_size=cache_size)
self.loc = loc
self.scale = scale
self._event_dim = event_dim

@property
def event_dim(self):
return self._event_dim

@constraints.dependent_property(is_discrete=False)
def domain(self):
if self.event_dim == 0:
return constraints.real
return constraints.independent(constraints.real, self.event_dim)

@constraints.dependent_property(is_discrete=False)
def codomain(self):
if self.event_dim == 0:
return constraints.real
return constraints.independent(constraints.real, self.event_dim)

def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return AffineTransform(self.loc, self.scale, self.event_dim, cache_size=cache_size)

def __eq__(self, other):
if not isinstance(other, AffineTransform):
return False

if isinstance(self.loc, numbers.Number) and isinstance(other.loc, numbers.Number):
if self.loc != other.loc:
return False
else:
if not (self.loc == other.loc).all().item():
return False

if isinstance(self.scale, numbers.Number) and isinstance(other.scale, numbers.Number):
if self.scale != other.scale:
return False
else:
if not (self.scale == other.scale).all().item():
return False

return True

@property
def sign(self):
if isinstance(self.scale, numbers.Real):
return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0
return self.scale.sign()

def _call(self, x):
return self.loc + self.scale * x

def _inverse(self, y):
return (y - self.loc) / self.scale

def log_abs_det_jacobian(self, x, y):
shape = x.shape
scale = self.scale
if isinstance(scale, numbers.Real):
result = full_like(x, math.log(abs(scale)))
else:
result = torch_abs(scale).log()
if self.event_dim:
result_size = result.size()[:-self.event_dim] + (-1,)
result = result.view(result_size).sum(-1)
shape = shape[:-self.event_dim]
return result.expand(shape)

def forward_shape(self, shape):
return broadcast_shapes(shape,
getattr(self.loc, "shape", ()),
getattr(self.scale, "shape", ()))

def inverse_shape(self, shape):
return broadcast_shapes(shape,
getattr(self.loc, "shape", ()),
getattr(self.scale, "shape", ()))


class CorrCholeskyTransform(Transform):
domain = constraints.real_vector
codomain = constraints.corr_cholesky
bijective = True

def _call(self, x):
x = tanh(x)
eps = torch_finfo(x.dtype).eps
x = x.clamp(min=-1 + eps, max=1 - eps)
r = vec_to_tril_matrix(x, diag=-1)
# apply stick-breaking on the squared values
# Note that y = sign(r) * sqrt(z * z1m_cumprod)
# = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
z = r ** 2
z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
# Diagonal elements must be 1.
r = r + eye(r.shape[-1], dtype=r.dtype, device=r.device)
y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1)
return y

def _inverse(self, y):
# inverse stick-breaking
# See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
y_cumsum = 1 - cumsum(y * y, dim=-1)
y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1)
y_vec = tril_matrix_to_vec(y, diag=-1)
y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1)
t = y_vec / (y_cumsum_vec).sqrt()
# inverse of tanh
x = ((1 + t) / (1 - t)).log() / 2
return x

def log_abs_det_jacobian(self, x, y, intermediates=None):
unsupported_attr(intermediates)
# Because domain and codomain are two spaces with different dimensions, determinant of
# Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the
# flattened lower triangular part of `y`.

# See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
y1m_cumsum = 1 - (y * y).cumsum(dim=-1)
# by taking diagonal=-2, we don't need to shift z_cumprod to the right
# also works for 2 x 2 matrix
y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2)
stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1)
tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.)).sum(dim=-1)
return stick_breaking_logdet + tanh_logdet

def forward_shape(self, shape):
# Reshape from (..., N) to (..., D, D).
if len(shape) < 1:
raise ValueError("Too few dimensions on input")
N = shape[-1]
D = round((0.25 + 2 * N) ** 0.5 + 0.5)
if D * (D - 1) // 2 != N:
raise ValueError("Input is not a flattend lower-diagonal number")
return shape[:-1] + (D, D)

def inverse_shape(self, shape):
# Reshape from (..., D, D) to (..., N).
if len(shape) < 2:
raise ValueError("Too few dimensions on input")
if shape[-2] != shape[-1]:
raise ValueError("Input is not square")
D = shape[-1]
N = D * (D - 1) // 2
return shape[:-2] + (N,)


class SoftmaxTransform(Transform):
domain = constraints.real_vector
codomain = constraints.simplex

def __eq__(self, other):
return isinstance(other, SoftmaxTransform)

def _call(self, x):
logprobs = x
probs = (logprobs - logprobs.max(-1, True)[0]).exp()
return probs / probs.sum(-1, True)

def _inverse(self, y):
probs = y
return probs.log()

def forward_shape(self, shape):
if len(shape) < 1:
raise ValueError("Too few dimensions on input")
return shape

def inverse_shape(self, shape):
if len(shape) < 1:
raise ValueError("Too few dimensions on input")
return shape


class StickBreakingTransform(Transform):
domain = constraints.real_vector
codomain = constraints.simplex
bijective = True

def __eq__(self, other):
return isinstance(other, StickBreakingTransform)

def _call(self, x):
offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
z = _clipped_sigmoid(x - offset.log())
z_cumprod = (1 - z).cumprod(-1)
y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1)
return y

def _inverse(self, y):
y_crop = y[..., :-1]
offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1)
sf = 1 - y_crop.cumsum(-1)
# we clamp to make sure that sf is positive which sometimes does not
# happen when y[-1] ~ 0 or y[:-1].sum() ~ 1
sf = clamp(sf, min=torch_finfo(y.dtype).tiny)
x = y_crop.log() - sf.log() + offset.log()
return x

def log_abs_det_jacobian(self, x, y):
offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
x = x - offset.log()
# use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x)
detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1)
return detJ

def forward_shape(self, shape):
if len(shape) < 1:
raise ValueError("Too few dimensions on input")
return shape[:-1] + (shape[-1] + 1,)

def inverse_shape(self, shape):
if len(shape) < 1:
raise ValueError("Too few dimensions on input")
return shape[:-1] + (shape[-1] - 1,)


class LowerCholeskyTransform(Transform):
domain = constraints.independent(constraints.real, 2)
codomain = constraints.lower_cholesky

def __eq__(self, other):
return isinstance(other, LowerCholeskyTransform)

def _call(self, x):
return x.tril(-1) + x.diagonal(dim1=-2, dim2=-1).exp().diag_embed()

def _inverse(self, y):
return y.tril(-1) + y.diagonal(dim1=-2, dim2=-1).log().diag_embed()


class CatTransform(Transform):
transforms: List[Transform]

def __init__(self, tseq, dim=0, lengths=None, cache_size=0):
assert all(isinstance(t, Transform) for t in tseq)
if cache_size:
tseq = [t.with_cache(cache_size) for t in tseq]
super(CatTransform, self).__init__(cache_size=cache_size)
self.transforms = list(tseq)
if lengths is None:
lengths = [1] * len(self.transforms)
self.lengths = list(lengths)
assert len(self.lengths) == len(self.transforms)
self.dim = dim

@lazy_property
def event_dim(self):
return max(t.event_dim for t in self.transforms)

@lazy_property
def length(self):
return sum(self.lengths)

def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return CatTransform(self.transforms, self.dim, self.lengths, cache_size)

def _call(self, x):
assert -x.dim() <= self.dim < x.dim()
assert x.size(self.dim) == self.length
yslices = []
start = 0
for trans, length in zip(self.transforms, self.lengths):
xslice = x.narrow(self.dim, start, length)
yslices.append(trans(xslice))
start = start + length # avoid += for jit compat
return cat(yslices, dim=self.dim)

def _inverse(self, y):
assert -y.dim() <= self.dim < y.dim()
assert y.size(self.dim) == self.length
xslices = []
start = 0
for trans, length in zip(self.transforms, self.lengths):
yslice = y.narrow(self.dim, start, length)
xslices.append(trans.inv(yslice))
start = start + length # avoid += for jit compat
return cat(xslices, dim=self.dim)

def log_abs_det_jacobian(self, x, y):
assert -x.dim() <= self.dim < x.dim()
assert x.size(self.dim) == self.length
assert -y.dim() <= self.dim < y.dim()
assert y.size(self.dim) == self.length
logdetjacs = []
start = 0
for trans, length in zip(self.transforms, self.lengths):
xslice = x.narrow(self.dim, start, length)
yslice = y.narrow(self.dim, start, length)
logdetjac = trans.log_abs_det_jacobian(xslice, yslice)
if trans.event_dim < self.event_dim:
logdetjac = _sum_rightmost(logdetjac, self.event_dim - trans.event_dim)
logdetjacs.append(logdetjac)
start = start + length # avoid += for jit compat
# Decide whether to concatenate or sum.
dim = self.dim
if dim >= 0:
dim = dim - x.dim()
dim = dim + self.event_dim
if dim < 0:
return cat(logdetjacs, dim=dim)
else:
return sum(logdetjacs)

@property
def bijective(self):
return all(t.bijective for t in self.transforms)

@constraints.dependent_property
def domain(self):
return constraints.cat([t.domain for t in self.transforms],
self.dim, self.lengths)

@constraints.dependent_property
def codomain(self):
return constraints.cat([t.codomain for t in self.transforms],
self.dim, self.lengths)


class StackTransform(Transform):
transforms: List[Transform]

def __init__(self, tseq, dim=0, cache_size=0):
assert all(isinstance(t, Transform) for t in tseq)
if cache_size:
tseq = [t.with_cache(cache_size) for t in tseq]
super(StackTransform, self).__init__(cache_size=cache_size)
self.transforms = list(tseq)
self.dim = dim

def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return StackTransform(self.transforms, self.dim, cache_size)

def _slice(self, z):
return [z.select(self.dim, i) for i in range(z.size(self.dim))]

def _call(self, x):
assert -x.dim() <= self.dim < x.dim()
assert x.size(self.dim) == len(self.transforms)
yslices = []
for xslice, trans in zip(self._slice(x), self.transforms):
yslices.append(trans(xslice))
return stack(yslices, dim=self.dim)

def _inverse(self, y):
assert -y.dim() <= self.dim < y.dim()
assert y.size(self.dim) == len(self.transforms)
xslices = []
for yslice, trans in zip(self._slice(y), self.transforms):
xslices.append(trans.inv(yslice))
return stack(xslices, dim=self.dim)

def log_abs_det_jacobian(self, x, y):
assert -x.dim() <= self.dim < x.dim()
assert x.size(self.dim) == len(self.transforms)
assert -y.dim() <= self.dim < y.dim()
assert y.size(self.dim) == len(self.transforms)
logdetjacs = []
yslices = self._slice(y)
xslices = self._slice(x)
for xslice, yslice, trans in zip(xslices, yslices, self.transforms):
logdetjacs.append(trans.log_abs_det_jacobian(xslice, yslice))
return stack(logdetjacs, dim=self.dim)

@property
def bijective(self):
return all(t.bijective for t in self.transforms)

@constraints.dependent_property
def domain(self):
return constraints.stack([t.domain for t in self.transforms], self.dim)

@constraints.dependent_property
def codomain(self):
return constraints.stack([t.codomain for t in self.transforms], self.dim)


class CumulativeDistributionTransform(Transform):
bijective = True
codomain = constraints.unit_interval
sign = +1

def __init__(self, distribution, cache_size=0):
super(CumulativeDistributionTransform, self).__init__(cache_size=cache_size)
self.distribution = distribution

@property
def domain(self):
return self.distribution.support

def _call(self, x):
return self.distribution.cdf(x)

def _inverse(self, y):
return self.distribution.icdf(y)

def log_abs_det_jacobian(self, x, y):
return self.distribution.log_prob(x)

def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
return CumulativeDistributionTransform(self.distribution, cache_size=cache_size)

+ 80
- 0
mindtorch/torch/distributions/uniform.py View File

@@ -0,0 +1,80 @@
from numbers import Number

from mindtorch.torch._six import nan
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.utils import broadcast_all
from mindtorch.torch._C.Size import Size
from mindtorch.torch.functional import log as torch_log, rand as torch_rand, lt as torch_lt


class Uniform(Distribution):
arg_constraints = {'low': constraints.dependent(is_discrete=False, event_dim=0),
'high': constraints.dependent(is_discrete=False, event_dim=0)}
has_rsample = True

@property
def mean(self):
return (self.high + self.low) / 2

@property
def mode(self):
return nan * self.high

@property
def stddev(self):
return (self.high - self.low) / 12**0.5

@property
def variance(self):
return (self.high - self.low).pow(2) / 12

def __init__(self, low, high, validate_args=None):
self.low, self.high = broadcast_all(low, high)

if isinstance(low, Number) and isinstance(high, Number):
batch_shape = Size()
else:
batch_shape = self.low.size()
super(Uniform, self).__init__(batch_shape, validate_args=validate_args)

if self._validate_args and not torch_lt(self.low, self.high).all():
raise ValueError("Uniform is not defined when low>= high")

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Uniform, _instance)
batch_shape = Size(batch_shape)
new.low = self.low.expand(batch_shape)
new.high = self.high.expand(batch_shape)
super(Uniform, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return constraints.interval(self.low, self.high)

def rsample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
rand = torch_rand(shape, dtype=self.low.dtype, device=self.low.device)
return self.low + rand * (self.high - self.low)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
lb = self.low.le(value).type_as(self.low)
ub = self.high.gt(value).type_as(self.low)
return torch_log(lb.mul(ub)) - torch_log(self.high - self.low)

def cdf(self, value):
if self._validate_args:
self._validate_sample(value)
result = (value - self.low) / (self.high - self.low)
return result.clamp(min=0, max=1)

def icdf(self, value):
result = value * (self.high - self.low) + self.low
return result

def entropy(self):
return torch_log(self.high - self.low)

+ 108
- 0
mindtorch/torch/distributions/utils.py View File

@@ -0,0 +1,108 @@
from functools import update_wrapper
from numbers import Number
from typing import Dict, Any
from mindtorch.torch._default_dtype import get_default_dtype
from mindtorch.torch.functional import broadcast_tensors, empty, sigmoid, normal, zeros, ones, log, log1p, \
round as torch_round
from mindtorch.torch.common.dtype import finfo
from mindtorch.torch.conflict_functional import arange as torch_arange
from mindtorch.torch.autograd import enable_grad
from mindtorch.torch._C.Size import Size
from mindtorch.torch._C import _get_tracing_state
from ..tensor import Tensor as torch_Tensor, tensor as torch_tensor
from ..overrides import is_tensor_like
from ..nn.functional import softmax

euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant


def broadcast_all(*values):
if not all(is_tensor_like(v) or isinstance(v, Number)
for v in values):
raise ValueError('Input arguments must all be instances of numbers.Number, '
'torch.Tensor or objects implementing __torch_function__.')
if not all(is_tensor_like(v) for v in values):
options: Dict[str, Any] = dict(dtype=get_default_dtype())
for value in values:
if isinstance(value, torch_Tensor):
options = dict(dtype=value.dtype, device=value.device)
break
new_values = [v if is_tensor_like(v) else torch_tensor(v, **options)
for v in values]
return broadcast_tensors(*new_values)
return broadcast_tensors(*values)


def _standard_normal(shape, dtype, device):
if _get_tracing_state():
return normal(zeros(shape, dtype=dtype, device=device),
ones(shape, dtype=dtype, device=device))
return empty(shape, dtype=dtype, device=device).normal_()


def _sum_rightmost(value, dim):
if dim == 0:
return value
required_shape = value.shape[:-dim] + (-1,)
return value.reshape(required_shape).sum(-1)


def logits_to_probs(logits, is_binary=False):
if is_binary:
return sigmoid(logits)
return softmax(logits, dim=-1)


def clamp_probs(probs):
eps = finfo(probs.dtype).eps
return probs.clamp(min=eps, max=1 - eps)


def probs_to_logits(probs, is_binary=False):
ps_clamped = clamp_probs(probs)
if is_binary:
return log(ps_clamped) - log1p(-ps_clamped)
return log(ps_clamped)


class lazy_property:
def __init__(self, wrapped):
self.wrapped = wrapped
update_wrapper(self, wrapped)

def __get__(self, instance, obj_type=None):
if instance is None:
return _lazy_property_and_property(self.wrapped)
with enable_grad():
value = self.wrapped(instance)
setattr(instance, self.wrapped.__name__, value)
return value


class _lazy_property_and_property(lazy_property, property):
def __init__(self, wrapped): # pylint: disable=E0101
return property.__init__(self, wrapped)


def tril_matrix_to_vec(mat, diag=0):
n = mat.shape[-1]
if not _get_tracing_state() and (diag < -n or diag >= n):
raise ValueError(f'diag ({diag}) provided is outside [{-n}, {n-1}].')
arange = torch_arange(n, device=mat.device)
tril_mask = arange < arange.view(-1, 1) + (diag + 1)
vec = mat[..., tril_mask]
return vec


def vec_to_tril_matrix(vec, diag=0):
n = (-(1 + 2 * diag) + ((1 + 2 * diag)**2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1))**0.5) / 2
eps = finfo(vec.dtype).eps
if not _get_tracing_state() and (round(n) - n > eps):
raise ValueError(f'The size of last dimension is {vec.shape[-1]} which cannot be expressed as ' +
'the lower triangular part of a square D x D matrix.')
n = torch_round(n).long() if isinstance(n, torch_Tensor) else round(n)
mat = vec.new_zeros(vec.shape[:-1] + Size((n, n)))
arange = torch_arange(n, device=vec.device)
tril_mask = arange < arange.view(-1, 1) + (diag + 1)
mat[..., tril_mask] = vec
return mat

+ 115
- 0
mindtorch/torch/distributions/von_mises.py View File

@@ -0,0 +1,115 @@
import math

from mindspore import _no_grad as torch_no_grad
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.distribution import Distribution
from mindtorch.torch.distributions.utils import broadcast_all, lazy_property
from mindtorch.torch.functional import where, cos, zeros, rand, empty
from mindtorch.torch._C.Size import Size
import mindtorch.torch.common.dtype as mindtorch_dtype


def _eval_poly(y, coef):
coef = list(coef)
result = coef.pop()
while coef:
result = coef.pop() + y * result
return result


_I0_COEF_SMALL = [1.0, 3.5156229, 3.0899424, 1.2067492, 0.2659732, 0.360768e-1, 0.45813e-2]
_I0_COEF_LARGE = [0.39894228, 0.1328592e-1, 0.225319e-2, -0.157565e-2, 0.916281e-2,
-0.2057706e-1, 0.2635537e-1, -0.1647633e-1, 0.392377e-2]
_I1_COEF_SMALL = [0.5, 0.87890594, 0.51498869, 0.15084934, 0.2658733e-1, 0.301532e-2, 0.32411e-3]
_I1_COEF_LARGE = [0.39894228, -0.3988024e-1, -0.362018e-2, 0.163801e-2, -0.1031555e-1,
0.2282967e-1, -0.2895312e-1, 0.1787654e-1, -0.420059e-2]

_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]


def _log_modified_bessel_fn(x, order=0):
assert order in (0, 1)

# compute small solution
y = (x / 3.75)
y = y * y
small = _eval_poly(y, _COEF_SMALL[order])
if order == 1:
small = x.abs() * small
small = small.log()

# compute large solution
y = 3.75 / x
large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()

result = where(x < 3.75, small, large)
return result


def _rejection_sample(loc, concentration, proposal_r, x):
done = zeros(x.shape, dtype=mindtorch_dtype.bool, device=loc.device)
while not done.all():
u = rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
u1, u2, u3 = u.unbind()
z = cos(math.pi * u1)
f = (1 + proposal_r * z) / (proposal_r + z)
c = concentration * (proposal_r - f)
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
if accept.any():
x = where(accept, (u3 - 0.5).sign() * f.acos(), x)
done = done | accept
return (x + math.pi + loc) % (2 * math.pi) - math.pi


class VonMises(Distribution):
arg_constraints = {'loc': constraints.real, 'concentration': constraints.positive}
support = constraints.real
has_rsample = False

def __init__(self, loc, concentration, validate_args=None):
self.loc, self.concentration = broadcast_all(loc, concentration)
batch_shape = self.loc.shape
event_shape = Size()

# Parameters for sampling
tau = 1 + (1 + 4 * self.concentration ** 2).sqrt()
rho = (tau - (2 * tau).sqrt()) / (2 * self.concentration)
self._proposal_r = (1 + rho ** 2) / (2 * rho)

super(VonMises, self).__init__(batch_shape, event_shape, validate_args)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
log_prob = self.concentration * cos(value - self.loc)
log_prob = log_prob - math.log(2 * math.pi) - _log_modified_bessel_fn(self.concentration, order=0)
return log_prob

@torch_no_grad()
def sample(self, sample_shape=Size()):
shape = self._extended_shape(sample_shape)
x = empty(shape, dtype=self.loc.dtype, device=self.loc.device)
return _rejection_sample(self.loc, self.concentration, self._proposal_r, x)

def expand(self, batch_shape):
try:
return super(VonMises, self).expand(batch_shape)
except NotImplementedError:
validate_args = self.__dict__.get('_validate_args')
loc = self.loc.expand(batch_shape)
concentration = self.concentration.expand(batch_shape)
return type(self)(loc, concentration, validate_args=validate_args)

@property
def mean(self):
return self.loc

@property
def mode(self):
return self.loc

@lazy_property
def variance(self):
return 1 - (_log_modified_bessel_fn(self.concentration, order=1) -
_log_modified_bessel_fn(self.concentration, order=0)).exp()

+ 53
- 0
mindtorch/torch/distributions/weibull.py View File

@@ -0,0 +1,53 @@
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.exponential import Exponential
from mindtorch.torch.distributions.transformed_distribution import TransformedDistribution
from mindtorch.torch.distributions.transforms import AffineTransform, PowerTransform
from mindtorch.torch.distributions.utils import broadcast_all
from mindtorch.torch.distributions.gumbel import euler_constant
from mindtorch.torch.functional import ones_like, lgamma, exp as torch_exp, log as torch_log


class Weibull(TransformedDistribution):
arg_constraints = {'scale': constraints.positive, 'concentration': constraints.positive}
support = constraints.positive

def __init__(self, scale, concentration, validate_args=None):
self.scale, self.concentration = broadcast_all(scale, concentration)
self.concentration_reciprocal = self.concentration.reciprocal()
base_dist = Exponential(ones_like(self.scale), validate_args=validate_args)
transforms = [PowerTransform(exponent=self.concentration_reciprocal),
AffineTransform(loc=0, scale=self.scale)]
super(Weibull, self).__init__(base_dist,
transforms,
validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Weibull, _instance)
new.scale = self.scale.expand(batch_shape)
new.concentration = self.concentration.expand(batch_shape)
new.concentration_reciprocal = new.concentration.reciprocal()
base_dist = self.base_dist.expand(batch_shape)
transforms = [PowerTransform(exponent=new.concentration_reciprocal),
AffineTransform(loc=0, scale=new.scale)]
super(Weibull, new).__init__(base_dist,
transforms,
validate_args=False)
new._validate_args = self._validate_args
return new

@property
def mean(self):
return self.scale * torch_exp(lgamma(1 + self.concentration_reciprocal))

@property
def mode(self):
return self.scale * ((self.concentration - 1) / self.concentration) ** self.concentration.reciprocal()

@property
def variance(self):
return self.scale.pow(2) * (torch_exp(lgamma(1 + 2 * self.concentration_reciprocal)) -
torch_exp(2 * lgamma(1 + self.concentration_reciprocal)))

def entropy(self):
return euler_constant * (1 - self.concentration_reciprocal) + \
torch_log(self.scale * self.concentration_reciprocal) + 1

+ 264
- 0
mindtorch/torch/distributions/wishart.py View File

@@ -0,0 +1,264 @@
import math
import warnings
from numbers import Number

from mindtorch.torch.distributions.chi2 import Chi2
from mindtorch.torch._six import nan
from mindtorch.torch.distributions import constraints
from mindtorch.torch.distributions.exp_family import ExponentialFamily
from mindtorch.torch.distributions.utils import lazy_property
from mindtorch.torch.distributions.multivariate_normal import _precision_to_scale_tril
from mindtorch.torch.functional import digamma, broadcast_shapes, eye, cholesky_solve, einsum, tril_indices, randn, \
where, mvlgamma
from mindtorch.torch.conflict_functional import arange as torch_arange
from mindtorch.torch.common.dtype import finfo
from mindtorch.torch._C.Size import Size
from mindtorch.torch._C import _get_tracing_state
from mindtorch.torch.linalg import cholesky, slogdet
from mindtorch.torch.tensor import tensor as torch_tensor


_log_2 = math.log(2)


def _mvdigamma(x, p):
assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function."
return digamma(
x.unsqueeze(-1)
- torch_arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,))
).sum(-1)

def _clamp_above_eps(x):
return x.clamp(min=finfo(x.dtype).eps)

class Wishart(ExponentialFamily):
arg_constraints = {
'covariance_matrix': constraints.positive_definite,
'precision_matrix': constraints.positive_definite,
'scale_tril': constraints.lower_cholesky,
'df': constraints.greater_than(0),
}
support = constraints.positive_definite
has_rsample = True
_mean_carrier_measure = 0

def __init__(self,
df,
covariance_matrix=None,
precision_matrix=None,
scale_tril=None,
validate_args=None):
assert (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) == 1, \
"Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."

param = next(p for p in (covariance_matrix, precision_matrix, scale_tril) if p is not None)

if param.dim() < 2:
raise ValueError("scale_tril must be at least two-dimensional, with optional leading batch dimensions")

if isinstance(df, Number):
batch_shape = Size(param.shape[:-2])
self.df = torch_tensor(df, dtype=param.dtype, device=param.device)
else:
batch_shape = broadcast_shapes(param.shape[:-2], df.shape)
self.df = df.expand(batch_shape)
event_shape = param.shape[-2:]

if self.df.le(event_shape[-1] - 1).any():
raise ValueError(f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}.")

if scale_tril is not None:
self.scale_tril = param.expand(batch_shape + (-1, -1))
elif covariance_matrix is not None:
self.covariance_matrix = param.expand(batch_shape + (-1, -1))
elif precision_matrix is not None:
self.precision_matrix = param.expand(batch_shape + (-1, -1))

self.arg_constraints['df'] = constraints.greater_than(event_shape[-1] - 1)
if self.df.lt(event_shape[-1]).any():
warnings.warn("Low df values detected. Singular samples are highly likely to occur for ndim " \
"- 1 < df < ndim.")

super(Wishart, self).__init__(batch_shape, event_shape, validate_args=validate_args)
self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))]

if scale_tril is not None:
self._unbroadcasted_scale_tril = scale_tril
elif covariance_matrix is not None:
self._unbroadcasted_scale_tril = cholesky(covariance_matrix)
else: # precision_matrix is not None
self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)

# Chi2 distribution is needed for Bartlett decomposition sampling
self._dist_chi2 = Chi2(
df=(
self.df.unsqueeze(-1)
- torch_arange(
self._event_shape[-1],
dtype=self._unbroadcasted_scale_tril.dtype,
device=self._unbroadcasted_scale_tril.device,
).expand(batch_shape + (-1,))
)
)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Wishart, _instance)
batch_shape = Size(batch_shape)
cov_shape = batch_shape + self.event_shape
new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape)
new.df = self.df.expand(batch_shape)

new._batch_dims = [-(x + 1) for x in range(len(batch_shape))]

if 'covariance_matrix' in self.__dict__:
new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
if 'scale_tril' in self.__dict__:
new.scale_tril = self.scale_tril.expand(cov_shape)
if 'precision_matrix' in self.__dict__:
new.precision_matrix = self.precision_matrix.expand(cov_shape)

new._dist_chi2 = Chi2(
df=(
new.df.unsqueeze(-1)
- torch_arange(
self.event_shape[-1],
dtype=new._unbroadcasted_scale_tril.dtype,
device=new._unbroadcasted_scale_tril.device,
).expand(batch_shape + (-1,))
)
)

super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new

@lazy_property
def scale_tril(self): # pylint: disable=E0202
return self._unbroadcasted_scale_tril.expand(
self._batch_shape + self._event_shape)

@lazy_property
def covariance_matrix(self): # pylint: disable=E0202
return (
self._unbroadcasted_scale_tril @ self._unbroadcasted_scale_tril.transpose(-2, -1)
).expand(self._batch_shape + self._event_shape)

@lazy_property
def precision_matrix(self): # pylint: disable=E0202
identity = eye(
self._event_shape[-1],
device=self._unbroadcasted_scale_tril.device,
dtype=self._unbroadcasted_scale_tril.dtype,
)
return cholesky_solve(
identity, self._unbroadcasted_scale_tril
).expand(self._batch_shape + self._event_shape)

@property
def mean(self):
return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix

@property
def mode(self):
factor = self.df - self.covariance_matrix.shape[-1] - 1
factor[factor <= 0] = nan
return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix


@property
def variance(self):
V = self.covariance_matrix
diag_V = V.diagonal(dim1=-2, dim2=-1)
return self.df.view(self._batch_shape + (1, 1)) * (V.pow(2) + einsum("...i,...j->...ij", diag_V, diag_V))

def _bartlett_sampling(self, sample_shape=Size()):
p = self._event_shape[-1]
noise = _clamp_above_eps(
self._dist_chi2.rsample(sample_shape).sqrt()
).diag_embed(dim1=-2, dim2=-1)

i, j = tril_indices(p, p, offset=-1)
noise[..., i, j] = randn(
Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),),
dtype=noise.dtype,
device=noise.device,
)
chol = self._unbroadcasted_scale_tril @ noise
return chol @ chol.transpose(-2, -1)

def rsample(self, sample_shape=Size(), max_try_correction=None):
if max_try_correction is None:
max_try_correction = 3 if _get_tracing_state() else 10

sample_shape = Size(sample_shape)
sample = self._bartlett_sampling(sample_shape)

# Below part is to improve numerical stability temporally and should be removed in the future
is_singular = self.support.check(sample)
if self._batch_shape:
is_singular = is_singular.amax(self._batch_dims)

if _get_tracing_state():
# Less optimized version for JIT
for _ in range(max_try_correction):
sample_new = self._bartlett_sampling(sample_shape)
sample = where(is_singular, sample_new, sample)

is_singular = ~self.support.check(sample)
if self._batch_shape:
is_singular = is_singular.amax(self._batch_dims)

else:
# More optimized version with data-dependent control flow.
if is_singular.any():
warnings.warn("Singular sample detected.")

for _ in range(max_try_correction):
sample_new = self._bartlett_sampling(is_singular[is_singular].shape)
sample[is_singular] = sample_new

is_singular_new = ~self.support.check(sample_new)
if self._batch_shape:
is_singular_new = is_singular_new.amax(self._batch_dims)
is_singular[is_singular.clone()] = is_singular_new

if not is_singular.any():
break

return sample

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
nu = self.df # has shape (batch_shape)
p = self._event_shape[-1] # has singleton shape
return (
- nu * (p * _log_2 / 2 + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
- mvlgamma(nu / 2, p=p)
+ (nu - p - 1) / 2 * slogdet(value).logabsdet
- cholesky_solve(value, self._unbroadcasted_scale_tril).diagonal(dim1=-2, dim2=-1).sum(dim=-1) / 2
)

def entropy(self):
nu = self.df # has shape (batch_shape)
p = self._event_shape[-1] # has singleton shape
# V = self.covariance_matrix # has shape (batch_shape x event_shape)
return (
(p + 1) * (p * _log_2 / 2 + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
+ mvlgamma(nu / 2, p=p)
- (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p)
+ nu * p / 2
)

@property
def _natural_params(self):
nu = self.df # has shape (batch_shape)
p = self._event_shape[-1] # has singleton shape
return - self.precision_matrix / 2, (nu - p - 1) / 2

def _log_normalizer(self, x, y):
p = self._event_shape[-1]
return (
(y + (p + 1) / 2) * (- slogdet(- 2 * x).logabsdet + _log_2 * p)
+ mvlgamma(y + (p + 1) / 2, p=p)
)

+ 2
- 1
mindtorch/torchvision/transforms/autoaugment.py View File

@@ -569,7 +569,8 @@ class AugMix(torch.nn.Module):

def _sample_dirichlet(self, params: Tensor) -> Tensor:
# Must be on a separate method so that we can overwrite it in tests.
return torch._sample_dirichlet(params)
# return torch._sample_dirichlet(params)
raise NotImplementedError("Currently, `AugMix._sample_dirichlet` is not implemented.")

def forward(self, orig_img: Tensor) -> Tensor:
"""


Loading…
Cancel
Save