Add tf.function test for device placement logging
PiperOrigin-RevId: 312789434 Change-Id: I26b4f34546cfe759a484a7f2b5b0bb234512d333
This commit is contained in:
parent
221af69be0
commit
21fdbbb07f
|
@ -34,6 +34,7 @@ from tensorflow.core.lib.core import error_codes_pb2
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import config
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import device as framework_device_lib
|
from tensorflow.python.framework import device as framework_device_lib
|
||||||
|
@ -1911,8 +1912,8 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self._output
|
return self._output
|
||||||
|
|
||||||
|
context.set_log_device_placement(True)
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
context.set_log_device_placement(True)
|
|
||||||
with CaptureStderr() as log:
|
with CaptureStderr() as log:
|
||||||
a = constant_op.constant(1)
|
a = constant_op.constant(1)
|
||||||
b = constant_op.constant(2)
|
b = constant_op.constant(2)
|
||||||
|
@ -1939,6 +1940,22 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||||
add_executions = [l for l in str(log).splitlines() if 'AddV2' in l]
|
add_executions = [l for l in str(log).splitlines() if 'AddV2' in l]
|
||||||
self.assertEqual(len(add_executions), 2)
|
self.assertEqual(len(add_executions), 2)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def fn():
|
||||||
|
a = constant_op.constant(1)
|
||||||
|
b = constant_op.constant(2)
|
||||||
|
c = a + b
|
||||||
|
d = a + b
|
||||||
|
return c, d
|
||||||
|
|
||||||
|
with CaptureStderr() as log:
|
||||||
|
c, d = self.evaluate(fn())
|
||||||
|
self.assertEqual(c, 3)
|
||||||
|
self.assertEqual(d, 3)
|
||||||
|
# Ensure that we did log device placement.
|
||||||
|
add_executions = [l for l in str(log).splitlines() if 'AddV2' in l]
|
||||||
|
self.assertEqual(len(add_executions), 2)
|
||||||
|
|
||||||
@test_util.run_v1_only('b/120545219')
|
@test_util.run_v1_only('b/120545219')
|
||||||
def testLocalMasterSessionTimeout(self):
|
def testLocalMasterSessionTimeout(self):
|
||||||
# Test that the timeout passed in a config to the session works correctly.
|
# Test that the timeout passed in a config to the session works correctly.
|
||||||
|
|
Loading…
Reference in New Issue