Test some distributions stuff in Eager as well as Graph
PiperOrigin-RevId: 197033485
This commit is contained in:
parent
831aa3984d
commit
01dbc6ac45
@ -299,7 +299,7 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
|
||||
GetContext(context), handle.get(), handle_dtype,
|
||||
static_cast<TF_DataType>(desired_dtype), self->status));
|
||||
if (TF_GetCode(self->status) != TF_OK) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
PyErr_SetString(PyExc_TypeError,
|
||||
tensorflow::strings::StrCat(
|
||||
"Error while casting from DataType ", handle_dtype,
|
||||
" to ", desired_dtype, ". ", TF_Message(self->status))
|
||||
|
@ -639,6 +639,14 @@ def assert_no_garbage_created(f):
|
||||
return decorator
|
||||
|
||||
|
||||
def run_all_in_graph_and_eager_modes(cls):
|
||||
base_decorator = run_in_graph_and_eager_modes()
|
||||
for name, value in cls.__dict__.copy().items():
|
||||
if callable(value) and name.startswith("test"):
|
||||
setattr(cls, name, base_decorator(value))
|
||||
return cls
|
||||
|
||||
|
||||
def run_in_graph_and_eager_modes(__unused__=None,
|
||||
config=None,
|
||||
use_gpu=True,
|
||||
|
@ -24,6 +24,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.distributions import bernoulli
|
||||
from tensorflow.python.ops.distributions import kullback_leibler
|
||||
@ -56,59 +57,65 @@ def entropy(p):
|
||||
|
||||
class BernoulliTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testP(self):
|
||||
p = [0.2, 0.4]
|
||||
dist = bernoulli.Bernoulli(probs=p)
|
||||
with self.test_session():
|
||||
self.assertAllClose(p, dist.probs.eval())
|
||||
self.assertAllClose(p, self.evaluate(dist.probs))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testLogits(self):
|
||||
logits = [-42., 42.]
|
||||
dist = bernoulli.Bernoulli(logits=logits)
|
||||
with self.test_session():
|
||||
self.assertAllClose(logits, dist.logits.eval())
|
||||
self.assertAllClose(logits, self.evaluate(dist.logits))
|
||||
|
||||
if not special:
|
||||
return
|
||||
|
||||
with self.test_session():
|
||||
self.assertAllClose(special.expit(logits), dist.probs.eval())
|
||||
self.assertAllClose(special.expit(logits), self.evaluate(dist.probs))
|
||||
|
||||
p = [0.01, 0.99, 0.42]
|
||||
dist = bernoulli.Bernoulli(probs=p)
|
||||
with self.test_session():
|
||||
self.assertAllClose(special.logit(p), dist.logits.eval())
|
||||
self.assertAllClose(special.logit(p), self.evaluate(dist.logits))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testInvalidP(self):
|
||||
invalid_ps = [1.01, 2.]
|
||||
for p in invalid_ps:
|
||||
with self.test_session():
|
||||
with self.assertRaisesOpError("probs has components greater than 1"):
|
||||
dist = bernoulli.Bernoulli(probs=p, validate_args=True)
|
||||
dist.probs.eval()
|
||||
self.evaluate(dist.probs)
|
||||
|
||||
invalid_ps = [-0.01, -3.]
|
||||
for p in invalid_ps:
|
||||
with self.test_session():
|
||||
with self.assertRaisesOpError("Condition x >= 0"):
|
||||
dist = bernoulli.Bernoulli(probs=p, validate_args=True)
|
||||
dist.probs.eval()
|
||||
self.evaluate(dist.probs)
|
||||
|
||||
valid_ps = [0.0, 0.5, 1.0]
|
||||
for p in valid_ps:
|
||||
with self.test_session():
|
||||
dist = bernoulli.Bernoulli(probs=p)
|
||||
self.assertEqual(p, dist.probs.eval()) # Should not fail
|
||||
self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testShapes(self):
|
||||
with self.test_session():
|
||||
for batch_shape in ([], [1], [2, 3, 4]):
|
||||
dist = make_bernoulli(batch_shape)
|
||||
self.assertAllEqual(batch_shape, dist.batch_shape.as_list())
|
||||
self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval())
|
||||
self.assertAllEqual(batch_shape,
|
||||
self.evaluate(dist.batch_shape_tensor()))
|
||||
self.assertAllEqual([], dist.event_shape.as_list())
|
||||
self.assertAllEqual([], dist.event_shape_tensor().eval())
|
||||
self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testDtype(self):
|
||||
dist = make_bernoulli([])
|
||||
self.assertEqual(dist.dtype, dtypes.int32)
|
||||
@ -126,6 +133,7 @@ class BernoulliTest(test.TestCase):
|
||||
self.assertEqual(dist64.dtype, dist64.sample(5).dtype)
|
||||
self.assertEqual(dist64.dtype, dist64.mode().dtype)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def _testPmf(self, **kwargs):
|
||||
dist = bernoulli.Bernoulli(**kwargs)
|
||||
with self.test_session():
|
||||
@ -147,8 +155,9 @@ class BernoulliTest(test.TestCase):
|
||||
# pylint: enable=bad-continuation
|
||||
|
||||
for x, expected_pmf in zip(xs, expected_pmfs):
|
||||
self.assertAllClose(dist.prob(x).eval(), expected_pmf)
|
||||
self.assertAllClose(dist.log_prob(x).eval(), np.log(expected_pmf))
|
||||
self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf)
|
||||
self.assertAllClose(
|
||||
self.evaluate(dist.log_prob(x)), np.log(expected_pmf))
|
||||
|
||||
def testPmfCorrectBroadcastDynamicShape(self):
|
||||
with self.test_session():
|
||||
@ -165,15 +174,17 @@ class BernoulliTest(test.TestCase):
|
||||
p: [0.2, 0.3, 0.4]
|
||||
}), [[0.2, 0.7, 0.4]])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testPmfInvalid(self):
|
||||
p = [0.1, 0.2, 0.7]
|
||||
with self.test_session():
|
||||
dist = bernoulli.Bernoulli(probs=p, validate_args=True)
|
||||
with self.assertRaisesOpError("must be non-negative."):
|
||||
dist.prob([1, 1, -1]).eval()
|
||||
self.evaluate(dist.prob([1, 1, -1]))
|
||||
with self.assertRaisesOpError("Elements cannot exceed 1."):
|
||||
dist.prob([2, 0, 1]).eval()
|
||||
self.evaluate(dist.prob([2, 0, 1]))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testPmfWithP(self):
|
||||
p = [[0.2, 0.4], [0.3, 0.6]]
|
||||
self._testPmf(probs=p)
|
||||
@ -203,7 +214,7 @@ class BernoulliTest(test.TestCase):
|
||||
|
||||
with self.test_session():
|
||||
dist = bernoulli.Bernoulli(probs=0.5)
|
||||
self.assertEqual(2, len(dist.log_prob([[1], [1]]).eval().shape))
|
||||
self.assertEqual(2, len(self.evaluate(dist.log_prob([[1], [1]])).shape))
|
||||
|
||||
with self.test_session():
|
||||
dist = bernoulli.Bernoulli(probs=0.5)
|
||||
@ -215,25 +226,31 @@ class BernoulliTest(test.TestCase):
|
||||
dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]])
|
||||
self.assertEqual((2, 1), dist.log_prob(1).get_shape())
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testBoundaryConditions(self):
|
||||
with self.test_session():
|
||||
dist = bernoulli.Bernoulli(probs=1.0)
|
||||
self.assertAllClose(np.nan, dist.log_prob(0).eval())
|
||||
self.assertAllClose([np.nan], [dist.log_prob(1).eval()])
|
||||
self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0)))
|
||||
self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testEntropyNoBatch(self):
|
||||
p = 0.2
|
||||
dist = bernoulli.Bernoulli(probs=p)
|
||||
with self.test_session():
|
||||
self.assertAllClose(dist.entropy().eval(), entropy(p))
|
||||
self.assertAllClose(self.evaluate(dist.entropy()), entropy(p))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testEntropyWithBatch(self):
|
||||
p = [[0.1, 0.7], [0.2, 0.6]]
|
||||
dist = bernoulli.Bernoulli(probs=p, validate_args=False)
|
||||
with self.test_session():
|
||||
self.assertAllClose(dist.entropy().eval(), [[entropy(0.1), entropy(0.7)],
|
||||
[entropy(0.2), entropy(0.6)]])
|
||||
self.assertAllClose(
|
||||
self.evaluate(dist.entropy()),
|
||||
[[entropy(0.1), entropy(0.7)], [entropy(0.2),
|
||||
entropy(0.6)]])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testSampleN(self):
|
||||
with self.test_session():
|
||||
p = [0.2, 0.6]
|
||||
@ -242,7 +259,7 @@ class BernoulliTest(test.TestCase):
|
||||
samples = dist.sample(n)
|
||||
samples.set_shape([n, 2])
|
||||
self.assertEqual(samples.dtype, dtypes.int32)
|
||||
sample_values = samples.eval()
|
||||
sample_values = self.evaluate(samples)
|
||||
self.assertTrue(np.all(sample_values >= 0))
|
||||
self.assertTrue(np.all(sample_values <= 1))
|
||||
# Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) /
|
||||
@ -262,51 +279,54 @@ class BernoulliTest(test.TestCase):
|
||||
n = 1000
|
||||
seed = 42
|
||||
self.assertAllEqual(
|
||||
dist.sample(n, seed).eval(), dist.sample(n, seed).eval())
|
||||
self.evaluate(dist.sample(n, seed)),
|
||||
self.evaluate(dist.sample(n, seed)))
|
||||
n = array_ops.placeholder(dtypes.int32)
|
||||
sample, sample = sess.run([dist.sample(n, seed), dist.sample(n, seed)],
|
||||
feed_dict={n: 1000})
|
||||
self.assertAllEqual(sample, sample)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testMean(self):
|
||||
with self.test_session():
|
||||
p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
|
||||
dist = bernoulli.Bernoulli(probs=p)
|
||||
self.assertAllEqual(dist.mean().eval(), p)
|
||||
self.assertAllEqual(self.evaluate(dist.mean()), p)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testVarianceAndStd(self):
|
||||
var = lambda p: p * (1. - p)
|
||||
with self.test_session():
|
||||
p = [[0.2, 0.7], [0.5, 0.4]]
|
||||
dist = bernoulli.Bernoulli(probs=p)
|
||||
self.assertAllClose(
|
||||
dist.variance().eval(),
|
||||
self.evaluate(dist.variance()),
|
||||
np.array(
|
||||
[[var(0.2), var(0.7)], [var(0.5), var(0.4)]], dtype=np.float32))
|
||||
self.assertAllClose(
|
||||
dist.stddev().eval(),
|
||||
self.evaluate(dist.stddev()),
|
||||
np.array(
|
||||
[[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
|
||||
[np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
|
||||
dtype=np.float32))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testBernoulliBernoulliKL(self):
|
||||
with self.test_session() as sess:
|
||||
batch_size = 6
|
||||
a_p = np.array([0.5] * batch_size, dtype=np.float32)
|
||||
b_p = np.array([0.4] * batch_size, dtype=np.float32)
|
||||
batch_size = 6
|
||||
a_p = np.array([0.5] * batch_size, dtype=np.float32)
|
||||
b_p = np.array([0.4] * batch_size, dtype=np.float32)
|
||||
|
||||
a = bernoulli.Bernoulli(probs=a_p)
|
||||
b = bernoulli.Bernoulli(probs=b_p)
|
||||
a = bernoulli.Bernoulli(probs=a_p)
|
||||
b = bernoulli.Bernoulli(probs=b_p)
|
||||
|
||||
kl = kullback_leibler.kl_divergence(a, b)
|
||||
kl_val = sess.run(kl)
|
||||
kl = kullback_leibler.kl_divergence(a, b)
|
||||
kl_val = self.evaluate(kl)
|
||||
|
||||
kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log(
|
||||
(1. - a_p) / (1. - b_p)))
|
||||
kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log(
|
||||
(1. - a_p) / (1. - b_p)))
|
||||
|
||||
self.assertEqual(kl.get_shape(), (batch_size,))
|
||||
self.assertAllClose(kl_val, kl_expected)
|
||||
self.assertEqual(kl.get_shape(), (batch_size,))
|
||||
self.assertAllClose(kl_val, kl_expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -24,6 +24,7 @@ from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops.distributions import beta as beta_lib
|
||||
@ -45,6 +46,7 @@ special = try_import("scipy.special")
|
||||
stats = try_import("scipy.stats")
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class BetaTest(test.TestCase):
|
||||
|
||||
def testSimpleShapes(self):
|
||||
@ -52,8 +54,8 @@ class BetaTest(test.TestCase):
|
||||
a = np.random.rand(3)
|
||||
b = np.random.rand(3)
|
||||
dist = beta_lib.Beta(a, b)
|
||||
self.assertAllEqual([], dist.event_shape_tensor().eval())
|
||||
self.assertAllEqual([3], dist.batch_shape_tensor().eval())
|
||||
self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
|
||||
self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor()))
|
||||
self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
|
||||
self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape)
|
||||
|
||||
@ -62,8 +64,8 @@ class BetaTest(test.TestCase):
|
||||
a = np.random.rand(3, 2, 2)
|
||||
b = np.random.rand(3, 2, 2)
|
||||
dist = beta_lib.Beta(a, b)
|
||||
self.assertAllEqual([], dist.event_shape_tensor().eval())
|
||||
self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval())
|
||||
self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
|
||||
self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
|
||||
self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
|
||||
self.assertEqual(
|
||||
tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
|
||||
@ -73,8 +75,8 @@ class BetaTest(test.TestCase):
|
||||
a = np.random.rand(3, 2, 2)
|
||||
b = np.random.rand(2, 2)
|
||||
dist = beta_lib.Beta(a, b)
|
||||
self.assertAllEqual([], dist.event_shape_tensor().eval())
|
||||
self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval())
|
||||
self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
|
||||
self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
|
||||
self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
|
||||
self.assertEqual(
|
||||
tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
|
||||
@ -85,7 +87,7 @@ class BetaTest(test.TestCase):
|
||||
with self.test_session():
|
||||
dist = beta_lib.Beta(a, b)
|
||||
self.assertEqual([1, 3], dist.concentration1.get_shape())
|
||||
self.assertAllClose(a, dist.concentration1.eval())
|
||||
self.assertAllClose(a, self.evaluate(dist.concentration1))
|
||||
|
||||
def testBetaProperty(self):
|
||||
a = [[1., 2, 3]]
|
||||
@ -93,24 +95,24 @@ class BetaTest(test.TestCase):
|
||||
with self.test_session():
|
||||
dist = beta_lib.Beta(a, b)
|
||||
self.assertEqual([1, 3], dist.concentration0.get_shape())
|
||||
self.assertAllClose(b, dist.concentration0.eval())
|
||||
self.assertAllClose(b, self.evaluate(dist.concentration0))
|
||||
|
||||
def testPdfXProper(self):
|
||||
a = [[1., 2, 3]]
|
||||
b = [[2., 4, 3]]
|
||||
with self.test_session():
|
||||
dist = beta_lib.Beta(a, b, validate_args=True)
|
||||
dist.prob([.1, .3, .6]).eval()
|
||||
dist.prob([.2, .3, .5]).eval()
|
||||
self.evaluate(dist.prob([.1, .3, .6]))
|
||||
self.evaluate(dist.prob([.2, .3, .5]))
|
||||
# Either condition can trigger.
|
||||
with self.assertRaisesOpError("sample must be positive"):
|
||||
dist.prob([-1., 0.1, 0.5]).eval()
|
||||
self.evaluate(dist.prob([-1., 0.1, 0.5]))
|
||||
with self.assertRaisesOpError("sample must be positive"):
|
||||
dist.prob([0., 0.1, 0.5]).eval()
|
||||
self.evaluate(dist.prob([0., 0.1, 0.5]))
|
||||
with self.assertRaisesOpError("sample must be less than `1`"):
|
||||
dist.prob([.1, .2, 1.2]).eval()
|
||||
self.evaluate(dist.prob([.1, .2, 1.2]))
|
||||
with self.assertRaisesOpError("sample must be less than `1`"):
|
||||
dist.prob([.1, .2, 1.0]).eval()
|
||||
self.evaluate(dist.prob([.1, .2, 1.0]))
|
||||
|
||||
def testPdfTwoBatches(self):
|
||||
with self.test_session():
|
||||
@ -119,7 +121,7 @@ class BetaTest(test.TestCase):
|
||||
x = [.5, .5]
|
||||
dist = beta_lib.Beta(a, b)
|
||||
pdf = dist.prob(x)
|
||||
self.assertAllClose([1., 3. / 2], pdf.eval())
|
||||
self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
|
||||
self.assertEqual((2,), pdf.get_shape())
|
||||
|
||||
def testPdfTwoBatchesNontrivialX(self):
|
||||
@ -129,7 +131,7 @@ class BetaTest(test.TestCase):
|
||||
x = [.3, .7]
|
||||
dist = beta_lib.Beta(a, b)
|
||||
pdf = dist.prob(x)
|
||||
self.assertAllClose([1, 63. / 50], pdf.eval())
|
||||
self.assertAllClose([1, 63. / 50], self.evaluate(pdf))
|
||||
self.assertEqual((2,), pdf.get_shape())
|
||||
|
||||
def testPdfUniformZeroBatch(self):
|
||||
@ -140,7 +142,7 @@ class BetaTest(test.TestCase):
|
||||
x = np.array([.1, .2, .3, .5, .8], dtype=np.float32)
|
||||
dist = beta_lib.Beta(a, b)
|
||||
pdf = dist.prob(x)
|
||||
self.assertAllClose([1.] * 5, pdf.eval())
|
||||
self.assertAllClose([1.] * 5, self.evaluate(pdf))
|
||||
self.assertEqual((5,), pdf.get_shape())
|
||||
|
||||
def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
|
||||
@ -150,7 +152,7 @@ class BetaTest(test.TestCase):
|
||||
x = [[.5, .5], [.3, .7]]
|
||||
dist = beta_lib.Beta(a, b)
|
||||
pdf = dist.prob(x)
|
||||
self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], pdf.eval())
|
||||
self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], self.evaluate(pdf))
|
||||
self.assertEqual((2, 2), pdf.get_shape())
|
||||
|
||||
def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
|
||||
@ -159,7 +161,7 @@ class BetaTest(test.TestCase):
|
||||
b = [1., 2]
|
||||
x = [[.5, .5], [.2, .8]]
|
||||
pdf = beta_lib.Beta(a, b).prob(x)
|
||||
self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], pdf.eval())
|
||||
self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], self.evaluate(pdf))
|
||||
self.assertEqual((2, 2), pdf.get_shape())
|
||||
|
||||
def testPdfXStretchedInBroadcastWhenSameRank(self):
|
||||
@ -168,7 +170,7 @@ class BetaTest(test.TestCase):
|
||||
b = [[1., 2], [2., 3]]
|
||||
x = [[.5, .5]]
|
||||
pdf = beta_lib.Beta(a, b).prob(x)
|
||||
self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], pdf.eval())
|
||||
self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
|
||||
self.assertEqual((2, 2), pdf.get_shape())
|
||||
|
||||
def testPdfXStretchedInBroadcastWhenLowerRank(self):
|
||||
@ -177,7 +179,7 @@ class BetaTest(test.TestCase):
|
||||
b = [[1., 2], [2., 3]]
|
||||
x = [.5, .5]
|
||||
pdf = beta_lib.Beta(a, b).prob(x)
|
||||
self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], pdf.eval())
|
||||
self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
|
||||
self.assertEqual((2, 2), pdf.get_shape())
|
||||
|
||||
def testBetaMean(self):
|
||||
@ -189,7 +191,7 @@ class BetaTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_mean = stats.beta.mean(a, b)
|
||||
self.assertAllClose(expected_mean, dist.mean().eval())
|
||||
self.assertAllClose(expected_mean, self.evaluate(dist.mean()))
|
||||
|
||||
def testBetaVariance(self):
|
||||
with session.Session():
|
||||
@ -200,7 +202,7 @@ class BetaTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_variance = stats.beta.var(a, b)
|
||||
self.assertAllClose(expected_variance, dist.variance().eval())
|
||||
self.assertAllClose(expected_variance, self.evaluate(dist.variance()))
|
||||
|
||||
def testBetaMode(self):
|
||||
with session.Session():
|
||||
@ -209,7 +211,7 @@ class BetaTest(test.TestCase):
|
||||
expected_mode = (a - 1) / (a + b - 2)
|
||||
dist = beta_lib.Beta(a, b)
|
||||
self.assertEqual(dist.mode().get_shape(), (3,))
|
||||
self.assertAllClose(expected_mode, dist.mode().eval())
|
||||
self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
|
||||
|
||||
def testBetaModeInvalid(self):
|
||||
with session.Session():
|
||||
@ -217,13 +219,13 @@ class BetaTest(test.TestCase):
|
||||
b = np.array([2., 4, 1.2])
|
||||
dist = beta_lib.Beta(a, b, allow_nan_stats=False)
|
||||
with self.assertRaisesOpError("Condition x < y.*"):
|
||||
dist.mode().eval()
|
||||
self.evaluate(dist.mode())
|
||||
|
||||
a = np.array([2., 2, 3])
|
||||
b = np.array([1., 4, 1.2])
|
||||
dist = beta_lib.Beta(a, b, allow_nan_stats=False)
|
||||
with self.assertRaisesOpError("Condition x < y.*"):
|
||||
dist.mode().eval()
|
||||
self.evaluate(dist.mode())
|
||||
|
||||
def testBetaModeEnableAllowNanStats(self):
|
||||
with session.Session():
|
||||
@ -234,7 +236,7 @@ class BetaTest(test.TestCase):
|
||||
expected_mode = (a - 1) / (a + b - 2)
|
||||
expected_mode[0] = np.nan
|
||||
self.assertEqual((3,), dist.mode().get_shape())
|
||||
self.assertAllClose(expected_mode, dist.mode().eval())
|
||||
self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
|
||||
|
||||
a = np.array([2., 2, 3])
|
||||
b = np.array([1., 4, 1.2])
|
||||
@ -243,7 +245,7 @@ class BetaTest(test.TestCase):
|
||||
expected_mode = (a - 1) / (a + b - 2)
|
||||
expected_mode[0] = np.nan
|
||||
self.assertEqual((3,), dist.mode().get_shape())
|
||||
self.assertAllClose(expected_mode, dist.mode().eval())
|
||||
self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
|
||||
|
||||
def testBetaEntropy(self):
|
||||
with session.Session():
|
||||
@ -254,7 +256,7 @@ class BetaTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_entropy = stats.beta.entropy(a, b)
|
||||
self.assertAllClose(expected_entropy, dist.entropy().eval())
|
||||
self.assertAllClose(expected_entropy, self.evaluate(dist.entropy()))
|
||||
|
||||
def testBetaSample(self):
|
||||
with self.test_session():
|
||||
@ -263,7 +265,7 @@ class BetaTest(test.TestCase):
|
||||
beta = beta_lib.Beta(a, b)
|
||||
n = constant_op.constant(100000)
|
||||
samples = beta.sample(n)
|
||||
sample_values = samples.eval()
|
||||
sample_values = self.evaluate(samples)
|
||||
self.assertEqual(sample_values.shape, (100000,))
|
||||
self.assertFalse(np.any(sample_values < 0.0))
|
||||
if not stats:
|
||||
@ -291,13 +293,13 @@ class BetaTest(test.TestCase):
|
||||
beta1 = beta_lib.Beta(concentration1=a_val,
|
||||
concentration0=b_val,
|
||||
name="beta1")
|
||||
samples1 = beta1.sample(n_val, seed=123456).eval()
|
||||
samples1 = self.evaluate(beta1.sample(n_val, seed=123456))
|
||||
|
||||
random_seed.set_random_seed(654321)
|
||||
beta2 = beta_lib.Beta(concentration1=a_val,
|
||||
concentration0=b_val,
|
||||
name="beta2")
|
||||
samples2 = beta2.sample(n_val, seed=123456).eval()
|
||||
samples2 = self.evaluate(beta2.sample(n_val, seed=123456))
|
||||
|
||||
self.assertAllClose(samples1, samples2)
|
||||
|
||||
@ -308,7 +310,7 @@ class BetaTest(test.TestCase):
|
||||
beta = beta_lib.Beta(a, b)
|
||||
n = constant_op.constant(100000)
|
||||
samples = beta.sample(n)
|
||||
sample_values = samples.eval()
|
||||
sample_values = self.evaluate(samples)
|
||||
self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
|
||||
self.assertFalse(np.any(sample_values < 0.0))
|
||||
if not stats:
|
||||
@ -325,7 +327,7 @@ class BetaTest(test.TestCase):
|
||||
a = 10. * np.random.random(shape).astype(dt)
|
||||
b = 10. * np.random.random(shape).astype(dt)
|
||||
x = np.random.random(shape).astype(dt)
|
||||
actual = beta_lib.Beta(a, b).cdf(x).eval()
|
||||
actual = self.evaluate(beta_lib.Beta(a, b).cdf(x))
|
||||
self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
|
||||
self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
|
||||
if not stats:
|
||||
@ -339,7 +341,7 @@ class BetaTest(test.TestCase):
|
||||
a = 10. * np.random.random(shape).astype(dt)
|
||||
b = 10. * np.random.random(shape).astype(dt)
|
||||
x = np.random.random(shape).astype(dt)
|
||||
actual = math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)).eval()
|
||||
actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)))
|
||||
self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
|
||||
self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
|
||||
if not stats:
|
||||
@ -350,46 +352,47 @@ class BetaTest(test.TestCase):
|
||||
with self.test_session():
|
||||
a, b = -4.2, -9.1
|
||||
dist = beta_lib.BetaWithSoftplusConcentration(a, b)
|
||||
self.assertAllClose(nn_ops.softplus(a).eval(), dist.concentration1.eval())
|
||||
self.assertAllClose(nn_ops.softplus(b).eval(), dist.concentration0.eval())
|
||||
self.assertAllClose(
|
||||
self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1))
|
||||
self.assertAllClose(
|
||||
self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0))
|
||||
|
||||
def testBetaBetaKL(self):
|
||||
with self.test_session() as sess:
|
||||
for shape in [(10,), (4, 5)]:
|
||||
a1 = 6.0 * np.random.random(size=shape) + 1e-4
|
||||
b1 = 6.0 * np.random.random(size=shape) + 1e-4
|
||||
a2 = 6.0 * np.random.random(size=shape) + 1e-4
|
||||
b2 = 6.0 * np.random.random(size=shape) + 1e-4
|
||||
# Take inverse softplus of values to test BetaWithSoftplusConcentration
|
||||
a1_sp = np.log(np.exp(a1) - 1.0)
|
||||
b1_sp = np.log(np.exp(b1) - 1.0)
|
||||
a2_sp = np.log(np.exp(a2) - 1.0)
|
||||
b2_sp = np.log(np.exp(b2) - 1.0)
|
||||
for shape in [(10,), (4, 5)]:
|
||||
a1 = 6.0 * np.random.random(size=shape) + 1e-4
|
||||
b1 = 6.0 * np.random.random(size=shape) + 1e-4
|
||||
a2 = 6.0 * np.random.random(size=shape) + 1e-4
|
||||
b2 = 6.0 * np.random.random(size=shape) + 1e-4
|
||||
# Take inverse softplus of values to test BetaWithSoftplusConcentration
|
||||
a1_sp = np.log(np.exp(a1) - 1.0)
|
||||
b1_sp = np.log(np.exp(b1) - 1.0)
|
||||
a2_sp = np.log(np.exp(a2) - 1.0)
|
||||
b2_sp = np.log(np.exp(b2) - 1.0)
|
||||
|
||||
d1 = beta_lib.Beta(concentration1=a1, concentration0=b1)
|
||||
d2 = beta_lib.Beta(concentration1=a2, concentration0=b2)
|
||||
d1_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a1_sp,
|
||||
concentration0=b1_sp)
|
||||
d2_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a2_sp,
|
||||
concentration0=b2_sp)
|
||||
d1 = beta_lib.Beta(concentration1=a1, concentration0=b1)
|
||||
d2 = beta_lib.Beta(concentration1=a2, concentration0=b2)
|
||||
d1_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a1_sp,
|
||||
concentration0=b1_sp)
|
||||
d2_sp = beta_lib.BetaWithSoftplusConcentration(concentration1=a2_sp,
|
||||
concentration0=b2_sp)
|
||||
|
||||
if not special:
|
||||
return
|
||||
kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) +
|
||||
(a1 - a2) * special.digamma(a1) +
|
||||
(b1 - b2) * special.digamma(b1) +
|
||||
(a2 - a1 + b2 - b1) * special.digamma(a1 + b1))
|
||||
if not special:
|
||||
return
|
||||
kl_expected = (special.betaln(a2, b2) - special.betaln(a1, b1) +
|
||||
(a1 - a2) * special.digamma(a1) +
|
||||
(b1 - b2) * special.digamma(b1) +
|
||||
(a2 - a1 + b2 - b1) * special.digamma(a1 + b1))
|
||||
|
||||
for dist1 in [d1, d1_sp]:
|
||||
for dist2 in [d2, d2_sp]:
|
||||
kl = kullback_leibler.kl_divergence(dist1, dist2)
|
||||
kl_val = sess.run(kl)
|
||||
self.assertEqual(kl.get_shape(), shape)
|
||||
self.assertAllClose(kl_val, kl_expected)
|
||||
for dist1 in [d1, d1_sp]:
|
||||
for dist2 in [d2, d2_sp]:
|
||||
kl = kullback_leibler.kl_divergence(dist1, dist2)
|
||||
kl_val = self.evaluate(kl)
|
||||
self.assertEqual(kl.get_shape(), shape)
|
||||
self.assertAllClose(kl_val, kl_expected)
|
||||
|
||||
# Make sure KL(d1||d1) is 0
|
||||
kl_same = sess.run(kullback_leibler.kl_divergence(d1, d1))
|
||||
self.assertAllClose(kl_same, np.zeros_like(kl_expected))
|
||||
# Make sure KL(d1||d1) is 0
|
||||
kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1))
|
||||
self.assertAllClose(kl_same, np.zeros_like(kl_expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -24,12 +24,14 @@ import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.distributions import bijector
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class BaseBijectorTest(test.TestCase):
|
||||
"""Tests properties of the Bijector base-class."""
|
||||
|
||||
@ -47,42 +49,38 @@ class BaseBijectorTest(test.TestCase):
|
||||
def __init__(self):
|
||||
super(_BareBonesBijector, self).__init__(forward_min_event_ndims=0)
|
||||
|
||||
with self.test_session() as sess:
|
||||
bij = _BareBonesBijector()
|
||||
self.assertEqual([], bij.graph_parents)
|
||||
self.assertEqual(False, bij.is_constant_jacobian)
|
||||
self.assertEqual(False, bij.validate_args)
|
||||
self.assertEqual(None, bij.dtype)
|
||||
self.assertEqual("bare_bones_bijector", bij.name)
|
||||
bij = _BareBonesBijector()
|
||||
self.assertEqual([], bij.graph_parents)
|
||||
self.assertEqual(False, bij.is_constant_jacobian)
|
||||
self.assertEqual(False, bij.validate_args)
|
||||
self.assertEqual(None, bij.dtype)
|
||||
self.assertEqual("bare_bones_bijector", bij.name)
|
||||
|
||||
for shape in [[], [1, 2], [1, 2, 3]]:
|
||||
[
|
||||
forward_event_shape_,
|
||||
inverse_event_shape_,
|
||||
] = sess.run([
|
||||
bij.inverse_event_shape_tensor(shape),
|
||||
bij.forward_event_shape_tensor(shape),
|
||||
])
|
||||
self.assertAllEqual(shape, forward_event_shape_)
|
||||
self.assertAllEqual(shape, bij.forward_event_shape(shape))
|
||||
self.assertAllEqual(shape, inverse_event_shape_)
|
||||
self.assertAllEqual(shape, bij.inverse_event_shape(shape))
|
||||
for shape in [[], [1, 2], [1, 2, 3]]:
|
||||
forward_event_shape_ = self.evaluate(
|
||||
bij.inverse_event_shape_tensor(shape))
|
||||
inverse_event_shape_ = self.evaluate(
|
||||
bij.forward_event_shape_tensor(shape))
|
||||
self.assertAllEqual(shape, forward_event_shape_)
|
||||
self.assertAllEqual(shape, bij.forward_event_shape(shape))
|
||||
self.assertAllEqual(shape, inverse_event_shape_)
|
||||
self.assertAllEqual(shape, bij.inverse_event_shape(shape))
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError, "inverse not implemented"):
|
||||
bij.inverse(0)
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError, "inverse not implemented"):
|
||||
bij.inverse(0)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError, "forward not implemented"):
|
||||
bij.forward(0)
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError, "forward not implemented"):
|
||||
bij.forward(0)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError, "inverse_log_det_jacobian not implemented"):
|
||||
bij.inverse_log_det_jacobian(0, event_ndims=0)
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError, "inverse_log_det_jacobian not implemented"):
|
||||
bij.inverse_log_det_jacobian(0, event_ndims=0)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError, "forward_log_det_jacobian not implemented"):
|
||||
bij.forward_log_det_jacobian(0, event_ndims=0)
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError, "forward_log_det_jacobian not implemented"):
|
||||
bij.forward_log_det_jacobian(0, event_ndims=0)
|
||||
|
||||
|
||||
class IntentionallyMissingError(Exception):
|
||||
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.distributions import dirichlet as dirichlet_lib
|
||||
@ -41,14 +42,15 @@ def try_import(name): # pylint: disable=invalid-name
|
||||
stats = try_import("scipy.stats")
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class DirichletTest(test.TestCase):
|
||||
|
||||
def testSimpleShapes(self):
|
||||
with self.test_session():
|
||||
alpha = np.random.rand(3)
|
||||
dist = dirichlet_lib.Dirichlet(alpha)
|
||||
self.assertEqual(3, dist.event_shape_tensor().eval())
|
||||
self.assertAllEqual([], dist.batch_shape_tensor().eval())
|
||||
self.assertEqual(3, self.evaluate(dist.event_shape_tensor()))
|
||||
self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor()))
|
||||
self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
|
||||
self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
|
||||
|
||||
@ -56,8 +58,8 @@ class DirichletTest(test.TestCase):
|
||||
with self.test_session():
|
||||
alpha = np.random.rand(3, 2, 2)
|
||||
dist = dirichlet_lib.Dirichlet(alpha)
|
||||
self.assertEqual(2, dist.event_shape_tensor().eval())
|
||||
self.assertAllEqual([3, 2], dist.batch_shape_tensor().eval())
|
||||
self.assertEqual(2, self.evaluate(dist.event_shape_tensor()))
|
||||
self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor()))
|
||||
self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
|
||||
self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
|
||||
|
||||
@ -66,22 +68,22 @@ class DirichletTest(test.TestCase):
|
||||
with self.test_session():
|
||||
dist = dirichlet_lib.Dirichlet(alpha)
|
||||
self.assertEqual([1, 3], dist.concentration.get_shape())
|
||||
self.assertAllClose(alpha, dist.concentration.eval())
|
||||
self.assertAllClose(alpha, self.evaluate(dist.concentration))
|
||||
|
||||
def testPdfXProper(self):
|
||||
alpha = [[1., 2, 3]]
|
||||
with self.test_session():
|
||||
dist = dirichlet_lib.Dirichlet(alpha, validate_args=True)
|
||||
dist.prob([.1, .3, .6]).eval()
|
||||
dist.prob([.2, .3, .5]).eval()
|
||||
self.evaluate(dist.prob([.1, .3, .6]))
|
||||
self.evaluate(dist.prob([.2, .3, .5]))
|
||||
# Either condition can trigger.
|
||||
with self.assertRaisesOpError("samples must be positive"):
|
||||
dist.prob([-1., 1.5, 0.5]).eval()
|
||||
self.evaluate(dist.prob([-1., 1.5, 0.5]))
|
||||
with self.assertRaisesOpError("samples must be positive"):
|
||||
dist.prob([0., .1, .9]).eval()
|
||||
self.evaluate(dist.prob([0., .1, .9]))
|
||||
with self.assertRaisesOpError(
|
||||
"sample last-dimension must sum to `1`"):
|
||||
dist.prob([.1, .2, .8]).eval()
|
||||
self.evaluate(dist.prob([.1, .2, .8]))
|
||||
|
||||
def testPdfZeroBatches(self):
|
||||
with self.test_session():
|
||||
@ -89,7 +91,7 @@ class DirichletTest(test.TestCase):
|
||||
x = [.5, .5]
|
||||
dist = dirichlet_lib.Dirichlet(alpha)
|
||||
pdf = dist.prob(x)
|
||||
self.assertAllClose(1., pdf.eval())
|
||||
self.assertAllClose(1., self.evaluate(pdf))
|
||||
self.assertEqual((), pdf.get_shape())
|
||||
|
||||
def testPdfZeroBatchesNontrivialX(self):
|
||||
@ -98,7 +100,7 @@ class DirichletTest(test.TestCase):
|
||||
x = [.3, .7]
|
||||
dist = dirichlet_lib.Dirichlet(alpha)
|
||||
pdf = dist.prob(x)
|
||||
self.assertAllClose(7. / 5, pdf.eval())
|
||||
self.assertAllClose(7. / 5, self.evaluate(pdf))
|
||||
self.assertEqual((), pdf.get_shape())
|
||||
|
||||
def testPdfUniformZeroBatches(self):
|
||||
@ -108,7 +110,7 @@ class DirichletTest(test.TestCase):
|
||||
x = [[.2, .5, .3], [.3, .4, .3]]
|
||||
dist = dirichlet_lib.Dirichlet(alpha)
|
||||
pdf = dist.prob(x)
|
||||
self.assertAllClose([2., 2.], pdf.eval())
|
||||
self.assertAllClose([2., 2.], self.evaluate(pdf))
|
||||
self.assertEqual((2), pdf.get_shape())
|
||||
|
||||
def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
|
||||
@ -117,7 +119,7 @@ class DirichletTest(test.TestCase):
|
||||
x = [[.5, .5], [.3, .7]]
|
||||
dist = dirichlet_lib.Dirichlet(alpha)
|
||||
pdf = dist.prob(x)
|
||||
self.assertAllClose([1., 7. / 5], pdf.eval())
|
||||
self.assertAllClose([1., 7. / 5], self.evaluate(pdf))
|
||||
self.assertEqual((2), pdf.get_shape())
|
||||
|
||||
def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
|
||||
@ -125,7 +127,7 @@ class DirichletTest(test.TestCase):
|
||||
alpha = [1., 2]
|
||||
x = [[.5, .5], [.2, .8]]
|
||||
pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
|
||||
self.assertAllClose([1., 8. / 5], pdf.eval())
|
||||
self.assertAllClose([1., 8. / 5], self.evaluate(pdf))
|
||||
self.assertEqual((2), pdf.get_shape())
|
||||
|
||||
def testPdfXStretchedInBroadcastWhenSameRank(self):
|
||||
@ -133,7 +135,7 @@ class DirichletTest(test.TestCase):
|
||||
alpha = [[1., 2], [2., 3]]
|
||||
x = [[.5, .5]]
|
||||
pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
|
||||
self.assertAllClose([1., 3. / 2], pdf.eval())
|
||||
self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
|
||||
self.assertEqual((2), pdf.get_shape())
|
||||
|
||||
def testPdfXStretchedInBroadcastWhenLowerRank(self):
|
||||
@ -141,7 +143,7 @@ class DirichletTest(test.TestCase):
|
||||
alpha = [[1., 2], [2., 3]]
|
||||
x = [.5, .5]
|
||||
pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
|
||||
self.assertAllClose([1., 3. / 2], pdf.eval())
|
||||
self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
|
||||
self.assertEqual((2), pdf.get_shape())
|
||||
|
||||
def testMean(self):
|
||||
@ -152,43 +154,44 @@ class DirichletTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_mean = stats.dirichlet.mean(alpha)
|
||||
self.assertAllClose(dirichlet.mean().eval(), expected_mean)
|
||||
self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean)
|
||||
|
||||
def testCovarianceFromSampling(self):
|
||||
alpha = np.array([[1., 2, 3],
|
||||
[2.5, 4, 0.01]], dtype=np.float32)
|
||||
with self.test_session() as sess:
|
||||
dist = dirichlet_lib.Dirichlet(alpha) # batch_shape=[2], event_shape=[3]
|
||||
x = dist.sample(int(250e3), seed=1)
|
||||
sample_mean = math_ops.reduce_mean(x, 0)
|
||||
x_centered = x - sample_mean[None, ...]
|
||||
sample_cov = math_ops.reduce_mean(math_ops.matmul(
|
||||
x_centered[..., None], x_centered[..., None, :]), 0)
|
||||
sample_var = array_ops.matrix_diag_part(sample_cov)
|
||||
sample_stddev = math_ops.sqrt(sample_var)
|
||||
[
|
||||
sample_mean_,
|
||||
sample_cov_,
|
||||
sample_var_,
|
||||
sample_stddev_,
|
||||
analytic_mean,
|
||||
analytic_cov,
|
||||
analytic_var,
|
||||
analytic_stddev,
|
||||
] = sess.run([
|
||||
sample_mean,
|
||||
sample_cov,
|
||||
sample_var,
|
||||
sample_stddev,
|
||||
dist.mean(),
|
||||
dist.covariance(),
|
||||
dist.variance(),
|
||||
dist.stddev(),
|
||||
])
|
||||
self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04)
|
||||
self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.06)
|
||||
self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.03)
|
||||
self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02)
|
||||
dist = dirichlet_lib.Dirichlet(alpha) # batch_shape=[2], event_shape=[3]
|
||||
x = dist.sample(int(250e3), seed=1)
|
||||
sample_mean = math_ops.reduce_mean(x, 0)
|
||||
x_centered = x - sample_mean[None, ...]
|
||||
sample_cov = math_ops.reduce_mean(math_ops.matmul(
|
||||
x_centered[..., None], x_centered[..., None, :]), 0)
|
||||
sample_var = array_ops.matrix_diag_part(sample_cov)
|
||||
sample_stddev = math_ops.sqrt(sample_var)
|
||||
|
||||
[
|
||||
sample_mean_,
|
||||
sample_cov_,
|
||||
sample_var_,
|
||||
sample_stddev_,
|
||||
analytic_mean,
|
||||
analytic_cov,
|
||||
analytic_var,
|
||||
analytic_stddev,
|
||||
] = self.evaluate([
|
||||
sample_mean,
|
||||
sample_cov,
|
||||
sample_var,
|
||||
sample_stddev,
|
||||
dist.mean(),
|
||||
dist.covariance(),
|
||||
dist.variance(),
|
||||
dist.stddev(),
|
||||
])
|
||||
|
||||
self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04)
|
||||
self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.06)
|
||||
self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.03)
|
||||
self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02)
|
||||
|
||||
def testVariance(self):
|
||||
with self.test_session():
|
||||
@ -201,7 +204,8 @@ class DirichletTest(test.TestCase):
|
||||
expected_covariance = np.diag(stats.dirichlet.var(alpha))
|
||||
expected_covariance += [[0., -2, -3], [-2, 0, -6],
|
||||
[-3, -6, 0]] / denominator
|
||||
self.assertAllClose(dirichlet.covariance().eval(), expected_covariance)
|
||||
self.assertAllClose(
|
||||
self.evaluate(dirichlet.covariance()), expected_covariance)
|
||||
|
||||
def testMode(self):
|
||||
with self.test_session():
|
||||
@ -209,7 +213,7 @@ class DirichletTest(test.TestCase):
|
||||
expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
|
||||
dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
|
||||
self.assertEqual(dirichlet.mode().get_shape(), [3])
|
||||
self.assertAllClose(dirichlet.mode().eval(), expected_mode)
|
||||
self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
|
||||
|
||||
def testModeInvalid(self):
|
||||
with self.test_session():
|
||||
@ -217,7 +221,7 @@ class DirichletTest(test.TestCase):
|
||||
dirichlet = dirichlet_lib.Dirichlet(concentration=alpha,
|
||||
allow_nan_stats=False)
|
||||
with self.assertRaisesOpError("Condition x < y.*"):
|
||||
dirichlet.mode().eval()
|
||||
self.evaluate(dirichlet.mode())
|
||||
|
||||
def testModeEnableAllowNanStats(self):
|
||||
with self.test_session():
|
||||
@ -227,7 +231,7 @@ class DirichletTest(test.TestCase):
|
||||
expected_mode = np.zeros_like(alpha) + np.nan
|
||||
|
||||
self.assertEqual(dirichlet.mode().get_shape(), [3])
|
||||
self.assertAllClose(dirichlet.mode().eval(), expected_mode)
|
||||
self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
|
||||
|
||||
def testEntropy(self):
|
||||
with self.test_session():
|
||||
@ -237,7 +241,7 @@ class DirichletTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_entropy = stats.dirichlet.entropy(alpha)
|
||||
self.assertAllClose(dirichlet.entropy().eval(), expected_entropy)
|
||||
self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy)
|
||||
|
||||
def testSample(self):
|
||||
with self.test_session():
|
||||
@ -245,7 +249,7 @@ class DirichletTest(test.TestCase):
|
||||
dirichlet = dirichlet_lib.Dirichlet(alpha)
|
||||
n = constant_op.constant(100000)
|
||||
samples = dirichlet.sample(n)
|
||||
sample_values = samples.eval()
|
||||
sample_values = self.evaluate(samples)
|
||||
self.assertEqual(sample_values.shape, (100000, 2))
|
||||
self.assertTrue(np.all(sample_values > 0.0))
|
||||
if not stats:
|
||||
|
@ -24,6 +24,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops.distributions import exponential as exponential_lib
|
||||
from tensorflow.python.platform import test
|
||||
@ -42,6 +43,7 @@ def try_import(name): # pylint: disable=invalid-name
|
||||
stats = try_import("scipy.stats")
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ExponentialTest(test.TestCase):
|
||||
|
||||
def testExponentialLogPDF(self):
|
||||
@ -61,8 +63,8 @@ class ExponentialTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
|
||||
self.assertAllClose(log_pdf.eval(), expected_log_pdf)
|
||||
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
|
||||
self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
|
||||
self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
|
||||
|
||||
def testExponentialCDF(self):
|
||||
with session.Session():
|
||||
@ -79,7 +81,7 @@ class ExponentialTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
|
||||
self.assertAllClose(cdf.eval(), expected_cdf)
|
||||
self.assertAllClose(self.evaluate(cdf), expected_cdf)
|
||||
|
||||
def testExponentialMean(self):
|
||||
with session.Session():
|
||||
@ -89,7 +91,7 @@ class ExponentialTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_mean = stats.expon.mean(scale=1 / lam_v)
|
||||
self.assertAllClose(exponential.mean().eval(), expected_mean)
|
||||
self.assertAllClose(self.evaluate(exponential.mean()), expected_mean)
|
||||
|
||||
def testExponentialVariance(self):
|
||||
with session.Session():
|
||||
@ -99,7 +101,8 @@ class ExponentialTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_variance = stats.expon.var(scale=1 / lam_v)
|
||||
self.assertAllClose(exponential.variance().eval(), expected_variance)
|
||||
self.assertAllClose(
|
||||
self.evaluate(exponential.variance()), expected_variance)
|
||||
|
||||
def testExponentialEntropy(self):
|
||||
with session.Session():
|
||||
@ -109,7 +112,8 @@ class ExponentialTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_entropy = stats.expon.entropy(scale=1 / lam_v)
|
||||
self.assertAllClose(exponential.entropy().eval(), expected_entropy)
|
||||
self.assertAllClose(
|
||||
self.evaluate(exponential.entropy()), expected_entropy)
|
||||
|
||||
def testExponentialSample(self):
|
||||
with self.test_session():
|
||||
@ -119,7 +123,7 @@ class ExponentialTest(test.TestCase):
|
||||
exponential = exponential_lib.Exponential(rate=lam)
|
||||
|
||||
samples = exponential.sample(n, seed=137)
|
||||
sample_values = samples.eval()
|
||||
sample_values = self.evaluate(samples)
|
||||
self.assertEqual(sample_values.shape, (100000, 2))
|
||||
self.assertFalse(np.any(sample_values < 0.0))
|
||||
if not stats:
|
||||
@ -142,7 +146,7 @@ class ExponentialTest(test.TestCase):
|
||||
samples = exponential.sample(n, seed=138)
|
||||
self.assertEqual(samples.get_shape(), (n, batch_size, 2))
|
||||
|
||||
sample_values = samples.eval()
|
||||
sample_values = self.evaluate(samples)
|
||||
|
||||
self.assertFalse(np.any(sample_values < 0.0))
|
||||
if not stats:
|
||||
@ -163,8 +167,8 @@ class ExponentialTest(test.TestCase):
|
||||
with self.test_session():
|
||||
lam = [-2.2, -3.4]
|
||||
exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam)
|
||||
self.assertAllClose(nn_ops.softplus(lam).eval(),
|
||||
exponential.rate.eval())
|
||||
self.assertAllClose(
|
||||
self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -24,6 +24,7 @@ import numpy as np
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops.distributions import laplace as laplace_lib
|
||||
from tensorflow.python.platform import test
|
||||
@ -43,6 +44,7 @@ def try_import(name): # pylint: disable=invalid-name
|
||||
stats = try_import("scipy.stats")
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LaplaceTest(test.TestCase):
|
||||
|
||||
def testLaplaceShape(self):
|
||||
@ -51,9 +53,9 @@ class LaplaceTest(test.TestCase):
|
||||
scale = constant_op.constant(11.0)
|
||||
laplace = laplace_lib.Laplace(loc=loc, scale=scale)
|
||||
|
||||
self.assertEqual(laplace.batch_shape_tensor().eval(), (5,))
|
||||
self.assertEqual(self.evaluate(laplace.batch_shape_tensor()), (5,))
|
||||
self.assertEqual(laplace.batch_shape, tensor_shape.TensorShape([5]))
|
||||
self.assertAllEqual(laplace.event_shape_tensor().eval(), [])
|
||||
self.assertAllEqual(self.evaluate(laplace.event_shape_tensor()), [])
|
||||
self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([]))
|
||||
|
||||
def testLaplaceLogPDF(self):
|
||||
@ -70,11 +72,11 @@ class LaplaceTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
|
||||
self.assertAllClose(log_pdf.eval(), expected_log_pdf)
|
||||
self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
|
||||
|
||||
pdf = laplace.prob(x)
|
||||
self.assertEqual(pdf.get_shape(), (6,))
|
||||
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
|
||||
self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
|
||||
|
||||
def testLaplaceLogPDFMultidimensional(self):
|
||||
with self.test_session():
|
||||
@ -86,11 +88,11 @@ class LaplaceTest(test.TestCase):
|
||||
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
|
||||
laplace = laplace_lib.Laplace(loc=loc, scale=scale)
|
||||
log_pdf = laplace.log_prob(x)
|
||||
log_pdf_values = log_pdf.eval()
|
||||
log_pdf_values = self.evaluate(log_pdf)
|
||||
self.assertEqual(log_pdf.get_shape(), (6, 2))
|
||||
|
||||
pdf = laplace.prob(x)
|
||||
pdf_values = pdf.eval()
|
||||
pdf_values = self.evaluate(pdf)
|
||||
self.assertEqual(pdf.get_shape(), (6, 2))
|
||||
if not stats:
|
||||
return
|
||||
@ -108,11 +110,11 @@ class LaplaceTest(test.TestCase):
|
||||
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
|
||||
laplace = laplace_lib.Laplace(loc=loc, scale=scale)
|
||||
log_pdf = laplace.log_prob(x)
|
||||
log_pdf_values = log_pdf.eval()
|
||||
log_pdf_values = self.evaluate(log_pdf)
|
||||
self.assertEqual(log_pdf.get_shape(), (6, 2))
|
||||
|
||||
pdf = laplace.prob(x)
|
||||
pdf_values = pdf.eval()
|
||||
pdf_values = self.evaluate(pdf)
|
||||
self.assertEqual(pdf.get_shape(), (6, 2))
|
||||
if not stats:
|
||||
return
|
||||
@ -136,7 +138,7 @@ class LaplaceTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v)
|
||||
self.assertAllClose(cdf.eval(), expected_cdf)
|
||||
self.assertAllClose(self.evaluate(cdf), expected_cdf)
|
||||
|
||||
def testLaplaceLogCDF(self):
|
||||
with self.test_session():
|
||||
@ -154,7 +156,7 @@ class LaplaceTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v)
|
||||
self.assertAllClose(cdf.eval(), expected_cdf)
|
||||
self.assertAllClose(self.evaluate(cdf), expected_cdf)
|
||||
|
||||
def testLaplaceLogSurvivalFunction(self):
|
||||
with self.test_session():
|
||||
@ -172,7 +174,7 @@ class LaplaceTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v)
|
||||
self.assertAllClose(sf.eval(), expected_sf)
|
||||
self.assertAllClose(self.evaluate(sf), expected_sf)
|
||||
|
||||
def testLaplaceMean(self):
|
||||
with self.test_session():
|
||||
@ -183,7 +185,7 @@ class LaplaceTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_means = stats.laplace.mean(loc_v, scale=scale_v)
|
||||
self.assertAllClose(laplace.mean().eval(), expected_means)
|
||||
self.assertAllClose(self.evaluate(laplace.mean()), expected_means)
|
||||
|
||||
def testLaplaceMode(self):
|
||||
with self.test_session():
|
||||
@ -191,7 +193,7 @@ class LaplaceTest(test.TestCase):
|
||||
scale_v = np.array([1.0, 4.0, 5.0])
|
||||
laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
|
||||
self.assertEqual(laplace.mode().get_shape(), (3,))
|
||||
self.assertAllClose(laplace.mode().eval(), loc_v)
|
||||
self.assertAllClose(self.evaluate(laplace.mode()), loc_v)
|
||||
|
||||
def testLaplaceVariance(self):
|
||||
with self.test_session():
|
||||
@ -202,7 +204,7 @@ class LaplaceTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_variances = stats.laplace.var(loc_v, scale=scale_v)
|
||||
self.assertAllClose(laplace.variance().eval(), expected_variances)
|
||||
self.assertAllClose(self.evaluate(laplace.variance()), expected_variances)
|
||||
|
||||
def testLaplaceStd(self):
|
||||
with self.test_session():
|
||||
@ -213,7 +215,7 @@ class LaplaceTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_stddev = stats.laplace.std(loc_v, scale=scale_v)
|
||||
self.assertAllClose(laplace.stddev().eval(), expected_stddev)
|
||||
self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev)
|
||||
|
||||
def testLaplaceEntropy(self):
|
||||
with self.test_session():
|
||||
@ -224,7 +226,7 @@ class LaplaceTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v)
|
||||
self.assertAllClose(laplace.entropy().eval(), expected_entropy)
|
||||
self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy)
|
||||
|
||||
def testLaplaceSample(self):
|
||||
with session.Session():
|
||||
@ -235,7 +237,7 @@ class LaplaceTest(test.TestCase):
|
||||
n = 100000
|
||||
laplace = laplace_lib.Laplace(loc=loc, scale=scale)
|
||||
samples = laplace.sample(n, seed=137)
|
||||
sample_values = samples.eval()
|
||||
sample_values = self.evaluate(samples)
|
||||
self.assertEqual(samples.get_shape(), (n,))
|
||||
self.assertEqual(sample_values.shape, (n,))
|
||||
if not stats:
|
||||
@ -260,7 +262,7 @@ class LaplaceTest(test.TestCase):
|
||||
laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
|
||||
n = 10000
|
||||
samples = laplace.sample(n, seed=137)
|
||||
sample_values = samples.eval()
|
||||
sample_values = self.evaluate(samples)
|
||||
self.assertEqual(samples.get_shape(), (n, 10, 100))
|
||||
self.assertEqual(sample_values.shape, (n, 10, 100))
|
||||
zeros = np.zeros_like(loc_v + scale_v) # 10 x 100
|
||||
@ -297,32 +299,31 @@ class LaplaceTest(test.TestCase):
|
||||
return ks < 0.02
|
||||
|
||||
def testLaplacePdfOfSampleMultiDims(self):
|
||||
with session.Session() as sess:
|
||||
laplace = laplace_lib.Laplace(loc=[7., 11.], scale=[[5.], [6.]])
|
||||
num = 50000
|
||||
samples = laplace.sample(num, seed=137)
|
||||
pdfs = laplace.prob(samples)
|
||||
sample_vals, pdf_vals = sess.run([samples, pdfs])
|
||||
self.assertEqual(samples.get_shape(), (num, 2, 2))
|
||||
self.assertEqual(pdfs.get_shape(), (num, 2, 2))
|
||||
self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
|
||||
self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
|
||||
self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
|
||||
self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
|
||||
if not stats:
|
||||
return
|
||||
self.assertAllClose(
|
||||
stats.laplace.mean(
|
||||
[[7., 11.], [7., 11.]], scale=np.array([[5., 5.], [6., 6.]])),
|
||||
sample_vals.mean(axis=0),
|
||||
rtol=0.05,
|
||||
atol=0.)
|
||||
self.assertAllClose(
|
||||
stats.laplace.var([[7., 11.], [7., 11.]],
|
||||
scale=np.array([[5., 5.], [6., 6.]])),
|
||||
sample_vals.var(axis=0),
|
||||
rtol=0.05,
|
||||
atol=0.)
|
||||
laplace = laplace_lib.Laplace(loc=[7., 11.], scale=[[5.], [6.]])
|
||||
num = 50000
|
||||
samples = laplace.sample(num, seed=137)
|
||||
pdfs = laplace.prob(samples)
|
||||
sample_vals, pdf_vals = self.evaluate([samples, pdfs])
|
||||
self.assertEqual(samples.get_shape(), (num, 2, 2))
|
||||
self.assertEqual(pdfs.get_shape(), (num, 2, 2))
|
||||
self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
|
||||
self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
|
||||
self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
|
||||
self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
|
||||
if not stats:
|
||||
return
|
||||
self.assertAllClose(
|
||||
stats.laplace.mean(
|
||||
[[7., 11.], [7., 11.]], scale=np.array([[5., 5.], [6., 6.]])),
|
||||
sample_vals.mean(axis=0),
|
||||
rtol=0.05,
|
||||
atol=0.)
|
||||
self.assertAllClose(
|
||||
stats.laplace.var([[7., 11.], [7., 11.]],
|
||||
scale=np.array([[5., 5.], [6., 6.]])),
|
||||
sample_vals.var(axis=0),
|
||||
rtol=0.05,
|
||||
atol=0.)
|
||||
|
||||
def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3):
|
||||
s_p = zip(sample_vals, pdf_vals)
|
||||
@ -338,24 +339,27 @@ class LaplaceTest(test.TestCase):
|
||||
with self.test_session():
|
||||
loc_v = constant_op.constant(0.0, name="loc")
|
||||
scale_v = constant_op.constant(-1.0, name="scale")
|
||||
laplace = laplace_lib.Laplace(
|
||||
loc=loc_v, scale=scale_v, validate_args=True)
|
||||
with self.assertRaisesOpError("scale"):
|
||||
laplace.mean().eval()
|
||||
with self.assertRaisesOpError(
|
||||
"Condition x > 0 did not hold element-wise"):
|
||||
laplace = laplace_lib.Laplace(
|
||||
loc=loc_v, scale=scale_v, validate_args=True)
|
||||
self.evaluate(laplace.mean())
|
||||
loc_v = constant_op.constant(1.0, name="loc")
|
||||
scale_v = constant_op.constant(0.0, name="scale")
|
||||
laplace = laplace_lib.Laplace(
|
||||
loc=loc_v, scale=scale_v, validate_args=True)
|
||||
with self.assertRaisesOpError("scale"):
|
||||
laplace.mean().eval()
|
||||
with self.assertRaisesOpError(
|
||||
"Condition x > 0 did not hold element-wise"):
|
||||
laplace = laplace_lib.Laplace(
|
||||
loc=loc_v, scale=scale_v, validate_args=True)
|
||||
self.evaluate(laplace.mean())
|
||||
|
||||
def testLaplaceWithSoftplusScale(self):
|
||||
with self.test_session():
|
||||
loc_v = constant_op.constant([0.0, 1.0], name="loc")
|
||||
scale_v = constant_op.constant([-1.0, 2.0], name="scale")
|
||||
laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v)
|
||||
self.assertAllClose(nn_ops.softplus(scale_v).eval(), laplace.scale.eval())
|
||||
self.assertAllClose(loc_v.eval(), laplace.loc.eval())
|
||||
self.assertAllClose(
|
||||
self.evaluate(nn_ops.softplus(scale_v)), self.evaluate(laplace.scale))
|
||||
self.assertAllClose(self.evaluate(loc_v), self.evaluate(laplace.loc))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import nn_ops
|
||||
@ -54,7 +55,7 @@ class NormalTest(test.TestCase):
|
||||
self._rng = np.random.RandomState(123)
|
||||
|
||||
def assertAllFinite(self, tensor):
|
||||
is_finite = np.isfinite(tensor.eval())
|
||||
is_finite = np.isfinite(self.evaluate(tensor))
|
||||
all_true = np.ones_like(is_finite, dtype=np.bool)
|
||||
self.assertAllEqual(all_true, is_finite)
|
||||
|
||||
@ -62,13 +63,13 @@ class NormalTest(test.TestCase):
|
||||
with self.test_session():
|
||||
param_shapes = normal_lib.Normal.param_shapes(sample_shape)
|
||||
mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"]
|
||||
self.assertAllEqual(expected, mu_shape.eval())
|
||||
self.assertAllEqual(expected, sigma_shape.eval())
|
||||
self.assertAllEqual(expected, self.evaluate(mu_shape))
|
||||
self.assertAllEqual(expected, self.evaluate(sigma_shape))
|
||||
mu = array_ops.zeros(mu_shape)
|
||||
sigma = array_ops.ones(sigma_shape)
|
||||
self.assertAllEqual(
|
||||
expected,
|
||||
array_ops.shape(normal_lib.Normal(mu, sigma).sample()).eval())
|
||||
self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample())))
|
||||
|
||||
def _testParamStaticShapes(self, sample_shape, expected):
|
||||
param_shapes = normal_lib.Normal.param_static_shapes(sample_shape)
|
||||
@ -76,25 +77,30 @@ class NormalTest(test.TestCase):
|
||||
self.assertEqual(expected, mu_shape)
|
||||
self.assertEqual(expected, sigma_shape)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testParamShapes(self):
|
||||
sample_shape = [10, 3, 4]
|
||||
self._testParamShapes(sample_shape, sample_shape)
|
||||
self._testParamShapes(constant_op.constant(sample_shape), sample_shape)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testParamStaticShapes(self):
|
||||
sample_shape = [10, 3, 4]
|
||||
self._testParamStaticShapes(sample_shape, sample_shape)
|
||||
self._testParamStaticShapes(
|
||||
tensor_shape.TensorShape(sample_shape), sample_shape)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalWithSoftplusScale(self):
|
||||
with self.test_session():
|
||||
mu = array_ops.zeros((10, 3))
|
||||
rho = array_ops.ones((10, 3)) * -2.
|
||||
normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho)
|
||||
self.assertAllEqual(mu.eval(), normal.loc.eval())
|
||||
self.assertAllEqual(nn_ops.softplus(rho).eval(), normal.scale.eval())
|
||||
self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc))
|
||||
self.assertAllEqual(
|
||||
self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalLogPDF(self):
|
||||
with self.test_session():
|
||||
batch_size = 6
|
||||
@ -104,25 +110,31 @@ class NormalTest(test.TestCase):
|
||||
normal = normal_lib.Normal(loc=mu, scale=sigma)
|
||||
|
||||
log_pdf = normal.log_prob(x)
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(),
|
||||
log_pdf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(),
|
||||
log_pdf.eval().shape)
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()),
|
||||
self.evaluate(log_pdf).shape)
|
||||
self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape, log_pdf.eval().shape)
|
||||
self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
|
||||
|
||||
pdf = normal.prob(x)
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf.eval().shape)
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()),
|
||||
self.evaluate(pdf).shape)
|
||||
self.assertAllEqual(normal.batch_shape, pdf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape, pdf.eval().shape)
|
||||
self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape)
|
||||
|
||||
if not stats:
|
||||
return
|
||||
expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x)
|
||||
self.assertAllClose(expected_log_pdf, log_pdf.eval())
|
||||
self.assertAllClose(np.exp(expected_log_pdf), pdf.eval())
|
||||
expected_log_pdf = stats.norm(self.evaluate(mu),
|
||||
self.evaluate(sigma)).logpdf(x)
|
||||
self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf))
|
||||
self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalLogPDFMultidimensional(self):
|
||||
with self.test_session():
|
||||
batch_size = 6
|
||||
@ -133,29 +145,34 @@ class NormalTest(test.TestCase):
|
||||
normal = normal_lib.Normal(loc=mu, scale=sigma)
|
||||
|
||||
log_pdf = normal.log_prob(x)
|
||||
log_pdf_values = log_pdf.eval()
|
||||
log_pdf_values = self.evaluate(log_pdf)
|
||||
self.assertEqual(log_pdf.get_shape(), (6, 2))
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(),
|
||||
log_pdf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(),
|
||||
log_pdf.eval().shape)
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()),
|
||||
self.evaluate(log_pdf).shape)
|
||||
self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape, log_pdf.eval().shape)
|
||||
self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
|
||||
|
||||
pdf = normal.prob(x)
|
||||
pdf_values = pdf.eval()
|
||||
pdf_values = self.evaluate(pdf)
|
||||
self.assertEqual(pdf.get_shape(), (6, 2))
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf_values.shape)
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()), pdf_values.shape)
|
||||
self.assertAllEqual(normal.batch_shape, pdf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape, pdf_values.shape)
|
||||
|
||||
if not stats:
|
||||
return
|
||||
expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x)
|
||||
expected_log_pdf = stats.norm(self.evaluate(mu),
|
||||
self.evaluate(sigma)).logpdf(x)
|
||||
self.assertAllClose(expected_log_pdf, log_pdf_values)
|
||||
self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalCDF(self):
|
||||
with self.test_session():
|
||||
batch_size = 50
|
||||
@ -165,15 +182,19 @@ class NormalTest(test.TestCase):
|
||||
|
||||
normal = normal_lib.Normal(loc=mu, scale=sigma)
|
||||
cdf = normal.cdf(x)
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.eval().shape)
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()),
|
||||
self.evaluate(cdf).shape)
|
||||
self.assertAllEqual(normal.batch_shape, cdf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape, cdf.eval().shape)
|
||||
self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
|
||||
if not stats:
|
||||
return
|
||||
expected_cdf = stats.norm(mu, sigma).cdf(x)
|
||||
self.assertAllClose(expected_cdf, cdf.eval(), atol=0)
|
||||
self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalSurvivalFunction(self):
|
||||
with self.test_session():
|
||||
batch_size = 50
|
||||
@ -184,15 +205,19 @@ class NormalTest(test.TestCase):
|
||||
normal = normal_lib.Normal(loc=mu, scale=sigma)
|
||||
|
||||
sf = normal.survival_function(x)
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.eval().shape)
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()),
|
||||
self.evaluate(sf).shape)
|
||||
self.assertAllEqual(normal.batch_shape, sf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape, sf.eval().shape)
|
||||
self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
|
||||
if not stats:
|
||||
return
|
||||
expected_sf = stats.norm(mu, sigma).sf(x)
|
||||
self.assertAllClose(expected_sf, sf.eval(), atol=0)
|
||||
self.assertAllClose(expected_sf, self.evaluate(sf), atol=0)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalLogCDF(self):
|
||||
with self.test_session():
|
||||
batch_size = 50
|
||||
@ -203,15 +228,18 @@ class NormalTest(test.TestCase):
|
||||
normal = normal_lib.Normal(loc=mu, scale=sigma)
|
||||
|
||||
cdf = normal.log_cdf(x)
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.eval().shape)
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()),
|
||||
self.evaluate(cdf).shape)
|
||||
self.assertAllEqual(normal.batch_shape, cdf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape, cdf.eval().shape)
|
||||
self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
|
||||
|
||||
if not stats:
|
||||
return
|
||||
expected_cdf = stats.norm(mu, sigma).logcdf(x)
|
||||
self.assertAllClose(expected_cdf, cdf.eval(), atol=0, rtol=1e-5)
|
||||
self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-5)
|
||||
|
||||
def testFiniteGradientAtDifficultPoints(self):
|
||||
for dtype in [np.float32, np.float64]:
|
||||
@ -233,6 +261,7 @@ class NormalTest(test.TestCase):
|
||||
self.assertAllFinite(grads[0])
|
||||
self.assertAllFinite(grads[1])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalLogSurvivalFunction(self):
|
||||
with self.test_session():
|
||||
batch_size = 50
|
||||
@ -243,16 +272,20 @@ class NormalTest(test.TestCase):
|
||||
normal = normal_lib.Normal(loc=mu, scale=sigma)
|
||||
|
||||
sf = normal.log_survival_function(x)
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.eval().shape)
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()),
|
||||
self.evaluate(sf).shape)
|
||||
self.assertAllEqual(normal.batch_shape, sf.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape, sf.eval().shape)
|
||||
self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
|
||||
|
||||
if not stats:
|
||||
return
|
||||
expected_sf = stats.norm(mu, sigma).logsf(x)
|
||||
self.assertAllClose(expected_sf, sf.eval(), atol=0, rtol=1e-5)
|
||||
self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalEntropyWithScalarInputs(self):
|
||||
# Scipy.stats.norm cannot deal with the shapes in the other test.
|
||||
with self.test_session():
|
||||
@ -261,18 +294,20 @@ class NormalTest(test.TestCase):
|
||||
normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
|
||||
|
||||
entropy = normal.entropy()
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(),
|
||||
entropy.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(),
|
||||
entropy.eval().shape)
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()),
|
||||
self.evaluate(entropy).shape)
|
||||
self.assertAllEqual(normal.batch_shape, entropy.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape, entropy.eval().shape)
|
||||
self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
|
||||
# scipy.stats.norm cannot deal with these shapes.
|
||||
if not stats:
|
||||
return
|
||||
expected_entropy = stats.norm(mu_v, sigma_v).entropy()
|
||||
self.assertAllClose(expected_entropy, entropy.eval())
|
||||
self.assertAllClose(expected_entropy, self.evaluate(entropy))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalEntropy(self):
|
||||
with self.test_session():
|
||||
mu_v = np.array([1.0, 1.0, 1.0])
|
||||
@ -284,14 +319,16 @@ class NormalTest(test.TestCase):
|
||||
expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**
|
||||
2)
|
||||
entropy = normal.entropy()
|
||||
np.testing.assert_allclose(expected_entropy, entropy.eval())
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(),
|
||||
entropy.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(),
|
||||
entropy.eval().shape)
|
||||
np.testing.assert_allclose(expected_entropy, self.evaluate(entropy))
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()),
|
||||
self.evaluate(entropy).shape)
|
||||
self.assertAllEqual(normal.batch_shape, entropy.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape, entropy.eval().shape)
|
||||
self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalMeanAndMode(self):
|
||||
with self.test_session():
|
||||
# Mu will be broadcast to [7, 7, 7].
|
||||
@ -301,11 +338,12 @@ class NormalTest(test.TestCase):
|
||||
normal = normal_lib.Normal(loc=mu, scale=sigma)
|
||||
|
||||
self.assertAllEqual((3,), normal.mean().get_shape())
|
||||
self.assertAllEqual([7., 7, 7], normal.mean().eval())
|
||||
self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean()))
|
||||
|
||||
self.assertAllEqual((3,), normal.mode().get_shape())
|
||||
self.assertAllEqual([7., 7, 7], normal.mode().eval())
|
||||
self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode()))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalQuantile(self):
|
||||
with self.test_session():
|
||||
batch_size = 52
|
||||
@ -319,15 +357,18 @@ class NormalTest(test.TestCase):
|
||||
normal = normal_lib.Normal(loc=mu, scale=sigma)
|
||||
x = normal.quantile(p)
|
||||
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(), x.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape_tensor().eval(), x.eval().shape)
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()), x.get_shape())
|
||||
self.assertAllEqual(
|
||||
self.evaluate(normal.batch_shape_tensor()),
|
||||
self.evaluate(x).shape)
|
||||
self.assertAllEqual(normal.batch_shape, x.get_shape())
|
||||
self.assertAllEqual(normal.batch_shape, x.eval().shape)
|
||||
self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape)
|
||||
|
||||
if not stats:
|
||||
return
|
||||
expected_x = stats.norm(mu, sigma).ppf(p)
|
||||
self.assertAllClose(expected_x, x.eval(), atol=0.)
|
||||
self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
|
||||
|
||||
def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype):
|
||||
g = ops.Graph()
|
||||
@ -354,6 +395,7 @@ class NormalTest(test.TestCase):
|
||||
def testQuantileFiniteGradientAtDifficultPointsFloat64(self):
|
||||
self._baseQuantileFiniteGradientAtDifficultPoints(np.float64)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalVariance(self):
|
||||
with self.test_session():
|
||||
# sigma will be broadcast to [7, 7, 7]
|
||||
@ -363,8 +405,9 @@ class NormalTest(test.TestCase):
|
||||
normal = normal_lib.Normal(loc=mu, scale=sigma)
|
||||
|
||||
self.assertAllEqual((3,), normal.variance().get_shape())
|
||||
self.assertAllEqual([49., 49, 49], normal.variance().eval())
|
||||
self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance()))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalStandardDeviation(self):
|
||||
with self.test_session():
|
||||
# sigma will be broadcast to [7, 7, 7]
|
||||
@ -374,8 +417,9 @@ class NormalTest(test.TestCase):
|
||||
normal = normal_lib.Normal(loc=mu, scale=sigma)
|
||||
|
||||
self.assertAllEqual((3,), normal.stddev().get_shape())
|
||||
self.assertAllEqual([7., 7, 7], normal.stddev().eval())
|
||||
self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev()))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalSample(self):
|
||||
with self.test_session():
|
||||
mu = constant_op.constant(3.0)
|
||||
@ -385,7 +429,7 @@ class NormalTest(test.TestCase):
|
||||
n = constant_op.constant(100000)
|
||||
normal = normal_lib.Normal(loc=mu, scale=sigma)
|
||||
samples = normal.sample(n)
|
||||
sample_values = samples.eval()
|
||||
sample_values = self.evaluate(samples)
|
||||
# Note that the standard error for the sample mean is ~ sigma / sqrt(n).
|
||||
# The sample variance similarly is dependent on sigma and n.
|
||||
# Thus, the tolerances below are very sensitive to number of samples
|
||||
@ -394,18 +438,22 @@ class NormalTest(test.TestCase):
|
||||
self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1)
|
||||
self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1)
|
||||
|
||||
expected_samples_shape = tensor_shape.TensorShape([n.eval()]).concatenate(
|
||||
tensor_shape.TensorShape(normal.batch_shape_tensor().eval()))
|
||||
expected_samples_shape = tensor_shape.TensorShape(
|
||||
[self.evaluate(n)]).concatenate(
|
||||
tensor_shape.TensorShape(
|
||||
self.evaluate(normal.batch_shape_tensor())))
|
||||
|
||||
self.assertAllEqual(expected_samples_shape, samples.get_shape())
|
||||
self.assertAllEqual(expected_samples_shape, sample_values.shape)
|
||||
|
||||
expected_samples_shape = (tensor_shape.TensorShape(
|
||||
[n.eval()]).concatenate(normal.batch_shape))
|
||||
expected_samples_shape = (
|
||||
tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
|
||||
normal.batch_shape))
|
||||
|
||||
self.assertAllEqual(expected_samples_shape, samples.get_shape())
|
||||
self.assertAllEqual(expected_samples_shape, sample_values.shape)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalSampleMultiDimensional(self):
|
||||
with self.test_session():
|
||||
batch_size = 2
|
||||
@ -417,7 +465,7 @@ class NormalTest(test.TestCase):
|
||||
n = constant_op.constant(100000)
|
||||
normal = normal_lib.Normal(loc=mu, scale=sigma)
|
||||
samples = normal.sample(n)
|
||||
sample_values = samples.eval()
|
||||
sample_values = self.evaluate(samples)
|
||||
# Note that the standard error for the sample mean is ~ sigma / sqrt(n).
|
||||
# The sample variance similarly is dependent on sigma and n.
|
||||
# Thus, the tolerances below are very sensitive to number of samples
|
||||
@ -428,32 +476,37 @@ class NormalTest(test.TestCase):
|
||||
self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1)
|
||||
self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1)
|
||||
|
||||
expected_samples_shape = tensor_shape.TensorShape([n.eval()]).concatenate(
|
||||
tensor_shape.TensorShape(normal.batch_shape_tensor().eval()))
|
||||
expected_samples_shape = tensor_shape.TensorShape(
|
||||
[self.evaluate(n)]).concatenate(
|
||||
tensor_shape.TensorShape(
|
||||
self.evaluate(normal.batch_shape_tensor())))
|
||||
self.assertAllEqual(expected_samples_shape, samples.get_shape())
|
||||
self.assertAllEqual(expected_samples_shape, sample_values.shape)
|
||||
|
||||
expected_samples_shape = (tensor_shape.TensorShape(
|
||||
[n.eval()]).concatenate(normal.batch_shape))
|
||||
expected_samples_shape = (
|
||||
tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
|
||||
normal.batch_shape))
|
||||
self.assertAllEqual(expected_samples_shape, samples.get_shape())
|
||||
self.assertAllEqual(expected_samples_shape, sample_values.shape)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNegativeSigmaFails(self):
|
||||
with self.test_session():
|
||||
normal = normal_lib.Normal(
|
||||
loc=[1.], scale=[-5.], validate_args=True, name="G")
|
||||
with self.assertRaisesOpError("Condition x > 0 did not hold"):
|
||||
normal.mean().eval()
|
||||
normal = normal_lib.Normal(
|
||||
loc=[1.], scale=[-5.], validate_args=True, name="G")
|
||||
self.evaluate(normal.mean())
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalShape(self):
|
||||
with self.test_session():
|
||||
mu = constant_op.constant([-3.0] * 5)
|
||||
sigma = constant_op.constant(11.0)
|
||||
normal = normal_lib.Normal(loc=mu, scale=sigma)
|
||||
|
||||
self.assertEqual(normal.batch_shape_tensor().eval(), [5])
|
||||
self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5])
|
||||
self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5]))
|
||||
self.assertAllEqual(normal.event_shape_tensor().eval(), [])
|
||||
self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), [])
|
||||
self.assertEqual(normal.event_shape, tensor_shape.TensorShape([]))
|
||||
|
||||
def testNormalShapeWithPlaceholders(self):
|
||||
@ -465,31 +518,31 @@ class NormalTest(test.TestCase):
|
||||
# get_batch_shape should return an "<unknown>" tensor.
|
||||
self.assertEqual(normal.batch_shape, tensor_shape.TensorShape(None))
|
||||
self.assertEqual(normal.event_shape, ())
|
||||
self.assertAllEqual(normal.event_shape_tensor().eval(), [])
|
||||
self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), [])
|
||||
self.assertAllEqual(
|
||||
sess.run(normal.batch_shape_tensor(),
|
||||
feed_dict={mu: 5.0,
|
||||
sigma: [1.0, 2.0]}), [2])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNormalNormalKL(self):
|
||||
with self.test_session() as sess:
|
||||
batch_size = 6
|
||||
mu_a = np.array([3.0] * batch_size)
|
||||
sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5])
|
||||
mu_b = np.array([-3.0] * batch_size)
|
||||
sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])
|
||||
batch_size = 6
|
||||
mu_a = np.array([3.0] * batch_size)
|
||||
sigma_a = np.array([1.0, 2.0, 3.0, 1.5, 2.5, 3.5])
|
||||
mu_b = np.array([-3.0] * batch_size)
|
||||
sigma_b = np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])
|
||||
|
||||
n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a)
|
||||
n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b)
|
||||
n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a)
|
||||
n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b)
|
||||
|
||||
kl = kullback_leibler.kl_divergence(n_a, n_b)
|
||||
kl_val = sess.run(kl)
|
||||
kl = kullback_leibler.kl_divergence(n_a, n_b)
|
||||
kl_val = self.evaluate(kl)
|
||||
|
||||
kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * (
|
||||
(sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b)))
|
||||
kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * (
|
||||
(sigma_a**2 / sigma_b**2) - 1 - 2 * np.log(sigma_a / sigma_b)))
|
||||
|
||||
self.assertEqual(kl.get_shape(), (batch_size,))
|
||||
self.assertAllClose(kl_val, kl_expected)
|
||||
self.assertEqual(kl.get_shape(), (batch_size,))
|
||||
self.assertAllClose(kl_val, kl_expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -147,6 +147,7 @@ class NdtriTest(test.TestCase):
|
||||
self._baseNdtriFiniteGradientTest(np.float64)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class NdtrTest(test.TestCase):
|
||||
_use_log = False
|
||||
# Grid min/max chosen to ensure 0 < cdf(x) < 1.
|
||||
|
@ -25,6 +25,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops.distributions import student_t
|
||||
@ -44,6 +45,7 @@ def try_import(name): # pylint: disable=invalid-name
|
||||
stats = try_import("scipy.stats")
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class StudentTTest(test.TestCase):
|
||||
|
||||
def testStudentPDFAndLogPDF(self):
|
||||
@ -60,10 +62,10 @@ class StudentTTest(test.TestCase):
|
||||
|
||||
log_pdf = student.log_prob(t)
|
||||
self.assertEquals(log_pdf.get_shape(), (6,))
|
||||
log_pdf_values = log_pdf.eval()
|
||||
log_pdf_values = self.evaluate(log_pdf)
|
||||
pdf = student.prob(t)
|
||||
self.assertEquals(pdf.get_shape(), (6,))
|
||||
pdf_values = pdf.eval()
|
||||
pdf_values = self.evaluate(pdf)
|
||||
|
||||
if not stats:
|
||||
return
|
||||
@ -88,10 +90,10 @@ class StudentTTest(test.TestCase):
|
||||
t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T
|
||||
student = student_t.StudentT(df, loc=mu, scale=sigma)
|
||||
log_pdf = student.log_prob(t)
|
||||
log_pdf_values = log_pdf.eval()
|
||||
log_pdf_values = self.evaluate(log_pdf)
|
||||
self.assertEqual(log_pdf.get_shape(), (6, 2))
|
||||
pdf = student.prob(t)
|
||||
pdf_values = pdf.eval()
|
||||
pdf_values = self.evaluate(pdf)
|
||||
self.assertEqual(pdf.get_shape(), (6, 2))
|
||||
|
||||
if not stats:
|
||||
@ -117,10 +119,10 @@ class StudentTTest(test.TestCase):
|
||||
|
||||
log_cdf = student.log_cdf(t)
|
||||
self.assertEquals(log_cdf.get_shape(), (6,))
|
||||
log_cdf_values = log_cdf.eval()
|
||||
log_cdf_values = self.evaluate(log_cdf)
|
||||
cdf = student.cdf(t)
|
||||
self.assertEquals(cdf.get_shape(), (6,))
|
||||
cdf_values = cdf.eval()
|
||||
cdf_values = self.evaluate(cdf)
|
||||
|
||||
if not stats:
|
||||
return
|
||||
@ -140,7 +142,7 @@ class StudentTTest(test.TestCase):
|
||||
with self.test_session():
|
||||
student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
|
||||
ent = student.entropy()
|
||||
ent_values = ent.eval()
|
||||
ent_values = self.evaluate(ent)
|
||||
|
||||
# Help scipy broadcast to 3x3
|
||||
ones = np.array([[1, 1, 1]])
|
||||
@ -167,7 +169,7 @@ class StudentTTest(test.TestCase):
|
||||
n = constant_op.constant(200000)
|
||||
student = student_t.StudentT(df=df, loc=mu, scale=sigma)
|
||||
samples = student.sample(n, seed=123456)
|
||||
sample_values = samples.eval()
|
||||
sample_values = self.evaluate(samples)
|
||||
n_val = 200000
|
||||
self.assertEqual(sample_values.shape, (n_val,))
|
||||
self.assertAllClose(sample_values.mean(), mu_v, rtol=1e-2, atol=0)
|
||||
@ -189,12 +191,12 @@ class StudentTTest(test.TestCase):
|
||||
random_seed.set_random_seed(654321)
|
||||
student = student_t.StudentT(
|
||||
df=df, loc=mu, scale=sigma, name="student_t1")
|
||||
samples1 = student.sample(n, seed=123456).eval()
|
||||
samples1 = self.evaluate(student.sample(n, seed=123456))
|
||||
|
||||
random_seed.set_random_seed(654321)
|
||||
student2 = student_t.StudentT(
|
||||
df=df, loc=mu, scale=sigma, name="student_t2")
|
||||
samples2 = student2.sample(n, seed=123456).eval()
|
||||
samples2 = self.evaluate(student2.sample(n, seed=123456))
|
||||
|
||||
self.assertAllClose(samples1, samples2)
|
||||
|
||||
@ -205,7 +207,7 @@ class StudentTTest(test.TestCase):
|
||||
n = constant_op.constant(200000)
|
||||
student = student_t.StudentT(df=df, loc=1., scale=1.)
|
||||
samples = student.sample(n, seed=123456)
|
||||
sample_values = samples.eval()
|
||||
sample_values = self.evaluate(samples)
|
||||
n_val = 200000
|
||||
self.assertEqual(sample_values.shape, (n_val, 4))
|
||||
self.assertTrue(np.all(np.logical_not(np.isnan(sample_values))))
|
||||
@ -223,7 +225,7 @@ class StudentTTest(test.TestCase):
|
||||
n = constant_op.constant(200000)
|
||||
student = student_t.StudentT(df=df, loc=mu, scale=sigma)
|
||||
samples = student.sample(n, seed=123456)
|
||||
sample_values = samples.eval()
|
||||
sample_values = self.evaluate(samples)
|
||||
self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
|
||||
self.assertAllClose(
|
||||
sample_values[:, 0, 0].mean(), mu_v[0], rtol=1e-2, atol=0)
|
||||
@ -325,7 +327,7 @@ class StudentTTest(test.TestCase):
|
||||
with self.test_session():
|
||||
mu = [1., 3.3, 4.4]
|
||||
student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.])
|
||||
mean = student.mean().eval()
|
||||
mean = self.evaluate(student.mean())
|
||||
self.assertAllClose([1., 3.3, 4.4], mean)
|
||||
|
||||
def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self):
|
||||
@ -335,7 +337,7 @@ class StudentTTest(test.TestCase):
|
||||
df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.],
|
||||
allow_nan_stats=False)
|
||||
with self.assertRaisesOpError("x < y"):
|
||||
student.mean().eval()
|
||||
self.evaluate(student.mean())
|
||||
|
||||
def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self):
|
||||
with self.test_session():
|
||||
@ -344,7 +346,7 @@ class StudentTTest(test.TestCase):
|
||||
student = student_t.StudentT(
|
||||
df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma,
|
||||
allow_nan_stats=True)
|
||||
mean = student.mean().eval()
|
||||
mean = self.evaluate(student.mean())
|
||||
self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
|
||||
|
||||
def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self):
|
||||
@ -356,7 +358,7 @@ class StudentTTest(test.TestCase):
|
||||
sigma = [5., 4., 3., 2., 1.]
|
||||
student = student_t.StudentT(
|
||||
df=df, loc=mu, scale=sigma, allow_nan_stats=True)
|
||||
var = student.variance().eval()
|
||||
var = self.evaluate(student.variance())
|
||||
## scipy uses inf for variance when the mean is undefined. When mean is
|
||||
# undefined we say variance is undefined as well. So test the first
|
||||
# member of var, making sure it is NaN, then replace with inf and compare
|
||||
@ -379,7 +381,7 @@ class StudentTTest(test.TestCase):
|
||||
mu = [0., 1., 3.3, 4.4]
|
||||
sigma = [4., 3., 2., 1.]
|
||||
student = student_t.StudentT(df=df, loc=mu, scale=sigma)
|
||||
var = student.variance().eval()
|
||||
var = self.evaluate(student.variance())
|
||||
|
||||
if not stats:
|
||||
return
|
||||
@ -394,14 +396,14 @@ class StudentTTest(test.TestCase):
|
||||
student = student_t.StudentT(
|
||||
df=1., loc=0., scale=1., allow_nan_stats=False)
|
||||
with self.assertRaisesOpError("x < y"):
|
||||
student.variance().eval()
|
||||
self.evaluate(student.variance())
|
||||
|
||||
with self.test_session():
|
||||
# df <= 1 ==> variance not defined
|
||||
student = student_t.StudentT(
|
||||
df=0.5, loc=0., scale=1., allow_nan_stats=False)
|
||||
with self.assertRaisesOpError("x < y"):
|
||||
student.variance().eval()
|
||||
self.evaluate(student.variance())
|
||||
|
||||
def testStd(self):
|
||||
with self.test_session():
|
||||
@ -411,7 +413,7 @@ class StudentTTest(test.TestCase):
|
||||
sigma = [5., 4., 3., 2., 1.]
|
||||
student = student_t.StudentT(df=df, loc=mu, scale=sigma)
|
||||
# Test broadcast of mu across shape of df/sigma
|
||||
stddev = student.stddev().eval()
|
||||
stddev = self.evaluate(student.stddev())
|
||||
mu *= len(df)
|
||||
|
||||
if not stats:
|
||||
@ -428,59 +430,58 @@ class StudentTTest(test.TestCase):
|
||||
sigma = [5., 4., 3.]
|
||||
student = student_t.StudentT(df=df, loc=mu, scale=sigma)
|
||||
# Test broadcast of mu across shape of df/sigma
|
||||
mode = student.mode().eval()
|
||||
mode = self.evaluate(student.mode())
|
||||
self.assertAllClose([-1., 0, 1], mode)
|
||||
|
||||
def testPdfOfSample(self):
|
||||
with self.test_session() as sess:
|
||||
student = student_t.StudentT(df=3., loc=np.pi, scale=1.)
|
||||
num = 20000
|
||||
samples = student.sample(num, seed=123456)
|
||||
pdfs = student.prob(samples)
|
||||
mean = student.mean()
|
||||
mean_pdf = student.prob(student.mean())
|
||||
sample_vals, pdf_vals, mean_val, mean_pdf_val = sess.run(
|
||||
[samples, pdfs, student.mean(), mean_pdf])
|
||||
self.assertEqual(samples.get_shape(), (num,))
|
||||
self.assertEqual(pdfs.get_shape(), (num,))
|
||||
self.assertEqual(mean.get_shape(), ())
|
||||
self.assertNear(np.pi, np.mean(sample_vals), err=0.02)
|
||||
self.assertNear(np.pi, mean_val, err=1e-6)
|
||||
# Verify integral over sample*pdf ~= 1.
|
||||
self._assertIntegral(sample_vals, pdf_vals, err=2e-3)
|
||||
if not stats:
|
||||
return
|
||||
self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6)
|
||||
student = student_t.StudentT(df=3., loc=np.pi, scale=1.)
|
||||
num = 20000
|
||||
samples = student.sample(num, seed=123456)
|
||||
pdfs = student.prob(samples)
|
||||
mean = student.mean()
|
||||
mean_pdf = student.prob(student.mean())
|
||||
sample_vals, pdf_vals, mean_val, mean_pdf_val = self.evaluate(
|
||||
[samples, pdfs, student.mean(), mean_pdf])
|
||||
self.assertEqual(samples.get_shape(), (num,))
|
||||
self.assertEqual(pdfs.get_shape(), (num,))
|
||||
self.assertEqual(mean.get_shape(), ())
|
||||
self.assertNear(np.pi, np.mean(sample_vals), err=0.02)
|
||||
self.assertNear(np.pi, mean_val, err=1e-6)
|
||||
# Verify integral over sample*pdf ~= 1.
|
||||
# Tolerance increased since eager was getting a value of 1.002041.
|
||||
self._assertIntegral(sample_vals, pdf_vals, err=3e-3)
|
||||
if not stats:
|
||||
return
|
||||
self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6)
|
||||
|
||||
def testPdfOfSampleMultiDims(self):
|
||||
with self.test_session() as sess:
|
||||
student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.)
|
||||
self.assertAllEqual([], student.event_shape)
|
||||
self.assertAllEqual([], student.event_shape_tensor().eval())
|
||||
self.assertAllEqual([2, 2], student.batch_shape)
|
||||
self.assertAllEqual([2, 2], student.batch_shape_tensor().eval())
|
||||
num = 50000
|
||||
samples = student.sample(num, seed=123456)
|
||||
pdfs = student.prob(samples)
|
||||
sample_vals, pdf_vals = sess.run([samples, pdfs])
|
||||
self.assertEqual(samples.get_shape(), (num, 2, 2))
|
||||
self.assertEqual(pdfs.get_shape(), (num, 2, 2))
|
||||
self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=.03)
|
||||
self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=.03)
|
||||
self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
|
||||
self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
|
||||
self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
|
||||
self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
|
||||
if not stats:
|
||||
return
|
||||
self.assertNear(
|
||||
stats.t.var(7., loc=0., scale=3.), # loc d.n. effect var
|
||||
np.var(sample_vals[:, :, 0]),
|
||||
err=.4)
|
||||
self.assertNear(
|
||||
stats.t.var(11., loc=0., scale=3.), # loc d.n. effect var
|
||||
np.var(sample_vals[:, :, 1]),
|
||||
err=.4)
|
||||
student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.)
|
||||
self.assertAllEqual([], student.event_shape)
|
||||
self.assertAllEqual([], self.evaluate(student.event_shape_tensor()))
|
||||
self.assertAllEqual([2, 2], student.batch_shape)
|
||||
self.assertAllEqual([2, 2], self.evaluate(student.batch_shape_tensor()))
|
||||
num = 50000
|
||||
samples = student.sample(num, seed=123456)
|
||||
pdfs = student.prob(samples)
|
||||
sample_vals, pdf_vals = self.evaluate([samples, pdfs])
|
||||
self.assertEqual(samples.get_shape(), (num, 2, 2))
|
||||
self.assertEqual(pdfs.get_shape(), (num, 2, 2))
|
||||
self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=.03)
|
||||
self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=.03)
|
||||
self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
|
||||
self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
|
||||
self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
|
||||
self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
|
||||
if not stats:
|
||||
return
|
||||
self.assertNear(
|
||||
stats.t.var(7., loc=0., scale=3.), # loc d.n. effect var
|
||||
np.var(sample_vals[:, :, 0]),
|
||||
err=.4)
|
||||
self.assertNear(
|
||||
stats.t.var(11., loc=0., scale=3.), # loc d.n. effect var
|
||||
np.var(sample_vals[:, :, 1]),
|
||||
err=.4)
|
||||
|
||||
def _assertIntegral(self, sample_vals, pdf_vals, err=1.5e-3):
|
||||
s_p = zip(sample_vals, pdf_vals)
|
||||
@ -494,10 +495,10 @@ class StudentTTest(test.TestCase):
|
||||
|
||||
def testNegativeDofFails(self):
|
||||
with self.test_session():
|
||||
student = student_t.StudentT(df=[2, -5.], loc=0., scale=1.,
|
||||
validate_args=True, name="S")
|
||||
with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
|
||||
student.mean().eval()
|
||||
student = student_t.StudentT(
|
||||
df=[2, -5.], loc=0., scale=1., validate_args=True, name="S")
|
||||
self.evaluate(student.mean())
|
||||
|
||||
def testStudentTWithAbsDfSoftplusScale(self):
|
||||
with self.test_session():
|
||||
@ -507,9 +508,11 @@ class StudentTTest(test.TestCase):
|
||||
student = student_t.StudentTWithAbsDfSoftplusScale(
|
||||
df=df, loc=mu, scale=sigma)
|
||||
self.assertAllClose(
|
||||
math_ops.floor(math_ops.abs(df)).eval(), student.df.eval())
|
||||
self.assertAllClose(mu.eval(), student.loc.eval())
|
||||
self.assertAllClose(nn_ops.softplus(sigma).eval(), student.scale.eval())
|
||||
math_ops.floor(self.evaluate(math_ops.abs(df))),
|
||||
self.evaluate(student.df))
|
||||
self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc))
|
||||
self.assertAllClose(
|
||||
self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -25,6 +25,7 @@ import numpy as np
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.distributions import uniform as uniform_lib
|
||||
@ -46,15 +47,17 @@ stats = try_import("scipy.stats")
|
||||
|
||||
class UniformTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testUniformRange(self):
|
||||
with self.test_session():
|
||||
a = 3.0
|
||||
b = 10.0
|
||||
uniform = uniform_lib.Uniform(low=a, high=b)
|
||||
self.assertAllClose(a, uniform.low.eval())
|
||||
self.assertAllClose(b, uniform.high.eval())
|
||||
self.assertAllClose(b - a, uniform.range().eval())
|
||||
self.assertAllClose(a, self.evaluate(uniform.low))
|
||||
self.assertAllClose(b, self.evaluate(uniform.high))
|
||||
self.assertAllClose(b - a, self.evaluate(uniform.range()))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testUniformPDF(self):
|
||||
with self.test_session():
|
||||
a = constant_op.constant([-3.0] * 5 + [15.0])
|
||||
@ -75,22 +78,24 @@ class UniformTest(test.TestCase):
|
||||
expected_pdf = _expected_pdf()
|
||||
|
||||
pdf = uniform.prob(x)
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
self.assertAllClose(expected_pdf, self.evaluate(pdf))
|
||||
|
||||
log_pdf = uniform.log_prob(x)
|
||||
self.assertAllClose(np.log(expected_pdf), log_pdf.eval())
|
||||
self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testUniformShape(self):
|
||||
with self.test_session():
|
||||
a = constant_op.constant([-3.0] * 5)
|
||||
b = constant_op.constant(11.0)
|
||||
uniform = uniform_lib.Uniform(low=a, high=b)
|
||||
|
||||
self.assertEqual(uniform.batch_shape_tensor().eval(), (5,))
|
||||
self.assertEqual(self.evaluate(uniform.batch_shape_tensor()), (5,))
|
||||
self.assertEqual(uniform.batch_shape, tensor_shape.TensorShape([5]))
|
||||
self.assertAllEqual(uniform.event_shape_tensor().eval(), [])
|
||||
self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), [])
|
||||
self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([]))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testUniformPDFWithScalarEndpoint(self):
|
||||
with self.test_session():
|
||||
a = constant_op.constant([0.0, 5.0])
|
||||
@ -101,8 +106,9 @@ class UniformTest(test.TestCase):
|
||||
expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)])
|
||||
|
||||
pdf = uniform.prob(x)
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
self.assertAllClose(expected_pdf, self.evaluate(pdf))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testUniformCDF(self):
|
||||
with self.test_session():
|
||||
batch_size = 6
|
||||
@ -121,11 +127,12 @@ class UniformTest(test.TestCase):
|
||||
return cdf
|
||||
|
||||
cdf = uniform.cdf(x)
|
||||
self.assertAllClose(_expected_cdf(), cdf.eval())
|
||||
self.assertAllClose(_expected_cdf(), self.evaluate(cdf))
|
||||
|
||||
log_cdf = uniform.log_cdf(x)
|
||||
self.assertAllClose(np.log(_expected_cdf()), log_cdf.eval())
|
||||
self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testUniformEntropy(self):
|
||||
with self.test_session():
|
||||
a_v = np.array([1.0, 1.0, 1.0])
|
||||
@ -133,18 +140,20 @@ class UniformTest(test.TestCase):
|
||||
uniform = uniform_lib.Uniform(low=a_v, high=b_v)
|
||||
|
||||
expected_entropy = np.log(b_v - a_v)
|
||||
self.assertAllClose(expected_entropy, uniform.entropy().eval())
|
||||
self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy()))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testUniformAssertMaxGtMin(self):
|
||||
with self.test_session():
|
||||
a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32)
|
||||
b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True)
|
||||
|
||||
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
|
||||
"x < y"):
|
||||
uniform.low.eval()
|
||||
uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True)
|
||||
self.evaluate(uniform.low)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testUniformSample(self):
|
||||
with self.test_session():
|
||||
a = constant_op.constant([3.0, 4.0])
|
||||
@ -156,7 +165,7 @@ class UniformTest(test.TestCase):
|
||||
uniform = uniform_lib.Uniform(low=a, high=b)
|
||||
|
||||
samples = uniform.sample(n, seed=137)
|
||||
sample_values = samples.eval()
|
||||
sample_values = self.evaluate(samples)
|
||||
self.assertEqual(sample_values.shape, (100000, 2))
|
||||
self.assertAllClose(
|
||||
sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-2)
|
||||
@ -167,6 +176,7 @@ class UniformTest(test.TestCase):
|
||||
self.assertFalse(
|
||||
np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def _testUniformSampleMultiDimensional(self):
|
||||
# DISABLED: Please enable this test once b/issues/30149644 is resolved.
|
||||
with self.test_session():
|
||||
@ -183,7 +193,7 @@ class UniformTest(test.TestCase):
|
||||
samples = uniform.sample(n)
|
||||
self.assertEqual(samples.get_shape(), (n_v, batch_size, 2))
|
||||
|
||||
sample_values = samples.eval()
|
||||
sample_values = self.evaluate(samples)
|
||||
|
||||
self.assertFalse(
|
||||
np.any(sample_values[:, 0, 0] < a_v[0]) or
|
||||
@ -197,6 +207,7 @@ class UniformTest(test.TestCase):
|
||||
self.assertAllClose(
|
||||
sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testUniformMean(self):
|
||||
with self.test_session():
|
||||
a = 10.0
|
||||
@ -205,8 +216,9 @@ class UniformTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
s_uniform = stats.uniform(loc=a, scale=b - a)
|
||||
self.assertAllClose(uniform.mean().eval(), s_uniform.mean())
|
||||
self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean())
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testUniformVariance(self):
|
||||
with self.test_session():
|
||||
a = 10.0
|
||||
@ -215,8 +227,9 @@ class UniformTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
s_uniform = stats.uniform(loc=a, scale=b - a)
|
||||
self.assertAllClose(uniform.variance().eval(), s_uniform.var())
|
||||
self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var())
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testUniformStd(self):
|
||||
with self.test_session():
|
||||
a = 10.0
|
||||
@ -225,8 +238,9 @@ class UniformTest(test.TestCase):
|
||||
if not stats:
|
||||
return
|
||||
s_uniform = stats.uniform(loc=a, scale=b - a)
|
||||
self.assertAllClose(uniform.stddev().eval(), s_uniform.std())
|
||||
self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std())
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testUniformNans(self):
|
||||
with self.test_session():
|
||||
a = 10.0
|
||||
@ -235,23 +249,26 @@ class UniformTest(test.TestCase):
|
||||
|
||||
no_nans = constant_op.constant(1.0)
|
||||
nans = constant_op.constant(0.0) / constant_op.constant(0.0)
|
||||
self.assertTrue(math_ops.is_nan(nans).eval())
|
||||
self.assertTrue(self.evaluate(math_ops.is_nan(nans)))
|
||||
with_nans = array_ops.stack([no_nans, nans])
|
||||
|
||||
pdf = uniform.prob(with_nans)
|
||||
|
||||
is_nan = math_ops.is_nan(pdf).eval()
|
||||
is_nan = self.evaluate(math_ops.is_nan(pdf))
|
||||
self.assertFalse(is_nan[0])
|
||||
self.assertTrue(is_nan[1])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testUniformSamplePdf(self):
|
||||
with self.test_session():
|
||||
a = 10.0
|
||||
b = [11.0, 100.0]
|
||||
uniform = uniform_lib.Uniform(a, b)
|
||||
self.assertTrue(
|
||||
math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0).eval())
|
||||
self.evaluate(
|
||||
math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0)))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testUniformBroadcasting(self):
|
||||
with self.test_session():
|
||||
a = 10.0
|
||||
@ -260,8 +277,9 @@ class UniformTest(test.TestCase):
|
||||
|
||||
pdf = uniform.prob([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]])
|
||||
expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]])
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
self.assertAllClose(expected_pdf, self.evaluate(pdf))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testUniformSampleWithShape(self):
|
||||
with self.test_session():
|
||||
a = 10.0
|
||||
@ -275,12 +293,13 @@ class UniformTest(test.TestCase):
|
||||
[[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
|
||||
]
|
||||
# pylint: enable=bad-continuation
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
self.assertAllClose(expected_pdf, self.evaluate(pdf))
|
||||
|
||||
pdf = uniform.prob(uniform.sample())
|
||||
expected_pdf = [1.0, 0.1]
|
||||
self.assertAllClose(expected_pdf, pdf.eval())
|
||||
self.assertAllClose(expected_pdf, self.evaluate(pdf))
|
||||
|
||||
# Eager doesn't pass due to a type mismatch in one of the ops.
|
||||
def testUniformFloat64(self):
|
||||
uniform = uniform_lib.Uniform(
|
||||
low=np.float64(0.), high=np.float64(1.))
|
||||
|
@ -22,9 +22,11 @@ import importlib
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
@ -97,6 +99,7 @@ class AssertCloseTest(test.TestCase):
|
||||
with ops.control_dependencies([du.assert_close(y, z)]):
|
||||
array_ops.identity(y).eval(feed_dict=feed_dict)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testAssertCloseEpsilon(self):
|
||||
x = [0., 5, 10, 15, 20]
|
||||
# x != y
|
||||
@ -105,15 +108,15 @@ class AssertCloseTest(test.TestCase):
|
||||
z = [1e-8, 5, 10, 15, 20]
|
||||
with self.test_session():
|
||||
with ops.control_dependencies([du.assert_close(x, z)]):
|
||||
array_ops.identity(x).eval()
|
||||
self.evaluate(array_ops.identity(x))
|
||||
|
||||
with self.assertRaisesOpError("Condition x ~= y"):
|
||||
with ops.control_dependencies([du.assert_close(x, y)]):
|
||||
array_ops.identity(x).eval()
|
||||
self.evaluate(array_ops.identity(x))
|
||||
|
||||
with self.assertRaisesOpError("Condition x ~= y"):
|
||||
with ops.control_dependencies([du.assert_close(y, z)]):
|
||||
array_ops.identity(y).eval()
|
||||
self.evaluate(array_ops.identity(y))
|
||||
|
||||
def testAssertIntegerForm(self):
|
||||
# This should only be detected as an integer.
|
||||
@ -147,18 +150,21 @@ class AssertCloseTest(test.TestCase):
|
||||
|
||||
class MaybeGetStaticTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testGetStaticInt(self):
|
||||
x = 2
|
||||
self.assertEqual(x, du.maybe_get_static_value(x))
|
||||
self.assertAllClose(
|
||||
np.array(2.), du.maybe_get_static_value(x, dtype=np.float64))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testGetStaticNumpyArray(self):
|
||||
x = np.array(2, dtype=np.int32)
|
||||
self.assertEqual(x, du.maybe_get_static_value(x))
|
||||
self.assertAllClose(
|
||||
np.array(2.), du.maybe_get_static_value(x, dtype=np.float64))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testGetStaticConstant(self):
|
||||
x = constant_op.constant(2, dtype=dtypes.int32)
|
||||
self.assertEqual(np.array(2, dtype=np.int32), du.maybe_get_static_value(x))
|
||||
@ -173,6 +179,7 @@ class MaybeGetStaticTest(test.TestCase):
|
||||
|
||||
class GetLogitsAndProbsTest(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testImproperArguments(self):
|
||||
with self.test_session():
|
||||
with self.assertRaises(ValueError):
|
||||
@ -181,6 +188,7 @@ class GetLogitsAndProbsTest(test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
du.get_logits_and_probs(logits=[0.1], probs=[0.1])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testLogits(self):
|
||||
p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)
|
||||
logits = _logit(p)
|
||||
@ -189,9 +197,10 @@ class GetLogitsAndProbsTest(test.TestCase):
|
||||
new_logits, new_p = du.get_logits_and_probs(
|
||||
logits=logits, validate_args=True)
|
||||
|
||||
self.assertAllClose(p, new_p.eval(), rtol=1e-5, atol=0.)
|
||||
self.assertAllClose(logits, new_logits.eval(), rtol=1e-5, atol=0.)
|
||||
self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.)
|
||||
self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testLogitsMultidimensional(self):
|
||||
p = np.array([0.2, 0.3, 0.5], dtype=np.float32)
|
||||
logits = np.log(p)
|
||||
@ -200,9 +209,10 @@ class GetLogitsAndProbsTest(test.TestCase):
|
||||
new_logits, new_p = du.get_logits_and_probs(
|
||||
logits=logits, multidimensional=True, validate_args=True)
|
||||
|
||||
self.assertAllClose(new_p.eval(), p)
|
||||
self.assertAllClose(new_logits.eval(), logits)
|
||||
self.assertAllClose(self.evaluate(new_p), p)
|
||||
self.assertAllClose(self.evaluate(new_logits), logits)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testProbability(self):
|
||||
p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)
|
||||
|
||||
@ -210,9 +220,10 @@ class GetLogitsAndProbsTest(test.TestCase):
|
||||
new_logits, new_p = du.get_logits_and_probs(
|
||||
probs=p, validate_args=True)
|
||||
|
||||
self.assertAllClose(_logit(p), new_logits.eval())
|
||||
self.assertAllClose(p, new_p.eval())
|
||||
self.assertAllClose(_logit(p), self.evaluate(new_logits))
|
||||
self.assertAllClose(p, self.evaluate(new_p))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testProbabilityMultidimensional(self):
|
||||
p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
|
||||
|
||||
@ -220,9 +231,10 @@ class GetLogitsAndProbsTest(test.TestCase):
|
||||
new_logits, new_p = du.get_logits_and_probs(
|
||||
probs=p, multidimensional=True, validate_args=True)
|
||||
|
||||
self.assertAllClose(np.log(p), new_logits.eval())
|
||||
self.assertAllClose(p, new_p.eval())
|
||||
self.assertAllClose(np.log(p), self.evaluate(new_logits))
|
||||
self.assertAllClose(p, self.evaluate(new_p))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testProbabilityValidateArgs(self):
|
||||
p = [0.01, 0.2, 0.5, 0.7, .99]
|
||||
# Component less than 0.
|
||||
@ -233,26 +245,27 @@ class GetLogitsAndProbsTest(test.TestCase):
|
||||
with self.test_session():
|
||||
_, prob = du.get_logits_and_probs(
|
||||
probs=p, validate_args=True)
|
||||
prob.eval()
|
||||
self.evaluate(prob)
|
||||
|
||||
with self.assertRaisesOpError("Condition x >= 0"):
|
||||
_, prob = du.get_logits_and_probs(
|
||||
probs=p2, validate_args=True)
|
||||
prob.eval()
|
||||
self.evaluate(prob)
|
||||
|
||||
_, prob = du.get_logits_and_probs(
|
||||
probs=p2, validate_args=False)
|
||||
prob.eval()
|
||||
self.evaluate(prob)
|
||||
|
||||
with self.assertRaisesOpError("probs has components greater than 1"):
|
||||
_, prob = du.get_logits_and_probs(
|
||||
probs=p3, validate_args=True)
|
||||
prob.eval()
|
||||
self.evaluate(prob)
|
||||
|
||||
_, prob = du.get_logits_and_probs(
|
||||
probs=p3, validate_args=False)
|
||||
prob.eval()
|
||||
self.evaluate(prob)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testProbabilityValidateArgsMultidimensional(self):
|
||||
p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
|
||||
# Component less than 0. Still sums to 1.
|
||||
@ -265,35 +278,35 @@ class GetLogitsAndProbsTest(test.TestCase):
|
||||
with self.test_session():
|
||||
_, prob = du.get_logits_and_probs(
|
||||
probs=p, multidimensional=True)
|
||||
prob.eval()
|
||||
self.evaluate(prob)
|
||||
|
||||
with self.assertRaisesOpError("Condition x >= 0"):
|
||||
_, prob = du.get_logits_and_probs(
|
||||
probs=p2, multidimensional=True, validate_args=True)
|
||||
prob.eval()
|
||||
self.evaluate(prob)
|
||||
|
||||
_, prob = du.get_logits_and_probs(
|
||||
probs=p2, multidimensional=True, validate_args=False)
|
||||
prob.eval()
|
||||
self.evaluate(prob)
|
||||
|
||||
with self.assertRaisesOpError(
|
||||
"(probs has components greater than 1|probs does not sum to 1)"):
|
||||
_, prob = du.get_logits_and_probs(
|
||||
probs=p3, multidimensional=True, validate_args=True)
|
||||
prob.eval()
|
||||
self.evaluate(prob)
|
||||
|
||||
_, prob = du.get_logits_and_probs(
|
||||
probs=p3, multidimensional=True, validate_args=False)
|
||||
prob.eval()
|
||||
self.evaluate(prob)
|
||||
|
||||
with self.assertRaisesOpError("probs does not sum to 1"):
|
||||
_, prob = du.get_logits_and_probs(
|
||||
probs=p4, multidimensional=True, validate_args=True)
|
||||
prob.eval()
|
||||
self.evaluate(prob)
|
||||
|
||||
_, prob = du.get_logits_and_probs(
|
||||
probs=p4, multidimensional=True, validate_args=False)
|
||||
prob.eval()
|
||||
self.evaluate(prob)
|
||||
|
||||
def testProbsMultidimShape(self):
|
||||
with self.test_session():
|
||||
@ -354,6 +367,7 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase):
|
||||
param)
|
||||
checked_param.eval(feed_dict={param: np.ones([int(2**11+1)])})
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testUnsupportedDtype(self):
|
||||
with self.test_session():
|
||||
with self.assertRaises(TypeError):
|
||||
@ -396,6 +410,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase):
|
||||
x_checked.eval(feed_dict={x: np.array([1, -1], dtype=np.int32)})
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LogCombinationsTest(test.TestCase):
|
||||
|
||||
def testLogCombinationsBinomial(self):
|
||||
@ -412,7 +427,7 @@ class LogCombinationsTest(test.TestCase):
|
||||
counts = [[1., 1], [2., 3], [4., 8], [11, 4]]
|
||||
log_binom = du.log_combinations(n, counts)
|
||||
self.assertEqual([4], log_binom.get_shape())
|
||||
self.assertAllClose(log_combs, log_binom.eval())
|
||||
self.assertAllClose(log_combs, self.evaluate(log_binom))
|
||||
|
||||
def testLogCombinationsShape(self):
|
||||
# Shape [2, 2]
|
||||
@ -537,14 +552,20 @@ class RotateTransposeTest(test.TestCase):
|
||||
x = np.array(x)
|
||||
return np.transpose(x, np.roll(np.arange(len(x.shape)), shift))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testRollStatic(self):
|
||||
with self.test_session():
|
||||
with self.assertRaisesRegexp(ValueError, "None values not supported."):
|
||||
if context.executing_eagerly():
|
||||
error_message = r"Attempt to convert a value \(None\)"
|
||||
else:
|
||||
error_message = "None values not supported."
|
||||
with self.assertRaisesRegexp(ValueError, error_message):
|
||||
du.rotate_transpose(None, 1)
|
||||
for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))):
|
||||
for shift in np.arange(-5, 5):
|
||||
y = du.rotate_transpose(x, shift)
|
||||
self.assertAllEqual(self._np_rotate_transpose(x, shift), y.eval())
|
||||
self.assertAllEqual(
|
||||
self._np_rotate_transpose(x, shift), self.evaluate(y))
|
||||
self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list())
|
||||
|
||||
def testRollDynamic(self):
|
||||
@ -569,12 +590,10 @@ class PickVectorTest(test.TestCase):
|
||||
with self.test_session():
|
||||
x = np.arange(10, 12)
|
||||
y = np.arange(15, 18)
|
||||
self.assertAllEqual(x,
|
||||
du.pick_vector(
|
||||
math_ops.less(0, 5), x, y).eval())
|
||||
self.assertAllEqual(y,
|
||||
du.pick_vector(
|
||||
math_ops.less(5, 0), x, y).eval())
|
||||
self.assertAllEqual(
|
||||
x, self.evaluate(du.pick_vector(math_ops.less(0, 5), x, y)))
|
||||
self.assertAllEqual(
|
||||
y, self.evaluate(du.pick_vector(math_ops.less(5, 0), x, y)))
|
||||
self.assertAllEqual(x,
|
||||
du.pick_vector(
|
||||
constant_op.constant(True), x, y)) # No eval.
|
||||
@ -870,25 +889,25 @@ class ReduceWeightedLogSumExp(test.TestCase):
|
||||
[1, 1, 1]])
|
||||
|
||||
self.assertAllClose(
|
||||
np.log(4),
|
||||
du.reduce_weighted_logsumexp(x, w).eval())
|
||||
np.log(4), self.evaluate(du.reduce_weighted_logsumexp(x, w)))
|
||||
|
||||
with np.errstate(divide="ignore"):
|
||||
self.assertAllClose(
|
||||
np.log([0, 2, 2]),
|
||||
du.reduce_weighted_logsumexp(x, w, axis=0).eval())
|
||||
self.evaluate(du.reduce_weighted_logsumexp(x, w, axis=0)))
|
||||
|
||||
self.assertAllClose(
|
||||
np.log([1, 3]),
|
||||
du.reduce_weighted_logsumexp(x, w, axis=1).eval())
|
||||
self.evaluate(du.reduce_weighted_logsumexp(x, w, axis=1)))
|
||||
|
||||
self.assertAllClose(
|
||||
np.log([[1], [3]]),
|
||||
du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True).eval())
|
||||
self.evaluate(
|
||||
du.reduce_weighted_logsumexp(x, w, axis=1, keep_dims=True)))
|
||||
|
||||
self.assertAllClose(
|
||||
np.log(4),
|
||||
du.reduce_weighted_logsumexp(x, w, axis=[0, 1]).eval())
|
||||
self.evaluate(du.reduce_weighted_logsumexp(x, w, axis=[0, 1])))
|
||||
|
||||
|
||||
class GenNewSeedTest(test.TestCase):
|
||||
@ -986,7 +1005,7 @@ class SoftplusTest(test.TestCase):
|
||||
# Note that this range contains both zero and inf.
|
||||
x = constant_op.constant(np.logspace(-8, 6).astype(np.float16))
|
||||
y = du.softplus_inverse(x)
|
||||
grads = gradients_impl.gradients(y, x)[0].eval()
|
||||
grads = self.evaluate(gradients_impl.gradients(y, x)[0])
|
||||
# Equivalent to `assertAllFalse` (if it existed).
|
||||
self.assertAllEqual(np.zeros_like(grads).astype(np.bool), np.isnan(grads))
|
||||
|
||||
@ -996,11 +1015,13 @@ class SoftplusTest(test.TestCase):
|
||||
# gradient and its approximations should be finite as well.
|
||||
x = constant_op.constant(np.logspace(-4.8, 4.5).astype(np.float16))
|
||||
y = du.softplus_inverse(x)
|
||||
grads = gradients_impl.gradients(y, x)[0].eval()
|
||||
grads = self.evaluate(gradients_impl.gradients(y, x)[0])
|
||||
# Equivalent to `assertAllTrue` (if it existed).
|
||||
self.assertAllEqual(
|
||||
np.ones_like(grads).astype(np.bool), np.isfinite(grads))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ArgumentsTest(test.TestCase):
|
||||
|
||||
def testNoArguments(self):
|
||||
|
Loading…
Reference in New Issue
Block a user