Clean up tf.print tests
- Remove with cached_session blocks - Use run_all_in_graph_and_eager_modes on entire class - Use assertIn instead of assertTrue PiperOrigin-RevId: 292730904 Change-Id: Iabc6e321a02cda4e9b2b6604c717a6d9a129d2a0
This commit is contained in:
parent
adf769043f
commit
0aba717531
@ -69,22 +69,19 @@ class LoggingOpsTest(test.TestCase):
|
|||||||
self.evaluate(out)
|
self.evaluate(out)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class PrintV2Test(test.TestCase):
|
class PrintV2Test(test.TestCase):
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintOneTensor(self):
|
def testPrintOneTensor(self):
|
||||||
with self.cached_session():
|
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor)
|
print_op = logging_ops.print_v2(tensor)
|
||||||
self.evaluate(print_op)
|
self.evaluate(print_op)
|
||||||
|
|
||||||
expected = "[0 1 2 ... 7 8 9]"
|
expected = "[0 1 2 ... 7 8 9]"
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintOneStringTensor(self):
|
def testPrintOneStringTensor(self):
|
||||||
with self.cached_session():
|
|
||||||
tensor = ops.convert_to_tensor([char for char in string.ascii_lowercase])
|
tensor = ops.convert_to_tensor([char for char in string.ascii_lowercase])
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor)
|
print_op = logging_ops.print_v2(tensor)
|
||||||
@ -93,47 +90,40 @@ class PrintV2Test(test.TestCase):
|
|||||||
expected = "[\"a\" \"b\" \"c\" ... \"x\" \"y\" \"z\"]"
|
expected = "[\"a\" \"b\" \"c\" ... \"x\" \"y\" \"z\"]"
|
||||||
self.assertIn((expected + "\n"), printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintOneTensorVarySummarize(self):
|
def testPrintOneTensorVarySummarize(self):
|
||||||
with self.cached_session():
|
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor, summarize=1)
|
print_op = logging_ops.print_v2(tensor, summarize=1)
|
||||||
self.evaluate(print_op)
|
self.evaluate(print_op)
|
||||||
|
|
||||||
expected = "[0 ... 9]"
|
expected = "[0 ... 9]"
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
with self.cached_session():
|
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor, summarize=2)
|
print_op = logging_ops.print_v2(tensor, summarize=2)
|
||||||
self.evaluate(print_op)
|
self.evaluate(print_op)
|
||||||
|
|
||||||
expected = "[0 1 ... 8 9]"
|
expected = "[0 1 ... 8 9]"
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
with self.cached_session():
|
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor, summarize=3)
|
print_op = logging_ops.print_v2(tensor, summarize=3)
|
||||||
self.evaluate(print_op)
|
self.evaluate(print_op)
|
||||||
|
|
||||||
expected = "[0 1 2 ... 7 8 9]"
|
expected = "[0 1 2 ... 7 8 9]"
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
with self.cached_session():
|
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor, summarize=-1)
|
print_op = logging_ops.print_v2(tensor, summarize=-1)
|
||||||
self.evaluate(print_op)
|
self.evaluate(print_op)
|
||||||
|
|
||||||
expected = "[0 1 2 3 4 5 6 7 8 9]"
|
expected = "[0 1 2 3 4 5 6 7 8 9]"
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintOneVariable(self):
|
def testPrintOneVariable(self):
|
||||||
with self.cached_session():
|
|
||||||
var = variables.Variable(math_ops.range(10))
|
var = variables.Variable(math_ops.range(10))
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
variables.global_variables_initializer().run()
|
variables.global_variables_initializer().run()
|
||||||
@ -141,11 +131,9 @@ class PrintV2Test(test.TestCase):
|
|||||||
print_op = logging_ops.print_v2(var)
|
print_op = logging_ops.print_v2(var)
|
||||||
self.evaluate(print_op)
|
self.evaluate(print_op)
|
||||||
expected = "[0 1 2 ... 7 8 9]"
|
expected = "[0 1 2 ... 7 8 9]"
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintTwoVariablesInStructWithAssignAdd(self):
|
def testPrintTwoVariablesInStructWithAssignAdd(self):
|
||||||
with self.cached_session():
|
|
||||||
var_one = variables.Variable(2.14)
|
var_one = variables.Variable(2.14)
|
||||||
plus_one = var_one.assign_add(1.0)
|
plus_one = var_one.assign_add(1.0)
|
||||||
var_two = variables.Variable(math_ops.range(10))
|
var_two = variables.Variable(math_ops.range(10))
|
||||||
@ -156,21 +144,17 @@ class PrintV2Test(test.TestCase):
|
|||||||
print_op = logging_ops.print_v2(var_one, {"second": var_two})
|
print_op = logging_ops.print_v2(var_one, {"second": var_two})
|
||||||
self.evaluate(print_op)
|
self.evaluate(print_op)
|
||||||
expected = "3.14 {'second': [0 1 2 ... 7 8 9]}"
|
expected = "3.14 {'second': [0 1 2 ... 7 8 9]}"
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintTwoTensors(self):
|
def testPrintTwoTensors(self):
|
||||||
with self.cached_session():
|
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor, tensor * 10)
|
print_op = logging_ops.print_v2(tensor, tensor * 10)
|
||||||
self.evaluate(print_op)
|
self.evaluate(print_op)
|
||||||
expected = "[0 1 2 ... 7 8 9] [0 10 20 ... 70 80 90]"
|
expected = "[0 1 2 ... 7 8 9] [0 10 20 ... 70 80 90]"
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintTwoTensorsDifferentSep(self):
|
def testPrintTwoTensorsDifferentSep(self):
|
||||||
with self.cached_session():
|
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor, tensor * 10, sep="<separator>")
|
print_op = logging_ops.print_v2(tensor, tensor * 10, sep="<separator>")
|
||||||
@ -178,48 +162,38 @@ class PrintV2Test(test.TestCase):
|
|||||||
expected = "[0 1 2 ... 7 8 9]<separator>[0 10 20 ... 70 80 90]"
|
expected = "[0 1 2 ... 7 8 9]<separator>[0 10 20 ... 70 80 90]"
|
||||||
self.assertIn(expected + "\n", printed.contents())
|
self.assertIn(expected + "\n", printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintPlaceholderGeneration(self):
|
def testPrintPlaceholderGeneration(self):
|
||||||
with self.cached_session():
|
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2("{}6", {"{}": tensor * 10})
|
print_op = logging_ops.print_v2("{}6", {"{}": tensor * 10})
|
||||||
self.evaluate(print_op)
|
self.evaluate(print_op)
|
||||||
expected = "{}6 {'{}': [0 10 20 ... 70 80 90]}"
|
expected = "{}6 {'{}': [0 10 20 ... 70 80 90]}"
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintNoTensors(self):
|
def testPrintNoTensors(self):
|
||||||
with self.cached_session():
|
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(23, [23, 5], {"6": 12})
|
print_op = logging_ops.print_v2(23, [23, 5], {"6": 12})
|
||||||
self.evaluate(print_op)
|
self.evaluate(print_op)
|
||||||
expected = "23 [23, 5] {'6': 12}"
|
expected = "23 [23, 5] {'6': 12}"
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintFloatScalar(self):
|
def testPrintFloatScalar(self):
|
||||||
with self.cached_session():
|
|
||||||
tensor = ops.convert_to_tensor(434.43)
|
tensor = ops.convert_to_tensor(434.43)
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor)
|
print_op = logging_ops.print_v2(tensor)
|
||||||
self.evaluate(print_op)
|
self.evaluate(print_op)
|
||||||
expected = "434.43"
|
expected = "434.43"
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintStringScalar(self):
|
def testPrintStringScalar(self):
|
||||||
with self.cached_session():
|
|
||||||
tensor = ops.convert_to_tensor("scalar")
|
tensor = ops.convert_to_tensor("scalar")
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor)
|
print_op = logging_ops.print_v2(tensor)
|
||||||
self.evaluate(print_op)
|
self.evaluate(print_op)
|
||||||
expected = "scalar"
|
expected = "scalar"
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintStringScalarDifferentEnd(self):
|
def testPrintStringScalarDifferentEnd(self):
|
||||||
with self.cached_session():
|
|
||||||
tensor = ops.convert_to_tensor("scalar")
|
tensor = ops.convert_to_tensor("scalar")
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
print_op = logging_ops.print_v2(tensor, end="<customend>")
|
print_op = logging_ops.print_v2(tensor, end="<customend>")
|
||||||
@ -227,9 +201,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
expected = "scalar<customend>"
|
expected = "scalar<customend>"
|
||||||
self.assertIn(expected, printed.contents())
|
self.assertIn(expected, printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintComplexTensorStruct(self):
|
def testPrintComplexTensorStruct(self):
|
||||||
with self.cached_session():
|
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
small_tensor = constant_op.constant([0.3, 12.4, -16.1])
|
small_tensor = constant_op.constant([0.3, 12.4, -16.1])
|
||||||
big_tensor = math_ops.mul(tensor, 10)
|
big_tensor = math_ops.mul(tensor, 10)
|
||||||
@ -245,11 +217,9 @@ class PrintV2Test(test.TestCase):
|
|||||||
"middle: {'Big': [0 10 20 ... 70 80 90], "
|
"middle: {'Big': [0 10 20 ... 70 80 90], "
|
||||||
"'small': [0.3 12.4 -16.1]} "
|
"'small': [0.3 12.4 -16.1]} "
|
||||||
"10 [[0 2 4 ... 14 16 18], [0 1 2 ... 7 8 9]]")
|
"10 [[0 2 4 ... 14 16 18], [0 1 2 ... 7 8 9]]")
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintSparseTensor(self):
|
def testPrintSparseTensor(self):
|
||||||
with self.cached_session():
|
|
||||||
ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
|
ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
|
||||||
val = [0, 10, 13, 4, 14, 32, 33]
|
val = [0, 10, 13, 4, 14, 32, 33]
|
||||||
shape = [5, 6]
|
shape = [5, 6]
|
||||||
@ -269,11 +239,9 @@ class PrintV2Test(test.TestCase):
|
|||||||
" [1 4]\n"
|
" [1 4]\n"
|
||||||
" [3 2]\n"
|
" [3 2]\n"
|
||||||
" [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])'")
|
" [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])'")
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintSparseTensorInDataStruct(self):
|
def testPrintSparseTensorInDataStruct(self):
|
||||||
with self.cached_session():
|
|
||||||
ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
|
ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
|
||||||
val = [0, 10, 13, 4, 14, 32, 33]
|
val = [0, 10, 13, 4, 14, 32, 33]
|
||||||
shape = [5, 6]
|
shape = [5, 6]
|
||||||
@ -293,20 +261,17 @@ class PrintV2Test(test.TestCase):
|
|||||||
" [1 4]\n"
|
" [1 4]\n"
|
||||||
" [3 2]\n"
|
" [3 2]\n"
|
||||||
" [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])']")
|
" [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])']")
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintOneTensorStdout(self):
|
def testPrintOneTensorStdout(self):
|
||||||
with self.cached_session():
|
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.captureWritesToStream(sys.stdout) as printed:
|
with self.captureWritesToStream(sys.stdout) as printed:
|
||||||
print_op = logging_ops.print_v2(
|
print_op = logging_ops.print_v2(
|
||||||
tensor, output_stream=sys.stdout)
|
tensor, output_stream=sys.stdout)
|
||||||
self.evaluate(print_op)
|
self.evaluate(print_op)
|
||||||
expected = "[0 1 2 ... 7 8 9]"
|
expected = "[0 1 2 ... 7 8 9]"
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintTensorsToFile(self):
|
def testPrintTensorsToFile(self):
|
||||||
tmpfile_name = tempfile.mktemp(".printv2_test")
|
tmpfile_name = tempfile.mktemp(".printv2_test")
|
||||||
tensor_0 = math_ops.range(0, 10)
|
tensor_0 = math_ops.range(0, 10)
|
||||||
@ -330,9 +295,7 @@ class PrintV2Test(test.TestCase):
|
|||||||
except IOError as e:
|
except IOError as e:
|
||||||
self.fail(e)
|
self.fail(e)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testInvalidOutputStreamRaisesError(self):
|
def testInvalidOutputStreamRaisesError(self):
|
||||||
with self.cached_session():
|
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
print_op = logging_ops.print_v2(
|
print_op = logging_ops.print_v2(
|
||||||
@ -341,14 +304,12 @@ class PrintV2Test(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testPrintOpName(self):
|
def testPrintOpName(self):
|
||||||
with self.cached_session():
|
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
print_op = logging_ops.print_v2(tensor, name="print_name")
|
print_op = logging_ops.print_v2(tensor, name="print_name")
|
||||||
self.assertEqual(print_op.name, "print_name")
|
self.assertEqual(print_op.name, "print_name")
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testNoDuplicateFormatOpGraphModeAfterExplicitFormat(self):
|
def testNoDuplicateFormatOpGraphModeAfterExplicitFormat(self):
|
||||||
with self.cached_session():
|
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
formatted_string = string_ops.string_format("{}", tensor)
|
formatted_string = string_ops.string_format("{}", tensor)
|
||||||
print_op = logging_ops.print_v2(formatted_string)
|
print_op = logging_ops.print_v2(formatted_string)
|
||||||
@ -359,13 +320,12 @@ class PrintV2Test(test.TestCase):
|
|||||||
self.assertEqual(len(format_ops), 1)
|
self.assertEqual(len(format_ops), 1)
|
||||||
|
|
||||||
def testPrintOneTensorEagerOnOpCreate(self):
|
def testPrintOneTensorEagerOnOpCreate(self):
|
||||||
with self.cached_session():
|
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
tensor = math_ops.range(10)
|
tensor = math_ops.range(10)
|
||||||
expected = "[0 1 2 ... 7 8 9]"
|
expected = "[0 1 2 ... 7 8 9]"
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
logging_ops.print_v2(tensor)
|
logging_ops.print_v2(tensor)
|
||||||
self.assertTrue((expected + "\n") in printed.contents())
|
self.assertIn((expected + "\n"), printed.contents())
|
||||||
|
|
||||||
def testPrintsOrderedInDefun(self):
|
def testPrintsOrderedInDefun(self):
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
@ -378,9 +338,8 @@ class PrintV2Test(test.TestCase):
|
|||||||
|
|
||||||
with self.captureWritesToStream(sys.stderr) as printed:
|
with self.captureWritesToStream(sys.stderr) as printed:
|
||||||
prints()
|
prints()
|
||||||
self.assertTrue(("A\nB\nC\n") in printed.contents())
|
self.assertTrue(("A\nB\nC\n"), printed.contents())
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes()
|
|
||||||
def testPrintInDefunWithoutExplicitEvalOfPrint(self):
|
def testPrintInDefunWithoutExplicitEvalOfPrint(self):
|
||||||
@function.defun
|
@function.defun
|
||||||
def f():
|
def f():
|
||||||
@ -392,14 +351,14 @@ class PrintV2Test(test.TestCase):
|
|||||||
with self.captureWritesToStream(sys.stderr) as printed_one:
|
with self.captureWritesToStream(sys.stderr) as printed_one:
|
||||||
x = f()
|
x = f()
|
||||||
self.evaluate(x)
|
self.evaluate(x)
|
||||||
self.assertTrue((expected + "\n") in printed_one.contents())
|
self.assertIn((expected + "\n"), printed_one.contents())
|
||||||
|
|
||||||
# We execute the function again to make sure it doesn't only print on the
|
# We execute the function again to make sure it doesn't only print on the
|
||||||
# first call.
|
# first call.
|
||||||
with self.captureWritesToStream(sys.stderr) as printed_two:
|
with self.captureWritesToStream(sys.stderr) as printed_two:
|
||||||
y = f()
|
y = f()
|
||||||
self.evaluate(y)
|
self.evaluate(y)
|
||||||
self.assertTrue((expected + "\n") in printed_two.contents())
|
self.assertIn((expected + "\n"), printed_two.contents())
|
||||||
|
|
||||||
|
|
||||||
class PrintGradientTest(test.TestCase):
|
class PrintGradientTest(test.TestCase):
|
||||||
@ -417,7 +376,6 @@ class PrintGradientTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testPrintGradient(self):
|
def testPrintGradient(self):
|
||||||
with self.cached_session():
|
|
||||||
inp = constant_op.constant(2.0, shape=[100, 32], name="in")
|
inp = constant_op.constant(2.0, shape=[100, 32], name="in")
|
||||||
w = constant_op.constant(4.0, shape=[10, 100], name="w")
|
w = constant_op.constant(4.0, shape=[10, 100], name="w")
|
||||||
wx = math_ops.matmul(w, inp, name="wx")
|
wx = math_ops.matmul(w, inp, name="wx")
|
||||||
|
Loading…
Reference in New Issue
Block a user