Add Poisson-LogNormal (approximate) compound distribution.
PiperOrigin-RevId: 163480957
This commit is contained in:
parent
7635e9db10
commit
4f60ddb470
@ -44,6 +44,7 @@ py_library(
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:histogram_ops",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:linalg_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
@ -327,6 +328,17 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "poisson_lognormal_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/poisson_lognormal_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "sample_stats_test",
|
||||
size = "small",
|
||||
@ -518,19 +530,13 @@ cuda_py_test(
|
||||
|
||||
cuda_py_test(
|
||||
name = "vector_diffeomixture_test",
|
||||
size = "large",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/vector_diffeomixture_test.py"],
|
||||
additional_deps = [
|
||||
":bijectors_py",
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -44,6 +44,7 @@ from tensorflow.contrib.distributions.python.ops.negative_binomial import *
|
||||
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *
|
||||
from tensorflow.contrib.distributions.python.ops.onehot_categorical import *
|
||||
from tensorflow.contrib.distributions.python.ops.poisson import *
|
||||
from tensorflow.contrib.distributions.python.ops.poisson_lognormal import *
|
||||
from tensorflow.contrib.distributions.python.ops.quantized_distribution import *
|
||||
from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import *
|
||||
from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import *
|
||||
@ -117,6 +118,7 @@ _allowed_symbols = [
|
||||
'Normal',
|
||||
'NormalWithSoftplusScale',
|
||||
'Poisson',
|
||||
'PoissonLogNormalQuadratureCompound',
|
||||
'StudentT',
|
||||
'StudentTWithAbsDfSoftplusScale',
|
||||
'Uniform',
|
||||
|
@ -0,0 +1,92 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for PoissonLogNormalQuadratureCompoundTest."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import poisson_lognormal
|
||||
from tensorflow.contrib.distributions.python.ops import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class PoissonLogNormalQuadratureCompoundTest(
|
||||
test_util.DiscreteScalarDistributionTestHelpers, test.TestCase):
|
||||
"""Tests the PoissonLogNormalQuadratureCompoundTest distribution."""
|
||||
|
||||
def testSampleProbConsistent(self):
|
||||
with self.test_session() as sess:
|
||||
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
|
||||
loc=-2.,
|
||||
scale=1.1,
|
||||
quadrature_polynomial_degree=10,
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess, pln, rtol=0.1)
|
||||
|
||||
def testMeanVariance(self):
|
||||
with self.test_session() as sess:
|
||||
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
|
||||
loc=0.,
|
||||
scale=1.,
|
||||
quadrature_polynomial_degree=10,
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_mean_variance(
|
||||
sess, pln, rtol=0.02)
|
||||
|
||||
def testSampleProbConsistentBroadcastScalar(self):
|
||||
with self.test_session() as sess:
|
||||
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
|
||||
loc=[0., -0.5],
|
||||
scale=1.,
|
||||
quadrature_polynomial_degree=10,
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess, pln, rtol=0.1, atol=0.01)
|
||||
|
||||
def testMeanVarianceBroadcastScalar(self):
|
||||
with self.test_session() as sess:
|
||||
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
|
||||
loc=[0., -0.5],
|
||||
scale=1.,
|
||||
quadrature_polynomial_degree=10,
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_mean_variance(
|
||||
sess, pln, rtol=0.1, atol=0.01)
|
||||
|
||||
def testSampleProbConsistentBroadcastBoth(self):
|
||||
with self.test_session() as sess:
|
||||
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
|
||||
loc=[[0.], [-0.5]],
|
||||
scale=[[1., 0.9]],
|
||||
quadrature_polynomial_degree=10,
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_log_prob(
|
||||
sess, pln, rtol=0.1, atol=0.08)
|
||||
|
||||
def testMeanVarianceBroadcastBoth(self):
|
||||
with self.test_session() as sess:
|
||||
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
|
||||
loc=[[0.], [-0.5]],
|
||||
scale=[[1., 0.9]],
|
||||
quadrature_polynomial_degree=10,
|
||||
validate_args=True)
|
||||
self.run_test_sample_consistent_mean_variance(
|
||||
sess, pln, rtol=0.1, atol=0.01)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -20,236 +20,16 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import test_util
|
||||
from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vector_diffeomixture_lib
|
||||
from tensorflow.contrib.linalg.python.ops import linear_operator_diag as linop_diag_lib
|
||||
from tensorflow.contrib.linalg.python.ops import linear_operator_identity as linop_identity_lib
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.distributions import normal as normal_lib
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class VectorDistributionTestHelpers(object):
|
||||
"""VectorDistributionTestHelpers helps test vector-event distributions."""
|
||||
|
||||
def linop(self, num_rows=None, multiplier=None, diag=None):
|
||||
"""Helper to create non-singular, symmetric, positive definite matrices."""
|
||||
if num_rows is not None and multiplier is not None:
|
||||
if any(p is not None for p in [diag]):
|
||||
raise ValueError("Found extra args for scaled identity.")
|
||||
return linop_identity_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=num_rows,
|
||||
multiplier=multiplier,
|
||||
is_positive_definite=True)
|
||||
elif num_rows is not None:
|
||||
if any(p is not None for p in [multiplier, diag]):
|
||||
raise ValueError("Found extra args for identity.")
|
||||
return linop_identity_lib.LinearOperatorIdentity(
|
||||
num_rows=num_rows,
|
||||
is_positive_definite=True)
|
||||
elif diag is not None:
|
||||
if any(p is not None for p in [num_rows, multiplier]):
|
||||
raise ValueError("Found extra args for diag.")
|
||||
return linop_diag_lib.LinearOperatorDiag(
|
||||
diag=diag,
|
||||
is_positive_definite=True)
|
||||
else:
|
||||
raise ValueError("Must specify at least one arg.")
|
||||
|
||||
def run_test_sample_consistent_log_prob(
|
||||
self,
|
||||
sess,
|
||||
dist,
|
||||
num_samples=int(1e5),
|
||||
radius=1.,
|
||||
center=0.,
|
||||
seed=42,
|
||||
rtol=1e-2,
|
||||
atol=0.):
|
||||
"""Tests that sample/log_prob are mutually consistent.
|
||||
|
||||
"Consistency" means that `sample` and `log_prob` correspond to the same
|
||||
distribution.
|
||||
|
||||
The idea of this test is to compute the Monte-Carlo estimate of the volume
|
||||
enclosed by a hypersphere, i.e., the volume of an `n`-ball. While we could
|
||||
choose an arbitrary function to integrate, the hypersphere's volume is nice
|
||||
because it is intuitive, has an easy analytical expression, and works for
|
||||
`dimensions > 1`.
|
||||
|
||||
Technical Details:
|
||||
|
||||
Observe that:
|
||||
|
||||
```none
|
||||
int_{R**d} dx [x in Ball(radius=r, center=c)]
|
||||
= E_{p(X)}[ [X in Ball(r, c)] / p(X) ]
|
||||
= lim_{m->infty} m**-1 sum_j^m [x[j] in Ball(r, c)] / p(x[j]),
|
||||
where x[j] ~iid p(X)
|
||||
```
|
||||
|
||||
Thus, for fixed `m`, the above is approximately true when `sample` and
|
||||
`log_prob` are mutually consistent.
|
||||
|
||||
Furthermore, the above calculation has the analytical result:
|
||||
`pi**(d/2) r**d / Gamma(1 + d/2)`.
|
||||
|
||||
Note: this test only verifies a necessary condition for consistency--it does
|
||||
does not verify sufficiency hence does not prove `sample`, `log_prob` truly
|
||||
are consistent. For this reason we recommend testing several different
|
||||
hyperspheres (assuming the hypersphere is supported by the distribution).
|
||||
Furthermore, we gain additional trust in this test when also tested `sample`
|
||||
against the first, second moments
|
||||
(`run_test_sample_consistent_mean_covariance`); it is probably unlikely that
|
||||
a "best-effort" implementation of `log_prob` would incorrectly pass both
|
||||
tests and for different hyperspheres.
|
||||
|
||||
For a discussion on the analytical result (second-line) see:
|
||||
https://en.wikipedia.org/wiki/Volume_of_an_n-ball.
|
||||
|
||||
For a discussion of importance sampling (fourth-line) see:
|
||||
https://en.wikipedia.org/wiki/Importance_sampling.
|
||||
|
||||
Args:
|
||||
sess: Tensorflow session.
|
||||
dist: Distribution instance or object which implements `sample`,
|
||||
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`. The
|
||||
distribution must have non-zero probability of sampling every point
|
||||
enclosed by the hypersphere.
|
||||
num_samples: Python `int` scalar indicating the number of Monte-Carlo
|
||||
samples to draw from `dist`.
|
||||
radius: Python `float`-type indicating the radius of the `n`-ball which
|
||||
we're computing the volume.
|
||||
center: Python floating-type vector (or scalar) indicating the center of
|
||||
the `n`-ball which we're computing the volume. When scalar, the value is
|
||||
broadcast to all event dims.
|
||||
seed: Python `int` indicating the seed to use when sampling from `dist`.
|
||||
In general it is not recommended to use `None` during a test as this
|
||||
increases the likelihood of spurious test failure.
|
||||
rtol: Python `float`-type indicating the admissible relative error between
|
||||
actual- and approximate-volumes.
|
||||
atol: Python `float`-type indicating the admissible absolute error between
|
||||
actual- and approximate-volumes. In general this should be zero since
|
||||
a typical radius implies a non-zero volume.
|
||||
"""
|
||||
|
||||
def actual_hypersphere_volume(dims, radius):
|
||||
# https://en.wikipedia.org/wiki/Volume_of_an_n-ball
|
||||
# Using tf.lgamma because we'd have to otherwise use SciPy which is not
|
||||
# a required dependency of core.
|
||||
radius = np.asarray(radius)
|
||||
dims = math_ops.cast(dims, dtype=radius.dtype)
|
||||
return math_ops.exp(
|
||||
(dims / 2.) * np.log(np.pi)
|
||||
- math_ops.lgamma(1. + dims / 2.)
|
||||
+ dims * math_ops.log(radius))
|
||||
|
||||
def is_in_ball(x, radius, center):
|
||||
return math_ops.cast(linalg_ops.norm(x - center, axis=-1) <= radius,
|
||||
dtype=x.dtype)
|
||||
|
||||
def monte_carlo_hypersphere_volume(dist, num_samples, radius, center):
|
||||
# https://en.wikipedia.org/wiki/Importance_sampling
|
||||
x = dist.sample(num_samples, seed=seed)
|
||||
return math_ops.reduce_mean(
|
||||
math_ops.exp(-dist.log_prob(x)) * is_in_ball(x, radius, center),
|
||||
axis=0)
|
||||
|
||||
[
|
||||
batch_shape_,
|
||||
actual_volume_,
|
||||
sample_volume_,
|
||||
] = sess.run([
|
||||
dist.batch_shape_tensor(),
|
||||
actual_hypersphere_volume(
|
||||
dims=dist.event_shape_tensor()[0],
|
||||
radius=radius),
|
||||
monte_carlo_hypersphere_volume(
|
||||
dist,
|
||||
num_samples=num_samples,
|
||||
radius=radius,
|
||||
center=center),
|
||||
])
|
||||
|
||||
self.assertAllClose(np.tile(actual_volume_, reps=batch_shape_),
|
||||
sample_volume_,
|
||||
rtol=rtol, atol=atol)
|
||||
|
||||
def run_test_sample_consistent_mean_covariance(
|
||||
self,
|
||||
sess,
|
||||
dist,
|
||||
num_samples=int(1e5),
|
||||
seed=24,
|
||||
rtol=1e-2,
|
||||
atol=0.,
|
||||
cov_rtol=None,
|
||||
cov_atol=None):
|
||||
"""Tests that sample/mean/covariance are consistent with each other.
|
||||
|
||||
"Consistency" means that `sample`, `mean`, `covariance`, etc all correspond
|
||||
to the same distribution.
|
||||
|
||||
Args:
|
||||
sess: Tensorflow session.
|
||||
dist: Distribution instance or object which implements `sample`,
|
||||
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
|
||||
num_samples: Python `int` scalar indicating the number of Monte-Carlo
|
||||
samples to draw from `dist`.
|
||||
seed: Python `int` indicating the seed to use when sampling from `dist`.
|
||||
In general it is not recommended to use `None` during a test as this
|
||||
increases the likelihood of spurious test failure.
|
||||
rtol: Python `float`-type indicating the admissible relative error between
|
||||
analytical and sample statistics.
|
||||
atol: Python `float`-type indicating the admissible absolute error between
|
||||
analytical and sample statistics.
|
||||
cov_rtol: Python `float`-type indicating the admissible relative error
|
||||
between analytical and sample covariance. Default: rtol.
|
||||
cov_atol: Python `float`-type indicating the admissible absolute error
|
||||
between analytical and sample covariance. Default: atol.
|
||||
"""
|
||||
|
||||
def vec_osquare(x):
|
||||
"""Computes the outer-product of a vector, i.e., x.T x."""
|
||||
return x[..., :, array_ops.newaxis] * x[..., array_ops.newaxis, :]
|
||||
|
||||
x = dist.sample(num_samples, seed=seed)
|
||||
sample_mean = math_ops.reduce_mean(x, axis=0)
|
||||
sample_covariance = math_ops.reduce_mean(
|
||||
vec_osquare(x - sample_mean), axis=0)
|
||||
sample_variance = array_ops.matrix_diag_part(sample_covariance)
|
||||
sample_stddev = math_ops.sqrt(sample_variance)
|
||||
|
||||
[
|
||||
sample_mean_,
|
||||
sample_covariance_,
|
||||
sample_variance_,
|
||||
sample_stddev_,
|
||||
mean_,
|
||||
covariance_,
|
||||
variance_,
|
||||
stddev_
|
||||
] = sess.run([
|
||||
sample_mean,
|
||||
sample_covariance,
|
||||
sample_variance,
|
||||
sample_stddev,
|
||||
dist.mean(),
|
||||
dist.covariance(),
|
||||
dist.variance(),
|
||||
dist.stddev(),
|
||||
])
|
||||
|
||||
self.assertAllClose(mean_, sample_mean_, rtol=rtol, atol=atol)
|
||||
self.assertAllClose(covariance_, sample_covariance_,
|
||||
rtol=cov_rtol or rtol,
|
||||
atol=cov_atol or atol)
|
||||
self.assertAllClose(variance_, sample_variance_, rtol=rtol, atol=atol)
|
||||
self.assertAllClose(stddev_, sample_stddev_, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
class VectorDiffeomixtureTest(VectorDistributionTestHelpers, test.TestCase):
|
||||
class VectorDiffeomixtureTest(
|
||||
test_util.VectorDistributionTestHelpers, test.TestCase):
|
||||
"""Tests the VectorDiffeomixture distribution."""
|
||||
|
||||
def testSampleProbConsistentBroadcastMix(self):
|
||||
|
313
tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
Normal file
313
tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
Normal file
@ -0,0 +1,313 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The PoissonLogNormalQuadratureCompound distribution class."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops.distributions import categorical as categorical_lib
|
||||
from tensorflow.python.ops.distributions import distribution as distribution_lib
|
||||
from tensorflow.python.ops.distributions import util as distribution_util
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PoissonLogNormalQuadratureCompound",
|
||||
]
|
||||
|
||||
|
||||
class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
|
||||
"""`PoissonLogNormalQuadratureCompound` distribution.
|
||||
|
||||
The `PoissonLogNormalQuadratureCompound` is an approximation to a
|
||||
Poisson-LogNormal [compound distribution](
|
||||
https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e.,
|
||||
|
||||
```none
|
||||
p(k|loc, scale)
|
||||
= int_{R_+} dl LogNormal(l | loc, scale) Poisson(k | l)
|
||||
= int_{R} dz ((lambda(z) sqrt(2) scale)
|
||||
* exp(-z**2) / (lambda(z) sqrt(2 pi) sigma)
|
||||
* Poisson(k | lambda(z)))
|
||||
= int_{R} dz exp(-z**2) / sqrt(pi) Poisson(k | lambda(z))
|
||||
approx= sum{ prob[d] Poisson(k | lambda(grid[d])) : d=0, ..., deg-1 }
|
||||
```
|
||||
|
||||
where `lambda(z) = exp(sqrt(2) scale z + loc)` and the `prob,grid` terms
|
||||
are from [Gauss--Hermite quadrature](
|
||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). Note that
|
||||
the second line made the substitution:
|
||||
`z(l) = (log(l) - loc) / (sqrt(2) scale)` which implies `lambda(z)` [above]
|
||||
and `dl = sqrt(2) scale lambda(z) dz`
|
||||
|
||||
In the non-approximation case, a draw from the LogNormal prior represents the
|
||||
Poisson rate parameter. Unfortunately, the non-approximate distribution lacks
|
||||
an analytical probability density function (pdf). Therefore the
|
||||
`PoissonLogNormalQuadratureCompound` class implements an approximation based
|
||||
on [Gauss-Hermite quadrature](
|
||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature).
|
||||
Note: although the `PoissonLogNormalQuadratureCompound` is approximately the
|
||||
Poisson-LogNormal compound distribution, it is itself a valid distribution.
|
||||
Viz., it possesses a `sample`, `log_prob`, `mean`, `variance`, etc. which are
|
||||
all mutually consistent.
|
||||
|
||||
#### Mathematical Details
|
||||
|
||||
The `PoissonLogNormalQuadratureCompound` approximates a Poisson-LogNormal
|
||||
[compound distribution](
|
||||
https://en.wikipedia.org/wiki/Compound_probability_distribution).
|
||||
Using variable-substitution and [Gauss-Hermite quadrature](
|
||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) we can
|
||||
redefine the distribution to be a parameter-less convex combination of `deg`
|
||||
different Poisson samples.
|
||||
|
||||
That is, defined over positive integers, this distribution is parameterized
|
||||
by a (batch of) `loc` and `scale` scalars.
|
||||
|
||||
The probability density function (pdf) is,
|
||||
|
||||
```none
|
||||
pdf(k | loc, scale, deg)
|
||||
= sum{ prob[d] Poisson(k | lambda=exp(sqrt(2) scale grid[d] + loc))
|
||||
: d=0, ..., deg-1 }
|
||||
```
|
||||
|
||||
where, [`grid, w = numpy.polynomial.hermite.hermgauss(deg)`](
|
||||
https://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.polynomial.hermite.hermgauss.html)
|
||||
and `prob = w / sqrt(pi)`.
|
||||
|
||||
#### Examples
|
||||
|
||||
```python
|
||||
ds = tf.contrib.distributions
|
||||
# Create two batches of PoissonLogNormalQuadratureCompounds, one with
|
||||
# prior `loc = 0.` and another with `loc = 1.` In both cases `scale = 1.`
|
||||
pln = ds.PoissonLogNormalQuadratureCompound(
|
||||
loc=[0., -0.5],
|
||||
scale=1.,
|
||||
quadrature_polynomial_degree=10,
|
||||
validate_args=True)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loc,
|
||||
scale,
|
||||
quadrature_polynomial_degree=8,
|
||||
validate_args=False,
|
||||
allow_nan_stats=True,
|
||||
name="PoissonLogNormalQuadratureCompound"):
|
||||
"""Constructs the PoissonLogNormalQuadratureCompound on `R**k`.
|
||||
|
||||
Args:
|
||||
loc: `float`-like (batch of) scalar `Tensor`; the location parameter of
|
||||
the LogNormal prior.
|
||||
scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
|
||||
the LogNormal prior.
|
||||
quadrature_polynomial_degree: Python `int`-like scalar.
|
||||
Default value: 8.
|
||||
validate_args: Python `bool`, default `False`. When `True` distribution
|
||||
parameters are checked for validity despite possibly degrading runtime
|
||||
performance. When `False` invalid inputs may silently render incorrect
|
||||
outputs.
|
||||
allow_nan_stats: Python `bool`, default `True`. When `True`,
|
||||
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
|
||||
indicate the result is undefined. When `False`, an exception is raised
|
||||
if one or more of the statistic's batch members are undefined.
|
||||
name: Python `str` name prefixed to Ops created by this class.
|
||||
|
||||
Raises:
|
||||
TypeError: if `loc.dtype != scale[0].dtype`.
|
||||
"""
|
||||
parameters = locals()
|
||||
with ops.name_scope(name, values=[loc, scale]):
|
||||
loc = ops.convert_to_tensor(loc, name="loc")
|
||||
self._loc = loc
|
||||
|
||||
scale = ops.convert_to_tensor(scale, name="scale")
|
||||
self._scale = scale
|
||||
|
||||
dtype = loc.dtype.base_dtype
|
||||
if dtype != scale.dtype.base_dtype:
|
||||
raise TypeError(
|
||||
"loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format(
|
||||
loc.dtype.name, scale.dtype.name))
|
||||
|
||||
self._degree = quadrature_polynomial_degree
|
||||
|
||||
grid, prob = np.polynomial.hermite.hermgauss(
|
||||
deg=quadrature_polynomial_degree)
|
||||
|
||||
# It should be that `sum(prob) == sqrt(pi)`, but self-normalization is
|
||||
# more numerically stable.
|
||||
prob = prob.astype(dtype.as_numpy_dtype)
|
||||
prob /= np.linalg.norm(prob, ord=1)
|
||||
|
||||
self._mixture_distribution = categorical_lib.Categorical(
|
||||
logits=np.log(prob),
|
||||
validate_args=validate_args,
|
||||
allow_nan_stats=allow_nan_stats)
|
||||
|
||||
# The following maps the broadcast of `loc` and `scale` to each grid
|
||||
# point, i.e., we are creating several log-rates that correspond to the
|
||||
# different Gauss-Hermite quadrature points and (possible) batches of
|
||||
# `loc` and `scale`.
|
||||
self._log_rate = (loc[..., array_ops.newaxis]
|
||||
+ np.sqrt(2.) * scale[..., array_ops.newaxis] * grid)
|
||||
|
||||
self._distribution = poisson_lib.Poisson(
|
||||
rate=math_ops.exp(self._log_rate, name="rate"),
|
||||
validate_args=validate_args,
|
||||
allow_nan_stats=allow_nan_stats)
|
||||
|
||||
super(PoissonLogNormalQuadratureCompound, self).__init__(
|
||||
dtype=dtype,
|
||||
reparameterization_type=distribution_lib.NOT_REPARAMETERIZED,
|
||||
validate_args=validate_args,
|
||||
allow_nan_stats=allow_nan_stats,
|
||||
parameters=parameters,
|
||||
graph_parents=[loc, scale],
|
||||
name=name)
|
||||
|
||||
@property
|
||||
def mixture_distribution(self):
|
||||
"""Distribution which randomly selects a Poisson with Gauss-Hermite rate."""
|
||||
return self._mixture_distribution
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
"""Base Poisson parameterized by a Gauss-Hermite grid of rates."""
|
||||
return self._distribution
|
||||
|
||||
@property
|
||||
def loc(self):
|
||||
"""Location parameter of the LogNormal prior."""
|
||||
return self._loc
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
"""Scale parameter of the LogNormal prior."""
|
||||
return self._scale
|
||||
|
||||
@property
|
||||
def quadrature_polynomial_degree(self):
|
||||
"""Polynomial largest exponent used for Gauss-Hermite quadrature."""
|
||||
return self._degree
|
||||
|
||||
def _batch_shape_tensor(self):
|
||||
return array_ops.broadcast_dynamic_shape(
|
||||
array_ops.shape(self.loc),
|
||||
array_ops.shape(self.scale))
|
||||
|
||||
def _batch_shape(self):
|
||||
return array_ops.broadcast_static_shape(
|
||||
self.loc.shape,
|
||||
self.scale.shape)
|
||||
|
||||
def _event_shape(self):
|
||||
return tensor_shape.scalar()
|
||||
|
||||
def _sample_n(self, n, seed=None):
|
||||
# Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
|
||||
# ids as a [n]-shaped vector.
|
||||
batch_size = (np.prod(self.batch_shape.as_list(), dtype=np.int32)
|
||||
if self.batch_shape.is_fully_defined()
|
||||
else math_ops.reduce_prod(self.batch_shape_tensor()))
|
||||
ids = self._mixture_distribution.sample(
|
||||
sample_shape=concat_vectors(
|
||||
[n],
|
||||
distribution_util.pick_vector(
|
||||
self.is_scalar_batch(),
|
||||
np.int32([]),
|
||||
[batch_size])),
|
||||
seed=distribution_util.gen_new_seed(
|
||||
seed, "poisson_lognormal_quadrature_compound"))
|
||||
# Stride `quadrature_polynomial_degree` for `batch_size` number of times.
|
||||
offset = math_ops.range(start=0,
|
||||
limit=batch_size * self._degree,
|
||||
delta=self._degree,
|
||||
dtype=ids.dtype)
|
||||
ids += offset
|
||||
rate = array_ops.gather(
|
||||
array_ops.reshape(self.distribution.rate, shape=[-1]), ids)
|
||||
rate = array_ops.reshape(
|
||||
rate, shape=concat_vectors([n], self.batch_shape_tensor()))
|
||||
return random_ops.random_poisson(
|
||||
lam=rate, shape=[], dtype=self.dtype, seed=seed)
|
||||
|
||||
def _log_prob(self, x):
|
||||
return math_ops.reduce_logsumexp(
|
||||
(self.mixture_distribution.logits
|
||||
+ self.distribution.log_prob(x[..., array_ops.newaxis])),
|
||||
axis=-1)
|
||||
|
||||
def _mean(self):
|
||||
return math_ops.exp(
|
||||
math_ops.reduce_logsumexp(
|
||||
self.mixture_distribution.logits + self._log_rate,
|
||||
axis=-1))
|
||||
|
||||
def _variance(self):
|
||||
return math_ops.exp(self._log_variance())
|
||||
|
||||
def _stddev(self):
|
||||
return math_ops.exp(0.5 * self._log_variance())
|
||||
|
||||
def _log_variance(self):
|
||||
# Following calculation is based on law of total variance:
|
||||
#
|
||||
# Var[Z] = E[Var[Z | V]] + Var[E[Z | V]]
|
||||
#
|
||||
# where,
|
||||
#
|
||||
# Z|v ~ interpolate_affine[v](distribution)
|
||||
# V ~ mixture_distrubution
|
||||
#
|
||||
# thus,
|
||||
#
|
||||
# E[Var[Z | V]] = sum{ prob[d] Var[d] : d=0, ..., deg-1 }
|
||||
# Var[E[Z | V]] = sum{ prob[d] (Mean[d] - Mean)**2 : d=0, ..., deg-1 }
|
||||
v = array_ops.stack([
|
||||
# log(self.distribution.variance()) = log(Var[d]) = log(rate[d])
|
||||
self._log_rate,
|
||||
# log((Mean[d] - Mean)**2)
|
||||
2. * math_ops.log(
|
||||
math_ops.abs(self.distribution.mean()
|
||||
- self._mean()[..., array_ops.newaxis])),
|
||||
], axis=-1)
|
||||
return math_ops.reduce_logsumexp(
|
||||
self.mixture_distribution.logits[..., array_ops.newaxis] + v,
|
||||
axis=[-2, -1])
|
||||
|
||||
|
||||
def static_value(x):
|
||||
"""Returns the static value of a `Tensor` or `None`."""
|
||||
return tensor_util.constant_value(ops.convert_to_tensor(x))
|
||||
|
||||
|
||||
def concat_vectors(*args):
|
||||
"""Concatenates input vectors, statically if possible."""
|
||||
args_ = [static_value(x) for x in args]
|
||||
if any(vec is None for vec in args_):
|
||||
return array_ops.concat(args, axis=0)
|
||||
return [val for vec in args_ for val in vec]
|
378
tensorflow/contrib/distributions/python/ops/test_util.py
Normal file
378
tensorflow/contrib/distributions/python/ops/test_util.py
Normal file
@ -0,0 +1,378 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities for testing distributions and/or bijectors."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import histogram_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DiscreteScalarDistributionTestHelpers",
|
||||
"VectorDistributionTestHelpers",
|
||||
]
|
||||
|
||||
|
||||
class DiscreteScalarDistributionTestHelpers(object):
|
||||
"""DiscreteScalarDistributionTestHelpers."""
|
||||
|
||||
def run_test_sample_consistent_log_prob(
|
||||
self, sess, dist,
|
||||
num_samples=int(1e5), num_threshold=int(1e3), seed=42,
|
||||
rtol=1e-2, atol=0.):
|
||||
"""Tests that sample/log_prob are consistent with each other.
|
||||
|
||||
"Consistency" means that `sample` and `log_prob` correspond to the same
|
||||
distribution.
|
||||
|
||||
Note: this test only verifies a necessary condition for consistency--it does
|
||||
does not verify sufficiency hence does not prove `sample`, `log_prob` truly
|
||||
are consistent.
|
||||
|
||||
Args:
|
||||
sess: Tensorflow session.
|
||||
dist: Distribution instance or object which implements `sample`,
|
||||
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
|
||||
num_samples: Python `int` scalar indicating the number of Monte-Carlo
|
||||
samples to draw from `dist`.
|
||||
num_threshold: Python `int` scalar indicating the number of samples a
|
||||
bucket must contain before being compared to the probability.
|
||||
Default value: 1e3; must be at least 1.
|
||||
Warning, set too high will cause test to falsely pass but setting too
|
||||
low will cause the test to falsely fail.
|
||||
seed: Python `int` indicating the seed to use when sampling from `dist`.
|
||||
In general it is not recommended to use `None` during a test as this
|
||||
increases the likelihood of spurious test failure.
|
||||
rtol: Python `float`-type indicating the admissible relative error between
|
||||
analytical and sample statistics.
|
||||
atol: Python `float`-type indicating the admissible absolute error between
|
||||
analytical and sample statistics.
|
||||
|
||||
Raises:
|
||||
ValueError: if `num_threshold < 1`.
|
||||
"""
|
||||
if num_threshold < 1:
|
||||
raise ValueError("num_threshold({}) must be at least 1.".format(
|
||||
num_threshold))
|
||||
# Histogram only supports vectors so we call it once per batch coordinate.
|
||||
y = dist.sample(num_samples, seed=seed)
|
||||
y = array_ops.reshape(y, shape=[num_samples, -1])
|
||||
batch_size = math_ops.reduce_prod(dist.batch_shape_tensor())
|
||||
batch_dims = array_ops.shape(dist.batch_shape_tensor())[0]
|
||||
edges_expanded_shape = 1 + array_ops.pad([-2], paddings=[[0, batch_dims]])
|
||||
for b, x in enumerate(array_ops.unstack(y, axis=1)):
|
||||
counts, edges = self.histogram(x)
|
||||
edges = array_ops.reshape(edges, edges_expanded_shape)
|
||||
probs = math_ops.exp(dist.log_prob(edges))
|
||||
probs = array_ops.reshape(probs, shape=[-1, batch_size])[:, b]
|
||||
|
||||
[counts_, probs_] = sess.run([counts, probs])
|
||||
valid = counts_ > num_threshold
|
||||
probs_ = probs_[valid]
|
||||
counts_ = counts_[valid]
|
||||
self.assertAllClose(probs_, counts_ / num_samples,
|
||||
rtol=rtol, atol=atol)
|
||||
|
||||
def run_test_sample_consistent_mean_variance(
|
||||
self, sess, dist,
|
||||
num_samples=int(1e5), seed=24,
|
||||
rtol=1e-2, atol=0.):
|
||||
"""Tests that sample/mean/variance are consistent with each other.
|
||||
|
||||
"Consistency" means that `sample`, `mean`, `variance`, etc all correspond
|
||||
to the same distribution.
|
||||
|
||||
Args:
|
||||
sess: Tensorflow session.
|
||||
dist: Distribution instance or object which implements `sample`,
|
||||
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
|
||||
num_samples: Python `int` scalar indicating the number of Monte-Carlo
|
||||
samples to draw from `dist`.
|
||||
seed: Python `int` indicating the seed to use when sampling from `dist`.
|
||||
In general it is not recommended to use `None` during a test as this
|
||||
increases the likelihood of spurious test failure.
|
||||
rtol: Python `float`-type indicating the admissible relative error between
|
||||
analytical and sample statistics.
|
||||
atol: Python `float`-type indicating the admissible absolute error between
|
||||
analytical and sample statistics.
|
||||
"""
|
||||
x = math_ops.to_float(dist.sample(num_samples, seed=seed))
|
||||
sample_mean = math_ops.reduce_mean(x, axis=0)
|
||||
sample_variance = math_ops.reduce_mean(
|
||||
math_ops.square(x - sample_mean), axis=0)
|
||||
sample_stddev = math_ops.sqrt(sample_variance)
|
||||
|
||||
[
|
||||
sample_mean_,
|
||||
sample_variance_,
|
||||
sample_stddev_,
|
||||
mean_,
|
||||
variance_,
|
||||
stddev_
|
||||
] = sess.run([
|
||||
sample_mean,
|
||||
sample_variance,
|
||||
sample_stddev,
|
||||
dist.mean(),
|
||||
dist.variance(),
|
||||
dist.stddev(),
|
||||
])
|
||||
|
||||
self.assertAllClose(mean_, sample_mean_, rtol=rtol, atol=atol)
|
||||
self.assertAllClose(variance_, sample_variance_, rtol=rtol, atol=atol)
|
||||
self.assertAllClose(stddev_, sample_stddev_, rtol=rtol, atol=atol)
|
||||
|
||||
def histogram(self, x, value_range=None, nbins=None, name=None):
|
||||
"""Return histogram of values.
|
||||
|
||||
Given the tensor `values`, this operation returns a rank 1 histogram
|
||||
counting the number of entries in `values` that fell into every bin. The
|
||||
bins are equal width and determined by the arguments `value_range` and
|
||||
`nbins`.
|
||||
|
||||
Args:
|
||||
x: 1D numeric `Tensor` of items to count.
|
||||
value_range: Shape [2] `Tensor`. `new_values <= value_range[0]` will be
|
||||
mapped to `hist[0]`, `values >= value_range[1]` will be mapped to
|
||||
`hist[-1]`. Must be same dtype as `x`.
|
||||
nbins: Scalar `int32 Tensor`. Number of histogram bins.
|
||||
name: Python `str` name prefixed to Ops created by this class.
|
||||
|
||||
Returns:
|
||||
counts: 1D `Tensor` of counts, i.e.,
|
||||
`counts[i] = sum{ edges[i-1] <= values[j] < edges[i] : j }`.
|
||||
edges: 1D `Tensor` characterizing intervals used for counting.
|
||||
"""
|
||||
with ops.name_scope(name, "histogram", [x]):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
if value_range is None:
|
||||
value_range = [math_ops.reduce_min(x), 1 + math_ops.reduce_max(x)]
|
||||
value_range = ops.convert_to_tensor(value_range, name="value_range")
|
||||
lo = value_range[0]
|
||||
hi = value_range[1]
|
||||
if nbins is None:
|
||||
nbins = math_ops.to_int32(hi - lo)
|
||||
delta = (hi - lo) / math_ops.cast(
|
||||
nbins, dtype=value_range.dtype.base_dtype)
|
||||
edges = math_ops.range(
|
||||
start=lo, limit=hi, delta=delta, dtype=x.dtype.base_dtype)
|
||||
counts = histogram_ops.histogram_fixed_width(
|
||||
x, value_range=value_range, nbins=nbins)
|
||||
return counts, edges
|
||||
|
||||
|
||||
class VectorDistributionTestHelpers(object):
|
||||
"""VectorDistributionTestHelpers helps test vector-event distributions."""
|
||||
|
||||
def run_test_sample_consistent_log_prob(
|
||||
self,
|
||||
sess,
|
||||
dist,
|
||||
num_samples=int(1e5),
|
||||
radius=1.,
|
||||
center=0.,
|
||||
seed=42,
|
||||
rtol=1e-2,
|
||||
atol=0.):
|
||||
"""Tests that sample/log_prob are mutually consistent.
|
||||
|
||||
"Consistency" means that `sample` and `log_prob` correspond to the same
|
||||
distribution.
|
||||
|
||||
The idea of this test is to compute the Monte-Carlo estimate of the volume
|
||||
enclosed by a hypersphere, i.e., the volume of an `n`-ball. While we could
|
||||
choose an arbitrary function to integrate, the hypersphere's volume is nice
|
||||
because it is intuitive, has an easy analytical expression, and works for
|
||||
`dimensions > 1`.
|
||||
|
||||
Technical Details:
|
||||
|
||||
Observe that:
|
||||
|
||||
```none
|
||||
int_{R**d} dx [x in Ball(radius=r, center=c)]
|
||||
= E_{p(X)}[ [X in Ball(r, c)] / p(X) ]
|
||||
= lim_{m->infty} m**-1 sum_j^m [x[j] in Ball(r, c)] / p(x[j]),
|
||||
where x[j] ~iid p(X)
|
||||
```
|
||||
|
||||
Thus, for fixed `m`, the above is approximately true when `sample` and
|
||||
`log_prob` are mutually consistent.
|
||||
|
||||
Furthermore, the above calculation has the analytical result:
|
||||
`pi**(d/2) r**d / Gamma(1 + d/2)`.
|
||||
|
||||
Note: this test only verifies a necessary condition for consistency--it does
|
||||
does not verify sufficiency hence does not prove `sample`, `log_prob` truly
|
||||
are consistent. For this reason we recommend testing several different
|
||||
hyperspheres (assuming the hypersphere is supported by the distribution).
|
||||
Furthermore, we gain additional trust in this test when also tested `sample`
|
||||
against the first, second moments
|
||||
(`run_test_sample_consistent_mean_covariance`); it is probably unlikely that
|
||||
a "best-effort" implementation of `log_prob` would incorrectly pass both
|
||||
tests and for different hyperspheres.
|
||||
|
||||
For a discussion on the analytical result (second-line) see:
|
||||
https://en.wikipedia.org/wiki/Volume_of_an_n-ball.
|
||||
|
||||
For a discussion of importance sampling (fourth-line) see:
|
||||
https://en.wikipedia.org/wiki/Importance_sampling.
|
||||
|
||||
Args:
|
||||
sess: Tensorflow session.
|
||||
dist: Distribution instance or object which implements `sample`,
|
||||
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`. The
|
||||
distribution must have non-zero probability of sampling every point
|
||||
enclosed by the hypersphere.
|
||||
num_samples: Python `int` scalar indicating the number of Monte-Carlo
|
||||
samples to draw from `dist`.
|
||||
radius: Python `float`-type indicating the radius of the `n`-ball which
|
||||
we're computing the volume.
|
||||
center: Python floating-type vector (or scalar) indicating the center of
|
||||
the `n`-ball which we're computing the volume. When scalar, the value is
|
||||
broadcast to all event dims.
|
||||
seed: Python `int` indicating the seed to use when sampling from `dist`.
|
||||
In general it is not recommended to use `None` during a test as this
|
||||
increases the likelihood of spurious test failure.
|
||||
rtol: Python `float`-type indicating the admissible relative error between
|
||||
actual- and approximate-volumes.
|
||||
atol: Python `float`-type indicating the admissible absolute error between
|
||||
actual- and approximate-volumes. In general this should be zero since
|
||||
a typical radius implies a non-zero volume.
|
||||
"""
|
||||
|
||||
def actual_hypersphere_volume(dims, radius):
|
||||
# https://en.wikipedia.org/wiki/Volume_of_an_n-ball
|
||||
# Using tf.lgamma because we'd have to otherwise use SciPy which is not
|
||||
# a required dependency of core.
|
||||
radius = np.asarray(radius)
|
||||
dims = math_ops.cast(dims, dtype=radius.dtype)
|
||||
return math_ops.exp(
|
||||
(dims / 2.) * np.log(np.pi)
|
||||
- math_ops.lgamma(1. + dims / 2.)
|
||||
+ dims * math_ops.log(radius))
|
||||
|
||||
def is_in_ball(x, radius, center):
|
||||
return math_ops.cast(linalg_ops.norm(x - center, axis=-1) <= radius,
|
||||
dtype=x.dtype)
|
||||
|
||||
def monte_carlo_hypersphere_volume(dist, num_samples, radius, center):
|
||||
# https://en.wikipedia.org/wiki/Importance_sampling
|
||||
x = dist.sample(num_samples, seed=seed)
|
||||
return math_ops.reduce_mean(
|
||||
math_ops.exp(-dist.log_prob(x)) * is_in_ball(x, radius, center),
|
||||
axis=0)
|
||||
|
||||
[
|
||||
batch_shape_,
|
||||
actual_volume_,
|
||||
sample_volume_,
|
||||
] = sess.run([
|
||||
dist.batch_shape_tensor(),
|
||||
actual_hypersphere_volume(
|
||||
dims=dist.event_shape_tensor()[0],
|
||||
radius=radius),
|
||||
monte_carlo_hypersphere_volume(
|
||||
dist,
|
||||
num_samples=num_samples,
|
||||
radius=radius,
|
||||
center=center),
|
||||
])
|
||||
|
||||
self.assertAllClose(np.tile(actual_volume_, reps=batch_shape_),
|
||||
sample_volume_,
|
||||
rtol=rtol, atol=atol)
|
||||
|
||||
def run_test_sample_consistent_mean_covariance(
|
||||
self,
|
||||
sess,
|
||||
dist,
|
||||
num_samples=int(1e5),
|
||||
seed=24,
|
||||
rtol=1e-2,
|
||||
atol=0.,
|
||||
cov_rtol=None,
|
||||
cov_atol=None):
|
||||
"""Tests that sample/mean/covariance are consistent with each other.
|
||||
|
||||
"Consistency" means that `sample`, `mean`, `covariance`, etc all correspond
|
||||
to the same distribution.
|
||||
|
||||
Args:
|
||||
sess: Tensorflow session.
|
||||
dist: Distribution instance or object which implements `sample`,
|
||||
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
|
||||
num_samples: Python `int` scalar indicating the number of Monte-Carlo
|
||||
samples to draw from `dist`.
|
||||
seed: Python `int` indicating the seed to use when sampling from `dist`.
|
||||
In general it is not recommended to use `None` during a test as this
|
||||
increases the likelihood of spurious test failure.
|
||||
rtol: Python `float`-type indicating the admissible relative error between
|
||||
analytical and sample statistics.
|
||||
atol: Python `float`-type indicating the admissible absolute error between
|
||||
analytical and sample statistics.
|
||||
cov_rtol: Python `float`-type indicating the admissible relative error
|
||||
between analytical and sample covariance. Default: rtol.
|
||||
cov_atol: Python `float`-type indicating the admissible absolute error
|
||||
between analytical and sample covariance. Default: atol.
|
||||
"""
|
||||
|
||||
x = dist.sample(num_samples, seed=seed)
|
||||
sample_mean = math_ops.reduce_mean(x, axis=0)
|
||||
sample_covariance = math_ops.reduce_mean(
|
||||
_vec_outer_square(x - sample_mean), axis=0)
|
||||
sample_variance = array_ops.matrix_diag_part(sample_covariance)
|
||||
sample_stddev = math_ops.sqrt(sample_variance)
|
||||
|
||||
[
|
||||
sample_mean_,
|
||||
sample_covariance_,
|
||||
sample_variance_,
|
||||
sample_stddev_,
|
||||
mean_,
|
||||
covariance_,
|
||||
variance_,
|
||||
stddev_
|
||||
] = sess.run([
|
||||
sample_mean,
|
||||
sample_covariance,
|
||||
sample_variance,
|
||||
sample_stddev,
|
||||
dist.mean(),
|
||||
dist.covariance(),
|
||||
dist.variance(),
|
||||
dist.stddev(),
|
||||
])
|
||||
|
||||
self.assertAllClose(mean_, sample_mean_, rtol=rtol, atol=atol)
|
||||
self.assertAllClose(covariance_, sample_covariance_,
|
||||
rtol=cov_rtol or rtol,
|
||||
atol=cov_atol or atol)
|
||||
self.assertAllClose(variance_, sample_variance_, rtol=rtol, atol=atol)
|
||||
self.assertAllClose(stddev_, sample_stddev_, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def _vec_outer_square(x, name=None):
|
||||
"""Computes the outer-product of a vector, i.e., x.T x."""
|
||||
with ops.name_scope(name, "vec_osquare", [x]):
|
||||
return x[..., :, array_ops.newaxis] * x[..., array_ops.newaxis, :]
|
@ -221,7 +221,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
||||
quadrature_polynomial_degree=8,
|
||||
validate_args=False,
|
||||
allow_nan_stats=True,
|
||||
name="VectorLocationScaleDiffeomixture"):
|
||||
name="VectorDiffeomixture"):
|
||||
"""Constructs the VectorDiffeomixture on `R**k`.
|
||||
|
||||
Args:
|
||||
@ -387,24 +387,34 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
||||
|
||||
@property
|
||||
def mixture_distribution(self):
|
||||
"""Distribution used to select a convex combination of affine transforms."""
|
||||
return self._mixture_distribution
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
"""Base scalar-event, scalar-batch distribution."""
|
||||
return self._distribution
|
||||
|
||||
@property
|
||||
def interpolate_weight(self):
|
||||
"""Grid of mixing probabilities, one for each grid point."""
|
||||
return self._interpolate_weight
|
||||
|
||||
@property
|
||||
def endpoint_affine(self):
|
||||
"""Affine transformation for each of `K` components."""
|
||||
return self._endpoint_affine
|
||||
|
||||
@property
|
||||
def interpolated_affine(self):
|
||||
"""Affine transformation for each convex combination of `K` components."""
|
||||
return self._interpolated_affine
|
||||
|
||||
@property
|
||||
def quadrature_polynomial_degree(self):
|
||||
"""Polynomial largest exponent used for Gauss-Hermite quadrature."""
|
||||
return self._degree
|
||||
|
||||
def _batch_shape_tensor(self):
|
||||
return self._batch_shape_
|
||||
|
||||
@ -457,12 +467,12 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
||||
|
||||
# Alternatively:
|
||||
# x = weight * x[0] + (1. - weight) * x[1]
|
||||
x = weight * (x[0] - x[1]) + array_ops.ones_like(x[0]) * x[1]
|
||||
x = weight * (x[0] - x[1]) + x[1]
|
||||
|
||||
return x
|
||||
|
||||
def _log_prob(self, x):
|
||||
# By convention, we always put the the grid points right-most.
|
||||
# By convention, we always put the grid points right-most.
|
||||
y = array_ops.stack(
|
||||
[aff.inverse(x) for aff in self.interpolated_affine],
|
||||
axis=-1)
|
||||
@ -740,8 +750,7 @@ def interpolate_loc(deg, interpolate_weight, loc):
|
||||
x = interpolate_weight[..., array_ops.newaxis] * loc[0]
|
||||
else:
|
||||
delta = loc[0] - loc[1]
|
||||
offset = array_ops.ones_like(loc[0]) * loc[1]
|
||||
x = interpolate_weight[..., array_ops.newaxis] * delta + offset
|
||||
x = interpolate_weight[..., array_ops.newaxis] * delta + loc[1]
|
||||
return [x[..., k, :] for k in range(deg)]
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user