STT-tensorflow/tensorflow/python/ops/distributions/transformed_distribution.py
2019-07-09 16:51:29 +05:30

645 lines
27 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A Transformed Distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
from tensorflow.python.ops.distributions import identity_bijector
from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util import deprecation
__all__ = [
"TransformedDistribution",
]
# The following helper functions attempt to statically perform a TF operation.
# These functions make debugging easier since we can do more validation during
# graph construction.
def _static_value(x):
"""Returns the static value of a `Tensor` or `None`."""
return tensor_util.constant_value(ops.convert_to_tensor(x))
def _logical_and(*args):
"""Convenience function which attempts to statically `reduce_all`."""
args_ = [_static_value(x) for x in args]
if any(x is not None and not bool(x) for x in args_):
return constant_op.constant(False)
if all(x is not None and bool(x) for x in args_):
return constant_op.constant(True)
if len(args) == 2:
return math_ops.logical_and(*args)
return math_ops.reduce_all(args)
def _logical_equal(x, y):
"""Convenience function which attempts to statically compute `x == y`."""
x_ = _static_value(x)
y_ = _static_value(y)
if x_ is None or y_ is None:
return math_ops.equal(x, y)
return constant_op.constant(np.array_equal(x_, y_))
def _logical_not(x):
"""Convenience function which attempts to statically apply `logical_not`."""
x_ = _static_value(x)
if x_ is None:
return math_ops.logical_not(x)
return constant_op.constant(np.logical_not(x_))
def _concat_vectors(*args):
"""Convenience function which concatenates input vectors."""
args_ = [_static_value(x) for x in args]
if any(x_ is None for x_ in args_):
return array_ops.concat(args, 0)
return constant_op.constant([x_ for vec_ in args_ for x_ in vec_])
def _pick_scalar_condition(pred, cond_true, cond_false):
"""Convenience function which chooses the condition based on the predicate."""
# Note: This function is only valid if all of pred, cond_true, and cond_false
# are scalars. This means its semantics are arguably more like tf.cond than
# tf.select even though we use tf.select to implement it.
pred_ = _static_value(pred)
if pred_ is None:
return array_ops.where_v2(pred, cond_true, cond_false)
return cond_true if pred_ else cond_false
def _ones_like(x):
"""Convenience function attempts to statically construct `ones_like`."""
# Should only be used for small vectors.
if x.get_shape().is_fully_defined():
return array_ops.ones(x.get_shape().as_list(), dtype=x.dtype)
return array_ops.ones_like(x)
def _ndims_from_shape(shape):
"""Returns `Tensor`'s `rank` implied by a `Tensor` shape."""
if shape.get_shape().ndims not in (None, 1):
raise ValueError("input is not a valid shape: not 1D")
if not shape.dtype.is_integer:
raise TypeError("input is not a valid shape: wrong dtype")
if shape.get_shape().is_fully_defined():
return constant_op.constant(shape.get_shape().as_list()[0])
return array_ops.shape(shape)[0]
def _is_scalar_from_shape(shape):
"""Returns `True` `Tensor` if `Tensor` shape implies a scalar."""
return _logical_equal(_ndims_from_shape(shape), 0)
class TransformedDistribution(distribution_lib.Distribution):
"""A Transformed Distribution.
A `TransformedDistribution` models `p(y)` given a base distribution `p(x)`,
and a deterministic, invertible, differentiable transform, `Y = g(X)`. The
transform is typically an instance of the `Bijector` class and the base
distribution is typically an instance of the `Distribution` class.
A `Bijector` is expected to implement the following functions:
- `forward`,
- `inverse`,
- `inverse_log_det_jacobian`.
The semantics of these functions are outlined in the `Bijector` documentation.
We now describe how a `TransformedDistribution` alters the input/outputs of a
`Distribution` associated with a random variable (rv) `X`.
Write `cdf(Y=y)` for an absolutely continuous cumulative distribution function
of random variable `Y`; write the probability density function `pdf(Y=y) :=
d^k / (dy_1,...,dy_k) cdf(Y=y)` for its derivative wrt to `Y` evaluated at
`y`. Assume that `Y = g(X)` where `g` is a deterministic diffeomorphism,
i.e., a non-random, continuous, differentiable, and invertible function.
Write the inverse of `g` as `X = g^{-1}(Y)` and `(J o g)(x)` for the Jacobian
of `g` evaluated at `x`.
A `TransformedDistribution` implements the following operations:
* `sample`
Mathematically: `Y = g(X)`
Programmatically: `bijector.forward(distribution.sample(...))`
* `log_prob`
Mathematically: `(log o pdf)(Y=y) = (log o pdf o g^{-1})(y)
+ (log o abs o det o J o g^{-1})(y)`
Programmatically: `(distribution.log_prob(bijector.inverse(y))
+ bijector.inverse_log_det_jacobian(y))`
* `log_cdf`
Mathematically: `(log o cdf)(Y=y) = (log o cdf o g^{-1})(y)`
Programmatically: `distribution.log_cdf(bijector.inverse(x))`
* and similarly for: `cdf`, `prob`, `log_survival_function`,
`survival_function`.
A simple example constructing a Log-Normal distribution from a Normal
distribution:
```python
ds = tfp.distributions
log_normal = ds.TransformedDistribution(
distribution=ds.Normal(loc=0., scale=1.),
bijector=ds.bijectors.Exp(),
name="LogNormalTransformedDistribution")
```
A `LogNormal` made from callables:
```python
ds = tfp.distributions
log_normal = ds.TransformedDistribution(
distribution=ds.Normal(loc=0., scale=1.),
bijector=ds.bijectors.Inline(
forward_fn=tf.exp,
inverse_fn=tf.math.log,
inverse_log_det_jacobian_fn=(
lambda y: -tf.reduce_sum(tf.math.log(y), axis=-1)),
name="LogNormalTransformedDistribution")
```
Another example constructing a Normal from a StandardNormal:
```python
ds = tfp.distributions
normal = ds.TransformedDistribution(
distribution=ds.Normal(loc=0., scale=1.),
bijector=ds.bijectors.Affine(
shift=-1.,
scale_identity_multiplier=2.)
name="NormalTransformedDistribution")
```
A `TransformedDistribution`'s batch- and event-shape are implied by the base
distribution unless explicitly overridden by `batch_shape` or `event_shape`
arguments. Specifying an overriding `batch_shape` (`event_shape`) is
permitted only if the base distribution has scalar batch-shape (event-shape).
The bijector is applied to the distribution as if the distribution possessed
the overridden shape(s). The following example demonstrates how to construct a
multivariate Normal as a `TransformedDistribution`.
```python
ds = tfp.distributions
# We will create two MVNs with batch_shape = event_shape = 2.
mean = [[-1., 0], # batch:0
[0., 1]] # batch:1
chol_cov = [[[1., 0],
[0, 1]], # batch:0
[[1, 0],
[2, 2]]] # batch:1
mvn1 = ds.TransformedDistribution(
distribution=ds.Normal(loc=0., scale=1.),
bijector=ds.bijectors.Affine(shift=mean, scale_tril=chol_cov),
batch_shape=[2], # Valid because base_distribution.batch_shape == [].
event_shape=[2]) # Valid because base_distribution.event_shape == [].
mvn2 = ds.MultivariateNormalTriL(loc=mean, scale_tril=chol_cov)
# mvn1.log_prob(x) == mvn2.log_prob(x)
```
"""
@deprecation.deprecated(
"2019-01-01",
"The TensorFlow Distributions library has moved to "
"TensorFlow Probability "
"(https://github.com/tensorflow/probability). You "
"should update all references to use `tfp.distributions` "
"instead of `tf.distributions`.",
warn_once=True)
def __init__(self,
distribution,
bijector=None,
batch_shape=None,
event_shape=None,
validate_args=False,
name=None):
"""Construct a Transformed Distribution.
Args:
distribution: The base distribution instance to transform. Typically an
instance of `Distribution`.
bijector: The object responsible for calculating the transformation.
Typically an instance of `Bijector`. `None` means `Identity()`.
batch_shape: `integer` vector `Tensor` which overrides `distribution`
`batch_shape`; valid only if `distribution.is_scalar_batch()`.
event_shape: `integer` vector `Tensor` which overrides `distribution`
`event_shape`; valid only if `distribution.is_scalar_event()`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
name: Python `str` name prefixed to Ops created by this class. Default:
`bijector.name + distribution.name`.
"""
parameters = dict(locals())
name = name or (("" if bijector is None else bijector.name) +
distribution.name)
with ops.name_scope(name, values=[event_shape, batch_shape]) as name:
# For convenience we define some handy constants.
self._zero = constant_op.constant(0, dtype=dtypes.int32, name="zero")
self._empty = constant_op.constant([], dtype=dtypes.int32, name="empty")
if bijector is None:
bijector = identity_bijector.Identity(validate_args=validate_args)
# We will keep track of a static and dynamic version of
# self._is_{batch,event}_override. This way we can do more prior to graph
# execution, including possibly raising Python exceptions.
self._override_batch_shape = self._maybe_validate_shape_override(
batch_shape, distribution.is_scalar_batch(), validate_args,
"batch_shape")
self._is_batch_override = _logical_not(_logical_equal(
_ndims_from_shape(self._override_batch_shape), self._zero))
self._is_maybe_batch_override = bool(
tensor_util.constant_value(self._override_batch_shape) is None or
tensor_util.constant_value(self._override_batch_shape).size != 0)
self._override_event_shape = self._maybe_validate_shape_override(
event_shape, distribution.is_scalar_event(), validate_args,
"event_shape")
self._is_event_override = _logical_not(_logical_equal(
_ndims_from_shape(self._override_event_shape), self._zero))
self._is_maybe_event_override = bool(
tensor_util.constant_value(self._override_event_shape) is None or
tensor_util.constant_value(self._override_event_shape).size != 0)
# To convert a scalar distribution into a multivariate distribution we
# will draw dims from the sample dims, which are otherwise iid. This is
# easy to do except in the case that the base distribution has batch dims
# and we're overriding event shape. When that case happens the event dims
# will incorrectly be to the left of the batch dims. In this case we'll
# cyclically permute left the new dims.
self._needs_rotation = _logical_and(
self._is_event_override,
_logical_not(self._is_batch_override),
_logical_not(distribution.is_scalar_batch()))
override_event_ndims = _ndims_from_shape(self._override_event_shape)
self._rotate_ndims = _pick_scalar_condition(
self._needs_rotation, override_event_ndims, 0)
# We'll be reducing the head dims (if at all), i.e., this will be []
# if we don't need to reduce.
self._reduce_event_indices = math_ops.range(
self._rotate_ndims - override_event_ndims, self._rotate_ndims)
self._distribution = distribution
self._bijector = bijector
super(TransformedDistribution, self).__init__(
dtype=self._distribution.dtype,
reparameterization_type=self._distribution.reparameterization_type,
validate_args=validate_args,
allow_nan_stats=self._distribution.allow_nan_stats,
parameters=parameters,
# We let TransformedDistribution access _graph_parents since this class
# is more like a baseclass than derived.
graph_parents=(distribution._graph_parents + # pylint: disable=protected-access
bijector.graph_parents),
name=name)
@property
def distribution(self):
"""Base distribution, p(x)."""
return self._distribution
@property
def bijector(self):
"""Function transforming x => y."""
return self._bijector
def _event_shape_tensor(self):
return self.bijector.forward_event_shape_tensor(
distribution_util.pick_vector(
self._is_event_override,
self._override_event_shape,
self.distribution.event_shape_tensor()))
def _event_shape(self):
# If there's a chance that the event_shape has been overridden, we return
# what we statically know about the `event_shape_override`. This works
# because: `_is_maybe_event_override` means `static_override` is `None` or a
# non-empty list, i.e., we don't statically know the `event_shape` or we do.
#
# Since the `bijector` may change the `event_shape`, we then forward what we
# know to the bijector. This allows the `bijector` to have final say in the
# `event_shape`.
static_override = tensor_util.constant_value_as_shape(
self._override_event_shape)
return self.bijector.forward_event_shape(
static_override
if self._is_maybe_event_override
else self.distribution.event_shape)
def _batch_shape_tensor(self):
return distribution_util.pick_vector(
self._is_batch_override,
self._override_batch_shape,
self.distribution.batch_shape_tensor())
def _batch_shape(self):
# If there's a chance that the batch_shape has been overridden, we return
# what we statically know about the `batch_shape_override`. This works
# because: `_is_maybe_batch_override` means `static_override` is `None` or a
# non-empty list, i.e., we don't statically know the `batch_shape` or we do.
#
# Notice that this implementation parallels the `_event_shape` except that
# the `bijector` doesn't get to alter the `batch_shape`. Recall that
# `batch_shape` is a property of a distribution while `event_shape` is
# shared between both the `distribution` instance and the `bijector`.
static_override = tensor_util.constant_value_as_shape(
self._override_batch_shape)
return (static_override
if self._is_maybe_batch_override
else self.distribution.batch_shape)
def _sample_n(self, n, seed=None):
sample_shape = _concat_vectors(
distribution_util.pick_vector(self._needs_rotation, self._empty, [n]),
self._override_batch_shape,
self._override_event_shape,
distribution_util.pick_vector(self._needs_rotation, [n], self._empty))
x = self.distribution.sample(sample_shape=sample_shape, seed=seed)
x = self._maybe_rotate_dims(x)
# We'll apply the bijector in the `_call_sample_n` function.
return x
def _call_sample_n(self, sample_shape, seed, name, **kwargs):
# We override `_call_sample_n` rather than `_sample_n` so we can ensure that
# the result of `self.bijector.forward` is not modified (and thus caching
# works).
with self._name_scope(name, values=[sample_shape]):
sample_shape = ops.convert_to_tensor(
sample_shape, dtype=dtypes.int32, name="sample_shape")
sample_shape, n = self._expand_sample_shape_to_vector(
sample_shape, "sample_shape")
# First, generate samples. We will possibly generate extra samples in the
# event that we need to reinterpret the samples as part of the
# event_shape.
x = self._sample_n(n, seed, **kwargs)
# Next, we reshape `x` into its final form. We do this prior to the call
# to the bijector to ensure that the bijector caching works.
batch_event_shape = array_ops.shape(x)[1:]
final_shape = array_ops.concat([sample_shape, batch_event_shape], 0)
x = array_ops.reshape(x, final_shape)
# Finally, we apply the bijector's forward transformation. For caching to
# work, it is imperative that this is the last modification to the
# returned result.
y = self.bijector.forward(x, **kwargs)
y = self._set_sample_static_shape(y, sample_shape)
return y
def _log_prob(self, y):
# For caching to work, it is imperative that the bijector is the first to
# modify the input.
x = self.bijector.inverse(y)
event_ndims = self._maybe_get_static_event_ndims()
ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)
if self.bijector._is_injective: # pylint: disable=protected-access
return self._finish_log_prob_for_one_fiber(y, x, ildj, event_ndims)
lp_on_fibers = [
self._finish_log_prob_for_one_fiber(y, x_i, ildj_i, event_ndims)
for x_i, ildj_i in zip(x, ildj)]
return math_ops.reduce_logsumexp(array_ops.stack(lp_on_fibers), axis=0)
def _finish_log_prob_for_one_fiber(self, y, x, ildj, event_ndims):
"""Finish computation of log_prob on one element of the inverse image."""
x = self._maybe_rotate_dims(x, rotate_right=True)
log_prob = self.distribution.log_prob(x)
if self._is_maybe_event_override:
log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices)
log_prob += math_ops.cast(ildj, log_prob.dtype)
if self._is_maybe_event_override and isinstance(event_ndims, int):
log_prob.set_shape(
array_ops.broadcast_static_shape(
y.get_shape().with_rank_at_least(1)[:-event_ndims],
self.batch_shape))
return log_prob
def _prob(self, y):
x = self.bijector.inverse(y)
event_ndims = self._maybe_get_static_event_ndims()
ildj = self.bijector.inverse_log_det_jacobian(y, event_ndims=event_ndims)
if self.bijector._is_injective: # pylint: disable=protected-access
return self._finish_prob_for_one_fiber(y, x, ildj, event_ndims)
prob_on_fibers = [
self._finish_prob_for_one_fiber(y, x_i, ildj_i, event_ndims)
for x_i, ildj_i in zip(x, ildj)]
return sum(prob_on_fibers)
def _finish_prob_for_one_fiber(self, y, x, ildj, event_ndims):
"""Finish computation of prob on one element of the inverse image."""
x = self._maybe_rotate_dims(x, rotate_right=True)
prob = self.distribution.prob(x)
if self._is_maybe_event_override:
prob = math_ops.reduce_prod(prob, self._reduce_event_indices)
prob *= math_ops.exp(math_ops.cast(ildj, prob.dtype))
if self._is_maybe_event_override and isinstance(event_ndims, int):
prob.set_shape(
array_ops.broadcast_static_shape(
y.get_shape().with_rank_at_least(1)[:-event_ndims],
self.batch_shape))
return prob
def _log_cdf(self, y):
if self._is_maybe_event_override:
raise NotImplementedError("log_cdf is not implemented when overriding "
"event_shape")
if not self.bijector._is_injective: # pylint: disable=protected-access
raise NotImplementedError("log_cdf is not implemented when "
"bijector is not injective.")
x = self.bijector.inverse(y)
return self.distribution.log_cdf(x)
def _cdf(self, y):
if self._is_maybe_event_override:
raise NotImplementedError("cdf is not implemented when overriding "
"event_shape")
if not self.bijector._is_injective: # pylint: disable=protected-access
raise NotImplementedError("cdf is not implemented when "
"bijector is not injective.")
x = self.bijector.inverse(y)
return self.distribution.cdf(x)
def _log_survival_function(self, y):
if self._is_maybe_event_override:
raise NotImplementedError("log_survival_function is not implemented when "
"overriding event_shape")
if not self.bijector._is_injective: # pylint: disable=protected-access
raise NotImplementedError("log_survival_function is not implemented when "
"bijector is not injective.")
x = self.bijector.inverse(y)
return self.distribution.log_survival_function(x)
def _survival_function(self, y):
if self._is_maybe_event_override:
raise NotImplementedError("survival_function is not implemented when "
"overriding event_shape")
if not self.bijector._is_injective: # pylint: disable=protected-access
raise NotImplementedError("survival_function is not implemented when "
"bijector is not injective.")
x = self.bijector.inverse(y)
return self.distribution.survival_function(x)
def _quantile(self, value):
if self._is_maybe_event_override:
raise NotImplementedError("quantile is not implemented when overriding "
"event_shape")
if not self.bijector._is_injective: # pylint: disable=protected-access
raise NotImplementedError("quantile is not implemented when "
"bijector is not injective.")
# x_q is the "qth quantile" of X iff q = P[X <= x_q]. Now, since X =
# g^{-1}(Y), q = P[X <= x_q] = P[g^{-1}(Y) <= x_q] = P[Y <= g(x_q)],
# implies the qth quantile of Y is g(x_q).
inv_cdf = self.distribution.quantile(value)
return self.bijector.forward(inv_cdf)
def _entropy(self):
if not self.bijector.is_constant_jacobian:
raise NotImplementedError("entropy is not implemented")
if not self.bijector._is_injective: # pylint: disable=protected-access
raise NotImplementedError("entropy is not implemented when "
"bijector is not injective.")
# Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It
# can be shown that:
# H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)].
# If is_constant_jacobian then:
# E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c)
# where c can by anything.
entropy = self.distribution.entropy()
if self._is_maybe_event_override:
# H[X] = sum_i H[X_i] if X_i are mutually independent.
# This means that a reduce_sum is a simple rescaling.
entropy *= math_ops.cast(math_ops.reduce_prod(self._override_event_shape),
dtype=entropy.dtype.base_dtype)
if self._is_maybe_batch_override:
new_shape = array_ops.concat([
_ones_like(self._override_batch_shape),
self.distribution.batch_shape_tensor()
], 0)
entropy = array_ops.reshape(entropy, new_shape)
multiples = array_ops.concat([
self._override_batch_shape,
_ones_like(self.distribution.batch_shape_tensor())
], 0)
entropy = array_ops.tile(entropy, multiples)
dummy = array_ops.zeros(
shape=array_ops.concat(
[self.batch_shape_tensor(), self.event_shape_tensor()],
0),
dtype=self.dtype)
event_ndims = (self.event_shape.ndims if self.event_shape.ndims is not None
else array_ops.size(self.event_shape_tensor()))
ildj = self.bijector.inverse_log_det_jacobian(
dummy, event_ndims=event_ndims)
entropy -= math_ops.cast(ildj, entropy.dtype)
entropy.set_shape(self.batch_shape)
return entropy
def _maybe_validate_shape_override(self, override_shape, base_is_scalar,
validate_args, name):
"""Helper to __init__ which ensures override batch/event_shape are valid."""
if override_shape is None:
override_shape = []
override_shape = ops.convert_to_tensor(override_shape, dtype=dtypes.int32,
name=name)
if not override_shape.dtype.is_integer:
raise TypeError("shape override must be an integer")
override_is_scalar = _is_scalar_from_shape(override_shape)
if tensor_util.constant_value(override_is_scalar):
return self._empty
dynamic_assertions = []
if override_shape.get_shape().ndims is not None:
if override_shape.get_shape().ndims != 1:
raise ValueError("shape override must be a vector")
elif validate_args:
dynamic_assertions += [check_ops.assert_rank(
override_shape, 1,
message="shape override must be a vector")]
if tensor_util.constant_value(override_shape) is not None:
if any(s <= 0 for s in tensor_util.constant_value(override_shape)):
raise ValueError("shape override must have positive elements")
elif validate_args:
dynamic_assertions += [check_ops.assert_positive(
override_shape,
message="shape override must have positive elements")]
is_both_nonscalar = _logical_and(_logical_not(base_is_scalar),
_logical_not(override_is_scalar))
if tensor_util.constant_value(is_both_nonscalar) is not None:
if tensor_util.constant_value(is_both_nonscalar):
raise ValueError("base distribution not scalar")
elif validate_args:
dynamic_assertions += [check_ops.assert_equal(
is_both_nonscalar, False,
message="base distribution not scalar")]
if not dynamic_assertions:
return override_shape
return control_flow_ops.with_dependencies(
dynamic_assertions, override_shape)
def _maybe_rotate_dims(self, x, rotate_right=False):
"""Helper which rolls left event_dims left or right event_dims right."""
needs_rotation_const = tensor_util.constant_value(self._needs_rotation)
if needs_rotation_const is not None and not needs_rotation_const:
return x
ndims = array_ops.rank(x)
n = (ndims - self._rotate_ndims) if rotate_right else self._rotate_ndims
return array_ops.transpose(
x, _concat_vectors(math_ops.range(n, ndims), math_ops.range(0, n)))
def _maybe_get_static_event_ndims(self):
if self.event_shape.ndims is not None:
return self.event_shape.ndims
event_ndims = array_ops.size(self.event_shape_tensor())
event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
if event_ndims_ is not None:
return event_ndims_
return event_ndims