diff --git a/tensorflow/python/distribute/tf_function_test.py b/tensorflow/python/distribute/tf_function_test.py index 5dc82cfd81b..6621f51cf32 100644 --- a/tensorflow/python/distribute/tf_function_test.py +++ b/tensorflow/python/distribute/tf_function_test.py @@ -50,7 +50,11 @@ class TFFunctionTest(test.TestCase, parameterized.TestCase): self, distribution, run_functions_eagerly): def_function.run_functions_eagerly(run_functions_eagerly) - expected_device = (device_util.canonicalize("cpu:0") + try: + worker = distribution.extended.worker_devices[0] + except RuntimeError: + worker = None + expected_device = (device_util.canonicalize("cpu:0", worker) if run_functions_eagerly else "") with distribution.scope(): with ops.device_v2("cpu:0"): @@ -72,7 +76,11 @@ class TFFunctionTest(test.TestCase, parameterized.TestCase): self, distribution, run_functions_eagerly): def_function.run_functions_eagerly(run_functions_eagerly) - expected_device = (device_util.canonicalize("cpu:0") + try: + worker = distribution.extended.worker_devices[0] + except RuntimeError: + worker = None + expected_device = (device_util.canonicalize("cpu:0", worker) if run_functions_eagerly else "") with distribution.scope(): @def_function.function