Add tf.function test for device placement logging

PiperOrigin-RevId: 312789434
Change-Id: I26b4f34546cfe759a484a7f2b5b0bb234512d333
This commit is contained in:
Gaurav Jain 2020-05-21 20:41:46 -07:00 committed by TensorFlower Gardener
parent 221af69be0
commit 21fdbbb07f
1 changed files with 18 additions and 1 deletions

View File

@ -34,6 +34,7 @@ from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as framework_device_lib from tensorflow.python.framework import device as framework_device_lib
@ -1911,8 +1912,8 @@ class SessionTest(test_util.TensorFlowTestCase):
def __str__(self): def __str__(self):
return self._output return self._output
context.set_log_device_placement(True)
if context.executing_eagerly(): if context.executing_eagerly():
context.set_log_device_placement(True)
with CaptureStderr() as log: with CaptureStderr() as log:
a = constant_op.constant(1) a = constant_op.constant(1)
b = constant_op.constant(2) b = constant_op.constant(2)
@ -1939,6 +1940,22 @@ class SessionTest(test_util.TensorFlowTestCase):
add_executions = [l for l in str(log).splitlines() if 'AddV2' in l] add_executions = [l for l in str(log).splitlines() if 'AddV2' in l]
self.assertEqual(len(add_executions), 2) self.assertEqual(len(add_executions), 2)
@def_function.function
def fn():
a = constant_op.constant(1)
b = constant_op.constant(2)
c = a + b
d = a + b
return c, d
with CaptureStderr() as log:
c, d = self.evaluate(fn())
self.assertEqual(c, 3)
self.assertEqual(d, 3)
# Ensure that we did log device placement.
add_executions = [l for l in str(log).splitlines() if 'AddV2' in l]
self.assertEqual(len(add_executions), 2)
@test_util.run_v1_only('b/120545219') @test_util.run_v1_only('b/120545219')
def testLocalMasterSessionTimeout(self): def testLocalMasterSessionTimeout(self):
# Test that the timeout passed in a config to the session works correctly. # Test that the timeout passed in a config to the session works correctly.