Move from deprecated self.test_session() to self.cached_session().

self.test_session() has been deprecated in 9962eb5e84 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.

PiperOrigin-RevId: 214300210
This commit is contained in:
A. Unique TensorFlower 2018-09-24 11:28:07 -07:00 committed by TensorFlower Gardener
parent 5fbb064ba1
commit 28eeda839f
19 changed files with 116 additions and 116 deletions

View File

@ -92,7 +92,7 @@ class ErrorsTest(tf.test.TestCase):
compiled_fn = ag.to_graph(test_fn) compiled_fn = ag.to_graph(test_fn)
with self.assertRaises(ag.TfRuntimeError) as error: with self.assertRaises(ag.TfRuntimeError) as error:
with self.test_session() as sess: with self.cached_session() as sess:
x = compiled_fn(tf.constant([4, 8])) x = compiled_fn(tf.constant([4, 8]))
with ag.improved_errors(compiled_fn): with ag.improved_errors(compiled_fn):
sess.run(x) sess.run(x)
@ -134,7 +134,7 @@ class ErrorsTest(tf.test.TestCase):
# frame with "g" as the function name but because we don't yet add # frame with "g" as the function name but because we don't yet add
# try/except blocks to inner functions the name is "tf__g". # try/except blocks to inner functions the name is "tf__g".
with self.assertRaises(ag.TfRuntimeError) as error: with self.assertRaises(ag.TfRuntimeError) as error:
with self.test_session() as sess: with self.cached_session() as sess:
x = compiled_fn(tf.constant([4, 8])) x = compiled_fn(tf.constant([4, 8]))
with ag.improved_errors(compiled_fn): with ag.improved_errors(compiled_fn):
sess.run(x) sess.run(x)

View File

@ -54,7 +54,7 @@ class RuntimeErrorsTest(test.TestCase):
ops = zero_div_caller() ops = zero_div_caller()
with self.assertRaises(errors.TfRuntimeError) as cm: with self.assertRaises(errors.TfRuntimeError) as cm:
with errors.improved_errors(zero_div_caller): with errors.improved_errors(zero_div_caller):
with self.test_session() as sess: with self.cached_session() as sess:
sess.run(ops) sess.run(ops)
for frame in cm.exception.custom_traceback: for frame in cm.exception.custom_traceback:
@ -69,7 +69,7 @@ class RuntimeErrorsTest(test.TestCase):
ops = zero_div_caller() ops = zero_div_caller()
with self.assertRaises(errors.TfRuntimeError) as cm: with self.assertRaises(errors.TfRuntimeError) as cm:
with errors.improved_errors(zero_div_caller): with errors.improved_errors(zero_div_caller):
with self.test_session() as sess: with self.cached_session() as sess:
sess.run(ops) sess.run(ops)
all_function_names = set() all_function_names = set()
@ -86,7 +86,7 @@ class RuntimeErrorsTest(test.TestCase):
ops = zero_div_caller() ops = zero_div_caller()
with self.assertRaises(tf_errors.InvalidArgumentError): with self.assertRaises(tf_errors.InvalidArgumentError):
with errors.improved_errors(zero_div_caller): with errors.improved_errors(zero_div_caller):
with self.test_session() as sess: with self.cached_session() as sess:
sess.run(ops) sess.run(ops)
def test_improved_errors_validation(self): def test_improved_errors_validation(self):

View File

@ -55,7 +55,7 @@ class ApiTest(test.TestCase):
return x return x
tc = TestClass() tc = TestClass()
with self.test_session() as sess: with self.cached_session() as sess:
x = tc.test_method( x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1), constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2)) constant_op.constant(-2))
@ -75,7 +75,7 @@ class ApiTest(test.TestCase):
return x return x
tc = TestClass() tc = TestClass()
with self.test_session() as sess: with self.cached_session() as sess:
x = tc.test_method( x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1), constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2)) constant_op.constant(-2))
@ -96,7 +96,7 @@ class ApiTest(test.TestCase):
return x return x
tc = TestClass() tc = TestClass()
with self.test_session() as sess: with self.cached_session() as sess:
x = tc.test_method( x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1), constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2)) constant_op.constant(-2))
@ -122,7 +122,7 @@ class ApiTest(test.TestCase):
return x return x
tc = TestClass() tc = TestClass()
with self.test_session() as sess: with self.cached_session() as sess:
x = tc.test_method( x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1), constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2)) constant_op.constant(-2))
@ -145,7 +145,7 @@ class ApiTest(test.TestCase):
return x return x
tc = TestClass() tc = TestClass()
with self.test_session() as sess: with self.cached_session() as sess:
x = tc.test_method( x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1), constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2)) constant_op.constant(-2))
@ -185,7 +185,7 @@ class ApiTest(test.TestCase):
return x return x
tc = TestClass() tc = TestClass()
with self.test_session() as sess: with self.cached_session() as sess:
x = tc.test_method( x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1), constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2)) constant_op.constant(-2))
@ -202,7 +202,7 @@ class ApiTest(test.TestCase):
return -x return -x
return x return x
with self.test_session() as sess: with self.cached_session() as sess:
x = api.converted_call(test_fn, api.ConversionOptions.new(), x = api.converted_call(test_fn, api.ConversionOptions.new(),
constant_op.constant(-1)) constant_op.constant(-1))
self.assertEqual(1, sess.run(x)) self.assertEqual(1, sess.run(x))
@ -219,7 +219,7 @@ class ApiTest(test.TestCase):
return -self.x return -self.x
return self.x return self.x
with self.test_session() as sess: with self.cached_session() as sess:
tc = TestClass(constant_op.constant(-1)) tc = TestClass(constant_op.constant(-1))
x = api.converted_call(tc.test_method, api.ConversionOptions.new(), tc) x = api.converted_call(tc.test_method, api.ConversionOptions.new(), tc)
self.assertEqual(1, sess.run(x)) self.assertEqual(1, sess.run(x))
@ -236,7 +236,7 @@ class ApiTest(test.TestCase):
return -self.x return -self.x
return self.x return self.x
with self.test_session() as sess: with self.cached_session() as sess:
tc = TestClass(constant_op.constant(-1)) tc = TestClass(constant_op.constant(-1))
x = api.converted_call( x = api.converted_call(
TestClass.test_method, TestClass.test_method,
@ -255,7 +255,7 @@ class ApiTest(test.TestCase):
return -self.x return -self.x
return self.x return self.x
with self.test_session() as sess: with self.cached_session() as sess:
tc = TestClass(constant_op.constant(-1)) tc = TestClass(constant_op.constant(-1))
x = api.converted_call(tc, api.ConversionOptions.new()) x = api.converted_call(tc, api.ConversionOptions.new())
self.assertEqual(1, sess.run(x)) self.assertEqual(1, sess.run(x))
@ -272,7 +272,7 @@ class ApiTest(test.TestCase):
return -self.x return -self.x
return self.x return self.x
with self.test_session() as sess: with self.cached_session() as sess:
tc = api.converted_call(TestClass, api.ConversionOptions.new(), tc = api.converted_call(TestClass, api.ConversionOptions.new(),
constant_op.constant(-1)) constant_op.constant(-1))
# tc is now a converted object. # tc is now a converted object.
@ -284,7 +284,7 @@ class ApiTest(test.TestCase):
def f(x): def f(x):
return x == 0 return x == 0
with self.test_session() as sess: with self.cached_session() as sess:
x = api.converted_call(f, api.ConversionOptions.new(), x = api.converted_call(f, api.ConversionOptions.new(),
constant_op.constant(0)) constant_op.constant(0))
self.assertTrue(sess.run(x)) self.assertTrue(sess.run(x))
@ -303,7 +303,7 @@ class ApiTest(test.TestCase):
compiled_fn = api.to_graph(test_fn) compiled_fn = api.to_graph(test_fn)
with self.test_session() as sess: with self.cached_session() as sess:
x = compiled_fn(constant_op.constant([4, 8]), 4) x = compiled_fn(constant_op.constant([4, 8]), 4)
self.assertListEqual([1, 2], sess.run(x).tolist()) self.assertListEqual([1, 2], sess.run(x).tolist())

View File

@ -33,7 +33,7 @@ class SpecialFunctionsTest(test.TestCase):
l = special_functions.tensor_list(elements) l = special_functions.tensor_list(elements)
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
with self.test_session() as sess: with self.cached_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]]) self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_tensor_list_array_from_elements(self): def test_tensor_list_array_from_elements(self):
@ -41,7 +41,7 @@ class SpecialFunctionsTest(test.TestCase):
l = special_functions.tensor_list(elements, use_tensor_array=True) l = special_functions.tensor_list(elements, use_tensor_array=True)
sl = l.stack() sl = l.stack()
with self.test_session() as sess: with self.cached_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]]) self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_stack(self): def test_stack(self):

View File

@ -36,7 +36,7 @@ class PyBuiltinsTest(test.TestCase):
def test_abs(self): def test_abs(self):
self.assertEqual(py_builtins.abs_(-1), 1) self.assertEqual(py_builtins.abs_(-1), 1)
with self.test_session() as sess: with self.cached_session() as sess:
t = py_builtins.abs_(constant_op.constant(-1)) t = py_builtins.abs_(constant_op.constant(-1))
self.assertEqual(sess.run(t), 1) self.assertEqual(sess.run(t), 1)
t = py_builtins.abs_(constant_op.constant([-1, 2, -3])) t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
@ -45,7 +45,7 @@ class PyBuiltinsTest(test.TestCase):
def test_float(self): def test_float(self):
self.assertEqual(py_builtins.float_(10), 10.0) self.assertEqual(py_builtins.float_(10), 10.0)
self.assertEqual(py_builtins.float_('10.0'), 10.0) self.assertEqual(py_builtins.float_('10.0'), 10.0)
with self.test_session() as sess: with self.cached_session() as sess:
t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64)) t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
self.assertEqual(sess.run(t), 1.0) self.assertEqual(sess.run(t), 1.0)
st = py_builtins.float_(constant_op.constant('1.0')) st = py_builtins.float_(constant_op.constant('1.0'))
@ -54,7 +54,7 @@ class PyBuiltinsTest(test.TestCase):
def test_int(self): def test_int(self):
self.assertEqual(py_builtins.int_(10.0), 10) self.assertEqual(py_builtins.int_(10.0), 10)
self.assertEqual(py_builtins.int_('11', 2), 3) self.assertEqual(py_builtins.int_('11', 2), 3)
with self.test_session() as sess: with self.cached_session() as sess:
t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64)) t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
self.assertEqual(sess.run(t), 1) self.assertEqual(sess.run(t), 1)
st = py_builtins.int_(constant_op.constant('1')) st = py_builtins.int_(constant_op.constant('1'))
@ -69,7 +69,7 @@ class PyBuiltinsTest(test.TestCase):
def test_len(self): def test_len(self):
self.assertEqual(py_builtins.len_([1, 2, 3]), 3) self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
with self.test_session() as sess: with self.cached_session() as sess:
t = py_builtins.len_(constant_op.constant([[1], [2], [3]])) t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
self.assertEqual(t, 3) self.assertEqual(t, 3)
ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5)) ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
@ -82,7 +82,7 @@ class PyBuiltinsTest(test.TestCase):
py_builtins.len_(constant_op.constant(1)) py_builtins.len_(constant_op.constant(1))
def test_len_dynamic_shape(self): def test_len_dynamic_shape(self):
with self.test_session() as sess: with self.cached_session() as sess:
p = array_ops.placeholder(dtype=dtypes.int32, shape=None) p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
t = py_builtins.len_(p) t = py_builtins.len_(p)
self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3) self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)
@ -95,7 +95,7 @@ class PyBuiltinsTest(test.TestCase):
try: try:
out_capturer = six.StringIO() out_capturer = six.StringIO()
sys.stdout = out_capturer sys.stdout = out_capturer
with self.test_session() as sess: with self.cached_session() as sess:
sess.run(py_builtins.print_(constant_op.constant('test message'), 1)) sess.run(py_builtins.print_(constant_op.constant('test message'), 1))
self.assertEqual(out_capturer.getvalue(), 'test message 1\n') self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
finally: finally:
@ -105,7 +105,7 @@ class PyBuiltinsTest(test.TestCase):
try: try:
out_capturer = six.StringIO() out_capturer = six.StringIO()
sys.stdout = out_capturer sys.stdout = out_capturer
with self.test_session() as sess: with self.cached_session() as sess:
sess.run( sess.run(
py_builtins.print_(constant_op.constant('test message'), [1, 2])) py_builtins.print_(constant_op.constant('test message'), [1, 2]))
self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
@ -118,7 +118,7 @@ class PyBuiltinsTest(test.TestCase):
self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1]) self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1])
def test_range_tensor(self): def test_range_tensor(self):
with self.test_session() as sess: with self.cached_session() as sess:
r = py_builtins.range_(constant_op.constant(3)) r = py_builtins.range_(constant_op.constant(3))
self.assertAllEqual(sess.run(r), [0, 1, 2]) self.assertAllEqual(sess.run(r), [0, 1, 2])
r = py_builtins.range_(1, constant_op.constant(3)) r = py_builtins.range_(1, constant_op.constant(3))

View File

@ -51,14 +51,14 @@ class SlicesTest(test.TestCase):
t = slices.get_item(initial_str, 1, t = slices.get_item(initial_str, 1,
slices.GetItemOpts(element_dtype=initial_str.dtype)) slices.GetItemOpts(element_dtype=initial_str.dtype))
with self.test_session() as sess: with self.cached_session() as sess:
self.assertEqual(sess.run(t), b'b') self.assertEqual(sess.run(t), b'b')
initial_list_str = constant_op.constant(['abcd', 'bcde']) initial_list_str = constant_op.constant(['abcd', 'bcde'])
t = slices.get_item(initial_list_str, 1, t = slices.get_item(initial_list_str, 1,
slices.GetItemOpts(element_dtype=initial_str.dtype)) slices.GetItemOpts(element_dtype=initial_str.dtype))
with self.test_session() as sess: with self.cached_session() as sess:
self.assertEqual(sess.run(t), b'bcde') self.assertEqual(sess.run(t), b'bcde')

View File

@ -1602,7 +1602,7 @@ class FunctionTest(test.TestCase):
defun_add = function.defun_with_attributes( defun_add = function.defun_with_attributes(
add, attributes={'experimental_3': True, 'experimental_4': 1.0}) add, attributes={'experimental_3': True, 'experimental_4': 1.0})
with context.graph_mode(), self.test_session(): with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default(): with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq = matmul(t, t) sq = matmul(t, t)
@ -1636,7 +1636,7 @@ class FunctionTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegexp(ValueError,
'.*Attribute name is not whitelisted.*'): '.*Attribute name is not whitelisted.*'):
with context.graph_mode(), self.test_session(): with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default(): with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
matmul(t, t) matmul(t, t)
@ -1647,7 +1647,7 @@ class FunctionTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegexp(ValueError,
'.*Unsupported attribute type.*'): '.*Unsupported attribute type.*'):
with context.graph_mode(), self.test_session(): with context.graph_mode(), self.cached_session():
with ops.get_default_graph().as_default(): with ops.get_default_graph().as_default():
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
add(t, t) add(t, t)

View File

@ -915,7 +915,7 @@ class TopologyConstructionTest(test.TestCase):
def test_constant_initializer_with_numpy(self): def test_constant_initializer_with_numpy(self):
with self.test_session(): with self.cached_session():
initializer = keras.initializers.Constant(np.ones((3, 2))) initializer = keras.initializers.Constant(np.ones((3, 2)))
model = keras.models.Sequential() model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,), model.add(keras.layers.Dense(2, input_shape=(3,),

View File

@ -186,7 +186,7 @@ class TestMultiGPUModel(test.TestCase):
if not check_if_compatible_devices(gpus=gpus): if not check_if_compatible_devices(gpus=gpus):
return return
with self.test_session(): with self.cached_session():
inputs = keras.Input((4, 3)) inputs = keras.Input((4, 3))
init_state = keras.Input((3,)) init_state = keras.Input((3,))
outputs = keras.layers.SimpleRNN( outputs = keras.layers.SimpleRNN(

View File

@ -934,7 +934,7 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
For example, this could happen if the final ensemble contains one tree that For example, this could happen if the final ensemble contains one tree that
got pruned up to the root. got pruned up to the root.
""" """
with self.test_session() as session: with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge( text_format.Merge(
""" """
@ -990,7 +990,7 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
def testContribsMultipleTreeWhenFirstTreeIsABiasNode(self): def testContribsMultipleTreeWhenFirstTreeIsABiasNode(self):
"""Tests case when, after training, first tree contains only a bias node.""" """Tests case when, after training, first tree contains only a bias node."""
with self.test_session() as session: with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge( text_format.Merge(
""" """

View File

@ -78,7 +78,7 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64) self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64)
def testBasicQuantileBucketsSingleResource(self): def testBasicQuantileBucketsSingleResource(self):
with self.test_session() as sess: with self.cached_session() as sess:
quantile_accumulator_handle = self.create_resource("floats", self.eps, quantile_accumulator_handle = self.create_resource("floats", self.eps,
self.max_elements, 2) self.max_elements, 2)
resources.initialize_resources(resources.shared_resources()).run() resources.initialize_resources(resources.shared_resources()).run()
@ -102,7 +102,7 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval()) self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
def testBasicQuantileBucketsMultipleResources(self): def testBasicQuantileBucketsMultipleResources(self):
with self.test_session() as sess: with self.cached_session() as sess:
quantile_accumulator_handle_0 = self.create_resource("float_0", self.eps, quantile_accumulator_handle_0 = self.create_resource("float_0", self.eps,
self.max_elements) self.max_elements)
quantile_accumulator_handle_1 = self.create_resource("float_1", self.eps, quantile_accumulator_handle_1 = self.create_resource("float_1", self.eps,

View File

@ -76,7 +76,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
[1., 1.], is_positive_definite=True, name="A") [1., 1.], is_positive_definite=True, name="A")
op_b = linalg.LinearOperatorDiag( op_b = linalg.LinearOperatorDiag(
[2., 2.], is_positive_definite=True, name="B") [2., 2.], is_positive_definite=True, name="B")
with self.test_session(): with self.cached_session():
op_sum = add_operators([op_a, op_b]) op_sum = add_operators([op_a, op_b])
self.assertEqual(1, len(op_sum)) self.assertEqual(1, len(op_sum))
op = op_sum[0] op = op_sum[0]
@ -98,7 +98,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
[2., 2.], is_positive_definite=True, name="op2") [2., 2.], is_positive_definite=True, name="op2")
op3 = linalg.LinearOperatorDiag( op3 = linalg.LinearOperatorDiag(
[3., 3.], is_positive_definite=True, name="op3") [3., 3.], is_positive_definite=True, name="op3")
with self.test_session(): with self.cached_session():
op_sum = add_operators([op1, op2, op3]) op_sum = add_operators([op1, op2, op3])
self.assertEqual(1, len(op_sum)) self.assertEqual(1, len(op_sum))
op = op_sum[0] op = op_sum[0]
@ -121,7 +121,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
name="tril") name="tril")
op3 = linalg.LinearOperatorDiag( op3 = linalg.LinearOperatorDiag(
[3., 3.], is_non_singular=True, name="diag_b") [3., 3.], is_non_singular=True, name="diag_b")
with self.test_session(): with self.cached_session():
op_sum = add_operators([op1, op2, op3]) op_sum = add_operators([op1, op2, op3])
self.assertEqual(1, len(op_sum)) self.assertEqual(1, len(op_sum))
op = op_sum[0] op = op_sum[0]
@ -143,7 +143,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
op2 = linalg.LinearOperatorLowerTriangular( op2 = linalg.LinearOperatorLowerTriangular(
[[2., 0.], [1.5, 2.]], name="tril") [[2., 0.], [1.5, 2.]], name="tril")
op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b") op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b")
with self.test_session(): with self.cached_session():
op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator") op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator")
self.assertEqual(1, len(op_sum)) self.assertEqual(1, len(op_sum))
op = op_sum[0] op = op_sum[0]
@ -233,7 +233,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase):
self.assertEqual(2, len(op_sum)) self.assertEqual(2, len(op_sum))
found_diag = False found_diag = False
found_tril = False found_tril = False
with self.test_session(): with self.cached_session():
for op in op_sum: for op in op_sum:
if isinstance(op, linalg.LinearOperatorDiag): if isinstance(op, linalg.LinearOperatorDiag):
found_diag = True found_diag = True
@ -273,7 +273,7 @@ class AddAndReturnScaledIdentityTest(test.TestCase):
operator = self._adder.add(id1, id2, "my_operator", hints) operator = self._adder.add(id1, id2, "my_operator", hints)
self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity) self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
with self.test_session(): with self.cached_session():
self.assertAllClose(2 * self.assertAllClose(2 *
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(), linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
operator.to_dense().eval()) operator.to_dense().eval())
@ -291,7 +291,7 @@ class AddAndReturnScaledIdentityTest(test.TestCase):
operator = self._adder.add(id1, id2, "my_operator", hints) operator = self._adder.add(id1, id2, "my_operator", hints)
self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity) self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
with self.test_session(): with self.cached_session():
self.assertAllClose(3.2 * self.assertAllClose(3.2 *
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(), linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
operator.to_dense().eval()) operator.to_dense().eval())
@ -310,7 +310,7 @@ class AddAndReturnScaledIdentityTest(test.TestCase):
operator = self._adder.add(id1, id2, "my_operator", hints) operator = self._adder.add(id1, id2, "my_operator", hints)
self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity) self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
with self.test_session(): with self.cached_session():
self.assertAllClose(1.2 * self.assertAllClose(1.2 *
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(), linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
operator.to_dense().eval()) operator.to_dense().eval())
@ -334,7 +334,7 @@ class AddAndReturnDiagTest(test.TestCase):
operator = self._adder.add(id1, id2, "my_operator", hints) operator = self._adder.add(id1, id2, "my_operator", hints)
self.assertIsInstance(operator, linalg.LinearOperatorDiag) self.assertIsInstance(operator, linalg.LinearOperatorDiag)
with self.test_session(): with self.cached_session():
self.assertAllClose(2 * self.assertAllClose(2 *
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(), linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
operator.to_dense().eval()) operator.to_dense().eval())
@ -354,7 +354,7 @@ class AddAndReturnDiagTest(test.TestCase):
operator = self._adder.add(op1, op2, "my_operator", hints) operator = self._adder.add(op1, op2, "my_operator", hints)
self.assertIsInstance(operator, linalg.LinearOperatorDiag) self.assertIsInstance(operator, linalg.LinearOperatorDiag)
with self.test_session(): with self.cached_session():
self.assertAllClose( self.assertAllClose(
linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(), linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(),
operator.to_dense().eval()) operator.to_dense().eval())
@ -379,7 +379,7 @@ class AddAndReturnTriLTest(test.TestCase):
operator = self._adder.add(diag, tril, "my_operator", hints) operator = self._adder.add(diag, tril, "my_operator", hints)
self.assertIsInstance(operator, linalg.LinearOperatorLowerTriangular) self.assertIsInstance(operator, linalg.LinearOperatorLowerTriangular)
with self.test_session(): with self.cached_session():
self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval()) self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval())
self.assertTrue(operator.is_positive_definite) self.assertTrue(operator.is_positive_definite)
self.assertTrue(operator.is_non_singular) self.assertTrue(operator.is_non_singular)
@ -401,7 +401,7 @@ class AddAndReturnMatrixTest(test.TestCase):
operator = self._adder.add(diag1, diag2, "my_operator", hints) operator = self._adder.add(diag1, diag2, "my_operator", hints)
self.assertIsInstance(operator, linalg.LinearOperatorFullMatrix) self.assertIsInstance(operator, linalg.LinearOperatorFullMatrix)
with self.test_session(): with self.cached_session():
self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval()) self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval())
self.assertFalse(operator.is_positive_definite) self.assertFalse(operator.is_positive_definite)
self.assertFalse(operator.is_non_singular) self.assertFalse(operator.is_non_singular)

View File

@ -31,7 +31,7 @@ class PrintV2LoggingLevelTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testPrintOneTensorLogInfo(self): def testPrintOneTensorLogInfo(self):
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
with self.captureWritesToStream(sys.stderr) as printed: with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2( print_op = logging_ops.print_v2(
@ -43,7 +43,7 @@ class PrintV2LoggingLevelTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testPrintOneTensorLogWarning(self): def testPrintOneTensorLogWarning(self):
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
with self.captureWritesToStream(sys.stderr) as printed: with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2( print_op = logging_ops.print_v2(
@ -55,7 +55,7 @@ class PrintV2LoggingLevelTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testPrintOneTensorLogError(self): def testPrintOneTensorLogError(self):
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
with self.captureWritesToStream(sys.stderr) as printed: with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2( print_op = logging_ops.print_v2(

View File

@ -69,7 +69,7 @@ class PrintV2Test(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testPrintOneTensor(self): def testPrintOneTensor(self):
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
with self.captureWritesToStream(sys.stderr) as printed: with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2(tensor) print_op = logging_ops.print_v2(tensor)
@ -80,7 +80,7 @@ class PrintV2Test(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testPrintOneTensorVarySummarize(self): def testPrintOneTensorVarySummarize(self):
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
with self.captureWritesToStream(sys.stderr) as printed: with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2(tensor, summarize=1) print_op = logging_ops.print_v2(tensor, summarize=1)
@ -89,7 +89,7 @@ class PrintV2Test(test.TestCase):
expected = "[0 ... 9]" expected = "[0 ... 9]"
self.assertTrue((expected + "\n") in printed.contents()) self.assertTrue((expected + "\n") in printed.contents())
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
with self.captureWritesToStream(sys.stderr) as printed: with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2(tensor, summarize=2) print_op = logging_ops.print_v2(tensor, summarize=2)
@ -98,7 +98,7 @@ class PrintV2Test(test.TestCase):
expected = "[0 1 ... 8 9]" expected = "[0 1 ... 8 9]"
self.assertTrue((expected + "\n") in printed.contents()) self.assertTrue((expected + "\n") in printed.contents())
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
with self.captureWritesToStream(sys.stderr) as printed: with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2(tensor, summarize=3) print_op = logging_ops.print_v2(tensor, summarize=3)
@ -107,7 +107,7 @@ class PrintV2Test(test.TestCase):
expected = "[0 1 2 ... 7 8 9]" expected = "[0 1 2 ... 7 8 9]"
self.assertTrue((expected + "\n") in printed.contents()) self.assertTrue((expected + "\n") in printed.contents())
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
with self.captureWritesToStream(sys.stderr) as printed: with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2(tensor, summarize=-1) print_op = logging_ops.print_v2(tensor, summarize=-1)
@ -118,7 +118,7 @@ class PrintV2Test(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testPrintOneVariable(self): def testPrintOneVariable(self):
with self.test_session(): with self.cached_session():
var = variables.Variable(math_ops.range(10)) var = variables.Variable(math_ops.range(10))
if not context.executing_eagerly(): if not context.executing_eagerly():
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
@ -130,7 +130,7 @@ class PrintV2Test(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testPrintTwoVariablesInStructWithAssignAdd(self): def testPrintTwoVariablesInStructWithAssignAdd(self):
with self.test_session(): with self.cached_session():
var_one = variables.Variable(2.14) var_one = variables.Variable(2.14)
plus_one = var_one.assign_add(1.0) plus_one = var_one.assign_add(1.0)
var_two = variables.Variable(math_ops.range(10)) var_two = variables.Variable(math_ops.range(10))
@ -145,7 +145,7 @@ class PrintV2Test(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testPrintTwoTensors(self): def testPrintTwoTensors(self):
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
with self.captureWritesToStream(sys.stderr) as printed: with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2(tensor, tensor * 10) print_op = logging_ops.print_v2(tensor, tensor * 10)
@ -155,7 +155,7 @@ class PrintV2Test(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testPrintPlaceholderGeneration(self): def testPrintPlaceholderGeneration(self):
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
with self.captureWritesToStream(sys.stderr) as printed: with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2("{}6", {"{}": tensor * 10}) print_op = logging_ops.print_v2("{}6", {"{}": tensor * 10})
@ -165,7 +165,7 @@ class PrintV2Test(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testPrintNoTensors(self): def testPrintNoTensors(self):
with self.test_session(): with self.cached_session():
with self.captureWritesToStream(sys.stderr) as printed: with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2(23, [23, 5], {"6": 12}) print_op = logging_ops.print_v2(23, [23, 5], {"6": 12})
self.evaluate(print_op) self.evaluate(print_op)
@ -174,7 +174,7 @@ class PrintV2Test(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testPrintFloatScalar(self): def testPrintFloatScalar(self):
with self.test_session(): with self.cached_session():
tensor = ops.convert_to_tensor(434.43) tensor = ops.convert_to_tensor(434.43)
with self.captureWritesToStream(sys.stderr) as printed: with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2(tensor) print_op = logging_ops.print_v2(tensor)
@ -184,7 +184,7 @@ class PrintV2Test(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testPrintStringScalar(self): def testPrintStringScalar(self):
with self.test_session(): with self.cached_session():
tensor = ops.convert_to_tensor("scalar") tensor = ops.convert_to_tensor("scalar")
with self.captureWritesToStream(sys.stderr) as printed: with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2(tensor) print_op = logging_ops.print_v2(tensor)
@ -194,7 +194,7 @@ class PrintV2Test(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testPrintComplexTensorStruct(self): def testPrintComplexTensorStruct(self):
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
small_tensor = constant_op.constant([0.3, 12.4, -16.1]) small_tensor = constant_op.constant([0.3, 12.4, -16.1])
big_tensor = math_ops.mul(tensor, 10) big_tensor = math_ops.mul(tensor, 10)
@ -214,7 +214,7 @@ class PrintV2Test(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testPrintSparseTensor(self): def testPrintSparseTensor(self):
with self.test_session(): with self.cached_session():
ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]] ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
val = [0, 10, 13, 4, 14, 32, 33] val = [0, 10, 13, 4, 14, 32, 33]
shape = [5, 6] shape = [5, 6]
@ -238,7 +238,7 @@ class PrintV2Test(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testPrintSparseTensorInDataStruct(self): def testPrintSparseTensorInDataStruct(self):
with self.test_session(): with self.cached_session():
ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]] ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
val = [0, 10, 13, 4, 14, 32, 33] val = [0, 10, 13, 4, 14, 32, 33]
shape = [5, 6] shape = [5, 6]
@ -262,7 +262,7 @@ class PrintV2Test(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testPrintOneTensorStdout(self): def testPrintOneTensorStdout(self):
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
with self.captureWritesToStream(sys.stdout) as printed: with self.captureWritesToStream(sys.stdout) as printed:
print_op = logging_ops.print_v2( print_op = logging_ops.print_v2(
@ -273,7 +273,7 @@ class PrintV2Test(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testInvalidOutputStreamRaisesError(self): def testInvalidOutputStreamRaisesError(self):
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
print_op = logging_ops.print_v2( print_op = logging_ops.print_v2(
@ -281,13 +281,13 @@ class PrintV2Test(test.TestCase):
self.evaluate(print_op) self.evaluate(print_op)
def testPrintOpName(self): def testPrintOpName(self):
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
print_op = logging_ops.print_v2(tensor, name="print_name") print_op = logging_ops.print_v2(tensor, name="print_name")
self.assertEqual(print_op.name, "print_name") self.assertEqual(print_op.name, "print_name")
def testNoDuplicateFormatOpGraphModeAfterExplicitFormat(self): def testNoDuplicateFormatOpGraphModeAfterExplicitFormat(self):
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
formatted_string = string_ops.string_format("{}", tensor) formatted_string = string_ops.string_format("{}", tensor)
print_op = logging_ops.print_v2(formatted_string) print_op = logging_ops.print_v2(formatted_string)
@ -298,7 +298,7 @@ class PrintV2Test(test.TestCase):
self.assertEqual(len(format_ops), 1) self.assertEqual(len(format_ops), 1)
def testPrintOneTensorEagerOnOpCreate(self): def testPrintOneTensorEagerOnOpCreate(self):
with self.test_session(): with self.cached_session():
with context.eager_mode(): with context.eager_mode():
tensor = math_ops.range(10) tensor = math_ops.range(10)
expected = "[0 1 2 ... 7 8 9]" expected = "[0 1 2 ... 7 8 9]"

View File

@ -34,14 +34,14 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatOneTensorOneDim(self): def testFormatOneTensorOneDim(self):
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
format_output = string_ops.string_format("{}", tensor) format_output = string_ops.string_format("{}", tensor)
out = self.evaluate(format_output) out = self.evaluate(format_output)
expected = "[0 1 2 ... 7 8 9]" expected = "[0 1 2 ... 7 8 9]"
self.assertEqual(compat.as_text(out), expected) self.assertEqual(compat.as_text(out), expected)
with self.test_session(): with self.cached_session():
tensor = math_ops.range(10) tensor = math_ops.range(10)
format_output = string_ops.string_format("{}", [tensor]) format_output = string_ops.string_format("{}", [tensor])
out = self.evaluate(format_output) out = self.evaluate(format_output)
@ -50,7 +50,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatOneVariableScalar(self): def testFormatOneVariableScalar(self):
with self.test_session(): with self.cached_session():
var = variables.Variable(3.34) var = variables.Variable(3.34)
format_output = string_ops.string_format("{}", [var]) format_output = string_ops.string_format("{}", [var])
if not context.executing_eagerly(): if not context.executing_eagerly():
@ -61,7 +61,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatOneVariableOneDim(self): def testFormatOneVariableOneDim(self):
with self.test_session(): with self.cached_session():
var = variables.Variable(math_ops.range(10)) var = variables.Variable(math_ops.range(10))
format_output = string_ops.string_format("{}", [var]) format_output = string_ops.string_format("{}", [var])
if not context.executing_eagerly(): if not context.executing_eagerly():
@ -72,7 +72,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatTwoVariablesWithAssignAdd(self): def testFormatTwoVariablesWithAssignAdd(self):
with self.test_session(): with self.cached_session():
var_one = variables.Variable(2.14) var_one = variables.Variable(2.14)
plus_one = var_one.assign_add(1.0) plus_one = var_one.assign_add(1.0)
var_two = variables.Variable(math_ops.range(10)) var_two = variables.Variable(math_ops.range(10))
@ -86,7 +86,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatOneTensorOneDimFloat(self): def testFormatOneTensorOneDimFloat(self):
with self.test_session(): with self.cached_session():
tensor = constant_op.constant([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]) tensor = constant_op.constant([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
format_output = string_ops.string_format("{}", tensor) format_output = string_ops.string_format("{}", tensor)
out = self.evaluate(format_output) out = self.evaluate(format_output)
@ -95,7 +95,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatOneTensorOneDimMatchesSummarize(self): def testFormatOneTensorOneDimMatchesSummarize(self):
with self.test_session(): with self.cached_session():
tensor = math_ops.range(6) tensor = math_ops.range(6)
format_output = string_ops.string_format("{}", tensor, summarize=3) format_output = string_ops.string_format("{}", tensor, summarize=3)
out = self.evaluate(format_output) out = self.evaluate(format_output)
@ -104,28 +104,28 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatOneTensorOneDimVarySummarize(self): def testFormatOneTensorOneDimVarySummarize(self):
with self.test_session(): with self.cached_session():
tensor = math_ops.range(6) tensor = math_ops.range(6)
format_output = string_ops.string_format("{}", tensor, summarize=-1) format_output = string_ops.string_format("{}", tensor, summarize=-1)
out = self.evaluate(format_output) out = self.evaluate(format_output)
expected = "[0 1 2 3 4 5]" expected = "[0 1 2 3 4 5]"
self.assertEqual(compat.as_text(out), expected) self.assertEqual(compat.as_text(out), expected)
with self.test_session(): with self.cached_session():
tensor = math_ops.range(6) tensor = math_ops.range(6)
format_output = string_ops.string_format("{}", tensor, summarize=1) format_output = string_ops.string_format("{}", tensor, summarize=1)
out = self.evaluate(format_output) out = self.evaluate(format_output)
expected = "[0 ... 5]" expected = "[0 ... 5]"
self.assertEqual(compat.as_text(out), expected) self.assertEqual(compat.as_text(out), expected)
with self.test_session(): with self.cached_session():
tensor = math_ops.range(6) tensor = math_ops.range(6)
format_output = string_ops.string_format("{}", tensor, summarize=2) format_output = string_ops.string_format("{}", tensor, summarize=2)
out = self.evaluate(format_output) out = self.evaluate(format_output)
expected = "[0 1 ... 4 5]" expected = "[0 1 ... 4 5]"
self.assertEqual(compat.as_text(out), expected) self.assertEqual(compat.as_text(out), expected)
with self.test_session(): with self.cached_session():
tensor = math_ops.range(6) tensor = math_ops.range(6)
format_output = string_ops.string_format("{}", tensor, summarize=10) format_output = string_ops.string_format("{}", tensor, summarize=10)
out = self.evaluate(format_output) out = self.evaluate(format_output)
@ -134,7 +134,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatOneTensorOneDimAlmostSummarize(self): def testFormatOneTensorOneDimAlmostSummarize(self):
with self.test_session(): with self.cached_session():
tensor = math_ops.range(5) tensor = math_ops.range(5)
format_output = string_ops.string_format("{}", tensor, summarize=3) format_output = string_ops.string_format("{}", tensor, summarize=3)
out = self.evaluate(format_output) out = self.evaluate(format_output)
@ -143,7 +143,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatOneTensorTwoDimLessThanSummarize(self): def testFormatOneTensorTwoDimLessThanSummarize(self):
with self.test_session(): with self.cached_session():
tensor = array_ops.reshape(math_ops.range(4), [2, 2]) tensor = array_ops.reshape(math_ops.range(4), [2, 2])
format_output = string_ops.string_format("{}", tensor, summarize=3) format_output = string_ops.string_format("{}", tensor, summarize=3)
out = self.evaluate(format_output) out = self.evaluate(format_output)
@ -153,7 +153,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatOneTensorTwoDim(self): def testFormatOneTensorTwoDim(self):
with self.test_session(): with self.cached_session():
tensor = array_ops.reshape(math_ops.range(100), [10, 10]) tensor = array_ops.reshape(math_ops.range(100), [10, 10])
format_output = string_ops.string_format("{}", tensor) format_output = string_ops.string_format("{}", tensor)
out = self.evaluate(format_output) out = self.evaluate(format_output)
@ -168,7 +168,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatOneTensorTwoDimSummarizeTwo(self): def testFormatOneTensorTwoDimSummarizeTwo(self):
with self.test_session(): with self.cached_session():
tensor = array_ops.reshape(math_ops.range(100), [10, 10]) tensor = array_ops.reshape(math_ops.range(100), [10, 10])
format_output = string_ops.string_format("{}", tensor, summarize=2) format_output = string_ops.string_format("{}", tensor, summarize=2)
out = self.evaluate(format_output) out = self.evaluate(format_output)
@ -181,7 +181,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatOneTensorThreeDim(self): def testFormatOneTensorThreeDim(self):
with self.test_session(): with self.cached_session():
tensor = array_ops.reshape(math_ops.range(1000), [10, 10, 10]) tensor = array_ops.reshape(math_ops.range(1000), [10, 10, 10])
format_output = string_ops.string_format("{}", tensor) format_output = string_ops.string_format("{}", tensor)
out = self.evaluate(format_output) out = self.evaluate(format_output)
@ -237,7 +237,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatOneTensorTemplatePrefix(self): def testFormatOneTensorTemplatePrefix(self):
with self.test_session(): with self.cached_session():
tensor = array_ops.reshape(math_ops.range(100), [10, 10]) tensor = array_ops.reshape(math_ops.range(100), [10, 10])
format_output = string_ops.string_format("tensor summary: {}", tensor) format_output = string_ops.string_format("tensor summary: {}", tensor)
out = self.evaluate(format_output) out = self.evaluate(format_output)
@ -252,7 +252,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatOneTensorTemplatePrefixAndSuffix(self): def testFormatOneTensorTemplatePrefixAndSuffix(self):
with self.test_session(): with self.cached_session():
tensor = array_ops.reshape(math_ops.range(100), [10, 10]) tensor = array_ops.reshape(math_ops.range(100), [10, 10])
format_output = string_ops.string_format("tensor summary: {}, suffix", format_output = string_ops.string_format("tensor summary: {}, suffix",
tensor) tensor)
@ -268,7 +268,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatOneTensorTemplateSuffix(self): def testFormatOneTensorTemplateSuffix(self):
with self.test_session(): with self.cached_session():
tensor = array_ops.reshape(math_ops.range(100), [10, 10]) tensor = array_ops.reshape(math_ops.range(100), [10, 10])
format_output = string_ops.string_format("{}, suffix", tensor) format_output = string_ops.string_format("{}, suffix", tensor)
out = self.evaluate(format_output) out = self.evaluate(format_output)
@ -283,7 +283,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatNoTensor(self): def testFormatNoTensor(self):
with self.test_session(): with self.cached_session():
format_output = string_ops.string_format("No tensor.", ()) format_output = string_ops.string_format("No tensor.", ())
out = self.evaluate(format_output) out = self.evaluate(format_output)
expected = "No tensor." expected = "No tensor."
@ -291,7 +291,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatMultiTensor(self): def testFormatMultiTensor(self):
with self.test_session(): with self.cached_session():
tensor_one = array_ops.reshape(math_ops.range(100), [10, 10]) tensor_one = array_ops.reshape(math_ops.range(100), [10, 10])
tensor_two = tensor_one * 10 tensor_two = tensor_one * 10
format_output = string_ops.string_format("One: {},\nTwo: {}", format_output = string_ops.string_format("One: {},\nTwo: {}",
@ -315,7 +315,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatSummarizeOne(self): def testFormatSummarizeOne(self):
with self.test_session(): with self.cached_session():
tensor = array_ops.reshape(math_ops.range(100), [10, 10]) tensor = array_ops.reshape(math_ops.range(100), [10, 10])
format_output = string_ops.string_format("tensor summary: {}", tensor, format_output = string_ops.string_format("tensor summary: {}", tensor,
summarize=1) summarize=1)
@ -327,7 +327,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatSummarizeTwo(self): def testFormatSummarizeTwo(self):
with self.test_session(): with self.cached_session():
tensor = array_ops.reshape(math_ops.range(100), [10, 10]) tensor = array_ops.reshape(math_ops.range(100), [10, 10])
format_output = string_ops.string_format("tensor summary: {}", tensor, format_output = string_ops.string_format("tensor summary: {}", tensor,
summarize=2) summarize=2)
@ -341,7 +341,7 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testFormatPlaceholder(self): def testFormatPlaceholder(self):
with self.test_session(): with self.cached_session():
tensor = array_ops.reshape(math_ops.range(100), [10, 10]) tensor = array_ops.reshape(math_ops.range(100), [10, 10])
format_output = string_ops.string_format("tensor summary: %t%", tensor, format_output = string_ops.string_format("tensor summary: %t%", tensor,
placeholder="%t%") placeholder="%t%")
@ -357,21 +357,21 @@ class StringFormatOpTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes() @test_util.run_in_graph_and_eager_modes()
def testTensorCountMustMatchPlaceholderCount(self): def testTensorCountMustMatchPlaceholderCount(self):
with self.test_session(): with self.cached_session():
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, r"2 placeholder\(s\) in template does not match 1 " ValueError, r"2 placeholder\(s\) in template does not match 1 "
r"tensor\(s\) provided as input"): r"tensor\(s\) provided as input"):
tensor = math_ops.range(10) tensor = math_ops.range(10)
format_output = string_ops.string_format("{} {}", tensor) format_output = string_ops.string_format("{} {}", tensor)
self.evaluate(format_output) self.evaluate(format_output)
with self.test_session(): with self.cached_session():
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, r"2 placeholder\(s\) in template does not match 1 " ValueError, r"2 placeholder\(s\) in template does not match 1 "
r"tensor\(s\) provided as input"): r"tensor\(s\) provided as input"):
tensor = math_ops.range(10) tensor = math_ops.range(10)
format_output = string_ops.string_format("{} {}", [tensor]) format_output = string_ops.string_format("{} {}", [tensor])
self.evaluate(format_output) self.evaluate(format_output)
with self.test_session(): with self.cached_session():
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, r"1 placeholder\(s\) in template does not match 2 " ValueError, r"1 placeholder\(s\) in template does not match 2 "
r"tensor\(s\) provided as input"): r"tensor\(s\) provided as input"):

View File

@ -41,7 +41,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
x = constant_op.constant(2.) x = constant_op.constant(2.)
ret = while_loop_v2(lambda v: v < 8., lambda v: v * v, [x]) ret = while_loop_v2(lambda v: v < 8., lambda v: v * v, [x])
grad = gradients_impl.gradients(ret, [x]) grad = gradients_impl.gradients(ret, [x])
with self.test_session() as sess: with self.cached_session() as sess:
self.assertEqual(sess.run(ret), 16.) self.assertEqual(sess.run(ret), 16.)
self.assertSequenceEqual(sess.run(grad), [32.]) self.assertSequenceEqual(sess.run(grad), [32.])
@ -58,7 +58,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
# Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0. # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
grad = gradients_impl.gradients(ret, [x]) # [2*x*y] grad = gradients_impl.gradients(ret, [x]) # [2*x*y]
with self.test_session() as sess: with self.cached_session() as sess:
self.assertSequenceEqual(sess.run(ret), [45., 3.]) self.assertSequenceEqual(sess.run(ret), [45., 3.])
self.assertSequenceEqual(sess.run(grad), [9.]) self.assertSequenceEqual(sess.run(grad), [9.])
@ -81,7 +81,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2] grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2]
grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1] grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1]
grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1] grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1]
with self.test_session() as sess: with self.cached_session() as sess:
self.assertSequenceEqual(sess.run(ret), [120., 23.]) self.assertSequenceEqual(sess.run(ret), [120., 23.])
self.assertSequenceEqual(sess.run(gradx_0), [39.]) self.assertSequenceEqual(sess.run(gradx_0), [39.])
self.assertSequenceEqual(sess.run(gradx_1), [4.]) self.assertSequenceEqual(sess.run(gradx_1), [4.])
@ -96,7 +96,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
ret2 = while_loop_v2(lambda v: v < 16., lambda v: v * v, ret1) # x**4 ret2 = while_loop_v2(lambda v: v < 16., lambda v: v * v, ret1) # x**4
grad = gradients_impl.gradients(ret2, [x]) # 4x**3 grad = gradients_impl.gradients(ret2, [x]) # 4x**3
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2 grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
with self.test_session() as sess: with self.cached_session() as sess:
self.assertSequenceEqual(sess.run(grad), [32.]) self.assertSequenceEqual(sess.run(grad), [32.])
self.assertSequenceEqual(sess.run(grad_grad), [48.]) self.assertSequenceEqual(sess.run(grad_grad), [48.])
@ -105,7 +105,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
ret = while_loop_v2(lambda v: v < 8., lambda v: v**2, [x]) # x**4 ret = while_loop_v2(lambda v: v < 8., lambda v: v**2, [x]) # x**4
grad = gradients_impl.gradients(ret, [x]) # 4x**3 grad = gradients_impl.gradients(ret, [x]) # 4x**3
grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2 grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
with self.test_session() as sess: with self.cached_session() as sess:
self.assertEqual(sess.run(ret), 16.) self.assertEqual(sess.run(ret), 16.)
self.assertSequenceEqual(sess.run(grad), [32.]) self.assertSequenceEqual(sess.run(grad), [32.])
self.assertSequenceEqual(sess.run(grad_grad), [48.]) self.assertSequenceEqual(sess.run(grad_grad), [48.])
@ -148,7 +148,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
y = constant_op.constant(1.) y = constant_op.constant(1.)
ret = while_loop_v2(lambda v: v + y < 9., lambda v: v * 3., [x]) ret = while_loop_v2(lambda v: v + y < 9., lambda v: v * 3., [x])
grad = gradients_impl.gradients(ret, [x]) grad = gradients_impl.gradients(ret, [x])
with self.test_session() as sess: with self.cached_session() as sess:
self.assertEqual(sess.run(ret), 18.) self.assertEqual(sess.run(ret), 18.)
self.assertSequenceEqual(sess.run(grad), [9.]) self.assertSequenceEqual(sess.run(grad), [9.])
@ -157,7 +157,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
y = constant_op.constant(3.) y = constant_op.constant(3.)
ret = while_loop_v2(lambda v: v < 8., lambda v: v * y, [x]) ret = while_loop_v2(lambda v: v < 8., lambda v: v * y, [x])
grad = gradients_impl.gradients(ret, [x]) grad = gradients_impl.gradients(ret, [x])
with self.test_session() as sess: with self.cached_session() as sess:
self.assertEqual(sess.run(ret), 18.) self.assertEqual(sess.run(ret), 18.)
self.assertSequenceEqual(sess.run(grad), [9.]) self.assertSequenceEqual(sess.run(grad), [9.])
@ -178,7 +178,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
ret = while_loop_v2(Cond, Body, [x, tensor_list]) ret = while_loop_v2(Cond, Body, [x, tensor_list])
grad = gradients_impl.gradients(ret[0], x) grad = gradients_impl.gradients(ret[0], x)
with self.test_session() as sess: with self.cached_session() as sess:
self.assertEqual(sess.run(ret[0]), 16.) self.assertEqual(sess.run(ret[0]), 16.)
self.assertSequenceEqual(sess.run(grad), [32.]) self.assertSequenceEqual(sess.run(grad), [32.])
@ -212,7 +212,7 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
self.assertEqual(accumulator_count, 1) self.assertEqual(accumulator_count, 1)
grad = gradients_impl.gradients(ret[0], x) grad = gradients_impl.gradients(ret[0], x)
with self.test_session() as sess: with self.cached_session() as sess:
self.assertEqual(sess.run(ret[0]), 16.) self.assertEqual(sess.run(ret[0]), 16.)
self.assertSequenceEqual(sess.run(grad), [32.]) self.assertSequenceEqual(sess.run(grad), [32.])

View File

@ -3673,7 +3673,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
# Note: There are multiple versions of non_max_suppression v2, v3, v4. # Note: There are multiple versions of non_max_suppression v2, v3, v4.
# gen_image_ops.non_max_suppression_v2: # gen_image_ops.non_max_suppression_v2:
for dtype in [np.float16, np.float32]: for dtype in [np.float16, np.float32]:
with self.test_session(): with self.cached_session():
boxes = constant_op.constant(boxes_np, dtype=dtype) boxes = constant_op.constant(boxes_np, dtype=dtype)
scores = constant_op.constant(scores_np, dtype=dtype) scores = constant_op.constant(scores_np, dtype=dtype)
max_output_size = constant_op.constant(max_output_size_np) max_output_size = constant_op.constant(max_output_size_np)
@ -3683,7 +3683,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
self.assertAllClose(selected_indices, [3, 0, 5]) self.assertAllClose(selected_indices, [3, 0, 5])
# image_ops.non_max_suppression = gen_image_ops.non_max_suppression_v3. # image_ops.non_max_suppression = gen_image_ops.non_max_suppression_v3.
for dtype in [np.float16, np.float32]: for dtype in [np.float16, np.float32]:
with self.test_session(): with self.cached_session():
boxes = constant_op.constant(boxes_np, dtype=dtype) boxes = constant_op.constant(boxes_np, dtype=dtype)
scores = constant_op.constant(scores_np, dtype=dtype) scores = constant_op.constant(scores_np, dtype=dtype)
max_output_size = constant_op.constant(max_output_size_np) max_output_size = constant_op.constant(max_output_size_np)
@ -3694,7 +3694,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
# gen_image_ops.non_max_suppression_v4. # gen_image_ops.non_max_suppression_v4.
score_threshold = float('-inf') score_threshold = float('-inf')
for dtype in [np.float16, np.float32]: for dtype in [np.float16, np.float32]:
with self.test_session(): with self.cached_session():
boxes = constant_op.constant(boxes_np, dtype=dtype) boxes = constant_op.constant(boxes_np, dtype=dtype)
scores = constant_op.constant(scores_np, dtype=dtype) scores = constant_op.constant(scores_np, dtype=dtype)
max_output_size = constant_op.constant(max_output_size_np) max_output_size = constant_op.constant(max_output_size_np)

View File

@ -218,7 +218,7 @@ class FtrlOptimizerTest(test.TestCase):
def testFtrlWithL1_L2_L2ShrinkageSparse(self): def testFtrlWithL1_L2_L2ShrinkageSparse(self):
"""Tests the new FTRL op with support for l2 shrinkage on sparse grads.""" """Tests the new FTRL op with support for l2 shrinkage on sparse grads."""
for dtype in [dtypes.half, dtypes.float32]: for dtype in [dtypes.half, dtypes.float32]:
with self.test_session() as sess: with self.cached_session() as sess:
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype) var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
var1 = variables.Variable([[4.0], [3.0]], dtype=dtype) var1 = variables.Variable([[4.0], [3.0]], dtype=dtype)
grads0 = ops.IndexedSlices( grads0 = ops.IndexedSlices(
@ -252,7 +252,7 @@ class FtrlOptimizerTest(test.TestCase):
def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self): def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
"""Verifies that l2 shrinkage in FTRL does not change lr schedule.""" """Verifies that l2 shrinkage in FTRL does not change lr schedule."""
for dtype in [dtypes.half, dtypes.float32]: for dtype in [dtypes.half, dtypes.float32]:
with self.test_session() as sess: with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0], dtype=dtype) var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([1.0, 2.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)

View File

@ -62,7 +62,7 @@ class LRDecayTestV2(test_util.TensorFlowTestCase):
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
def testVariables(self): def testVariables(self):
with self.test_session(): with self.cached_session():
step = variables.Variable(1) step = variables.Variable(1)
assign_1 = step.assign(1) assign_1 = step.assign(1)
assign_2 = step.assign(2) assign_2 = step.assign(2)