For a multi-device function, if the output producing node is assigned to a remote device, colocate the retval node with its input node.
If a retval node is assigned to a remote device, the function instantiation would fail with an error "outputting tensors on remote devices is not supported.". After this change, a retval node which was assigned to a remote device might be assigned to a local device now. PiperOrigin-RevId: 285090649 Change-Id: Iade0982fee600591e3e0e2dfe12d676efa720729
This commit is contained in:
parent
824250a7ed
commit
126c04a9ed
@ -410,6 +410,18 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
|
||||
<< " src_device: " << *src_device
|
||||
<< " colo group: " << colocation_group;
|
||||
}
|
||||
// If colocation_group is not set and output producing node is assigned
|
||||
// to a remote device, colocate the retval node with its input node.
|
||||
// TODO(yujingzhang): Remove this when we support outputting tensors on
|
||||
// remote devices.
|
||||
const bool remote_src_device =
|
||||
!src_device->empty() && GetFLR(*src_device) == nullptr;
|
||||
if (colocation_group.empty() && remote_src_device) {
|
||||
colocation_group =
|
||||
absl::StrCat(kColocationGroupPrefix, it->src()->name());
|
||||
VLOG(3) << "Considering src: " << src_node->name()
|
||||
<< " colo group: " << colocation_group;
|
||||
}
|
||||
|
||||
// If resource is produced by a function call node, we can't trust
|
||||
// source node device assignment, because multi-device functions can
|
||||
|
@ -94,12 +94,8 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
||||
c = variable_b + 1
|
||||
return c, i + variable_b
|
||||
|
||||
with self.assertRaises(errors.UnimplementedError) as cm:
|
||||
remote_output(constant_op.constant([1]))
|
||||
|
||||
self.assertIn(
|
||||
'Currently, outputting tensors on remote devices is not supported.',
|
||||
cm.exception.message)
|
||||
self.assertAllEqual(
|
||||
remote_output(constant_op.constant([1]))[0].numpy(), 2)
|
||||
|
||||
def testMultiDeviceFunctionAmbiguousDevice(self):
|
||||
|
||||
@ -176,6 +172,19 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
||||
# Reset the context to avoid polluting other test cases.
|
||||
context._reset_context()
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testReturnRemoteArgument(self):
|
||||
|
||||
@def_function.function
|
||||
def local_func(i):
|
||||
return i
|
||||
|
||||
with ops.device('/job:worker/replica:0/task:0'):
|
||||
x = constant_op.constant([2, 1])
|
||||
|
||||
with ops.device('/job:worker/replica:0/task:1'):
|
||||
self.assertAllEqual(local_func(x), [2, 1])
|
||||
|
||||
@test_util.eager_lazy_remote_copy_on_and_off
|
||||
def testMultiDeviceFunctionOnLocalDevice(self):
|
||||
with ops.device('/job:worker/replica:0/task:1'):
|
||||
|
Loading…
Reference in New Issue
Block a user