Transformed Distribution
Change: 126362065
This commit is contained in:
parent
aabfe1d03e
commit
e2587a0476
@ -171,6 +171,16 @@ cuda_py_tests(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "transformed_distribution_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/transformed_distribution_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -47,6 +47,10 @@ initialized with parameters that define the distributions.
|
||||
|
||||
@@DirichletMultinomial
|
||||
|
||||
### Transformed distributions
|
||||
|
||||
@@ContinuousTransformedDistribution
|
||||
|
||||
## Operators allowing for matrix-free methods
|
||||
|
||||
### Positive definite operators
|
||||
@ -95,4 +99,5 @@ from tensorflow.contrib.distributions.python.ops.operator_pd import *
|
||||
from tensorflow.contrib.distributions.python.ops.operator_pd_cholesky import *
|
||||
from tensorflow.contrib.distributions.python.ops.operator_pd_full import *
|
||||
from tensorflow.contrib.distributions.python.ops.student_t import *
|
||||
from tensorflow.contrib.distributions.python.ops.transformed_distribution import *
|
||||
from tensorflow.contrib.distributions.python.ops.uniform import *
|
||||
|
@ -0,0 +1,79 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for ContinuousTransformedDistribution."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class ContinuousTransformedDistributionTest(tf.test.TestCase):
|
||||
|
||||
def testContinuousTransformedDistribution(self):
|
||||
with self.test_session():
|
||||
mu = 3.0
|
||||
sigma = 0.02
|
||||
log_normal = tf.contrib.distributions.ContinuousTransformedDistribution(
|
||||
base_dist_cls=tf.contrib.distributions.Normal,
|
||||
mu=mu,
|
||||
sigma=sigma,
|
||||
transform=lambda x: tf.exp(x),
|
||||
inverse=lambda y: tf.log(y),
|
||||
log_det_jacobian=(lambda x: tf.reduce_sum(x)))
|
||||
|
||||
# sample
|
||||
self.assertAllClose([stats.lognorm.mean(s=sigma, scale=np.exp(mu))],
|
||||
[np.mean(log_normal.sample(100000, seed=235).eval())],
|
||||
atol=1e-2)
|
||||
|
||||
# pdf, log_pdf
|
||||
test_vals = np.linspace(0.00001, 10.).astype(np.float32)
|
||||
for test_val in test_vals:
|
||||
expected = stats.lognorm.logpdf(test_val, s=sigma, scale=np.exp(mu))
|
||||
self.assertAllClose([expected], [log_normal.log_pdf(test_val).eval()])
|
||||
self.assertAllClose([np.exp(expected)],
|
||||
[log_normal.pdf(test_val).eval()])
|
||||
|
||||
def testCachedSamplesWithoutInverse(self):
|
||||
with self.test_session() as sess:
|
||||
mu = 3.0
|
||||
sigma = 0.02
|
||||
log_normal = tf.contrib.distributions.ContinuousTransformedDistribution(
|
||||
base_dist_cls=tf.contrib.distributions.Normal,
|
||||
mu=mu,
|
||||
sigma=sigma,
|
||||
transform=lambda x: tf.exp(x),
|
||||
inverse=None,
|
||||
log_det_jacobian=(lambda x: tf.reduce_sum(x)))
|
||||
|
||||
sample = log_normal.sample(1)
|
||||
sample_val, log_pdf_val = sess.run([sample, log_normal.log_pdf(sample)])
|
||||
self.assertAllClose(
|
||||
stats.lognorm.logpdf(sample_val, s=sigma,
|
||||
scale=np.exp(mu)),
|
||||
log_pdf_val,
|
||||
atol=1e-2)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"was not returned from `sample`"):
|
||||
log_normal.log_pdf(tf.constant(3.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -0,0 +1,252 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A Transformed Distribution class."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import distribution # pylint: disable=line-too-long
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
class ContinuousTransformedDistribution(distribution.ContinuousDistribution):
|
||||
"""A Transformed Distribution.
|
||||
|
||||
A Transformed Distribution models `p(y)` given a base distribution `p(x)`,
|
||||
an invertible transform, `y = f(x)`, and the determinant of the Jacobian of
|
||||
`f(x)`.
|
||||
|
||||
Shapes, type, and reparameterization are taken from the base distribution.
|
||||
|
||||
#### Mathematical details
|
||||
|
||||
* `p(x)` - probability distribution for random variable X
|
||||
* `p(y)` - probability distribution for random variable Y
|
||||
* `f` - transform
|
||||
* `g` - inverse transform, `f(g(x)) = x`
|
||||
* `J(x)` - Jacobian of f(x)
|
||||
|
||||
A Transformed Distribution exposes `sample` and `pdf`:
|
||||
|
||||
* `sample`: `y = f(x)`, after drawing a sample of X.
|
||||
* `pdf`: `p(y) = p(x) / det|J(x)| = p(g(y)) / det|J(g(y))|`
|
||||
|
||||
A simple example constructing a Log-Normal distribution from a Normal
|
||||
distribution:
|
||||
|
||||
```
|
||||
logit_normal = ContinuousTransformedDistribution(
|
||||
base_dist=Normal(mu, sigma),
|
||||
transform=lambda x: tf.sigmoid(x),
|
||||
inverse=lambda y: tf.log(y) - tf.log(1. - y),
|
||||
log_det_jacobian=(lambda x:
|
||||
tf.reduce_sum(tf.log(tf.sigmoid(x)) + tf.log(1. - tf.sigmoid(x)),
|
||||
reduction_indices=[-1])))
|
||||
name="LogitNormalTransformedDistribution"
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
base_dist_cls,
|
||||
transform,
|
||||
inverse,
|
||||
log_det_jacobian,
|
||||
name="ContinuousTransformedDistribution",
|
||||
**base_dist_args):
|
||||
"""Construct a Transformed Distribution.
|
||||
|
||||
Args:
|
||||
base_dist_cls: the base distribution class to transform. Must be a
|
||||
subclass of `ContinuousDistribution`.
|
||||
transform: a callable that takes a `Tensor` sample from `base_dist` and
|
||||
returns a `Tensor` of the same shape and type. `x => y`.
|
||||
inverse: a callable that computes the inverse of transform. `y => x`. If
|
||||
None, users can only call `log_pdf` on values returned by `sample`.
|
||||
log_det_jacobian: a callable that takes a `Tensor` sample from `base_dist`
|
||||
and returns the log of the determinant of the Jacobian of `transform`.
|
||||
name: The name for the distribution.
|
||||
**base_dist_args: kwargs to pass on to dist_cls on construction.
|
||||
|
||||
Raises:
|
||||
TypeError: if `base_dist_cls` is not a subclass of
|
||||
`ContinuousDistribution`.
|
||||
"""
|
||||
if not issubclass(base_dist_cls, distribution.ContinuousDistribution):
|
||||
raise TypeError("base_dist_cls must be a subclass of"
|
||||
"ContinuousDistribution.")
|
||||
with ops.op_scope(base_dist_args.values(), name) as scope:
|
||||
self._name = scope
|
||||
self._base_dist = base_dist_cls(**base_dist_args)
|
||||
self._transform = transform
|
||||
self._inverse = inverse
|
||||
self._log_det_jacobian = log_det_jacobian
|
||||
self._inverse_cache = {}
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._base_dist.dtype
|
||||
|
||||
def batch_shape(self, name="batch_shape"):
|
||||
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
|
||||
|
||||
The product of the dimensions of the `batch_shape` is the number of
|
||||
independent distributions of this kind the instance represents.
|
||||
|
||||
Args:
|
||||
name: name to give to the op.
|
||||
|
||||
Returns:
|
||||
`Tensor` `batch_shape`
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
return self._base_dist.batch_shape(name)
|
||||
|
||||
def get_batch_shape(self):
|
||||
"""`TensorShape` available at graph construction time.
|
||||
|
||||
Same meaning as `batch_shape`. May be only partially defined.
|
||||
|
||||
Returns:
|
||||
batch shape
|
||||
"""
|
||||
return self._base_dist.get_batch_shape()
|
||||
|
||||
def event_shape(self, name="event_shape"):
|
||||
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
|
||||
|
||||
Args:
|
||||
name: name to give to the op.
|
||||
|
||||
Returns:
|
||||
`Tensor` `event_shape`
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
return self._base_dist.event_shape(name)
|
||||
|
||||
def get_event_shape(self):
|
||||
"""`TensorShape` available at graph construction time.
|
||||
|
||||
Same meaning as `event_shape`. May be only partially defined.
|
||||
|
||||
Returns:
|
||||
event shape
|
||||
"""
|
||||
return self._base_dist.get_event_shape()
|
||||
|
||||
@property
|
||||
def base_distribution(self):
|
||||
"""Base distribution, p(x)."""
|
||||
return self._base_dist
|
||||
|
||||
@property
|
||||
def transform(self):
|
||||
"""Function transforming x => y."""
|
||||
return self._transform
|
||||
|
||||
@property
|
||||
def inverse(self):
|
||||
"""Inverse function of transform, y => x."""
|
||||
return self._inverse
|
||||
|
||||
@property
|
||||
def log_det_jacobian(self):
|
||||
"""Function computing the log determinant of the Jacobian of transform."""
|
||||
return self._log_det_jacobian
|
||||
|
||||
def log_pdf(self, y, name="log_pdf"):
|
||||
"""Log pdf of observations in `y`.
|
||||
|
||||
`log ( p(g(y)) / det|J(g(y))| )`, where `g` is the inverse of `transform`.
|
||||
|
||||
Args:
|
||||
y: tensor of dtype `dtype`.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
log_pdf: tensor of dtype `dtype`, the log-PDFs of `y`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `inverse` was not provided to the distribution and `y` was
|
||||
not returned from `sample`.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([y], name):
|
||||
y = ops.convert_to_tensor(y)
|
||||
if y.dtype != self.dtype:
|
||||
raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
|
||||
(y.dtype, self.dtype))
|
||||
with ops.name_scope("inverse"):
|
||||
if y in self._inverse_cache:
|
||||
x = self._inverse_cache[y]
|
||||
elif self._inverse:
|
||||
x = self._inverse(y)
|
||||
else:
|
||||
raise ValueError("No inverse function exists and input `y` was not "
|
||||
"returned from `sample`.")
|
||||
with ops.name_scope("log_det_jacobian"):
|
||||
log_det_jacobian = self._log_det_jacobian(x)
|
||||
return self._base_dist.log_likelihood(x) - log_det_jacobian
|
||||
|
||||
def pdf(self, y, name="pdf"):
|
||||
"""The PDF of observations in `y`.
|
||||
|
||||
`p(g(y)) / det|J(g(y))|`, where `g` is the inverse of `transform`.
|
||||
|
||||
Args:
|
||||
y: `Tensor` of dtype `dtype`.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
pdf: `Tensor` of dtype `dtype`, the pdf values of `y`.
|
||||
"""
|
||||
return super(ContinuousTransformedDistribution, self).pdf(y, name=name)
|
||||
|
||||
def sample(self, n, seed=None, name="sample"):
|
||||
"""Sample `n` observations.
|
||||
|
||||
Samples from the base distribution and then passes through the transform.
|
||||
|
||||
Args:
|
||||
n: scalar, type int32, the number of observations to sample.
|
||||
seed: Python integer, the random seed.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
samples: `[n, ...]`, a `Tensor` of `n` samples.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.name_scope(name):
|
||||
samples = self._base_dist.sample(n=n, seed=seed)
|
||||
with ops.name_scope("transform"):
|
||||
transformed = self._transform(samples)
|
||||
self._inverse_cache[transformed] = samples
|
||||
return transformed
|
||||
|
||||
@property
|
||||
def is_reparameterized(self):
|
||||
return self._base_dist.is_reparameterized
|
||||
|
||||
@property
|
||||
def strict_statistics(self):
|
||||
return self._base_dist.strict_statistics
|
||||
|
||||
@property
|
||||
def strict(self):
|
||||
return self._base_dist.strict
|
Loading…
Reference in New Issue
Block a user