From 21fdbbb07f8ff7d27d3545d740c0bace5a3f23eb Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Thu, 21 May 2020 20:41:46 -0700 Subject: [PATCH] Add tf.function test for device placement logging PiperOrigin-RevId: 312789434 Change-Id: I26b4f34546cfe759a484a7f2b5b0bb234512d333 --- tensorflow/python/client/session_test.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 1c244c1b297..074b50bf69b 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -34,6 +34,7 @@ from tensorflow.core.lib.core import error_codes_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as framework_device_lib @@ -1911,8 +1912,8 @@ class SessionTest(test_util.TensorFlowTestCase): def __str__(self): return self._output + context.set_log_device_placement(True) if context.executing_eagerly(): - context.set_log_device_placement(True) with CaptureStderr() as log: a = constant_op.constant(1) b = constant_op.constant(2) @@ -1939,6 +1940,22 @@ class SessionTest(test_util.TensorFlowTestCase): add_executions = [l for l in str(log).splitlines() if 'AddV2' in l] self.assertEqual(len(add_executions), 2) + @def_function.function + def fn(): + a = constant_op.constant(1) + b = constant_op.constant(2) + c = a + b + d = a + b + return c, d + + with CaptureStderr() as log: + c, d = self.evaluate(fn()) + self.assertEqual(c, 3) + self.assertEqual(d, 3) + # Ensure that we did log device placement. + add_executions = [l for l in str(log).splitlines() if 'AddV2' in l] + self.assertEqual(len(add_executions), 2) + @test_util.run_v1_only('b/120545219') def testLocalMasterSessionTimeout(self): # Test that the timeout passed in a config to the session works correctly.