diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
index 9ad77a54cbc..26d013bccb9 100644
--- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
+++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
@@ -62,59 +62,50 @@ class BernoulliTest(test.TestCase):
   def testP(self):
     p = [0.2, 0.4]
     dist = bernoulli.Bernoulli(probs=p)
-    with self.test_session():
-      self.assertAllClose(p, self.evaluate(dist.probs))
+    self.assertAllClose(p, self.evaluate(dist.probs))
 
   @test_util.run_in_graph_and_eager_modes
   def testLogits(self):
     logits = [-42., 42.]
     dist = bernoulli.Bernoulli(logits=logits)
-    with self.test_session():
-      self.assertAllClose(logits, self.evaluate(dist.logits))
+    self.assertAllClose(logits, self.evaluate(dist.logits))
 
     if not special:
       return
 
-    with self.test_session():
-      self.assertAllClose(special.expit(logits), self.evaluate(dist.probs))
+    self.assertAllClose(special.expit(logits), self.evaluate(dist.probs))
 
     p = [0.01, 0.99, 0.42]
     dist = bernoulli.Bernoulli(probs=p)
-    with self.test_session():
-      self.assertAllClose(special.logit(p), self.evaluate(dist.logits))
+    self.assertAllClose(special.logit(p), self.evaluate(dist.logits))
 
   @test_util.run_in_graph_and_eager_modes
   def testInvalidP(self):
     invalid_ps = [1.01, 2.]
     for p in invalid_ps:
-      with self.test_session():
-        with self.assertRaisesOpError("probs has components greater than 1"):
-          dist = bernoulli.Bernoulli(probs=p, validate_args=True)
-          self.evaluate(dist.probs)
+      with self.assertRaisesOpError("probs has components greater than 1"):
+        dist = bernoulli.Bernoulli(probs=p, validate_args=True)
+        self.evaluate(dist.probs)
 
     invalid_ps = [-0.01, -3.]
     for p in invalid_ps:
-      with self.test_session():
-        with self.assertRaisesOpError("Condition x >= 0"):
-          dist = bernoulli.Bernoulli(probs=p, validate_args=True)
-          self.evaluate(dist.probs)
+      with self.assertRaisesOpError("Condition x >= 0"):
+        dist = bernoulli.Bernoulli(probs=p, validate_args=True)
+        self.evaluate(dist.probs)
 
     valid_ps = [0.0, 0.5, 1.0]
     for p in valid_ps:
-      with self.test_session():
-        dist = bernoulli.Bernoulli(probs=p)
-        self.assertEqual(p, self.evaluate(dist.probs))  # Should not fail
+      dist = bernoulli.Bernoulli(probs=p)
+      self.assertEqual(p, self.evaluate(dist.probs))  # Should not fail
 
   @test_util.run_in_graph_and_eager_modes
   def testShapes(self):
-    with self.test_session():
-      for batch_shape in ([], [1], [2, 3, 4]):
-        dist = make_bernoulli(batch_shape)
-        self.assertAllEqual(batch_shape, dist.batch_shape.as_list())
-        self.assertAllEqual(batch_shape,
-                            self.evaluate(dist.batch_shape_tensor()))
-        self.assertAllEqual([], dist.event_shape.as_list())
-        self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
+    for batch_shape in ([], [1], [2, 3, 4]):
+      dist = make_bernoulli(batch_shape)
+      self.assertAllEqual(batch_shape, dist.batch_shape.as_list())
+      self.assertAllEqual(batch_shape, self.evaluate(dist.batch_shape_tensor()))
+      self.assertAllEqual([], dist.event_shape.as_list())
+      self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
 
   @test_util.run_in_graph_and_eager_modes
   def testDtype(self):
@@ -137,31 +128,29 @@ class BernoulliTest(test.TestCase):
   @test_util.run_in_graph_and_eager_modes
   def _testPmf(self, **kwargs):
     dist = bernoulli.Bernoulli(**kwargs)
-    with self.test_session():
-      # pylint: disable=bad-continuation
-      xs = [
-          0,
-          [1],
-          [1, 0],
-          [[1, 0]],
-          [[1, 0], [1, 1]],
-      ]
-      expected_pmfs = [
-          [[0.8, 0.6], [0.7, 0.4]],
-          [[0.2, 0.4], [0.3, 0.6]],
-          [[0.2, 0.6], [0.3, 0.4]],
-          [[0.2, 0.6], [0.3, 0.4]],
-          [[0.2, 0.6], [0.3, 0.6]],
-      ]
-      # pylint: enable=bad-continuation
+    # pylint: disable=bad-continuation
+    xs = [
+        0,
+        [1],
+        [1, 0],
+        [[1, 0]],
+        [[1, 0], [1, 1]],
+    ]
+    expected_pmfs = [
+        [[0.8, 0.6], [0.7, 0.4]],
+        [[0.2, 0.4], [0.3, 0.6]],
+        [[0.2, 0.6], [0.3, 0.4]],
+        [[0.2, 0.6], [0.3, 0.4]],
+        [[0.2, 0.6], [0.3, 0.6]],
+    ]
+    # pylint: enable=bad-continuation
 
-      for x, expected_pmf in zip(xs, expected_pmfs):
-        self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf)
-        self.assertAllClose(
-            self.evaluate(dist.log_prob(x)), np.log(expected_pmf))
+    for x, expected_pmf in zip(xs, expected_pmfs):
+      self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf)
+      self.assertAllClose(self.evaluate(dist.log_prob(x)), np.log(expected_pmf))
 
   def testPmfCorrectBroadcastDynamicShape(self):
-    with self.test_session():
+    with self.cached_session():
       p = array_ops.placeholder(dtype=dtypes.float32)
       dist = bernoulli.Bernoulli(probs=p)
       event1 = [1, 0, 1]
@@ -178,12 +167,11 @@ class BernoulliTest(test.TestCase):
   @test_util.run_in_graph_and_eager_modes
   def testPmfInvalid(self):
     p = [0.1, 0.2, 0.7]
-    with self.test_session():
-      dist = bernoulli.Bernoulli(probs=p, validate_args=True)
-      with self.assertRaisesOpError("must be non-negative."):
-        self.evaluate(dist.prob([1, 1, -1]))
-      with self.assertRaisesOpError("Elements cannot exceed 1."):
-        self.evaluate(dist.prob([2, 0, 1]))
+    dist = bernoulli.Bernoulli(probs=p, validate_args=True)
+    with self.assertRaisesOpError("must be non-negative."):
+      self.evaluate(dist.prob([1, 1, -1]))
+    with self.assertRaisesOpError("Elements cannot exceed 1."):
+      self.evaluate(dist.prob([2, 0, 1]))
 
   @test_util.run_in_graph_and_eager_modes
   def testPmfWithP(self):
@@ -194,7 +182,7 @@ class BernoulliTest(test.TestCase):
     self._testPmf(logits=special.logit(p))
 
   def testBroadcasting(self):
-    with self.test_session():
+    with self.cached_session():
       p = array_ops.placeholder(dtypes.float32)
       dist = bernoulli.Bernoulli(probs=p)
       self.assertAllClose(np.log(0.5), dist.log_prob(1).eval({p: 0.5}))
@@ -208,70 +196,63 @@ class BernoulliTest(test.TestCase):
           }))
 
   def testPmfShapes(self):
-    with self.test_session():
+    with self.cached_session():
       p = array_ops.placeholder(dtypes.float32, shape=[None, 1])
       dist = bernoulli.Bernoulli(probs=p)
       self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape))
 
-    with self.test_session():
       dist = bernoulli.Bernoulli(probs=0.5)
       self.assertEqual(2, len(self.evaluate(dist.log_prob([[1], [1]])).shape))
 
-    with self.test_session():
       dist = bernoulli.Bernoulli(probs=0.5)
       self.assertEqual((), dist.log_prob(1).get_shape())
       self.assertEqual((1), dist.log_prob([1]).get_shape())
       self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape())
 
-    with self.test_session():
       dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]])
       self.assertEqual((2, 1), dist.log_prob(1).get_shape())
 
   @test_util.run_in_graph_and_eager_modes
   def testBoundaryConditions(self):
-    with self.test_session():
-      dist = bernoulli.Bernoulli(probs=1.0)
-      self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0)))
-      self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))])
+    dist = bernoulli.Bernoulli(probs=1.0)
+    self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0)))
+    self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))])
 
   @test_util.run_in_graph_and_eager_modes
   def testEntropyNoBatch(self):
     p = 0.2
     dist = bernoulli.Bernoulli(probs=p)
-    with self.test_session():
-      self.assertAllClose(self.evaluate(dist.entropy()), entropy(p))
+    self.assertAllClose(self.evaluate(dist.entropy()), entropy(p))
 
   @test_util.run_in_graph_and_eager_modes
   def testEntropyWithBatch(self):
     p = [[0.1, 0.7], [0.2, 0.6]]
     dist = bernoulli.Bernoulli(probs=p, validate_args=False)
-    with self.test_session():
-      self.assertAllClose(
-          self.evaluate(dist.entropy()),
-          [[entropy(0.1), entropy(0.7)], [entropy(0.2),
-                                          entropy(0.6)]])
+    self.assertAllClose(
+        self.evaluate(dist.entropy()),
+        [[entropy(0.1), entropy(0.7)], [entropy(0.2),
+                                        entropy(0.6)]])
 
   @test_util.run_in_graph_and_eager_modes
   def testSampleN(self):
-    with self.test_session():
-      p = [0.2, 0.6]
-      dist = bernoulli.Bernoulli(probs=p)
-      n = 100000
-      samples = dist.sample(n)
-      samples.set_shape([n, 2])
-      self.assertEqual(samples.dtype, dtypes.int32)
-      sample_values = self.evaluate(samples)
-      self.assertTrue(np.all(sample_values >= 0))
-      self.assertTrue(np.all(sample_values <= 1))
-      # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) /
-      # n). This means that the tolerance is very sensitive to the value of p
-      # as well as n.
-      self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2)
-      self.assertEqual(set([0, 1]), set(sample_values.flatten()))
-      # In this test we're just interested in verifying there isn't a crash
-      # owing to mismatched types. b/30940152
-      dist = bernoulli.Bernoulli(np.log([.2, .4]))
-      self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list())
+    p = [0.2, 0.6]
+    dist = bernoulli.Bernoulli(probs=p)
+    n = 100000
+    samples = dist.sample(n)
+    samples.set_shape([n, 2])
+    self.assertEqual(samples.dtype, dtypes.int32)
+    sample_values = self.evaluate(samples)
+    self.assertTrue(np.all(sample_values >= 0))
+    self.assertTrue(np.all(sample_values <= 1))
+    # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) /
+    # n). This means that the tolerance is very sensitive to the value of p
+    # as well as n.
+    self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2)
+    self.assertEqual(set([0, 1]), set(sample_values.flatten()))
+    # In this test we're just interested in verifying there isn't a crash
+    # owing to mismatched types. b/30940152
+    dist = bernoulli.Bernoulli(np.log([.2, .4]))
+    self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list())
 
   @test_util.run_in_graph_and_eager_modes
   def testNotReparameterized(self):
@@ -284,7 +265,7 @@ class BernoulliTest(test.TestCase):
     self.assertIsNone(grad_p)
 
   def testSampleActsLikeSampleN(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       p = [0.2, 0.6]
       dist = bernoulli.Bernoulli(probs=p)
       n = 1000
@@ -299,27 +280,24 @@ class BernoulliTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes
   def testMean(self):
-    with self.test_session():
-      p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
-      dist = bernoulli.Bernoulli(probs=p)
-      self.assertAllEqual(self.evaluate(dist.mean()), p)
+    p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
+    dist = bernoulli.Bernoulli(probs=p)
+    self.assertAllEqual(self.evaluate(dist.mean()), p)
 
   @test_util.run_in_graph_and_eager_modes
   def testVarianceAndStd(self):
     var = lambda p: p * (1. - p)
-    with self.test_session():
-      p = [[0.2, 0.7], [0.5, 0.4]]
-      dist = bernoulli.Bernoulli(probs=p)
-      self.assertAllClose(
-          self.evaluate(dist.variance()),
-          np.array(
-              [[var(0.2), var(0.7)], [var(0.5), var(0.4)]], dtype=np.float32))
-      self.assertAllClose(
-          self.evaluate(dist.stddev()),
-          np.array(
-              [[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
-               [np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
-              dtype=np.float32))
+    p = [[0.2, 0.7], [0.5, 0.4]]
+    dist = bernoulli.Bernoulli(probs=p)
+    self.assertAllClose(
+        self.evaluate(dist.variance()),
+        np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]],
+                 dtype=np.float32))
+    self.assertAllClose(
+        self.evaluate(dist.stddev()),
+        np.array([[np.sqrt(var(0.2)), np.sqrt(var(0.7))],
+                  [np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
+                 dtype=np.float32))
 
   @test_util.run_in_graph_and_eager_modes
   def testBernoulliBernoulliKL(self):
diff --git a/tensorflow/python/kernel_tests/distributions/beta_test.py b/tensorflow/python/kernel_tests/distributions/beta_test.py
index 36f3ffc333f..d580a415dd8 100644
--- a/tensorflow/python/kernel_tests/distributions/beta_test.py
+++ b/tensorflow/python/kernel_tests/distributions/beta_test.py
@@ -20,7 +20,6 @@ import importlib
 
 import numpy as np
 
-from tensorflow.python.client import session
 from tensorflow.python.eager import backprop
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import random_seed
@@ -51,237 +50,215 @@ stats = try_import("scipy.stats")
 class BetaTest(test.TestCase):
 
   def testSimpleShapes(self):
-    with self.test_session():
-      a = np.random.rand(3)
-      b = np.random.rand(3)
-      dist = beta_lib.Beta(a, b)
-      self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
-      self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor()))
-      self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
-      self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape)
+    a = np.random.rand(3)
+    b = np.random.rand(3)
+    dist = beta_lib.Beta(a, b)
+    self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
+    self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor()))
+    self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+    self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape)
 
   def testComplexShapes(self):
-    with self.test_session():
-      a = np.random.rand(3, 2, 2)
-      b = np.random.rand(3, 2, 2)
-      dist = beta_lib.Beta(a, b)
-      self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
-      self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
-      self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
-      self.assertEqual(
-          tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
+    a = np.random.rand(3, 2, 2)
+    b = np.random.rand(3, 2, 2)
+    dist = beta_lib.Beta(a, b)
+    self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
+    self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
+    self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+    self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
 
   def testComplexShapesBroadcast(self):
-    with self.test_session():
-      a = np.random.rand(3, 2, 2)
-      b = np.random.rand(2, 2)
-      dist = beta_lib.Beta(a, b)
-      self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
-      self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
-      self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
-      self.assertEqual(
-          tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
+    a = np.random.rand(3, 2, 2)
+    b = np.random.rand(2, 2)
+    dist = beta_lib.Beta(a, b)
+    self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
+    self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor()))
+    self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape)
+    self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape)
 
   def testAlphaProperty(self):
     a = [[1., 2, 3]]
     b = [[2., 4, 3]]
-    with self.test_session():
-      dist = beta_lib.Beta(a, b)
-      self.assertEqual([1, 3], dist.concentration1.get_shape())
-      self.assertAllClose(a, self.evaluate(dist.concentration1))
+    dist = beta_lib.Beta(a, b)
+    self.assertEqual([1, 3], dist.concentration1.get_shape())
+    self.assertAllClose(a, self.evaluate(dist.concentration1))
 
   def testBetaProperty(self):
     a = [[1., 2, 3]]
     b = [[2., 4, 3]]
-    with self.test_session():
-      dist = beta_lib.Beta(a, b)
-      self.assertEqual([1, 3], dist.concentration0.get_shape())
-      self.assertAllClose(b, self.evaluate(dist.concentration0))
+    dist = beta_lib.Beta(a, b)
+    self.assertEqual([1, 3], dist.concentration0.get_shape())
+    self.assertAllClose(b, self.evaluate(dist.concentration0))
 
   def testPdfXProper(self):
     a = [[1., 2, 3]]
     b = [[2., 4, 3]]
-    with self.test_session():
-      dist = beta_lib.Beta(a, b, validate_args=True)
-      self.evaluate(dist.prob([.1, .3, .6]))
-      self.evaluate(dist.prob([.2, .3, .5]))
-      # Either condition can trigger.
-      with self.assertRaisesOpError("sample must be positive"):
-        self.evaluate(dist.prob([-1., 0.1, 0.5]))
-      with self.assertRaisesOpError("sample must be positive"):
-        self.evaluate(dist.prob([0., 0.1, 0.5]))
-      with self.assertRaisesOpError("sample must be less than `1`"):
-        self.evaluate(dist.prob([.1, .2, 1.2]))
-      with self.assertRaisesOpError("sample must be less than `1`"):
-        self.evaluate(dist.prob([.1, .2, 1.0]))
+    dist = beta_lib.Beta(a, b, validate_args=True)
+    self.evaluate(dist.prob([.1, .3, .6]))
+    self.evaluate(dist.prob([.2, .3, .5]))
+    # Either condition can trigger.
+    with self.assertRaisesOpError("sample must be positive"):
+      self.evaluate(dist.prob([-1., 0.1, 0.5]))
+    with self.assertRaisesOpError("sample must be positive"):
+      self.evaluate(dist.prob([0., 0.1, 0.5]))
+    with self.assertRaisesOpError("sample must be less than `1`"):
+      self.evaluate(dist.prob([.1, .2, 1.2]))
+    with self.assertRaisesOpError("sample must be less than `1`"):
+      self.evaluate(dist.prob([.1, .2, 1.0]))
 
   def testPdfTwoBatches(self):
-    with self.test_session():
-      a = [1., 2]
-      b = [1., 2]
-      x = [.5, .5]
-      dist = beta_lib.Beta(a, b)
-      pdf = dist.prob(x)
-      self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
-      self.assertEqual((2,), pdf.get_shape())
+    a = [1., 2]
+    b = [1., 2]
+    x = [.5, .5]
+    dist = beta_lib.Beta(a, b)
+    pdf = dist.prob(x)
+    self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
+    self.assertEqual((2,), pdf.get_shape())
 
   def testPdfTwoBatchesNontrivialX(self):
-    with self.test_session():
-      a = [1., 2]
-      b = [1., 2]
-      x = [.3, .7]
-      dist = beta_lib.Beta(a, b)
-      pdf = dist.prob(x)
-      self.assertAllClose([1, 63. / 50], self.evaluate(pdf))
-      self.assertEqual((2,), pdf.get_shape())
+    a = [1., 2]
+    b = [1., 2]
+    x = [.3, .7]
+    dist = beta_lib.Beta(a, b)
+    pdf = dist.prob(x)
+    self.assertAllClose([1, 63. / 50], self.evaluate(pdf))
+    self.assertEqual((2,), pdf.get_shape())
 
   def testPdfUniformZeroBatch(self):
-    with self.test_session():
-      # This is equivalent to a uniform distribution
-      a = 1.
-      b = 1.
-      x = np.array([.1, .2, .3, .5, .8], dtype=np.float32)
-      dist = beta_lib.Beta(a, b)
-      pdf = dist.prob(x)
-      self.assertAllClose([1.] * 5, self.evaluate(pdf))
-      self.assertEqual((5,), pdf.get_shape())
+    # This is equivalent to a uniform distribution
+    a = 1.
+    b = 1.
+    x = np.array([.1, .2, .3, .5, .8], dtype=np.float32)
+    dist = beta_lib.Beta(a, b)
+    pdf = dist.prob(x)
+    self.assertAllClose([1.] * 5, self.evaluate(pdf))
+    self.assertEqual((5,), pdf.get_shape())
 
   def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
-    with self.test_session():
-      a = [[1., 2]]
-      b = [[1., 2]]
-      x = [[.5, .5], [.3, .7]]
-      dist = beta_lib.Beta(a, b)
-      pdf = dist.prob(x)
-      self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], self.evaluate(pdf))
-      self.assertEqual((2, 2), pdf.get_shape())
+    a = [[1., 2]]
+    b = [[1., 2]]
+    x = [[.5, .5], [.3, .7]]
+    dist = beta_lib.Beta(a, b)
+    pdf = dist.prob(x)
+    self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], self.evaluate(pdf))
+    self.assertEqual((2, 2), pdf.get_shape())
 
   def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
-    with self.test_session():
-      a = [1., 2]
-      b = [1., 2]
-      x = [[.5, .5], [.2, .8]]
-      pdf = beta_lib.Beta(a, b).prob(x)
-      self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], self.evaluate(pdf))
-      self.assertEqual((2, 2), pdf.get_shape())
+    a = [1., 2]
+    b = [1., 2]
+    x = [[.5, .5], [.2, .8]]
+    pdf = beta_lib.Beta(a, b).prob(x)
+    self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], self.evaluate(pdf))
+    self.assertEqual((2, 2), pdf.get_shape())
 
   def testPdfXStretchedInBroadcastWhenSameRank(self):
-    with self.test_session():
-      a = [[1., 2], [2., 3]]
-      b = [[1., 2], [2., 3]]
-      x = [[.5, .5]]
-      pdf = beta_lib.Beta(a, b).prob(x)
-      self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
-      self.assertEqual((2, 2), pdf.get_shape())
+    a = [[1., 2], [2., 3]]
+    b = [[1., 2], [2., 3]]
+    x = [[.5, .5]]
+    pdf = beta_lib.Beta(a, b).prob(x)
+    self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
+    self.assertEqual((2, 2), pdf.get_shape())
 
   def testPdfXStretchedInBroadcastWhenLowerRank(self):
-    with self.test_session():
-      a = [[1., 2], [2., 3]]
-      b = [[1., 2], [2., 3]]
-      x = [.5, .5]
-      pdf = beta_lib.Beta(a, b).prob(x)
-      self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
-      self.assertEqual((2, 2), pdf.get_shape())
+    a = [[1., 2], [2., 3]]
+    b = [[1., 2], [2., 3]]
+    x = [.5, .5]
+    pdf = beta_lib.Beta(a, b).prob(x)
+    self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
+    self.assertEqual((2, 2), pdf.get_shape())
 
   def testBetaMean(self):
-    with session.Session():
-      a = [1., 2, 3]
-      b = [2., 4, 1.2]
-      dist = beta_lib.Beta(a, b)
-      self.assertEqual(dist.mean().get_shape(), (3,))
-      if not stats:
-        return
-      expected_mean = stats.beta.mean(a, b)
-      self.assertAllClose(expected_mean, self.evaluate(dist.mean()))
+    a = [1., 2, 3]
+    b = [2., 4, 1.2]
+    dist = beta_lib.Beta(a, b)
+    self.assertEqual(dist.mean().get_shape(), (3,))
+    if not stats:
+      return
+    expected_mean = stats.beta.mean(a, b)
+    self.assertAllClose(expected_mean, self.evaluate(dist.mean()))
 
   def testBetaVariance(self):
-    with session.Session():
-      a = [1., 2, 3]
-      b = [2., 4, 1.2]
-      dist = beta_lib.Beta(a, b)
-      self.assertEqual(dist.variance().get_shape(), (3,))
-      if not stats:
-        return
-      expected_variance = stats.beta.var(a, b)
-      self.assertAllClose(expected_variance, self.evaluate(dist.variance()))
+    a = [1., 2, 3]
+    b = [2., 4, 1.2]
+    dist = beta_lib.Beta(a, b)
+    self.assertEqual(dist.variance().get_shape(), (3,))
+    if not stats:
+      return
+    expected_variance = stats.beta.var(a, b)
+    self.assertAllClose(expected_variance, self.evaluate(dist.variance()))
 
   def testBetaMode(self):
-    with session.Session():
-      a = np.array([1.1, 2, 3])
-      b = np.array([2., 4, 1.2])
-      expected_mode = (a - 1) / (a + b - 2)
-      dist = beta_lib.Beta(a, b)
-      self.assertEqual(dist.mode().get_shape(), (3,))
-      self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
+    a = np.array([1.1, 2, 3])
+    b = np.array([2., 4, 1.2])
+    expected_mode = (a - 1) / (a + b - 2)
+    dist = beta_lib.Beta(a, b)
+    self.assertEqual(dist.mode().get_shape(), (3,))
+    self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
 
   def testBetaModeInvalid(self):
-    with session.Session():
-      a = np.array([1., 2, 3])
-      b = np.array([2., 4, 1.2])
-      dist = beta_lib.Beta(a, b, allow_nan_stats=False)
-      with self.assertRaisesOpError("Condition x < y.*"):
-        self.evaluate(dist.mode())
+    a = np.array([1., 2, 3])
+    b = np.array([2., 4, 1.2])
+    dist = beta_lib.Beta(a, b, allow_nan_stats=False)
+    with self.assertRaisesOpError("Condition x < y.*"):
+      self.evaluate(dist.mode())
 
-      a = np.array([2., 2, 3])
-      b = np.array([1., 4, 1.2])
-      dist = beta_lib.Beta(a, b, allow_nan_stats=False)
-      with self.assertRaisesOpError("Condition x < y.*"):
-        self.evaluate(dist.mode())
+    a = np.array([2., 2, 3])
+    b = np.array([1., 4, 1.2])
+    dist = beta_lib.Beta(a, b, allow_nan_stats=False)
+    with self.assertRaisesOpError("Condition x < y.*"):
+      self.evaluate(dist.mode())
 
   def testBetaModeEnableAllowNanStats(self):
-    with session.Session():
-      a = np.array([1., 2, 3])
-      b = np.array([2., 4, 1.2])
-      dist = beta_lib.Beta(a, b, allow_nan_stats=True)
+    a = np.array([1., 2, 3])
+    b = np.array([2., 4, 1.2])
+    dist = beta_lib.Beta(a, b, allow_nan_stats=True)
 
-      expected_mode = (a - 1) / (a + b - 2)
-      expected_mode[0] = np.nan
-      self.assertEqual((3,), dist.mode().get_shape())
-      self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
+    expected_mode = (a - 1) / (a + b - 2)
+    expected_mode[0] = np.nan
+    self.assertEqual((3,), dist.mode().get_shape())
+    self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
 
-      a = np.array([2., 2, 3])
-      b = np.array([1., 4, 1.2])
-      dist = beta_lib.Beta(a, b, allow_nan_stats=True)
+    a = np.array([2., 2, 3])
+    b = np.array([1., 4, 1.2])
+    dist = beta_lib.Beta(a, b, allow_nan_stats=True)
 
-      expected_mode = (a - 1) / (a + b - 2)
-      expected_mode[0] = np.nan
-      self.assertEqual((3,), dist.mode().get_shape())
-      self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
+    expected_mode = (a - 1) / (a + b - 2)
+    expected_mode[0] = np.nan
+    self.assertEqual((3,), dist.mode().get_shape())
+    self.assertAllClose(expected_mode, self.evaluate(dist.mode()))
 
   def testBetaEntropy(self):
-    with session.Session():
-      a = [1., 2, 3]
-      b = [2., 4, 1.2]
-      dist = beta_lib.Beta(a, b)
-      self.assertEqual(dist.entropy().get_shape(), (3,))
-      if not stats:
-        return
-      expected_entropy = stats.beta.entropy(a, b)
-      self.assertAllClose(expected_entropy, self.evaluate(dist.entropy()))
+    a = [1., 2, 3]
+    b = [2., 4, 1.2]
+    dist = beta_lib.Beta(a, b)
+    self.assertEqual(dist.entropy().get_shape(), (3,))
+    if not stats:
+      return
+    expected_entropy = stats.beta.entropy(a, b)
+    self.assertAllClose(expected_entropy, self.evaluate(dist.entropy()))
 
   def testBetaSample(self):
-    with self.test_session():
-      a = 1.
-      b = 2.
-      beta = beta_lib.Beta(a, b)
-      n = constant_op.constant(100000)
-      samples = beta.sample(n)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(sample_values.shape, (100000,))
-      self.assertFalse(np.any(sample_values < 0.0))
-      if not stats:
-        return
-      self.assertLess(
-          stats.kstest(
-              # Beta is a univariate distribution.
-              sample_values,
-              stats.beta(a=1., b=2.).cdf)[0],
-          0.01)
-      # The standard error of the sample mean is 1 / (sqrt(18 * n))
-      self.assertAllClose(
-          sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2)
-      self.assertAllClose(
-          np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1)
+    a = 1.
+    b = 2.
+    beta = beta_lib.Beta(a, b)
+    n = constant_op.constant(100000)
+    samples = beta.sample(n)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(sample_values.shape, (100000,))
+    self.assertFalse(np.any(sample_values < 0.0))
+    if not stats:
+      return
+    self.assertLess(
+        stats.kstest(
+            # Beta is a univariate distribution.
+            sample_values,
+            stats.beta(a=1., b=2.).cdf)[0],
+        0.01)
+    # The standard error of the sample mean is 1 / (sqrt(18 * n))
+    self.assertAllClose(
+        sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2)
+    self.assertAllClose(
+        np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1)
 
   def testBetaFullyReparameterized(self):
     a = constant_op.constant(1.0)
@@ -297,78 +274,71 @@ class BetaTest(test.TestCase):
 
   # Test that sampling with the same seed twice gives the same results.
   def testBetaSampleMultipleTimes(self):
-    with self.test_session():
-      a_val = 1.
-      b_val = 2.
-      n_val = 100
+    a_val = 1.
+    b_val = 2.
+    n_val = 100
 
-      random_seed.set_random_seed(654321)
-      beta1 = beta_lib.Beta(concentration1=a_val,
-                            concentration0=b_val,
-                            name="beta1")
-      samples1 = self.evaluate(beta1.sample(n_val, seed=123456))
+    random_seed.set_random_seed(654321)
+    beta1 = beta_lib.Beta(
+        concentration1=a_val, concentration0=b_val, name="beta1")
+    samples1 = self.evaluate(beta1.sample(n_val, seed=123456))
 
-      random_seed.set_random_seed(654321)
-      beta2 = beta_lib.Beta(concentration1=a_val,
-                            concentration0=b_val,
-                            name="beta2")
-      samples2 = self.evaluate(beta2.sample(n_val, seed=123456))
+    random_seed.set_random_seed(654321)
+    beta2 = beta_lib.Beta(
+        concentration1=a_val, concentration0=b_val, name="beta2")
+    samples2 = self.evaluate(beta2.sample(n_val, seed=123456))
 
-      self.assertAllClose(samples1, samples2)
+    self.assertAllClose(samples1, samples2)
 
   def testBetaSampleMultidimensional(self):
-    with self.test_session():
-      a = np.random.rand(3, 2, 2).astype(np.float32)
-      b = np.random.rand(3, 2, 2).astype(np.float32)
-      beta = beta_lib.Beta(a, b)
-      n = constant_op.constant(100000)
-      samples = beta.sample(n)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
-      self.assertFalse(np.any(sample_values < 0.0))
-      if not stats:
-        return
-      self.assertAllClose(
-          sample_values[:, 1, :].mean(axis=0),
-          stats.beta.mean(a, b)[1, :],
-          atol=1e-1)
+    a = np.random.rand(3, 2, 2).astype(np.float32)
+    b = np.random.rand(3, 2, 2).astype(np.float32)
+    beta = beta_lib.Beta(a, b)
+    n = constant_op.constant(100000)
+    samples = beta.sample(n)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(sample_values.shape, (100000, 3, 2, 2))
+    self.assertFalse(np.any(sample_values < 0.0))
+    if not stats:
+      return
+    self.assertAllClose(
+        sample_values[:, 1, :].mean(axis=0),
+        stats.beta.mean(a, b)[1, :],
+        atol=1e-1)
 
   def testBetaCdf(self):
-    with self.test_session():
-      shape = (30, 40, 50)
-      for dt in (np.float32, np.float64):
-        a = 10. * np.random.random(shape).astype(dt)
-        b = 10. * np.random.random(shape).astype(dt)
-        x = np.random.random(shape).astype(dt)
-        actual = self.evaluate(beta_lib.Beta(a, b).cdf(x))
-        self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
-        self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
-        if not stats:
-          return
-        self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
+    shape = (30, 40, 50)
+    for dt in (np.float32, np.float64):
+      a = 10. * np.random.random(shape).astype(dt)
+      b = 10. * np.random.random(shape).astype(dt)
+      x = np.random.random(shape).astype(dt)
+      actual = self.evaluate(beta_lib.Beta(a, b).cdf(x))
+      self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
+      self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
+      if not stats:
+        return
+      self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
 
   def testBetaLogCdf(self):
-    with self.test_session():
-      shape = (30, 40, 50)
-      for dt in (np.float32, np.float64):
-        a = 10. * np.random.random(shape).astype(dt)
-        b = 10. * np.random.random(shape).astype(dt)
-        x = np.random.random(shape).astype(dt)
-        actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)))
-        self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
-        self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
-        if not stats:
-          return
-        self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
+    shape = (30, 40, 50)
+    for dt in (np.float32, np.float64):
+      a = 10. * np.random.random(shape).astype(dt)
+      b = 10. * np.random.random(shape).astype(dt)
+      x = np.random.random(shape).astype(dt)
+      actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x)))
+      self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x)
+      self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x)
+      if not stats:
+        return
+      self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0)
 
   def testBetaWithSoftplusConcentration(self):
-    with self.test_session():
-      a, b = -4.2, -9.1
-      dist = beta_lib.BetaWithSoftplusConcentration(a, b)
-      self.assertAllClose(
-          self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1))
-      self.assertAllClose(
-          self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0))
+    a, b = -4.2, -9.1
+    dist = beta_lib.BetaWithSoftplusConcentration(a, b)
+    self.assertAllClose(
+        self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1))
+    self.assertAllClose(
+        self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0))
 
   def testBetaBetaKL(self):
     for shape in [(10,), (4, 5)]:
diff --git a/tensorflow/python/kernel_tests/distributions/bijector_test.py b/tensorflow/python/kernel_tests/distributions/bijector_test.py
index 8b11556330a..e20f59f48ac 100644
--- a/tensorflow/python/kernel_tests/distributions/bijector_test.py
+++ b/tensorflow/python/kernel_tests/distributions/bijector_test.py
@@ -36,11 +36,10 @@ class BaseBijectorTest(test.TestCase):
   """Tests properties of the Bijector base-class."""
 
   def testIsAbstract(self):
-    with self.test_session():
-      with self.assertRaisesRegexp(TypeError,
-                                   ("Can't instantiate abstract class Bijector "
-                                    "with abstract methods __init__")):
-        bijector.Bijector()  # pylint: disable=abstract-class-instantiated
+    with self.assertRaisesRegexp(TypeError,
+                                 ("Can't instantiate abstract class Bijector "
+                                  "with abstract methods __init__")):
+      bijector.Bijector()  # pylint: disable=abstract-class-instantiated
 
   def testDefaults(self):
     class _BareBonesBijector(bijector.Bijector):
@@ -136,7 +135,7 @@ class BijectorTestEventNdims(test.TestCase):
   def testBijectorDynamicEventNdims(self):
     bij = BrokenBijector(validate_args=True)
     event_ndims = array_ops.placeholder(dtype=np.int32, shape=None)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError("Expected scalar"):
         bij.forward_log_det_jacobian(1., event_ndims=event_ndims).eval({
             event_ndims: (1, 2)})
@@ -308,7 +307,7 @@ class BijectorReduceEventDimsTest(test.TestCase):
     event_ndims = array_ops.placeholder(dtype=np.int32, shape=[])
     bij = ExpOnlyJacobian(forward_min_event_ndims=1)
     bij.inverse_log_det_jacobian(x, event_ndims=event_ndims)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       ildj = sess.run(bij.inverse_log_det_jacobian(x, event_ndims=event_ndims),
                       feed_dict={event_ndims: 1})
     self.assertAllClose(-np.log(x_), ildj)
diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
index 67ed0447ede..cace5b3ba2c 100644
--- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
+++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
@@ -49,115 +49,102 @@ stats = try_import("scipy.stats")
 class DirichletTest(test.TestCase):
 
   def testSimpleShapes(self):
-    with self.test_session():
-      alpha = np.random.rand(3)
-      dist = dirichlet_lib.Dirichlet(alpha)
-      self.assertEqual(3, self.evaluate(dist.event_shape_tensor()))
-      self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor()))
-      self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
-      self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
+    alpha = np.random.rand(3)
+    dist = dirichlet_lib.Dirichlet(alpha)
+    self.assertEqual(3, self.evaluate(dist.event_shape_tensor()))
+    self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor()))
+    self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape)
+    self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
 
   def testComplexShapes(self):
-    with self.test_session():
-      alpha = np.random.rand(3, 2, 2)
-      dist = dirichlet_lib.Dirichlet(alpha)
-      self.assertEqual(2, self.evaluate(dist.event_shape_tensor()))
-      self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor()))
-      self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
-      self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
+    alpha = np.random.rand(3, 2, 2)
+    dist = dirichlet_lib.Dirichlet(alpha)
+    self.assertEqual(2, self.evaluate(dist.event_shape_tensor()))
+    self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor()))
+    self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape)
+    self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape)
 
   def testConcentrationProperty(self):
     alpha = [[1., 2, 3]]
-    with self.test_session():
-      dist = dirichlet_lib.Dirichlet(alpha)
-      self.assertEqual([1, 3], dist.concentration.get_shape())
-      self.assertAllClose(alpha, self.evaluate(dist.concentration))
+    dist = dirichlet_lib.Dirichlet(alpha)
+    self.assertEqual([1, 3], dist.concentration.get_shape())
+    self.assertAllClose(alpha, self.evaluate(dist.concentration))
 
   def testPdfXProper(self):
     alpha = [[1., 2, 3]]
-    with self.test_session():
-      dist = dirichlet_lib.Dirichlet(alpha, validate_args=True)
-      self.evaluate(dist.prob([.1, .3, .6]))
-      self.evaluate(dist.prob([.2, .3, .5]))
-      # Either condition can trigger.
-      with self.assertRaisesOpError("samples must be positive"):
-        self.evaluate(dist.prob([-1., 1.5, 0.5]))
-      with self.assertRaisesOpError("samples must be positive"):
-        self.evaluate(dist.prob([0., .1, .9]))
-      with self.assertRaisesOpError(
-          "sample last-dimension must sum to `1`"):
-        self.evaluate(dist.prob([.1, .2, .8]))
+    dist = dirichlet_lib.Dirichlet(alpha, validate_args=True)
+    self.evaluate(dist.prob([.1, .3, .6]))
+    self.evaluate(dist.prob([.2, .3, .5]))
+    # Either condition can trigger.
+    with self.assertRaisesOpError("samples must be positive"):
+      self.evaluate(dist.prob([-1., 1.5, 0.5]))
+    with self.assertRaisesOpError("samples must be positive"):
+      self.evaluate(dist.prob([0., .1, .9]))
+    with self.assertRaisesOpError("sample last-dimension must sum to `1`"):
+      self.evaluate(dist.prob([.1, .2, .8]))
 
   def testPdfZeroBatches(self):
-    with self.test_session():
-      alpha = [1., 2]
-      x = [.5, .5]
-      dist = dirichlet_lib.Dirichlet(alpha)
-      pdf = dist.prob(x)
-      self.assertAllClose(1., self.evaluate(pdf))
-      self.assertEqual((), pdf.get_shape())
+    alpha = [1., 2]
+    x = [.5, .5]
+    dist = dirichlet_lib.Dirichlet(alpha)
+    pdf = dist.prob(x)
+    self.assertAllClose(1., self.evaluate(pdf))
+    self.assertEqual((), pdf.get_shape())
 
   def testPdfZeroBatchesNontrivialX(self):
-    with self.test_session():
-      alpha = [1., 2]
-      x = [.3, .7]
-      dist = dirichlet_lib.Dirichlet(alpha)
-      pdf = dist.prob(x)
-      self.assertAllClose(7. / 5, self.evaluate(pdf))
-      self.assertEqual((), pdf.get_shape())
+    alpha = [1., 2]
+    x = [.3, .7]
+    dist = dirichlet_lib.Dirichlet(alpha)
+    pdf = dist.prob(x)
+    self.assertAllClose(7. / 5, self.evaluate(pdf))
+    self.assertEqual((), pdf.get_shape())
 
   def testPdfUniformZeroBatches(self):
-    with self.test_session():
-      # Corresponds to a uniform distribution
-      alpha = [1., 1, 1]
-      x = [[.2, .5, .3], [.3, .4, .3]]
-      dist = dirichlet_lib.Dirichlet(alpha)
-      pdf = dist.prob(x)
-      self.assertAllClose([2., 2.], self.evaluate(pdf))
-      self.assertEqual((2), pdf.get_shape())
+    # Corresponds to a uniform distribution
+    alpha = [1., 1, 1]
+    x = [[.2, .5, .3], [.3, .4, .3]]
+    dist = dirichlet_lib.Dirichlet(alpha)
+    pdf = dist.prob(x)
+    self.assertAllClose([2., 2.], self.evaluate(pdf))
+    self.assertEqual((2), pdf.get_shape())
 
   def testPdfAlphaStretchedInBroadcastWhenSameRank(self):
-    with self.test_session():
-      alpha = [[1., 2]]
-      x = [[.5, .5], [.3, .7]]
-      dist = dirichlet_lib.Dirichlet(alpha)
-      pdf = dist.prob(x)
-      self.assertAllClose([1., 7. / 5], self.evaluate(pdf))
-      self.assertEqual((2), pdf.get_shape())
+    alpha = [[1., 2]]
+    x = [[.5, .5], [.3, .7]]
+    dist = dirichlet_lib.Dirichlet(alpha)
+    pdf = dist.prob(x)
+    self.assertAllClose([1., 7. / 5], self.evaluate(pdf))
+    self.assertEqual((2), pdf.get_shape())
 
   def testPdfAlphaStretchedInBroadcastWhenLowerRank(self):
-    with self.test_session():
-      alpha = [1., 2]
-      x = [[.5, .5], [.2, .8]]
-      pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
-      self.assertAllClose([1., 8. / 5], self.evaluate(pdf))
-      self.assertEqual((2), pdf.get_shape())
+    alpha = [1., 2]
+    x = [[.5, .5], [.2, .8]]
+    pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
+    self.assertAllClose([1., 8. / 5], self.evaluate(pdf))
+    self.assertEqual((2), pdf.get_shape())
 
   def testPdfXStretchedInBroadcastWhenSameRank(self):
-    with self.test_session():
-      alpha = [[1., 2], [2., 3]]
-      x = [[.5, .5]]
-      pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
-      self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
-      self.assertEqual((2), pdf.get_shape())
+    alpha = [[1., 2], [2., 3]]
+    x = [[.5, .5]]
+    pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
+    self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
+    self.assertEqual((2), pdf.get_shape())
 
   def testPdfXStretchedInBroadcastWhenLowerRank(self):
-    with self.test_session():
-      alpha = [[1., 2], [2., 3]]
-      x = [.5, .5]
-      pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
-      self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
-      self.assertEqual((2), pdf.get_shape())
+    alpha = [[1., 2], [2., 3]]
+    x = [.5, .5]
+    pdf = dirichlet_lib.Dirichlet(alpha).prob(x)
+    self.assertAllClose([1., 3. / 2], self.evaluate(pdf))
+    self.assertEqual((2), pdf.get_shape())
 
   def testMean(self):
-    with self.test_session():
-      alpha = [1., 2, 3]
-      dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
-      self.assertEqual(dirichlet.mean().get_shape(), [3])
-      if not stats:
-        return
-      expected_mean = stats.dirichlet.mean(alpha)
-      self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean)
+    alpha = [1., 2, 3]
+    dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+    self.assertEqual(dirichlet.mean().get_shape(), [3])
+    if not stats:
+      return
+    expected_mean = stats.dirichlet.mean(alpha)
+    self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean)
 
   def testCovarianceFromSampling(self):
     alpha = np.array([[1., 2, 3],
@@ -197,73 +184,66 @@ class DirichletTest(test.TestCase):
     self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
 
   def testVariance(self):
-    with self.test_session():
-      alpha = [1., 2, 3]
-      denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
-      dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
-      self.assertEqual(dirichlet.covariance().get_shape(), (3, 3))
-      if not stats:
-        return
-      expected_covariance = np.diag(stats.dirichlet.var(alpha))
-      expected_covariance += [[0., -2, -3], [-2, 0, -6],
-                              [-3, -6, 0]] / denominator
-      self.assertAllClose(
-          self.evaluate(dirichlet.covariance()), expected_covariance)
+    alpha = [1., 2, 3]
+    denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1)
+    dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+    self.assertEqual(dirichlet.covariance().get_shape(), (3, 3))
+    if not stats:
+      return
+    expected_covariance = np.diag(stats.dirichlet.var(alpha))
+    expected_covariance += [[0., -2, -3], [-2, 0, -6], [-3, -6, 0]
+                           ] / denominator
+    self.assertAllClose(
+        self.evaluate(dirichlet.covariance()), expected_covariance)
 
   def testMode(self):
-    with self.test_session():
-      alpha = np.array([1.1, 2, 3])
-      expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
-      dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
-      self.assertEqual(dirichlet.mode().get_shape(), [3])
-      self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
+    alpha = np.array([1.1, 2, 3])
+    expected_mode = (alpha - 1) / (np.sum(alpha) - 3)
+    dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+    self.assertEqual(dirichlet.mode().get_shape(), [3])
+    self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
 
   def testModeInvalid(self):
-    with self.test_session():
-      alpha = np.array([1., 2, 3])
-      dirichlet = dirichlet_lib.Dirichlet(concentration=alpha,
-                                          allow_nan_stats=False)
-      with self.assertRaisesOpError("Condition x < y.*"):
-        self.evaluate(dirichlet.mode())
+    alpha = np.array([1., 2, 3])
+    dirichlet = dirichlet_lib.Dirichlet(
+        concentration=alpha, allow_nan_stats=False)
+    with self.assertRaisesOpError("Condition x < y.*"):
+      self.evaluate(dirichlet.mode())
 
   def testModeEnableAllowNanStats(self):
-    with self.test_session():
-      alpha = np.array([1., 2, 3])
-      dirichlet = dirichlet_lib.Dirichlet(concentration=alpha,
-                                          allow_nan_stats=True)
-      expected_mode = np.zeros_like(alpha) + np.nan
+    alpha = np.array([1., 2, 3])
+    dirichlet = dirichlet_lib.Dirichlet(
+        concentration=alpha, allow_nan_stats=True)
+    expected_mode = np.zeros_like(alpha) + np.nan
 
-      self.assertEqual(dirichlet.mode().get_shape(), [3])
-      self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
+    self.assertEqual(dirichlet.mode().get_shape(), [3])
+    self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode)
 
   def testEntropy(self):
-    with self.test_session():
-      alpha = [1., 2, 3]
-      dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
-      self.assertEqual(dirichlet.entropy().get_shape(), ())
-      if not stats:
-        return
-      expected_entropy = stats.dirichlet.entropy(alpha)
-      self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy)
+    alpha = [1., 2, 3]
+    dirichlet = dirichlet_lib.Dirichlet(concentration=alpha)
+    self.assertEqual(dirichlet.entropy().get_shape(), ())
+    if not stats:
+      return
+    expected_entropy = stats.dirichlet.entropy(alpha)
+    self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy)
 
   def testSample(self):
-    with self.test_session():
-      alpha = [1., 2]
-      dirichlet = dirichlet_lib.Dirichlet(alpha)
-      n = constant_op.constant(100000)
-      samples = dirichlet.sample(n)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(sample_values.shape, (100000, 2))
-      self.assertTrue(np.all(sample_values > 0.0))
-      if not stats:
-        return
-      self.assertLess(
-          stats.kstest(
-              # Beta is a univariate distribution.
-              sample_values[:, 0],
-              stats.beta(
-                  a=1., b=2.).cdf)[0],
-          0.01)
+    alpha = [1., 2]
+    dirichlet = dirichlet_lib.Dirichlet(alpha)
+    n = constant_op.constant(100000)
+    samples = dirichlet.sample(n)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(sample_values.shape, (100000, 2))
+    self.assertTrue(np.all(sample_values > 0.0))
+    if not stats:
+      return
+    self.assertLess(
+        stats.kstest(
+            # Beta is a univariate distribution.
+            sample_values[:, 0],
+            stats.beta(a=1., b=2.).cdf)[0],
+        0.01)
 
   def testDirichletFullyReparameterized(self):
     alpha = constant_op.constant([1.0, 2.0, 3.0])
diff --git a/tensorflow/python/kernel_tests/distributions/exponential_test.py b/tensorflow/python/kernel_tests/distributions/exponential_test.py
index 850da3e9697..27d12919121 100644
--- a/tensorflow/python/kernel_tests/distributions/exponential_test.py
+++ b/tensorflow/python/kernel_tests/distributions/exponential_test.py
@@ -22,7 +22,6 @@ import importlib
 
 import numpy as np
 
-from tensorflow.python.client import session
 from tensorflow.python.eager import backprop
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import test_util
@@ -48,121 +47,108 @@ stats = try_import("scipy.stats")
 class ExponentialTest(test.TestCase):
 
   def testExponentialLogPDF(self):
-    with session.Session():
-      batch_size = 6
-      lam = constant_op.constant([2.0] * batch_size)
-      lam_v = 2.0
-      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
-      exponential = exponential_lib.Exponential(rate=lam)
+    batch_size = 6
+    lam = constant_op.constant([2.0] * batch_size)
+    lam_v = 2.0
+    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+    exponential = exponential_lib.Exponential(rate=lam)
 
-      log_pdf = exponential.log_prob(x)
-      self.assertEqual(log_pdf.get_shape(), (6,))
+    log_pdf = exponential.log_prob(x)
+    self.assertEqual(log_pdf.get_shape(), (6,))
 
-      pdf = exponential.prob(x)
-      self.assertEqual(pdf.get_shape(), (6,))
+    pdf = exponential.prob(x)
+    self.assertEqual(pdf.get_shape(), (6,))
 
-      if not stats:
-        return
-      expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
-      self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
-      self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
+    if not stats:
+      return
+    expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)
+    self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
+    self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
 
   def testExponentialCDF(self):
-    with session.Session():
-      batch_size = 6
-      lam = constant_op.constant([2.0] * batch_size)
-      lam_v = 2.0
-      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+    batch_size = 6
+    lam = constant_op.constant([2.0] * batch_size)
+    lam_v = 2.0
+    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
 
-      exponential = exponential_lib.Exponential(rate=lam)
+    exponential = exponential_lib.Exponential(rate=lam)
 
-      cdf = exponential.cdf(x)
-      self.assertEqual(cdf.get_shape(), (6,))
+    cdf = exponential.cdf(x)
+    self.assertEqual(cdf.get_shape(), (6,))
 
-      if not stats:
-        return
-      expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
-      self.assertAllClose(self.evaluate(cdf), expected_cdf)
+    if not stats:
+      return
+    expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
+    self.assertAllClose(self.evaluate(cdf), expected_cdf)
 
   def testExponentialMean(self):
-    with session.Session():
-      lam_v = np.array([1.0, 4.0, 2.5])
-      exponential = exponential_lib.Exponential(rate=lam_v)
-      self.assertEqual(exponential.mean().get_shape(), (3,))
-      if not stats:
-        return
-      expected_mean = stats.expon.mean(scale=1 / lam_v)
-      self.assertAllClose(self.evaluate(exponential.mean()), expected_mean)
+    lam_v = np.array([1.0, 4.0, 2.5])
+    exponential = exponential_lib.Exponential(rate=lam_v)
+    self.assertEqual(exponential.mean().get_shape(), (3,))
+    if not stats:
+      return
+    expected_mean = stats.expon.mean(scale=1 / lam_v)
+    self.assertAllClose(self.evaluate(exponential.mean()), expected_mean)
 
   def testExponentialVariance(self):
-    with session.Session():
-      lam_v = np.array([1.0, 4.0, 2.5])
-      exponential = exponential_lib.Exponential(rate=lam_v)
-      self.assertEqual(exponential.variance().get_shape(), (3,))
-      if not stats:
-        return
-      expected_variance = stats.expon.var(scale=1 / lam_v)
-      self.assertAllClose(
-          self.evaluate(exponential.variance()), expected_variance)
+    lam_v = np.array([1.0, 4.0, 2.5])
+    exponential = exponential_lib.Exponential(rate=lam_v)
+    self.assertEqual(exponential.variance().get_shape(), (3,))
+    if not stats:
+      return
+    expected_variance = stats.expon.var(scale=1 / lam_v)
+    self.assertAllClose(
+        self.evaluate(exponential.variance()), expected_variance)
 
   def testExponentialEntropy(self):
-    with session.Session():
-      lam_v = np.array([1.0, 4.0, 2.5])
-      exponential = exponential_lib.Exponential(rate=lam_v)
-      self.assertEqual(exponential.entropy().get_shape(), (3,))
-      if not stats:
-        return
-      expected_entropy = stats.expon.entropy(scale=1 / lam_v)
-      self.assertAllClose(
-          self.evaluate(exponential.entropy()), expected_entropy)
+    lam_v = np.array([1.0, 4.0, 2.5])
+    exponential = exponential_lib.Exponential(rate=lam_v)
+    self.assertEqual(exponential.entropy().get_shape(), (3,))
+    if not stats:
+      return
+    expected_entropy = stats.expon.entropy(scale=1 / lam_v)
+    self.assertAllClose(self.evaluate(exponential.entropy()), expected_entropy)
 
   def testExponentialSample(self):
-    with self.test_session():
-      lam = constant_op.constant([3.0, 4.0])
-      lam_v = [3.0, 4.0]
-      n = constant_op.constant(100000)
-      exponential = exponential_lib.Exponential(rate=lam)
+    lam = constant_op.constant([3.0, 4.0])
+    lam_v = [3.0, 4.0]
+    n = constant_op.constant(100000)
+    exponential = exponential_lib.Exponential(rate=lam)
 
-      samples = exponential.sample(n, seed=137)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(sample_values.shape, (100000, 2))
-      self.assertFalse(np.any(sample_values < 0.0))
-      if not stats:
-        return
-      for i in range(2):
-        self.assertLess(
-            stats.kstest(
-                sample_values[:, i], stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
-            0.01)
+    samples = exponential.sample(n, seed=137)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(sample_values.shape, (100000, 2))
+    self.assertFalse(np.any(sample_values < 0.0))
+    if not stats:
+      return
+    for i in range(2):
+      self.assertLess(
+          stats.kstest(sample_values[:, i],
+                       stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
 
   def testExponentialSampleMultiDimensional(self):
-    with self.test_session():
-      batch_size = 2
-      lam_v = [3.0, 22.0]
-      lam = constant_op.constant([lam_v] * batch_size)
+    batch_size = 2
+    lam_v = [3.0, 22.0]
+    lam = constant_op.constant([lam_v] * batch_size)
 
-      exponential = exponential_lib.Exponential(rate=lam)
+    exponential = exponential_lib.Exponential(rate=lam)
 
-      n = 100000
-      samples = exponential.sample(n, seed=138)
-      self.assertEqual(samples.get_shape(), (n, batch_size, 2))
+    n = 100000
+    samples = exponential.sample(n, seed=138)
+    self.assertEqual(samples.get_shape(), (n, batch_size, 2))
 
-      sample_values = self.evaluate(samples)
+    sample_values = self.evaluate(samples)
 
-      self.assertFalse(np.any(sample_values < 0.0))
-      if not stats:
-        return
-      for i in range(2):
-        self.assertLess(
-            stats.kstest(
-                sample_values[:, 0, i],
-                stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
-            0.01)
-        self.assertLess(
-            stats.kstest(
-                sample_values[:, 1, i],
-                stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
-            0.01)
+    self.assertFalse(np.any(sample_values < 0.0))
+    if not stats:
+      return
+    for i in range(2):
+      self.assertLess(
+          stats.kstest(sample_values[:, 0, i],
+                       stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
+      self.assertLess(
+          stats.kstest(sample_values[:, 1, i],
+                       stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01)
 
   def testFullyReparameterized(self):
     lam = constant_op.constant([0.1, 1.0])
@@ -174,11 +160,10 @@ class ExponentialTest(test.TestCase):
     self.assertIsNotNone(grad_lam)
 
   def testExponentialWithSoftplusRate(self):
-    with self.test_session():
-      lam = [-2.2, -3.4]
-      exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam)
-      self.assertAllClose(
-          self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate))
+    lam = [-2.2, -3.4]
+    exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam)
+    self.assertAllClose(
+        self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/distributions/gamma_test.py b/tensorflow/python/kernel_tests/distributions/gamma_test.py
index 297e20264c6..4eff40b0295 100644
--- a/tensorflow/python/kernel_tests/distributions/gamma_test.py
+++ b/tensorflow/python/kernel_tests/distributions/gamma_test.py
@@ -50,221 +50,203 @@ stats = try_import("scipy.stats")
 class GammaTest(test.TestCase):
 
   def testGammaShape(self):
-    with self.test_session():
-      alpha = constant_op.constant([3.0] * 5)
-      beta = constant_op.constant(11.0)
-      gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+    alpha = constant_op.constant([3.0] * 5)
+    beta = constant_op.constant(11.0)
+    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
 
-      self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,))
-      self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5]))
-      self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), [])
-      self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([]))
+    self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,))
+    self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5]))
+    self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), [])
+    self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([]))
 
   def testGammaLogPDF(self):
-    with self.test_session():
-      batch_size = 6
-      alpha = constant_op.constant([2.0] * batch_size)
-      beta = constant_op.constant([3.0] * batch_size)
-      alpha_v = 2.0
-      beta_v = 3.0
-      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
-      gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
-      log_pdf = gamma.log_prob(x)
-      self.assertEqual(log_pdf.get_shape(), (6,))
-      pdf = gamma.prob(x)
-      self.assertEqual(pdf.get_shape(), (6,))
-      if not stats:
-        return
-      expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
-      self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
-      self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
+    batch_size = 6
+    alpha = constant_op.constant([2.0] * batch_size)
+    beta = constant_op.constant([3.0] * batch_size)
+    alpha_v = 2.0
+    beta_v = 3.0
+    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+    log_pdf = gamma.log_prob(x)
+    self.assertEqual(log_pdf.get_shape(), (6,))
+    pdf = gamma.prob(x)
+    self.assertEqual(pdf.get_shape(), (6,))
+    if not stats:
+      return
+    expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
+    self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
+    self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
 
   def testGammaLogPDFMultidimensional(self):
-    with self.test_session():
-      batch_size = 6
-      alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
-      beta = constant_op.constant([[3.0, 4.0]] * batch_size)
-      alpha_v = np.array([2.0, 4.0])
-      beta_v = np.array([3.0, 4.0])
-      x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
-      gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
-      log_pdf = gamma.log_prob(x)
-      log_pdf_values = self.evaluate(log_pdf)
-      self.assertEqual(log_pdf.get_shape(), (6, 2))
-      pdf = gamma.prob(x)
-      pdf_values = self.evaluate(pdf)
-      self.assertEqual(pdf.get_shape(), (6, 2))
-      if not stats:
-        return
-      expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
-      self.assertAllClose(log_pdf_values, expected_log_pdf)
-      self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+    batch_size = 6
+    alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
+    beta = constant_op.constant([[3.0, 4.0]] * batch_size)
+    alpha_v = np.array([2.0, 4.0])
+    beta_v = np.array([3.0, 4.0])
+    x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+    log_pdf = gamma.log_prob(x)
+    log_pdf_values = self.evaluate(log_pdf)
+    self.assertEqual(log_pdf.get_shape(), (6, 2))
+    pdf = gamma.prob(x)
+    pdf_values = self.evaluate(pdf)
+    self.assertEqual(pdf.get_shape(), (6, 2))
+    if not stats:
+      return
+    expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
+    self.assertAllClose(log_pdf_values, expected_log_pdf)
+    self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
 
   def testGammaLogPDFMultidimensionalBroadcasting(self):
-    with self.test_session():
-      batch_size = 6
-      alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
-      beta = constant_op.constant(3.0)
-      alpha_v = np.array([2.0, 4.0])
-      beta_v = 3.0
-      x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
-      gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
-      log_pdf = gamma.log_prob(x)
-      log_pdf_values = self.evaluate(log_pdf)
-      self.assertEqual(log_pdf.get_shape(), (6, 2))
-      pdf = gamma.prob(x)
-      pdf_values = self.evaluate(pdf)
-      self.assertEqual(pdf.get_shape(), (6, 2))
+    batch_size = 6
+    alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
+    beta = constant_op.constant(3.0)
+    alpha_v = np.array([2.0, 4.0])
+    beta_v = 3.0
+    x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+    log_pdf = gamma.log_prob(x)
+    log_pdf_values = self.evaluate(log_pdf)
+    self.assertEqual(log_pdf.get_shape(), (6, 2))
+    pdf = gamma.prob(x)
+    pdf_values = self.evaluate(pdf)
+    self.assertEqual(pdf.get_shape(), (6, 2))
 
-      if not stats:
-        return
-      expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
-      self.assertAllClose(log_pdf_values, expected_log_pdf)
-      self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+    if not stats:
+      return
+    expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
+    self.assertAllClose(log_pdf_values, expected_log_pdf)
+    self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
 
   def testGammaCDF(self):
-    with self.test_session():
-      batch_size = 6
-      alpha = constant_op.constant([2.0] * batch_size)
-      beta = constant_op.constant([3.0] * batch_size)
-      alpha_v = 2.0
-      beta_v = 3.0
-      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+    batch_size = 6
+    alpha = constant_op.constant([2.0] * batch_size)
+    beta = constant_op.constant([3.0] * batch_size)
+    alpha_v = 2.0
+    beta_v = 3.0
+    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
 
-      gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
-      cdf = gamma.cdf(x)
-      self.assertEqual(cdf.get_shape(), (6,))
-      if not stats:
-        return
-      expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
-      self.assertAllClose(self.evaluate(cdf), expected_cdf)
+    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+    cdf = gamma.cdf(x)
+    self.assertEqual(cdf.get_shape(), (6,))
+    if not stats:
+      return
+    expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
+    self.assertAllClose(self.evaluate(cdf), expected_cdf)
 
   def testGammaMean(self):
-    with self.test_session():
-      alpha_v = np.array([1.0, 3.0, 2.5])
-      beta_v = np.array([1.0, 4.0, 5.0])
-      gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
-      self.assertEqual(gamma.mean().get_shape(), (3,))
-      if not stats:
-        return
-      expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
-      self.assertAllClose(self.evaluate(gamma.mean()), expected_means)
+    alpha_v = np.array([1.0, 3.0, 2.5])
+    beta_v = np.array([1.0, 4.0, 5.0])
+    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+    self.assertEqual(gamma.mean().get_shape(), (3,))
+    if not stats:
+      return
+    expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
+    self.assertAllClose(self.evaluate(gamma.mean()), expected_means)
 
   def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
-    with self.test_session():
-      alpha_v = np.array([5.5, 3.0, 2.5])
-      beta_v = np.array([1.0, 4.0, 5.0])
-      gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
-      expected_modes = (alpha_v - 1) / beta_v
-      self.assertEqual(gamma.mode().get_shape(), (3,))
-      self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
+    alpha_v = np.array([5.5, 3.0, 2.5])
+    beta_v = np.array([1.0, 4.0, 5.0])
+    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+    expected_modes = (alpha_v - 1) / beta_v
+    self.assertEqual(gamma.mode().get_shape(), (3,))
+    self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
 
   def testGammaModeAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
-    with self.test_session():
-      # Mode will not be defined for the first entry.
-      alpha_v = np.array([0.5, 3.0, 2.5])
-      beta_v = np.array([1.0, 4.0, 5.0])
-      gamma = gamma_lib.Gamma(concentration=alpha_v,
-                              rate=beta_v,
-                              allow_nan_stats=False)
-      with self.assertRaisesOpError("x < y"):
-        self.evaluate(gamma.mode())
+    # Mode will not be defined for the first entry.
+    alpha_v = np.array([0.5, 3.0, 2.5])
+    beta_v = np.array([1.0, 4.0, 5.0])
+    gamma = gamma_lib.Gamma(
+        concentration=alpha_v, rate=beta_v, allow_nan_stats=False)
+    with self.assertRaisesOpError("x < y"):
+      self.evaluate(gamma.mode())
 
   def testGammaModeAllowNanStatsIsTrueReturnsNaNforUndefinedBatchMembers(self):
-    with self.test_session():
-      # Mode will not be defined for the first entry.
-      alpha_v = np.array([0.5, 3.0, 2.5])
-      beta_v = np.array([1.0, 4.0, 5.0])
-      gamma = gamma_lib.Gamma(concentration=alpha_v,
-                              rate=beta_v,
-                              allow_nan_stats=True)
-      expected_modes = (alpha_v - 1) / beta_v
-      expected_modes[0] = np.nan
-      self.assertEqual(gamma.mode().get_shape(), (3,))
-      self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
+    # Mode will not be defined for the first entry.
+    alpha_v = np.array([0.5, 3.0, 2.5])
+    beta_v = np.array([1.0, 4.0, 5.0])
+    gamma = gamma_lib.Gamma(
+        concentration=alpha_v, rate=beta_v, allow_nan_stats=True)
+    expected_modes = (alpha_v - 1) / beta_v
+    expected_modes[0] = np.nan
+    self.assertEqual(gamma.mode().get_shape(), (3,))
+    self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
 
   def testGammaVariance(self):
-    with self.test_session():
-      alpha_v = np.array([1.0, 3.0, 2.5])
-      beta_v = np.array([1.0, 4.0, 5.0])
-      gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
-      self.assertEqual(gamma.variance().get_shape(), (3,))
-      if not stats:
-        return
-      expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
-      self.assertAllClose(self.evaluate(gamma.variance()), expected_variances)
+    alpha_v = np.array([1.0, 3.0, 2.5])
+    beta_v = np.array([1.0, 4.0, 5.0])
+    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+    self.assertEqual(gamma.variance().get_shape(), (3,))
+    if not stats:
+      return
+    expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
+    self.assertAllClose(self.evaluate(gamma.variance()), expected_variances)
 
   def testGammaStd(self):
-    with self.test_session():
-      alpha_v = np.array([1.0, 3.0, 2.5])
-      beta_v = np.array([1.0, 4.0, 5.0])
-      gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
-      self.assertEqual(gamma.stddev().get_shape(), (3,))
-      if not stats:
-        return
-      expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v)
-      self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev)
+    alpha_v = np.array([1.0, 3.0, 2.5])
+    beta_v = np.array([1.0, 4.0, 5.0])
+    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+    self.assertEqual(gamma.stddev().get_shape(), (3,))
+    if not stats:
+      return
+    expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v)
+    self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev)
 
   def testGammaEntropy(self):
-    with self.test_session():
-      alpha_v = np.array([1.0, 3.0, 2.5])
-      beta_v = np.array([1.0, 4.0, 5.0])
-      gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
-      self.assertEqual(gamma.entropy().get_shape(), (3,))
-      if not stats:
-        return
-      expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
-      self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy)
+    alpha_v = np.array([1.0, 3.0, 2.5])
+    beta_v = np.array([1.0, 4.0, 5.0])
+    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+    self.assertEqual(gamma.entropy().get_shape(), (3,))
+    if not stats:
+      return
+    expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
+    self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy)
 
   def testGammaSampleSmallAlpha(self):
-    with self.test_session():
-      alpha_v = 0.05
-      beta_v = 1.0
-      alpha = constant_op.constant(alpha_v)
-      beta = constant_op.constant(beta_v)
-      n = 100000
-      gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
-      samples = gamma.sample(n, seed=137)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(samples.get_shape(), (n,))
-      self.assertEqual(sample_values.shape, (n,))
-      self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
-      if not stats:
-        return
-      self.assertAllClose(
-          sample_values.mean(),
-          stats.gamma.mean(
-              alpha_v, scale=1 / beta_v),
-          atol=.01)
-      self.assertAllClose(
-          sample_values.var(),
-          stats.gamma.var(alpha_v, scale=1 / beta_v),
-          atol=.15)
+    alpha_v = 0.05
+    beta_v = 1.0
+    alpha = constant_op.constant(alpha_v)
+    beta = constant_op.constant(beta_v)
+    n = 100000
+    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+    samples = gamma.sample(n, seed=137)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(samples.get_shape(), (n,))
+    self.assertEqual(sample_values.shape, (n,))
+    self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
+    if not stats:
+      return
+    self.assertAllClose(
+        sample_values.mean(),
+        stats.gamma.mean(alpha_v, scale=1 / beta_v),
+        atol=.01)
+    self.assertAllClose(
+        sample_values.var(),
+        stats.gamma.var(alpha_v, scale=1 / beta_v),
+        atol=.15)
 
   def testGammaSample(self):
-    with self.test_session():
-      alpha_v = 4.0
-      beta_v = 3.0
-      alpha = constant_op.constant(alpha_v)
-      beta = constant_op.constant(beta_v)
-      n = 100000
-      gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
-      samples = gamma.sample(n, seed=137)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(samples.get_shape(), (n,))
-      self.assertEqual(sample_values.shape, (n,))
-      self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
-      if not stats:
-        return
-      self.assertAllClose(
-          sample_values.mean(),
-          stats.gamma.mean(
-              alpha_v, scale=1 / beta_v),
-          atol=.01)
-      self.assertAllClose(
-          sample_values.var(),
-          stats.gamma.var(alpha_v, scale=1 / beta_v),
-          atol=.15)
+    alpha_v = 4.0
+    beta_v = 3.0
+    alpha = constant_op.constant(alpha_v)
+    beta = constant_op.constant(beta_v)
+    n = 100000
+    gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+    samples = gamma.sample(n, seed=137)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(samples.get_shape(), (n,))
+    self.assertEqual(sample_values.shape, (n,))
+    self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
+    if not stats:
+      return
+    self.assertAllClose(
+        sample_values.mean(),
+        stats.gamma.mean(alpha_v, scale=1 / beta_v),
+        atol=.01)
+    self.assertAllClose(
+        sample_values.var(),
+        stats.gamma.var(alpha_v, scale=1 / beta_v),
+        atol=.15)
 
   def testGammaFullyReparameterized(self):
     alpha = constant_op.constant(4.0)
@@ -279,37 +261,37 @@ class GammaTest(test.TestCase):
     self.assertIsNotNone(grad_beta)
 
   def testGammaSampleMultiDimensional(self):
-    with self.test_session():
-      alpha_v = np.array([np.arange(1, 101, dtype=np.float32)])  # 1 x 100
-      beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T  # 10 x 1
-      gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
-      n = 10000
-      samples = gamma.sample(n, seed=137)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(samples.get_shape(), (n, 10, 100))
-      self.assertEqual(sample_values.shape, (n, 10, 100))
-      zeros = np.zeros_like(alpha_v + beta_v)  # 10 x 100
-      alpha_bc = alpha_v + zeros
-      beta_bc = beta_v + zeros
-      if not stats:
-        return
-      self.assertAllClose(
-          sample_values.mean(axis=0),
-          stats.gamma.mean(
-              alpha_bc, scale=1 / beta_bc),
-          atol=0., rtol=.05)
-      self.assertAllClose(
-          sample_values.var(axis=0),
-          stats.gamma.var(alpha_bc, scale=1 / beta_bc),
-          atol=10.0, rtol=0.)
-      fails = 0
-      trials = 0
-      for ai, a in enumerate(np.reshape(alpha_v, [-1])):
-        for bi, b in enumerate(np.reshape(beta_v, [-1])):
-          s = sample_values[:, bi, ai]
-          trials += 1
-          fails += 0 if self._kstest(a, b, s) else 1
-      self.assertLess(fails, trials * 0.03)
+    alpha_v = np.array([np.arange(1, 101, dtype=np.float32)])  # 1 x 100
+    beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T  # 10 x 1
+    gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
+    n = 10000
+    samples = gamma.sample(n, seed=137)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(samples.get_shape(), (n, 10, 100))
+    self.assertEqual(sample_values.shape, (n, 10, 100))
+    zeros = np.zeros_like(alpha_v + beta_v)  # 10 x 100
+    alpha_bc = alpha_v + zeros
+    beta_bc = beta_v + zeros
+    if not stats:
+      return
+    self.assertAllClose(
+        sample_values.mean(axis=0),
+        stats.gamma.mean(alpha_bc, scale=1 / beta_bc),
+        atol=0.,
+        rtol=.05)
+    self.assertAllClose(
+        sample_values.var(axis=0),
+        stats.gamma.var(alpha_bc, scale=1 / beta_bc),
+        atol=10.0,
+        rtol=0.)
+    fails = 0
+    trials = 0
+    for ai, a in enumerate(np.reshape(alpha_v, [-1])):
+      for bi, b in enumerate(np.reshape(beta_v, [-1])):
+        s = sample_values[:, bi, ai]
+        trials += 1
+        fails += 0 if self._kstest(a, b, s) else 1
+    self.assertLess(fails, trials * 0.03)
 
   def _kstest(self, alpha, beta, samples):
     # Uses the Kolmogorov-Smirnov test for goodness of fit.
@@ -320,30 +302,29 @@ class GammaTest(test.TestCase):
     return ks < 0.02
 
   def testGammaPdfOfSampleMultiDims(self):
-    with self.test_session():
-      gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]])
-      num = 50000
-      samples = gamma.sample(num, seed=137)
-      pdfs = gamma.prob(samples)
-      sample_vals, pdf_vals = self.evaluate([samples, pdfs])
-      self.assertEqual(samples.get_shape(), (num, 2, 2))
-      self.assertEqual(pdfs.get_shape(), (num, 2, 2))
-      self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
-      self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
-      self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
-      self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
-      if not stats:
-        return
-      self.assertAllClose(
-          stats.gamma.mean(
-              [[7., 11.], [7., 11.]], scale=1 / np.array([[5., 5.], [6., 6.]])),
-          sample_vals.mean(axis=0),
-          atol=.1)
-      self.assertAllClose(
-          stats.gamma.var([[7., 11.], [7., 11.]],
-                          scale=1 / np.array([[5., 5.], [6., 6.]])),
-          sample_vals.var(axis=0),
-          atol=.1)
+    gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]])
+    num = 50000
+    samples = gamma.sample(num, seed=137)
+    pdfs = gamma.prob(samples)
+    sample_vals, pdf_vals = self.evaluate([samples, pdfs])
+    self.assertEqual(samples.get_shape(), (num, 2, 2))
+    self.assertEqual(pdfs.get_shape(), (num, 2, 2))
+    self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
+    self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
+    self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
+    self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
+    if not stats:
+      return
+    self.assertAllClose(
+        stats.gamma.mean([[7., 11.], [7., 11.]],
+                         scale=1 / np.array([[5., 5.], [6., 6.]])),
+        sample_vals.mean(axis=0),
+        atol=.1)
+    self.assertAllClose(
+        stats.gamma.var([[7., 11.], [7., 11.]],
+                        scale=1 / np.array([[5., 5.], [6., 6.]])),
+        sample_vals.var(axis=0),
+        atol=.1)
 
   def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3):
     s_p = zip(sample_vals, pdf_vals)
@@ -356,32 +337,29 @@ class GammaTest(test.TestCase):
     self.assertNear(1., total, err=err)
 
   def testGammaNonPositiveInitializationParamsRaises(self):
-    with self.test_session():
-      alpha_v = constant_op.constant(0.0, name="alpha")
-      beta_v = constant_op.constant(1.0, name="beta")
-      with self.assertRaisesOpError("x > 0"):
-        gamma = gamma_lib.Gamma(concentration=alpha_v,
-                                rate=beta_v,
-                                validate_args=True)
-        self.evaluate(gamma.mean())
-      alpha_v = constant_op.constant(1.0, name="alpha")
-      beta_v = constant_op.constant(0.0, name="beta")
-      with self.assertRaisesOpError("x > 0"):
-        gamma = gamma_lib.Gamma(concentration=alpha_v,
-                                rate=beta_v,
-                                validate_args=True)
-        self.evaluate(gamma.mean())
+    alpha_v = constant_op.constant(0.0, name="alpha")
+    beta_v = constant_op.constant(1.0, name="beta")
+    with self.assertRaisesOpError("x > 0"):
+      gamma = gamma_lib.Gamma(
+          concentration=alpha_v, rate=beta_v, validate_args=True)
+      self.evaluate(gamma.mean())
+    alpha_v = constant_op.constant(1.0, name="alpha")
+    beta_v = constant_op.constant(0.0, name="beta")
+    with self.assertRaisesOpError("x > 0"):
+      gamma = gamma_lib.Gamma(
+          concentration=alpha_v, rate=beta_v, validate_args=True)
+      self.evaluate(gamma.mean())
 
   def testGammaWithSoftplusConcentrationRate(self):
-    with self.test_session():
-      alpha_v = constant_op.constant([0.0, -2.1], name="alpha")
-      beta_v = constant_op.constant([1.0, -3.6], name="beta")
-      gamma = gamma_lib.GammaWithSoftplusConcentrationRate(
-          concentration=alpha_v, rate=beta_v)
-      self.assertAllEqual(self.evaluate(nn_ops.softplus(alpha_v)),
-                          self.evaluate(gamma.concentration))
-      self.assertAllEqual(self.evaluate(nn_ops.softplus(beta_v)),
-                          self.evaluate(gamma.rate))
+    alpha_v = constant_op.constant([0.0, -2.1], name="alpha")
+    beta_v = constant_op.constant([1.0, -3.6], name="beta")
+    gamma = gamma_lib.GammaWithSoftplusConcentrationRate(
+        concentration=alpha_v, rate=beta_v)
+    self.assertAllEqual(
+        self.evaluate(nn_ops.softplus(alpha_v)),
+        self.evaluate(gamma.concentration))
+    self.assertAllEqual(
+        self.evaluate(nn_ops.softplus(beta_v)), self.evaluate(gamma.rate))
 
   def testGammaGammaKL(self):
     alpha0 = np.array([3.])
@@ -391,15 +369,14 @@ class GammaTest(test.TestCase):
     beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.])
 
     # Build graph.
-    with self.test_session():
-      g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0)
-      g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1)
-      x = g0.sample(int(1e4), seed=0)
-      kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0)
-      kl_actual = kullback_leibler.kl_divergence(g0, g1)
+    g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0)
+    g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1)
+    x = g0.sample(int(1e4), seed=0)
+    kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0)
+    kl_actual = kullback_leibler.kl_divergence(g0, g1)
 
-      # Execute graph.
-      [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual])
+    # Execute graph.
+    [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual])
 
     self.assertEqual(beta0.shape, kl_actual.get_shape())
 
diff --git a/tensorflow/python/kernel_tests/distributions/laplace_test.py b/tensorflow/python/kernel_tests/distributions/laplace_test.py
index 24b243f647e..630c2cb4241 100644
--- a/tensorflow/python/kernel_tests/distributions/laplace_test.py
+++ b/tensorflow/python/kernel_tests/distributions/laplace_test.py
@@ -21,7 +21,6 @@ import importlib
 
 import numpy as np
 
-from tensorflow.python.client import session
 from tensorflow.python.eager import backprop
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import tensor_shape
@@ -49,212 +48,198 @@ stats = try_import("scipy.stats")
 class LaplaceTest(test.TestCase):
 
   def testLaplaceShape(self):
-    with self.test_session():
-      loc = constant_op.constant([3.0] * 5)
-      scale = constant_op.constant(11.0)
-      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+    loc = constant_op.constant([3.0] * 5)
+    scale = constant_op.constant(11.0)
+    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
 
-      self.assertEqual(self.evaluate(laplace.batch_shape_tensor()), (5,))
-      self.assertEqual(laplace.batch_shape, tensor_shape.TensorShape([5]))
-      self.assertAllEqual(self.evaluate(laplace.event_shape_tensor()), [])
-      self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([]))
+    self.assertEqual(self.evaluate(laplace.batch_shape_tensor()), (5,))
+    self.assertEqual(laplace.batch_shape, tensor_shape.TensorShape([5]))
+    self.assertAllEqual(self.evaluate(laplace.event_shape_tensor()), [])
+    self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([]))
 
   def testLaplaceLogPDF(self):
-    with self.test_session():
-      batch_size = 6
-      loc = constant_op.constant([2.0] * batch_size)
-      scale = constant_op.constant([3.0] * batch_size)
-      loc_v = 2.0
-      scale_v = 3.0
-      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
-      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
-      log_pdf = laplace.log_prob(x)
-      self.assertEqual(log_pdf.get_shape(), (6,))
-      if not stats:
-        return
-      expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
-      self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
+    batch_size = 6
+    loc = constant_op.constant([2.0] * batch_size)
+    scale = constant_op.constant([3.0] * batch_size)
+    loc_v = 2.0
+    scale_v = 3.0
+    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+    log_pdf = laplace.log_prob(x)
+    self.assertEqual(log_pdf.get_shape(), (6,))
+    if not stats:
+      return
+    expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
+    self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
 
-      pdf = laplace.prob(x)
-      self.assertEqual(pdf.get_shape(), (6,))
-      self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
+    pdf = laplace.prob(x)
+    self.assertEqual(pdf.get_shape(), (6,))
+    self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
 
   def testLaplaceLogPDFMultidimensional(self):
-    with self.test_session():
-      batch_size = 6
-      loc = constant_op.constant([[2.0, 4.0]] * batch_size)
-      scale = constant_op.constant([[3.0, 4.0]] * batch_size)
-      loc_v = np.array([2.0, 4.0])
-      scale_v = np.array([3.0, 4.0])
-      x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
-      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
-      log_pdf = laplace.log_prob(x)
-      log_pdf_values = self.evaluate(log_pdf)
-      self.assertEqual(log_pdf.get_shape(), (6, 2))
+    batch_size = 6
+    loc = constant_op.constant([[2.0, 4.0]] * batch_size)
+    scale = constant_op.constant([[3.0, 4.0]] * batch_size)
+    loc_v = np.array([2.0, 4.0])
+    scale_v = np.array([3.0, 4.0])
+    x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+    log_pdf = laplace.log_prob(x)
+    log_pdf_values = self.evaluate(log_pdf)
+    self.assertEqual(log_pdf.get_shape(), (6, 2))
 
-      pdf = laplace.prob(x)
-      pdf_values = self.evaluate(pdf)
-      self.assertEqual(pdf.get_shape(), (6, 2))
-      if not stats:
-        return
-      expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
-      self.assertAllClose(log_pdf_values, expected_log_pdf)
-      self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+    pdf = laplace.prob(x)
+    pdf_values = self.evaluate(pdf)
+    self.assertEqual(pdf.get_shape(), (6, 2))
+    if not stats:
+      return
+    expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
+    self.assertAllClose(log_pdf_values, expected_log_pdf)
+    self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
 
   def testLaplaceLogPDFMultidimensionalBroadcasting(self):
-    with self.test_session():
-      batch_size = 6
-      loc = constant_op.constant([[2.0, 4.0]] * batch_size)
-      scale = constant_op.constant(3.0)
-      loc_v = np.array([2.0, 4.0])
-      scale_v = 3.0
-      x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
-      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
-      log_pdf = laplace.log_prob(x)
-      log_pdf_values = self.evaluate(log_pdf)
-      self.assertEqual(log_pdf.get_shape(), (6, 2))
+    batch_size = 6
+    loc = constant_op.constant([[2.0, 4.0]] * batch_size)
+    scale = constant_op.constant(3.0)
+    loc_v = np.array([2.0, 4.0])
+    scale_v = 3.0
+    x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
+    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+    log_pdf = laplace.log_prob(x)
+    log_pdf_values = self.evaluate(log_pdf)
+    self.assertEqual(log_pdf.get_shape(), (6, 2))
 
-      pdf = laplace.prob(x)
-      pdf_values = self.evaluate(pdf)
-      self.assertEqual(pdf.get_shape(), (6, 2))
-      if not stats:
-        return
-      expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
-      self.assertAllClose(log_pdf_values, expected_log_pdf)
-      self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
+    pdf = laplace.prob(x)
+    pdf_values = self.evaluate(pdf)
+    self.assertEqual(pdf.get_shape(), (6, 2))
+    if not stats:
+      return
+    expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v)
+    self.assertAllClose(log_pdf_values, expected_log_pdf)
+    self.assertAllClose(pdf_values, np.exp(expected_log_pdf))
 
   def testLaplaceCDF(self):
-    with self.test_session():
-      batch_size = 6
-      loc = constant_op.constant([2.0] * batch_size)
-      scale = constant_op.constant([3.0] * batch_size)
-      loc_v = 2.0
-      scale_v = 3.0
-      x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+    batch_size = 6
+    loc = constant_op.constant([2.0] * batch_size)
+    scale = constant_op.constant([3.0] * batch_size)
+    loc_v = 2.0
+    scale_v = 3.0
+    x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
 
-      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
 
-      cdf = laplace.cdf(x)
-      self.assertEqual(cdf.get_shape(), (6,))
-      if not stats:
-        return
-      expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v)
-      self.assertAllClose(self.evaluate(cdf), expected_cdf)
+    cdf = laplace.cdf(x)
+    self.assertEqual(cdf.get_shape(), (6,))
+    if not stats:
+      return
+    expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v)
+    self.assertAllClose(self.evaluate(cdf), expected_cdf)
 
   def testLaplaceLogCDF(self):
-    with self.test_session():
-      batch_size = 6
-      loc = constant_op.constant([2.0] * batch_size)
-      scale = constant_op.constant([3.0] * batch_size)
-      loc_v = 2.0
-      scale_v = 3.0
-      x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+    batch_size = 6
+    loc = constant_op.constant([2.0] * batch_size)
+    scale = constant_op.constant([3.0] * batch_size)
+    loc_v = 2.0
+    scale_v = 3.0
+    x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
 
-      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
 
-      cdf = laplace.log_cdf(x)
-      self.assertEqual(cdf.get_shape(), (6,))
-      if not stats:
-        return
-      expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v)
-      self.assertAllClose(self.evaluate(cdf), expected_cdf)
+    cdf = laplace.log_cdf(x)
+    self.assertEqual(cdf.get_shape(), (6,))
+    if not stats:
+      return
+    expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v)
+    self.assertAllClose(self.evaluate(cdf), expected_cdf)
 
   def testLaplaceLogSurvivalFunction(self):
-    with self.test_session():
-      batch_size = 6
-      loc = constant_op.constant([2.0] * batch_size)
-      scale = constant_op.constant([3.0] * batch_size)
-      loc_v = 2.0
-      scale_v = 3.0
-      x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
+    batch_size = 6
+    loc = constant_op.constant([2.0] * batch_size)
+    scale = constant_op.constant([3.0] * batch_size)
+    loc_v = 2.0
+    scale_v = 3.0
+    x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32)
 
-      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
 
-      sf = laplace.log_survival_function(x)
-      self.assertEqual(sf.get_shape(), (6,))
-      if not stats:
-        return
-      expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v)
-      self.assertAllClose(self.evaluate(sf), expected_sf)
+    sf = laplace.log_survival_function(x)
+    self.assertEqual(sf.get_shape(), (6,))
+    if not stats:
+      return
+    expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v)
+    self.assertAllClose(self.evaluate(sf), expected_sf)
 
   def testLaplaceMean(self):
-    with self.test_session():
-      loc_v = np.array([1.0, 3.0, 2.5])
-      scale_v = np.array([1.0, 4.0, 5.0])
-      laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
-      self.assertEqual(laplace.mean().get_shape(), (3,))
-      if not stats:
-        return
-      expected_means = stats.laplace.mean(loc_v, scale=scale_v)
-      self.assertAllClose(self.evaluate(laplace.mean()), expected_means)
+    loc_v = np.array([1.0, 3.0, 2.5])
+    scale_v = np.array([1.0, 4.0, 5.0])
+    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+    self.assertEqual(laplace.mean().get_shape(), (3,))
+    if not stats:
+      return
+    expected_means = stats.laplace.mean(loc_v, scale=scale_v)
+    self.assertAllClose(self.evaluate(laplace.mean()), expected_means)
 
   def testLaplaceMode(self):
-    with self.test_session():
-      loc_v = np.array([0.5, 3.0, 2.5])
-      scale_v = np.array([1.0, 4.0, 5.0])
-      laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
-      self.assertEqual(laplace.mode().get_shape(), (3,))
-      self.assertAllClose(self.evaluate(laplace.mode()), loc_v)
+    loc_v = np.array([0.5, 3.0, 2.5])
+    scale_v = np.array([1.0, 4.0, 5.0])
+    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+    self.assertEqual(laplace.mode().get_shape(), (3,))
+    self.assertAllClose(self.evaluate(laplace.mode()), loc_v)
 
   def testLaplaceVariance(self):
-    with self.test_session():
-      loc_v = np.array([1.0, 3.0, 2.5])
-      scale_v = np.array([1.0, 4.0, 5.0])
-      laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
-      self.assertEqual(laplace.variance().get_shape(), (3,))
-      if not stats:
-        return
-      expected_variances = stats.laplace.var(loc_v, scale=scale_v)
-      self.assertAllClose(self.evaluate(laplace.variance()), expected_variances)
+    loc_v = np.array([1.0, 3.0, 2.5])
+    scale_v = np.array([1.0, 4.0, 5.0])
+    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+    self.assertEqual(laplace.variance().get_shape(), (3,))
+    if not stats:
+      return
+    expected_variances = stats.laplace.var(loc_v, scale=scale_v)
+    self.assertAllClose(self.evaluate(laplace.variance()), expected_variances)
 
   def testLaplaceStd(self):
-    with self.test_session():
-      loc_v = np.array([1.0, 3.0, 2.5])
-      scale_v = np.array([1.0, 4.0, 5.0])
-      laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
-      self.assertEqual(laplace.stddev().get_shape(), (3,))
-      if not stats:
-        return
-      expected_stddev = stats.laplace.std(loc_v, scale=scale_v)
-      self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev)
+    loc_v = np.array([1.0, 3.0, 2.5])
+    scale_v = np.array([1.0, 4.0, 5.0])
+    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+    self.assertEqual(laplace.stddev().get_shape(), (3,))
+    if not stats:
+      return
+    expected_stddev = stats.laplace.std(loc_v, scale=scale_v)
+    self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev)
 
   def testLaplaceEntropy(self):
-    with self.test_session():
-      loc_v = np.array([1.0, 3.0, 2.5])
-      scale_v = np.array([1.0, 4.0, 5.0])
-      laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
-      self.assertEqual(laplace.entropy().get_shape(), (3,))
-      if not stats:
-        return
-      expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v)
-      self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy)
+    loc_v = np.array([1.0, 3.0, 2.5])
+    scale_v = np.array([1.0, 4.0, 5.0])
+    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+    self.assertEqual(laplace.entropy().get_shape(), (3,))
+    if not stats:
+      return
+    expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v)
+    self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy)
 
   def testLaplaceSample(self):
-    with session.Session():
-      loc_v = 4.0
-      scale_v = 3.0
-      loc = constant_op.constant(loc_v)
-      scale = constant_op.constant(scale_v)
-      n = 100000
-      laplace = laplace_lib.Laplace(loc=loc, scale=scale)
-      samples = laplace.sample(n, seed=137)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(samples.get_shape(), (n,))
-      self.assertEqual(sample_values.shape, (n,))
-      if not stats:
-        return
-      self.assertAllClose(
-          sample_values.mean(),
-          stats.laplace.mean(
-              loc_v, scale=scale_v),
-          rtol=0.05,
-          atol=0.)
-      self.assertAllClose(
-          sample_values.var(),
-          stats.laplace.var(loc_v, scale=scale_v),
-          rtol=0.05,
-          atol=0.)
-      self.assertTrue(self._kstest(loc_v, scale_v, sample_values))
+    loc_v = 4.0
+    scale_v = 3.0
+    loc = constant_op.constant(loc_v)
+    scale = constant_op.constant(scale_v)
+    n = 100000
+    laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+    samples = laplace.sample(n, seed=137)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(samples.get_shape(), (n,))
+    self.assertEqual(sample_values.shape, (n,))
+    if not stats:
+      return
+    self.assertAllClose(
+        sample_values.mean(),
+        stats.laplace.mean(loc_v, scale=scale_v),
+        rtol=0.05,
+        atol=0.)
+    self.assertAllClose(
+        sample_values.var(),
+        stats.laplace.var(loc_v, scale=scale_v),
+        rtol=0.05,
+        atol=0.)
+    self.assertTrue(self._kstest(loc_v, scale_v, sample_values))
 
   def testLaplaceFullyReparameterized(self):
     loc = constant_op.constant(4.0)
@@ -269,39 +254,37 @@ class LaplaceTest(test.TestCase):
     self.assertIsNotNone(grad_scale)
 
   def testLaplaceSampleMultiDimensional(self):
-    with session.Session():
-      loc_v = np.array([np.arange(1, 101, dtype=np.float32)])  # 1 x 100
-      scale_v = np.array([np.arange(1, 11, dtype=np.float32)]).T  # 10 x 1
-      laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
-      n = 10000
-      samples = laplace.sample(n, seed=137)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(samples.get_shape(), (n, 10, 100))
-      self.assertEqual(sample_values.shape, (n, 10, 100))
-      zeros = np.zeros_like(loc_v + scale_v)  # 10 x 100
-      loc_bc = loc_v + zeros
-      scale_bc = scale_v + zeros
-      if not stats:
-        return
-      self.assertAllClose(
-          sample_values.mean(axis=0),
-          stats.laplace.mean(
-              loc_bc, scale=scale_bc),
-          rtol=0.35,
-          atol=0.)
-      self.assertAllClose(
-          sample_values.var(axis=0),
-          stats.laplace.var(loc_bc, scale=scale_bc),
-          rtol=0.10,
-          atol=0.)
-      fails = 0
-      trials = 0
-      for ai, a in enumerate(np.reshape(loc_v, [-1])):
-        for bi, b in enumerate(np.reshape(scale_v, [-1])):
-          s = sample_values[:, bi, ai]
-          trials += 1
-          fails += 0 if self._kstest(a, b, s) else 1
-      self.assertLess(fails, trials * 0.03)
+    loc_v = np.array([np.arange(1, 101, dtype=np.float32)])  # 1 x 100
+    scale_v = np.array([np.arange(1, 11, dtype=np.float32)]).T  # 10 x 1
+    laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v)
+    n = 10000
+    samples = laplace.sample(n, seed=137)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(samples.get_shape(), (n, 10, 100))
+    self.assertEqual(sample_values.shape, (n, 10, 100))
+    zeros = np.zeros_like(loc_v + scale_v)  # 10 x 100
+    loc_bc = loc_v + zeros
+    scale_bc = scale_v + zeros
+    if not stats:
+      return
+    self.assertAllClose(
+        sample_values.mean(axis=0),
+        stats.laplace.mean(loc_bc, scale=scale_bc),
+        rtol=0.35,
+        atol=0.)
+    self.assertAllClose(
+        sample_values.var(axis=0),
+        stats.laplace.var(loc_bc, scale=scale_bc),
+        rtol=0.10,
+        atol=0.)
+    fails = 0
+    trials = 0
+    for ai, a in enumerate(np.reshape(loc_v, [-1])):
+      for bi, b in enumerate(np.reshape(scale_v, [-1])):
+        s = sample_values[:, bi, ai]
+        trials += 1
+        fails += 0 if self._kstest(a, b, s) else 1
+    self.assertLess(fails, trials * 0.03)
 
   def _kstest(self, loc, scale, samples):
     # Uses the Kolmogorov-Smirnov test for goodness of fit.
@@ -349,30 +332,26 @@ class LaplaceTest(test.TestCase):
     self.assertNear(1., total, err=err)
 
   def testLaplaceNonPositiveInitializationParamsRaises(self):
-    with self.test_session():
-      loc_v = constant_op.constant(0.0, name="loc")
-      scale_v = constant_op.constant(-1.0, name="scale")
-      with self.assertRaisesOpError(
-          "Condition x > 0 did not hold element-wise"):
-        laplace = laplace_lib.Laplace(
-            loc=loc_v, scale=scale_v, validate_args=True)
-        self.evaluate(laplace.mean())
-      loc_v = constant_op.constant(1.0, name="loc")
-      scale_v = constant_op.constant(0.0, name="scale")
-      with self.assertRaisesOpError(
-          "Condition x > 0 did not hold element-wise"):
-        laplace = laplace_lib.Laplace(
-            loc=loc_v, scale=scale_v, validate_args=True)
-        self.evaluate(laplace.mean())
+    loc_v = constant_op.constant(0.0, name="loc")
+    scale_v = constant_op.constant(-1.0, name="scale")
+    with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"):
+      laplace = laplace_lib.Laplace(
+          loc=loc_v, scale=scale_v, validate_args=True)
+      self.evaluate(laplace.mean())
+    loc_v = constant_op.constant(1.0, name="loc")
+    scale_v = constant_op.constant(0.0, name="scale")
+    with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"):
+      laplace = laplace_lib.Laplace(
+          loc=loc_v, scale=scale_v, validate_args=True)
+      self.evaluate(laplace.mean())
 
   def testLaplaceWithSoftplusScale(self):
-    with self.test_session():
-      loc_v = constant_op.constant([0.0, 1.0], name="loc")
-      scale_v = constant_op.constant([-1.0, 2.0], name="scale")
-      laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v)
-      self.assertAllClose(
-          self.evaluate(nn_ops.softplus(scale_v)), self.evaluate(laplace.scale))
-      self.assertAllClose(self.evaluate(loc_v), self.evaluate(laplace.loc))
+    loc_v = constant_op.constant([0.0, 1.0], name="loc")
+    scale_v = constant_op.constant([-1.0, 2.0], name="scale")
+    laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v)
+    self.assertAllClose(
+        self.evaluate(nn_ops.softplus(scale_v)), self.evaluate(laplace.scale))
+    self.assertAllClose(self.evaluate(loc_v), self.evaluate(laplace.loc))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py
index 5dcd6f6df46..de73a40b234 100644
--- a/tensorflow/python/kernel_tests/distributions/normal_test.py
+++ b/tensorflow/python/kernel_tests/distributions/normal_test.py
@@ -61,16 +61,15 @@ class NormalTest(test.TestCase):
     self.assertAllEqual(all_true, is_finite)
 
   def _testParamShapes(self, sample_shape, expected):
-    with self.test_session():
-      param_shapes = normal_lib.Normal.param_shapes(sample_shape)
-      mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"]
-      self.assertAllEqual(expected, self.evaluate(mu_shape))
-      self.assertAllEqual(expected, self.evaluate(sigma_shape))
-      mu = array_ops.zeros(mu_shape)
-      sigma = array_ops.ones(sigma_shape)
-      self.assertAllEqual(
-          expected,
-          self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample())))
+    param_shapes = normal_lib.Normal.param_shapes(sample_shape)
+    mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"]
+    self.assertAllEqual(expected, self.evaluate(mu_shape))
+    self.assertAllEqual(expected, self.evaluate(sigma_shape))
+    mu = array_ops.zeros(mu_shape)
+    sigma = array_ops.ones(sigma_shape)
+    self.assertAllEqual(
+        expected,
+        self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample())))
 
   def _testParamStaticShapes(self, sample_shape, expected):
     param_shapes = normal_lib.Normal.param_static_shapes(sample_shape)
@@ -93,154 +92,148 @@ class NormalTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testNormalWithSoftplusScale(self):
-    with self.test_session():
-      mu = array_ops.zeros((10, 3))
-      rho = array_ops.ones((10, 3)) * -2.
-      normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho)
-      self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc))
-      self.assertAllEqual(
-          self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale))
+    mu = array_ops.zeros((10, 3))
+    rho = array_ops.ones((10, 3)) * -2.
+    normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho)
+    self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc))
+    self.assertAllEqual(
+        self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale))
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalLogPDF(self):
-    with self.test_session():
-      batch_size = 6
-      mu = constant_op.constant([3.0] * batch_size)
-      sigma = constant_op.constant([math.sqrt(10.0)] * batch_size)
-      x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    batch_size = 6
+    mu = constant_op.constant([3.0] * batch_size)
+    sigma = constant_op.constant([math.sqrt(10.0)] * batch_size)
+    x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      log_pdf = normal.log_prob(x)
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(log_pdf).shape)
-      self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
+    log_pdf = normal.log_prob(x)
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(log_pdf).shape)
+    self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
 
-      pdf = normal.prob(x)
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(pdf).shape)
-      self.assertAllEqual(normal.batch_shape, pdf.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape)
+    pdf = normal.prob(x)
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(pdf).shape)
+    self.assertAllEqual(normal.batch_shape, pdf.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape)
 
-      if not stats:
-        return
-      expected_log_pdf = stats.norm(self.evaluate(mu),
-                                    self.evaluate(sigma)).logpdf(x)
-      self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf))
-      self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf))
+    if not stats:
+      return
+    expected_log_pdf = stats.norm(self.evaluate(mu),
+                                  self.evaluate(sigma)).logpdf(x)
+    self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf))
+    self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf))
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalLogPDFMultidimensional(self):
-    with self.test_session():
-      batch_size = 6
-      mu = constant_op.constant([[3.0, -3.0]] * batch_size)
-      sigma = constant_op.constant([[math.sqrt(10.0), math.sqrt(15.0)]] *
-                                   batch_size)
-      x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    batch_size = 6
+    mu = constant_op.constant([[3.0, -3.0]] * batch_size)
+    sigma = constant_op.constant(
+        [[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size)
+    x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      log_pdf = normal.log_prob(x)
-      log_pdf_values = self.evaluate(log_pdf)
-      self.assertEqual(log_pdf.get_shape(), (6, 2))
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(log_pdf).shape)
-      self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
+    log_pdf = normal.log_prob(x)
+    log_pdf_values = self.evaluate(log_pdf)
+    self.assertEqual(log_pdf.get_shape(), (6, 2))
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(log_pdf).shape)
+    self.assertAllEqual(normal.batch_shape, log_pdf.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape)
 
-      pdf = normal.prob(x)
-      pdf_values = self.evaluate(pdf)
-      self.assertEqual(pdf.get_shape(), (6, 2))
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), pdf_values.shape)
-      self.assertAllEqual(normal.batch_shape, pdf.get_shape())
-      self.assertAllEqual(normal.batch_shape, pdf_values.shape)
+    pdf = normal.prob(x)
+    pdf_values = self.evaluate(pdf)
+    self.assertEqual(pdf.get_shape(), (6, 2))
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), pdf.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), pdf_values.shape)
+    self.assertAllEqual(normal.batch_shape, pdf.get_shape())
+    self.assertAllEqual(normal.batch_shape, pdf_values.shape)
 
-      if not stats:
-        return
-      expected_log_pdf = stats.norm(self.evaluate(mu),
-                                    self.evaluate(sigma)).logpdf(x)
-      self.assertAllClose(expected_log_pdf, log_pdf_values)
-      self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
+    if not stats:
+      return
+    expected_log_pdf = stats.norm(self.evaluate(mu),
+                                  self.evaluate(sigma)).logpdf(x)
+    self.assertAllClose(expected_log_pdf, log_pdf_values)
+    self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalCDF(self):
-    with self.test_session():
-      batch_size = 50
-      mu = self._rng.randn(batch_size)
-      sigma = self._rng.rand(batch_size) + 1.0
-      x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
+    batch_size = 50
+    mu = self._rng.randn(batch_size)
+    sigma = self._rng.rand(batch_size) + 1.0
+    x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
 
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
-      cdf = normal.cdf(x)
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(cdf).shape)
-      self.assertAllEqual(normal.batch_shape, cdf.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
-      if not stats:
-        return
-      expected_cdf = stats.norm(mu, sigma).cdf(x)
-      self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
+    cdf = normal.cdf(x)
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(cdf).shape)
+    self.assertAllEqual(normal.batch_shape, cdf.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
+    if not stats:
+      return
+    expected_cdf = stats.norm(mu, sigma).cdf(x)
+    self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0)
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalSurvivalFunction(self):
-    with self.test_session():
-      batch_size = 50
-      mu = self._rng.randn(batch_size)
-      sigma = self._rng.rand(batch_size) + 1.0
-      x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
+    batch_size = 50
+    mu = self._rng.randn(batch_size)
+    sigma = self._rng.rand(batch_size) + 1.0
+    x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
 
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      sf = normal.survival_function(x)
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(sf).shape)
-      self.assertAllEqual(normal.batch_shape, sf.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
-      if not stats:
-        return
-      expected_sf = stats.norm(mu, sigma).sf(x)
-      self.assertAllClose(expected_sf, self.evaluate(sf), atol=0)
+    sf = normal.survival_function(x)
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(sf).shape)
+    self.assertAllEqual(normal.batch_shape, sf.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
+    if not stats:
+      return
+    expected_sf = stats.norm(mu, sigma).sf(x)
+    self.assertAllClose(expected_sf, self.evaluate(sf), atol=0)
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalLogCDF(self):
-    with self.test_session():
-      batch_size = 50
-      mu = self._rng.randn(batch_size)
-      sigma = self._rng.rand(batch_size) + 1.0
-      x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64)
+    batch_size = 50
+    mu = self._rng.randn(batch_size)
+    sigma = self._rng.rand(batch_size) + 1.0
+    x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64)
 
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      cdf = normal.log_cdf(x)
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(cdf).shape)
-      self.assertAllEqual(normal.batch_shape, cdf.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
+    cdf = normal.log_cdf(x)
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), cdf.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(cdf).shape)
+    self.assertAllEqual(normal.batch_shape, cdf.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape)
 
-      if not stats:
-        return
-      expected_cdf = stats.norm(mu, sigma).logcdf(x)
-      self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3)
+    if not stats:
+      return
+    expected_cdf = stats.norm(mu, sigma).logcdf(x)
+    self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3)
 
   def testFiniteGradientAtDifficultPoints(self):
     for dtype in [np.float32, np.float64]:
@@ -256,7 +249,7 @@ class NormalTest(test.TestCase):
         ]:
           value = func(x)
           grads = gradients_impl.gradients(value, [mu, sigma])
-          with self.test_session(graph=g):
+          with self.session(graph=g):
             variables.global_variables_initializer().run()
             self.assertAllFinite(value)
             self.assertAllFinite(grads[0])
@@ -264,112 +257,106 @@ class NormalTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalLogSurvivalFunction(self):
-    with self.test_session():
-      batch_size = 50
-      mu = self._rng.randn(batch_size)
-      sigma = self._rng.rand(batch_size) + 1.0
-      x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64)
+    batch_size = 50
+    mu = self._rng.randn(batch_size)
+    sigma = self._rng.rand(batch_size) + 1.0
+    x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64)
 
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      sf = normal.log_survival_function(x)
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(sf).shape)
-      self.assertAllEqual(normal.batch_shape, sf.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
+    sf = normal.log_survival_function(x)
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), sf.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(sf).shape)
+    self.assertAllEqual(normal.batch_shape, sf.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape)
 
-      if not stats:
-        return
-      expected_sf = stats.norm(mu, sigma).logsf(x)
-      self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5)
+    if not stats:
+      return
+    expected_sf = stats.norm(mu, sigma).logsf(x)
+    self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5)
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalEntropyWithScalarInputs(self):
     # Scipy.stats.norm cannot deal with the shapes in the other test.
-    with self.test_session():
-      mu_v = 2.34
-      sigma_v = 4.56
-      normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
+    mu_v = 2.34
+    sigma_v = 4.56
+    normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
 
-      entropy = normal.entropy()
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(entropy).shape)
-      self.assertAllEqual(normal.batch_shape, entropy.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
-      # scipy.stats.norm cannot deal with these shapes.
-      if not stats:
-        return
-      expected_entropy = stats.norm(mu_v, sigma_v).entropy()
-      self.assertAllClose(expected_entropy, self.evaluate(entropy))
+    entropy = normal.entropy()
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(entropy).shape)
+    self.assertAllEqual(normal.batch_shape, entropy.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
+    # scipy.stats.norm cannot deal with these shapes.
+    if not stats:
+      return
+    expected_entropy = stats.norm(mu_v, sigma_v).entropy()
+    self.assertAllClose(expected_entropy, self.evaluate(entropy))
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalEntropy(self):
-    with self.test_session():
-      mu_v = np.array([1.0, 1.0, 1.0])
-      sigma_v = np.array([[1.0, 2.0, 3.0]]).T
-      normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
+    mu_v = np.array([1.0, 1.0, 1.0])
+    sigma_v = np.array([[1.0, 2.0, 3.0]]).T
+    normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
 
-      # scipy.stats.norm cannot deal with these shapes.
-      sigma_broadcast = mu_v * sigma_v
-      expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**
-                                      2)
-      entropy = normal.entropy()
-      np.testing.assert_allclose(expected_entropy, self.evaluate(entropy))
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(entropy).shape)
-      self.assertAllEqual(normal.batch_shape, entropy.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
+    # scipy.stats.norm cannot deal with these shapes.
+    sigma_broadcast = mu_v * sigma_v
+    expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**2)
+    entropy = normal.entropy()
+    np.testing.assert_allclose(expected_entropy, self.evaluate(entropy))
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), entropy.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(entropy).shape)
+    self.assertAllEqual(normal.batch_shape, entropy.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
 
   @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testNormalMeanAndMode(self):
-    with self.test_session():
-      # Mu will be broadcast to [7, 7, 7].
-      mu = [7.]
-      sigma = [11., 12., 13.]
+    # Mu will be broadcast to [7, 7, 7].
+    mu = [7.]
+    sigma = [11., 12., 13.]
 
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      self.assertAllEqual((3,), normal.mean().get_shape())
-      self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean()))
+    self.assertAllEqual((3,), normal.mean().get_shape())
+    self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean()))
 
-      self.assertAllEqual((3,), normal.mode().get_shape())
-      self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode()))
+    self.assertAllEqual((3,), normal.mode().get_shape())
+    self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode()))
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalQuantile(self):
-    with self.test_session():
-      batch_size = 52
-      mu = self._rng.randn(batch_size)
-      sigma = self._rng.rand(batch_size) + 1.0
-      p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64)
-      # Quantile performs piecewise rational approximation so adding some
-      # special input values to make sure we hit all the pieces.
-      p = np.hstack((p, np.exp(-33), 1. - np.exp(-33)))
+    batch_size = 52
+    mu = self._rng.randn(batch_size)
+    sigma = self._rng.rand(batch_size) + 1.0
+    p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64)
+    # Quantile performs piecewise rational approximation so adding some
+    # special input values to make sure we hit all the pieces.
+    p = np.hstack((p, np.exp(-33), 1. - np.exp(-33)))
 
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
-      x = normal.quantile(p)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
+    x = normal.quantile(p)
 
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()), x.get_shape())
-      self.assertAllEqual(
-          self.evaluate(normal.batch_shape_tensor()),
-          self.evaluate(x).shape)
-      self.assertAllEqual(normal.batch_shape, x.get_shape())
-      self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape)
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()), x.get_shape())
+    self.assertAllEqual(
+        self.evaluate(normal.batch_shape_tensor()),
+        self.evaluate(x).shape)
+    self.assertAllEqual(normal.batch_shape, x.get_shape())
+    self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape)
 
-      if not stats:
-        return
-      expected_x = stats.norm(mu, sigma).ppf(p)
-      self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
+    if not stats:
+      return
+    expected_x = stats.norm(mu, sigma).ppf(p)
+    self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
 
   def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype):
     g = ops.Graph()
@@ -385,7 +372,7 @@ class NormalTest(test.TestCase):
 
       value = dist.quantile(p)
       grads = gradients_impl.gradients(value, [mu, p])
-      with self.test_session(graph=g):
+      with self.cached_session(graph=g):
         variables.global_variables_initializer().run()
         self.assertAllFinite(grads[0])
         self.assertAllFinite(grads[1])
@@ -398,61 +385,58 @@ class NormalTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalVariance(self):
-    with self.test_session():
-      # sigma will be broadcast to [7, 7, 7]
-      mu = [1., 2., 3.]
-      sigma = [7.]
+    # sigma will be broadcast to [7, 7, 7]
+    mu = [1., 2., 3.]
+    sigma = [7.]
 
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      self.assertAllEqual((3,), normal.variance().get_shape())
-      self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance()))
+    self.assertAllEqual((3,), normal.variance().get_shape())
+    self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance()))
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalStandardDeviation(self):
-    with self.test_session():
-      # sigma will be broadcast to [7, 7, 7]
-      mu = [1., 2., 3.]
-      sigma = [7.]
+    # sigma will be broadcast to [7, 7, 7]
+    mu = [1., 2., 3.]
+    sigma = [7.]
 
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      self.assertAllEqual((3,), normal.stddev().get_shape())
-      self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev()))
+    self.assertAllEqual((3,), normal.stddev().get_shape())
+    self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev()))
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalSample(self):
-    with self.test_session():
-      mu = constant_op.constant(3.0)
-      sigma = constant_op.constant(math.sqrt(3.0))
-      mu_v = 3.0
-      sigma_v = np.sqrt(3.0)
-      n = constant_op.constant(100000)
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
-      samples = normal.sample(n)
-      sample_values = self.evaluate(samples)
-      # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
-      # The sample variance similarly is dependent on sigma and n.
-      # Thus, the tolerances below are very sensitive to number of samples
-      # as well as the variances chosen.
-      self.assertEqual(sample_values.shape, (100000,))
-      self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1)
-      self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1)
+    mu = constant_op.constant(3.0)
+    sigma = constant_op.constant(math.sqrt(3.0))
+    mu_v = 3.0
+    sigma_v = np.sqrt(3.0)
+    n = constant_op.constant(100000)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
+    samples = normal.sample(n)
+    sample_values = self.evaluate(samples)
+    # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
+    # The sample variance similarly is dependent on sigma and n.
+    # Thus, the tolerances below are very sensitive to number of samples
+    # as well as the variances chosen.
+    self.assertEqual(sample_values.shape, (100000,))
+    self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1)
+    self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1)
 
-      expected_samples_shape = tensor_shape.TensorShape(
-          [self.evaluate(n)]).concatenate(
-              tensor_shape.TensorShape(
-                  self.evaluate(normal.batch_shape_tensor())))
+    expected_samples_shape = tensor_shape.TensorShape(
+        [self.evaluate(n)]).concatenate(
+            tensor_shape.TensorShape(
+                self.evaluate(normal.batch_shape_tensor())))
 
-      self.assertAllEqual(expected_samples_shape, samples.get_shape())
-      self.assertAllEqual(expected_samples_shape, sample_values.shape)
+    self.assertAllEqual(expected_samples_shape, samples.get_shape())
+    self.assertAllEqual(expected_samples_shape, sample_values.shape)
 
-      expected_samples_shape = (
-          tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
-              normal.batch_shape))
+    expected_samples_shape = (
+        tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
+            normal.batch_shape))
 
-      self.assertAllEqual(expected_samples_shape, samples.get_shape())
-      self.assertAllEqual(expected_samples_shape, sample_values.shape)
+    self.assertAllEqual(expected_samples_shape, samples.get_shape())
+    self.assertAllEqual(expected_samples_shape, sample_values.shape)
 
   def testNormalFullyReparameterized(self):
     mu = constant_op.constant(4.0)
@@ -468,66 +452,63 @@ class NormalTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalSampleMultiDimensional(self):
-    with self.test_session():
-      batch_size = 2
-      mu = constant_op.constant([[3.0, -3.0]] * batch_size)
-      sigma = constant_op.constant([[math.sqrt(2.0), math.sqrt(3.0)]] *
-                                   batch_size)
-      mu_v = [3.0, -3.0]
-      sigma_v = [np.sqrt(2.0), np.sqrt(3.0)]
-      n = constant_op.constant(100000)
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
-      samples = normal.sample(n)
-      sample_values = self.evaluate(samples)
-      # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
-      # The sample variance similarly is dependent on sigma and n.
-      # Thus, the tolerances below are very sensitive to number of samples
-      # as well as the variances chosen.
-      self.assertEqual(samples.get_shape(), (100000, batch_size, 2))
-      self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1)
-      self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1)
-      self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1)
-      self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1)
+    batch_size = 2
+    mu = constant_op.constant([[3.0, -3.0]] * batch_size)
+    sigma = constant_op.constant(
+        [[math.sqrt(2.0), math.sqrt(3.0)]] * batch_size)
+    mu_v = [3.0, -3.0]
+    sigma_v = [np.sqrt(2.0), np.sqrt(3.0)]
+    n = constant_op.constant(100000)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
+    samples = normal.sample(n)
+    sample_values = self.evaluate(samples)
+    # Note that the standard error for the sample mean is ~ sigma / sqrt(n).
+    # The sample variance similarly is dependent on sigma and n.
+    # Thus, the tolerances below are very sensitive to number of samples
+    # as well as the variances chosen.
+    self.assertEqual(samples.get_shape(), (100000, batch_size, 2))
+    self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1)
+    self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1)
+    self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1)
+    self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1)
 
-      expected_samples_shape = tensor_shape.TensorShape(
-          [self.evaluate(n)]).concatenate(
-              tensor_shape.TensorShape(
-                  self.evaluate(normal.batch_shape_tensor())))
-      self.assertAllEqual(expected_samples_shape, samples.get_shape())
-      self.assertAllEqual(expected_samples_shape, sample_values.shape)
+    expected_samples_shape = tensor_shape.TensorShape(
+        [self.evaluate(n)]).concatenate(
+            tensor_shape.TensorShape(
+                self.evaluate(normal.batch_shape_tensor())))
+    self.assertAllEqual(expected_samples_shape, samples.get_shape())
+    self.assertAllEqual(expected_samples_shape, sample_values.shape)
 
-      expected_samples_shape = (
-          tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
-              normal.batch_shape))
-      self.assertAllEqual(expected_samples_shape, samples.get_shape())
-      self.assertAllEqual(expected_samples_shape, sample_values.shape)
+    expected_samples_shape = (
+        tensor_shape.TensorShape([self.evaluate(n)]).concatenate(
+            normal.batch_shape))
+    self.assertAllEqual(expected_samples_shape, samples.get_shape())
+    self.assertAllEqual(expected_samples_shape, sample_values.shape)
 
   @test_util.run_in_graph_and_eager_modes
   def testNegativeSigmaFails(self):
-    with self.test_session():
-      with self.assertRaisesOpError("Condition x > 0 did not hold"):
-        normal = normal_lib.Normal(
-            loc=[1.], scale=[-5.], validate_args=True, name="G")
-        self.evaluate(normal.mean())
+    with self.assertRaisesOpError("Condition x > 0 did not hold"):
+      normal = normal_lib.Normal(
+          loc=[1.], scale=[-5.], validate_args=True, name="G")
+      self.evaluate(normal.mean())
 
   @test_util.run_in_graph_and_eager_modes
   def testNormalShape(self):
-    with self.test_session():
-      mu = constant_op.constant([-3.0] * 5)
-      sigma = constant_op.constant(11.0)
-      normal = normal_lib.Normal(loc=mu, scale=sigma)
+    mu = constant_op.constant([-3.0] * 5)
+    sigma = constant_op.constant(11.0)
+    normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-      self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5])
-      self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5]))
-      self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), [])
-      self.assertEqual(normal.event_shape, tensor_shape.TensorShape([]))
+    self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5])
+    self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5]))
+    self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), [])
+    self.assertEqual(normal.event_shape, tensor_shape.TensorShape([]))
 
   def testNormalShapeWithPlaceholders(self):
     mu = array_ops.placeholder(dtype=dtypes.float32)
     sigma = array_ops.placeholder(dtype=dtypes.float32)
     normal = normal_lib.Normal(loc=mu, scale=sigma)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # get_batch_shape should return an "<unknown>" tensor.
       self.assertEqual(normal.batch_shape, tensor_shape.TensorShape(None))
       self.assertEqual(normal.event_shape, ())
diff --git a/tensorflow/python/kernel_tests/distributions/special_math_test.py b/tensorflow/python/kernel_tests/distributions/special_math_test.py
index a634194ce52..cc43e121686 100644
--- a/tensorflow/python/kernel_tests/distributions/special_math_test.py
+++ b/tensorflow/python/kernel_tests/distributions/special_math_test.py
@@ -92,22 +92,21 @@ class NdtriTest(test.TestCase):
   @test_util.run_in_graph_and_eager_modes
   def testNdtri(self):
     """Verifies that ndtri computation is correct."""
-    with self.test_session():
-      if not special:
-        return
+    if not special:
+      return
 
-      p = np.linspace(0., 1.0, 50).astype(np.float64)
-      # Quantile performs piecewise rational approximation so adding some
-      # special input values to make sure we hit all the pieces.
-      p = np.hstack((p, np.exp(-32), 1. - np.exp(-32),
-                     np.exp(-2), 1. - np.exp(-2)))
-      expected_x = special.ndtri(p)
-      x = special_math.ndtri(p)
-      self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
+    p = np.linspace(0., 1.0, 50).astype(np.float64)
+    # Quantile performs piecewise rational approximation so adding some
+    # special input values to make sure we hit all the pieces.
+    p = np.hstack((p, np.exp(-32), 1. - np.exp(-32), np.exp(-2),
+                   1. - np.exp(-2)))
+    expected_x = special.ndtri(p)
+    x = special_math.ndtri(p)
+    self.assertAllClose(expected_x, self.evaluate(x), atol=0.)
 
   def testNdtriDynamicShape(self):
     """Verifies that ndtri computation is correct."""
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       if not special:
         return
 
@@ -286,7 +285,7 @@ class NdtrGradientTest(test.TestCase):
   def _test_grad_accuracy(self, dtype, grid_spec, error_spec):
     raw_grid = _make_grid(dtype, grid_spec)
     grid = ops.convert_to_tensor(raw_grid)
-    with self.test_session():
+    with self.cached_session():
       fn = sm.log_ndtr if self._use_log else sm.ndtr
 
       # If there are N points in the grid,
@@ -355,7 +354,7 @@ class LogNdtrGradientTest(NdtrGradientTest):
 class ErfInvTest(test.TestCase):
 
   def testErfInvValues(self):
-    with self.test_session():
+    with self.cached_session():
       if not special:
         return
 
@@ -366,7 +365,7 @@ class ErfInvTest(test.TestCase):
       self.assertAllClose(expected_x, x.eval(), atol=0.)
 
   def testErfInvIntegerInput(self):
-    with self.test_session():
+    with self.cached_session():
 
       with self.assertRaises(TypeError):
         x = np.array([1, 2, 3]).astype(np.int32)
@@ -397,7 +396,7 @@ class LogCDFLaplaceTest(test.TestCase):
     self.assertAllEqual(np.ones_like(x, dtype=np.bool), x)
 
   def _test_grid_log(self, dtype, scipy_dtype, grid_spec, error_spec):
-    with self.test_session():
+    with self.cached_session():
       grid = _make_grid(dtype, grid_spec)
       actual = sm.log_cdf_laplace(grid).eval()
 
@@ -439,7 +438,7 @@ class LogCDFLaplaceTest(test.TestCase):
         ErrorSpec(rtol=0.05, atol=0))
 
   def test_float32_extreme_values_result_and_gradient_finite_and_nonzero(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # On the lower branch, log_cdf_laplace(x) = x, so we know this will be
       # fine, but test to -200 anyways.
       grid = _make_grid(
@@ -458,7 +457,7 @@ class LogCDFLaplaceTest(test.TestCase):
       self.assertFalse(np.any(grad_ == 0))
 
   def test_float64_extreme_values_result_and_gradient_finite_and_nonzero(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       # On the lower branch, log_cdf_laplace(x) = x, so we know this will be
       # fine, but test to -200 anyways.
       grid = _make_grid(
diff --git a/tensorflow/python/kernel_tests/distributions/student_t_test.py b/tensorflow/python/kernel_tests/distributions/student_t_test.py
index 05590542efe..b34b5381604 100644
--- a/tensorflow/python/kernel_tests/distributions/student_t_test.py
+++ b/tensorflow/python/kernel_tests/distributions/student_t_test.py
@@ -50,100 +50,96 @@ stats = try_import("scipy.stats")
 class StudentTTest(test.TestCase):
 
   def testStudentPDFAndLogPDF(self):
-    with self.test_session():
-      batch_size = 6
-      df = constant_op.constant([3.] * batch_size)
-      mu = constant_op.constant([7.] * batch_size)
-      sigma = constant_op.constant([8.] * batch_size)
-      df_v = 3.
-      mu_v = 7.
-      sigma_v = 8.
-      t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
-      student = student_t.StudentT(df, loc=mu, scale=-sigma)
+    batch_size = 6
+    df = constant_op.constant([3.] * batch_size)
+    mu = constant_op.constant([7.] * batch_size)
+    sigma = constant_op.constant([8.] * batch_size)
+    df_v = 3.
+    mu_v = 7.
+    sigma_v = 8.
+    t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
+    student = student_t.StudentT(df, loc=mu, scale=-sigma)
 
-      log_pdf = student.log_prob(t)
-      self.assertEquals(log_pdf.get_shape(), (6,))
-      log_pdf_values = self.evaluate(log_pdf)
-      pdf = student.prob(t)
-      self.assertEquals(pdf.get_shape(), (6,))
-      pdf_values = self.evaluate(pdf)
+    log_pdf = student.log_prob(t)
+    self.assertEquals(log_pdf.get_shape(), (6,))
+    log_pdf_values = self.evaluate(log_pdf)
+    pdf = student.prob(t)
+    self.assertEquals(pdf.get_shape(), (6,))
+    pdf_values = self.evaluate(pdf)
 
-      if not stats:
-        return
+    if not stats:
+      return
 
-      expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
-      expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
-      self.assertAllClose(expected_log_pdf, log_pdf_values)
-      self.assertAllClose(np.log(expected_pdf), log_pdf_values)
-      self.assertAllClose(expected_pdf, pdf_values)
-      self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
+    expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
+    expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
+    self.assertAllClose(expected_log_pdf, log_pdf_values)
+    self.assertAllClose(np.log(expected_pdf), log_pdf_values)
+    self.assertAllClose(expected_pdf, pdf_values)
+    self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
 
   def testStudentLogPDFMultidimensional(self):
-    with self.test_session():
-      batch_size = 6
-      df = constant_op.constant([[1.5, 7.2]] * batch_size)
-      mu = constant_op.constant([[3., -3.]] * batch_size)
-      sigma = constant_op.constant([[-math.sqrt(10.), math.sqrt(15.)]] *
-                                   batch_size)
-      df_v = np.array([1.5, 7.2])
-      mu_v = np.array([3., -3.])
-      sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)])
-      t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T
-      student = student_t.StudentT(df, loc=mu, scale=sigma)
-      log_pdf = student.log_prob(t)
-      log_pdf_values = self.evaluate(log_pdf)
-      self.assertEqual(log_pdf.get_shape(), (6, 2))
-      pdf = student.prob(t)
-      pdf_values = self.evaluate(pdf)
-      self.assertEqual(pdf.get_shape(), (6, 2))
+    batch_size = 6
+    df = constant_op.constant([[1.5, 7.2]] * batch_size)
+    mu = constant_op.constant([[3., -3.]] * batch_size)
+    sigma = constant_op.constant(
+        [[-math.sqrt(10.), math.sqrt(15.)]] * batch_size)
+    df_v = np.array([1.5, 7.2])
+    mu_v = np.array([3., -3.])
+    sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)])
+    t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T
+    student = student_t.StudentT(df, loc=mu, scale=sigma)
+    log_pdf = student.log_prob(t)
+    log_pdf_values = self.evaluate(log_pdf)
+    self.assertEqual(log_pdf.get_shape(), (6, 2))
+    pdf = student.prob(t)
+    pdf_values = self.evaluate(pdf)
+    self.assertEqual(pdf.get_shape(), (6, 2))
 
-      if not stats:
-        return
-      expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
-      expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
-      self.assertAllClose(expected_log_pdf, log_pdf_values)
-      self.assertAllClose(np.log(expected_pdf), log_pdf_values)
-      self.assertAllClose(expected_pdf, pdf_values)
-      self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
+    if not stats:
+      return
+    expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
+    expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
+    self.assertAllClose(expected_log_pdf, log_pdf_values)
+    self.assertAllClose(np.log(expected_pdf), log_pdf_values)
+    self.assertAllClose(expected_pdf, pdf_values)
+    self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
 
   def testStudentCDFAndLogCDF(self):
-    with self.test_session():
-      batch_size = 6
-      df = constant_op.constant([3.] * batch_size)
-      mu = constant_op.constant([7.] * batch_size)
-      sigma = constant_op.constant([-8.] * batch_size)
-      df_v = 3.
-      mu_v = 7.
-      sigma_v = 8.
-      t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
-      student = student_t.StudentT(df, loc=mu, scale=sigma)
+    batch_size = 6
+    df = constant_op.constant([3.] * batch_size)
+    mu = constant_op.constant([7.] * batch_size)
+    sigma = constant_op.constant([-8.] * batch_size)
+    df_v = 3.
+    mu_v = 7.
+    sigma_v = 8.
+    t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
+    student = student_t.StudentT(df, loc=mu, scale=sigma)
 
-      log_cdf = student.log_cdf(t)
-      self.assertEquals(log_cdf.get_shape(), (6,))
-      log_cdf_values = self.evaluate(log_cdf)
-      cdf = student.cdf(t)
-      self.assertEquals(cdf.get_shape(), (6,))
-      cdf_values = self.evaluate(cdf)
+    log_cdf = student.log_cdf(t)
+    self.assertEquals(log_cdf.get_shape(), (6,))
+    log_cdf_values = self.evaluate(log_cdf)
+    cdf = student.cdf(t)
+    self.assertEquals(cdf.get_shape(), (6,))
+    cdf_values = self.evaluate(cdf)
 
-      if not stats:
-        return
-      expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v)
-      expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v)
-      self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5)
-      self.assertAllClose(
-          np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5)
-      self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5)
-      self.assertAllClose(
-          np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5)
+    if not stats:
+      return
+    expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v)
+    expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v)
+    self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5)
+    self.assertAllClose(
+        np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5)
+    self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5)
+    self.assertAllClose(
+        np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5)
 
   def testStudentEntropy(self):
     df_v = np.array([[2., 3., 7.]])  # 1x3
     mu_v = np.array([[1., -1, 0]])  # 1x3
     sigma_v = np.array([[1., -2., 3.]]).T  # transposed => 3x1
-    with self.test_session():
-      student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
-      ent = student.entropy()
-      ent_values = self.evaluate(ent)
+    student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
+    ent = student.entropy()
+    ent_values = self.evaluate(ent)
 
     # Help scipy broadcast to 3x3
     ones = np.array([[1, 1, 1]])
@@ -160,90 +156,81 @@ class StudentTTest(test.TestCase):
     self.assertAllClose(expected_entropy, ent_values)
 
   def testStudentSample(self):
-    with self.test_session():
-      df = constant_op.constant(4.)
-      mu = constant_op.constant(3.)
-      sigma = constant_op.constant(-math.sqrt(10.))
-      df_v = 4.
-      mu_v = 3.
-      sigma_v = np.sqrt(10.)
-      n = constant_op.constant(200000)
-      student = student_t.StudentT(df=df, loc=mu, scale=sigma)
-      samples = student.sample(n, seed=123456)
-      sample_values = self.evaluate(samples)
-      n_val = 200000
-      self.assertEqual(sample_values.shape, (n_val,))
-      self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0)
-      self.assertAllClose(
-          sample_values.var(),
-          sigma_v**2 * df_v / (df_v - 2),
-          rtol=0.1,
-          atol=0)
-      self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
+    df = constant_op.constant(4.)
+    mu = constant_op.constant(3.)
+    sigma = constant_op.constant(-math.sqrt(10.))
+    df_v = 4.
+    mu_v = 3.
+    sigma_v = np.sqrt(10.)
+    n = constant_op.constant(200000)
+    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+    samples = student.sample(n, seed=123456)
+    sample_values = self.evaluate(samples)
+    n_val = 200000
+    self.assertEqual(sample_values.shape, (n_val,))
+    self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0)
+    self.assertAllClose(
+        sample_values.var(), sigma_v**2 * df_v / (df_v - 2), rtol=0.1, atol=0)
+    self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
 
   # Test that sampling with the same seed twice gives the same results.
   def testStudentSampleMultipleTimes(self):
-    with self.test_session():
-      df = constant_op.constant(4.)
-      mu = constant_op.constant(3.)
-      sigma = constant_op.constant(math.sqrt(10.))
-      n = constant_op.constant(100)
+    df = constant_op.constant(4.)
+    mu = constant_op.constant(3.)
+    sigma = constant_op.constant(math.sqrt(10.))
+    n = constant_op.constant(100)
 
-      random_seed.set_random_seed(654321)
-      student = student_t.StudentT(
-          df=df, loc=mu, scale=sigma, name="student_t1")
-      samples1 = self.evaluate(student.sample(n, seed=123456))
+    random_seed.set_random_seed(654321)
+    student = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t1")
+    samples1 = self.evaluate(student.sample(n, seed=123456))
 
-      random_seed.set_random_seed(654321)
-      student2 = student_t.StudentT(
-          df=df, loc=mu, scale=sigma, name="student_t2")
-      samples2 = self.evaluate(student2.sample(n, seed=123456))
+    random_seed.set_random_seed(654321)
+    student2 = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t2")
+    samples2 = self.evaluate(student2.sample(n, seed=123456))
 
-      self.assertAllClose(samples1, samples2)
+    self.assertAllClose(samples1, samples2)
 
   def testStudentSampleSmallDfNoNan(self):
-    with self.test_session():
-      df_v = [1e-1, 1e-5, 1e-10, 1e-20]
-      df = constant_op.constant(df_v)
-      n = constant_op.constant(200000)
-      student = student_t.StudentT(df=df, loc=1., scale=1.)
-      samples = student.sample(n, seed=123456)
-      sample_values = self.evaluate(samples)
-      n_val = 200000
-      self.assertEqual(sample_values.shape, (n_val, 4))
-      self.assertTrue(np.all(np.logical_not(np.isnan(sample_values))))
+    df_v = [1e-1, 1e-5, 1e-10, 1e-20]
+    df = constant_op.constant(df_v)
+    n = constant_op.constant(200000)
+    student = student_t.StudentT(df=df, loc=1., scale=1.)
+    samples = student.sample(n, seed=123456)
+    sample_values = self.evaluate(samples)
+    n_val = 200000
+    self.assertEqual(sample_values.shape, (n_val, 4))
+    self.assertTrue(np.all(np.logical_not(np.isnan(sample_values))))
 
   def testStudentSampleMultiDimensional(self):
-    with self.test_session():
-      batch_size = 7
-      df = constant_op.constant([[5., 7.]] * batch_size)
-      mu = constant_op.constant([[3., -3.]] * batch_size)
-      sigma = constant_op.constant([[math.sqrt(10.), math.sqrt(15.)]] *
-                                   batch_size)
-      df_v = [5., 7.]
-      mu_v = [3., -3.]
-      sigma_v = [np.sqrt(10.), np.sqrt(15.)]
-      n = constant_op.constant(200000)
-      student = student_t.StudentT(df=df, loc=mu, scale=sigma)
-      samples = student.sample(n, seed=123456)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
-      self.assertAllClose(
-          sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0)
-      self.assertAllClose(
-          sample_values[:, 0, 0].var(),
-          sigma_v[0]**2 * df_v[0] / (df_v[0] - 2),
-          rtol=0.2,
-          atol=0)
-      self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0])
-      self.assertAllClose(
-          sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0)
-      self.assertAllClose(
-          sample_values[:, 0, 1].var(),
-          sigma_v[1]**2 * df_v[1] / (df_v[1] - 2),
-          rtol=0.2,
-          atol=0)
-      self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1])
+    batch_size = 7
+    df = constant_op.constant([[5., 7.]] * batch_size)
+    mu = constant_op.constant([[3., -3.]] * batch_size)
+    sigma = constant_op.constant(
+        [[math.sqrt(10.), math.sqrt(15.)]] * batch_size)
+    df_v = [5., 7.]
+    mu_v = [3., -3.]
+    sigma_v = [np.sqrt(10.), np.sqrt(15.)]
+    n = constant_op.constant(200000)
+    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+    samples = student.sample(n, seed=123456)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
+    self.assertAllClose(
+        sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0)
+    self.assertAllClose(
+        sample_values[:, 0, 0].var(),
+        sigma_v[0]**2 * df_v[0] / (df_v[0] - 2),
+        rtol=0.2,
+        atol=0)
+    self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0])
+    self.assertAllClose(
+        sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0)
+    self.assertAllClose(
+        sample_values[:, 0, 1].var(),
+        sigma_v[1]**2 * df_v[1] / (df_v[1] - 2),
+        rtol=0.2,
+        atol=0)
+    self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1])
 
   def _checkKLApprox(self, df, mu, sigma, samples):
     n = samples.size
@@ -325,114 +312,102 @@ class StudentTTest(test.TestCase):
     _check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]]))
 
   def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
-    with self.test_session():
-      mu = [1., 3.3, 4.4]
-      student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.])
-      mean = self.evaluate(student.mean())
-      self.assertAllClose([1., 3.3, 4.4], mean)
+    mu = [1., 3.3, 4.4]
+    student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.])
+    mean = self.evaluate(student.mean())
+    self.assertAllClose([1., 3.3, 4.4], mean)
 
   def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self):
-    with self.test_session():
-      mu = [1., 3.3, 4.4]
-      student = student_t.StudentT(
-          df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.],
-          allow_nan_stats=False)
-      with self.assertRaisesOpError("x < y"):
-        self.evaluate(student.mean())
+    mu = [1., 3.3, 4.4]
+    student = student_t.StudentT(
+        df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], allow_nan_stats=False)
+    with self.assertRaisesOpError("x < y"):
+      self.evaluate(student.mean())
 
   def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self):
-    with self.test_session():
-      mu = [-2, 0., 1., 3.3, 4.4]
-      sigma = [5., 4., 3., 2., 1.]
-      student = student_t.StudentT(
-          df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma,
-          allow_nan_stats=True)
-      mean = self.evaluate(student.mean())
-      self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
+    mu = [-2, 0., 1., 3.3, 4.4]
+    sigma = [5., 4., 3., 2., 1.]
+    student = student_t.StudentT(
+        df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, allow_nan_stats=True)
+    mean = self.evaluate(student.mean())
+    self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
 
   def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self):
-    with self.test_session():
-      # df = 0.5 ==> undefined mean ==> undefined variance.
-      # df = 1.5 ==> infinite variance.
-      df = [0.5, 1.5, 3., 5., 7.]
-      mu = [-2, 0., 1., 3.3, 4.4]
-      sigma = [5., 4., 3., 2., 1.]
-      student = student_t.StudentT(
-          df=df, loc=mu, scale=sigma, allow_nan_stats=True)
-      var = self.evaluate(student.variance())
-      ## scipy uses inf for variance when the mean is undefined.  When mean is
-      # undefined we say variance is undefined as well.  So test the first
-      # member of var, making sure it is NaN, then replace with inf and compare
-      # to scipy.
-      self.assertTrue(np.isnan(var[0]))
-      var[0] = np.inf
+    # df = 0.5 ==> undefined mean ==> undefined variance.
+    # df = 1.5 ==> infinite variance.
+    df = [0.5, 1.5, 3., 5., 7.]
+    mu = [-2, 0., 1., 3.3, 4.4]
+    sigma = [5., 4., 3., 2., 1.]
+    student = student_t.StudentT(
+        df=df, loc=mu, scale=sigma, allow_nan_stats=True)
+    var = self.evaluate(student.variance())
+    ## scipy uses inf for variance when the mean is undefined.  When mean is
+    # undefined we say variance is undefined as well.  So test the first
+    # member of var, making sure it is NaN, then replace with inf and compare
+    # to scipy.
+    self.assertTrue(np.isnan(var[0]))
+    var[0] = np.inf
 
-      if not stats:
-        return
-      expected_var = [
-          stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
-      ]
-      self.assertAllClose(expected_var, var)
+    if not stats:
+      return
+    expected_var = [
+        stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
+    ]
+    self.assertAllClose(expected_var, var)
 
   def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers(
       self):
-    with self.test_session():
-      # df = 1.5 ==> infinite variance.
-      df = [1.5, 3., 5., 7.]
-      mu = [0., 1., 3.3, 4.4]
-      sigma = [4., 3., 2., 1.]
-      student = student_t.StudentT(df=df, loc=mu, scale=sigma)
-      var = self.evaluate(student.variance())
+    # df = 1.5 ==> infinite variance.
+    df = [1.5, 3., 5., 7.]
+    mu = [0., 1., 3.3, 4.4]
+    sigma = [4., 3., 2., 1.]
+    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+    var = self.evaluate(student.variance())
 
-      if not stats:
-        return
-      expected_var = [
-          stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
-      ]
-      self.assertAllClose(expected_var, var)
+    if not stats:
+      return
+    expected_var = [
+        stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
+    ]
+    self.assertAllClose(expected_var, var)
 
   def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
-    with self.test_session():
-      # df <= 1 ==> variance not defined
-      student = student_t.StudentT(
-          df=1., loc=0., scale=1., allow_nan_stats=False)
-      with self.assertRaisesOpError("x < y"):
-        self.evaluate(student.variance())
+    # df <= 1 ==> variance not defined
+    student = student_t.StudentT(df=1., loc=0., scale=1., allow_nan_stats=False)
+    with self.assertRaisesOpError("x < y"):
+      self.evaluate(student.variance())
 
-    with self.test_session():
-      # df <= 1 ==> variance not defined
-      student = student_t.StudentT(
-          df=0.5, loc=0., scale=1., allow_nan_stats=False)
-      with self.assertRaisesOpError("x < y"):
-        self.evaluate(student.variance())
+    # df <= 1 ==> variance not defined
+    student = student_t.StudentT(
+        df=0.5, loc=0., scale=1., allow_nan_stats=False)
+    with self.assertRaisesOpError("x < y"):
+      self.evaluate(student.variance())
 
   def testStd(self):
-    with self.test_session():
-      # Defined for all batch members.
-      df = [3.5, 5., 3., 5., 7.]
-      mu = [-2.2]
-      sigma = [5., 4., 3., 2., 1.]
-      student = student_t.StudentT(df=df, loc=mu, scale=sigma)
-      # Test broadcast of mu across shape of df/sigma
-      stddev = self.evaluate(student.stddev())
-      mu *= len(df)
+    # Defined for all batch members.
+    df = [3.5, 5., 3., 5., 7.]
+    mu = [-2.2]
+    sigma = [5., 4., 3., 2., 1.]
+    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+    # Test broadcast of mu across shape of df/sigma
+    stddev = self.evaluate(student.stddev())
+    mu *= len(df)
 
-      if not stats:
-        return
-      expected_stddev = [
-          stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
-      ]
-      self.assertAllClose(expected_stddev, stddev)
+    if not stats:
+      return
+    expected_stddev = [
+        stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
+    ]
+    self.assertAllClose(expected_stddev, stddev)
 
   def testMode(self):
-    with self.test_session():
-      df = [0.5, 1., 3]
-      mu = [-1, 0., 1]
-      sigma = [5., 4., 3.]
-      student = student_t.StudentT(df=df, loc=mu, scale=sigma)
-      # Test broadcast of mu across shape of df/sigma
-      mode = self.evaluate(student.mode())
-      self.assertAllClose([-1., 0, 1], mode)
+    df = [0.5, 1., 3]
+    mu = [-1, 0., 1]
+    sigma = [5., 4., 3.]
+    student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+    # Test broadcast of mu across shape of df/sigma
+    mode = self.evaluate(student.mode())
+    self.assertAllClose([-1., 0, 1], mode)
 
   def testPdfOfSample(self):
     student = student_t.StudentT(df=3., loc=np.pi, scale=1.)
@@ -510,25 +485,23 @@ class StudentTTest(test.TestCase):
     self.assertNear(1., total, err=err)
 
   def testNegativeDofFails(self):
-    with self.test_session():
-      with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
-        student = student_t.StudentT(
-            df=[2, -5.], loc=0., scale=1., validate_args=True, name="S")
-        self.evaluate(student.mean())
+    with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
+      student = student_t.StudentT(
+          df=[2, -5.], loc=0., scale=1., validate_args=True, name="S")
+      self.evaluate(student.mean())
 
   def testStudentTWithAbsDfSoftplusScale(self):
-    with self.test_session():
-      df = constant_op.constant([-3.2, -4.6])
-      mu = constant_op.constant([-4.2, 3.4])
-      sigma = constant_op.constant([-6.4, -8.8])
-      student = student_t.StudentTWithAbsDfSoftplusScale(
-          df=df, loc=mu, scale=sigma)
-      self.assertAllClose(
-          math_ops.floor(self.evaluate(math_ops.abs(df))),
-          self.evaluate(student.df))
-      self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc))
-      self.assertAllClose(
-          self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale))
+    df = constant_op.constant([-3.2, -4.6])
+    mu = constant_op.constant([-4.2, 3.4])
+    sigma = constant_op.constant([-6.4, -8.8])
+    student = student_t.StudentTWithAbsDfSoftplusScale(
+        df=df, loc=mu, scale=sigma)
+    self.assertAllClose(
+        math_ops.floor(self.evaluate(math_ops.abs(df))),
+        self.evaluate(student.df))
+    self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc))
+    self.assertAllClose(
+        self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/distributions/uniform_test.py b/tensorflow/python/kernel_tests/distributions/uniform_test.py
index bc9c267b9a5..9cdcd369c17 100644
--- a/tensorflow/python/kernel_tests/distributions/uniform_test.py
+++ b/tensorflow/python/kernel_tests/distributions/uniform_test.py
@@ -50,255 +50,239 @@ class UniformTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformRange(self):
-    with self.test_session():
-      a = 3.0
-      b = 10.0
-      uniform = uniform_lib.Uniform(low=a, high=b)
-      self.assertAllClose(a, self.evaluate(uniform.low))
-      self.assertAllClose(b, self.evaluate(uniform.high))
-      self.assertAllClose(b - a, self.evaluate(uniform.range()))
+    a = 3.0
+    b = 10.0
+    uniform = uniform_lib.Uniform(low=a, high=b)
+    self.assertAllClose(a, self.evaluate(uniform.low))
+    self.assertAllClose(b, self.evaluate(uniform.high))
+    self.assertAllClose(b - a, self.evaluate(uniform.range()))
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformPDF(self):
-    with self.test_session():
-      a = constant_op.constant([-3.0] * 5 + [15.0])
-      b = constant_op.constant([11.0] * 5 + [20.0])
-      uniform = uniform_lib.Uniform(low=a, high=b)
+    a = constant_op.constant([-3.0] * 5 + [15.0])
+    b = constant_op.constant([11.0] * 5 + [20.0])
+    uniform = uniform_lib.Uniform(low=a, high=b)
 
-      a_v = -3.0
-      b_v = 11.0
-      x = np.array([-10.5, 4.0, 0.0, 10.99, 11.3, 17.0], dtype=np.float32)
+    a_v = -3.0
+    b_v = 11.0
+    x = np.array([-10.5, 4.0, 0.0, 10.99, 11.3, 17.0], dtype=np.float32)
 
-      def _expected_pdf():
-        pdf = np.zeros_like(x) + 1.0 / (b_v - a_v)
-        pdf[x > b_v] = 0.0
-        pdf[x < a_v] = 0.0
-        pdf[5] = 1.0 / (20.0 - 15.0)
-        return pdf
+    def _expected_pdf():
+      pdf = np.zeros_like(x) + 1.0 / (b_v - a_v)
+      pdf[x > b_v] = 0.0
+      pdf[x < a_v] = 0.0
+      pdf[5] = 1.0 / (20.0 - 15.0)
+      return pdf
 
-      expected_pdf = _expected_pdf()
+    expected_pdf = _expected_pdf()
 
-      pdf = uniform.prob(x)
-      self.assertAllClose(expected_pdf, self.evaluate(pdf))
+    pdf = uniform.prob(x)
+    self.assertAllClose(expected_pdf, self.evaluate(pdf))
 
-      log_pdf = uniform.log_prob(x)
-      self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf))
+    log_pdf = uniform.log_prob(x)
+    self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf))
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformShape(self):
-    with self.test_session():
-      a = constant_op.constant([-3.0] * 5)
-      b = constant_op.constant(11.0)
-      uniform = uniform_lib.Uniform(low=a, high=b)
+    a = constant_op.constant([-3.0] * 5)
+    b = constant_op.constant(11.0)
+    uniform = uniform_lib.Uniform(low=a, high=b)
 
-      self.assertEqual(self.evaluate(uniform.batch_shape_tensor()), (5,))
-      self.assertEqual(uniform.batch_shape, tensor_shape.TensorShape([5]))
-      self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), [])
-      self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([]))
+    self.assertEqual(self.evaluate(uniform.batch_shape_tensor()), (5,))
+    self.assertEqual(uniform.batch_shape, tensor_shape.TensorShape([5]))
+    self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), [])
+    self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([]))
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformPDFWithScalarEndpoint(self):
-    with self.test_session():
-      a = constant_op.constant([0.0, 5.0])
-      b = constant_op.constant(10.0)
-      uniform = uniform_lib.Uniform(low=a, high=b)
+    a = constant_op.constant([0.0, 5.0])
+    b = constant_op.constant(10.0)
+    uniform = uniform_lib.Uniform(low=a, high=b)
 
-      x = np.array([0.0, 8.0], dtype=np.float32)
-      expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)])
+    x = np.array([0.0, 8.0], dtype=np.float32)
+    expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)])
 
-      pdf = uniform.prob(x)
-      self.assertAllClose(expected_pdf, self.evaluate(pdf))
+    pdf = uniform.prob(x)
+    self.assertAllClose(expected_pdf, self.evaluate(pdf))
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformCDF(self):
-    with self.test_session():
-      batch_size = 6
-      a = constant_op.constant([1.0] * batch_size)
-      b = constant_op.constant([11.0] * batch_size)
-      a_v = 1.0
-      b_v = 11.0
-      x = np.array([-2.5, 2.5, 4.0, 0.0, 10.99, 12.0], dtype=np.float32)
+    batch_size = 6
+    a = constant_op.constant([1.0] * batch_size)
+    b = constant_op.constant([11.0] * batch_size)
+    a_v = 1.0
+    b_v = 11.0
+    x = np.array([-2.5, 2.5, 4.0, 0.0, 10.99, 12.0], dtype=np.float32)
 
-      uniform = uniform_lib.Uniform(low=a, high=b)
+    uniform = uniform_lib.Uniform(low=a, high=b)
 
-      def _expected_cdf():
-        cdf = (x - a_v) / (b_v - a_v)
-        cdf[x >= b_v] = 1
-        cdf[x < a_v] = 0
-        return cdf
+    def _expected_cdf():
+      cdf = (x - a_v) / (b_v - a_v)
+      cdf[x >= b_v] = 1
+      cdf[x < a_v] = 0
+      return cdf
 
-      cdf = uniform.cdf(x)
-      self.assertAllClose(_expected_cdf(), self.evaluate(cdf))
+    cdf = uniform.cdf(x)
+    self.assertAllClose(_expected_cdf(), self.evaluate(cdf))
 
-      log_cdf = uniform.log_cdf(x)
-      self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf))
+    log_cdf = uniform.log_cdf(x)
+    self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf))
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformEntropy(self):
-    with self.test_session():
-      a_v = np.array([1.0, 1.0, 1.0])
-      b_v = np.array([[1.5, 2.0, 3.0]])
-      uniform = uniform_lib.Uniform(low=a_v, high=b_v)
+    a_v = np.array([1.0, 1.0, 1.0])
+    b_v = np.array([[1.5, 2.0, 3.0]])
+    uniform = uniform_lib.Uniform(low=a_v, high=b_v)
 
-      expected_entropy = np.log(b_v - a_v)
-      self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy()))
+    expected_entropy = np.log(b_v - a_v)
+    self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy()))
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformAssertMaxGtMin(self):
-    with self.test_session():
-      a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32)
-      b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32)
+    a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32)
+    b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32)
 
-      with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
-                                               "x < y"):
-        uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True)
-        self.evaluate(uniform.low)
+    with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+                                             "x < y"):
+      uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True)
+      self.evaluate(uniform.low)
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformSample(self):
-    with self.test_session():
-      a = constant_op.constant([3.0, 4.0])
-      b = constant_op.constant(13.0)
-      a1_v = 3.0
-      a2_v = 4.0
-      b_v = 13.0
-      n = constant_op.constant(100000)
-      uniform = uniform_lib.Uniform(low=a, high=b)
+    a = constant_op.constant([3.0, 4.0])
+    b = constant_op.constant(13.0)
+    a1_v = 3.0
+    a2_v = 4.0
+    b_v = 13.0
+    n = constant_op.constant(100000)
+    uniform = uniform_lib.Uniform(low=a, high=b)
 
-      samples = uniform.sample(n, seed=137)
-      sample_values = self.evaluate(samples)
-      self.assertEqual(sample_values.shape, (100000, 2))
-      self.assertAllClose(
-          sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-1, rtol=0.)
-      self.assertAllClose(
-          sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-1, rtol=0.)
-      self.assertFalse(
-          np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v))
-      self.assertFalse(
-          np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v))
+    samples = uniform.sample(n, seed=137)
+    sample_values = self.evaluate(samples)
+    self.assertEqual(sample_values.shape, (100000, 2))
+    self.assertAllClose(
+        sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-1, rtol=0.)
+    self.assertAllClose(
+        sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-1, rtol=0.)
+    self.assertFalse(
+        np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v))
+    self.assertFalse(
+        np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v))
 
   @test_util.run_in_graph_and_eager_modes
   def _testUniformSampleMultiDimensional(self):
     # DISABLED: Please enable this test once b/issues/30149644 is resolved.
-    with self.test_session():
-      batch_size = 2
-      a_v = [3.0, 22.0]
-      b_v = [13.0, 35.0]
-      a = constant_op.constant([a_v] * batch_size)
-      b = constant_op.constant([b_v] * batch_size)
+    batch_size = 2
+    a_v = [3.0, 22.0]
+    b_v = [13.0, 35.0]
+    a = constant_op.constant([a_v] * batch_size)
+    b = constant_op.constant([b_v] * batch_size)
 
-      uniform = uniform_lib.Uniform(low=a, high=b)
+    uniform = uniform_lib.Uniform(low=a, high=b)
 
-      n_v = 100000
-      n = constant_op.constant(n_v)
-      samples = uniform.sample(n)
-      self.assertEqual(samples.get_shape(), (n_v, batch_size, 2))
+    n_v = 100000
+    n = constant_op.constant(n_v)
+    samples = uniform.sample(n)
+    self.assertEqual(samples.get_shape(), (n_v, batch_size, 2))
 
-      sample_values = self.evaluate(samples)
+    sample_values = self.evaluate(samples)
 
-      self.assertFalse(
-          np.any(sample_values[:, 0, 0] < a_v[0]) or
-          np.any(sample_values[:, 0, 0] >= b_v[0]))
-      self.assertFalse(
-          np.any(sample_values[:, 0, 1] < a_v[1]) or
-          np.any(sample_values[:, 0, 1] >= b_v[1]))
+    self.assertFalse(
+        np.any(sample_values[:, 0, 0] < a_v[0]) or
+        np.any(sample_values[:, 0, 0] >= b_v[0]))
+    self.assertFalse(
+        np.any(sample_values[:, 0, 1] < a_v[1]) or
+        np.any(sample_values[:, 0, 1] >= b_v[1]))
 
-      self.assertAllClose(
-          sample_values[:, 0, 0].mean(), (a_v[0] + b_v[0]) / 2, atol=1e-2)
-      self.assertAllClose(
-          sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2)
+    self.assertAllClose(
+        sample_values[:, 0, 0].mean(), (a_v[0] + b_v[0]) / 2, atol=1e-2)
+    self.assertAllClose(
+        sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2)
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformMean(self):
-    with self.test_session():
-      a = 10.0
-      b = 100.0
-      uniform = uniform_lib.Uniform(low=a, high=b)
-      if not stats:
-        return
-      s_uniform = stats.uniform(loc=a, scale=b - a)
-      self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean())
+    a = 10.0
+    b = 100.0
+    uniform = uniform_lib.Uniform(low=a, high=b)
+    if not stats:
+      return
+    s_uniform = stats.uniform(loc=a, scale=b - a)
+    self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean())
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformVariance(self):
-    with self.test_session():
-      a = 10.0
-      b = 100.0
-      uniform = uniform_lib.Uniform(low=a, high=b)
-      if not stats:
-        return
-      s_uniform = stats.uniform(loc=a, scale=b - a)
-      self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var())
+    a = 10.0
+    b = 100.0
+    uniform = uniform_lib.Uniform(low=a, high=b)
+    if not stats:
+      return
+    s_uniform = stats.uniform(loc=a, scale=b - a)
+    self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var())
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformStd(self):
-    with self.test_session():
-      a = 10.0
-      b = 100.0
-      uniform = uniform_lib.Uniform(low=a, high=b)
-      if not stats:
-        return
-      s_uniform = stats.uniform(loc=a, scale=b - a)
-      self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std())
+    a = 10.0
+    b = 100.0
+    uniform = uniform_lib.Uniform(low=a, high=b)
+    if not stats:
+      return
+    s_uniform = stats.uniform(loc=a, scale=b - a)
+    self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std())
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformNans(self):
-    with self.test_session():
-      a = 10.0
-      b = [11.0, 100.0]
-      uniform = uniform_lib.Uniform(low=a, high=b)
+    a = 10.0
+    b = [11.0, 100.0]
+    uniform = uniform_lib.Uniform(low=a, high=b)
 
-      no_nans = constant_op.constant(1.0)
-      nans = constant_op.constant(0.0) / constant_op.constant(0.0)
-      self.assertTrue(self.evaluate(math_ops.is_nan(nans)))
-      with_nans = array_ops.stack([no_nans, nans])
+    no_nans = constant_op.constant(1.0)
+    nans = constant_op.constant(0.0) / constant_op.constant(0.0)
+    self.assertTrue(self.evaluate(math_ops.is_nan(nans)))
+    with_nans = array_ops.stack([no_nans, nans])
 
-      pdf = uniform.prob(with_nans)
+    pdf = uniform.prob(with_nans)
 
-      is_nan = self.evaluate(math_ops.is_nan(pdf))
-      self.assertFalse(is_nan[0])
-      self.assertTrue(is_nan[1])
+    is_nan = self.evaluate(math_ops.is_nan(pdf))
+    self.assertFalse(is_nan[0])
+    self.assertTrue(is_nan[1])
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformSamplePdf(self):
-    with self.test_session():
-      a = 10.0
-      b = [11.0, 100.0]
-      uniform = uniform_lib.Uniform(a, b)
-      self.assertTrue(
-          self.evaluate(
-              math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0)))
+    a = 10.0
+    b = [11.0, 100.0]
+    uniform = uniform_lib.Uniform(a, b)
+    self.assertTrue(
+        self.evaluate(
+            math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0)))
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformBroadcasting(self):
-    with self.test_session():
-      a = 10.0
-      b = [11.0, 20.0]
-      uniform = uniform_lib.Uniform(a, b)
+    a = 10.0
+    b = [11.0, 20.0]
+    uniform = uniform_lib.Uniform(a, b)
 
-      pdf = uniform.prob([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]])
-      expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]])
-      self.assertAllClose(expected_pdf, self.evaluate(pdf))
+    pdf = uniform.prob([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]])
+    expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]])
+    self.assertAllClose(expected_pdf, self.evaluate(pdf))
 
   @test_util.run_in_graph_and_eager_modes
   def testUniformSampleWithShape(self):
-    with self.test_session():
-      a = 10.0
-      b = [11.0, 20.0]
-      uniform = uniform_lib.Uniform(a, b)
+    a = 10.0
+    b = [11.0, 20.0]
+    uniform = uniform_lib.Uniform(a, b)
 
-      pdf = uniform.prob(uniform.sample((2, 3)))
-      # pylint: disable=bad-continuation
-      expected_pdf = [
-          [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
-          [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
-      ]
-      # pylint: enable=bad-continuation
-      self.assertAllClose(expected_pdf, self.evaluate(pdf))
+    pdf = uniform.prob(uniform.sample((2, 3)))
+    # pylint: disable=bad-continuation
+    expected_pdf = [
+        [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
+        [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]],
+    ]
+    # pylint: enable=bad-continuation
+    self.assertAllClose(expected_pdf, self.evaluate(pdf))
 
-      pdf = uniform.prob(uniform.sample())
-      expected_pdf = [1.0, 0.1]
-      self.assertAllClose(expected_pdf, self.evaluate(pdf))
+    pdf = uniform.prob(uniform.sample())
+    expected_pdf = [1.0, 0.1]
+    self.assertAllClose(expected_pdf, self.evaluate(pdf))
 
   def testFullyReparameterized(self):
     a = constant_op.constant(0.1)
diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py
index 61faa8466ed..27d652c2c62 100644
--- a/tensorflow/python/kernel_tests/distributions/util_test.py
+++ b/tensorflow/python/kernel_tests/distributions/util_test.py
@@ -69,7 +69,7 @@ class AssertCloseTest(test.TestCase):
     w = array_ops.placeholder(dtypes.float32)
     feed_dict = {x: [1., 5, 10, 15, 20], y: [1.1, 5, 10, 15, 20],
                  z: [1.0001, 5, 10, 15, 20], w: [1e-8, 5, 10, 15, 20]}
-    with self.test_session():
+    with self.cached_session():
       with ops.control_dependencies([du.assert_integer_form(x)]):
         array_ops.identity(x).eval(feed_dict=feed_dict)
 
@@ -122,58 +122,52 @@ class GetLogitsAndProbsTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes
   def testImproperArguments(self):
-    with self.test_session():
-      with self.assertRaises(ValueError):
-        du.get_logits_and_probs(logits=None, probs=None)
+    with self.assertRaises(ValueError):
+      du.get_logits_and_probs(logits=None, probs=None)
 
-      with self.assertRaises(ValueError):
-        du.get_logits_and_probs(logits=[0.1], probs=[0.1])
+    with self.assertRaises(ValueError):
+      du.get_logits_and_probs(logits=[0.1], probs=[0.1])
 
   @test_util.run_in_graph_and_eager_modes
   def testLogits(self):
     p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)
     logits = _logit(p)
 
-    with self.test_session():
-      new_logits, new_p = du.get_logits_and_probs(
-          logits=logits, validate_args=True)
+    new_logits, new_p = du.get_logits_and_probs(
+        logits=logits, validate_args=True)
 
-      self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.)
-      self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.)
+    self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.)
+    self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.)
 
   @test_util.run_in_graph_and_eager_modes
   def testLogitsMultidimensional(self):
     p = np.array([0.2, 0.3, 0.5], dtype=np.float32)
     logits = np.log(p)
 
-    with self.test_session():
-      new_logits, new_p = du.get_logits_and_probs(
-          logits=logits, multidimensional=True, validate_args=True)
+    new_logits, new_p = du.get_logits_and_probs(
+        logits=logits, multidimensional=True, validate_args=True)
 
-      self.assertAllClose(self.evaluate(new_p), p)
-      self.assertAllClose(self.evaluate(new_logits), logits)
+    self.assertAllClose(self.evaluate(new_p), p)
+    self.assertAllClose(self.evaluate(new_logits), logits)
 
   @test_util.run_in_graph_and_eager_modes
   def testProbability(self):
     p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)
 
-    with self.test_session():
-      new_logits, new_p = du.get_logits_and_probs(
-          probs=p, validate_args=True)
+    new_logits, new_p = du.get_logits_and_probs(probs=p, validate_args=True)
 
-      self.assertAllClose(_logit(p), self.evaluate(new_logits))
-      self.assertAllClose(p, self.evaluate(new_p))
+    self.assertAllClose(_logit(p), self.evaluate(new_logits))
+    self.assertAllClose(p, self.evaluate(new_p))
 
   @test_util.run_in_graph_and_eager_modes
   def testProbabilityMultidimensional(self):
     p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
 
-    with self.test_session():
-      new_logits, new_p = du.get_logits_and_probs(
-          probs=p, multidimensional=True, validate_args=True)
+    new_logits, new_p = du.get_logits_and_probs(
+        probs=p, multidimensional=True, validate_args=True)
 
-      self.assertAllClose(np.log(p), self.evaluate(new_logits))
-      self.assertAllClose(p, self.evaluate(new_p))
+    self.assertAllClose(np.log(p), self.evaluate(new_logits))
+    self.assertAllClose(p, self.evaluate(new_p))
 
   @test_util.run_in_graph_and_eager_modes
   def testProbabilityValidateArgs(self):
@@ -183,28 +177,22 @@ class GetLogitsAndProbsTest(test.TestCase):
     # Component greater than 1.
     p3 = [2, 0.2, 0.5, 0.3, .2]
 
-    with self.test_session():
-      _, prob = du.get_logits_and_probs(
-          probs=p, validate_args=True)
+    _, prob = du.get_logits_and_probs(probs=p, validate_args=True)
+    self.evaluate(prob)
+
+    with self.assertRaisesOpError("Condition x >= 0"):
+      _, prob = du.get_logits_and_probs(probs=p2, validate_args=True)
       self.evaluate(prob)
 
-      with self.assertRaisesOpError("Condition x >= 0"):
-        _, prob = du.get_logits_and_probs(
-            probs=p2, validate_args=True)
-        self.evaluate(prob)
+    _, prob = du.get_logits_and_probs(probs=p2, validate_args=False)
+    self.evaluate(prob)
 
-      _, prob = du.get_logits_and_probs(
-          probs=p2, validate_args=False)
+    with self.assertRaisesOpError("probs has components greater than 1"):
+      _, prob = du.get_logits_and_probs(probs=p3, validate_args=True)
       self.evaluate(prob)
 
-      with self.assertRaisesOpError("probs has components greater than 1"):
-        _, prob = du.get_logits_and_probs(
-            probs=p3, validate_args=True)
-        self.evaluate(prob)
-
-      _, prob = du.get_logits_and_probs(
-          probs=p3, validate_args=False)
-      self.evaluate(prob)
+    _, prob = du.get_logits_and_probs(probs=p3, validate_args=False)
+    self.evaluate(prob)
 
   @test_util.run_in_graph_and_eager_modes
   def testProbabilityValidateArgsMultidimensional(self):
@@ -216,41 +204,39 @@ class GetLogitsAndProbsTest(test.TestCase):
     # Does not sum to 1.
     p4 = np.array([[1.1, 0.3, 0.4], [0.1, 0.5, 0.4]], dtype=np.float32)
 
-    with self.test_session():
+    _, prob = du.get_logits_and_probs(probs=p, multidimensional=True)
+    self.evaluate(prob)
+
+    with self.assertRaisesOpError("Condition x >= 0"):
       _, prob = du.get_logits_and_probs(
-          probs=p, multidimensional=True)
+          probs=p2, multidimensional=True, validate_args=True)
       self.evaluate(prob)
 
-      with self.assertRaisesOpError("Condition x >= 0"):
-        _, prob = du.get_logits_and_probs(
-            probs=p2, multidimensional=True, validate_args=True)
-        self.evaluate(prob)
+    _, prob = du.get_logits_and_probs(
+        probs=p2, multidimensional=True, validate_args=False)
+    self.evaluate(prob)
 
+    with self.assertRaisesOpError(
+        "(probs has components greater than 1|probs does not sum to 1)"):
       _, prob = du.get_logits_and_probs(
-          probs=p2, multidimensional=True, validate_args=False)
+          probs=p3, multidimensional=True, validate_args=True)
       self.evaluate(prob)
 
-      with self.assertRaisesOpError(
-          "(probs has components greater than 1|probs does not sum to 1)"):
-        _, prob = du.get_logits_and_probs(
-            probs=p3, multidimensional=True, validate_args=True)
-        self.evaluate(prob)
+    _, prob = du.get_logits_and_probs(
+        probs=p3, multidimensional=True, validate_args=False)
+    self.evaluate(prob)
 
+    with self.assertRaisesOpError("probs does not sum to 1"):
       _, prob = du.get_logits_and_probs(
-          probs=p3, multidimensional=True, validate_args=False)
+          probs=p4, multidimensional=True, validate_args=True)
       self.evaluate(prob)
 
-      with self.assertRaisesOpError("probs does not sum to 1"):
-        _, prob = du.get_logits_and_probs(
-            probs=p4, multidimensional=True, validate_args=True)
-        self.evaluate(prob)
-
-      _, prob = du.get_logits_and_probs(
-          probs=p4, multidimensional=True, validate_args=False)
-      self.evaluate(prob)
+    _, prob = du.get_logits_and_probs(
+        probs=p4, multidimensional=True, validate_args=False)
+    self.evaluate(prob)
 
   def testProbsMultidimShape(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         p = array_ops.ones([int(2**11+1)], dtype=np.float16)
         du.get_logits_and_probs(
@@ -264,7 +250,7 @@ class GetLogitsAndProbsTest(test.TestCase):
         prob.eval(feed_dict={p: np.ones([int(2**11+1)])})
 
   def testLogitsMultidimShape(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         l = array_ops.ones([int(2**11+1)], dtype=np.float16)
         du.get_logits_and_probs(
@@ -281,7 +267,7 @@ class GetLogitsAndProbsTest(test.TestCase):
 class EmbedCheckCategoricalEventShapeTest(test.TestCase):
 
   def testTooSmall(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         param = array_ops.ones([1], dtype=np.float16)
         checked_param = du.embed_check_categorical_event_shape(
@@ -295,7 +281,7 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase):
         checked_param.eval(feed_dict={param: np.ones([1])})
 
   def testTooLarge(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaises(ValueError):
         param = array_ops.ones([int(2**11+1)], dtype=dtypes.float16)
         checked_param = du.embed_check_categorical_event_shape(
@@ -310,18 +296,17 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes
   def testUnsupportedDtype(self):
-    with self.test_session():
-      param = ops.convert_to_tensor(
-          np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype),
-          dtype=dtypes.qint16)
-      with self.assertRaises(TypeError):
-        du.embed_check_categorical_event_shape(param)
+    param = ops.convert_to_tensor(
+        np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype),
+        dtype=dtypes.qint16)
+    with self.assertRaises(TypeError):
+      du.embed_check_categorical_event_shape(param)
 
 
 class EmbedCheckIntegerCastingClosedTest(test.TestCase):
 
   def testCorrectlyAssertsNonnegative(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError("Elements must be non-negative"):
         x = array_ops.placeholder(dtype=dtypes.float16)
         x_checked = du.embed_check_integer_casting_closed(
@@ -329,7 +314,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase):
         x_checked.eval(feed_dict={x: np.array([1, -1], dtype=np.float16)})
 
   def testCorrectlyAssersIntegerForm(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError("Elements must be int16-equivalent."):
         x = array_ops.placeholder(dtype=dtypes.float16)
         x_checked = du.embed_check_integer_casting_closed(
@@ -337,7 +322,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase):
         x_checked.eval(feed_dict={x: np.array([1, 1.5], dtype=np.float16)})
 
   def testCorrectlyAssertsLargestPossibleInteger(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError("Elements cannot exceed 32767."):
         x = array_ops.placeholder(dtype=dtypes.int32)
         x_checked = du.embed_check_integer_casting_closed(
@@ -345,7 +330,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase):
         x_checked.eval(feed_dict={x: np.array([1, 2**15], dtype=np.int32)})
 
   def testCorrectlyAssertsSmallestPossibleInteger(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesOpError("Elements cannot be smaller than 0."):
         x = array_ops.placeholder(dtype=dtypes.int32)
         x_checked = du.embed_check_integer_casting_closed(
@@ -365,29 +350,27 @@ class LogCombinationsTest(test.TestCase):
 
     log_combs = np.log(special.binom(n, k))
 
-    with self.test_session():
-      n = np.array(n, dtype=np.float32)
-      counts = [[1., 1], [2., 3], [4., 8], [11, 4]]
-      log_binom = du.log_combinations(n, counts)
-      self.assertEqual([4], log_binom.get_shape())
-      self.assertAllClose(log_combs, self.evaluate(log_binom))
+    n = np.array(n, dtype=np.float32)
+    counts = [[1., 1], [2., 3], [4., 8], [11, 4]]
+    log_binom = du.log_combinations(n, counts)
+    self.assertEqual([4], log_binom.get_shape())
+    self.assertAllClose(log_combs, self.evaluate(log_binom))
 
   def testLogCombinationsShape(self):
     # Shape [2, 2]
     n = [[2, 5], [12, 15]]
 
-    with self.test_session():
-      n = np.array(n, dtype=np.float32)
-      # Shape [2, 2, 4]
-      counts = [[[1., 1, 0, 0], [2., 2, 1, 0]], [[4., 4, 1, 3], [10, 1, 1, 4]]]
-      log_binom = du.log_combinations(n, counts)
-      self.assertEqual([2, 2], log_binom.get_shape())
+    n = np.array(n, dtype=np.float32)
+    # Shape [2, 2, 4]
+    counts = [[[1., 1, 0, 0], [2., 2, 1, 0]], [[4., 4, 1, 3], [10, 1, 1, 4]]]
+    log_binom = du.log_combinations(n, counts)
+    self.assertEqual([2, 2], log_binom.get_shape())
 
 
 class DynamicShapeTest(test.TestCase):
 
   def testSameDynamicShape(self):
-    with self.test_session():
+    with self.cached_session():
       scalar = constant_op.constant(2.0)
       scalar1 = array_ops.placeholder(dtype=dtypes.float32)
 
@@ -497,22 +480,21 @@ class RotateTransposeTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes
   def testRollStatic(self):
-    with self.test_session():
-      if context.executing_eagerly():
-        error_message = r"Attempt to convert a value \(None\)"
-      else:
-        error_message = "None values not supported."
-      with self.assertRaisesRegexp(ValueError, error_message):
-        du.rotate_transpose(None, 1)
-      for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))):
-        for shift in np.arange(-5, 5):
-          y = du.rotate_transpose(x, shift)
-          self.assertAllEqual(
-              self._np_rotate_transpose(x, shift), self.evaluate(y))
-          self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list())
+    if context.executing_eagerly():
+      error_message = r"Attempt to convert a value \(None\)"
+    else:
+      error_message = "None values not supported."
+    with self.assertRaisesRegexp(ValueError, error_message):
+      du.rotate_transpose(None, 1)
+    for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))):
+      for shift in np.arange(-5, 5):
+        y = du.rotate_transpose(x, shift)
+        self.assertAllEqual(
+            self._np_rotate_transpose(x, shift), self.evaluate(y))
+        self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list())
 
   def testRollDynamic(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = array_ops.placeholder(dtypes.float32)
       shift = array_ops.placeholder(dtypes.int32)
       for x_value in (np.ones(
@@ -530,7 +512,7 @@ class RotateTransposeTest(test.TestCase):
 class PickVectorTest(test.TestCase):
 
   def testCorrectlyPicksVector(self):
-    with self.test_session():
+    with self.cached_session():
       x = np.arange(10, 12)
       y = np.arange(15, 18)
       self.assertAllEqual(
@@ -568,19 +550,19 @@ class PreferStaticRankTest(test.TestCase):
   def testDynamicRankEndsUpBeingNonEmpty(self):
     x = array_ops.placeholder(np.float64, shape=None)
     rank = du.prefer_static_rank(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(2, rank.eval(feed_dict={x: np.zeros((2, 3))}))
 
   def testDynamicRankEndsUpBeingEmpty(self):
     x = array_ops.placeholder(np.int32, shape=None)
     rank = du.prefer_static_rank(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(1, rank.eval(feed_dict={x: []}))
 
   def testDynamicRankEndsUpBeingScalar(self):
     x = array_ops.placeholder(np.int32, shape=None)
     rank = du.prefer_static_rank(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(0, rank.eval(feed_dict={x: 1}))
 
 
@@ -607,19 +589,19 @@ class PreferStaticShapeTest(test.TestCase):
   def testDynamicShapeEndsUpBeingNonEmpty(self):
     x = array_ops.placeholder(np.float64, shape=None)
     shape = du.prefer_static_shape(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual((2, 3), shape.eval(feed_dict={x: np.zeros((2, 3))}))
 
   def testDynamicShapeEndsUpBeingEmpty(self):
     x = array_ops.placeholder(np.int32, shape=None)
     shape = du.prefer_static_shape(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(np.array([0]), shape.eval(feed_dict={x: []}))
 
   def testDynamicShapeEndsUpBeingScalar(self):
     x = array_ops.placeholder(np.int32, shape=None)
     shape = du.prefer_static_shape(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(np.array([]), shape.eval(feed_dict={x: 1}))
 
 
@@ -646,20 +628,20 @@ class PreferStaticValueTest(test.TestCase):
   def testDynamicValueEndsUpBeingNonEmpty(self):
     x = array_ops.placeholder(np.float64, shape=None)
     value = du.prefer_static_value(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(np.zeros((2, 3)),
                           value.eval(feed_dict={x: np.zeros((2, 3))}))
 
   def testDynamicValueEndsUpBeingEmpty(self):
     x = array_ops.placeholder(np.int32, shape=None)
     value = du.prefer_static_value(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(np.array([]), value.eval(feed_dict={x: []}))
 
   def testDynamicValueEndsUpBeingScalar(self):
     x = array_ops.placeholder(np.int32, shape=None)
     value = du.prefer_static_value(x)
-    with self.test_session():
+    with self.cached_session():
       self.assertAllEqual(np.array(1), value.eval(feed_dict={x: 1}))
 
 
@@ -691,7 +673,7 @@ class FillTriangularTest(test.TestCase):
 
   def _run_test(self, x_, use_deferred_shape=False, **kwargs):
     x_ = np.asarray(x_)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       static_shape = None if use_deferred_shape else x_.shape
       x_pl = array_ops.placeholder_with_default(x_, shape=static_shape)
       # Add `zeros_like(x)` such that x's value and gradient are identical. We
@@ -761,7 +743,7 @@ class FillTriangularInverseTest(FillTriangularTest):
 
   def _run_test(self, x_, use_deferred_shape=False, **kwargs):
     x_ = np.asarray(x_)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       static_shape = None if use_deferred_shape else x_.shape
       x_pl = array_ops.placeholder_with_default(x_, shape=static_shape)
       zeros_like_x_pl = (x_pl * array_ops.stop_gradient(x_pl - 1.)
@@ -795,7 +777,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
     logx_ = np.array([[0., -1, 1000.],
                       [0, 1, -1000.],
                       [-5, 0, 5]])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       logx = constant_op.constant(logx_)
       expected = math_ops.reduce_logsumexp(logx, axis=-1)
       grad_expected = gradients_impl.gradients(expected, logx)[0]
@@ -818,7 +800,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
                    [1, -2, 1],
                    [1, 0, 1]])
     expected, _ = self._reduce_weighted_logsumexp(logx_, w_, axis=-1)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       logx = constant_op.constant(logx_)
       w = constant_op.constant(w_)
       actual, actual_sgn = du.reduce_weighted_logsumexp(
@@ -836,7 +818,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
                    [1, 0, 1]])
     expected, _ = self._reduce_weighted_logsumexp(
         logx_, w_, axis=-1, keep_dims=True)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       logx = constant_op.constant(logx_)
       w = constant_op.constant(w_)
       actual, actual_sgn = du.reduce_weighted_logsumexp(
@@ -848,7 +830,7 @@ class ReduceWeightedLogSumExp(test.TestCase):
   def testDocString(self):
     """This test verifies the correctness of the docstring examples."""
 
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant([[0., 0, 0],
                                 [0, 0, 0]])
 
@@ -952,7 +934,7 @@ class SoftplusTest(test.TestCase):
           use_gpu=True)
 
   def testGradient(self):
-    with self.test_session():
+    with self.cached_session():
       x = constant_op.constant(
           [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
           shape=[2, 5],
@@ -968,7 +950,7 @@ class SoftplusTest(test.TestCase):
     self.assertLess(err, 1e-4)
 
   def testInverseSoftplusGradientNeverNan(self):
-    with self.test_session():
+    with self.cached_session():
       # Note that this range contains both zero and inf.
       x = constant_op.constant(np.logspace(-8, 6).astype(np.float16))
       y = du.softplus_inverse(x)
@@ -977,7 +959,7 @@ class SoftplusTest(test.TestCase):
       self.assertAllEqual(np.zeros_like(grads).astype(np.bool), np.isnan(grads))
 
   def testInverseSoftplusGradientFinite(self):
-    with self.test_session():
+    with self.cached_session():
       # This range of x is all finite, and so is 1 / x.  So the
       # gradient and its approximations should be finite as well.
       x = constant_op.constant(np.logspace(-4.8, 4.5).astype(np.float16))