Add Poisson-LogNormal (approximate) compound distribution.

PiperOrigin-RevId: 163480957
This commit is contained in:
Joshua V. Dillon 2017-07-28 09:50:13 -07:00 committed by TensorFlower Gardener
parent 7635e9db10
commit 4f60ddb470
7 changed files with 815 additions and 235 deletions

View File

@ -44,6 +44,7 @@ py_library(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:histogram_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
@ -327,6 +328,17 @@ cuda_py_test(
],
)
cuda_py_test(
name = "poisson_lognormal_test",
size = "small",
srcs = ["python/kernel_tests/poisson_lognormal_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "sample_stats_test",
size = "small",
@ -518,19 +530,13 @@ cuda_py_test(
cuda_py_test(
name = "vector_diffeomixture_test",
size = "large",
size = "small",
srcs = ["python/kernel_tests/vector_diffeomixture_test.py"],
additional_deps = [
":bijectors_py",
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
],
)

View File

@ -44,6 +44,7 @@ from tensorflow.contrib.distributions.python.ops.negative_binomial import *
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *
from tensorflow.contrib.distributions.python.ops.onehot_categorical import *
from tensorflow.contrib.distributions.python.ops.poisson import *
from tensorflow.contrib.distributions.python.ops.poisson_lognormal import *
from tensorflow.contrib.distributions.python.ops.quantized_distribution import *
from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import *
from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import *
@ -117,6 +118,7 @@ _allowed_symbols = [
'Normal',
'NormalWithSoftplusScale',
'Poisson',
'PoissonLogNormalQuadratureCompound',
'StudentT',
'StudentTWithAbsDfSoftplusScale',
'Uniform',

View File

@ -0,0 +1,92 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for PoissonLogNormalQuadratureCompoundTest."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import poisson_lognormal
from tensorflow.contrib.distributions.python.ops import test_util
from tensorflow.python.platform import test
class PoissonLogNormalQuadratureCompoundTest(
test_util.DiscreteScalarDistributionTestHelpers, test.TestCase):
"""Tests the PoissonLogNormalQuadratureCompoundTest distribution."""
def testSampleProbConsistent(self):
with self.test_session() as sess:
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=-2.,
scale=1.1,
quadrature_polynomial_degree=10,
validate_args=True)
self.run_test_sample_consistent_log_prob(
sess, pln, rtol=0.1)
def testMeanVariance(self):
with self.test_session() as sess:
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=0.,
scale=1.,
quadrature_polynomial_degree=10,
validate_args=True)
self.run_test_sample_consistent_mean_variance(
sess, pln, rtol=0.02)
def testSampleProbConsistentBroadcastScalar(self):
with self.test_session() as sess:
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=[0., -0.5],
scale=1.,
quadrature_polynomial_degree=10,
validate_args=True)
self.run_test_sample_consistent_log_prob(
sess, pln, rtol=0.1, atol=0.01)
def testMeanVarianceBroadcastScalar(self):
with self.test_session() as sess:
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=[0., -0.5],
scale=1.,
quadrature_polynomial_degree=10,
validate_args=True)
self.run_test_sample_consistent_mean_variance(
sess, pln, rtol=0.1, atol=0.01)
def testSampleProbConsistentBroadcastBoth(self):
with self.test_session() as sess:
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=[[0.], [-0.5]],
scale=[[1., 0.9]],
quadrature_polynomial_degree=10,
validate_args=True)
self.run_test_sample_consistent_log_prob(
sess, pln, rtol=0.1, atol=0.08)
def testMeanVarianceBroadcastBoth(self):
with self.test_session() as sess:
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
loc=[[0.], [-0.5]],
scale=[[1., 0.9]],
quadrature_polynomial_degree=10,
validate_args=True)
self.run_test_sample_consistent_mean_variance(
sess, pln, rtol=0.1, atol=0.01)
if __name__ == "__main__":
test.main()

View File

@ -20,236 +20,16 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops import test_util
from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vector_diffeomixture_lib
from tensorflow.contrib.linalg.python.ops import linear_operator_diag as linop_diag_lib
from tensorflow.contrib.linalg.python.ops import linear_operator_identity as linop_identity_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.platform import test
class VectorDistributionTestHelpers(object):
"""VectorDistributionTestHelpers helps test vector-event distributions."""
def linop(self, num_rows=None, multiplier=None, diag=None):
"""Helper to create non-singular, symmetric, positive definite matrices."""
if num_rows is not None and multiplier is not None:
if any(p is not None for p in [diag]):
raise ValueError("Found extra args for scaled identity.")
return linop_identity_lib.LinearOperatorScaledIdentity(
num_rows=num_rows,
multiplier=multiplier,
is_positive_definite=True)
elif num_rows is not None:
if any(p is not None for p in [multiplier, diag]):
raise ValueError("Found extra args for identity.")
return linop_identity_lib.LinearOperatorIdentity(
num_rows=num_rows,
is_positive_definite=True)
elif diag is not None:
if any(p is not None for p in [num_rows, multiplier]):
raise ValueError("Found extra args for diag.")
return linop_diag_lib.LinearOperatorDiag(
diag=diag,
is_positive_definite=True)
else:
raise ValueError("Must specify at least one arg.")
def run_test_sample_consistent_log_prob(
self,
sess,
dist,
num_samples=int(1e5),
radius=1.,
center=0.,
seed=42,
rtol=1e-2,
atol=0.):
"""Tests that sample/log_prob are mutually consistent.
"Consistency" means that `sample` and `log_prob` correspond to the same
distribution.
The idea of this test is to compute the Monte-Carlo estimate of the volume
enclosed by a hypersphere, i.e., the volume of an `n`-ball. While we could
choose an arbitrary function to integrate, the hypersphere's volume is nice
because it is intuitive, has an easy analytical expression, and works for
`dimensions > 1`.
Technical Details:
Observe that:
```none
int_{R**d} dx [x in Ball(radius=r, center=c)]
= E_{p(X)}[ [X in Ball(r, c)] / p(X) ]
= lim_{m->infty} m**-1 sum_j^m [x[j] in Ball(r, c)] / p(x[j]),
where x[j] ~iid p(X)
```
Thus, for fixed `m`, the above is approximately true when `sample` and
`log_prob` are mutually consistent.
Furthermore, the above calculation has the analytical result:
`pi**(d/2) r**d / Gamma(1 + d/2)`.
Note: this test only verifies a necessary condition for consistency--it does
does not verify sufficiency hence does not prove `sample`, `log_prob` truly
are consistent. For this reason we recommend testing several different
hyperspheres (assuming the hypersphere is supported by the distribution).
Furthermore, we gain additional trust in this test when also tested `sample`
against the first, second moments
(`run_test_sample_consistent_mean_covariance`); it is probably unlikely that
a "best-effort" implementation of `log_prob` would incorrectly pass both
tests and for different hyperspheres.
For a discussion on the analytical result (second-line) see:
https://en.wikipedia.org/wiki/Volume_of_an_n-ball.
For a discussion of importance sampling (fourth-line) see:
https://en.wikipedia.org/wiki/Importance_sampling.
Args:
sess: Tensorflow session.
dist: Distribution instance or object which implements `sample`,
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`. The
distribution must have non-zero probability of sampling every point
enclosed by the hypersphere.
num_samples: Python `int` scalar indicating the number of Monte-Carlo
samples to draw from `dist`.
radius: Python `float`-type indicating the radius of the `n`-ball which
we're computing the volume.
center: Python floating-type vector (or scalar) indicating the center of
the `n`-ball which we're computing the volume. When scalar, the value is
broadcast to all event dims.
seed: Python `int` indicating the seed to use when sampling from `dist`.
In general it is not recommended to use `None` during a test as this
increases the likelihood of spurious test failure.
rtol: Python `float`-type indicating the admissible relative error between
actual- and approximate-volumes.
atol: Python `float`-type indicating the admissible absolute error between
actual- and approximate-volumes. In general this should be zero since
a typical radius implies a non-zero volume.
"""
def actual_hypersphere_volume(dims, radius):
# https://en.wikipedia.org/wiki/Volume_of_an_n-ball
# Using tf.lgamma because we'd have to otherwise use SciPy which is not
# a required dependency of core.
radius = np.asarray(radius)
dims = math_ops.cast(dims, dtype=radius.dtype)
return math_ops.exp(
(dims / 2.) * np.log(np.pi)
- math_ops.lgamma(1. + dims / 2.)
+ dims * math_ops.log(radius))
def is_in_ball(x, radius, center):
return math_ops.cast(linalg_ops.norm(x - center, axis=-1) <= radius,
dtype=x.dtype)
def monte_carlo_hypersphere_volume(dist, num_samples, radius, center):
# https://en.wikipedia.org/wiki/Importance_sampling
x = dist.sample(num_samples, seed=seed)
return math_ops.reduce_mean(
math_ops.exp(-dist.log_prob(x)) * is_in_ball(x, radius, center),
axis=0)
[
batch_shape_,
actual_volume_,
sample_volume_,
] = sess.run([
dist.batch_shape_tensor(),
actual_hypersphere_volume(
dims=dist.event_shape_tensor()[0],
radius=radius),
monte_carlo_hypersphere_volume(
dist,
num_samples=num_samples,
radius=radius,
center=center),
])
self.assertAllClose(np.tile(actual_volume_, reps=batch_shape_),
sample_volume_,
rtol=rtol, atol=atol)
def run_test_sample_consistent_mean_covariance(
self,
sess,
dist,
num_samples=int(1e5),
seed=24,
rtol=1e-2,
atol=0.,
cov_rtol=None,
cov_atol=None):
"""Tests that sample/mean/covariance are consistent with each other.
"Consistency" means that `sample`, `mean`, `covariance`, etc all correspond
to the same distribution.
Args:
sess: Tensorflow session.
dist: Distribution instance or object which implements `sample`,
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
num_samples: Python `int` scalar indicating the number of Monte-Carlo
samples to draw from `dist`.
seed: Python `int` indicating the seed to use when sampling from `dist`.
In general it is not recommended to use `None` during a test as this
increases the likelihood of spurious test failure.
rtol: Python `float`-type indicating the admissible relative error between
analytical and sample statistics.
atol: Python `float`-type indicating the admissible absolute error between
analytical and sample statistics.
cov_rtol: Python `float`-type indicating the admissible relative error
between analytical and sample covariance. Default: rtol.
cov_atol: Python `float`-type indicating the admissible absolute error
between analytical and sample covariance. Default: atol.
"""
def vec_osquare(x):
"""Computes the outer-product of a vector, i.e., x.T x."""
return x[..., :, array_ops.newaxis] * x[..., array_ops.newaxis, :]
x = dist.sample(num_samples, seed=seed)
sample_mean = math_ops.reduce_mean(x, axis=0)
sample_covariance = math_ops.reduce_mean(
vec_osquare(x - sample_mean), axis=0)
sample_variance = array_ops.matrix_diag_part(sample_covariance)
sample_stddev = math_ops.sqrt(sample_variance)
[
sample_mean_,
sample_covariance_,
sample_variance_,
sample_stddev_,
mean_,
covariance_,
variance_,
stddev_
] = sess.run([
sample_mean,
sample_covariance,
sample_variance,
sample_stddev,
dist.mean(),
dist.covariance(),
dist.variance(),
dist.stddev(),
])
self.assertAllClose(mean_, sample_mean_, rtol=rtol, atol=atol)
self.assertAllClose(covariance_, sample_covariance_,
rtol=cov_rtol or rtol,
atol=cov_atol or atol)
self.assertAllClose(variance_, sample_variance_, rtol=rtol, atol=atol)
self.assertAllClose(stddev_, sample_stddev_, rtol=rtol, atol=atol)
class VectorDiffeomixtureTest(VectorDistributionTestHelpers, test.TestCase):
class VectorDiffeomixtureTest(
test_util.VectorDistributionTestHelpers, test.TestCase):
"""Tests the VectorDiffeomixture distribution."""
def testSampleProbConsistentBroadcastMix(self):

View File

@ -0,0 +1,313 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The PoissonLogNormalQuadratureCompound distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import categorical as categorical_lib
from tensorflow.python.ops.distributions import distribution as distribution_lib
from tensorflow.python.ops.distributions import util as distribution_util
__all__ = [
"PoissonLogNormalQuadratureCompound",
]
class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
"""`PoissonLogNormalQuadratureCompound` distribution.
The `PoissonLogNormalQuadratureCompound` is an approximation to a
Poisson-LogNormal [compound distribution](
https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e.,
```none
p(k|loc, scale)
= int_{R_+} dl LogNormal(l | loc, scale) Poisson(k | l)
= int_{R} dz ((lambda(z) sqrt(2) scale)
* exp(-z**2) / (lambda(z) sqrt(2 pi) sigma)
* Poisson(k | lambda(z)))
= int_{R} dz exp(-z**2) / sqrt(pi) Poisson(k | lambda(z))
approx= sum{ prob[d] Poisson(k | lambda(grid[d])) : d=0, ..., deg-1 }
```
where `lambda(z) = exp(sqrt(2) scale z + loc)` and the `prob,grid` terms
are from [Gauss--Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). Note that
the second line made the substitution:
`z(l) = (log(l) - loc) / (sqrt(2) scale)` which implies `lambda(z)` [above]
and `dl = sqrt(2) scale lambda(z) dz`
In the non-approximation case, a draw from the LogNormal prior represents the
Poisson rate parameter. Unfortunately, the non-approximate distribution lacks
an analytical probability density function (pdf). Therefore the
`PoissonLogNormalQuadratureCompound` class implements an approximation based
on [Gauss-Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature).
Note: although the `PoissonLogNormalQuadratureCompound` is approximately the
Poisson-LogNormal compound distribution, it is itself a valid distribution.
Viz., it possesses a `sample`, `log_prob`, `mean`, `variance`, etc. which are
all mutually consistent.
#### Mathematical Details
The `PoissonLogNormalQuadratureCompound` approximates a Poisson-LogNormal
[compound distribution](
https://en.wikipedia.org/wiki/Compound_probability_distribution).
Using variable-substitution and [Gauss-Hermite quadrature](
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) we can
redefine the distribution to be a parameter-less convex combination of `deg`
different Poisson samples.
That is, defined over positive integers, this distribution is parameterized
by a (batch of) `loc` and `scale` scalars.
The probability density function (pdf) is,
```none
pdf(k | loc, scale, deg)
= sum{ prob[d] Poisson(k | lambda=exp(sqrt(2) scale grid[d] + loc))
: d=0, ..., deg-1 }
```
where, [`grid, w = numpy.polynomial.hermite.hermgauss(deg)`](
https://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.polynomial.hermite.hermgauss.html)
and `prob = w / sqrt(pi)`.
#### Examples
```python
ds = tf.contrib.distributions
# Create two batches of PoissonLogNormalQuadratureCompounds, one with
# prior `loc = 0.` and another with `loc = 1.` In both cases `scale = 1.`
pln = ds.PoissonLogNormalQuadratureCompound(
loc=[0., -0.5],
scale=1.,
quadrature_polynomial_degree=10,
validate_args=True)
"""
def __init__(self,
loc,
scale,
quadrature_polynomial_degree=8,
validate_args=False,
allow_nan_stats=True,
name="PoissonLogNormalQuadratureCompound"):
"""Constructs the PoissonLogNormalQuadratureCompound on `R**k`.
Args:
loc: `float`-like (batch of) scalar `Tensor`; the location parameter of
the LogNormal prior.
scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
the LogNormal prior.
quadrature_polynomial_degree: Python `int`-like scalar.
Default value: 8.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`,
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
indicate the result is undefined. When `False`, an exception is raised
if one or more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
Raises:
TypeError: if `loc.dtype != scale[0].dtype`.
"""
parameters = locals()
with ops.name_scope(name, values=[loc, scale]):
loc = ops.convert_to_tensor(loc, name="loc")
self._loc = loc
scale = ops.convert_to_tensor(scale, name="scale")
self._scale = scale
dtype = loc.dtype.base_dtype
if dtype != scale.dtype.base_dtype:
raise TypeError(
"loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format(
loc.dtype.name, scale.dtype.name))
self._degree = quadrature_polynomial_degree
grid, prob = np.polynomial.hermite.hermgauss(
deg=quadrature_polynomial_degree)
# It should be that `sum(prob) == sqrt(pi)`, but self-normalization is
# more numerically stable.
prob = prob.astype(dtype.as_numpy_dtype)
prob /= np.linalg.norm(prob, ord=1)
self._mixture_distribution = categorical_lib.Categorical(
logits=np.log(prob),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats)
# The following maps the broadcast of `loc` and `scale` to each grid
# point, i.e., we are creating several log-rates that correspond to the
# different Gauss-Hermite quadrature points and (possible) batches of
# `loc` and `scale`.
self._log_rate = (loc[..., array_ops.newaxis]
+ np.sqrt(2.) * scale[..., array_ops.newaxis] * grid)
self._distribution = poisson_lib.Poisson(
rate=math_ops.exp(self._log_rate, name="rate"),
validate_args=validate_args,
allow_nan_stats=allow_nan_stats)
super(PoissonLogNormalQuadratureCompound, self).__init__(
dtype=dtype,
reparameterization_type=distribution_lib.NOT_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[loc, scale],
name=name)
@property
def mixture_distribution(self):
"""Distribution which randomly selects a Poisson with Gauss-Hermite rate."""
return self._mixture_distribution
@property
def distribution(self):
"""Base Poisson parameterized by a Gauss-Hermite grid of rates."""
return self._distribution
@property
def loc(self):
"""Location parameter of the LogNormal prior."""
return self._loc
@property
def scale(self):
"""Scale parameter of the LogNormal prior."""
return self._scale
@property
def quadrature_polynomial_degree(self):
"""Polynomial largest exponent used for Gauss-Hermite quadrature."""
return self._degree
def _batch_shape_tensor(self):
return array_ops.broadcast_dynamic_shape(
array_ops.shape(self.loc),
array_ops.shape(self.scale))
def _batch_shape(self):
return array_ops.broadcast_static_shape(
self.loc.shape,
self.scale.shape)
def _event_shape(self):
return tensor_shape.scalar()
def _sample_n(self, n, seed=None):
# Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
# ids as a [n]-shaped vector.
batch_size = (np.prod(self.batch_shape.as_list(), dtype=np.int32)
if self.batch_shape.is_fully_defined()
else math_ops.reduce_prod(self.batch_shape_tensor()))
ids = self._mixture_distribution.sample(
sample_shape=concat_vectors(
[n],
distribution_util.pick_vector(
self.is_scalar_batch(),
np.int32([]),
[batch_size])),
seed=distribution_util.gen_new_seed(
seed, "poisson_lognormal_quadrature_compound"))
# Stride `quadrature_polynomial_degree` for `batch_size` number of times.
offset = math_ops.range(start=0,
limit=batch_size * self._degree,
delta=self._degree,
dtype=ids.dtype)
ids += offset
rate = array_ops.gather(
array_ops.reshape(self.distribution.rate, shape=[-1]), ids)
rate = array_ops.reshape(
rate, shape=concat_vectors([n], self.batch_shape_tensor()))
return random_ops.random_poisson(
lam=rate, shape=[], dtype=self.dtype, seed=seed)
def _log_prob(self, x):
return math_ops.reduce_logsumexp(
(self.mixture_distribution.logits
+ self.distribution.log_prob(x[..., array_ops.newaxis])),
axis=-1)
def _mean(self):
return math_ops.exp(
math_ops.reduce_logsumexp(
self.mixture_distribution.logits + self._log_rate,
axis=-1))
def _variance(self):
return math_ops.exp(self._log_variance())
def _stddev(self):
return math_ops.exp(0.5 * self._log_variance())
def _log_variance(self):
# Following calculation is based on law of total variance:
#
# Var[Z] = E[Var[Z | V]] + Var[E[Z | V]]
#
# where,
#
# Z|v ~ interpolate_affine[v](distribution)
# V ~ mixture_distrubution
#
# thus,
#
# E[Var[Z | V]] = sum{ prob[d] Var[d] : d=0, ..., deg-1 }
# Var[E[Z | V]] = sum{ prob[d] (Mean[d] - Mean)**2 : d=0, ..., deg-1 }
v = array_ops.stack([
# log(self.distribution.variance()) = log(Var[d]) = log(rate[d])
self._log_rate,
# log((Mean[d] - Mean)**2)
2. * math_ops.log(
math_ops.abs(self.distribution.mean()
- self._mean()[..., array_ops.newaxis])),
], axis=-1)
return math_ops.reduce_logsumexp(
self.mixture_distribution.logits[..., array_ops.newaxis] + v,
axis=[-2, -1])
def static_value(x):
"""Returns the static value of a `Tensor` or `None`."""
return tensor_util.constant_value(ops.convert_to_tensor(x))
def concat_vectors(*args):
"""Concatenates input vectors, statically if possible."""
args_ = [static_value(x) for x in args]
if any(vec is None for vec in args_):
return array_ops.concat(args, axis=0)
return [val for vec in args_ for val in vec]

View File

@ -0,0 +1,378 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for testing distributions and/or bijectors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import histogram_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
__all__ = [
"DiscreteScalarDistributionTestHelpers",
"VectorDistributionTestHelpers",
]
class DiscreteScalarDistributionTestHelpers(object):
"""DiscreteScalarDistributionTestHelpers."""
def run_test_sample_consistent_log_prob(
self, sess, dist,
num_samples=int(1e5), num_threshold=int(1e3), seed=42,
rtol=1e-2, atol=0.):
"""Tests that sample/log_prob are consistent with each other.
"Consistency" means that `sample` and `log_prob` correspond to the same
distribution.
Note: this test only verifies a necessary condition for consistency--it does
does not verify sufficiency hence does not prove `sample`, `log_prob` truly
are consistent.
Args:
sess: Tensorflow session.
dist: Distribution instance or object which implements `sample`,
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
num_samples: Python `int` scalar indicating the number of Monte-Carlo
samples to draw from `dist`.
num_threshold: Python `int` scalar indicating the number of samples a
bucket must contain before being compared to the probability.
Default value: 1e3; must be at least 1.
Warning, set too high will cause test to falsely pass but setting too
low will cause the test to falsely fail.
seed: Python `int` indicating the seed to use when sampling from `dist`.
In general it is not recommended to use `None` during a test as this
increases the likelihood of spurious test failure.
rtol: Python `float`-type indicating the admissible relative error between
analytical and sample statistics.
atol: Python `float`-type indicating the admissible absolute error between
analytical and sample statistics.
Raises:
ValueError: if `num_threshold < 1`.
"""
if num_threshold < 1:
raise ValueError("num_threshold({}) must be at least 1.".format(
num_threshold))
# Histogram only supports vectors so we call it once per batch coordinate.
y = dist.sample(num_samples, seed=seed)
y = array_ops.reshape(y, shape=[num_samples, -1])
batch_size = math_ops.reduce_prod(dist.batch_shape_tensor())
batch_dims = array_ops.shape(dist.batch_shape_tensor())[0]
edges_expanded_shape = 1 + array_ops.pad([-2], paddings=[[0, batch_dims]])
for b, x in enumerate(array_ops.unstack(y, axis=1)):
counts, edges = self.histogram(x)
edges = array_ops.reshape(edges, edges_expanded_shape)
probs = math_ops.exp(dist.log_prob(edges))
probs = array_ops.reshape(probs, shape=[-1, batch_size])[:, b]
[counts_, probs_] = sess.run([counts, probs])
valid = counts_ > num_threshold
probs_ = probs_[valid]
counts_ = counts_[valid]
self.assertAllClose(probs_, counts_ / num_samples,
rtol=rtol, atol=atol)
def run_test_sample_consistent_mean_variance(
self, sess, dist,
num_samples=int(1e5), seed=24,
rtol=1e-2, atol=0.):
"""Tests that sample/mean/variance are consistent with each other.
"Consistency" means that `sample`, `mean`, `variance`, etc all correspond
to the same distribution.
Args:
sess: Tensorflow session.
dist: Distribution instance or object which implements `sample`,
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
num_samples: Python `int` scalar indicating the number of Monte-Carlo
samples to draw from `dist`.
seed: Python `int` indicating the seed to use when sampling from `dist`.
In general it is not recommended to use `None` during a test as this
increases the likelihood of spurious test failure.
rtol: Python `float`-type indicating the admissible relative error between
analytical and sample statistics.
atol: Python `float`-type indicating the admissible absolute error between
analytical and sample statistics.
"""
x = math_ops.to_float(dist.sample(num_samples, seed=seed))
sample_mean = math_ops.reduce_mean(x, axis=0)
sample_variance = math_ops.reduce_mean(
math_ops.square(x - sample_mean), axis=0)
sample_stddev = math_ops.sqrt(sample_variance)
[
sample_mean_,
sample_variance_,
sample_stddev_,
mean_,
variance_,
stddev_
] = sess.run([
sample_mean,
sample_variance,
sample_stddev,
dist.mean(),
dist.variance(),
dist.stddev(),
])
self.assertAllClose(mean_, sample_mean_, rtol=rtol, atol=atol)
self.assertAllClose(variance_, sample_variance_, rtol=rtol, atol=atol)
self.assertAllClose(stddev_, sample_stddev_, rtol=rtol, atol=atol)
def histogram(self, x, value_range=None, nbins=None, name=None):
"""Return histogram of values.
Given the tensor `values`, this operation returns a rank 1 histogram
counting the number of entries in `values` that fell into every bin. The
bins are equal width and determined by the arguments `value_range` and
`nbins`.
Args:
x: 1D numeric `Tensor` of items to count.
value_range: Shape [2] `Tensor`. `new_values <= value_range[0]` will be
mapped to `hist[0]`, `values >= value_range[1]` will be mapped to
`hist[-1]`. Must be same dtype as `x`.
nbins: Scalar `int32 Tensor`. Number of histogram bins.
name: Python `str` name prefixed to Ops created by this class.
Returns:
counts: 1D `Tensor` of counts, i.e.,
`counts[i] = sum{ edges[i-1] <= values[j] < edges[i] : j }`.
edges: 1D `Tensor` characterizing intervals used for counting.
"""
with ops.name_scope(name, "histogram", [x]):
x = ops.convert_to_tensor(x, name="x")
if value_range is None:
value_range = [math_ops.reduce_min(x), 1 + math_ops.reduce_max(x)]
value_range = ops.convert_to_tensor(value_range, name="value_range")
lo = value_range[0]
hi = value_range[1]
if nbins is None:
nbins = math_ops.to_int32(hi - lo)
delta = (hi - lo) / math_ops.cast(
nbins, dtype=value_range.dtype.base_dtype)
edges = math_ops.range(
start=lo, limit=hi, delta=delta, dtype=x.dtype.base_dtype)
counts = histogram_ops.histogram_fixed_width(
x, value_range=value_range, nbins=nbins)
return counts, edges
class VectorDistributionTestHelpers(object):
"""VectorDistributionTestHelpers helps test vector-event distributions."""
def run_test_sample_consistent_log_prob(
self,
sess,
dist,
num_samples=int(1e5),
radius=1.,
center=0.,
seed=42,
rtol=1e-2,
atol=0.):
"""Tests that sample/log_prob are mutually consistent.
"Consistency" means that `sample` and `log_prob` correspond to the same
distribution.
The idea of this test is to compute the Monte-Carlo estimate of the volume
enclosed by a hypersphere, i.e., the volume of an `n`-ball. While we could
choose an arbitrary function to integrate, the hypersphere's volume is nice
because it is intuitive, has an easy analytical expression, and works for
`dimensions > 1`.
Technical Details:
Observe that:
```none
int_{R**d} dx [x in Ball(radius=r, center=c)]
= E_{p(X)}[ [X in Ball(r, c)] / p(X) ]
= lim_{m->infty} m**-1 sum_j^m [x[j] in Ball(r, c)] / p(x[j]),
where x[j] ~iid p(X)
```
Thus, for fixed `m`, the above is approximately true when `sample` and
`log_prob` are mutually consistent.
Furthermore, the above calculation has the analytical result:
`pi**(d/2) r**d / Gamma(1 + d/2)`.
Note: this test only verifies a necessary condition for consistency--it does
does not verify sufficiency hence does not prove `sample`, `log_prob` truly
are consistent. For this reason we recommend testing several different
hyperspheres (assuming the hypersphere is supported by the distribution).
Furthermore, we gain additional trust in this test when also tested `sample`
against the first, second moments
(`run_test_sample_consistent_mean_covariance`); it is probably unlikely that
a "best-effort" implementation of `log_prob` would incorrectly pass both
tests and for different hyperspheres.
For a discussion on the analytical result (second-line) see:
https://en.wikipedia.org/wiki/Volume_of_an_n-ball.
For a discussion of importance sampling (fourth-line) see:
https://en.wikipedia.org/wiki/Importance_sampling.
Args:
sess: Tensorflow session.
dist: Distribution instance or object which implements `sample`,
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`. The
distribution must have non-zero probability of sampling every point
enclosed by the hypersphere.
num_samples: Python `int` scalar indicating the number of Monte-Carlo
samples to draw from `dist`.
radius: Python `float`-type indicating the radius of the `n`-ball which
we're computing the volume.
center: Python floating-type vector (or scalar) indicating the center of
the `n`-ball which we're computing the volume. When scalar, the value is
broadcast to all event dims.
seed: Python `int` indicating the seed to use when sampling from `dist`.
In general it is not recommended to use `None` during a test as this
increases the likelihood of spurious test failure.
rtol: Python `float`-type indicating the admissible relative error between
actual- and approximate-volumes.
atol: Python `float`-type indicating the admissible absolute error between
actual- and approximate-volumes. In general this should be zero since
a typical radius implies a non-zero volume.
"""
def actual_hypersphere_volume(dims, radius):
# https://en.wikipedia.org/wiki/Volume_of_an_n-ball
# Using tf.lgamma because we'd have to otherwise use SciPy which is not
# a required dependency of core.
radius = np.asarray(radius)
dims = math_ops.cast(dims, dtype=radius.dtype)
return math_ops.exp(
(dims / 2.) * np.log(np.pi)
- math_ops.lgamma(1. + dims / 2.)
+ dims * math_ops.log(radius))
def is_in_ball(x, radius, center):
return math_ops.cast(linalg_ops.norm(x - center, axis=-1) <= radius,
dtype=x.dtype)
def monte_carlo_hypersphere_volume(dist, num_samples, radius, center):
# https://en.wikipedia.org/wiki/Importance_sampling
x = dist.sample(num_samples, seed=seed)
return math_ops.reduce_mean(
math_ops.exp(-dist.log_prob(x)) * is_in_ball(x, radius, center),
axis=0)
[
batch_shape_,
actual_volume_,
sample_volume_,
] = sess.run([
dist.batch_shape_tensor(),
actual_hypersphere_volume(
dims=dist.event_shape_tensor()[0],
radius=radius),
monte_carlo_hypersphere_volume(
dist,
num_samples=num_samples,
radius=radius,
center=center),
])
self.assertAllClose(np.tile(actual_volume_, reps=batch_shape_),
sample_volume_,
rtol=rtol, atol=atol)
def run_test_sample_consistent_mean_covariance(
self,
sess,
dist,
num_samples=int(1e5),
seed=24,
rtol=1e-2,
atol=0.,
cov_rtol=None,
cov_atol=None):
"""Tests that sample/mean/covariance are consistent with each other.
"Consistency" means that `sample`, `mean`, `covariance`, etc all correspond
to the same distribution.
Args:
sess: Tensorflow session.
dist: Distribution instance or object which implements `sample`,
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
num_samples: Python `int` scalar indicating the number of Monte-Carlo
samples to draw from `dist`.
seed: Python `int` indicating the seed to use when sampling from `dist`.
In general it is not recommended to use `None` during a test as this
increases the likelihood of spurious test failure.
rtol: Python `float`-type indicating the admissible relative error between
analytical and sample statistics.
atol: Python `float`-type indicating the admissible absolute error between
analytical and sample statistics.
cov_rtol: Python `float`-type indicating the admissible relative error
between analytical and sample covariance. Default: rtol.
cov_atol: Python `float`-type indicating the admissible absolute error
between analytical and sample covariance. Default: atol.
"""
x = dist.sample(num_samples, seed=seed)
sample_mean = math_ops.reduce_mean(x, axis=0)
sample_covariance = math_ops.reduce_mean(
_vec_outer_square(x - sample_mean), axis=0)
sample_variance = array_ops.matrix_diag_part(sample_covariance)
sample_stddev = math_ops.sqrt(sample_variance)
[
sample_mean_,
sample_covariance_,
sample_variance_,
sample_stddev_,
mean_,
covariance_,
variance_,
stddev_
] = sess.run([
sample_mean,
sample_covariance,
sample_variance,
sample_stddev,
dist.mean(),
dist.covariance(),
dist.variance(),
dist.stddev(),
])
self.assertAllClose(mean_, sample_mean_, rtol=rtol, atol=atol)
self.assertAllClose(covariance_, sample_covariance_,
rtol=cov_rtol or rtol,
atol=cov_atol or atol)
self.assertAllClose(variance_, sample_variance_, rtol=rtol, atol=atol)
self.assertAllClose(stddev_, sample_stddev_, rtol=rtol, atol=atol)
def _vec_outer_square(x, name=None):
"""Computes the outer-product of a vector, i.e., x.T x."""
with ops.name_scope(name, "vec_osquare", [x]):
return x[..., :, array_ops.newaxis] * x[..., array_ops.newaxis, :]

View File

@ -221,7 +221,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
quadrature_polynomial_degree=8,
validate_args=False,
allow_nan_stats=True,
name="VectorLocationScaleDiffeomixture"):
name="VectorDiffeomixture"):
"""Constructs the VectorDiffeomixture on `R**k`.
Args:
@ -387,24 +387,34 @@ class VectorDiffeomixture(distribution_lib.Distribution):
@property
def mixture_distribution(self):
"""Distribution used to select a convex combination of affine transforms."""
return self._mixture_distribution
@property
def distribution(self):
"""Base scalar-event, scalar-batch distribution."""
return self._distribution
@property
def interpolate_weight(self):
"""Grid of mixing probabilities, one for each grid point."""
return self._interpolate_weight
@property
def endpoint_affine(self):
"""Affine transformation for each of `K` components."""
return self._endpoint_affine
@property
def interpolated_affine(self):
"""Affine transformation for each convex combination of `K` components."""
return self._interpolated_affine
@property
def quadrature_polynomial_degree(self):
"""Polynomial largest exponent used for Gauss-Hermite quadrature."""
return self._degree
def _batch_shape_tensor(self):
return self._batch_shape_
@ -457,12 +467,12 @@ class VectorDiffeomixture(distribution_lib.Distribution):
# Alternatively:
# x = weight * x[0] + (1. - weight) * x[1]
x = weight * (x[0] - x[1]) + array_ops.ones_like(x[0]) * x[1]
x = weight * (x[0] - x[1]) + x[1]
return x
def _log_prob(self, x):
# By convention, we always put the the grid points right-most.
# By convention, we always put the grid points right-most.
y = array_ops.stack(
[aff.inverse(x) for aff in self.interpolated_affine],
axis=-1)
@ -740,8 +750,7 @@ def interpolate_loc(deg, interpolate_weight, loc):
x = interpolate_weight[..., array_ops.newaxis] * loc[0]
else:
delta = loc[0] - loc[1]
offset = array_ops.ones_like(loc[0]) * loc[1]
x = interpolate_weight[..., array_ops.newaxis] * delta + offset
x = interpolate_weight[..., array_ops.newaxis] * delta + loc[1]
return [x[..., k, :] for k in range(deg)]