Transformed Distribution

Change: 126362065
This commit is contained in:
A. Unique TensorFlower 2016-06-30 16:15:10 -08:00 committed by TensorFlower Gardener
parent aabfe1d03e
commit e2587a0476
4 changed files with 346 additions and 0 deletions

View File

@ -171,6 +171,16 @@ cuda_py_tests(
],
)
cuda_py_tests(
name = "transformed_distribution_test",
size = "small",
srcs = ["python/kernel_tests/transformed_distribution_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:platform_test",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -47,6 +47,10 @@ initialized with parameters that define the distributions.
@@DirichletMultinomial
### Transformed distributions
@@ContinuousTransformedDistribution
## Operators allowing for matrix-free methods
### Positive definite operators
@ -95,4 +99,5 @@ from tensorflow.contrib.distributions.python.ops.operator_pd import *
from tensorflow.contrib.distributions.python.ops.operator_pd_cholesky import *
from tensorflow.contrib.distributions.python.ops.operator_pd_full import *
from tensorflow.contrib.distributions.python.ops.student_t import *
from tensorflow.contrib.distributions.python.ops.transformed_distribution import *
from tensorflow.contrib.distributions.python.ops.uniform import *

View File

@ -0,0 +1,79 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ContinuousTransformedDistribution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from scipy import stats
import tensorflow as tf
class ContinuousTransformedDistributionTest(tf.test.TestCase):
def testContinuousTransformedDistribution(self):
with self.test_session():
mu = 3.0
sigma = 0.02
log_normal = tf.contrib.distributions.ContinuousTransformedDistribution(
base_dist_cls=tf.contrib.distributions.Normal,
mu=mu,
sigma=sigma,
transform=lambda x: tf.exp(x),
inverse=lambda y: tf.log(y),
log_det_jacobian=(lambda x: tf.reduce_sum(x)))
# sample
self.assertAllClose([stats.lognorm.mean(s=sigma, scale=np.exp(mu))],
[np.mean(log_normal.sample(100000, seed=235).eval())],
atol=1e-2)
# pdf, log_pdf
test_vals = np.linspace(0.00001, 10.).astype(np.float32)
for test_val in test_vals:
expected = stats.lognorm.logpdf(test_val, s=sigma, scale=np.exp(mu))
self.assertAllClose([expected], [log_normal.log_pdf(test_val).eval()])
self.assertAllClose([np.exp(expected)],
[log_normal.pdf(test_val).eval()])
def testCachedSamplesWithoutInverse(self):
with self.test_session() as sess:
mu = 3.0
sigma = 0.02
log_normal = tf.contrib.distributions.ContinuousTransformedDistribution(
base_dist_cls=tf.contrib.distributions.Normal,
mu=mu,
sigma=sigma,
transform=lambda x: tf.exp(x),
inverse=None,
log_det_jacobian=(lambda x: tf.reduce_sum(x)))
sample = log_normal.sample(1)
sample_val, log_pdf_val = sess.run([sample, log_normal.log_pdf(sample)])
self.assertAllClose(
stats.lognorm.logpdf(sample_val, s=sigma,
scale=np.exp(mu)),
log_pdf_val,
atol=1e-2)
with self.assertRaisesRegexp(ValueError,
"was not returned from `sample`"):
log_normal.log_pdf(tf.constant(3.0))
if __name__ == "__main__":
tf.test.main()

View File

@ -0,0 +1,252 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A Transformed Distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long
from tensorflow.python.framework import ops
class ContinuousTransformedDistribution(distribution.ContinuousDistribution):
"""A Transformed Distribution.
A Transformed Distribution models `p(y)` given a base distribution `p(x)`,
an invertible transform, `y = f(x)`, and the determinant of the Jacobian of
`f(x)`.
Shapes, type, and reparameterization are taken from the base distribution.
#### Mathematical details
* `p(x)` - probability distribution for random variable X
* `p(y)` - probability distribution for random variable Y
* `f` - transform
* `g` - inverse transform, `f(g(x)) = x`
* `J(x)` - Jacobian of f(x)
A Transformed Distribution exposes `sample` and `pdf`:
* `sample`: `y = f(x)`, after drawing a sample of X.
* `pdf`: `p(y) = p(x) / det|J(x)| = p(g(y)) / det|J(g(y))|`
A simple example constructing a Log-Normal distribution from a Normal
distribution:
```
logit_normal = ContinuousTransformedDistribution(
base_dist=Normal(mu, sigma),
transform=lambda x: tf.sigmoid(x),
inverse=lambda y: tf.log(y) - tf.log(1. - y),
log_det_jacobian=(lambda x:
tf.reduce_sum(tf.log(tf.sigmoid(x)) + tf.log(1. - tf.sigmoid(x)),
reduction_indices=[-1])))
name="LogitNormalTransformedDistribution"
)
```
"""
def __init__(self,
base_dist_cls,
transform,
inverse,
log_det_jacobian,
name="ContinuousTransformedDistribution",
**base_dist_args):
"""Construct a Transformed Distribution.
Args:
base_dist_cls: the base distribution class to transform. Must be a
subclass of `ContinuousDistribution`.
transform: a callable that takes a `Tensor` sample from `base_dist` and
returns a `Tensor` of the same shape and type. `x => y`.
inverse: a callable that computes the inverse of transform. `y => x`. If
None, users can only call `log_pdf` on values returned by `sample`.
log_det_jacobian: a callable that takes a `Tensor` sample from `base_dist`
and returns the log of the determinant of the Jacobian of `transform`.
name: The name for the distribution.
**base_dist_args: kwargs to pass on to dist_cls on construction.
Raises:
TypeError: if `base_dist_cls` is not a subclass of
`ContinuousDistribution`.
"""
if not issubclass(base_dist_cls, distribution.ContinuousDistribution):
raise TypeError("base_dist_cls must be a subclass of"
"ContinuousDistribution.")
with ops.op_scope(base_dist_args.values(), name) as scope:
self._name = scope
self._base_dist = base_dist_cls(**base_dist_args)
self._transform = transform
self._inverse = inverse
self._log_det_jacobian = log_det_jacobian
self._inverse_cache = {}
@property
def name(self):
return self._name
@property
def dtype(self):
return self._base_dist.dtype
def batch_shape(self, name="batch_shape"):
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
The product of the dimensions of the `batch_shape` is the number of
independent distributions of this kind the instance represents.
Args:
name: name to give to the op.
Returns:
`Tensor` `batch_shape`
"""
with ops.name_scope(self.name):
return self._base_dist.batch_shape(name)
def get_batch_shape(self):
"""`TensorShape` available at graph construction time.
Same meaning as `batch_shape`. May be only partially defined.
Returns:
batch shape
"""
return self._base_dist.get_batch_shape()
def event_shape(self, name="event_shape"):
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
Args:
name: name to give to the op.
Returns:
`Tensor` `event_shape`
"""
with ops.name_scope(self.name):
return self._base_dist.event_shape(name)
def get_event_shape(self):
"""`TensorShape` available at graph construction time.
Same meaning as `event_shape`. May be only partially defined.
Returns:
event shape
"""
return self._base_dist.get_event_shape()
@property
def base_distribution(self):
"""Base distribution, p(x)."""
return self._base_dist
@property
def transform(self):
"""Function transforming x => y."""
return self._transform
@property
def inverse(self):
"""Inverse function of transform, y => x."""
return self._inverse
@property
def log_det_jacobian(self):
"""Function computing the log determinant of the Jacobian of transform."""
return self._log_det_jacobian
def log_pdf(self, y, name="log_pdf"):
"""Log pdf of observations in `y`.
`log ( p(g(y)) / det|J(g(y))| )`, where `g` is the inverse of `transform`.
Args:
y: tensor of dtype `dtype`.
name: The name to give this op.
Returns:
log_pdf: tensor of dtype `dtype`, the log-PDFs of `y`.
Raises:
ValueError: if `inverse` was not provided to the distribution and `y` was
not returned from `sample`.
"""
with ops.name_scope(self.name):
with ops.op_scope([y], name):
y = ops.convert_to_tensor(y)
if y.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
(y.dtype, self.dtype))
with ops.name_scope("inverse"):
if y in self._inverse_cache:
x = self._inverse_cache[y]
elif self._inverse:
x = self._inverse(y)
else:
raise ValueError("No inverse function exists and input `y` was not "
"returned from `sample`.")
with ops.name_scope("log_det_jacobian"):
log_det_jacobian = self._log_det_jacobian(x)
return self._base_dist.log_likelihood(x) - log_det_jacobian
def pdf(self, y, name="pdf"):
"""The PDF of observations in `y`.
`p(g(y)) / det|J(g(y))|`, where `g` is the inverse of `transform`.
Args:
y: `Tensor` of dtype `dtype`.
name: The name to give this op.
Returns:
pdf: `Tensor` of dtype `dtype`, the pdf values of `y`.
"""
return super(ContinuousTransformedDistribution, self).pdf(y, name=name)
def sample(self, n, seed=None, name="sample"):
"""Sample `n` observations.
Samples from the base distribution and then passes through the transform.
Args:
n: scalar, type int32, the number of observations to sample.
seed: Python integer, the random seed.
name: The name to give this op.
Returns:
samples: `[n, ...]`, a `Tensor` of `n` samples.
"""
with ops.name_scope(self.name):
with ops.name_scope(name):
samples = self._base_dist.sample(n=n, seed=seed)
with ops.name_scope("transform"):
transformed = self._transform(samples)
self._inverse_cache[transformed] = samples
return transformed
@property
def is_reparameterized(self):
return self._base_dist.is_reparameterized
@property
def strict_statistics(self):
return self._base_dist.strict_statistics
@property
def strict(self):
return self._base_dist.strict