contrib/distributions: Test code cleanups

- Remove unnecessary test_session() boilerplate when executing eagerly
- Use self.cached_session() instead of self.test_session() when using graphs

self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.

PiperOrigin-RevId: 211542360
This commit is contained in:
Asim Shankar 2018-09-04 16:05:05 -07:00 committed by TensorFlower Gardener
parent 0065d3389a
commit ec6ea3ad0a
12 changed files with 1766 additions and 1979 deletions

View File

@ -62,57 +62,48 @@ class BernoulliTest(test.TestCase):
def testP(self): def testP(self):
p = [0.2, 0.4] p = [0.2, 0.4]
dist = bernoulli.Bernoulli(probs=p) dist = bernoulli.Bernoulli(probs=p)
with self.test_session():
self.assertAllClose(p, self.evaluate(dist.probs)) self.assertAllClose(p, self.evaluate(dist.probs))
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testLogits(self): def testLogits(self):
logits = [-42., 42.] logits = [-42., 42.]
dist = bernoulli.Bernoulli(logits=logits) dist = bernoulli.Bernoulli(logits=logits)
with self.test_session():
self.assertAllClose(logits, self.evaluate(dist.logits)) self.assertAllClose(logits, self.evaluate(dist.logits))
if not special: if not special:
return return
with self.test_session():
self.assertAllClose(special.expit(logits), self.evaluate(dist.probs)) self.assertAllClose(special.expit(logits), self.evaluate(dist.probs))
p = [0.01, 0.99, 0.42] p = [0.01, 0.99, 0.42]
dist = bernoulli.Bernoulli(probs=p) dist = bernoulli.Bernoulli(probs=p)
with self.test_session():
self.assertAllClose(special.logit(p), self.evaluate(dist.logits)) self.assertAllClose(special.logit(p), self.evaluate(dist.logits))
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testInvalidP(self): def testInvalidP(self):
invalid_ps = [1.01, 2.] invalid_ps = [1.01, 2.]
for p in invalid_ps: for p in invalid_ps:
with self.test_session():
with self.assertRaisesOpError("probs has components greater than 1"): with self.assertRaisesOpError("probs has components greater than 1"):
dist = bernoulli.Bernoulli(probs=p, validate_args=True) dist = bernoulli.Bernoulli(probs=p, validate_args=True)
self.evaluate(dist.probs) self.evaluate(dist.probs)
invalid_ps = [-0.01, -3.] invalid_ps = [-0.01, -3.]
for p in invalid_ps: for p in invalid_ps:
with self.test_session():
with self.assertRaisesOpError("Condition x >= 0"): with self.assertRaisesOpError("Condition x >= 0"):
dist = bernoulli.Bernoulli(probs=p, validate_args=True) dist = bernoulli.Bernoulli(probs=p, validate_args=True)
self.evaluate(dist.probs) self.evaluate(dist.probs)
valid_ps = [0.0, 0.5, 1.0] valid_ps = [0.0, 0.5, 1.0]
for p in valid_ps: for p in valid_ps:
with self.test_session():
dist = bernoulli.Bernoulli(probs=p) dist = bernoulli.Bernoulli(probs=p)
self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testShapes(self): def testShapes(self):
with self.test_session():
for batch_shape in ([], [1], [2, 3, 4]): for batch_shape in ([], [1], [2, 3, 4]):
dist = make_bernoulli(batch_shape) dist = make_bernoulli(batch_shape)
self.assertAllEqual(batch_shape, dist.batch_shape.as_list()) self.assertAllEqual(batch_shape, dist.batch_shape.as_list())
self.assertAllEqual(batch_shape, self.assertAllEqual(batch_shape, self.evaluate(dist.batch_shape_tensor()))
self.evaluate(dist.batch_shape_tensor()))
self.assertAllEqual([], dist.event_shape.as_list()) self.assertAllEqual([], dist.event_shape.as_list())
self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
@ -137,7 +128,6 @@ class BernoulliTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def _testPmf(self, **kwargs): def _testPmf(self, **kwargs):
dist = bernoulli.Bernoulli(**kwargs) dist = bernoulli.Bernoulli(**kwargs)
with self.test_session():
# pylint: disable=bad-continuation # pylint: disable=bad-continuation
xs = [ xs = [
0, 0,
@ -157,11 +147,10 @@ class BernoulliTest(test.TestCase):
for x, expected_pmf in zip(xs, expected_pmfs): for x, expected_pmf in zip(xs, expected_pmfs):
self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf) self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf)
self.assertAllClose( self.assertAllClose(self.evaluate(dist.log_prob(x)), np.log(expected_pmf))
self.evaluate(dist.log_prob(x)), np.log(expected_pmf))
def testPmfCorrectBroadcastDynamicShape(self): def testPmfCorrectBroadcastDynamicShape(self):
with self.test_session(): with self.cached_session():
p = array_ops.placeholder(dtype=dtypes.float32) p = array_ops.placeholder(dtype=dtypes.float32)
dist = bernoulli.Bernoulli(probs=p) dist = bernoulli.Bernoulli(probs=p)
event1 = [1, 0, 1] event1 = [1, 0, 1]
@ -178,7 +167,6 @@ class BernoulliTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testPmfInvalid(self): def testPmfInvalid(self):
p = [0.1, 0.2, 0.7] p = [0.1, 0.2, 0.7]
with self.test_session():
dist = bernoulli.Bernoulli(probs=p, validate_args=True) dist = bernoulli.Bernoulli(probs=p, validate_args=True)
with self.assertRaisesOpError("must be non-negative."): with self.assertRaisesOpError("must be non-negative."):
self.evaluate(dist.prob([1, 1, -1])) self.evaluate(dist.prob([1, 1, -1]))
@ -194,7 +182,7 @@ class BernoulliTest(test.TestCase):
self._testPmf(logits=special.logit(p)) self._testPmf(logits=special.logit(p))
def testBroadcasting(self): def testBroadcasting(self):
with self.test_session(): with self.cached_session():
p = array_ops.placeholder(dtypes.float32) p = array_ops.placeholder(dtypes.float32)
dist = bernoulli.Bernoulli(probs=p) dist = bernoulli.Bernoulli(probs=p)
self.assertAllClose(np.log(0.5), dist.log_prob(1).eval({p: 0.5})) self.assertAllClose(np.log(0.5), dist.log_prob(1).eval({p: 0.5}))
@ -208,28 +196,24 @@ class BernoulliTest(test.TestCase):
})) }))
def testPmfShapes(self): def testPmfShapes(self):
with self.test_session(): with self.cached_session():
p = array_ops.placeholder(dtypes.float32, shape=[None, 1]) p = array_ops.placeholder(dtypes.float32, shape=[None, 1])
dist = bernoulli.Bernoulli(probs=p) dist = bernoulli.Bernoulli(probs=p)
self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape)) self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape))
with self.test_session():
dist = bernoulli.Bernoulli(probs=0.5) dist = bernoulli.Bernoulli(probs=0.5)
self.assertEqual(2, len(self.evaluate(dist.log_prob([[1], [1]])).shape)) self.assertEqual(2, len(self.evaluate(dist.log_prob([[1], [1]])).shape))
with self.test_session():
dist = bernoulli.Bernoulli(probs=0.5) dist = bernoulli.Bernoulli(probs=0.5)
self.assertEqual((), dist.log_prob(1).get_shape()) self.assertEqual((), dist.log_prob(1).get_shape())
self.assertEqual((1), dist.log_prob([1]).get_shape()) self.assertEqual((1), dist.log_prob([1]).get_shape())
self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape()) self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape())
with self.test_session():
dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]]) dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]])
self.assertEqual((2, 1), dist.log_prob(1).get_shape()) self.assertEqual((2, 1), dist.log_prob(1).get_shape())
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testBoundaryConditions(self): def testBoundaryConditions(self):
with self.test_session():
dist = bernoulli.Bernoulli(probs=1.0) dist = bernoulli.Bernoulli(probs=1.0)
self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0))) self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0)))
self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))]) self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))])
@ -238,14 +222,12 @@ class BernoulliTest(test.TestCase):
def testEntropyNoBatch(self): def testEntropyNoBatch(self):
p = 0.2 p = 0.2
dist = bernoulli.Bernoulli(probs=p) dist = bernoulli.Bernoulli(probs=p)
with self.test_session():
self.assertAllClose(self.evaluate(dist.entropy()), entropy(p)) self.assertAllClose(self.evaluate(dist.entropy()), entropy(p))
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testEntropyWithBatch(self): def testEntropyWithBatch(self):
p = [[0.1, 0.7], [0.2, 0.6]] p = [[0.1, 0.7], [0.2, 0.6]]
dist = bernoulli.Bernoulli(probs=p, validate_args=False) dist = bernoulli.Bernoulli(probs=p, validate_args=False)
with self.test_session():
self.assertAllClose( self.assertAllClose(
self.evaluate(dist.entropy()), self.evaluate(dist.entropy()),
[[entropy(0.1), entropy(0.7)], [entropy(0.2), [[entropy(0.1), entropy(0.7)], [entropy(0.2),
@ -253,7 +235,6 @@ class BernoulliTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testSampleN(self): def testSampleN(self):
with self.test_session():
p = [0.2, 0.6] p = [0.2, 0.6]
dist = bernoulli.Bernoulli(probs=p) dist = bernoulli.Bernoulli(probs=p)
n = 100000 n = 100000
@ -284,7 +265,7 @@ class BernoulliTest(test.TestCase):
self.assertIsNone(grad_p) self.assertIsNone(grad_p)
def testSampleActsLikeSampleN(self): def testSampleActsLikeSampleN(self):
with self.test_session() as sess: with self.cached_session() as sess:
p = [0.2, 0.6] p = [0.2, 0.6]
dist = bernoulli.Bernoulli(probs=p) dist = bernoulli.Bernoulli(probs=p)
n = 1000 n = 1000
@ -299,7 +280,6 @@ class BernoulliTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testMean(self): def testMean(self):
with self.test_session():
p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32) p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
dist = bernoulli.Bernoulli(probs=p) dist = bernoulli.Bernoulli(probs=p)
self.assertAllEqual(self.evaluate(dist.mean()), p) self.assertAllEqual(self.evaluate(dist.mean()), p)
@ -307,17 +287,15 @@ class BernoulliTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testVarianceAndStd(self): def testVarianceAndStd(self):
var = lambda p: p * (1. - p) var = lambda p: p * (1. - p)
with self.test_session():
p = [[0.2, 0.7], [0.5, 0.4]] p = [[0.2, 0.7], [0.5, 0.4]]
dist = bernoulli.Bernoulli(probs=p) dist = bernoulli.Bernoulli(probs=p)
self.assertAllClose( self.assertAllClose(
self.evaluate(dist.variance()), self.evaluate(dist.variance()),
np.array( np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]],
[[var(0.2), var(0.7)], [var(0.5), var(0.4)]], dtype=np.float32)) dtype=np.float32))
self.assertAllClose( self.assertAllClose(
self.evaluate(dist.stddev()), self.evaluate(dist.stddev()),
np.array( np.array([[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
[[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
[np.sqrt(var(0.5)), np.sqrt(var(0.4))]], [np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
dtype=np.float32)) dtype=np.float32))

View File

@ -20,7 +20,6 @@ import importlib
import numpy as np import numpy as np
from tensorflow.python.client import session
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
@ -51,7 +50,6 @@ stats = try_import("scipy.stats")
class BetaTest(test.TestCase): class BetaTest(test.TestCase):
def testSimpleShapes(self): def testSimpleShapes(self):
with self.test_session():
a = np.random.rand(3) a = np.random.rand(3)
b = np.random.rand(3) b = np.random.rand(3)
dist = beta_lib.Beta(a, b) dist = beta_lib.Beta(a, b)
@ -61,31 +59,26 @@ class BetaTest(test.TestCase):
self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape) self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape)
def testComplexShapes(self): def testComplexShapes(self):
with self.test_session():
a = np.random.rand(3, 2, 2) a = np.random.rand(3, 2, 2)
b = np.random.rand(3, 2, 2) b = np.random.rand(3, 2, 2)
dist = beta_lib.Beta(a, b) dist = beta_lib.Beta(a, b)
self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor())) self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
self.assertEqual( self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
def testComplexShapesBroadcast(self): def testComplexShapesBroadcast(self):
with self.test_session():
a = np.random.rand(3, 2, 2) a = np.random.rand(3, 2, 2)
b = np.random.rand(2, 2) b = np.random.rand(2, 2)
dist = beta_lib.Beta(a, b) dist = beta_lib.Beta(a, b)
self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor())) self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
self.assertEqual( self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
def testAlphaProperty(self): def testAlphaProperty(self):
a = [[1., 2, 3]] a = [[1., 2, 3]]
b = [[2., 4, 3]] b = [[2., 4, 3]]
with self.test_session():
dist = beta_lib.Beta(a, b) dist = beta_lib.Beta(a, b)
self.assertEqual([1, 3], dist.concentration1.get_shape()) self.assertEqual([1, 3], dist.concentration1.get_shape())
self.assertAllClose(a, self.evaluate(dist.concentration1)) self.assertAllClose(a, self.evaluate(dist.concentration1))
@ -93,7 +86,6 @@ class BetaTest(test.TestCase):
def testBetaProperty(self): def testBetaProperty(self):
a = [[1., 2, 3]] a = [[1., 2, 3]]
b = [[2., 4, 3]] b = [[2., 4, 3]]
with self.test_session():
dist = beta_lib.Beta(a, b) dist = beta_lib.Beta(a, b)
self.assertEqual([1, 3], dist.concentration0.get_shape()) self.assertEqual([1, 3], dist.concentration0.get_shape())
self.assertAllClose(b, self.evaluate(dist.concentration0)) self.assertAllClose(b, self.evaluate(dist.concentration0))
@ -101,7 +93,6 @@ class BetaTest(test.TestCase):
def testPdfXProper(self): def testPdfXProper(self):
a = [[1., 2, 3]] a = [[1., 2, 3]]
b = [[2., 4, 3]] b = [[2., 4, 3]]
with self.test_session():
dist = beta_lib.Beta(a, b, validate_args=True) dist = beta_lib.Beta(a, b, validate_args=True)
self.evaluate(dist.prob([.1, .3, .6])) self.evaluate(dist.prob([.1, .3, .6]))
self.evaluate(dist.prob([.2, .3, .5])) self.evaluate(dist.prob([.2, .3, .5]))
@ -116,7 +107,6 @@ class BetaTest(test.TestCase):
self.evaluate(dist.prob([.1, .2, 1.0])) self.evaluate(dist.prob([.1, .2, 1.0]))
def testPdfTwoBatches(self): def testPdfTwoBatches(self):
with self.test_session():
a = [1., 2] a = [1., 2]
b = [1., 2] b = [1., 2]
x = [.5, .5] x = [.5, .5]
@ -126,7 +116,6 @@ class BetaTest(test.TestCase):
self.assertEqual((2,), pdf.get_shape()) self.assertEqual((2,), pdf.get_shape())
def testPdfTwoBatchesNontrivialX(self): def testPdfTwoBatchesNontrivialX(self):
with self.test_session():
a = [1., 2] a = [1., 2]
b = [1., 2] b = [1., 2]
x = [.3, .7] x = [.3, .7]
@ -136,7 +125,6 @@ class BetaTest(test.TestCase):
self.assertEqual((2,), pdf.get_shape()) self.assertEqual((2,), pdf.get_shape())
def testPdfUniformZeroBatch(self): def testPdfUniformZeroBatch(self):
with self.test_session():
# This is equivalent to a uniform distribution # This is equivalent to a uniform distribution
a = 1. a = 1.
b = 1. b = 1.
@ -147,7 +135,6 @@ class BetaTest(test.TestCase):
self.assertEqual((5,), pdf.get_shape()) self.assertEqual((5,), pdf.get_shape())
def testPdfAlphaStretchedInBroadcastWhenSameRank(self): def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
with self.test_session():
a = [[1., 2]] a = [[1., 2]]
b = [[1., 2]] b = [[1., 2]]
x = [[.5, .5], [.3, .7]] x = [[.5, .5], [.3, .7]]
@ -157,7 +144,6 @@ class BetaTest(test.TestCase):
self.assertEqual((2, 2), pdf.get_shape()) self.assertEqual((2, 2), pdf.get_shape())
def testPdfAlphaStretchedInBroadcastWhenLowerRank(self): def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
with self.test_session():
a = [1., 2] a = [1., 2]
b = [1., 2] b = [1., 2]
x = [[.5, .5], [.2, .8]] x = [[.5, .5], [.2, .8]]
@ -166,7 +152,6 @@ class BetaTest(test.TestCase):
self.assertEqual((2, 2), pdf.get_shape()) self.assertEqual((2, 2), pdf.get_shape())
def testPdfXStretchedInBroadcastWhenSameRank(self): def testPdfXStretchedInBroadcastWhenSameRank(self):
with self.test_session():
a = [[1., 2], [2., 3]] a = [[1., 2], [2., 3]]
b = [[1., 2], [2., 3]] b = [[1., 2], [2., 3]]
x = [[.5, .5]] x = [[.5, .5]]
@ -175,7 +160,6 @@ class BetaTest(test.TestCase):
self.assertEqual((2, 2), pdf.get_shape()) self.assertEqual((2, 2), pdf.get_shape())
def testPdfXStretchedInBroadcastWhenLowerRank(self): def testPdfXStretchedInBroadcastWhenLowerRank(self):
with self.test_session():
a = [[1., 2], [2., 3]] a = [[1., 2], [2., 3]]
b = [[1., 2], [2., 3]] b = [[1., 2], [2., 3]]
x = [.5, .5] x = [.5, .5]
@ -184,7 +168,6 @@ class BetaTest(test.TestCase):
self.assertEqual((2, 2), pdf.get_shape()) self.assertEqual((2, 2), pdf.get_shape())
def testBetaMean(self): def testBetaMean(self):
with session.Session():
a = [1., 2, 3] a = [1., 2, 3]
b = [2., 4, 1.2] b = [2., 4, 1.2]
dist = beta_lib.Beta(a, b) dist = beta_lib.Beta(a, b)
@ -195,7 +178,6 @@ class BetaTest(test.TestCase):
self.assertAllClose(expected_mean, self.evaluate(dist.mean())) self.assertAllClose(expected_mean, self.evaluate(dist.mean()))
def testBetaVariance(self): def testBetaVariance(self):
with session.Session():
a = [1., 2, 3] a = [1., 2, 3]
b = [2., 4, 1.2] b = [2., 4, 1.2]
dist = beta_lib.Beta(a, b) dist = beta_lib.Beta(a, b)
@ -206,7 +188,6 @@ class BetaTest(test.TestCase):
self.assertAllClose(expected_variance, self.evaluate(dist.variance())) self.assertAllClose(expected_variance, self.evaluate(dist.variance()))
def testBetaMode(self): def testBetaMode(self):
with session.Session():
a = np.array([1.1, 2, 3]) a = np.array([1.1, 2, 3])
b = np.array([2., 4, 1.2]) b = np.array([2., 4, 1.2])
expected_mode = (a - 1) / (a + b - 2) expected_mode = (a - 1) / (a + b - 2)
@ -215,7 +196,6 @@ class BetaTest(test.TestCase):
self.assertAllClose(expected_mode, self.evaluate(dist.mode())) self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
def testBetaModeInvalid(self): def testBetaModeInvalid(self):
with session.Session():
a = np.array([1., 2, 3]) a = np.array([1., 2, 3])
b = np.array([2., 4, 1.2]) b = np.array([2., 4, 1.2])
dist = beta_lib.Beta(a, b, allow_nan_stats=False) dist = beta_lib.Beta(a, b, allow_nan_stats=False)
@ -229,7 +209,6 @@ class BetaTest(test.TestCase):
self.evaluate(dist.mode()) self.evaluate(dist.mode())
def testBetaModeEnableAllowNanStats(self): def testBetaModeEnableAllowNanStats(self):
with session.Session():
a = np.array([1., 2, 3]) a = np.array([1., 2, 3])
b = np.array([2., 4, 1.2]) b = np.array([2., 4, 1.2])
dist = beta_lib.Beta(a, b, allow_nan_stats=True) dist = beta_lib.Beta(a, b, allow_nan_stats=True)
@ -249,7 +228,6 @@ class BetaTest(test.TestCase):
self.assertAllClose(expected_mode, self.evaluate(dist.mode())) self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
def testBetaEntropy(self): def testBetaEntropy(self):
with session.Session():
a = [1., 2, 3] a = [1., 2, 3]
b = [2., 4, 1.2] b = [2., 4, 1.2]
dist = beta_lib.Beta(a, b) dist = beta_lib.Beta(a, b)
@ -260,7 +238,6 @@ class BetaTest(test.TestCase):
self.assertAllClose(expected_entropy, self.evaluate(dist.entropy())) self.assertAllClose(expected_entropy, self.evaluate(dist.entropy()))
def testBetaSample(self): def testBetaSample(self):
with self.test_session():
a = 1. a = 1.
b = 2. b = 2.
beta = beta_lib.Beta(a, b) beta = beta_lib.Beta(a, b)
@ -297,27 +274,23 @@ class BetaTest(test.TestCase):
# Test that sampling with the same seed twice gives the same results. # Test that sampling with the same seed twice gives the same results.
def testBetaSampleMultipleTimes(self): def testBetaSampleMultipleTimes(self):
with self.test_session():
a_val = 1. a_val = 1.
b_val = 2. b_val = 2.
n_val = 100 n_val = 100
random_seed.set_random_seed(654321) random_seed.set_random_seed(654321)
beta1 = beta_lib.Beta(concentration1=a_val, beta1 = beta_lib.Beta(
concentration0=b_val, concentration1=a_val, concentration0=b_val, name="beta1")
name="beta1")
samples1 = self.evaluate(beta1.sample(n_val, seed=123456)) samples1 = self.evaluate(beta1.sample(n_val, seed=123456))
random_seed.set_random_seed(654321) random_seed.set_random_seed(654321)
beta2 = beta_lib.Beta(concentration1=a_val, beta2 = beta_lib.Beta(
concentration0=b_val, concentration1=a_val, concentration0=b_val, name="beta2")
name="beta2")
samples2 = self.evaluate(beta2.sample(n_val, seed=123456)) samples2 = self.evaluate(beta2.sample(n_val, seed=123456))
self.assertAllClose(samples1, samples2) self.assertAllClose(samples1, samples2)
def testBetaSampleMultidimensional(self): def testBetaSampleMultidimensional(self):
with self.test_session():
a = np.random.rand(3, 2, 2).astype(np.float32) a = np.random.rand(3, 2, 2).astype(np.float32)
b = np.random.rand(3, 2, 2).astype(np.float32) b = np.random.rand(3, 2, 2).astype(np.float32)
beta = beta_lib.Beta(a, b) beta = beta_lib.Beta(a, b)
@ -334,7 +307,6 @@ class BetaTest(test.TestCase):
atol=1e-1) atol=1e-1)
def testBetaCdf(self): def testBetaCdf(self):
with self.test_session():
shape = (30, 40, 50) shape = (30, 40, 50)
for dt in (np.float32, np.float64): for dt in (np.float32, np.float64):
a = 10. * np.random.random(shape).astype(dt) a = 10. * np.random.random(shape).astype(dt)
@ -348,7 +320,6 @@ class BetaTest(test.TestCase):
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
def testBetaLogCdf(self): def testBetaLogCdf(self):
with self.test_session():
shape = (30, 40, 50) shape = (30, 40, 50)
for dt in (np.float32, np.float64): for dt in (np.float32, np.float64):
a = 10. * np.random.random(shape).astype(dt) a = 10. * np.random.random(shape).astype(dt)
@ -362,7 +333,6 @@ class BetaTest(test.TestCase):
self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
def testBetaWithSoftplusConcentration(self): def testBetaWithSoftplusConcentration(self):
with self.test_session():
a, b = -4.2, -9.1 a, b = -4.2, -9.1
dist = beta_lib.BetaWithSoftplusConcentration(a, b) dist = beta_lib.BetaWithSoftplusConcentration(a, b)
self.assertAllClose( self.assertAllClose(

View File

@ -36,7 +36,6 @@ class BaseBijectorTest(test.TestCase):
"""Tests properties of the Bijector base-class.""" """Tests properties of the Bijector base-class."""
def testIsAbstract(self): def testIsAbstract(self):
with self.test_session():
with self.assertRaisesRegexp(TypeError, with self.assertRaisesRegexp(TypeError,
("Can't instantiate abstract class Bijector " ("Can't instantiate abstract class Bijector "
"with abstract methods __init__")): "with abstract methods __init__")):
@ -136,7 +135,7 @@ class BijectorTestEventNdims(test.TestCase):
def testBijectorDynamicEventNdims(self): def testBijectorDynamicEventNdims(self):
bij = BrokenBijector(validate_args=True) bij = BrokenBijector(validate_args=True)
event_ndims = array_ops.placeholder(dtype=np.int32, shape=None) event_ndims = array_ops.placeholder(dtype=np.int32, shape=None)
with self.test_session(): with self.cached_session():
with self.assertRaisesOpError("Expected scalar"): with self.assertRaisesOpError("Expected scalar"):
bij.forward_log_det_jacobian(1., event_ndims=event_ndims).eval({ bij.forward_log_det_jacobian(1., event_ndims=event_ndims).eval({
event_ndims: (1, 2)}) event_ndims: (1, 2)})
@ -308,7 +307,7 @@ class BijectorReduceEventDimsTest(test.TestCase):
event_ndims = array_ops.placeholder(dtype=np.int32, shape=[]) event_ndims = array_ops.placeholder(dtype=np.int32, shape=[])
bij = ExpOnlyJacobian(forward_min_event_ndims=1) bij = ExpOnlyJacobian(forward_min_event_ndims=1)
bij.inverse_log_det_jacobian(x, event_ndims=event_ndims) bij.inverse_log_det_jacobian(x, event_ndims=event_ndims)
with self.test_session() as sess: with self.cached_session() as sess:
ildj = sess.run(bij.inverse_log_det_jacobian(x, event_ndims=event_ndims), ildj = sess.run(bij.inverse_log_det_jacobian(x, event_ndims=event_ndims),
feed_dict={event_ndims: 1}) feed_dict={event_ndims: 1})
self.assertAllClose(-np.log(x_), ildj) self.assertAllClose(-np.log(x_), ildj)

View File

@ -49,7 +49,6 @@ stats = try_import("scipy.stats")
class DirichletTest(test.TestCase): class DirichletTest(test.TestCase):
def testSimpleShapes(self): def testSimpleShapes(self):
with self.test_session():
alpha = np.random.rand(3) alpha = np.random.rand(3)
dist = dirichlet_lib.Dirichlet(alpha) dist = dirichlet_lib.Dirichlet(alpha)
self.assertEqual(3, self.evaluate(dist.event_shape_tensor())) self.assertEqual(3, self.evaluate(dist.event_shape_tensor()))
@ -58,7 +57,6 @@ class DirichletTest(test.TestCase):
self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape) self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
def testComplexShapes(self): def testComplexShapes(self):
with self.test_session():
alpha = np.random.rand(3, 2, 2) alpha = np.random.rand(3, 2, 2)
dist = dirichlet_lib.Dirichlet(alpha) dist = dirichlet_lib.Dirichlet(alpha)
self.assertEqual(2, self.evaluate(dist.event_shape_tensor())) self.assertEqual(2, self.evaluate(dist.event_shape_tensor()))
@ -68,14 +66,12 @@ class DirichletTest(test.TestCase):
def testConcentrationProperty(self): def testConcentrationProperty(self):
alpha = [[1., 2, 3]] alpha = [[1., 2, 3]]
with self.test_session():
dist = dirichlet_lib.Dirichlet(alpha) dist = dirichlet_lib.Dirichlet(alpha)
self.assertEqual([1, 3], dist.concentration.get_shape()) self.assertEqual([1, 3], dist.concentration.get_shape())
self.assertAllClose(alpha, self.evaluate(dist.concentration)) self.assertAllClose(alpha, self.evaluate(dist.concentration))
def testPdfXProper(self): def testPdfXProper(self):
alpha = [[1., 2, 3]] alpha = [[1., 2, 3]]
with self.test_session():
dist = dirichlet_lib.Dirichlet(alpha, validate_args=True) dist = dirichlet_lib.Dirichlet(alpha, validate_args=True)
self.evaluate(dist.prob([.1, .3, .6])) self.evaluate(dist.prob([.1, .3, .6]))
self.evaluate(dist.prob([.2, .3, .5])) self.evaluate(dist.prob([.2, .3, .5]))
@ -84,12 +80,10 @@ class DirichletTest(test.TestCase):
self.evaluate(dist.prob([-1., 1.5, 0.5])) self.evaluate(dist.prob([-1., 1.5, 0.5]))
with self.assertRaisesOpError("samples must be positive"): with self.assertRaisesOpError("samples must be positive"):
self.evaluate(dist.prob([0., .1, .9])) self.evaluate(dist.prob([0., .1, .9]))
with self.assertRaisesOpError( with self.assertRaisesOpError("sample last-dimension must sum to `1`"):
"sample last-dimension must sum to `1`"):
self.evaluate(dist.prob([.1, .2, .8])) self.evaluate(dist.prob([.1, .2, .8]))
def testPdfZeroBatches(self): def testPdfZeroBatches(self):
with self.test_session():
alpha = [1., 2] alpha = [1., 2]
x = [.5, .5] x = [.5, .5]
dist = dirichlet_lib.Dirichlet(alpha) dist = dirichlet_lib.Dirichlet(alpha)
@ -98,7 +92,6 @@ class DirichletTest(test.TestCase):
self.assertEqual((), pdf.get_shape()) self.assertEqual((), pdf.get_shape())
def testPdfZeroBatchesNontrivialX(self): def testPdfZeroBatchesNontrivialX(self):
with self.test_session():
alpha = [1., 2] alpha = [1., 2]
x = [.3, .7] x = [.3, .7]
dist = dirichlet_lib.Dirichlet(alpha) dist = dirichlet_lib.Dirichlet(alpha)
@ -107,7 +100,6 @@ class DirichletTest(test.TestCase):
self.assertEqual((), pdf.get_shape()) self.assertEqual((), pdf.get_shape())
def testPdfUniformZeroBatches(self): def testPdfUniformZeroBatches(self):
with self.test_session():
# Corresponds to a uniform distribution # Corresponds to a uniform distribution
alpha = [1., 1, 1] alpha = [1., 1, 1]
x = [[.2, .5, .3], [.3, .4, .3]] x = [[.2, .5, .3], [.3, .4, .3]]
@ -117,7 +109,6 @@ class DirichletTest(test.TestCase):
self.assertEqual((2), pdf.get_shape()) self.assertEqual((2), pdf.get_shape())
def testPdfAlphaStretchedInBroadcastWhenSameRank(self): def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
with self.test_session():
alpha = [[1., 2]] alpha = [[1., 2]]
x = [[.5, .5], [.3, .7]] x = [[.5, .5], [.3, .7]]
dist = dirichlet_lib.Dirichlet(alpha) dist = dirichlet_lib.Dirichlet(alpha)
@ -126,7 +117,6 @@ class DirichletTest(test.TestCase):
self.assertEqual((2), pdf.get_shape()) self.assertEqual((2), pdf.get_shape())
def testPdfAlphaStretchedInBroadcastWhenLowerRank(self): def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
with self.test_session():
alpha = [1., 2] alpha = [1., 2]
x = [[.5, .5], [.2, .8]] x = [[.5, .5], [.2, .8]]
pdf = dirichlet_lib.Dirichlet(alpha).prob(x) pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
@ -134,7 +124,6 @@ class DirichletTest(test.TestCase):
self.assertEqual((2), pdf.get_shape()) self.assertEqual((2), pdf.get_shape())
def testPdfXStretchedInBroadcastWhenSameRank(self): def testPdfXStretchedInBroadcastWhenSameRank(self):
with self.test_session():
alpha = [[1., 2], [2., 3]] alpha = [[1., 2], [2., 3]]
x = [[.5, .5]] x = [[.5, .5]]
pdf = dirichlet_lib.Dirichlet(alpha).prob(x) pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
@ -142,7 +131,6 @@ class DirichletTest(test.TestCase):
self.assertEqual((2), pdf.get_shape()) self.assertEqual((2), pdf.get_shape())
def testPdfXStretchedInBroadcastWhenLowerRank(self): def testPdfXStretchedInBroadcastWhenLowerRank(self):
with self.test_session():
alpha = [[1., 2], [2., 3]] alpha = [[1., 2], [2., 3]]
x = [.5, .5] x = [.5, .5]
pdf = dirichlet_lib.Dirichlet(alpha).prob(x) pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
@ -150,7 +138,6 @@ class DirichletTest(test.TestCase):
self.assertEqual((2), pdf.get_shape()) self.assertEqual((2), pdf.get_shape())
def testMean(self): def testMean(self):
with self.test_session():
alpha = [1., 2, 3] alpha = [1., 2, 3]
dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
self.assertEqual(dirichlet.mean().get_shape(), [3]) self.assertEqual(dirichlet.mean().get_shape(), [3])
@ -197,7 +184,6 @@ class DirichletTest(test.TestCase):
self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
def testVariance(self): def testVariance(self):
with self.test_session():
alpha = [1., 2, 3] alpha = [1., 2, 3]
denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1) denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
@ -205,13 +191,12 @@ class DirichletTest(test.TestCase):
if not stats: if not stats:
return return
expected_covariance = np.diag(stats.dirichlet.var(alpha)) expected_covariance = np.diag(stats.dirichlet.var(alpha))
expected_covariance += [[0., -2, -3], [-2, 0, -6], expected_covariance += [[0., -2, -3], [-2, 0, -6], [-3, -6, 0]
[-3, -6, 0]] / denominator ] / denominator
self.assertAllClose( self.assertAllClose(
self.evaluate(dirichlet.covariance()), expected_covariance) self.evaluate(dirichlet.covariance()), expected_covariance)
def testMode(self): def testMode(self):
with self.test_session():
alpha = np.array([1.1, 2, 3]) alpha = np.array([1.1, 2, 3])
expected_mode = (alpha - 1) / (np.sum(alpha) - 3) expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
@ -219,25 +204,22 @@ class DirichletTest(test.TestCase):
self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode) self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
def testModeInvalid(self): def testModeInvalid(self):
with self.test_session():
alpha = np.array([1., 2, 3]) alpha = np.array([1., 2, 3])
dirichlet = dirichlet_lib.Dirichlet(concentration=alpha, dirichlet = dirichlet_lib.Dirichlet(
allow_nan_stats=False) concentration=alpha, allow_nan_stats=False)
with self.assertRaisesOpError("Condition x < y.*"): with self.assertRaisesOpError("Condition x < y.*"):
self.evaluate(dirichlet.mode()) self.evaluate(dirichlet.mode())
def testModeEnableAllowNanStats(self): def testModeEnableAllowNanStats(self):
with self.test_session():
alpha = np.array([1., 2, 3]) alpha = np.array([1., 2, 3])
dirichlet = dirichlet_lib.Dirichlet(concentration=alpha, dirichlet = dirichlet_lib.Dirichlet(
allow_nan_stats=True) concentration=alpha, allow_nan_stats=True)
expected_mode = np.zeros_like(alpha) + np.nan expected_mode = np.zeros_like(alpha) + np.nan
self.assertEqual(dirichlet.mode().get_shape(), [3]) self.assertEqual(dirichlet.mode().get_shape(), [3])
self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode) self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
def testEntropy(self): def testEntropy(self):
with self.test_session():
alpha = [1., 2, 3] alpha = [1., 2, 3]
dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
self.assertEqual(dirichlet.entropy().get_shape(), ()) self.assertEqual(dirichlet.entropy().get_shape(), ())
@ -247,7 +229,6 @@ class DirichletTest(test.TestCase):
self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy) self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy)
def testSample(self): def testSample(self):
with self.test_session():
alpha = [1., 2] alpha = [1., 2]
dirichlet = dirichlet_lib.Dirichlet(alpha) dirichlet = dirichlet_lib.Dirichlet(alpha)
n = constant_op.constant(100000) n = constant_op.constant(100000)
@ -261,8 +242,7 @@ class DirichletTest(test.TestCase):
stats.kstest( stats.kstest(
# Beta is a univariate distribution. # Beta is a univariate distribution.
sample_values[:, 0], sample_values[:, 0],
stats.beta( stats.beta(a=1., b=2.).cdf)[0],
a=1., b=2.).cdf)[0],
0.01) 0.01)
def testDirichletFullyReparameterized(self): def testDirichletFullyReparameterized(self):

View File

@ -22,7 +22,6 @@ import importlib
import numpy as np import numpy as np
from tensorflow.python.client import session
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -48,7 +47,6 @@ stats = try_import("scipy.stats")
class ExponentialTest(test.TestCase): class ExponentialTest(test.TestCase):
def testExponentialLogPDF(self): def testExponentialLogPDF(self):
with session.Session():
batch_size = 6 batch_size = 6
lam = constant_op.constant([2.0] * batch_size) lam = constant_op.constant([2.0] * batch_size)
lam_v = 2.0 lam_v = 2.0
@ -68,7 +66,6 @@ class ExponentialTest(test.TestCase):
self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
def testExponentialCDF(self): def testExponentialCDF(self):
with session.Session():
batch_size = 6 batch_size = 6
lam = constant_op.constant([2.0] * batch_size) lam = constant_op.constant([2.0] * batch_size)
lam_v = 2.0 lam_v = 2.0
@ -85,7 +82,6 @@ class ExponentialTest(test.TestCase):
self.assertAllClose(self.evaluate(cdf), expected_cdf) self.assertAllClose(self.evaluate(cdf), expected_cdf)
def testExponentialMean(self): def testExponentialMean(self):
with session.Session():
lam_v = np.array([1.0, 4.0, 2.5]) lam_v = np.array([1.0, 4.0, 2.5])
exponential = exponential_lib.Exponential(rate=lam_v) exponential = exponential_lib.Exponential(rate=lam_v)
self.assertEqual(exponential.mean().get_shape(), (3,)) self.assertEqual(exponential.mean().get_shape(), (3,))
@ -95,7 +91,6 @@ class ExponentialTest(test.TestCase):
self.assertAllClose(self.evaluate(exponential.mean()), expected_mean) self.assertAllClose(self.evaluate(exponential.mean()), expected_mean)
def testExponentialVariance(self): def testExponentialVariance(self):
with session.Session():
lam_v = np.array([1.0, 4.0, 2.5]) lam_v = np.array([1.0, 4.0, 2.5])
exponential = exponential_lib.Exponential(rate=lam_v) exponential = exponential_lib.Exponential(rate=lam_v)
self.assertEqual(exponential.variance().get_shape(), (3,)) self.assertEqual(exponential.variance().get_shape(), (3,))
@ -106,18 +101,15 @@ class ExponentialTest(test.TestCase):
self.evaluate(exponential.variance()), expected_variance) self.evaluate(exponential.variance()), expected_variance)
def testExponentialEntropy(self): def testExponentialEntropy(self):
with session.Session():
lam_v = np.array([1.0, 4.0, 2.5]) lam_v = np.array([1.0, 4.0, 2.5])
exponential = exponential_lib.Exponential(rate=lam_v) exponential = exponential_lib.Exponential(rate=lam_v)
self.assertEqual(exponential.entropy().get_shape(), (3,)) self.assertEqual(exponential.entropy().get_shape(), (3,))
if not stats: if not stats:
return return
expected_entropy = stats.expon.entropy(scale=1 / lam_v) expected_entropy = stats.expon.entropy(scale=1 / lam_v)
self.assertAllClose( self.assertAllClose(self.evaluate(exponential.entropy()), expected_entropy)
self.evaluate(exponential.entropy()), expected_entropy)
def testExponentialSample(self): def testExponentialSample(self):
with self.test_session():
lam = constant_op.constant([3.0, 4.0]) lam = constant_op.constant([3.0, 4.0])
lam_v = [3.0, 4.0] lam_v = [3.0, 4.0]
n = constant_op.constant(100000) n = constant_op.constant(100000)
@ -131,12 +123,10 @@ class ExponentialTest(test.TestCase):
return return
for i in range(2): for i in range(2):
self.assertLess( self.assertLess(
stats.kstest( stats.kstest(sample_values[:, i],
sample_values[:, i], stats.expon(scale=1.0 / lam_v[i]).cdf)[0], stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
0.01)
def testExponentialSampleMultiDimensional(self): def testExponentialSampleMultiDimensional(self):
with self.test_session():
batch_size = 2 batch_size = 2
lam_v = [3.0, 22.0] lam_v = [3.0, 22.0]
lam = constant_op.constant([lam_v] * batch_size) lam = constant_op.constant([lam_v] * batch_size)
@ -154,15 +144,11 @@ class ExponentialTest(test.TestCase):
return return
for i in range(2): for i in range(2):
self.assertLess( self.assertLess(
stats.kstest( stats.kstest(sample_values[:, 0, i],
sample_values[:, 0, i], stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
0.01)
self.assertLess( self.assertLess(
stats.kstest( stats.kstest(sample_values[:, 1, i],
sample_values[:, 1, i], stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
0.01)
def testFullyReparameterized(self): def testFullyReparameterized(self):
lam = constant_op.constant([0.1, 1.0]) lam = constant_op.constant([0.1, 1.0])
@ -174,7 +160,6 @@ class ExponentialTest(test.TestCase):
self.assertIsNotNone(grad_lam) self.assertIsNotNone(grad_lam)
def testExponentialWithSoftplusRate(self): def testExponentialWithSoftplusRate(self):
with self.test_session():
lam = [-2.2, -3.4] lam = [-2.2, -3.4]
exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam) exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam)
self.assertAllClose( self.assertAllClose(

View File

@ -50,7 +50,6 @@ stats = try_import("scipy.stats")
class GammaTest(test.TestCase): class GammaTest(test.TestCase):
def testGammaShape(self): def testGammaShape(self):
with self.test_session():
alpha = constant_op.constant([3.0] * 5) alpha = constant_op.constant([3.0] * 5)
beta = constant_op.constant(11.0) beta = constant_op.constant(11.0)
gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
@ -61,7 +60,6 @@ class GammaTest(test.TestCase):
self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([])) self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([]))
def testGammaLogPDF(self): def testGammaLogPDF(self):
with self.test_session():
batch_size = 6 batch_size = 6
alpha = constant_op.constant([2.0] * batch_size) alpha = constant_op.constant([2.0] * batch_size)
beta = constant_op.constant([3.0] * batch_size) beta = constant_op.constant([3.0] * batch_size)
@ -80,7 +78,6 @@ class GammaTest(test.TestCase):
self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
def testGammaLogPDFMultidimensional(self): def testGammaLogPDFMultidimensional(self):
with self.test_session():
batch_size = 6 batch_size = 6
alpha = constant_op.constant([[2.0, 4.0]] * batch_size) alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
beta = constant_op.constant([[3.0, 4.0]] * batch_size) beta = constant_op.constant([[3.0, 4.0]] * batch_size)
@ -101,7 +98,6 @@ class GammaTest(test.TestCase):
self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testGammaLogPDFMultidimensionalBroadcasting(self): def testGammaLogPDFMultidimensionalBroadcasting(self):
with self.test_session():
batch_size = 6 batch_size = 6
alpha = constant_op.constant([[2.0, 4.0]] * batch_size) alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
beta = constant_op.constant(3.0) beta = constant_op.constant(3.0)
@ -123,7 +119,6 @@ class GammaTest(test.TestCase):
self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testGammaCDF(self): def testGammaCDF(self):
with self.test_session():
batch_size = 6 batch_size = 6
alpha = constant_op.constant([2.0] * batch_size) alpha = constant_op.constant([2.0] * batch_size)
beta = constant_op.constant([3.0] * batch_size) beta = constant_op.constant([3.0] * batch_size)
@ -140,7 +135,6 @@ class GammaTest(test.TestCase):
self.assertAllClose(self.evaluate(cdf), expected_cdf) self.assertAllClose(self.evaluate(cdf), expected_cdf)
def testGammaMean(self): def testGammaMean(self):
with self.test_session():
alpha_v = np.array([1.0, 3.0, 2.5]) alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0]) beta_v = np.array([1.0, 4.0, 5.0])
gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
@ -151,7 +145,6 @@ class GammaTest(test.TestCase):
self.assertAllClose(self.evaluate(gamma.mean()), expected_means) self.assertAllClose(self.evaluate(gamma.mean()), expected_means)
def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
with self.test_session():
alpha_v = np.array([5.5, 3.0, 2.5]) alpha_v = np.array([5.5, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0]) beta_v = np.array([1.0, 4.0, 5.0])
gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
@ -160,31 +153,26 @@ class GammaTest(test.TestCase):
self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
def testGammaModeAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): def testGammaModeAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
with self.test_session():
# Mode will not be defined for the first entry. # Mode will not be defined for the first entry.
alpha_v = np.array([0.5, 3.0, 2.5]) alpha_v = np.array([0.5, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0]) beta_v = np.array([1.0, 4.0, 5.0])
gamma = gamma_lib.Gamma(concentration=alpha_v, gamma = gamma_lib.Gamma(
rate=beta_v, concentration=alpha_v, rate=beta_v, allow_nan_stats=False)
allow_nan_stats=False)
with self.assertRaisesOpError("x < y"): with self.assertRaisesOpError("x < y"):
self.evaluate(gamma.mode()) self.evaluate(gamma.mode())
def testGammaModeAllowNanStatsIsTrueReturnsNaNforUndefinedBatchMembers(self): def testGammaModeAllowNanStatsIsTrueReturnsNaNforUndefinedBatchMembers(self):
with self.test_session():
# Mode will not be defined for the first entry. # Mode will not be defined for the first entry.
alpha_v = np.array([0.5, 3.0, 2.5]) alpha_v = np.array([0.5, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0]) beta_v = np.array([1.0, 4.0, 5.0])
gamma = gamma_lib.Gamma(concentration=alpha_v, gamma = gamma_lib.Gamma(
rate=beta_v, concentration=alpha_v, rate=beta_v, allow_nan_stats=True)
allow_nan_stats=True)
expected_modes = (alpha_v - 1) / beta_v expected_modes = (alpha_v - 1) / beta_v
expected_modes[0] = np.nan expected_modes[0] = np.nan
self.assertEqual(gamma.mode().get_shape(), (3,)) self.assertEqual(gamma.mode().get_shape(), (3,))
self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
def testGammaVariance(self): def testGammaVariance(self):
with self.test_session():
alpha_v = np.array([1.0, 3.0, 2.5]) alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0]) beta_v = np.array([1.0, 4.0, 5.0])
gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
@ -195,7 +183,6 @@ class GammaTest(test.TestCase):
self.assertAllClose(self.evaluate(gamma.variance()), expected_variances) self.assertAllClose(self.evaluate(gamma.variance()), expected_variances)
def testGammaStd(self): def testGammaStd(self):
with self.test_session():
alpha_v = np.array([1.0, 3.0, 2.5]) alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0]) beta_v = np.array([1.0, 4.0, 5.0])
gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
@ -206,7 +193,6 @@ class GammaTest(test.TestCase):
self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev) self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev)
def testGammaEntropy(self): def testGammaEntropy(self):
with self.test_session():
alpha_v = np.array([1.0, 3.0, 2.5]) alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0]) beta_v = np.array([1.0, 4.0, 5.0])
gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
@ -217,7 +203,6 @@ class GammaTest(test.TestCase):
self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy) self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy)
def testGammaSampleSmallAlpha(self): def testGammaSampleSmallAlpha(self):
with self.test_session():
alpha_v = 0.05 alpha_v = 0.05
beta_v = 1.0 beta_v = 1.0
alpha = constant_op.constant(alpha_v) alpha = constant_op.constant(alpha_v)
@ -233,8 +218,7 @@ class GammaTest(test.TestCase):
return return
self.assertAllClose( self.assertAllClose(
sample_values.mean(), sample_values.mean(),
stats.gamma.mean( stats.gamma.mean(alpha_v, scale=1 / beta_v),
alpha_v, scale=1 / beta_v),
atol=.01) atol=.01)
self.assertAllClose( self.assertAllClose(
sample_values.var(), sample_values.var(),
@ -242,7 +226,6 @@ class GammaTest(test.TestCase):
atol=.15) atol=.15)
def testGammaSample(self): def testGammaSample(self):
with self.test_session():
alpha_v = 4.0 alpha_v = 4.0
beta_v = 3.0 beta_v = 3.0
alpha = constant_op.constant(alpha_v) alpha = constant_op.constant(alpha_v)
@ -258,8 +241,7 @@ class GammaTest(test.TestCase):
return return
self.assertAllClose( self.assertAllClose(
sample_values.mean(), sample_values.mean(),
stats.gamma.mean( stats.gamma.mean(alpha_v, scale=1 / beta_v),
alpha_v, scale=1 / beta_v),
atol=.01) atol=.01)
self.assertAllClose( self.assertAllClose(
sample_values.var(), sample_values.var(),
@ -279,7 +261,6 @@ class GammaTest(test.TestCase):
self.assertIsNotNone(grad_beta) self.assertIsNotNone(grad_beta)
def testGammaSampleMultiDimensional(self): def testGammaSampleMultiDimensional(self):
with self.test_session():
alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100
beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1
gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
@ -295,13 +276,14 @@ class GammaTest(test.TestCase):
return return
self.assertAllClose( self.assertAllClose(
sample_values.mean(axis=0), sample_values.mean(axis=0),
stats.gamma.mean( stats.gamma.mean(alpha_bc, scale=1 / beta_bc),
alpha_bc, scale=1 / beta_bc), atol=0.,
atol=0., rtol=.05) rtol=.05)
self.assertAllClose( self.assertAllClose(
sample_values.var(axis=0), sample_values.var(axis=0),
stats.gamma.var(alpha_bc, scale=1 / beta_bc), stats.gamma.var(alpha_bc, scale=1 / beta_bc),
atol=10.0, rtol=0.) atol=10.0,
rtol=0.)
fails = 0 fails = 0
trials = 0 trials = 0
for ai, a in enumerate(np.reshape(alpha_v, [-1])): for ai, a in enumerate(np.reshape(alpha_v, [-1])):
@ -320,7 +302,6 @@ class GammaTest(test.TestCase):
return ks < 0.02 return ks < 0.02
def testGammaPdfOfSampleMultiDims(self): def testGammaPdfOfSampleMultiDims(self):
with self.test_session():
gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]]) gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]])
num = 50000 num = 50000
samples = gamma.sample(num, seed=137) samples = gamma.sample(num, seed=137)
@ -335,8 +316,8 @@ class GammaTest(test.TestCase):
if not stats: if not stats:
return return
self.assertAllClose( self.assertAllClose(
stats.gamma.mean( stats.gamma.mean([[7., 11.], [7., 11.]],
[[7., 11.], [7., 11.]], scale=1 / np.array([[5., 5.], [6., 6.]])), scale=1 / np.array([[5., 5.], [6., 6.]])),
sample_vals.mean(axis=0), sample_vals.mean(axis=0),
atol=.1) atol=.1)
self.assertAllClose( self.assertAllClose(
@ -356,32 +337,29 @@ class GammaTest(test.TestCase):
self.assertNear(1., total, err=err) self.assertNear(1., total, err=err)
def testGammaNonPositiveInitializationParamsRaises(self): def testGammaNonPositiveInitializationParamsRaises(self):
with self.test_session():
alpha_v = constant_op.constant(0.0, name="alpha") alpha_v = constant_op.constant(0.0, name="alpha")
beta_v = constant_op.constant(1.0, name="beta") beta_v = constant_op.constant(1.0, name="beta")
with self.assertRaisesOpError("x > 0"): with self.assertRaisesOpError("x > 0"):
gamma = gamma_lib.Gamma(concentration=alpha_v, gamma = gamma_lib.Gamma(
rate=beta_v, concentration=alpha_v, rate=beta_v, validate_args=True)
validate_args=True)
self.evaluate(gamma.mean()) self.evaluate(gamma.mean())
alpha_v = constant_op.constant(1.0, name="alpha") alpha_v = constant_op.constant(1.0, name="alpha")
beta_v = constant_op.constant(0.0, name="beta") beta_v = constant_op.constant(0.0, name="beta")
with self.assertRaisesOpError("x > 0"): with self.assertRaisesOpError("x > 0"):
gamma = gamma_lib.Gamma(concentration=alpha_v, gamma = gamma_lib.Gamma(
rate=beta_v, concentration=alpha_v, rate=beta_v, validate_args=True)
validate_args=True)
self.evaluate(gamma.mean()) self.evaluate(gamma.mean())
def testGammaWithSoftplusConcentrationRate(self): def testGammaWithSoftplusConcentrationRate(self):
with self.test_session():
alpha_v = constant_op.constant([0.0, -2.1], name="alpha") alpha_v = constant_op.constant([0.0, -2.1], name="alpha")
beta_v = constant_op.constant([1.0, -3.6], name="beta") beta_v = constant_op.constant([1.0, -3.6], name="beta")
gamma = gamma_lib.GammaWithSoftplusConcentrationRate( gamma = gamma_lib.GammaWithSoftplusConcentrationRate(
concentration=alpha_v, rate=beta_v) concentration=alpha_v, rate=beta_v)
self.assertAllEqual(self.evaluate(nn_ops.softplus(alpha_v)), self.assertAllEqual(
self.evaluate(nn_ops.softplus(alpha_v)),
self.evaluate(gamma.concentration)) self.evaluate(gamma.concentration))
self.assertAllEqual(self.evaluate(nn_ops.softplus(beta_v)), self.assertAllEqual(
self.evaluate(gamma.rate)) self.evaluate(nn_ops.softplus(beta_v)), self.evaluate(gamma.rate))
def testGammaGammaKL(self): def testGammaGammaKL(self):
alpha0 = np.array([3.]) alpha0 = np.array([3.])
@ -391,7 +369,6 @@ class GammaTest(test.TestCase):
beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.]) beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.])
# Build graph. # Build graph.
with self.test_session():
g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0) g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0)
g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1) g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1)
x = g0.sample(int(1e4), seed=0) x = g0.sample(int(1e4), seed=0)

View File

@ -21,7 +21,6 @@ import importlib
import numpy as np import numpy as np
from tensorflow.python.client import session
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
@ -49,7 +48,6 @@ stats = try_import("scipy.stats")
class LaplaceTest(test.TestCase): class LaplaceTest(test.TestCase):
def testLaplaceShape(self): def testLaplaceShape(self):
with self.test_session():
loc = constant_op.constant([3.0] * 5) loc = constant_op.constant([3.0] * 5)
scale = constant_op.constant(11.0) scale = constant_op.constant(11.0)
laplace = laplace_lib.Laplace(loc=loc, scale=scale) laplace = laplace_lib.Laplace(loc=loc, scale=scale)
@ -60,7 +58,6 @@ class LaplaceTest(test.TestCase):
self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([])) self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([]))
def testLaplaceLogPDF(self): def testLaplaceLogPDF(self):
with self.test_session():
batch_size = 6 batch_size = 6
loc = constant_op.constant([2.0] * batch_size) loc = constant_op.constant([2.0] * batch_size)
scale = constant_op.constant([3.0] * batch_size) scale = constant_op.constant([3.0] * batch_size)
@ -80,7 +77,6 @@ class LaplaceTest(test.TestCase):
self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
def testLaplaceLogPDFMultidimensional(self): def testLaplaceLogPDFMultidimensional(self):
with self.test_session():
batch_size = 6 batch_size = 6
loc = constant_op.constant([[2.0, 4.0]] * batch_size) loc = constant_op.constant([[2.0, 4.0]] * batch_size)
scale = constant_op.constant([[3.0, 4.0]] * batch_size) scale = constant_op.constant([[3.0, 4.0]] * batch_size)
@ -102,7 +98,6 @@ class LaplaceTest(test.TestCase):
self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testLaplaceLogPDFMultidimensionalBroadcasting(self): def testLaplaceLogPDFMultidimensionalBroadcasting(self):
with self.test_session():
batch_size = 6 batch_size = 6
loc = constant_op.constant([[2.0, 4.0]] * batch_size) loc = constant_op.constant([[2.0, 4.0]] * batch_size)
scale = constant_op.constant(3.0) scale = constant_op.constant(3.0)
@ -124,7 +119,6 @@ class LaplaceTest(test.TestCase):
self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
def testLaplaceCDF(self): def testLaplaceCDF(self):
with self.test_session():
batch_size = 6 batch_size = 6
loc = constant_op.constant([2.0] * batch_size) loc = constant_op.constant([2.0] * batch_size)
scale = constant_op.constant([3.0] * batch_size) scale = constant_op.constant([3.0] * batch_size)
@ -142,7 +136,6 @@ class LaplaceTest(test.TestCase):
self.assertAllClose(self.evaluate(cdf), expected_cdf) self.assertAllClose(self.evaluate(cdf), expected_cdf)
def testLaplaceLogCDF(self): def testLaplaceLogCDF(self):
with self.test_session():
batch_size = 6 batch_size = 6
loc = constant_op.constant([2.0] * batch_size) loc = constant_op.constant([2.0] * batch_size)
scale = constant_op.constant([3.0] * batch_size) scale = constant_op.constant([3.0] * batch_size)
@ -160,7 +153,6 @@ class LaplaceTest(test.TestCase):
self.assertAllClose(self.evaluate(cdf), expected_cdf) self.assertAllClose(self.evaluate(cdf), expected_cdf)
def testLaplaceLogSurvivalFunction(self): def testLaplaceLogSurvivalFunction(self):
with self.test_session():
batch_size = 6 batch_size = 6
loc = constant_op.constant([2.0] * batch_size) loc = constant_op.constant([2.0] * batch_size)
scale = constant_op.constant([3.0] * batch_size) scale = constant_op.constant([3.0] * batch_size)
@ -178,7 +170,6 @@ class LaplaceTest(test.TestCase):
self.assertAllClose(self.evaluate(sf), expected_sf) self.assertAllClose(self.evaluate(sf), expected_sf)
def testLaplaceMean(self): def testLaplaceMean(self):
with self.test_session():
loc_v = np.array([1.0, 3.0, 2.5]) loc_v = np.array([1.0, 3.0, 2.5])
scale_v = np.array([1.0, 4.0, 5.0]) scale_v = np.array([1.0, 4.0, 5.0])
laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
@ -189,7 +180,6 @@ class LaplaceTest(test.TestCase):
self.assertAllClose(self.evaluate(laplace.mean()), expected_means) self.assertAllClose(self.evaluate(laplace.mean()), expected_means)
def testLaplaceMode(self): def testLaplaceMode(self):
with self.test_session():
loc_v = np.array([0.5, 3.0, 2.5]) loc_v = np.array([0.5, 3.0, 2.5])
scale_v = np.array([1.0, 4.0, 5.0]) scale_v = np.array([1.0, 4.0, 5.0])
laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
@ -197,7 +187,6 @@ class LaplaceTest(test.TestCase):
self.assertAllClose(self.evaluate(laplace.mode()), loc_v) self.assertAllClose(self.evaluate(laplace.mode()), loc_v)
def testLaplaceVariance(self): def testLaplaceVariance(self):
with self.test_session():
loc_v = np.array([1.0, 3.0, 2.5]) loc_v = np.array([1.0, 3.0, 2.5])
scale_v = np.array([1.0, 4.0, 5.0]) scale_v = np.array([1.0, 4.0, 5.0])
laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
@ -208,7 +197,6 @@ class LaplaceTest(test.TestCase):
self.assertAllClose(self.evaluate(laplace.variance()), expected_variances) self.assertAllClose(self.evaluate(laplace.variance()), expected_variances)
def testLaplaceStd(self): def testLaplaceStd(self):
with self.test_session():
loc_v = np.array([1.0, 3.0, 2.5]) loc_v = np.array([1.0, 3.0, 2.5])
scale_v = np.array([1.0, 4.0, 5.0]) scale_v = np.array([1.0, 4.0, 5.0])
laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
@ -219,7 +207,6 @@ class LaplaceTest(test.TestCase):
self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev) self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev)
def testLaplaceEntropy(self): def testLaplaceEntropy(self):
with self.test_session():
loc_v = np.array([1.0, 3.0, 2.5]) loc_v = np.array([1.0, 3.0, 2.5])
scale_v = np.array([1.0, 4.0, 5.0]) scale_v = np.array([1.0, 4.0, 5.0])
laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
@ -230,7 +217,6 @@ class LaplaceTest(test.TestCase):
self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy) self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy)
def testLaplaceSample(self): def testLaplaceSample(self):
with session.Session():
loc_v = 4.0 loc_v = 4.0
scale_v = 3.0 scale_v = 3.0
loc = constant_op.constant(loc_v) loc = constant_op.constant(loc_v)
@ -245,8 +231,7 @@ class LaplaceTest(test.TestCase):
return return
self.assertAllClose( self.assertAllClose(
sample_values.mean(), sample_values.mean(),
stats.laplace.mean( stats.laplace.mean(loc_v, scale=scale_v),
loc_v, scale=scale_v),
rtol=0.05, rtol=0.05,
atol=0.) atol=0.)
self.assertAllClose( self.assertAllClose(
@ -269,7 +254,6 @@ class LaplaceTest(test.TestCase):
self.assertIsNotNone(grad_scale) self.assertIsNotNone(grad_scale)
def testLaplaceSampleMultiDimensional(self): def testLaplaceSampleMultiDimensional(self):
with session.Session():
loc_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 loc_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100
scale_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 scale_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1
laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
@ -285,8 +269,7 @@ class LaplaceTest(test.TestCase):
return return
self.assertAllClose( self.assertAllClose(
sample_values.mean(axis=0), sample_values.mean(axis=0),
stats.laplace.mean( stats.laplace.mean(loc_bc, scale=scale_bc),
loc_bc, scale=scale_bc),
rtol=0.35, rtol=0.35,
atol=0.) atol=0.)
self.assertAllClose( self.assertAllClose(
@ -349,24 +332,20 @@ class LaplaceTest(test.TestCase):
self.assertNear(1., total, err=err) self.assertNear(1., total, err=err)
def testLaplaceNonPositiveInitializationParamsRaises(self): def testLaplaceNonPositiveInitializationParamsRaises(self):
with self.test_session():
loc_v = constant_op.constant(0.0, name="loc") loc_v = constant_op.constant(0.0, name="loc")
scale_v = constant_op.constant(-1.0, name="scale") scale_v = constant_op.constant(-1.0, name="scale")
with self.assertRaisesOpError( with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"):
"Condition x > 0 did not hold element-wise"):
laplace = laplace_lib.Laplace( laplace = laplace_lib.Laplace(
loc=loc_v, scale=scale_v, validate_args=True) loc=loc_v, scale=scale_v, validate_args=True)
self.evaluate(laplace.mean()) self.evaluate(laplace.mean())
loc_v = constant_op.constant(1.0, name="loc") loc_v = constant_op.constant(1.0, name="loc")
scale_v = constant_op.constant(0.0, name="scale") scale_v = constant_op.constant(0.0, name="scale")
with self.assertRaisesOpError( with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"):
"Condition x > 0 did not hold element-wise"):
laplace = laplace_lib.Laplace( laplace = laplace_lib.Laplace(
loc=loc_v, scale=scale_v, validate_args=True) loc=loc_v, scale=scale_v, validate_args=True)
self.evaluate(laplace.mean()) self.evaluate(laplace.mean())
def testLaplaceWithSoftplusScale(self): def testLaplaceWithSoftplusScale(self):
with self.test_session():
loc_v = constant_op.constant([0.0, 1.0], name="loc") loc_v = constant_op.constant([0.0, 1.0], name="loc")
scale_v = constant_op.constant([-1.0, 2.0], name="scale") scale_v = constant_op.constant([-1.0, 2.0], name="scale")
laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v) laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v)

View File

@ -61,7 +61,6 @@ class NormalTest(test.TestCase):
self.assertAllEqual(all_true, is_finite) self.assertAllEqual(all_true, is_finite)
def _testParamShapes(self, sample_shape, expected): def _testParamShapes(self, sample_shape, expected):
with self.test_session():
param_shapes = normal_lib.Normal.param_shapes(sample_shape) param_shapes = normal_lib.Normal.param_shapes(sample_shape)
mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"] mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"]
self.assertAllEqual(expected, self.evaluate(mu_shape)) self.assertAllEqual(expected, self.evaluate(mu_shape))
@ -93,7 +92,6 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testNormalWithSoftplusScale(self): def testNormalWithSoftplusScale(self):
with self.test_session():
mu = array_ops.zeros((10, 3)) mu = array_ops.zeros((10, 3))
rho = array_ops.ones((10, 3)) * -2. rho = array_ops.ones((10, 3)) * -2.
normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho) normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho)
@ -103,7 +101,6 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testNormalLogPDF(self): def testNormalLogPDF(self):
with self.test_session():
batch_size = 6 batch_size = 6
mu = constant_op.constant([3.0] * batch_size) mu = constant_op.constant([3.0] * batch_size)
sigma = constant_op.constant([math.sqrt(10.0)] * batch_size) sigma = constant_op.constant([math.sqrt(10.0)] * batch_size)
@ -137,11 +134,10 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testNormalLogPDFMultidimensional(self): def testNormalLogPDFMultidimensional(self):
with self.test_session():
batch_size = 6 batch_size = 6
mu = constant_op.constant([[3.0, -3.0]] * batch_size) mu = constant_op.constant([[3.0, -3.0]] * batch_size)
sigma = constant_op.constant([[math.sqrt(10.0), math.sqrt(15.0)]] * sigma = constant_op.constant(
batch_size) [[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size)
x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
normal = normal_lib.Normal(loc=mu, scale=sigma) normal = normal_lib.Normal(loc=mu, scale=sigma)
@ -175,7 +171,6 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testNormalCDF(self): def testNormalCDF(self):
with self.test_session():
batch_size = 50 batch_size = 50
mu = self._rng.randn(batch_size) mu = self._rng.randn(batch_size)
sigma = self._rng.rand(batch_size) + 1.0 sigma = self._rng.rand(batch_size) + 1.0
@ -197,7 +192,6 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testNormalSurvivalFunction(self): def testNormalSurvivalFunction(self):
with self.test_session():
batch_size = 50 batch_size = 50
mu = self._rng.randn(batch_size) mu = self._rng.randn(batch_size)
sigma = self._rng.rand(batch_size) + 1.0 sigma = self._rng.rand(batch_size) + 1.0
@ -220,7 +214,6 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testNormalLogCDF(self): def testNormalLogCDF(self):
with self.test_session():
batch_size = 50 batch_size = 50
mu = self._rng.randn(batch_size) mu = self._rng.randn(batch_size)
sigma = self._rng.rand(batch_size) + 1.0 sigma = self._rng.rand(batch_size) + 1.0
@ -256,7 +249,7 @@ class NormalTest(test.TestCase):
]: ]:
value = func(x) value = func(x)
grads = gradients_impl.gradients(value, [mu, sigma]) grads = gradients_impl.gradients(value, [mu, sigma])
with self.test_session(graph=g): with self.session(graph=g):
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
self.assertAllFinite(value) self.assertAllFinite(value)
self.assertAllFinite(grads[0]) self.assertAllFinite(grads[0])
@ -264,7 +257,6 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testNormalLogSurvivalFunction(self): def testNormalLogSurvivalFunction(self):
with self.test_session():
batch_size = 50 batch_size = 50
mu = self._rng.randn(batch_size) mu = self._rng.randn(batch_size)
sigma = self._rng.rand(batch_size) + 1.0 sigma = self._rng.rand(batch_size) + 1.0
@ -289,7 +281,6 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testNormalEntropyWithScalarInputs(self): def testNormalEntropyWithScalarInputs(self):
# Scipy.stats.norm cannot deal with the shapes in the other test. # Scipy.stats.norm cannot deal with the shapes in the other test.
with self.test_session():
mu_v = 2.34 mu_v = 2.34
sigma_v = 4.56 sigma_v = 4.56
normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
@ -310,15 +301,13 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testNormalEntropy(self): def testNormalEntropy(self):
with self.test_session():
mu_v = np.array([1.0, 1.0, 1.0]) mu_v = np.array([1.0, 1.0, 1.0])
sigma_v = np.array([[1.0, 2.0, 3.0]]).T sigma_v = np.array([[1.0, 2.0, 3.0]]).T
normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
# scipy.stats.norm cannot deal with these shapes. # scipy.stats.norm cannot deal with these shapes.
sigma_broadcast = mu_v * sigma_v sigma_broadcast = mu_v * sigma_v
expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast** expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**2)
2)
entropy = normal.entropy() entropy = normal.entropy()
np.testing.assert_allclose(expected_entropy, self.evaluate(entropy)) np.testing.assert_allclose(expected_entropy, self.evaluate(entropy))
self.assertAllEqual( self.assertAllEqual(
@ -331,7 +320,6 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testNormalMeanAndMode(self): def testNormalMeanAndMode(self):
with self.test_session():
# Mu will be broadcast to [7, 7, 7]. # Mu will be broadcast to [7, 7, 7].
mu = [7.] mu = [7.]
sigma = [11., 12., 13.] sigma = [11., 12., 13.]
@ -346,7 +334,6 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testNormalQuantile(self): def testNormalQuantile(self):
with self.test_session():
batch_size = 52 batch_size = 52
mu = self._rng.randn(batch_size) mu = self._rng.randn(batch_size)
sigma = self._rng.rand(batch_size) + 1.0 sigma = self._rng.rand(batch_size) + 1.0
@ -385,7 +372,7 @@ class NormalTest(test.TestCase):
value = dist.quantile(p) value = dist.quantile(p)
grads = gradients_impl.gradients(value, [mu, p]) grads = gradients_impl.gradients(value, [mu, p])
with self.test_session(graph=g): with self.cached_session(graph=g):
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
self.assertAllFinite(grads[0]) self.assertAllFinite(grads[0])
self.assertAllFinite(grads[1]) self.assertAllFinite(grads[1])
@ -398,7 +385,6 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testNormalVariance(self): def testNormalVariance(self):
with self.test_session():
# sigma will be broadcast to [7, 7, 7] # sigma will be broadcast to [7, 7, 7]
mu = [1., 2., 3.] mu = [1., 2., 3.]
sigma = [7.] sigma = [7.]
@ -410,7 +396,6 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testNormalStandardDeviation(self): def testNormalStandardDeviation(self):
with self.test_session():
# sigma will be broadcast to [7, 7, 7] # sigma will be broadcast to [7, 7, 7]
mu = [1., 2., 3.] mu = [1., 2., 3.]
sigma = [7.] sigma = [7.]
@ -422,7 +407,6 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testNormalSample(self): def testNormalSample(self):
with self.test_session():
mu = constant_op.constant(3.0) mu = constant_op.constant(3.0)
sigma = constant_op.constant(math.sqrt(3.0)) sigma = constant_op.constant(math.sqrt(3.0))
mu_v = 3.0 mu_v = 3.0
@ -468,11 +452,10 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testNormalSampleMultiDimensional(self): def testNormalSampleMultiDimensional(self):
with self.test_session():
batch_size = 2 batch_size = 2
mu = constant_op.constant([[3.0, -3.0]] * batch_size) mu = constant_op.constant([[3.0, -3.0]] * batch_size)
sigma = constant_op.constant([[math.sqrt(2.0), math.sqrt(3.0)]] * sigma = constant_op.constant(
batch_size) [[math.sqrt(2.0), math.sqrt(3.0)]] * batch_size)
mu_v = [3.0, -3.0] mu_v = [3.0, -3.0]
sigma_v = [np.sqrt(2.0), np.sqrt(3.0)] sigma_v = [np.sqrt(2.0), np.sqrt(3.0)]
n = constant_op.constant(100000) n = constant_op.constant(100000)
@ -504,7 +487,6 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testNegativeSigmaFails(self): def testNegativeSigmaFails(self):
with self.test_session():
with self.assertRaisesOpError("Condition x > 0 did not hold"): with self.assertRaisesOpError("Condition x > 0 did not hold"):
normal = normal_lib.Normal( normal = normal_lib.Normal(
loc=[1.], scale=[-5.], validate_args=True, name="G") loc=[1.], scale=[-5.], validate_args=True, name="G")
@ -512,7 +494,6 @@ class NormalTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testNormalShape(self): def testNormalShape(self):
with self.test_session():
mu = constant_op.constant([-3.0] * 5) mu = constant_op.constant([-3.0] * 5)
sigma = constant_op.constant(11.0) sigma = constant_op.constant(11.0)
normal = normal_lib.Normal(loc=mu, scale=sigma) normal = normal_lib.Normal(loc=mu, scale=sigma)
@ -527,7 +508,7 @@ class NormalTest(test.TestCase):
sigma = array_ops.placeholder(dtype=dtypes.float32) sigma = array_ops.placeholder(dtype=dtypes.float32)
normal = normal_lib.Normal(loc=mu, scale=sigma) normal = normal_lib.Normal(loc=mu, scale=sigma)
with self.test_session() as sess: with self.cached_session() as sess:
# get_batch_shape should return an "<unknown>" tensor. # get_batch_shape should return an "<unknown>" tensor.
self.assertEqual(normal.batch_shape, tensor_shape.TensorShape(None)) self.assertEqual(normal.batch_shape, tensor_shape.TensorShape(None))
self.assertEqual(normal.event_shape, ()) self.assertEqual(normal.event_shape, ())

View File

@ -92,22 +92,21 @@ class NdtriTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testNdtri(self): def testNdtri(self):
"""Verifies that ndtri computation is correct.""" """Verifies that ndtri computation is correct."""
with self.test_session():
if not special: if not special:
return return
p = np.linspace(0., 1.0, 50).astype(np.float64) p = np.linspace(0., 1.0, 50).astype(np.float64)
# Quantile performs piecewise rational approximation so adding some # Quantile performs piecewise rational approximation so adding some
# special input values to make sure we hit all the pieces. # special input values to make sure we hit all the pieces.
p = np.hstack((p, np.exp(-32), 1. - np.exp(-32), p = np.hstack((p, np.exp(-32), 1. - np.exp(-32), np.exp(-2),
np.exp(-2), 1. - np.exp(-2))) 1. - np.exp(-2)))
expected_x = special.ndtri(p) expected_x = special.ndtri(p)
x = special_math.ndtri(p) x = special_math.ndtri(p)
self.assertAllClose(expected_x, self.evaluate(x), atol=0.) self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
def testNdtriDynamicShape(self): def testNdtriDynamicShape(self):
"""Verifies that ndtri computation is correct.""" """Verifies that ndtri computation is correct."""
with self.test_session() as sess: with self.cached_session() as sess:
if not special: if not special:
return return
@ -286,7 +285,7 @@ class NdtrGradientTest(test.TestCase):
def _test_grad_accuracy(self, dtype, grid_spec, error_spec): def _test_grad_accuracy(self, dtype, grid_spec, error_spec):
raw_grid = _make_grid(dtype, grid_spec) raw_grid = _make_grid(dtype, grid_spec)
grid = ops.convert_to_tensor(raw_grid) grid = ops.convert_to_tensor(raw_grid)
with self.test_session(): with self.cached_session():
fn = sm.log_ndtr if self._use_log else sm.ndtr fn = sm.log_ndtr if self._use_log else sm.ndtr
# If there are N points in the grid, # If there are N points in the grid,
@ -355,7 +354,7 @@ class LogNdtrGradientTest(NdtrGradientTest):
class ErfInvTest(test.TestCase): class ErfInvTest(test.TestCase):
def testErfInvValues(self): def testErfInvValues(self):
with self.test_session(): with self.cached_session():
if not special: if not special:
return return
@ -366,7 +365,7 @@ class ErfInvTest(test.TestCase):
self.assertAllClose(expected_x, x.eval(), atol=0.) self.assertAllClose(expected_x, x.eval(), atol=0.)
def testErfInvIntegerInput(self): def testErfInvIntegerInput(self):
with self.test_session(): with self.cached_session():
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
x = np.array([1, 2, 3]).astype(np.int32) x = np.array([1, 2, 3]).astype(np.int32)
@ -397,7 +396,7 @@ class LogCDFLaplaceTest(test.TestCase):
self.assertAllEqual(np.ones_like(x, dtype=np.bool), x) self.assertAllEqual(np.ones_like(x, dtype=np.bool), x)
def _test_grid_log(self, dtype, scipy_dtype, grid_spec, error_spec): def _test_grid_log(self, dtype, scipy_dtype, grid_spec, error_spec):
with self.test_session(): with self.cached_session():
grid = _make_grid(dtype, grid_spec) grid = _make_grid(dtype, grid_spec)
actual = sm.log_cdf_laplace(grid).eval() actual = sm.log_cdf_laplace(grid).eval()
@ -439,7 +438,7 @@ class LogCDFLaplaceTest(test.TestCase):
ErrorSpec(rtol=0.05, atol=0)) ErrorSpec(rtol=0.05, atol=0))
def test_float32_extreme_values_result_and_gradient_finite_and_nonzero(self): def test_float32_extreme_values_result_and_gradient_finite_and_nonzero(self):
with self.test_session() as sess: with self.cached_session() as sess:
# On the lower branch, log_cdf_laplace(x) = x, so we know this will be # On the lower branch, log_cdf_laplace(x) = x, so we know this will be
# fine, but test to -200 anyways. # fine, but test to -200 anyways.
grid = _make_grid( grid = _make_grid(
@ -458,7 +457,7 @@ class LogCDFLaplaceTest(test.TestCase):
self.assertFalse(np.any(grad_ == 0)) self.assertFalse(np.any(grad_ == 0))
def test_float64_extreme_values_result_and_gradient_finite_and_nonzero(self): def test_float64_extreme_values_result_and_gradient_finite_and_nonzero(self):
with self.test_session() as sess: with self.cached_session() as sess:
# On the lower branch, log_cdf_laplace(x) = x, so we know this will be # On the lower branch, log_cdf_laplace(x) = x, so we know this will be
# fine, but test to -200 anyways. # fine, but test to -200 anyways.
grid = _make_grid( grid = _make_grid(

View File

@ -50,7 +50,6 @@ stats = try_import("scipy.stats")
class StudentTTest(test.TestCase): class StudentTTest(test.TestCase):
def testStudentPDFAndLogPDF(self): def testStudentPDFAndLogPDF(self):
with self.test_session():
batch_size = 6 batch_size = 6
df = constant_op.constant([3.] * batch_size) df = constant_op.constant([3.] * batch_size)
mu = constant_op.constant([7.] * batch_size) mu = constant_op.constant([7.] * batch_size)
@ -79,12 +78,11 @@ class StudentTTest(test.TestCase):
self.assertAllClose(np.exp(expected_log_pdf), pdf_values) self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
def testStudentLogPDFMultidimensional(self): def testStudentLogPDFMultidimensional(self):
with self.test_session():
batch_size = 6 batch_size = 6
df = constant_op.constant([[1.5, 7.2]] * batch_size) df = constant_op.constant([[1.5, 7.2]] * batch_size)
mu = constant_op.constant([[3., -3.]] * batch_size) mu = constant_op.constant([[3., -3.]] * batch_size)
sigma = constant_op.constant([[-math.sqrt(10.), math.sqrt(15.)]] * sigma = constant_op.constant(
batch_size) [[-math.sqrt(10.), math.sqrt(15.)]] * batch_size)
df_v = np.array([1.5, 7.2]) df_v = np.array([1.5, 7.2])
mu_v = np.array([3., -3.]) mu_v = np.array([3., -3.])
sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)]) sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)])
@ -107,7 +105,6 @@ class StudentTTest(test.TestCase):
self.assertAllClose(np.exp(expected_log_pdf), pdf_values) self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
def testStudentCDFAndLogCDF(self): def testStudentCDFAndLogCDF(self):
with self.test_session():
batch_size = 6 batch_size = 6
df = constant_op.constant([3.] * batch_size) df = constant_op.constant([3.] * batch_size)
mu = constant_op.constant([7.] * batch_size) mu = constant_op.constant([7.] * batch_size)
@ -140,7 +137,6 @@ class StudentTTest(test.TestCase):
df_v = np.array([[2., 3., 7.]]) # 1x3 df_v = np.array([[2., 3., 7.]]) # 1x3
mu_v = np.array([[1., -1, 0]]) # 1x3 mu_v = np.array([[1., -1, 0]]) # 1x3
sigma_v = np.array([[1., -2., 3.]]).T # transposed => 3x1 sigma_v = np.array([[1., -2., 3.]]).T # transposed => 3x1
with self.test_session():
student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v) student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
ent = student.entropy() ent = student.entropy()
ent_values = self.evaluate(ent) ent_values = self.evaluate(ent)
@ -160,7 +156,6 @@ class StudentTTest(test.TestCase):
self.assertAllClose(expected_entropy, ent_values) self.assertAllClose(expected_entropy, ent_values)
def testStudentSample(self): def testStudentSample(self):
with self.test_session():
df = constant_op.constant(4.) df = constant_op.constant(4.)
mu = constant_op.constant(3.) mu = constant_op.constant(3.)
sigma = constant_op.constant(-math.sqrt(10.)) sigma = constant_op.constant(-math.sqrt(10.))
@ -175,34 +170,27 @@ class StudentTTest(test.TestCase):
self.assertEqual(sample_values.shape, (n_val,)) self.assertEqual(sample_values.shape, (n_val,))
self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0) self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0)
self.assertAllClose( self.assertAllClose(
sample_values.var(), sample_values.var(), sigma_v**2 * df_v / (df_v - 2), rtol=0.1, atol=0)
sigma_v**2 * df_v / (df_v - 2),
rtol=0.1,
atol=0)
self._checkKLApprox(df_v, mu_v, sigma_v, sample_values) self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
# Test that sampling with the same seed twice gives the same results. # Test that sampling with the same seed twice gives the same results.
def testStudentSampleMultipleTimes(self): def testStudentSampleMultipleTimes(self):
with self.test_session():
df = constant_op.constant(4.) df = constant_op.constant(4.)
mu = constant_op.constant(3.) mu = constant_op.constant(3.)
sigma = constant_op.constant(math.sqrt(10.)) sigma = constant_op.constant(math.sqrt(10.))
n = constant_op.constant(100) n = constant_op.constant(100)
random_seed.set_random_seed(654321) random_seed.set_random_seed(654321)
student = student_t.StudentT( student = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t1")
df=df, loc=mu, scale=sigma, name="student_t1")
samples1 = self.evaluate(student.sample(n, seed=123456)) samples1 = self.evaluate(student.sample(n, seed=123456))
random_seed.set_random_seed(654321) random_seed.set_random_seed(654321)
student2 = student_t.StudentT( student2 = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t2")
df=df, loc=mu, scale=sigma, name="student_t2")
samples2 = self.evaluate(student2.sample(n, seed=123456)) samples2 = self.evaluate(student2.sample(n, seed=123456))
self.assertAllClose(samples1, samples2) self.assertAllClose(samples1, samples2)
def testStudentSampleSmallDfNoNan(self): def testStudentSampleSmallDfNoNan(self):
with self.test_session():
df_v = [1e-1, 1e-5, 1e-10, 1e-20] df_v = [1e-1, 1e-5, 1e-10, 1e-20]
df = constant_op.constant(df_v) df = constant_op.constant(df_v)
n = constant_op.constant(200000) n = constant_op.constant(200000)
@ -214,12 +202,11 @@ class StudentTTest(test.TestCase):
self.assertTrue(np.all(np.logical_not(np.isnan(sample_values)))) self.assertTrue(np.all(np.logical_not(np.isnan(sample_values))))
def testStudentSampleMultiDimensional(self): def testStudentSampleMultiDimensional(self):
with self.test_session():
batch_size = 7 batch_size = 7
df = constant_op.constant([[5., 7.]] * batch_size) df = constant_op.constant([[5., 7.]] * batch_size)
mu = constant_op.constant([[3., -3.]] * batch_size) mu = constant_op.constant([[3., -3.]] * batch_size)
sigma = constant_op.constant([[math.sqrt(10.), math.sqrt(15.)]] * sigma = constant_op.constant(
batch_size) [[math.sqrt(10.), math.sqrt(15.)]] * batch_size)
df_v = [5., 7.] df_v = [5., 7.]
mu_v = [3., -3.] mu_v = [3., -3.]
sigma_v = [np.sqrt(10.), np.sqrt(15.)] sigma_v = [np.sqrt(10.), np.sqrt(15.)]
@ -325,33 +312,27 @@ class StudentTTest(test.TestCase):
_check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]])) _check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]]))
def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
with self.test_session():
mu = [1., 3.3, 4.4] mu = [1., 3.3, 4.4]
student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.]) student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.])
mean = self.evaluate(student.mean()) mean = self.evaluate(student.mean())
self.assertAllClose([1., 3.3, 4.4], mean) self.assertAllClose([1., 3.3, 4.4], mean)
def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self): def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self):
with self.test_session():
mu = [1., 3.3, 4.4] mu = [1., 3.3, 4.4]
student = student_t.StudentT( student = student_t.StudentT(
df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], allow_nan_stats=False)
allow_nan_stats=False)
with self.assertRaisesOpError("x < y"): with self.assertRaisesOpError("x < y"):
self.evaluate(student.mean()) self.evaluate(student.mean())
def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self): def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self):
with self.test_session():
mu = [-2, 0., 1., 3.3, 4.4] mu = [-2, 0., 1., 3.3, 4.4]
sigma = [5., 4., 3., 2., 1.] sigma = [5., 4., 3., 2., 1.]
student = student_t.StudentT( student = student_t.StudentT(
df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, allow_nan_stats=True)
allow_nan_stats=True)
mean = self.evaluate(student.mean()) mean = self.evaluate(student.mean())
self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean) self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self): def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self):
with self.test_session():
# df = 0.5 ==> undefined mean ==> undefined variance. # df = 0.5 ==> undefined mean ==> undefined variance.
# df = 1.5 ==> infinite variance. # df = 1.5 ==> infinite variance.
df = [0.5, 1.5, 3., 5., 7.] df = [0.5, 1.5, 3., 5., 7.]
@ -376,7 +357,6 @@ class StudentTTest(test.TestCase):
def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers( def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers(
self): self):
with self.test_session():
# df = 1.5 ==> infinite variance. # df = 1.5 ==> infinite variance.
df = [1.5, 3., 5., 7.] df = [1.5, 3., 5., 7.]
mu = [0., 1., 3.3, 4.4] mu = [0., 1., 3.3, 4.4]
@ -392,14 +372,11 @@ class StudentTTest(test.TestCase):
self.assertAllClose(expected_var, var) self.assertAllClose(expected_var, var)
def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
with self.test_session():
# df <= 1 ==> variance not defined # df <= 1 ==> variance not defined
student = student_t.StudentT( student = student_t.StudentT(df=1., loc=0., scale=1., allow_nan_stats=False)
df=1., loc=0., scale=1., allow_nan_stats=False)
with self.assertRaisesOpError("x < y"): with self.assertRaisesOpError("x < y"):
self.evaluate(student.variance()) self.evaluate(student.variance())
with self.test_session():
# df <= 1 ==> variance not defined # df <= 1 ==> variance not defined
student = student_t.StudentT( student = student_t.StudentT(
df=0.5, loc=0., scale=1., allow_nan_stats=False) df=0.5, loc=0., scale=1., allow_nan_stats=False)
@ -407,7 +384,6 @@ class StudentTTest(test.TestCase):
self.evaluate(student.variance()) self.evaluate(student.variance())
def testStd(self): def testStd(self):
with self.test_session():
# Defined for all batch members. # Defined for all batch members.
df = [3.5, 5., 3., 5., 7.] df = [3.5, 5., 3., 5., 7.]
mu = [-2.2] mu = [-2.2]
@ -425,7 +401,6 @@ class StudentTTest(test.TestCase):
self.assertAllClose(expected_stddev, stddev) self.assertAllClose(expected_stddev, stddev)
def testMode(self): def testMode(self):
with self.test_session():
df = [0.5, 1., 3] df = [0.5, 1., 3]
mu = [-1, 0., 1] mu = [-1, 0., 1]
sigma = [5., 4., 3.] sigma = [5., 4., 3.]
@ -510,14 +485,12 @@ class StudentTTest(test.TestCase):
self.assertNear(1., total, err=err) self.assertNear(1., total, err=err)
def testNegativeDofFails(self): def testNegativeDofFails(self):
with self.test_session():
with self.assertRaisesOpError(r"Condition x > 0 did not hold"): with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
student = student_t.StudentT( student = student_t.StudentT(
df=[2, -5.], loc=0., scale=1., validate_args=True, name="S") df=[2, -5.], loc=0., scale=1., validate_args=True, name="S")
self.evaluate(student.mean()) self.evaluate(student.mean())
def testStudentTWithAbsDfSoftplusScale(self): def testStudentTWithAbsDfSoftplusScale(self):
with self.test_session():
df = constant_op.constant([-3.2, -4.6]) df = constant_op.constant([-3.2, -4.6])
mu = constant_op.constant([-4.2, 3.4]) mu = constant_op.constant([-4.2, 3.4])
sigma = constant_op.constant([-6.4, -8.8]) sigma = constant_op.constant([-6.4, -8.8])

View File

@ -50,7 +50,6 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testUniformRange(self): def testUniformRange(self):
with self.test_session():
a = 3.0 a = 3.0
b = 10.0 b = 10.0
uniform = uniform_lib.Uniform(low=a, high=b) uniform = uniform_lib.Uniform(low=a, high=b)
@ -60,7 +59,6 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testUniformPDF(self): def testUniformPDF(self):
with self.test_session():
a = constant_op.constant([-3.0] * 5 + [15.0]) a = constant_op.constant([-3.0] * 5 + [15.0])
b = constant_op.constant([11.0] * 5 + [20.0]) b = constant_op.constant([11.0] * 5 + [20.0])
uniform = uniform_lib.Uniform(low=a, high=b) uniform = uniform_lib.Uniform(low=a, high=b)
@ -86,7 +84,6 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testUniformShape(self): def testUniformShape(self):
with self.test_session():
a = constant_op.constant([-3.0] * 5) a = constant_op.constant([-3.0] * 5)
b = constant_op.constant(11.0) b = constant_op.constant(11.0)
uniform = uniform_lib.Uniform(low=a, high=b) uniform = uniform_lib.Uniform(low=a, high=b)
@ -98,7 +95,6 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testUniformPDFWithScalarEndpoint(self): def testUniformPDFWithScalarEndpoint(self):
with self.test_session():
a = constant_op.constant([0.0, 5.0]) a = constant_op.constant([0.0, 5.0])
b = constant_op.constant(10.0) b = constant_op.constant(10.0)
uniform = uniform_lib.Uniform(low=a, high=b) uniform = uniform_lib.Uniform(low=a, high=b)
@ -111,7 +107,6 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testUniformCDF(self): def testUniformCDF(self):
with self.test_session():
batch_size = 6 batch_size = 6
a = constant_op.constant([1.0] * batch_size) a = constant_op.constant([1.0] * batch_size)
b = constant_op.constant([11.0] * batch_size) b = constant_op.constant([11.0] * batch_size)
@ -135,7 +130,6 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testUniformEntropy(self): def testUniformEntropy(self):
with self.test_session():
a_v = np.array([1.0, 1.0, 1.0]) a_v = np.array([1.0, 1.0, 1.0])
b_v = np.array([[1.5, 2.0, 3.0]]) b_v = np.array([[1.5, 2.0, 3.0]])
uniform = uniform_lib.Uniform(low=a_v, high=b_v) uniform = uniform_lib.Uniform(low=a_v, high=b_v)
@ -145,7 +139,6 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testUniformAssertMaxGtMin(self): def testUniformAssertMaxGtMin(self):
with self.test_session():
a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32) a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32)
b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32) b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32)
@ -156,7 +149,6 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testUniformSample(self): def testUniformSample(self):
with self.test_session():
a = constant_op.constant([3.0, 4.0]) a = constant_op.constant([3.0, 4.0])
b = constant_op.constant(13.0) b = constant_op.constant(13.0)
a1_v = 3.0 a1_v = 3.0
@ -180,7 +172,6 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def _testUniformSampleMultiDimensional(self): def _testUniformSampleMultiDimensional(self):
# DISABLED: Please enable this test once b/issues/30149644 is resolved. # DISABLED: Please enable this test once b/issues/30149644 is resolved.
with self.test_session():
batch_size = 2 batch_size = 2
a_v = [3.0, 22.0] a_v = [3.0, 22.0]
b_v = [13.0, 35.0] b_v = [13.0, 35.0]
@ -210,7 +201,6 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testUniformMean(self): def testUniformMean(self):
with self.test_session():
a = 10.0 a = 10.0
b = 100.0 b = 100.0
uniform = uniform_lib.Uniform(low=a, high=b) uniform = uniform_lib.Uniform(low=a, high=b)
@ -221,7 +211,6 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testUniformVariance(self): def testUniformVariance(self):
with self.test_session():
a = 10.0 a = 10.0
b = 100.0 b = 100.0
uniform = uniform_lib.Uniform(low=a, high=b) uniform = uniform_lib.Uniform(low=a, high=b)
@ -232,7 +221,6 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testUniformStd(self): def testUniformStd(self):
with self.test_session():
a = 10.0 a = 10.0
b = 100.0 b = 100.0
uniform = uniform_lib.Uniform(low=a, high=b) uniform = uniform_lib.Uniform(low=a, high=b)
@ -243,7 +231,6 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testUniformNans(self): def testUniformNans(self):
with self.test_session():
a = 10.0 a = 10.0
b = [11.0, 100.0] b = [11.0, 100.0]
uniform = uniform_lib.Uniform(low=a, high=b) uniform = uniform_lib.Uniform(low=a, high=b)
@ -261,7 +248,6 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testUniformSamplePdf(self): def testUniformSamplePdf(self):
with self.test_session():
a = 10.0 a = 10.0
b = [11.0, 100.0] b = [11.0, 100.0]
uniform = uniform_lib.Uniform(a, b) uniform = uniform_lib.Uniform(a, b)
@ -271,7 +257,6 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testUniformBroadcasting(self): def testUniformBroadcasting(self):
with self.test_session():
a = 10.0 a = 10.0
b = [11.0, 20.0] b = [11.0, 20.0]
uniform = uniform_lib.Uniform(a, b) uniform = uniform_lib.Uniform(a, b)
@ -282,7 +267,6 @@ class UniformTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testUniformSampleWithShape(self): def testUniformSampleWithShape(self):
with self.test_session():
a = 10.0 a = 10.0
b = [11.0, 20.0] b = [11.0, 20.0]
uniform = uniform_lib.Uniform(a, b) uniform = uniform_lib.Uniform(a, b)

View File

@ -69,7 +69,7 @@ class AssertCloseTest(test.TestCase):
w = array_ops.placeholder(dtypes.float32) w = array_ops.placeholder(dtypes.float32)
feed_dict = {x: [1., 5, 10, 15, 20], y: [1.1, 5, 10, 15, 20], feed_dict = {x: [1., 5, 10, 15, 20], y: [1.1, 5, 10, 15, 20],
z: [1.0001, 5, 10, 15, 20], w: [1e-8, 5, 10, 15, 20]} z: [1.0001, 5, 10, 15, 20], w: [1e-8, 5, 10, 15, 20]}
with self.test_session(): with self.cached_session():
with ops.control_dependencies([du.assert_integer_form(x)]): with ops.control_dependencies([du.assert_integer_form(x)]):
array_ops.identity(x).eval(feed_dict=feed_dict) array_ops.identity(x).eval(feed_dict=feed_dict)
@ -122,7 +122,6 @@ class GetLogitsAndProbsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testImproperArguments(self): def testImproperArguments(self):
with self.test_session():
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
du.get_logits_and_probs(logits=None, probs=None) du.get_logits_and_probs(logits=None, probs=None)
@ -134,7 +133,6 @@ class GetLogitsAndProbsTest(test.TestCase):
p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32) p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)
logits = _logit(p) logits = _logit(p)
with self.test_session():
new_logits, new_p = du.get_logits_and_probs( new_logits, new_p = du.get_logits_and_probs(
logits=logits, validate_args=True) logits=logits, validate_args=True)
@ -146,7 +144,6 @@ class GetLogitsAndProbsTest(test.TestCase):
p = np.array([0.2, 0.3, 0.5], dtype=np.float32) p = np.array([0.2, 0.3, 0.5], dtype=np.float32)
logits = np.log(p) logits = np.log(p)
with self.test_session():
new_logits, new_p = du.get_logits_and_probs( new_logits, new_p = du.get_logits_and_probs(
logits=logits, multidimensional=True, validate_args=True) logits=logits, multidimensional=True, validate_args=True)
@ -157,9 +154,7 @@ class GetLogitsAndProbsTest(test.TestCase):
def testProbability(self): def testProbability(self):
p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32) p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)
with self.test_session(): new_logits, new_p = du.get_logits_and_probs(probs=p, validate_args=True)
new_logits, new_p = du.get_logits_and_probs(
probs=p, validate_args=True)
self.assertAllClose(_logit(p), self.evaluate(new_logits)) self.assertAllClose(_logit(p), self.evaluate(new_logits))
self.assertAllClose(p, self.evaluate(new_p)) self.assertAllClose(p, self.evaluate(new_p))
@ -168,7 +163,6 @@ class GetLogitsAndProbsTest(test.TestCase):
def testProbabilityMultidimensional(self): def testProbabilityMultidimensional(self):
p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32) p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
with self.test_session():
new_logits, new_p = du.get_logits_and_probs( new_logits, new_p = du.get_logits_and_probs(
probs=p, multidimensional=True, validate_args=True) probs=p, multidimensional=True, validate_args=True)
@ -183,27 +177,21 @@ class GetLogitsAndProbsTest(test.TestCase):
# Component greater than 1. # Component greater than 1.
p3 = [2, 0.2, 0.5, 0.3, .2] p3 = [2, 0.2, 0.5, 0.3, .2]
with self.test_session(): _, prob = du.get_logits_and_probs(probs=p, validate_args=True)
_, prob = du.get_logits_and_probs(
probs=p, validate_args=True)
self.evaluate(prob) self.evaluate(prob)
with self.assertRaisesOpError("Condition x >= 0"): with self.assertRaisesOpError("Condition x >= 0"):
_, prob = du.get_logits_and_probs( _, prob = du.get_logits_and_probs(probs=p2, validate_args=True)
probs=p2, validate_args=True)
self.evaluate(prob) self.evaluate(prob)
_, prob = du.get_logits_and_probs( _, prob = du.get_logits_and_probs(probs=p2, validate_args=False)
probs=p2, validate_args=False)
self.evaluate(prob) self.evaluate(prob)
with self.assertRaisesOpError("probs has components greater than 1"): with self.assertRaisesOpError("probs has components greater than 1"):
_, prob = du.get_logits_and_probs( _, prob = du.get_logits_and_probs(probs=p3, validate_args=True)
probs=p3, validate_args=True)
self.evaluate(prob) self.evaluate(prob)
_, prob = du.get_logits_and_probs( _, prob = du.get_logits_and_probs(probs=p3, validate_args=False)
probs=p3, validate_args=False)
self.evaluate(prob) self.evaluate(prob)
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
@ -216,9 +204,7 @@ class GetLogitsAndProbsTest(test.TestCase):
# Does not sum to 1. # Does not sum to 1.
p4 = np.array([[1.1, 0.3, 0.4], [0.1, 0.5, 0.4]], dtype=np.float32) p4 = np.array([[1.1, 0.3, 0.4], [0.1, 0.5, 0.4]], dtype=np.float32)
with self.test_session(): _, prob = du.get_logits_and_probs(probs=p, multidimensional=True)
_, prob = du.get_logits_and_probs(
probs=p, multidimensional=True)
self.evaluate(prob) self.evaluate(prob)
with self.assertRaisesOpError("Condition x >= 0"): with self.assertRaisesOpError("Condition x >= 0"):
@ -250,7 +236,7 @@ class GetLogitsAndProbsTest(test.TestCase):
self.evaluate(prob) self.evaluate(prob)
def testProbsMultidimShape(self): def testProbsMultidimShape(self):
with self.test_session(): with self.cached_session():
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
p = array_ops.ones([int(2**11+1)], dtype=np.float16) p = array_ops.ones([int(2**11+1)], dtype=np.float16)
du.get_logits_and_probs( du.get_logits_and_probs(
@ -264,7 +250,7 @@ class GetLogitsAndProbsTest(test.TestCase):
prob.eval(feed_dict={p: np.ones([int(2**11+1)])}) prob.eval(feed_dict={p: np.ones([int(2**11+1)])})
def testLogitsMultidimShape(self): def testLogitsMultidimShape(self):
with self.test_session(): with self.cached_session():
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
l = array_ops.ones([int(2**11+1)], dtype=np.float16) l = array_ops.ones([int(2**11+1)], dtype=np.float16)
du.get_logits_and_probs( du.get_logits_and_probs(
@ -281,7 +267,7 @@ class GetLogitsAndProbsTest(test.TestCase):
class EmbedCheckCategoricalEventShapeTest(test.TestCase): class EmbedCheckCategoricalEventShapeTest(test.TestCase):
def testTooSmall(self): def testTooSmall(self):
with self.test_session(): with self.cached_session():
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
param = array_ops.ones([1], dtype=np.float16) param = array_ops.ones([1], dtype=np.float16)
checked_param = du.embed_check_categorical_event_shape( checked_param = du.embed_check_categorical_event_shape(
@ -295,7 +281,7 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase):
checked_param.eval(feed_dict={param: np.ones([1])}) checked_param.eval(feed_dict={param: np.ones([1])})
def testTooLarge(self): def testTooLarge(self):
with self.test_session(): with self.cached_session():
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
param = array_ops.ones([int(2**11+1)], dtype=dtypes.float16) param = array_ops.ones([int(2**11+1)], dtype=dtypes.float16)
checked_param = du.embed_check_categorical_event_shape( checked_param = du.embed_check_categorical_event_shape(
@ -310,7 +296,6 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testUnsupportedDtype(self): def testUnsupportedDtype(self):
with self.test_session():
param = ops.convert_to_tensor( param = ops.convert_to_tensor(
np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype), np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype),
dtype=dtypes.qint16) dtype=dtypes.qint16)
@ -321,7 +306,7 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase):
class EmbedCheckIntegerCastingClosedTest(test.TestCase): class EmbedCheckIntegerCastingClosedTest(test.TestCase):
def testCorrectlyAssertsNonnegative(self): def testCorrectlyAssertsNonnegative(self):
with self.test_session(): with self.cached_session():
with self.assertRaisesOpError("Elements must be non-negative"): with self.assertRaisesOpError("Elements must be non-negative"):
x = array_ops.placeholder(dtype=dtypes.float16) x = array_ops.placeholder(dtype=dtypes.float16)
x_checked = du.embed_check_integer_casting_closed( x_checked = du.embed_check_integer_casting_closed(
@ -329,7 +314,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase):
x_checked.eval(feed_dict={x: np.array([1, -1], dtype=np.float16)}) x_checked.eval(feed_dict={x: np.array([1, -1], dtype=np.float16)})
def testCorrectlyAssersIntegerForm(self): def testCorrectlyAssersIntegerForm(self):
with self.test_session(): with self.cached_session():
with self.assertRaisesOpError("Elements must be int16-equivalent."): with self.assertRaisesOpError("Elements must be int16-equivalent."):
x = array_ops.placeholder(dtype=dtypes.float16) x = array_ops.placeholder(dtype=dtypes.float16)
x_checked = du.embed_check_integer_casting_closed( x_checked = du.embed_check_integer_casting_closed(
@ -337,7 +322,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase):
x_checked.eval(feed_dict={x: np.array([1, 1.5], dtype=np.float16)}) x_checked.eval(feed_dict={x: np.array([1, 1.5], dtype=np.float16)})
def testCorrectlyAssertsLargestPossibleInteger(self): def testCorrectlyAssertsLargestPossibleInteger(self):
with self.test_session(): with self.cached_session():
with self.assertRaisesOpError("Elements cannot exceed 32767."): with self.assertRaisesOpError("Elements cannot exceed 32767."):
x = array_ops.placeholder(dtype=dtypes.int32) x = array_ops.placeholder(dtype=dtypes.int32)
x_checked = du.embed_check_integer_casting_closed( x_checked = du.embed_check_integer_casting_closed(
@ -345,7 +330,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase):
x_checked.eval(feed_dict={x: np.array([1, 2**15], dtype=np.int32)}) x_checked.eval(feed_dict={x: np.array([1, 2**15], dtype=np.int32)})
def testCorrectlyAssertsSmallestPossibleInteger(self): def testCorrectlyAssertsSmallestPossibleInteger(self):
with self.test_session(): with self.cached_session():
with self.assertRaisesOpError("Elements cannot be smaller than 0."): with self.assertRaisesOpError("Elements cannot be smaller than 0."):
x = array_ops.placeholder(dtype=dtypes.int32) x = array_ops.placeholder(dtype=dtypes.int32)
x_checked = du.embed_check_integer_casting_closed( x_checked = du.embed_check_integer_casting_closed(
@ -365,7 +350,6 @@ class LogCombinationsTest(test.TestCase):
log_combs = np.log(special.binom(n, k)) log_combs = np.log(special.binom(n, k))
with self.test_session():
n = np.array(n, dtype=np.float32) n = np.array(n, dtype=np.float32)
counts = [[1., 1], [2., 3], [4., 8], [11, 4]] counts = [[1., 1], [2., 3], [4., 8], [11, 4]]
log_binom = du.log_combinations(n, counts) log_binom = du.log_combinations(n, counts)
@ -376,7 +360,6 @@ class LogCombinationsTest(test.TestCase):
# Shape [2, 2] # Shape [2, 2]
n = [[2, 5], [12, 15]] n = [[2, 5], [12, 15]]
with self.test_session():
n = np.array(n, dtype=np.float32) n = np.array(n, dtype=np.float32)
# Shape [2, 2, 4] # Shape [2, 2, 4]
counts = [[[1., 1, 0, 0], [2., 2, 1, 0]], [[4., 4, 1, 3], [10, 1, 1, 4]]] counts = [[[1., 1, 0, 0], [2., 2, 1, 0]], [[4., 4, 1, 3], [10, 1, 1, 4]]]
@ -387,7 +370,7 @@ class LogCombinationsTest(test.TestCase):
class DynamicShapeTest(test.TestCase): class DynamicShapeTest(test.TestCase):
def testSameDynamicShape(self): def testSameDynamicShape(self):
with self.test_session(): with self.cached_session():
scalar = constant_op.constant(2.0) scalar = constant_op.constant(2.0)
scalar1 = array_ops.placeholder(dtype=dtypes.float32) scalar1 = array_ops.placeholder(dtype=dtypes.float32)
@ -497,7 +480,6 @@ class RotateTransposeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testRollStatic(self): def testRollStatic(self):
with self.test_session():
if context.executing_eagerly(): if context.executing_eagerly():
error_message = r"Attempt to convert a value \(None\)" error_message = r"Attempt to convert a value \(None\)"
else: else:
@ -512,7 +494,7 @@ class RotateTransposeTest(test.TestCase):
self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list()) self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list())
def testRollDynamic(self): def testRollDynamic(self):
with self.test_session() as sess: with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32) x = array_ops.placeholder(dtypes.float32)
shift = array_ops.placeholder(dtypes.int32) shift = array_ops.placeholder(dtypes.int32)
for x_value in (np.ones( for x_value in (np.ones(
@ -530,7 +512,7 @@ class RotateTransposeTest(test.TestCase):
class PickVectorTest(test.TestCase): class PickVectorTest(test.TestCase):
def testCorrectlyPicksVector(self): def testCorrectlyPicksVector(self):
with self.test_session(): with self.cached_session():
x = np.arange(10, 12) x = np.arange(10, 12)
y = np.arange(15, 18) y = np.arange(15, 18)
self.assertAllEqual( self.assertAllEqual(
@ -568,19 +550,19 @@ class PreferStaticRankTest(test.TestCase):
def testDynamicRankEndsUpBeingNonEmpty(self): def testDynamicRankEndsUpBeingNonEmpty(self):
x = array_ops.placeholder(np.float64, shape=None) x = array_ops.placeholder(np.float64, shape=None)
rank = du.prefer_static_rank(x) rank = du.prefer_static_rank(x)
with self.test_session(): with self.cached_session():
self.assertAllEqual(2, rank.eval(feed_dict={x: np.zeros((2, 3))})) self.assertAllEqual(2, rank.eval(feed_dict={x: np.zeros((2, 3))}))
def testDynamicRankEndsUpBeingEmpty(self): def testDynamicRankEndsUpBeingEmpty(self):
x = array_ops.placeholder(np.int32, shape=None) x = array_ops.placeholder(np.int32, shape=None)
rank = du.prefer_static_rank(x) rank = du.prefer_static_rank(x)
with self.test_session(): with self.cached_session():
self.assertAllEqual(1, rank.eval(feed_dict={x: []})) self.assertAllEqual(1, rank.eval(feed_dict={x: []}))
def testDynamicRankEndsUpBeingScalar(self): def testDynamicRankEndsUpBeingScalar(self):
x = array_ops.placeholder(np.int32, shape=None) x = array_ops.placeholder(np.int32, shape=None)
rank = du.prefer_static_rank(x) rank = du.prefer_static_rank(x)
with self.test_session(): with self.cached_session():
self.assertAllEqual(0, rank.eval(feed_dict={x: 1})) self.assertAllEqual(0, rank.eval(feed_dict={x: 1}))
@ -607,19 +589,19 @@ class PreferStaticShapeTest(test.TestCase):
def testDynamicShapeEndsUpBeingNonEmpty(self): def testDynamicShapeEndsUpBeingNonEmpty(self):
x = array_ops.placeholder(np.float64, shape=None) x = array_ops.placeholder(np.float64, shape=None)
shape = du.prefer_static_shape(x) shape = du.prefer_static_shape(x)
with self.test_session(): with self.cached_session():
self.assertAllEqual((2, 3), shape.eval(feed_dict={x: np.zeros((2, 3))})) self.assertAllEqual((2, 3), shape.eval(feed_dict={x: np.zeros((2, 3))}))
def testDynamicShapeEndsUpBeingEmpty(self): def testDynamicShapeEndsUpBeingEmpty(self):
x = array_ops.placeholder(np.int32, shape=None) x = array_ops.placeholder(np.int32, shape=None)
shape = du.prefer_static_shape(x) shape = du.prefer_static_shape(x)
with self.test_session(): with self.cached_session():
self.assertAllEqual(np.array([0]), shape.eval(feed_dict={x: []})) self.assertAllEqual(np.array([0]), shape.eval(feed_dict={x: []}))
def testDynamicShapeEndsUpBeingScalar(self): def testDynamicShapeEndsUpBeingScalar(self):
x = array_ops.placeholder(np.int32, shape=None) x = array_ops.placeholder(np.int32, shape=None)
shape = du.prefer_static_shape(x) shape = du.prefer_static_shape(x)
with self.test_session(): with self.cached_session():
self.assertAllEqual(np.array([]), shape.eval(feed_dict={x: 1})) self.assertAllEqual(np.array([]), shape.eval(feed_dict={x: 1}))
@ -646,20 +628,20 @@ class PreferStaticValueTest(test.TestCase):
def testDynamicValueEndsUpBeingNonEmpty(self): def testDynamicValueEndsUpBeingNonEmpty(self):
x = array_ops.placeholder(np.float64, shape=None) x = array_ops.placeholder(np.float64, shape=None)
value = du.prefer_static_value(x) value = du.prefer_static_value(x)
with self.test_session(): with self.cached_session():
self.assertAllEqual(np.zeros((2, 3)), self.assertAllEqual(np.zeros((2, 3)),
value.eval(feed_dict={x: np.zeros((2, 3))})) value.eval(feed_dict={x: np.zeros((2, 3))}))
def testDynamicValueEndsUpBeingEmpty(self): def testDynamicValueEndsUpBeingEmpty(self):
x = array_ops.placeholder(np.int32, shape=None) x = array_ops.placeholder(np.int32, shape=None)
value = du.prefer_static_value(x) value = du.prefer_static_value(x)
with self.test_session(): with self.cached_session():
self.assertAllEqual(np.array([]), value.eval(feed_dict={x: []})) self.assertAllEqual(np.array([]), value.eval(feed_dict={x: []}))
def testDynamicValueEndsUpBeingScalar(self): def testDynamicValueEndsUpBeingScalar(self):
x = array_ops.placeholder(np.int32, shape=None) x = array_ops.placeholder(np.int32, shape=None)
value = du.prefer_static_value(x) value = du.prefer_static_value(x)
with self.test_session(): with self.cached_session():
self.assertAllEqual(np.array(1), value.eval(feed_dict={x: 1})) self.assertAllEqual(np.array(1), value.eval(feed_dict={x: 1}))
@ -691,7 +673,7 @@ class FillTriangularTest(test.TestCase):
def _run_test(self, x_, use_deferred_shape=False, **kwargs): def _run_test(self, x_, use_deferred_shape=False, **kwargs):
x_ = np.asarray(x_) x_ = np.asarray(x_)
with self.test_session() as sess: with self.cached_session() as sess:
static_shape = None if use_deferred_shape else x_.shape static_shape = None if use_deferred_shape else x_.shape
x_pl = array_ops.placeholder_with_default(x_, shape=static_shape) x_pl = array_ops.placeholder_with_default(x_, shape=static_shape)
# Add `zeros_like(x)` such that x's value and gradient are identical. We # Add `zeros_like(x)` such that x's value and gradient are identical. We
@ -761,7 +743,7 @@ class FillTriangularInverseTest(FillTriangularTest):
def _run_test(self, x_, use_deferred_shape=False, **kwargs): def _run_test(self, x_, use_deferred_shape=False, **kwargs):
x_ = np.asarray(x_) x_ = np.asarray(x_)
with self.test_session() as sess: with self.cached_session() as sess:
static_shape = None if use_deferred_shape else x_.shape static_shape = None if use_deferred_shape else x_.shape
x_pl = array_ops.placeholder_with_default(x_, shape=static_shape) x_pl = array_ops.placeholder_with_default(x_, shape=static_shape)
zeros_like_x_pl = (x_pl * array_ops.stop_gradient(x_pl - 1.) zeros_like_x_pl = (x_pl * array_ops.stop_gradient(x_pl - 1.)
@ -795,7 +777,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
logx_ = np.array([[0., -1, 1000.], logx_ = np.array([[0., -1, 1000.],
[0, 1, -1000.], [0, 1, -1000.],
[-5, 0, 5]]) [-5, 0, 5]])
with self.test_session() as sess: with self.cached_session() as sess:
logx = constant_op.constant(logx_) logx = constant_op.constant(logx_)
expected = math_ops.reduce_logsumexp(logx, axis=-1) expected = math_ops.reduce_logsumexp(logx, axis=-1)
grad_expected = gradients_impl.gradients(expected, logx)[0] grad_expected = gradients_impl.gradients(expected, logx)[0]
@ -818,7 +800,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
[1, -2, 1], [1, -2, 1],
[1, 0, 1]]) [1, 0, 1]])
expected, _ = self._reduce_weighted_logsumexp(logx_, w_, axis=-1) expected, _ = self._reduce_weighted_logsumexp(logx_, w_, axis=-1)
with self.test_session() as sess: with self.cached_session() as sess:
logx = constant_op.constant(logx_) logx = constant_op.constant(logx_)
w = constant_op.constant(w_) w = constant_op.constant(w_)
actual, actual_sgn = du.reduce_weighted_logsumexp( actual, actual_sgn = du.reduce_weighted_logsumexp(
@ -836,7 +818,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
[1, 0, 1]]) [1, 0, 1]])
expected, _ = self._reduce_weighted_logsumexp( expected, _ = self._reduce_weighted_logsumexp(
logx_, w_, axis=-1, keep_dims=True) logx_, w_, axis=-1, keep_dims=True)
with self.test_session() as sess: with self.cached_session() as sess:
logx = constant_op.constant(logx_) logx = constant_op.constant(logx_)
w = constant_op.constant(w_) w = constant_op.constant(w_)
actual, actual_sgn = du.reduce_weighted_logsumexp( actual, actual_sgn = du.reduce_weighted_logsumexp(
@ -848,7 +830,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
def testDocString(self): def testDocString(self):
"""This test verifies the correctness of the docstring examples.""" """This test verifies the correctness of the docstring examples."""
with self.test_session(): with self.cached_session():
x = constant_op.constant([[0., 0, 0], x = constant_op.constant([[0., 0, 0],
[0, 0, 0]]) [0, 0, 0]])
@ -952,7 +934,7 @@ class SoftplusTest(test.TestCase):
use_gpu=True) use_gpu=True)
def testGradient(self): def testGradient(self):
with self.test_session(): with self.cached_session():
x = constant_op.constant( x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5], shape=[2, 5],
@ -968,7 +950,7 @@ class SoftplusTest(test.TestCase):
self.assertLess(err, 1e-4) self.assertLess(err, 1e-4)
def testInverseSoftplusGradientNeverNan(self): def testInverseSoftplusGradientNeverNan(self):
with self.test_session(): with self.cached_session():
# Note that this range contains both zero and inf. # Note that this range contains both zero and inf.
x = constant_op.constant(np.logspace(-8, 6).astype(np.float16)) x = constant_op.constant(np.logspace(-8, 6).astype(np.float16))
y = du.softplus_inverse(x) y = du.softplus_inverse(x)
@ -977,7 +959,7 @@ class SoftplusTest(test.TestCase):
self.assertAllEqual(np.zeros_like(grads).astype(np.bool), np.isnan(grads)) self.assertAllEqual(np.zeros_like(grads).astype(np.bool), np.isnan(grads))
def testInverseSoftplusGradientFinite(self): def testInverseSoftplusGradientFinite(self):
with self.test_session(): with self.cached_session():
# This range of x is all finite, and so is 1 / x. So the # This range of x is all finite, and so is 1 / x. So the
# gradient and its approximations should be finite as well. # gradient and its approximations should be finite as well.
x = constant_op.constant(np.logspace(-4.8, 4.5).astype(np.float16)) x = constant_op.constant(np.logspace(-4.8, 4.5).astype(np.float16))