Add tf.function test for device placement logging
PiperOrigin-RevId: 312789434 Change-Id: I26b4f34546cfe759a484a7f2b5b0bb234512d333
This commit is contained in:
parent
221af69be0
commit
21fdbbb07f
|
@ -34,6 +34,7 @@ from tensorflow.core.lib.core import error_codes_pb2
|
|||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as framework_device_lib
|
||||
|
@ -1911,8 +1912,8 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||
def __str__(self):
|
||||
return self._output
|
||||
|
||||
context.set_log_device_placement(True)
|
||||
if context.executing_eagerly():
|
||||
context.set_log_device_placement(True)
|
||||
with CaptureStderr() as log:
|
||||
a = constant_op.constant(1)
|
||||
b = constant_op.constant(2)
|
||||
|
@ -1939,6 +1940,22 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||
add_executions = [l for l in str(log).splitlines() if 'AddV2' in l]
|
||||
self.assertEqual(len(add_executions), 2)
|
||||
|
||||
@def_function.function
|
||||
def fn():
|
||||
a = constant_op.constant(1)
|
||||
b = constant_op.constant(2)
|
||||
c = a + b
|
||||
d = a + b
|
||||
return c, d
|
||||
|
||||
with CaptureStderr() as log:
|
||||
c, d = self.evaluate(fn())
|
||||
self.assertEqual(c, 3)
|
||||
self.assertEqual(d, 3)
|
||||
# Ensure that we did log device placement.
|
||||
add_executions = [l for l in str(log).splitlines() if 'AddV2' in l]
|
||||
self.assertEqual(len(add_executions), 2)
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
def testLocalMasterSessionTimeout(self):
|
||||
# Test that the timeout passed in a config to the session works correctly.
|
||||
|
|
Loading…
Reference in New Issue