diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 659d29601e8..c837f1de14b 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -516,6 +516,22 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets( if (on_same_task) { continue; } + // Compare with default_device if it has a narrower scope matching + // requested device. + int colocated_on_default_device = 0; + for (int i = 0; i < matching_devices.size(); ++i) { + if (DeviceNameUtils::IsSameAddressSpace( + default_device->parsed_name(), + matching_devices.at(i)->parsed_name())) { + colocated_on_default_device++; + } + } + // Continue to raise error if multiple colocated devices are + // found. + if (colocated_on_default_device == 1) { + continue; + } + // Convert a vector of devices to a string. // Using absl::StrJoin did not work in Android builds. string devices = "["; diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 66b0771f3d1..6e8d4363505 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -1046,6 +1046,7 @@ cuda_py_test( "//tensorflow/python:functional_ops", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:tensor_spec", + "//tensorflow/python:test_ops", "@absl_py//absl/testing:parameterized", "@six_archive//:six", ], diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 5c8ee1459f2..92e9c17df12 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1795,6 +1795,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase): has_device.f() self.assertIn('CPU', has_device.v.device) + @test_util.run_in_graph_and_eager_modes + def testMultipleDeviceCheck(self): + + def f(): + with ops.device('cpu'): + return test_ops.device_placement_op() + + func = function.defun(f) + with ops.device('cpu:0'): + output = self.evaluate(func()) + self.assertIn(compat.as_bytes('CPU:0'), output) + @test_util.run_in_graph_and_eager_modes def testDeviceAnnotationsRespected(self): diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py index ebe70d9b964..2b23a15abcc 100644 --- a/tensorflow/python/eager/remote_test.py +++ b/tensorflow/python/eager/remote_test.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -45,9 +46,11 @@ from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.training import server_lib from tensorflow.python.training.server_lib import ClusterSpec +from tensorflow.python.util import compat class SingleWorkerTest(test.TestCase, parameterized.TestCase): @@ -112,20 +115,6 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase): self.assertEqual(rets[1].backing_device, '/job:worker/replica:0/task:0/device:CPU:0') - def testMultiDeviceFunctionAmbiguousDevice(self): - - @def_function.function - def ambiguous_device(i): - with ops.device('cpu:0'): - return i + constant_op.constant([2]) - - with self.assertRaises(errors.InvalidArgumentError) as cm: - with ops.device('/job:worker/replica:0/task:0/cpu:0'): - ambiguous_device(constant_op.constant([2])).numpy() - - self.assertIn('the output node must match exactly one device', - cm.exception.message) - def testStreaming(self): """A mini stress test for streaming - issuing many RPCs back to back.""" with ops.device('job:worker/replica:0/task:0/device:CPU:0'): @@ -318,6 +307,21 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase): with ops.device('/job:worker/replica:0/task:1'): self.assertAllEqual(local_func(x), [2, 1]) + def testMultiDeviceFunctionAmbiguousDevice(self): + + @def_function.function + def ambiguous_device(i): + with ops.device('/job:worker'): + # Multiple worker tasks, thus ambiguous device found error will be + # raised. + return i + constant_op.constant([2]) + + with self.assertRaises(errors.InvalidArgumentError) as cm: + ambiguous_device(constant_op.constant([2])).numpy() + + self.assertIn('the output node must match exactly one device', + cm.exception.message) + # Note that the following tests for remote function cancellation only works # when non-streaming RPC. We need to disable streaming explicitly and restore # this config to its initial value at the end of each test case. @@ -579,6 +583,32 @@ class MultiJobsTest(test.TestCase, parameterized.TestCase): # Reset the context to avoid polluting other test cases. context._reset_context() + def testMultipleDeviceFoundCheck(self): + remote.connect_to_cluster(self._cluster) + + @def_function.function + def func(): + with ops.device('cpu:0'): + # Multiple CPU:0 devices match would be found, but the CPU:0 from the + # parent device scope should be picked. + x = test_ops.device_placement_op() + y = string_ops.string_upper(x) + packed_var_0 = array_ops.stack([x, y], 0) + return packed_var_0 + + with ops.device('/job:my_worker/task:1'): + output = self.evaluate(func()) + self.assertEqual( + compat.as_bytes('/job:my_worker/replica:0/task:1/device:CPU:0'), + output[0]) + self.assertIn(compat.as_bytes('/JOB:MY_WORKER'), output[1]) + with ops.device('/job:my_ps/task:1'): + output = self.evaluate(func()) + self.assertEqual( + compat.as_bytes('/job:my_ps/replica:0/task:1/device:CPU:0'), + output[0]) + self.assertIn(compat.as_bytes('/JOB:MY_PS'), output[1]) + def testSimpleParameterServer(self): remote.connect_to_cluster(self._cluster)