diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 4c01978e6d5..7bd5d09be97 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -410,6 +410,18 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
                   << " src_device: " << *src_device
                   << " colo group: " << colocation_group;
         }
+        // If colocation_group is not set and output producing node is assigned
+        // to a remote device, colocate the retval node with its input node.
+        // TODO(yujingzhang): Remove this when we support outputting tensors on
+        // remote devices.
+        const bool remote_src_device =
+            !src_device->empty() && GetFLR(*src_device) == nullptr;
+        if (colocation_group.empty() && remote_src_device) {
+          colocation_group =
+              absl::StrCat(kColocationGroupPrefix, it->src()->name());
+          VLOG(3) << "Considering src: " << src_node->name()
+                  << " colo group: " << colocation_group;
+        }
 
         // If resource is produced by a function call node, we can't trust
         // source node device assignment, because multi-device functions can
diff --git a/tensorflow/python/eager/remote_test.py b/tensorflow/python/eager/remote_test.py
index f4322b8b0e4..acafbb2626d 100644
--- a/tensorflow/python/eager/remote_test.py
+++ b/tensorflow/python/eager/remote_test.py
@@ -94,12 +94,8 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
         c = variable_b + 1
       return c, i + variable_b
 
-    with self.assertRaises(errors.UnimplementedError) as cm:
-      remote_output(constant_op.constant([1]))
-
-    self.assertIn(
-        'Currently, outputting tensors on remote devices is not supported.',
-        cm.exception.message)
+    self.assertAllEqual(
+        remote_output(constant_op.constant([1]))[0].numpy(), 2)
 
   def testMultiDeviceFunctionAmbiguousDevice(self):
 
@@ -176,6 +172,19 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
     # Reset the context to avoid polluting other test cases.
     context._reset_context()
 
+  @test_util.eager_lazy_remote_copy_on_and_off
+  def testReturnRemoteArgument(self):
+
+    @def_function.function
+    def local_func(i):
+      return i
+
+    with ops.device('/job:worker/replica:0/task:0'):
+      x = constant_op.constant([2, 1])
+
+    with ops.device('/job:worker/replica:0/task:1'):
+      self.assertAllEqual(local_func(x), [2, 1])
+
   @test_util.eager_lazy_remote_copy_on_and_off
   def testMultiDeviceFunctionOnLocalDevice(self):
     with ops.device('/job:worker/replica:0/task:1'):