Clean up tf.print tests

- Remove with cached_session blocks
- Use run_all_in_graph_and_eager_modes on entire class
- Use assertIn instead of assertTrue

PiperOrigin-RevId: 292730904
Change-Id: Iabc6e321a02cda4e9b2b6604c717a6d9a129d2a0
This commit is contained in:
Gaurav Jain 2020-02-01 14:27:14 -08:00 committed by TensorFlower Gardener
parent adf769043f
commit 0aba717531

View File

@ -69,244 +69,209 @@ class LoggingOpsTest(test.TestCase):
self.evaluate(out) self.evaluate(out)
@test_util.run_all_in_graph_and_eager_modes
class PrintV2Test(test.TestCase): class PrintV2Test(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testPrintOneTensor(self): def testPrintOneTensor(self):
with self.cached_session(): tensor = math_ops.range(10)
tensor = math_ops.range(10) with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: print_op = logging_ops.print_v2(tensor)
print_op = logging_ops.print_v2(tensor) self.evaluate(print_op)
self.evaluate(print_op)
expected = "[0 1 2 ... 7 8 9]" expected = "[0 1 2 ... 7 8 9]"
self.assertTrue((expected + "\n") in printed.contents()) self.assertIn((expected + "\n"), printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintOneStringTensor(self): def testPrintOneStringTensor(self):
with self.cached_session(): tensor = ops.convert_to_tensor([char for char in string.ascii_lowercase])
tensor = ops.convert_to_tensor([char for char in string.ascii_lowercase]) with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: print_op = logging_ops.print_v2(tensor)
print_op = logging_ops.print_v2(tensor) self.evaluate(print_op)
self.evaluate(print_op)
expected = "[\"a\" \"b\" \"c\" ... \"x\" \"y\" \"z\"]" expected = "[\"a\" \"b\" \"c\" ... \"x\" \"y\" \"z\"]"
self.assertIn((expected + "\n"), printed.contents()) self.assertIn((expected + "\n"), printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintOneTensorVarySummarize(self): def testPrintOneTensorVarySummarize(self):
with self.cached_session(): tensor = math_ops.range(10)
tensor = math_ops.range(10) with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: print_op = logging_ops.print_v2(tensor, summarize=1)
print_op = logging_ops.print_v2(tensor, summarize=1) self.evaluate(print_op)
self.evaluate(print_op)
expected = "[0 ... 9]" expected = "[0 ... 9]"
self.assertTrue((expected + "\n") in printed.contents()) self.assertIn((expected + "\n"), printed.contents())
with self.cached_session(): tensor = math_ops.range(10)
tensor = math_ops.range(10) with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: print_op = logging_ops.print_v2(tensor, summarize=2)
print_op = logging_ops.print_v2(tensor, summarize=2) self.evaluate(print_op)
self.evaluate(print_op)
expected = "[0 1 ... 8 9]" expected = "[0 1 ... 8 9]"
self.assertTrue((expected + "\n") in printed.contents()) self.assertIn((expected + "\n"), printed.contents())
with self.cached_session(): tensor = math_ops.range(10)
tensor = math_ops.range(10) with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: print_op = logging_ops.print_v2(tensor, summarize=3)
print_op = logging_ops.print_v2(tensor, summarize=3) self.evaluate(print_op)
self.evaluate(print_op)
expected = "[0 1 2 ... 7 8 9]" expected = "[0 1 2 ... 7 8 9]"
self.assertTrue((expected + "\n") in printed.contents()) self.assertIn((expected + "\n"), printed.contents())
with self.cached_session(): tensor = math_ops.range(10)
tensor = math_ops.range(10) with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: print_op = logging_ops.print_v2(tensor, summarize=-1)
print_op = logging_ops.print_v2(tensor, summarize=-1) self.evaluate(print_op)
self.evaluate(print_op)
expected = "[0 1 2 3 4 5 6 7 8 9]" expected = "[0 1 2 3 4 5 6 7 8 9]"
self.assertTrue((expected + "\n") in printed.contents()) self.assertIn((expected + "\n"), printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintOneVariable(self): def testPrintOneVariable(self):
with self.cached_session(): var = variables.Variable(math_ops.range(10))
var = variables.Variable(math_ops.range(10)) if not context.executing_eagerly():
if not context.executing_eagerly(): variables.global_variables_initializer().run()
variables.global_variables_initializer().run() with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: print_op = logging_ops.print_v2(var)
print_op = logging_ops.print_v2(var) self.evaluate(print_op)
self.evaluate(print_op) expected = "[0 1 2 ... 7 8 9]"
expected = "[0 1 2 ... 7 8 9]" self.assertIn((expected + "\n"), printed.contents())
self.assertTrue((expected + "\n") in printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintTwoVariablesInStructWithAssignAdd(self): def testPrintTwoVariablesInStructWithAssignAdd(self):
with self.cached_session(): var_one = variables.Variable(2.14)
var_one = variables.Variable(2.14) plus_one = var_one.assign_add(1.0)
plus_one = var_one.assign_add(1.0) var_two = variables.Variable(math_ops.range(10))
var_two = variables.Variable(math_ops.range(10)) if not context.executing_eagerly():
if not context.executing_eagerly(): variables.global_variables_initializer().run()
variables.global_variables_initializer().run() with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: self.evaluate(plus_one)
self.evaluate(plus_one) print_op = logging_ops.print_v2(var_one, {"second": var_two})
print_op = logging_ops.print_v2(var_one, {"second": var_two}) self.evaluate(print_op)
self.evaluate(print_op) expected = "3.14 {'second': [0 1 2 ... 7 8 9]}"
expected = "3.14 {'second': [0 1 2 ... 7 8 9]}" self.assertIn((expected + "\n"), printed.contents())
self.assertTrue((expected + "\n") in printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintTwoTensors(self): def testPrintTwoTensors(self):
with self.cached_session(): tensor = math_ops.range(10)
tensor = math_ops.range(10) with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: print_op = logging_ops.print_v2(tensor, tensor * 10)
print_op = logging_ops.print_v2(tensor, tensor * 10) self.evaluate(print_op)
self.evaluate(print_op) expected = "[0 1 2 ... 7 8 9] [0 10 20 ... 70 80 90]"
expected = "[0 1 2 ... 7 8 9] [0 10 20 ... 70 80 90]" self.assertIn((expected + "\n"), printed.contents())
self.assertTrue((expected + "\n") in printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintTwoTensorsDifferentSep(self): def testPrintTwoTensorsDifferentSep(self):
with self.cached_session(): tensor = math_ops.range(10)
tensor = math_ops.range(10) with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: print_op = logging_ops.print_v2(tensor, tensor * 10, sep="<separator>")
print_op = logging_ops.print_v2(tensor, tensor * 10, sep="<separator>") self.evaluate(print_op)
self.evaluate(print_op) expected = "[0 1 2 ... 7 8 9]<separator>[0 10 20 ... 70 80 90]"
expected = "[0 1 2 ... 7 8 9]<separator>[0 10 20 ... 70 80 90]" self.assertIn(expected + "\n", printed.contents())
self.assertIn(expected + "\n", printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintPlaceholderGeneration(self): def testPrintPlaceholderGeneration(self):
with self.cached_session(): tensor = math_ops.range(10)
tensor = math_ops.range(10) with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: print_op = logging_ops.print_v2("{}6", {"{}": tensor * 10})
print_op = logging_ops.print_v2("{}6", {"{}": tensor * 10}) self.evaluate(print_op)
self.evaluate(print_op) expected = "{}6 {'{}': [0 10 20 ... 70 80 90]}"
expected = "{}6 {'{}': [0 10 20 ... 70 80 90]}" self.assertIn((expected + "\n"), printed.contents())
self.assertTrue((expected + "\n") in printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintNoTensors(self): def testPrintNoTensors(self):
with self.cached_session(): with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: print_op = logging_ops.print_v2(23, [23, 5], {"6": 12})
print_op = logging_ops.print_v2(23, [23, 5], {"6": 12}) self.evaluate(print_op)
self.evaluate(print_op) expected = "23 [23, 5] {'6': 12}"
expected = "23 [23, 5] {'6': 12}" self.assertIn((expected + "\n"), printed.contents())
self.assertTrue((expected + "\n") in printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintFloatScalar(self): def testPrintFloatScalar(self):
with self.cached_session(): tensor = ops.convert_to_tensor(434.43)
tensor = ops.convert_to_tensor(434.43) with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: print_op = logging_ops.print_v2(tensor)
print_op = logging_ops.print_v2(tensor) self.evaluate(print_op)
self.evaluate(print_op) expected = "434.43"
expected = "434.43" self.assertIn((expected + "\n"), printed.contents())
self.assertTrue((expected + "\n") in printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintStringScalar(self): def testPrintStringScalar(self):
with self.cached_session(): tensor = ops.convert_to_tensor("scalar")
tensor = ops.convert_to_tensor("scalar") with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: print_op = logging_ops.print_v2(tensor)
print_op = logging_ops.print_v2(tensor) self.evaluate(print_op)
self.evaluate(print_op) expected = "scalar"
expected = "scalar" self.assertIn((expected + "\n"), printed.contents())
self.assertTrue((expected + "\n") in printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintStringScalarDifferentEnd(self): def testPrintStringScalarDifferentEnd(self):
with self.cached_session(): tensor = ops.convert_to_tensor("scalar")
tensor = ops.convert_to_tensor("scalar") with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: print_op = logging_ops.print_v2(tensor, end="<customend>")
print_op = logging_ops.print_v2(tensor, end="<customend>") self.evaluate(print_op)
self.evaluate(print_op) expected = "scalar<customend>"
expected = "scalar<customend>" self.assertIn(expected, printed.contents())
self.assertIn(expected, printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintComplexTensorStruct(self): def testPrintComplexTensorStruct(self):
with self.cached_session(): tensor = math_ops.range(10)
tensor = math_ops.range(10) small_tensor = constant_op.constant([0.3, 12.4, -16.1])
small_tensor = constant_op.constant([0.3, 12.4, -16.1]) big_tensor = math_ops.mul(tensor, 10)
big_tensor = math_ops.mul(tensor, 10) with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: print_op = logging_ops.print_v2(
print_op = logging_ops.print_v2( "first:", tensor, "middle:",
"first:", tensor, "middle:", {"small": small_tensor, "Big": big_tensor}, 10,
{"small": small_tensor, "Big": big_tensor}, 10, [tensor * 2, tensor])
[tensor * 2, tensor]) self.evaluate(print_op)
self.evaluate(print_op) # Note that the keys in the dict will always be sorted,
# Note that the keys in the dict will always be sorted, # so 'Big' comes before 'small'
# so 'Big' comes before 'small' expected = ("first: [0 1 2 ... 7 8 9] "
expected = ("first: [0 1 2 ... 7 8 9] " "middle: {'Big': [0 10 20 ... 70 80 90], "
"middle: {'Big': [0 10 20 ... 70 80 90], " "'small': [0.3 12.4 -16.1]} "
"'small': [0.3 12.4 -16.1]} " "10 [[0 2 4 ... 14 16 18], [0 1 2 ... 7 8 9]]")
"10 [[0 2 4 ... 14 16 18], [0 1 2 ... 7 8 9]]") self.assertIn((expected + "\n"), printed.contents())
self.assertTrue((expected + "\n") in printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintSparseTensor(self): def testPrintSparseTensor(self):
with self.cached_session(): ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]] val = [0, 10, 13, 4, 14, 32, 33]
val = [0, 10, 13, 4, 14, 32, 33] shape = [5, 6]
shape = [5, 6]
sparse = sparse_tensor.SparseTensor( sparse = sparse_tensor.SparseTensor(
constant_op.constant(ind, dtypes.int64), constant_op.constant(ind, dtypes.int64),
constant_op.constant(val, dtypes.int64), constant_op.constant(val, dtypes.int64),
constant_op.constant(shape, dtypes.int64)) constant_op.constant(shape, dtypes.int64))
with self.captureWritesToStream(sys.stderr) as printed: with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2(sparse) print_op = logging_ops.print_v2(sparse)
self.evaluate(print_op) self.evaluate(print_op)
expected = ("'SparseTensor(indices=[[0 0]\n" expected = ("'SparseTensor(indices=[[0 0]\n"
" [1 0]\n" " [1 0]\n"
" [1 3]\n" " [1 3]\n"
" ...\n" " ...\n"
" [1 4]\n" " [1 4]\n"
" [3 2]\n" " [3 2]\n"
" [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])'") " [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])'")
self.assertTrue((expected + "\n") in printed.contents()) self.assertIn((expected + "\n"), printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintSparseTensorInDataStruct(self): def testPrintSparseTensorInDataStruct(self):
with self.cached_session(): ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]] val = [0, 10, 13, 4, 14, 32, 33]
val = [0, 10, 13, 4, 14, 32, 33] shape = [5, 6]
shape = [5, 6]
sparse = sparse_tensor.SparseTensor( sparse = sparse_tensor.SparseTensor(
constant_op.constant(ind, dtypes.int64), constant_op.constant(ind, dtypes.int64),
constant_op.constant(val, dtypes.int64), constant_op.constant(val, dtypes.int64),
constant_op.constant(shape, dtypes.int64)) constant_op.constant(shape, dtypes.int64))
with self.captureWritesToStream(sys.stderr) as printed: with self.captureWritesToStream(sys.stderr) as printed:
print_op = logging_ops.print_v2([sparse]) print_op = logging_ops.print_v2([sparse])
self.evaluate(print_op) self.evaluate(print_op)
expected = ("['SparseTensor(indices=[[0 0]\n" expected = ("['SparseTensor(indices=[[0 0]\n"
" [1 0]\n" " [1 0]\n"
" [1 3]\n" " [1 3]\n"
" ...\n" " ...\n"
" [1 4]\n" " [1 4]\n"
" [3 2]\n" " [3 2]\n"
" [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])']") " [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])']")
self.assertTrue((expected + "\n") in printed.contents()) self.assertIn((expected + "\n"), printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintOneTensorStdout(self): def testPrintOneTensorStdout(self):
with self.cached_session(): tensor = math_ops.range(10)
tensor = math_ops.range(10) with self.captureWritesToStream(sys.stdout) as printed:
with self.captureWritesToStream(sys.stdout) as printed: print_op = logging_ops.print_v2(
print_op = logging_ops.print_v2( tensor, output_stream=sys.stdout)
tensor, output_stream=sys.stdout) self.evaluate(print_op)
self.evaluate(print_op) expected = "[0 1 2 ... 7 8 9]"
expected = "[0 1 2 ... 7 8 9]" self.assertIn((expected + "\n"), printed.contents())
self.assertTrue((expected + "\n") in printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintTensorsToFile(self): def testPrintTensorsToFile(self):
tmpfile_name = tempfile.mktemp(".printv2_test") tmpfile_name = tempfile.mktemp(".printv2_test")
tensor_0 = math_ops.range(0, 10) tensor_0 = math_ops.range(0, 10)
@ -330,42 +295,37 @@ class PrintV2Test(test.TestCase):
except IOError as e: except IOError as e:
self.fail(e) self.fail(e)
@test_util.run_in_graph_and_eager_modes()
def testInvalidOutputStreamRaisesError(self): def testInvalidOutputStreamRaisesError(self):
with self.cached_session(): tensor = math_ops.range(10)
tensor = math_ops.range(10) with self.assertRaises(ValueError):
with self.assertRaises(ValueError): print_op = logging_ops.print_v2(
print_op = logging_ops.print_v2( tensor, output_stream="unknown")
tensor, output_stream="unknown") self.evaluate(print_op)
self.evaluate(print_op)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testPrintOpName(self): def testPrintOpName(self):
with self.cached_session(): tensor = math_ops.range(10)
tensor = math_ops.range(10) print_op = logging_ops.print_v2(tensor, name="print_name")
print_op = logging_ops.print_v2(tensor, name="print_name") self.assertEqual(print_op.name, "print_name")
self.assertEqual(print_op.name, "print_name")
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testNoDuplicateFormatOpGraphModeAfterExplicitFormat(self): def testNoDuplicateFormatOpGraphModeAfterExplicitFormat(self):
with self.cached_session(): tensor = math_ops.range(10)
tensor = math_ops.range(10) formatted_string = string_ops.string_format("{}", tensor)
formatted_string = string_ops.string_format("{}", tensor) print_op = logging_ops.print_v2(formatted_string)
print_op = logging_ops.print_v2(formatted_string) self.evaluate(print_op)
self.evaluate(print_op) graph_ops = ops.get_default_graph().get_operations()
graph_ops = ops.get_default_graph().get_operations() format_ops = [op for op in graph_ops if op.type == "StringFormat"]
format_ops = [op for op in graph_ops if op.type == "StringFormat"] # Should be only 1 format_op for graph mode.
# Should be only 1 format_op for graph mode. self.assertEqual(len(format_ops), 1)
self.assertEqual(len(format_ops), 1)
def testPrintOneTensorEagerOnOpCreate(self): def testPrintOneTensorEagerOnOpCreate(self):
with self.cached_session(): with context.eager_mode():
with context.eager_mode(): tensor = math_ops.range(10)
tensor = math_ops.range(10) expected = "[0 1 2 ... 7 8 9]"
expected = "[0 1 2 ... 7 8 9]" with self.captureWritesToStream(sys.stderr) as printed:
with self.captureWritesToStream(sys.stderr) as printed: logging_ops.print_v2(tensor)
logging_ops.print_v2(tensor) self.assertIn((expected + "\n"), printed.contents())
self.assertTrue((expected + "\n") in printed.contents())
def testPrintsOrderedInDefun(self): def testPrintsOrderedInDefun(self):
with context.eager_mode(): with context.eager_mode():
@ -378,9 +338,8 @@ class PrintV2Test(test.TestCase):
with self.captureWritesToStream(sys.stderr) as printed: with self.captureWritesToStream(sys.stderr) as printed:
prints() prints()
self.assertTrue(("A\nB\nC\n") in printed.contents()) self.assertTrue(("A\nB\nC\n"), printed.contents())
@test_util.run_in_graph_and_eager_modes()
def testPrintInDefunWithoutExplicitEvalOfPrint(self): def testPrintInDefunWithoutExplicitEvalOfPrint(self):
@function.defun @function.defun
def f(): def f():
@ -392,14 +351,14 @@ class PrintV2Test(test.TestCase):
with self.captureWritesToStream(sys.stderr) as printed_one: with self.captureWritesToStream(sys.stderr) as printed_one:
x = f() x = f()
self.evaluate(x) self.evaluate(x)
self.assertTrue((expected + "\n") in printed_one.contents()) self.assertIn((expected + "\n"), printed_one.contents())
# We execute the function again to make sure it doesn't only print on the # We execute the function again to make sure it doesn't only print on the
# first call. # first call.
with self.captureWritesToStream(sys.stderr) as printed_two: with self.captureWritesToStream(sys.stderr) as printed_two:
y = f() y = f()
self.evaluate(y) self.evaluate(y)
self.assertTrue((expected + "\n") in printed_two.contents()) self.assertIn((expected + "\n"), printed_two.contents())
class PrintGradientTest(test.TestCase): class PrintGradientTest(test.TestCase):
@ -417,15 +376,14 @@ class PrintGradientTest(test.TestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testPrintGradient(self): def testPrintGradient(self):
with self.cached_session(): inp = constant_op.constant(2.0, shape=[100, 32], name="in")
inp = constant_op.constant(2.0, shape=[100, 32], name="in") w = constant_op.constant(4.0, shape=[10, 100], name="w")
w = constant_op.constant(4.0, shape=[10, 100], name="w") wx = math_ops.matmul(w, inp, name="wx")
wx = math_ops.matmul(w, inp, name="wx") wx_print = logging_ops.Print(wx, [w, w, w])
wx_print = logging_ops.Print(wx, [w, w, w]) wx_grad = gradients_impl.gradients(wx, w)[0]
wx_grad = gradients_impl.gradients(wx, w)[0] wx_print_grad = gradients_impl.gradients(wx_print, w)[0]
wx_print_grad = gradients_impl.gradients(wx_print, w)[0] wxg = self.evaluate(wx_grad)
wxg = self.evaluate(wx_grad) wxpg = self.evaluate(wx_print_grad)
wxpg = self.evaluate(wx_print_grad)
self.assertAllEqual(wxg, wxpg) self.assertAllEqual(wxg, wxpg)