For a multi-device function, if the output producing node is assigned to a remote device, colocate the retval node with its input node.

If a retval node is assigned to a remote device, the function instantiation would fail with an error "outputting tensors on remote devices is not supported.". After this change, a retval node which was assigned to a remote device might be assigned to a local device now.

PiperOrigin-RevId: 285090649
Change-Id: Iade0982fee600591e3e0e2dfe12d676efa720729
This commit is contained in:
Yujing Zhang 2019-12-11 16:50:41 -08:00 committed by TensorFlower Gardener
parent 824250a7ed
commit 126c04a9ed
2 changed files with 27 additions and 6 deletions

View File

@ -410,6 +410,18 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
<< " src_device: " << *src_device
<< " colo group: " << colocation_group;
}
// If colocation_group is not set and output producing node is assigned
// to a remote device, colocate the retval node with its input node.
// TODO(yujingzhang): Remove this when we support outputting tensors on
// remote devices.
const bool remote_src_device =
!src_device->empty() && GetFLR(*src_device) == nullptr;
if (colocation_group.empty() && remote_src_device) {
colocation_group =
absl::StrCat(kColocationGroupPrefix, it->src()->name());
VLOG(3) << "Considering src: " << src_node->name()
<< " colo group: " << colocation_group;
}
// If resource is produced by a function call node, we can't trust
// source node device assignment, because multi-device functions can

View File

@ -94,12 +94,8 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
c = variable_b + 1
return c, i + variable_b
with self.assertRaises(errors.UnimplementedError) as cm:
remote_output(constant_op.constant([1]))
self.assertIn(
'Currently, outputting tensors on remote devices is not supported.',
cm.exception.message)
self.assertAllEqual(
remote_output(constant_op.constant([1]))[0].numpy(), 2)
def testMultiDeviceFunctionAmbiguousDevice(self):
@ -176,6 +172,19 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
# Reset the context to avoid polluting other test cases.
context._reset_context()
@test_util.eager_lazy_remote_copy_on_and_off
def testReturnRemoteArgument(self):
@def_function.function
def local_func(i):
return i
with ops.device('/job:worker/replica:0/task:0'):
x = constant_op.constant([2, 1])
with ops.device('/job:worker/replica:0/task:1'):
self.assertAllEqual(local_func(x), [2, 1])
@test_util.eager_lazy_remote_copy_on_and_off
def testMultiDeviceFunctionOnLocalDevice(self):
with ops.device('/job:worker/replica:0/task:1'):