Use first worker as default device in tf_function_test.

PiperOrigin-RevId: 315316217
Change-Id: I45565148d19756eb8a6038549ea4df5f20c05c70
This commit is contained in:
Bruce Fontaine 2020-06-08 11:27:48 -07:00 committed by TensorFlower Gardener
parent a5e5e94904
commit be20584437

View File

@ -50,7 +50,11 @@ class TFFunctionTest(test.TestCase, parameterized.TestCase):
self, distribution, run_functions_eagerly):
def_function.run_functions_eagerly(run_functions_eagerly)
expected_device = (device_util.canonicalize("cpu:0")
try:
worker = distribution.extended.worker_devices[0]
except RuntimeError:
worker = None
expected_device = (device_util.canonicalize("cpu:0", worker)
if run_functions_eagerly else "")
with distribution.scope():
with ops.device_v2("cpu:0"):
@ -72,7 +76,11 @@ class TFFunctionTest(test.TestCase, parameterized.TestCase):
self, distribution, run_functions_eagerly):
def_function.run_functions_eagerly(run_functions_eagerly)
expected_device = (device_util.canonicalize("cpu:0")
try:
worker = distribution.extended.worker_devices[0]
except RuntimeError:
worker = None
expected_device = (device_util.canonicalize("cpu:0", worker)
if run_functions_eagerly else "")
with distribution.scope():
@def_function.function