diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 4c01978e6d5..7bd5d09be97 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -410,6 +410,18 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets( << " src_device: " << *src_device << " colo group: " << colocation_group; } + // If colocation_group is not set and output producing node is assigned + // to a remote device, colocate the retval node with its input node. + // TODO(yujingzhang): Remove this when we support outputting tensors on + // remote devices. + const bool remote_src_device = + !src_device->empty() && GetFLR(*src_device) == nullptr; + if (colocation_group.empty() && remote_src_device) { + colocation_group = + absl::StrCat(kColocationGroupPrefix, it->src()->name()); + VLOG(3) << "Considering src: " << src_node->name() + << " colo group: " << colocation_group; + } // If resource is produced by a function call node, we can't trust // source node device assignment, because multi-device functions can diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py index f4322b8b0e4..acafbb2626d 100644 --- a/tensorflow/python/eager/remote_test.py +++ b/tensorflow/python/eager/remote_test.py @@ -94,12 +94,8 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase): c = variable_b + 1 return c, i + variable_b - with self.assertRaises(errors.UnimplementedError) as cm: - remote_output(constant_op.constant([1])) - - self.assertIn( - 'Currently, outputting tensors on remote devices is not supported.', - cm.exception.message) + self.assertAllEqual( + remote_output(constant_op.constant([1]))[0].numpy(), 2) def testMultiDeviceFunctionAmbiguousDevice(self): @@ -176,6 +172,19 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase): # Reset the context to avoid polluting other test cases. context._reset_context() + @test_util.eager_lazy_remote_copy_on_and_off + def testReturnRemoteArgument(self): + + @def_function.function + def local_func(i): + return i + + with ops.device('/job:worker/replica:0/task:0'): + x = constant_op.constant([2, 1]) + + with ops.device('/job:worker/replica:0/task:1'): + self.assertAllEqual(local_func(x), [2, 1]) + @test_util.eager_lazy_remote_copy_on_and_off def testMultiDeviceFunctionOnLocalDevice(self): with ops.device('/job:worker/replica:0/task:1'):