From 28eeda839f124cf5ba648576e86214b38141e4ab Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Mon, 24 Sep 2018 11:28:07 -0700
Subject: [PATCH] Move from deprecated self.test_session() to
 self.cached_session().

self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.

PiperOrigin-RevId: 214300210
---
 .../integration_tests/errors_test.py          |  4 +-
 .../python/autograph/core/errors_test.py      |  6 +--
 tensorflow/python/autograph/impl/api_test.py  | 26 ++++-----
 .../autograph/lang/special_functions_test.py  |  4 +-
 .../autograph/operators/py_builtins_test.py   | 16 +++---
 .../python/autograph/operators/slices_test.py |  4 +-
 tensorflow/python/eager/function_test.py      |  6 +--
 .../python/keras/engine/topology_test.py      |  2 +-
 .../keras/utils/multi_gpu_utils_test.py       |  2 +-
 .../boosted_trees/prediction_ops_test.py      |  4 +-
 .../boosted_trees/quantile_ops_test.py        |  4 +-
 .../linalg/linear_operator_addition_test.py   | 24 ++++-----
 .../logging_ops_logging_level_test.py         |  6 +--
 .../python/kernel_tests/logging_ops_test.py   | 40 +++++++-------
 .../kernel_tests/string_format_op_test.py     | 54 +++++++++----------
 .../python/kernel_tests/while_v2_test.py      | 18 +++----
 tensorflow/python/ops/image_ops_test.py       |  6 +--
 tensorflow/python/training/ftrl_test.py       |  4 +-
 .../training/learning_rate_decay_v2_test.py   |  2 +-
 19 files changed, 116 insertions(+), 116 deletions(-)

diff --git a/tensorflow/examples/autograph/integration_tests/errors_test.py b/tensorflow/examples/autograph/integration_tests/errors_test.py
index 69e5936832b..9c10dad9aa3 100644
--- a/tensorflow/examples/autograph/integration_tests/errors_test.py
+++ b/tensorflow/examples/autograph/integration_tests/errors_test.py
@@ -92,7 +92,7 @@ class ErrorsTest(tf.test.TestCase):
     compiled_fn = ag.to_graph(test_fn)
 
     with self.assertRaises(ag.TfRuntimeError) as error:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         x = compiled_fn(tf.constant([4, 8]))
         with ag.improved_errors(compiled_fn):
           sess.run(x)
@@ -134,7 +134,7 @@ class ErrorsTest(tf.test.TestCase):
     # frame with "g" as the function name but because we don't yet add
     # try/except blocks to inner functions the name is "tf__g".
     with self.assertRaises(ag.TfRuntimeError) as error:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         x = compiled_fn(tf.constant([4, 8]))
         with ag.improved_errors(compiled_fn):
           sess.run(x)
diff --git a/tensorflow/python/autograph/core/errors_test.py b/tensorflow/python/autograph/core/errors_test.py
index 0444ed7eab5..aa6c293268c 100644
--- a/tensorflow/python/autograph/core/errors_test.py
+++ b/tensorflow/python/autograph/core/errors_test.py
@@ -54,7 +54,7 @@ class RuntimeErrorsTest(test.TestCase):
     ops = zero_div_caller()
     with self.assertRaises(errors.TfRuntimeError) as cm:
       with errors.improved_errors(zero_div_caller):
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           sess.run(ops)
 
     for frame in cm.exception.custom_traceback:
@@ -69,7 +69,7 @@ class RuntimeErrorsTest(test.TestCase):
     ops = zero_div_caller()
     with self.assertRaises(errors.TfRuntimeError) as cm:
       with errors.improved_errors(zero_div_caller):
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           sess.run(ops)
 
     all_function_names = set()
@@ -86,7 +86,7 @@ class RuntimeErrorsTest(test.TestCase):
     ops = zero_div_caller()
     with self.assertRaises(tf_errors.InvalidArgumentError):
       with errors.improved_errors(zero_div_caller):
-        with self.test_session() as sess:
+        with self.cached_session() as sess:
           sess.run(ops)
 
   def test_improved_errors_validation(self):
diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py
index e0770ef4c6e..8ce5022c0a0 100644
--- a/tensorflow/python/autograph/impl/api_test.py
+++ b/tensorflow/python/autograph/impl/api_test.py
@@ -55,7 +55,7 @@ class ApiTest(test.TestCase):
         return x
 
     tc = TestClass()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = tc.test_method(
           constant_op.constant([2, 4]), constant_op.constant(1),
           constant_op.constant(-2))
@@ -75,7 +75,7 @@ class ApiTest(test.TestCase):
         return x
 
     tc = TestClass()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = tc.test_method(
           constant_op.constant([2, 4]), constant_op.constant(1),
           constant_op.constant(-2))
@@ -96,7 +96,7 @@ class ApiTest(test.TestCase):
         return x
 
     tc = TestClass()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = tc.test_method(
           constant_op.constant([2, 4]), constant_op.constant(1),
           constant_op.constant(-2))
@@ -122,7 +122,7 @@ class ApiTest(test.TestCase):
         return x
 
     tc = TestClass()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = tc.test_method(
           constant_op.constant([2, 4]), constant_op.constant(1),
           constant_op.constant(-2))
@@ -145,7 +145,7 @@ class ApiTest(test.TestCase):
         return x
 
     tc = TestClass()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = tc.test_method(
           constant_op.constant([2, 4]), constant_op.constant(1),
           constant_op.constant(-2))
@@ -185,7 +185,7 @@ class ApiTest(test.TestCase):
         return x
 
     tc = TestClass()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = tc.test_method(
           constant_op.constant([2, 4]), constant_op.constant(1),
           constant_op.constant(-2))
@@ -202,7 +202,7 @@ class ApiTest(test.TestCase):
         return -x
       return x
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = api.converted_call(test_fn, api.ConversionOptions.new(),
                              constant_op.constant(-1))
       self.assertEqual(1, sess.run(x))
@@ -219,7 +219,7 @@ class ApiTest(test.TestCase):
           return -self.x
         return self.x
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tc = TestClass(constant_op.constant(-1))
       x = api.converted_call(tc.test_method, api.ConversionOptions.new(), tc)
       self.assertEqual(1, sess.run(x))
@@ -236,7 +236,7 @@ class ApiTest(test.TestCase):
           return -self.x
         return self.x
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tc = TestClass(constant_op.constant(-1))
       x = api.converted_call(
           TestClass.test_method,
@@ -255,7 +255,7 @@ class ApiTest(test.TestCase):
           return -self.x
         return self.x
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tc = TestClass(constant_op.constant(-1))
       x = api.converted_call(tc, api.ConversionOptions.new())
       self.assertEqual(1, sess.run(x))
@@ -272,7 +272,7 @@ class ApiTest(test.TestCase):
           return -self.x
         return self.x
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       tc = api.converted_call(TestClass, api.ConversionOptions.new(),
                               constant_op.constant(-1))
       # tc is now a converted object.
@@ -284,7 +284,7 @@ class ApiTest(test.TestCase):
     def f(x):
       return x == 0
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = api.converted_call(f, api.ConversionOptions.new(),
                              constant_op.constant(0))
       self.assertTrue(sess.run(x))
@@ -303,7 +303,7 @@ class ApiTest(test.TestCase):
 
     compiled_fn = api.to_graph(test_fn)
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       x = compiled_fn(constant_op.constant([4, 8]), 4)
       self.assertListEqual([1, 2], sess.run(x).tolist())
 
diff --git a/tensorflow/python/autograph/lang/special_functions_test.py b/tensorflow/python/autograph/lang/special_functions_test.py
index 1f1cec18f7a..545dd117294 100644
--- a/tensorflow/python/autograph/lang/special_functions_test.py
+++ b/tensorflow/python/autograph/lang/special_functions_test.py
@@ -33,7 +33,7 @@ class SpecialFunctionsTest(test.TestCase):
 
     l = special_functions.tensor_list(elements)
     sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
 
   def test_tensor_list_array_from_elements(self):
@@ -41,7 +41,7 @@ class SpecialFunctionsTest(test.TestCase):
 
     l = special_functions.tensor_list(elements, use_tensor_array=True)
     sl = l.stack()
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
 
   def test_stack(self):
diff --git a/tensorflow/python/autograph/operators/py_builtins_test.py b/tensorflow/python/autograph/operators/py_builtins_test.py
index a021263ffa8..d64d31cc791 100644
--- a/tensorflow/python/autograph/operators/py_builtins_test.py
+++ b/tensorflow/python/autograph/operators/py_builtins_test.py
@@ -36,7 +36,7 @@ class PyBuiltinsTest(test.TestCase):
 
   def test_abs(self):
     self.assertEqual(py_builtins.abs_(-1), 1)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       t = py_builtins.abs_(constant_op.constant(-1))
       self.assertEqual(sess.run(t), 1)
       t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
@@ -45,7 +45,7 @@ class PyBuiltinsTest(test.TestCase):
   def test_float(self):
     self.assertEqual(py_builtins.float_(10), 10.0)
     self.assertEqual(py_builtins.float_('10.0'), 10.0)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
       self.assertEqual(sess.run(t), 1.0)
       st = py_builtins.float_(constant_op.constant('1.0'))
@@ -54,7 +54,7 @@ class PyBuiltinsTest(test.TestCase):
   def test_int(self):
     self.assertEqual(py_builtins.int_(10.0), 10)
     self.assertEqual(py_builtins.int_('11', 2), 3)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
       self.assertEqual(sess.run(t), 1)
       st = py_builtins.int_(constant_op.constant('1'))
@@ -69,7 +69,7 @@ class PyBuiltinsTest(test.TestCase):
 
   def test_len(self):
     self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
       self.assertEqual(t, 3)
       ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
@@ -82,7 +82,7 @@ class PyBuiltinsTest(test.TestCase):
       py_builtins.len_(constant_op.constant(1))
 
   def test_len_dynamic_shape(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
       t = py_builtins.len_(p)
       self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)
@@ -95,7 +95,7 @@ class PyBuiltinsTest(test.TestCase):
     try:
       out_capturer = six.StringIO()
       sys.stdout = out_capturer
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(py_builtins.print_(constant_op.constant('test message'), 1))
         self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
     finally:
@@ -105,7 +105,7 @@ class PyBuiltinsTest(test.TestCase):
     try:
       out_capturer = six.StringIO()
       sys.stdout = out_capturer
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         sess.run(
             py_builtins.print_(constant_op.constant('test message'), [1, 2]))
         self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
@@ -118,7 +118,7 @@ class PyBuiltinsTest(test.TestCase):
     self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1])
 
   def test_range_tensor(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       r = py_builtins.range_(constant_op.constant(3))
       self.assertAllEqual(sess.run(r), [0, 1, 2])
       r = py_builtins.range_(1, constant_op.constant(3))
diff --git a/tensorflow/python/autograph/operators/slices_test.py b/tensorflow/python/autograph/operators/slices_test.py
index d8b8418750c..9e4865b3c66 100644
--- a/tensorflow/python/autograph/operators/slices_test.py
+++ b/tensorflow/python/autograph/operators/slices_test.py
@@ -51,14 +51,14 @@ class SlicesTest(test.TestCase):
     t = slices.get_item(initial_str, 1,
                         slices.GetItemOpts(element_dtype=initial_str.dtype))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(sess.run(t), b'b')
 
     initial_list_str = constant_op.constant(['abcd', 'bcde'])
     t = slices.get_item(initial_list_str, 1,
                         slices.GetItemOpts(element_dtype=initial_str.dtype))
 
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(sess.run(t), b'bcde')
 
 
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index e4513cc87ce..04f42f63d4e 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -1602,7 +1602,7 @@ class FunctionTest(test.TestCase):
     defun_add = function.defun_with_attributes(
         add, attributes={'experimental_3': True, 'experimental_4': 1.0})
 
-    with context.graph_mode(), self.test_session():
+    with context.graph_mode(), self.cached_session():
       with ops.get_default_graph().as_default():
         t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
         sq = matmul(t, t)
@@ -1636,7 +1636,7 @@ class FunctionTest(test.TestCase):
 
     with self.assertRaisesRegexp(ValueError,
                                  '.*Attribute name is not whitelisted.*'):
-      with context.graph_mode(), self.test_session():
+      with context.graph_mode(), self.cached_session():
         with ops.get_default_graph().as_default():
           t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
           matmul(t, t)
@@ -1647,7 +1647,7 @@ class FunctionTest(test.TestCase):
 
     with self.assertRaisesRegexp(ValueError,
                                  '.*Unsupported attribute type.*'):
-      with context.graph_mode(), self.test_session():
+      with context.graph_mode(), self.cached_session():
         with ops.get_default_graph().as_default():
           t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
           add(t, t)
diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py
index 061db8ee344..a0da96334b3 100644
--- a/tensorflow/python/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/engine/topology_test.py
@@ -915,7 +915,7 @@ class TopologyConstructionTest(test.TestCase):
 
   def test_constant_initializer_with_numpy(self):
 
-    with self.test_session():
+    with self.cached_session():
       initializer = keras.initializers.Constant(np.ones((3, 2)))
       model = keras.models.Sequential()
       model.add(keras.layers.Dense(2, input_shape=(3,),
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
index d6016ed7114..3d0351a11f4 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils_test.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
@@ -186,7 +186,7 @@ class TestMultiGPUModel(test.TestCase):
     if not check_if_compatible_devices(gpus=gpus):
       return
 
-    with self.test_session():
+    with self.cached_session():
       inputs = keras.Input((4, 3))
       init_state = keras.Input((3,))
       outputs = keras.layers.SimpleRNN(
diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
index 3b28d44cf8f..467e33ec877 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
@@ -934,7 +934,7 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
     For example, this could happen if the final ensemble contains one tree that
     got pruned up to the root.
     """
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge(
           """
@@ -990,7 +990,7 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
 
   def testContribsMultipleTreeWhenFirstTreeIsABiasNode(self):
     """Tests case when, after training, first tree contains only a bias node."""
-    with self.test_session() as session:
+    with self.cached_session() as session:
       tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
       text_format.Merge(
           """
diff --git a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
index c71b8df4ada..e0d46bae83a 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
@@ -78,7 +78,7 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
     self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64)
 
   def testBasicQuantileBucketsSingleResource(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       quantile_accumulator_handle = self.create_resource("floats", self.eps,
                                                          self.max_elements, 2)
       resources.initialize_resources(resources.shared_resources()).run()
@@ -102,7 +102,7 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
       self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
 
   def testBasicQuantileBucketsMultipleResources(self):
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       quantile_accumulator_handle_0 = self.create_resource("float_0", self.eps,
                                                            self.max_elements)
       quantile_accumulator_handle_1 = self.create_resource("float_1", self.eps,
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py
index 7c79fedf658..cf56168d637 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py
@@ -76,7 +76,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
         [1., 1.], is_positive_definite=True, name="A")
     op_b = linalg.LinearOperatorDiag(
         [2., 2.], is_positive_definite=True, name="B")
-    with self.test_session():
+    with self.cached_session():
       op_sum = add_operators([op_a, op_b])
       self.assertEqual(1, len(op_sum))
       op = op_sum[0]
@@ -98,7 +98,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
         [2., 2.], is_positive_definite=True, name="op2")
     op3 = linalg.LinearOperatorDiag(
         [3., 3.], is_positive_definite=True, name="op3")
-    with self.test_session():
+    with self.cached_session():
       op_sum = add_operators([op1, op2, op3])
       self.assertEqual(1, len(op_sum))
       op = op_sum[0]
@@ -121,7 +121,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
         name="tril")
     op3 = linalg.LinearOperatorDiag(
         [3., 3.], is_non_singular=True, name="diag_b")
-    with self.test_session():
+    with self.cached_session():
       op_sum = add_operators([op1, op2, op3])
       self.assertEqual(1, len(op_sum))
       op = op_sum[0]
@@ -143,7 +143,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
     op2 = linalg.LinearOperatorLowerTriangular(
         [[2., 0.], [1.5, 2.]], name="tril")
     op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b")
-    with self.test_session():
+    with self.cached_session():
       op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator")
       self.assertEqual(1, len(op_sum))
       op = op_sum[0]
@@ -233,7 +233,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase):
     self.assertEqual(2, len(op_sum))
     found_diag = False
     found_tril = False
-    with self.test_session():
+    with self.cached_session():
       for op in op_sum:
         if isinstance(op, linalg.LinearOperatorDiag):
           found_diag = True
@@ -273,7 +273,7 @@ class AddAndReturnScaledIdentityTest(test.TestCase):
     operator = self._adder.add(id1, id2, "my_operator", hints)
     self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(2 *
                           linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
                           operator.to_dense().eval())
@@ -291,7 +291,7 @@ class AddAndReturnScaledIdentityTest(test.TestCase):
     operator = self._adder.add(id1, id2, "my_operator", hints)
     self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(3.2 *
                           linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
                           operator.to_dense().eval())
@@ -310,7 +310,7 @@ class AddAndReturnScaledIdentityTest(test.TestCase):
     operator = self._adder.add(id1, id2, "my_operator", hints)
     self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(1.2 *
                           linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
                           operator.to_dense().eval())
@@ -334,7 +334,7 @@ class AddAndReturnDiagTest(test.TestCase):
     operator = self._adder.add(id1, id2, "my_operator", hints)
     self.assertIsInstance(operator, linalg.LinearOperatorDiag)
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(2 *
                           linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
                           operator.to_dense().eval())
@@ -354,7 +354,7 @@ class AddAndReturnDiagTest(test.TestCase):
     operator = self._adder.add(op1, op2, "my_operator", hints)
     self.assertIsInstance(operator, linalg.LinearOperatorDiag)
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose(
           linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(),
           operator.to_dense().eval())
@@ -379,7 +379,7 @@ class AddAndReturnTriLTest(test.TestCase):
     operator = self._adder.add(diag, tril, "my_operator", hints)
     self.assertIsInstance(operator, linalg.LinearOperatorLowerTriangular)
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval())
     self.assertTrue(operator.is_positive_definite)
     self.assertTrue(operator.is_non_singular)
@@ -401,7 +401,7 @@ class AddAndReturnMatrixTest(test.TestCase):
     operator = self._adder.add(diag1, diag2, "my_operator", hints)
     self.assertIsInstance(operator, linalg.LinearOperatorFullMatrix)
 
-    with self.test_session():
+    with self.cached_session():
       self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval())
     self.assertFalse(operator.is_positive_definite)
     self.assertFalse(operator.is_non_singular)
diff --git a/tensorflow/python/kernel_tests/logging_ops_logging_level_test.py b/tensorflow/python/kernel_tests/logging_ops_logging_level_test.py
index 252090b7bd7..0e8197dccbb 100644
--- a/tensorflow/python/kernel_tests/logging_ops_logging_level_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_logging_level_test.py
@@ -31,7 +31,7 @@ class PrintV2LoggingLevelTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testPrintOneTensorLogInfo(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       with self.captureWritesToStream(sys.stderr) as printed:
         print_op = logging_ops.print_v2(
@@ -43,7 +43,7 @@ class PrintV2LoggingLevelTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testPrintOneTensorLogWarning(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       with self.captureWritesToStream(sys.stderr) as printed:
         print_op = logging_ops.print_v2(
@@ -55,7 +55,7 @@ class PrintV2LoggingLevelTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testPrintOneTensorLogError(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       with self.captureWritesToStream(sys.stderr) as printed:
         print_op = logging_ops.print_v2(
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
index b24a0d0f9b5..4beddd00bb2 100644
--- a/tensorflow/python/kernel_tests/logging_ops_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -69,7 +69,7 @@ class PrintV2Test(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testPrintOneTensor(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       with self.captureWritesToStream(sys.stderr) as printed:
         print_op = logging_ops.print_v2(tensor)
@@ -80,7 +80,7 @@ class PrintV2Test(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testPrintOneTensorVarySummarize(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       with self.captureWritesToStream(sys.stderr) as printed:
         print_op = logging_ops.print_v2(tensor, summarize=1)
@@ -89,7 +89,7 @@ class PrintV2Test(test.TestCase):
       expected = "[0 ... 9]"
       self.assertTrue((expected + "\n") in printed.contents())
 
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       with self.captureWritesToStream(sys.stderr) as printed:
         print_op = logging_ops.print_v2(tensor, summarize=2)
@@ -98,7 +98,7 @@ class PrintV2Test(test.TestCase):
       expected = "[0 1 ... 8 9]"
       self.assertTrue((expected + "\n") in printed.contents())
 
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       with self.captureWritesToStream(sys.stderr) as printed:
         print_op = logging_ops.print_v2(tensor, summarize=3)
@@ -107,7 +107,7 @@ class PrintV2Test(test.TestCase):
       expected = "[0 1 2 ... 7 8 9]"
       self.assertTrue((expected + "\n") in printed.contents())
 
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       with self.captureWritesToStream(sys.stderr) as printed:
         print_op = logging_ops.print_v2(tensor, summarize=-1)
@@ -118,7 +118,7 @@ class PrintV2Test(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testPrintOneVariable(self):
-    with self.test_session():
+    with self.cached_session():
       var = variables.Variable(math_ops.range(10))
       if not context.executing_eagerly():
         variables.global_variables_initializer().run()
@@ -130,7 +130,7 @@ class PrintV2Test(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testPrintTwoVariablesInStructWithAssignAdd(self):
-    with self.test_session():
+    with self.cached_session():
       var_one = variables.Variable(2.14)
       plus_one = var_one.assign_add(1.0)
       var_two = variables.Variable(math_ops.range(10))
@@ -145,7 +145,7 @@ class PrintV2Test(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testPrintTwoTensors(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       with self.captureWritesToStream(sys.stderr) as printed:
         print_op = logging_ops.print_v2(tensor, tensor * 10)
@@ -155,7 +155,7 @@ class PrintV2Test(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testPrintPlaceholderGeneration(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       with self.captureWritesToStream(sys.stderr) as printed:
         print_op = logging_ops.print_v2("{}6", {"{}": tensor * 10})
@@ -165,7 +165,7 @@ class PrintV2Test(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testPrintNoTensors(self):
-    with self.test_session():
+    with self.cached_session():
       with self.captureWritesToStream(sys.stderr) as printed:
         print_op = logging_ops.print_v2(23, [23, 5], {"6": 12})
         self.evaluate(print_op)
@@ -174,7 +174,7 @@ class PrintV2Test(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testPrintFloatScalar(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = ops.convert_to_tensor(434.43)
       with self.captureWritesToStream(sys.stderr) as printed:
         print_op = logging_ops.print_v2(tensor)
@@ -184,7 +184,7 @@ class PrintV2Test(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testPrintStringScalar(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = ops.convert_to_tensor("scalar")
       with self.captureWritesToStream(sys.stderr) as printed:
         print_op = logging_ops.print_v2(tensor)
@@ -194,7 +194,7 @@ class PrintV2Test(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testPrintComplexTensorStruct(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       small_tensor = constant_op.constant([0.3, 12.4, -16.1])
       big_tensor = math_ops.mul(tensor, 10)
@@ -214,7 +214,7 @@ class PrintV2Test(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testPrintSparseTensor(self):
-    with self.test_session():
+    with self.cached_session():
       ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
       val = [0, 10, 13, 4, 14, 32, 33]
       shape = [5, 6]
@@ -238,7 +238,7 @@ class PrintV2Test(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testPrintSparseTensorInDataStruct(self):
-    with self.test_session():
+    with self.cached_session():
       ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
       val = [0, 10, 13, 4, 14, 32, 33]
       shape = [5, 6]
@@ -262,7 +262,7 @@ class PrintV2Test(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testPrintOneTensorStdout(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       with self.captureWritesToStream(sys.stdout) as printed:
         print_op = logging_ops.print_v2(
@@ -273,7 +273,7 @@ class PrintV2Test(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testInvalidOutputStreamRaisesError(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       with self.assertRaises(ValueError):
         print_op = logging_ops.print_v2(
@@ -281,13 +281,13 @@ class PrintV2Test(test.TestCase):
         self.evaluate(print_op)
 
   def testPrintOpName(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       print_op = logging_ops.print_v2(tensor, name="print_name")
       self.assertEqual(print_op.name, "print_name")
 
   def testNoDuplicateFormatOpGraphModeAfterExplicitFormat(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       formatted_string = string_ops.string_format("{}", tensor)
       print_op = logging_ops.print_v2(formatted_string)
@@ -298,7 +298,7 @@ class PrintV2Test(test.TestCase):
       self.assertEqual(len(format_ops), 1)
 
   def testPrintOneTensorEagerOnOpCreate(self):
-    with self.test_session():
+    with self.cached_session():
       with context.eager_mode():
         tensor = math_ops.range(10)
         expected = "[0 1 2 ... 7 8 9]"
diff --git a/tensorflow/python/kernel_tests/string_format_op_test.py b/tensorflow/python/kernel_tests/string_format_op_test.py
index afa71db9092..74a5072bab9 100644
--- a/tensorflow/python/kernel_tests/string_format_op_test.py
+++ b/tensorflow/python/kernel_tests/string_format_op_test.py
@@ -34,14 +34,14 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatOneTensorOneDim(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       format_output = string_ops.string_format("{}", tensor)
       out = self.evaluate(format_output)
       expected = "[0 1 2 ... 7 8 9]"
       self.assertEqual(compat.as_text(out), expected)
 
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(10)
       format_output = string_ops.string_format("{}", [tensor])
       out = self.evaluate(format_output)
@@ -50,7 +50,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatOneVariableScalar(self):
-    with self.test_session():
+    with self.cached_session():
       var = variables.Variable(3.34)
       format_output = string_ops.string_format("{}", [var])
       if not context.executing_eagerly():
@@ -61,7 +61,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatOneVariableOneDim(self):
-    with self.test_session():
+    with self.cached_session():
       var = variables.Variable(math_ops.range(10))
       format_output = string_ops.string_format("{}", [var])
       if not context.executing_eagerly():
@@ -72,7 +72,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatTwoVariablesWithAssignAdd(self):
-    with self.test_session():
+    with self.cached_session():
       var_one = variables.Variable(2.14)
       plus_one = var_one.assign_add(1.0)
       var_two = variables.Variable(math_ops.range(10))
@@ -86,7 +86,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatOneTensorOneDimFloat(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = constant_op.constant([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
       format_output = string_ops.string_format("{}", tensor)
       out = self.evaluate(format_output)
@@ -95,7 +95,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatOneTensorOneDimMatchesSummarize(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(6)
       format_output = string_ops.string_format("{}", tensor, summarize=3)
       out = self.evaluate(format_output)
@@ -104,28 +104,28 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatOneTensorOneDimVarySummarize(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(6)
       format_output = string_ops.string_format("{}", tensor, summarize=-1)
       out = self.evaluate(format_output)
       expected = "[0 1 2 3 4 5]"
       self.assertEqual(compat.as_text(out), expected)
 
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(6)
       format_output = string_ops.string_format("{}", tensor, summarize=1)
       out = self.evaluate(format_output)
       expected = "[0 ... 5]"
       self.assertEqual(compat.as_text(out), expected)
 
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(6)
       format_output = string_ops.string_format("{}", tensor, summarize=2)
       out = self.evaluate(format_output)
       expected = "[0 1 ... 4 5]"
       self.assertEqual(compat.as_text(out), expected)
 
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(6)
       format_output = string_ops.string_format("{}", tensor, summarize=10)
       out = self.evaluate(format_output)
@@ -134,7 +134,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatOneTensorOneDimAlmostSummarize(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = math_ops.range(5)
       format_output = string_ops.string_format("{}", tensor, summarize=3)
       out = self.evaluate(format_output)
@@ -143,7 +143,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatOneTensorTwoDimLessThanSummarize(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.reshape(math_ops.range(4), [2, 2])
       format_output = string_ops.string_format("{}", tensor, summarize=3)
       out = self.evaluate(format_output)
@@ -153,7 +153,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatOneTensorTwoDim(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.reshape(math_ops.range(100), [10, 10])
       format_output = string_ops.string_format("{}", tensor)
       out = self.evaluate(format_output)
@@ -168,7 +168,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatOneTensorTwoDimSummarizeTwo(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.reshape(math_ops.range(100), [10, 10])
       format_output = string_ops.string_format("{}", tensor, summarize=2)
       out = self.evaluate(format_output)
@@ -181,7 +181,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatOneTensorThreeDim(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.reshape(math_ops.range(1000), [10, 10, 10])
       format_output = string_ops.string_format("{}", tensor)
       out = self.evaluate(format_output)
@@ -237,7 +237,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatOneTensorTemplatePrefix(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.reshape(math_ops.range(100), [10, 10])
       format_output = string_ops.string_format("tensor summary: {}", tensor)
       out = self.evaluate(format_output)
@@ -252,7 +252,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatOneTensorTemplatePrefixAndSuffix(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.reshape(math_ops.range(100), [10, 10])
       format_output = string_ops.string_format("tensor summary: {}, suffix",
                                                tensor)
@@ -268,7 +268,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatOneTensorTemplateSuffix(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.reshape(math_ops.range(100), [10, 10])
       format_output = string_ops.string_format("{}, suffix", tensor)
       out = self.evaluate(format_output)
@@ -283,7 +283,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatNoTensor(self):
-    with self.test_session():
+    with self.cached_session():
       format_output = string_ops.string_format("No tensor.", ())
       out = self.evaluate(format_output)
       expected = "No tensor."
@@ -291,7 +291,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatMultiTensor(self):
-    with self.test_session():
+    with self.cached_session():
       tensor_one = array_ops.reshape(math_ops.range(100), [10, 10])
       tensor_two = tensor_one * 10
       format_output = string_ops.string_format("One: {},\nTwo: {}",
@@ -315,7 +315,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatSummarizeOne(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.reshape(math_ops.range(100), [10, 10])
       format_output = string_ops.string_format("tensor summary: {}", tensor,
                                                summarize=1)
@@ -327,7 +327,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatSummarizeTwo(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.reshape(math_ops.range(100), [10, 10])
       format_output = string_ops.string_format("tensor summary: {}", tensor,
                                                summarize=2)
@@ -341,7 +341,7 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testFormatPlaceholder(self):
-    with self.test_session():
+    with self.cached_session():
       tensor = array_ops.reshape(math_ops.range(100), [10, 10])
       format_output = string_ops.string_format("tensor summary: %t%", tensor,
                                                placeholder="%t%")
@@ -357,21 +357,21 @@ class StringFormatOpTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testTensorCountMustMatchPlaceholderCount(self):
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           ValueError, r"2 placeholder\(s\) in template does not match 1 "
                       r"tensor\(s\) provided as input"):
         tensor = math_ops.range(10)
         format_output = string_ops.string_format("{} {}", tensor)
         self.evaluate(format_output)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           ValueError, r"2 placeholder\(s\) in template does not match 1 "
                       r"tensor\(s\) provided as input"):
         tensor = math_ops.range(10)
         format_output = string_ops.string_format("{} {}", [tensor])
         self.evaluate(format_output)
-    with self.test_session():
+    with self.cached_session():
       with self.assertRaisesRegexp(
           ValueError, r"1 placeholder\(s\) in template does not match 2 "
                       r"tensor\(s\) provided as input"):
diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py
index 0c3b72408ed..3a070544e83 100644
--- a/tensorflow/python/kernel_tests/while_v2_test.py
+++ b/tensorflow/python/kernel_tests/while_v2_test.py
@@ -41,7 +41,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
     x = constant_op.constant(2.)
     ret = while_loop_v2(lambda v: v < 8., lambda v: v * v, [x])
     grad = gradients_impl.gradients(ret, [x])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(sess.run(ret), 16.)
       self.assertSequenceEqual(sess.run(grad), [32.])
 
@@ -58,7 +58,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
 
     # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
     grad = gradients_impl.gradients(ret, [x])  # [2*x*y]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertSequenceEqual(sess.run(ret), [45., 3.])
       self.assertSequenceEqual(sess.run(grad), [9.])
 
@@ -81,7 +81,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
     grady_0 = gradients_impl.gradients(ret[0], [y])  # [2*x*y + x**2]
     grady_1 = gradients_impl.gradients(ret[1], [y])  # [x + 1]
     grady_2 = gradients_impl.gradients(ret, [y])  # [2*x*y + x**2 + x + 1]
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertSequenceEqual(sess.run(ret), [120., 23.])
       self.assertSequenceEqual(sess.run(gradx_0), [39.])
       self.assertSequenceEqual(sess.run(gradx_1), [4.])
@@ -96,7 +96,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
     ret2 = while_loop_v2(lambda v: v < 16., lambda v: v * v, ret1)  # x**4
     grad = gradients_impl.gradients(ret2, [x])  # 4x**3
     grad_grad = gradients_impl.gradients(grad, [x])  # 12x**2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertSequenceEqual(sess.run(grad), [32.])
       self.assertSequenceEqual(sess.run(grad_grad), [48.])
 
@@ -105,7 +105,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
     ret = while_loop_v2(lambda v: v < 8., lambda v: v**2, [x])  # x**4
     grad = gradients_impl.gradients(ret, [x])  # 4x**3
     grad_grad = gradients_impl.gradients(grad, [x])  # 12x**2
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(sess.run(ret), 16.)
       self.assertSequenceEqual(sess.run(grad), [32.])
       self.assertSequenceEqual(sess.run(grad_grad), [48.])
@@ -148,7 +148,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
     y = constant_op.constant(1.)
     ret = while_loop_v2(lambda v: v + y < 9., lambda v: v * 3., [x])
     grad = gradients_impl.gradients(ret, [x])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(sess.run(ret), 18.)
       self.assertSequenceEqual(sess.run(grad), [9.])
 
@@ -157,7 +157,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
     y = constant_op.constant(3.)
     ret = while_loop_v2(lambda v: v < 8., lambda v: v * y, [x])
     grad = gradients_impl.gradients(ret, [x])
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(sess.run(ret), 18.)
       self.assertSequenceEqual(sess.run(grad), [9.])
 
@@ -178,7 +178,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
 
     ret = while_loop_v2(Cond, Body, [x, tensor_list])
     grad = gradients_impl.gradients(ret[0], x)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(sess.run(ret[0]), 16.)
       self.assertSequenceEqual(sess.run(grad), [32.])
 
@@ -212,7 +212,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
     self.assertEqual(accumulator_count, 1)
 
     grad = gradients_impl.gradients(ret[0], x)
-    with self.test_session() as sess:
+    with self.cached_session() as sess:
       self.assertEqual(sess.run(ret[0]), 16.)
       self.assertSequenceEqual(sess.run(grad), [32.])
 
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index da45f6e3e69..35fdee4fad8 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -3673,7 +3673,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
     # Note: There are multiple versions of non_max_suppression v2, v3, v4.
     # gen_image_ops.non_max_suppression_v2:
     for dtype in [np.float16, np.float32]:
-      with self.test_session():
+      with self.cached_session():
         boxes = constant_op.constant(boxes_np, dtype=dtype)
         scores = constant_op.constant(scores_np, dtype=dtype)
         max_output_size = constant_op.constant(max_output_size_np)
@@ -3683,7 +3683,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
         self.assertAllClose(selected_indices, [3, 0, 5])
     # image_ops.non_max_suppression = gen_image_ops.non_max_suppression_v3.
     for dtype in [np.float16, np.float32]:
-      with self.test_session():
+      with self.cached_session():
         boxes = constant_op.constant(boxes_np, dtype=dtype)
         scores = constant_op.constant(scores_np, dtype=dtype)
         max_output_size = constant_op.constant(max_output_size_np)
@@ -3694,7 +3694,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
     # gen_image_ops.non_max_suppression_v4.
     score_threshold = float('-inf')
     for dtype in [np.float16, np.float32]:
-      with self.test_session():
+      with self.cached_session():
         boxes = constant_op.constant(boxes_np, dtype=dtype)
         scores = constant_op.constant(scores_np, dtype=dtype)
         max_output_size = constant_op.constant(max_output_size_np)
diff --git a/tensorflow/python/training/ftrl_test.py b/tensorflow/python/training/ftrl_test.py
index 09d6fe36d35..15c50bc8788 100644
--- a/tensorflow/python/training/ftrl_test.py
+++ b/tensorflow/python/training/ftrl_test.py
@@ -218,7 +218,7 @@ class FtrlOptimizerTest(test.TestCase):
   def testFtrlWithL1_L2_L2ShrinkageSparse(self):
     """Tests the new FTRL op with support for l2 shrinkage on sparse grads."""
     for dtype in [dtypes.half, dtypes.float32]:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
         var1 = variables.Variable([[4.0], [3.0]], dtype=dtype)
         grads0 = ops.IndexedSlices(
@@ -252,7 +252,7 @@ class FtrlOptimizerTest(test.TestCase):
   def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
     """Verifies that l2 shrinkage in FTRL does not change lr schedule."""
     for dtype in [dtypes.half, dtypes.float32]:
-      with self.test_session() as sess:
+      with self.cached_session() as sess:
         var0 = variables.Variable([1.0, 2.0], dtype=dtype)
         var1 = variables.Variable([1.0, 2.0], dtype=dtype)
         grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
diff --git a/tensorflow/python/training/learning_rate_decay_v2_test.py b/tensorflow/python/training/learning_rate_decay_v2_test.py
index 0f2d60dafc8..b2ac93f06fe 100644
--- a/tensorflow/python/training/learning_rate_decay_v2_test.py
+++ b/tensorflow/python/training/learning_rate_decay_v2_test.py
@@ -62,7 +62,7 @@ class LRDecayTestV2(test_util.TensorFlowTestCase):
       self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
 
   def testVariables(self):
-    with self.test_session():
+    with self.cached_session():
       step = variables.Variable(1)
       assign_1 = step.assign(1)
       assign_2 = step.assign(2)