Use first worker as default device in tf_function_test.
PiperOrigin-RevId: 315316217 Change-Id: I45565148d19756eb8a6038549ea4df5f20c05c70
This commit is contained in:
parent
a5e5e94904
commit
be20584437
@ -50,7 +50,11 @@ class TFFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self, distribution, run_functions_eagerly):
|
||||
|
||||
def_function.run_functions_eagerly(run_functions_eagerly)
|
||||
expected_device = (device_util.canonicalize("cpu:0")
|
||||
try:
|
||||
worker = distribution.extended.worker_devices[0]
|
||||
except RuntimeError:
|
||||
worker = None
|
||||
expected_device = (device_util.canonicalize("cpu:0", worker)
|
||||
if run_functions_eagerly else "")
|
||||
with distribution.scope():
|
||||
with ops.device_v2("cpu:0"):
|
||||
@ -72,7 +76,11 @@ class TFFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self, distribution, run_functions_eagerly):
|
||||
|
||||
def_function.run_functions_eagerly(run_functions_eagerly)
|
||||
expected_device = (device_util.canonicalize("cpu:0")
|
||||
try:
|
||||
worker = distribution.extended.worker_devices[0]
|
||||
except RuntimeError:
|
||||
worker = None
|
||||
expected_device = (device_util.canonicalize("cpu:0", worker)
|
||||
if run_functions_eagerly else "")
|
||||
with distribution.scope():
|
||||
@def_function.function
|
||||
|
Loading…
Reference in New Issue
Block a user