Make tf.contrib.distributions
quadrature family accept a Tensor
for
`quadrature_grid_and_probs` argument. PiperOrigin-RevId: 172950094
This commit is contained in:
parent
8ff33271ea
commit
4948379369
@ -111,7 +111,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers,
|
|||||||
event_shape=[dims],
|
event_shape=[dims],
|
||||||
validate_args=True)
|
validate_args=True)
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess=sess,
|
sess_run_fn=sess.run,
|
||||||
dist=dist,
|
dist=dist,
|
||||||
num_samples=int(1e5),
|
num_samples=int(1e5),
|
||||||
radius=1.,
|
radius=1.,
|
||||||
@ -130,7 +130,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers,
|
|||||||
event_shape=[dims],
|
event_shape=[dims],
|
||||||
validate_args=True)
|
validate_args=True)
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess=sess,
|
sess_run_fn=sess.run,
|
||||||
dist=dist,
|
dist=dist,
|
||||||
num_samples=int(1e5),
|
num_samples=int(1e5),
|
||||||
radius=1.,
|
radius=1.,
|
||||||
|
@ -23,7 +23,6 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import independent as independent_lib
|
from tensorflow.contrib.distributions.python.ops import independent as independent_lib
|
||||||
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
|
from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib
|
||||||
from tensorflow.contrib.distributions.python.ops import test_util
|
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.distributions import normal as normal_lib
|
from tensorflow.python.ops.distributions import normal as normal_lib
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -41,8 +40,7 @@ def try_import(name): # pylint: disable=invalid-name
|
|||||||
stats = try_import("scipy.stats")
|
stats = try_import("scipy.stats")
|
||||||
|
|
||||||
|
|
||||||
class ProductDistributionTest(
|
class ProductDistributionTest(test.TestCase):
|
||||||
test_util.VectorDistributionTestHelpers, test.TestCase):
|
|
||||||
|
|
||||||
def testSampleAndLogProbUnivariate(self):
|
def testSampleAndLogProbUnivariate(self):
|
||||||
loc = np.float32([-1., 1])
|
loc = np.float32([-1., 1])
|
||||||
|
@ -94,10 +94,10 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
|
|||||||
loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]))
|
loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]))
|
||||||
# Ball centered at component0's mean.
|
# Ball centered at component0's mean.
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess, gm, radius=1., center=[-1., 1], rtol=0.02)
|
sess.run, gm, radius=1., center=[-1., 1], rtol=0.02)
|
||||||
# Larger ball centered at component1's mean.
|
# Larger ball centered at component1's mean.
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess, gm, radius=1., center=[1., -1], rtol=0.02)
|
sess.run, gm, radius=1., center=[1., -1], rtol=0.02)
|
||||||
|
|
||||||
def testLogCdf(self):
|
def testLogCdf(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -122,7 +122,7 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers,
|
|||||||
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
|
mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]),
|
||||||
components_distribution=mvn_diag_lib.MultivariateNormalDiag(
|
components_distribution=mvn_diag_lib.MultivariateNormalDiag(
|
||||||
loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]))
|
loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]))
|
||||||
self.run_test_sample_consistent_mean_covariance(sess, gm)
|
self.run_test_sample_consistent_mean_covariance(sess.run, gm)
|
||||||
|
|
||||||
def testVarianceConsistentCovariance(self):
|
def testVarianceConsistentCovariance(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
|
@ -22,6 +22,8 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import poisson_lognormal
|
from tensorflow.contrib.distributions.python.ops import poisson_lognormal
|
||||||
from tensorflow.contrib.distributions.python.ops import test_util
|
from tensorflow.contrib.distributions.python.ops import test_util
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -38,7 +40,7 @@ class PoissonLogNormalQuadratureCompoundTest(
|
|||||||
np.polynomial.hermite.hermgauss(deg=10)),
|
np.polynomial.hermite.hermgauss(deg=10)),
|
||||||
validate_args=True)
|
validate_args=True)
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess, pln, rtol=0.1)
|
sess.run, pln, rtol=0.1)
|
||||||
|
|
||||||
def testMeanVariance(self):
|
def testMeanVariance(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -49,7 +51,7 @@ class PoissonLogNormalQuadratureCompoundTest(
|
|||||||
np.polynomial.hermite.hermgauss(deg=10)),
|
np.polynomial.hermite.hermgauss(deg=10)),
|
||||||
validate_args=True)
|
validate_args=True)
|
||||||
self.run_test_sample_consistent_mean_variance(
|
self.run_test_sample_consistent_mean_variance(
|
||||||
sess, pln, rtol=0.02)
|
sess.run, pln, rtol=0.02)
|
||||||
|
|
||||||
def testSampleProbConsistentBroadcastScalar(self):
|
def testSampleProbConsistentBroadcastScalar(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -60,7 +62,7 @@ class PoissonLogNormalQuadratureCompoundTest(
|
|||||||
np.polynomial.hermite.hermgauss(deg=10)),
|
np.polynomial.hermite.hermgauss(deg=10)),
|
||||||
validate_args=True)
|
validate_args=True)
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess, pln, rtol=0.1, atol=0.01)
|
sess.run, pln, rtol=0.1, atol=0.01)
|
||||||
|
|
||||||
def testMeanVarianceBroadcastScalar(self):
|
def testMeanVarianceBroadcastScalar(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -71,7 +73,7 @@ class PoissonLogNormalQuadratureCompoundTest(
|
|||||||
np.polynomial.hermite.hermgauss(deg=10)),
|
np.polynomial.hermite.hermgauss(deg=10)),
|
||||||
validate_args=True)
|
validate_args=True)
|
||||||
self.run_test_sample_consistent_mean_variance(
|
self.run_test_sample_consistent_mean_variance(
|
||||||
sess, pln, rtol=0.1, atol=0.01)
|
sess.run, pln, rtol=0.1, atol=0.01)
|
||||||
|
|
||||||
def testSampleProbConsistentBroadcastBoth(self):
|
def testSampleProbConsistentBroadcastBoth(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -82,7 +84,7 @@ class PoissonLogNormalQuadratureCompoundTest(
|
|||||||
np.polynomial.hermite.hermgauss(deg=10)),
|
np.polynomial.hermite.hermgauss(deg=10)),
|
||||||
validate_args=True)
|
validate_args=True)
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess, pln, rtol=0.1, atol=0.08)
|
sess.run, pln, rtol=0.1, atol=0.08)
|
||||||
|
|
||||||
def testMeanVarianceBroadcastBoth(self):
|
def testMeanVarianceBroadcastBoth(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -93,7 +95,21 @@ class PoissonLogNormalQuadratureCompoundTest(
|
|||||||
np.polynomial.hermite.hermgauss(deg=10)),
|
np.polynomial.hermite.hermgauss(deg=10)),
|
||||||
validate_args=True)
|
validate_args=True)
|
||||||
self.run_test_sample_consistent_mean_variance(
|
self.run_test_sample_consistent_mean_variance(
|
||||||
sess, pln, rtol=0.1, atol=0.01)
|
sess.run, pln, rtol=0.1, atol=0.01)
|
||||||
|
|
||||||
|
def testSampleProbConsistentDynamicQuadrature(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
qgrid = array_ops.placeholder(dtype=dtypes.float32)
|
||||||
|
qprobs = array_ops.placeholder(dtype=dtypes.float32)
|
||||||
|
g, p = np.polynomial.hermite.hermgauss(deg=10)
|
||||||
|
pln = poisson_lognormal.PoissonLogNormalQuadratureCompound(
|
||||||
|
loc=-2.,
|
||||||
|
scale=1.1,
|
||||||
|
quadrature_grid_and_probs=(g, p),
|
||||||
|
validate_args=True)
|
||||||
|
self.run_test_sample_consistent_log_prob(
|
||||||
|
lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p}),
|
||||||
|
pln, rtol=0.1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -22,6 +22,8 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.contrib.distributions.python.ops import test_util
|
from tensorflow.contrib.distributions.python.ops import test_util
|
||||||
from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vector_diffeomixture_lib
|
from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vector_diffeomixture_lib
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops.distributions import normal as normal_lib
|
from tensorflow.python.ops.distributions import normal as normal_lib
|
||||||
from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib
|
from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib
|
||||||
from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib
|
from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib
|
||||||
@ -55,10 +57,10 @@ class VectorDiffeomixtureTest(
|
|||||||
validate_args=True)
|
validate_args=True)
|
||||||
# Ball centered at component0's mean.
|
# Ball centered at component0's mean.
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess, vdm, radius=2., center=0., rtol=0.005)
|
sess.run, vdm, radius=2., center=0., rtol=0.005)
|
||||||
# Larger ball centered at component1's mean.
|
# Larger ball centered at component1's mean.
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess, vdm, radius=4., center=2., rtol=0.005)
|
sess.run, vdm, radius=4., center=2., rtol=0.005)
|
||||||
|
|
||||||
def testSampleProbConsistentBroadcastMixNonStandardBase(self):
|
def testSampleProbConsistentBroadcastMixNonStandardBase(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -83,10 +85,10 @@ class VectorDiffeomixtureTest(
|
|||||||
validate_args=True)
|
validate_args=True)
|
||||||
# Ball centered at component0's mean.
|
# Ball centered at component0's mean.
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess, vdm, radius=2., center=1., rtol=0.006)
|
sess.run, vdm, radius=2., center=1., rtol=0.006)
|
||||||
# Larger ball centered at component1's mean.
|
# Larger ball centered at component1's mean.
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess, vdm, radius=4., center=3., rtol=0.009)
|
sess.run, vdm, radius=4., center=3., rtol=0.009)
|
||||||
|
|
||||||
def testSampleProbConsistentBroadcastMixBatch(self):
|
def testSampleProbConsistentBroadcastMixBatch(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -114,10 +116,10 @@ class VectorDiffeomixtureTest(
|
|||||||
validate_args=True)
|
validate_args=True)
|
||||||
# Ball centered at component0's mean.
|
# Ball centered at component0's mean.
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess, vdm, radius=2., center=0., rtol=0.005)
|
sess.run, vdm, radius=2., center=0., rtol=0.005)
|
||||||
# Larger ball centered at component1's mean.
|
# Larger ball centered at component1's mean.
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess, vdm, radius=4., center=2., rtol=0.005)
|
sess.run, vdm, radius=4., center=2., rtol=0.005)
|
||||||
|
|
||||||
def testMeanCovarianceNoBatch(self):
|
def testMeanCovarianceNoBatch(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -141,7 +143,7 @@ class VectorDiffeomixtureTest(
|
|||||||
],
|
],
|
||||||
validate_args=True)
|
validate_args=True)
|
||||||
self.run_test_sample_consistent_mean_covariance(
|
self.run_test_sample_consistent_mean_covariance(
|
||||||
sess, vdm, rtol=0.02, cov_rtol=0.06)
|
sess.run, vdm, rtol=0.02, cov_rtol=0.06)
|
||||||
|
|
||||||
def testMeanCovarianceNoBatchUncenteredNonStandardBase(self):
|
def testMeanCovarianceNoBatchUncenteredNonStandardBase(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -165,7 +167,7 @@ class VectorDiffeomixtureTest(
|
|||||||
],
|
],
|
||||||
validate_args=True)
|
validate_args=True)
|
||||||
self.run_test_sample_consistent_mean_covariance(
|
self.run_test_sample_consistent_mean_covariance(
|
||||||
sess, vdm, num_samples=int(1e6), rtol=0.01, cov_atol=0.025)
|
sess.run, vdm, num_samples=int(1e6), rtol=0.01, cov_atol=0.025)
|
||||||
|
|
||||||
def testMeanCovarianceBatch(self):
|
def testMeanCovarianceBatch(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -192,7 +194,40 @@ class VectorDiffeomixtureTest(
|
|||||||
],
|
],
|
||||||
validate_args=True)
|
validate_args=True)
|
||||||
self.run_test_sample_consistent_mean_covariance(
|
self.run_test_sample_consistent_mean_covariance(
|
||||||
sess, vdm, rtol=0.02, cov_rtol=0.06)
|
sess.run, vdm, rtol=0.02, cov_rtol=0.06)
|
||||||
|
|
||||||
|
def testSampleProbConsistentDynamicQuadrature(self):
|
||||||
|
with self.test_session() as sess:
|
||||||
|
qgrid = array_ops.placeholder(dtype=dtypes.float32)
|
||||||
|
qprobs = array_ops.placeholder(dtype=dtypes.float32)
|
||||||
|
g, p = np.polynomial.hermite.hermgauss(deg=8)
|
||||||
|
dims = 4
|
||||||
|
vdm = vector_diffeomixture_lib.VectorDiffeomixture(
|
||||||
|
mix_loc=[[0.], [1.]],
|
||||||
|
mix_scale=[1.],
|
||||||
|
distribution=normal_lib.Normal(0., 1.),
|
||||||
|
loc=[
|
||||||
|
None,
|
||||||
|
np.float32([2.]*dims),
|
||||||
|
],
|
||||||
|
scale=[
|
||||||
|
linop_identity_lib.LinearOperatorScaledIdentity(
|
||||||
|
num_rows=dims,
|
||||||
|
multiplier=np.float32(1.1),
|
||||||
|
is_positive_definite=True),
|
||||||
|
linop_diag_lib.LinearOperatorDiag(
|
||||||
|
diag=np.linspace(2.5, 3.5, dims, dtype=np.float32),
|
||||||
|
is_positive_definite=True),
|
||||||
|
],
|
||||||
|
quadrature_grid_and_probs=(g, p),
|
||||||
|
validate_args=True)
|
||||||
|
# Ball centered at component0's mean.
|
||||||
|
sess_run_fn = lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p})
|
||||||
|
self.run_test_sample_consistent_log_prob(
|
||||||
|
sess_run_fn, vdm, radius=2., center=0., rtol=0.005)
|
||||||
|
# Larger ball centered at component1's mean.
|
||||||
|
self.run_test_sample_consistent_log_prob(
|
||||||
|
sess_run_fn, vdm, radius=4., center=2., rtol=0.005)
|
||||||
|
|
||||||
# TODO(jvdillon): We've tested that (i) .sample and .log_prob are consistent,
|
# TODO(jvdillon): We've tested that (i) .sample and .log_prob are consistent,
|
||||||
# (ii) .mean, .stddev etc... and .sample are consistent. However, we haven't
|
# (ii) .mean, .stddev etc... and .sample are consistent. However, we haven't
|
||||||
|
@ -210,15 +210,15 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers,
|
|||||||
validate_args=True)
|
validate_args=True)
|
||||||
|
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess, sasnorm, radius=1.0, center=0., rtol=0.1)
|
sess.run, sasnorm, radius=1.0, center=0., rtol=0.1)
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess,
|
sess.run,
|
||||||
sasnorm,
|
sasnorm,
|
||||||
radius=1.0,
|
radius=1.0,
|
||||||
center=-0.15,
|
center=-0.15,
|
||||||
rtol=0.1)
|
rtol=0.1)
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess,
|
sess.run,
|
||||||
sasnorm,
|
sasnorm,
|
||||||
radius=1.0,
|
radius=1.0,
|
||||||
center=0.15,
|
center=0.15,
|
||||||
@ -237,15 +237,15 @@ class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers,
|
|||||||
validate_args=True)
|
validate_args=True)
|
||||||
|
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess, sasnorm, radius=1.0, center=0., rtol=0.1)
|
sess.run, sasnorm, radius=1.0, center=0., rtol=0.1)
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess,
|
sess.run,
|
||||||
sasnorm,
|
sasnorm,
|
||||||
radius=1.0,
|
radius=1.0,
|
||||||
center=-0.15,
|
center=-0.15,
|
||||||
rtol=0.1)
|
rtol=0.1)
|
||||||
self.run_test_sample_consistent_log_prob(
|
self.run_test_sample_consistent_log_prob(
|
||||||
sess,
|
sess.run,
|
||||||
sasnorm,
|
sasnorm,
|
||||||
radius=1.0,
|
radius=1.0,
|
||||||
center=0.15,
|
center=0.15,
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.contrib.distributions.python.ops import distribution_util
|
||||||
from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib
|
from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
@ -29,7 +30,6 @@ from tensorflow.python.ops import math_ops
|
|||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops.distributions import categorical as categorical_lib
|
from tensorflow.python.ops.distributions import categorical as categorical_lib
|
||||||
from tensorflow.python.ops.distributions import distribution as distribution_lib
|
from tensorflow.python.ops.distributions import distribution as distribution_lib
|
||||||
from tensorflow.python.ops.distributions import util as distribution_util
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -55,8 +55,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
|
|||||||
```
|
```
|
||||||
|
|
||||||
where `lambda(z) = exp(sqrt(2) scale z + loc)` and the `prob,grid` terms
|
where `lambda(z) = exp(sqrt(2) scale z + loc)` and the `prob,grid` terms
|
||||||
are from [Gauss--Hermite quadrature](
|
are from [numerical quadrature](
|
||||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). Note that
|
https://en.wikipedia.org/wiki/Numerical_integration) (default:
|
||||||
|
[Gauss--Hermite quadrature](
|
||||||
|
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). Note that
|
||||||
the second line made the substitution:
|
the second line made the substitution:
|
||||||
`z(l) = (log(l) - loc) / (sqrt(2) scale)` which implies `lambda(z)` [above]
|
`z(l) = (log(l) - loc) / (sqrt(2) scale)` which implies `lambda(z)` [above]
|
||||||
and `dl = sqrt(2) scale lambda(z) dz`
|
and `dl = sqrt(2) scale lambda(z) dz`
|
||||||
@ -65,8 +67,11 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
|
|||||||
Poisson rate parameter. Unfortunately, the non-approximate distribution lacks
|
Poisson rate parameter. Unfortunately, the non-approximate distribution lacks
|
||||||
an analytical probability density function (pdf). Therefore the
|
an analytical probability density function (pdf). Therefore the
|
||||||
`PoissonLogNormalQuadratureCompound` class implements an approximation based
|
`PoissonLogNormalQuadratureCompound` class implements an approximation based
|
||||||
on [Gauss-Hermite quadrature](
|
on [numerical quadrature](
|
||||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature).
|
https://en.wikipedia.org/wiki/Numerical_integration) (default:
|
||||||
|
[Gauss--Hermite quadrature](
|
||||||
|
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)).
|
||||||
|
|
||||||
Note: although the `PoissonLogNormalQuadratureCompound` is approximately the
|
Note: although the `PoissonLogNormalQuadratureCompound` is approximately the
|
||||||
Poisson-LogNormal compound distribution, it is itself a valid distribution.
|
Poisson-LogNormal compound distribution, it is itself a valid distribution.
|
||||||
Viz., it possesses a `sample`, `log_prob`, `mean`, `variance`, etc. which are
|
Viz., it possesses a `sample`, `log_prob`, `mean`, `variance`, etc. which are
|
||||||
@ -76,9 +81,11 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
|
|||||||
|
|
||||||
The `PoissonLogNormalQuadratureCompound` approximates a Poisson-LogNormal
|
The `PoissonLogNormalQuadratureCompound` approximates a Poisson-LogNormal
|
||||||
[compound distribution](
|
[compound distribution](
|
||||||
https://en.wikipedia.org/wiki/Compound_probability_distribution).
|
https://en.wikipedia.org/wiki/Compound_probability_distribution). Using
|
||||||
Using variable-substitution and [Gauss-Hermite quadrature](
|
variable-substitution and [numerical quadrature](
|
||||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) we can
|
https://en.wikipedia.org/wiki/Numerical_integration) (default:
|
||||||
|
[Gauss--Hermite quadrature](
|
||||||
|
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can
|
||||||
redefine the distribution to be a parameter-less convex combination of `deg`
|
redefine the distribution to be a parameter-less convex combination of `deg`
|
||||||
different Poisson samples.
|
different Poisson samples.
|
||||||
|
|
||||||
@ -125,9 +132,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
|
|||||||
the LogNormal prior.
|
the LogNormal prior.
|
||||||
scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
|
scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of
|
||||||
the LogNormal prior.
|
the LogNormal prior.
|
||||||
quadrature_grid_and_probs: Python pair of `list`-like objects representing
|
quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
|
||||||
the sample points and the corresponding (possibly normalized) weight.
|
representing the sample points and the corresponding (possibly
|
||||||
When `None`, defaults to: `np.polynomial.hermite.hermgauss(deg=8)`.
|
normalized) weight. When `None`, defaults to:
|
||||||
|
`np.polynomial.hermite.hermgauss(deg=8)`.
|
||||||
validate_args: Python `bool`, default `False`. When `True` distribution
|
validate_args: Python `bool`, default `False`. When `True` distribution
|
||||||
parameters are checked for validity despite possibly degrading runtime
|
parameters are checked for validity despite possibly degrading runtime
|
||||||
performance. When `False` invalid inputs may silently render incorrect
|
performance. When `False` invalid inputs may silently render incorrect
|
||||||
@ -140,8 +148,6 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: if `loc.dtype != scale[0].dtype`.
|
TypeError: if `loc.dtype != scale[0].dtype`.
|
||||||
ValueError: if `quadrature_grid_and_probs is not None` and
|
|
||||||
`len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
|
|
||||||
"""
|
"""
|
||||||
parameters = locals()
|
parameters = locals()
|
||||||
with ops.name_scope(name, values=[loc, scale]):
|
with ops.name_scope(name, values=[loc, scale]):
|
||||||
@ -157,21 +163,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
|
|||||||
"loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format(
|
"loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format(
|
||||||
loc.dtype.name, scale.dtype.name))
|
loc.dtype.name, scale.dtype.name))
|
||||||
|
|
||||||
if quadrature_grid_and_probs is None:
|
grid, probs = distribution_util.process_quadrature_grid_and_probs(
|
||||||
grid, probs = np.polynomial.hermite.hermgauss(deg=8)
|
quadrature_grid_and_probs, dtype, validate_args)
|
||||||
else:
|
|
||||||
grid, probs = tuple(quadrature_grid_and_probs)
|
|
||||||
if len(grid) != len(probs):
|
|
||||||
raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
|
|
||||||
"same-length list-like objects")
|
|
||||||
grid = grid.astype(dtype.as_numpy_dtype)
|
|
||||||
probs = probs.astype(dtype.as_numpy_dtype)
|
|
||||||
probs /= np.linalg.norm(probs, ord=1)
|
|
||||||
self._quadrature_grid = grid
|
self._quadrature_grid = grid
|
||||||
self._quadrature_probs = probs
|
self._quadrature_probs = probs
|
||||||
|
self._quadrature_size = distribution_util.dimension_size(probs, axis=0)
|
||||||
|
|
||||||
self._mixture_distribution = categorical_lib.Categorical(
|
self._mixture_distribution = categorical_lib.Categorical(
|
||||||
logits=np.log(probs),
|
logits=math_ops.log(self._quadrature_probs),
|
||||||
validate_args=validate_args,
|
validate_args=validate_args,
|
||||||
allow_nan_stats=allow_nan_stats)
|
allow_nan_stats=allow_nan_stats)
|
||||||
|
|
||||||
@ -254,10 +253,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
|
|||||||
[batch_size])),
|
[batch_size])),
|
||||||
seed=distribution_util.gen_new_seed(
|
seed=distribution_util.gen_new_seed(
|
||||||
seed, "poisson_lognormal_quadrature_compound"))
|
seed, "poisson_lognormal_quadrature_compound"))
|
||||||
# Stride `quadrature_degree` for `batch_size` number of times.
|
# Stride `quadrature_size` for `batch_size` number of times.
|
||||||
offset = math_ops.range(start=0,
|
offset = math_ops.range(start=0,
|
||||||
limit=batch_size * len(self.quadrature_probs),
|
limit=batch_size * self._quadrature_size,
|
||||||
delta=len(self.quadrature_probs),
|
delta=self._quadrature_size,
|
||||||
dtype=ids.dtype)
|
dtype=ids.dtype)
|
||||||
ids += offset
|
ids += offset
|
||||||
rate = array_ops.gather(
|
rate = array_ops.gather(
|
||||||
|
@ -38,7 +38,7 @@ class DiscreteScalarDistributionTestHelpers(object):
|
|||||||
"""DiscreteScalarDistributionTestHelpers."""
|
"""DiscreteScalarDistributionTestHelpers."""
|
||||||
|
|
||||||
def run_test_sample_consistent_log_prob(
|
def run_test_sample_consistent_log_prob(
|
||||||
self, sess, dist,
|
self, sess_run_fn, dist,
|
||||||
num_samples=int(1e5), num_threshold=int(1e3), seed=42,
|
num_samples=int(1e5), num_threshold=int(1e3), seed=42,
|
||||||
rtol=1e-2, atol=0.):
|
rtol=1e-2, atol=0.):
|
||||||
"""Tests that sample/log_prob are consistent with each other.
|
"""Tests that sample/log_prob are consistent with each other.
|
||||||
@ -51,7 +51,9 @@ class DiscreteScalarDistributionTestHelpers(object):
|
|||||||
are consistent.
|
are consistent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sess: Tensorflow session.
|
sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and
|
||||||
|
returning a list of results after running one "step" of TensorFlow
|
||||||
|
computation, typically set to `sess.run`.
|
||||||
dist: Distribution instance or object which implements `sample`,
|
dist: Distribution instance or object which implements `sample`,
|
||||||
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
|
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
|
||||||
num_samples: Python `int` scalar indicating the number of Monte-Carlo
|
num_samples: Python `int` scalar indicating the number of Monte-Carlo
|
||||||
@ -87,7 +89,7 @@ class DiscreteScalarDistributionTestHelpers(object):
|
|||||||
probs = math_ops.exp(dist.log_prob(edges))
|
probs = math_ops.exp(dist.log_prob(edges))
|
||||||
probs = array_ops.reshape(probs, shape=[-1, batch_size])[:, b]
|
probs = array_ops.reshape(probs, shape=[-1, batch_size])[:, b]
|
||||||
|
|
||||||
[counts_, probs_] = sess.run([counts, probs])
|
[counts_, probs_] = sess_run_fn([counts, probs])
|
||||||
valid = counts_ > num_threshold
|
valid = counts_ > num_threshold
|
||||||
probs_ = probs_[valid]
|
probs_ = probs_[valid]
|
||||||
counts_ = counts_[valid]
|
counts_ = counts_[valid]
|
||||||
@ -95,7 +97,7 @@ class DiscreteScalarDistributionTestHelpers(object):
|
|||||||
rtol=rtol, atol=atol)
|
rtol=rtol, atol=atol)
|
||||||
|
|
||||||
def run_test_sample_consistent_mean_variance(
|
def run_test_sample_consistent_mean_variance(
|
||||||
self, sess, dist,
|
self, sess_run_fn, dist,
|
||||||
num_samples=int(1e5), seed=24,
|
num_samples=int(1e5), seed=24,
|
||||||
rtol=1e-2, atol=0.):
|
rtol=1e-2, atol=0.):
|
||||||
"""Tests that sample/mean/variance are consistent with each other.
|
"""Tests that sample/mean/variance are consistent with each other.
|
||||||
@ -104,7 +106,9 @@ class DiscreteScalarDistributionTestHelpers(object):
|
|||||||
to the same distribution.
|
to the same distribution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sess: Tensorflow session.
|
sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and
|
||||||
|
returning a list of results after running one "step" of TensorFlow
|
||||||
|
computation, typically set to `sess.run`.
|
||||||
dist: Distribution instance or object which implements `sample`,
|
dist: Distribution instance or object which implements `sample`,
|
||||||
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
|
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
|
||||||
num_samples: Python `int` scalar indicating the number of Monte-Carlo
|
num_samples: Python `int` scalar indicating the number of Monte-Carlo
|
||||||
@ -130,7 +134,7 @@ class DiscreteScalarDistributionTestHelpers(object):
|
|||||||
mean_,
|
mean_,
|
||||||
variance_,
|
variance_,
|
||||||
stddev_
|
stddev_
|
||||||
] = sess.run([
|
] = sess_run_fn([
|
||||||
sample_mean,
|
sample_mean,
|
||||||
sample_variance,
|
sample_variance,
|
||||||
sample_stddev,
|
sample_stddev,
|
||||||
@ -187,7 +191,7 @@ class VectorDistributionTestHelpers(object):
|
|||||||
|
|
||||||
def run_test_sample_consistent_log_prob(
|
def run_test_sample_consistent_log_prob(
|
||||||
self,
|
self,
|
||||||
sess,
|
sess_run_fn,
|
||||||
dist,
|
dist,
|
||||||
num_samples=int(1e5),
|
num_samples=int(1e5),
|
||||||
radius=1.,
|
radius=1.,
|
||||||
@ -240,7 +244,9 @@ class VectorDistributionTestHelpers(object):
|
|||||||
https://en.wikipedia.org/wiki/Importance_sampling.
|
https://en.wikipedia.org/wiki/Importance_sampling.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sess: Tensorflow session.
|
sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and
|
||||||
|
returning a list of results after running one "step" of TensorFlow
|
||||||
|
computation, typically set to `sess.run`.
|
||||||
dist: Distribution instance or object which implements `sample`,
|
dist: Distribution instance or object which implements `sample`,
|
||||||
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`. The
|
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`. The
|
||||||
distribution must have non-zero probability of sampling every point
|
distribution must have non-zero probability of sampling every point
|
||||||
@ -301,8 +307,8 @@ class VectorDistributionTestHelpers(object):
|
|||||||
init_op = variables_ops.global_variables_initializer()
|
init_op = variables_ops.global_variables_initializer()
|
||||||
|
|
||||||
# Execute graph.
|
# Execute graph.
|
||||||
sess.run(init_op)
|
sess_run_fn(init_op)
|
||||||
[batch_shape_, actual_volume_, sample_volume_] = sess.run([
|
[batch_shape_, actual_volume_, sample_volume_] = sess_run_fn([
|
||||||
batch_shape, actual_volume, sample_volume])
|
batch_shape, actual_volume, sample_volume])
|
||||||
|
|
||||||
# Check results.
|
# Check results.
|
||||||
@ -312,7 +318,7 @@ class VectorDistributionTestHelpers(object):
|
|||||||
|
|
||||||
def run_test_sample_consistent_mean_covariance(
|
def run_test_sample_consistent_mean_covariance(
|
||||||
self,
|
self,
|
||||||
sess,
|
sess_run_fn,
|
||||||
dist,
|
dist,
|
||||||
num_samples=int(1e5),
|
num_samples=int(1e5),
|
||||||
seed=24,
|
seed=24,
|
||||||
@ -326,7 +332,9 @@ class VectorDistributionTestHelpers(object):
|
|||||||
to the same distribution.
|
to the same distribution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sess: Tensorflow session.
|
sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and
|
||||||
|
returning a list of results after running one "step" of TensorFlow
|
||||||
|
computation, typically set to `sess.run`.
|
||||||
dist: Distribution instance or object which implements `sample`,
|
dist: Distribution instance or object which implements `sample`,
|
||||||
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
|
`log_prob`, `event_shape_tensor` and `batch_shape_tensor`.
|
||||||
num_samples: Python `int` scalar indicating the number of Monte-Carlo
|
num_samples: Python `int` scalar indicating the number of Monte-Carlo
|
||||||
@ -360,7 +368,7 @@ class VectorDistributionTestHelpers(object):
|
|||||||
covariance_,
|
covariance_,
|
||||||
variance_,
|
variance_,
|
||||||
stddev_
|
stddev_
|
||||||
] = sess.run([
|
] = sess_run_fn([
|
||||||
sample_mean,
|
sample_mean,
|
||||||
sample_covariance,
|
sample_covariance,
|
||||||
sample_variance,
|
sample_variance,
|
||||||
|
@ -73,8 +73,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
|||||||
denotes matrix multiplication. However, the non-approximate distribution does
|
denotes matrix multiplication. However, the non-approximate distribution does
|
||||||
not have an analytical probability density function (pdf). Therefore the
|
not have an analytical probability density function (pdf). Therefore the
|
||||||
`VectorDiffeomixture` class implements an approximation based on
|
`VectorDiffeomixture` class implements an approximation based on
|
||||||
[Gauss-Hermite quadrature](
|
[numerical quadrature](
|
||||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). I.e., in
|
https://en.wikipedia.org/wiki/Numerical_integration) (default:
|
||||||
|
[Gauss--Hermite quadrature](
|
||||||
|
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). I.e., in
|
||||||
Note: although the `VectorDiffeomixture` is approximately the
|
Note: although the `VectorDiffeomixture` is approximately the
|
||||||
`SoftmaxNormal-Distribution` compound distribution, it is itself a valid
|
`SoftmaxNormal-Distribution` compound distribution, it is itself a valid
|
||||||
distribution. It possesses a `sample`, `log_prob`, `mean`, `covariance` which
|
distribution. It possesses a `sample`, `log_prob`, `mean`, `covariance` which
|
||||||
@ -109,8 +111,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
|||||||
The `VectorDiffeomixture` approximates a SoftmaxNormal-mixed ("prior")
|
The `VectorDiffeomixture` approximates a SoftmaxNormal-mixed ("prior")
|
||||||
[compound distribution](
|
[compound distribution](
|
||||||
https://en.wikipedia.org/wiki/Compound_probability_distribution).
|
https://en.wikipedia.org/wiki/Compound_probability_distribution).
|
||||||
Using variable-substitution and [Gauss-Hermite quadrature](
|
Using variable-substitution and [numerical quadrature](
|
||||||
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) we can
|
https://en.wikipedia.org/wiki/Numerical_integration) (default:
|
||||||
|
[Gauss--Hermite quadrature](
|
||||||
|
https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can
|
||||||
redefine the distribution to be a parameter-less convex combination of `K`
|
redefine the distribution to be a parameter-less convex combination of `K`
|
||||||
different affine combinations of a `d` iid samples from `distribution`.
|
different affine combinations of a `d` iid samples from `distribution`.
|
||||||
|
|
||||||
@ -141,7 +145,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
|||||||
and,
|
and,
|
||||||
|
|
||||||
```none
|
```none
|
||||||
grid, weight = np.polynomial.hermite.hermgauss(quadrature_degree)
|
grid, weight = np.polynomial.hermite.hermgauss(quadrature_size)
|
||||||
prob[k] = weight[k] / sqrt(pi)
|
prob[k] = weight[k] / sqrt(pi)
|
||||||
lambda[k; i] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[i])
|
lambda[k; i] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[i])
|
||||||
```
|
```
|
||||||
@ -248,9 +252,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
|||||||
`k`-th element represents the `scale` used for the `k`-th affine
|
`k`-th element represents the `scale` used for the `k`-th affine
|
||||||
transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
|
transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`,
|
||||||
`b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
|
`b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices
|
||||||
quadrature_grid_and_probs: Python pair of `list`-like objects representing
|
quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
|
||||||
the sample points and the corresponding (possibly normalized) weight.
|
representing the sample points and the corresponding (possibly
|
||||||
When `None`, defaults to: `np.polynomial.hermite.hermgauss(deg=8)`.
|
normalized) weight. When `None`, defaults to:
|
||||||
|
`np.polynomial.hermite.hermgauss(deg=8)`.
|
||||||
validate_args: Python `bool`, default `False`. When `True` distribution
|
validate_args: Python `bool`, default `False`. When `True` distribution
|
||||||
parameters are checked for validity despite possibly degrading runtime
|
parameters are checked for validity despite possibly degrading runtime
|
||||||
performance. When `False` invalid inputs may silently render incorrect
|
performance. When `False` invalid inputs may silently render incorrect
|
||||||
@ -317,24 +322,17 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
|||||||
raise NotImplementedError("Currently only bimixtures are supported; "
|
raise NotImplementedError("Currently only bimixtures are supported; "
|
||||||
"len(scale)={} is not 2.".format(len(scale)))
|
"len(scale)={} is not 2.".format(len(scale)))
|
||||||
|
|
||||||
if quadrature_grid_and_probs is None:
|
grid, probs = distribution_util.process_quadrature_grid_and_probs(
|
||||||
grid, probs = np.polynomial.hermite.hermgauss(deg=8)
|
quadrature_grid_and_probs, dtype, validate_args)
|
||||||
else:
|
|
||||||
grid, probs = tuple(quadrature_grid_and_probs)
|
|
||||||
if len(grid) != len(probs):
|
|
||||||
raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
|
|
||||||
"same-length list-like objects")
|
|
||||||
grid = grid.astype(dtype.as_numpy_dtype)
|
|
||||||
probs = probs.astype(dtype.as_numpy_dtype)
|
|
||||||
probs /= np.linalg.norm(probs, ord=1)
|
|
||||||
self._quadrature_grid = grid
|
self._quadrature_grid = grid
|
||||||
self._quadrature_probs = probs
|
self._quadrature_probs = probs
|
||||||
|
self._quadrature_size = distribution_util.dimension_size(probs, axis=0)
|
||||||
|
|
||||||
# Note: by creating the logits as `log(prob)` we ensure that
|
# Note: by creating the logits as `log(prob)` we ensure that
|
||||||
# `self.mixture_distribution.logits` is equivalent to
|
# `self.mixture_distribution.logits` is equivalent to
|
||||||
# `math_ops.log(self.mixture_distribution.probs)`.
|
# `math_ops.log(self.mixture_distribution.probs)`.
|
||||||
self._mixture_distribution = categorical_lib.Categorical(
|
self._mixture_distribution = categorical_lib.Categorical(
|
||||||
logits=np.log(probs),
|
logits=math_ops.log(probs),
|
||||||
validate_args=validate_args,
|
validate_args=validate_args,
|
||||||
allow_nan_stats=allow_nan_stats)
|
allow_nan_stats=allow_nan_stats)
|
||||||
|
|
||||||
@ -361,10 +359,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
|||||||
validate_args=validate_args,
|
validate_args=validate_args,
|
||||||
name="interpolated_affine_{}".format(k))
|
name="interpolated_affine_{}".format(k))
|
||||||
for k, (loc_, scale_) in enumerate(zip(
|
for k, (loc_, scale_) in enumerate(zip(
|
||||||
interpolate_loc(len(self._quadrature_grid),
|
interpolate_loc(self._quadrature_size,
|
||||||
self._interpolate_weight,
|
self._interpolate_weight,
|
||||||
loc),
|
loc),
|
||||||
interpolate_scale(len(self._quadrature_grid),
|
interpolate_scale(self._quadrature_size,
|
||||||
self._interpolate_weight,
|
self._interpolate_weight,
|
||||||
scale)))]
|
scale)))]
|
||||||
|
|
||||||
@ -463,10 +461,10 @@ class VectorDiffeomixture(distribution_lib.Distribution):
|
|||||||
seed=distribution_util.gen_new_seed(
|
seed=distribution_util.gen_new_seed(
|
||||||
seed, "vector_diffeomixture"))
|
seed, "vector_diffeomixture"))
|
||||||
|
|
||||||
# Stride `quadrature_degree` for `batch_size` number of times.
|
# Stride `quadrature_size` for `batch_size` number of times.
|
||||||
offset = math_ops.range(start=0,
|
offset = math_ops.range(start=0,
|
||||||
limit=batch_size * len(self.quadrature_probs),
|
limit=batch_size * self._quadrature_size,
|
||||||
delta=len(self.quadrature_probs),
|
delta=self._quadrature_size,
|
||||||
dtype=ids.dtype)
|
dtype=ids.dtype)
|
||||||
|
|
||||||
weight = array_ops.gather(
|
weight = array_ops.gather(
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import tensor_util
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import linalg_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
|
|
||||||
@ -1049,13 +1050,77 @@ def dimension_size(x, axis):
|
|||||||
"""Returns the size of a specific dimension."""
|
"""Returns the size of a specific dimension."""
|
||||||
# Since tf.gather isn't "constant-in, constant-out", we must first check the
|
# Since tf.gather isn't "constant-in, constant-out", we must first check the
|
||||||
# static shape or fallback to dynamic shape.
|
# static shape or fallback to dynamic shape.
|
||||||
num_rows = (None if x.get_shape().ndims is None
|
s = x.shape.with_rank_at_least(axis + 1)[axis].value
|
||||||
else x.get_shape()[axis].value)
|
if axis > -1 and s is not None:
|
||||||
if num_rows is not None:
|
return s
|
||||||
return num_rows
|
|
||||||
return array_ops.shape(x)[axis]
|
return array_ops.shape(x)[axis]
|
||||||
|
|
||||||
|
|
||||||
|
def process_quadrature_grid_and_probs(
|
||||||
|
quadrature_grid_and_probs, dtype, validate_args, name=None):
|
||||||
|
"""Validates quadrature grid, probs or computes them as necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
|
||||||
|
representing the sample points and the corresponding (possibly
|
||||||
|
normalized) weight. When `None`, defaults to:
|
||||||
|
`np.polynomial.hermite.hermgauss(deg=8)`.
|
||||||
|
dtype: The expected `dtype` of `grid` and `probs`.
|
||||||
|
validate_args: Python `bool`, default `False`. When `True` distribution
|
||||||
|
parameters are checked for validity despite possibly degrading runtime
|
||||||
|
performance. When `False` invalid inputs may silently render incorrect
|
||||||
|
outputs.
|
||||||
|
name: Python `str` name prefixed to Ops created by this class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s
|
||||||
|
representing the sample points and the corresponding (possibly
|
||||||
|
normalized) weight.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `quadrature_grid_and_probs is not None` and
|
||||||
|
`len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])`
|
||||||
|
"""
|
||||||
|
with ops.name_scope(name, "process_quadrature_grid_and_probs",
|
||||||
|
[quadrature_grid_and_probs]):
|
||||||
|
if quadrature_grid_and_probs is None:
|
||||||
|
grid, probs = np.polynomial.hermite.hermgauss(deg=8)
|
||||||
|
grid = grid.astype(dtype.as_numpy_dtype)
|
||||||
|
probs = probs.astype(dtype.as_numpy_dtype)
|
||||||
|
probs /= np.linalg.norm(probs, ord=1, keepdims=True)
|
||||||
|
grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
|
||||||
|
probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
|
||||||
|
return grid, probs
|
||||||
|
|
||||||
|
grid, probs = tuple(quadrature_grid_and_probs)
|
||||||
|
grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
|
||||||
|
probs = ops.convert_to_tensor(probs, name="unnormalized_probs",
|
||||||
|
dtype=dtype)
|
||||||
|
probs /= linalg_ops.norm(probs, ord=1, axis=-1, keep_dims=True,
|
||||||
|
name="probs")
|
||||||
|
|
||||||
|
def _static_dim_size(x, axis):
|
||||||
|
"""Returns the static size of a specific dimension or `None`."""
|
||||||
|
return x.shape.with_rank_at_least(axis + 1)[axis].value
|
||||||
|
|
||||||
|
m, n = _static_dim_size(probs, axis=0), _static_dim_size(grid, axis=0)
|
||||||
|
if m is not None and n is not None:
|
||||||
|
if m != n:
|
||||||
|
raise ValueError("`quadrature_grid_and_probs` must be a `tuple` of "
|
||||||
|
"same-length zero-th-dimension `Tensor`s "
|
||||||
|
"(saw lengths {}, {})".format(m, n))
|
||||||
|
elif validate_args:
|
||||||
|
grid = control_flow_ops.with_dependencies([
|
||||||
|
check_ops.assert_equal(
|
||||||
|
dimension_size(probs, axis=0),
|
||||||
|
dimension_size(grid, axis=0),
|
||||||
|
message=("`quadrature_grid_and_probs` must be a `tuple` of "
|
||||||
|
"same-length zero-th-dimension `Tensor`s")),
|
||||||
|
], grid)
|
||||||
|
|
||||||
|
return grid, probs
|
||||||
|
|
||||||
|
|
||||||
class AppendDocstring(object):
|
class AppendDocstring(object):
|
||||||
"""Helper class to promote private subclass docstring to public counterpart.
|
"""Helper class to promote private subclass docstring to public counterpart.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user