From 79cd2fd4ac45ab4304b7762d6c17db5774c30cf9 Mon Sep 17 00:00:00 2001
From: Isha Arkatkar <ishark@google.com>
Date: Wed, 17 Feb 2021 11:37:23 -0800
Subject: [PATCH] Update check for multiple devices found to check whether
 there is a single device colocated with the default_device.

If multiple devices are found in the cluster, but there is a single device colocated with default_device, the single colocated device would be used by placer. This change relaxes the current check to account for parent device scope.

For example,
@tf.function()
def replica_fn()
  with tf.device("gpu:0")

with tf.device("worker1"):
   replica_fn()

In this case there may be multiple gpu:0 in the cluster, but gpu:0 on worker 1 should be picked. In this case, we should not receive multiple devices found error.

PiperOrigin-RevId: 357996270
Change-Id: I6a19deb5f26d7741ba7f7adff7f14550679b3d6d
---
 .../process_function_library_runtime.cc       | 16 +++++
 tensorflow/python/eager/BUILD                 |  1 +
 tensorflow/python/eager/function_test.py      | 12 ++++
 tensorflow/python/eager/remote_test.py        | 58 ++++++++++++++-----
 4 files changed, 73 insertions(+), 14 deletions(-)

diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 659d29601e8..c837f1de14b 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -516,6 +516,22 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
             if (on_same_task) {
               continue;
             }
+            // Compare with default_device if it has a narrower scope matching
+            // requested device.
+            int colocated_on_default_device = 0;
+            for (int i = 0; i < matching_devices.size(); ++i) {
+              if (DeviceNameUtils::IsSameAddressSpace(
+                      default_device->parsed_name(),
+                      matching_devices.at(i)->parsed_name())) {
+                colocated_on_default_device++;
+              }
+            }
+            // Continue to raise error if multiple colocated devices are
+            // found.
+            if (colocated_on_default_device == 1) {
+              continue;
+            }
+
             // Convert a vector of devices to a string.
             // Using absl::StrJoin did not work in Android builds.
             string devices = "[";
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 66b0771f3d1..6e8d4363505 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -1046,6 +1046,7 @@ cuda_py_test(
         "//tensorflow/python:functional_ops",
         "//tensorflow/python:resource_variable_ops",
         "//tensorflow/python:tensor_spec",
+        "//tensorflow/python:test_ops",
         "@absl_py//absl/testing:parameterized",
         "@six_archive//:six",
     ],
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 5c8ee1459f2..92e9c17df12 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -1795,6 +1795,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
       has_device.f()
     self.assertIn('CPU', has_device.v.device)
 
+  @test_util.run_in_graph_and_eager_modes
+  def testMultipleDeviceCheck(self):
+
+    def f():
+      with ops.device('cpu'):
+        return test_ops.device_placement_op()
+
+    func = function.defun(f)
+    with ops.device('cpu:0'):
+      output = self.evaluate(func())
+      self.assertIn(compat.as_bytes('CPU:0'), output)
+
   @test_util.run_in_graph_and_eager_modes
   def testDeviceAnnotationsRespected(self):
 
diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py
index ebe70d9b964..2b23a15abcc 100644
--- a/tensorflow/python/eager/remote_test.py
+++ b/tensorflow/python/eager/remote_test.py
@@ -38,6 +38,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_spec
+from tensorflow.python.framework import test_ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
@@ -45,9 +46,11 @@ from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import functional_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import string_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.training import server_lib
 from tensorflow.python.training.server_lib import ClusterSpec
+from tensorflow.python.util import compat
 
 
 class SingleWorkerTest(test.TestCase, parameterized.TestCase):
@@ -112,20 +115,6 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
     self.assertEqual(rets[1].backing_device,
                      '/job:worker/replica:0/task:0/device:CPU:0')
 
-  def testMultiDeviceFunctionAmbiguousDevice(self):
-
-    @def_function.function
-    def ambiguous_device(i):
-      with ops.device('cpu:0'):
-        return i + constant_op.constant([2])
-
-    with self.assertRaises(errors.InvalidArgumentError) as cm:
-      with ops.device('/job:worker/replica:0/task:0/cpu:0'):
-        ambiguous_device(constant_op.constant([2])).numpy()
-
-    self.assertIn('the output node must match exactly one device',
-                  cm.exception.message)
-
   def testStreaming(self):
     """A mini stress test for streaming - issuing many RPCs back to back."""
     with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
@@ -318,6 +307,21 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
     with ops.device('/job:worker/replica:0/task:1'):
       self.assertAllEqual(local_func(x), [2, 1])
 
+  def testMultiDeviceFunctionAmbiguousDevice(self):
+
+    @def_function.function
+    def ambiguous_device(i):
+      with ops.device('/job:worker'):
+        # Multiple worker tasks, thus ambiguous device found error will be
+        # raised.
+        return i + constant_op.constant([2])
+
+    with self.assertRaises(errors.InvalidArgumentError) as cm:
+      ambiguous_device(constant_op.constant([2])).numpy()
+
+    self.assertIn('the output node must match exactly one device',
+                  cm.exception.message)
+
   # Note that the following tests for remote function cancellation only works
   # when non-streaming RPC. We need to disable streaming explicitly and restore
   # this config to its initial value at the end of each test case.
@@ -579,6 +583,32 @@ class MultiJobsTest(test.TestCase, parameterized.TestCase):
     # Reset the context to avoid polluting other test cases.
     context._reset_context()
 
+  def testMultipleDeviceFoundCheck(self):
+    remote.connect_to_cluster(self._cluster)
+
+    @def_function.function
+    def func():
+      with ops.device('cpu:0'):
+        # Multiple CPU:0 devices match would be found, but the CPU:0 from the
+        # parent device scope should be picked.
+        x = test_ops.device_placement_op()
+        y = string_ops.string_upper(x)
+        packed_var_0 = array_ops.stack([x, y], 0)
+        return packed_var_0
+
+    with ops.device('/job:my_worker/task:1'):
+      output = self.evaluate(func())
+      self.assertEqual(
+          compat.as_bytes('/job:my_worker/replica:0/task:1/device:CPU:0'),
+          output[0])
+      self.assertIn(compat.as_bytes('/JOB:MY_WORKER'), output[1])
+    with ops.device('/job:my_ps/task:1'):
+      output = self.evaluate(func())
+      self.assertEqual(
+          compat.as_bytes('/job:my_ps/replica:0/task:1/device:CPU:0'),
+          output[0])
+      self.assertIn(compat.as_bytes('/JOB:MY_PS'), output[1])
+
   def testSimpleParameterServer(self):
     remote.connect_to_cluster(self._cluster)