For a multi-device function, if the output producing node is assigned to a remote device, colocate the retval node with its input node.
If a retval node is assigned to a remote device, the function instantiation would fail with an error "outputting tensors on remote devices is not supported.". After this change, a retval node which was assigned to a remote device might be assigned to a local device now. PiperOrigin-RevId: 285090649 Change-Id: Iade0982fee600591e3e0e2dfe12d676efa720729
This commit is contained in:
parent
824250a7ed
commit
126c04a9ed
@ -410,6 +410,18 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
|
|||||||
<< " src_device: " << *src_device
|
<< " src_device: " << *src_device
|
||||||
<< " colo group: " << colocation_group;
|
<< " colo group: " << colocation_group;
|
||||||
}
|
}
|
||||||
|
// If colocation_group is not set and output producing node is assigned
|
||||||
|
// to a remote device, colocate the retval node with its input node.
|
||||||
|
// TODO(yujingzhang): Remove this when we support outputting tensors on
|
||||||
|
// remote devices.
|
||||||
|
const bool remote_src_device =
|
||||||
|
!src_device->empty() && GetFLR(*src_device) == nullptr;
|
||||||
|
if (colocation_group.empty() && remote_src_device) {
|
||||||
|
colocation_group =
|
||||||
|
absl::StrCat(kColocationGroupPrefix, it->src()->name());
|
||||||
|
VLOG(3) << "Considering src: " << src_node->name()
|
||||||
|
<< " colo group: " << colocation_group;
|
||||||
|
}
|
||||||
|
|
||||||
// If resource is produced by a function call node, we can't trust
|
// If resource is produced by a function call node, we can't trust
|
||||||
// source node device assignment, because multi-device functions can
|
// source node device assignment, because multi-device functions can
|
||||||
|
|||||||
@ -94,12 +94,8 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
|||||||
c = variable_b + 1
|
c = variable_b + 1
|
||||||
return c, i + variable_b
|
return c, i + variable_b
|
||||||
|
|
||||||
with self.assertRaises(errors.UnimplementedError) as cm:
|
self.assertAllEqual(
|
||||||
remote_output(constant_op.constant([1]))
|
remote_output(constant_op.constant([1]))[0].numpy(), 2)
|
||||||
|
|
||||||
self.assertIn(
|
|
||||||
'Currently, outputting tensors on remote devices is not supported.',
|
|
||||||
cm.exception.message)
|
|
||||||
|
|
||||||
def testMultiDeviceFunctionAmbiguousDevice(self):
|
def testMultiDeviceFunctionAmbiguousDevice(self):
|
||||||
|
|
||||||
@ -176,6 +172,19 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
|||||||
# Reset the context to avoid polluting other test cases.
|
# Reset the context to avoid polluting other test cases.
|
||||||
context._reset_context()
|
context._reset_context()
|
||||||
|
|
||||||
|
@test_util.eager_lazy_remote_copy_on_and_off
|
||||||
|
def testReturnRemoteArgument(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def local_func(i):
|
||||||
|
return i
|
||||||
|
|
||||||
|
with ops.device('/job:worker/replica:0/task:0'):
|
||||||
|
x = constant_op.constant([2, 1])
|
||||||
|
|
||||||
|
with ops.device('/job:worker/replica:0/task:1'):
|
||||||
|
self.assertAllEqual(local_func(x), [2, 1])
|
||||||
|
|
||||||
@test_util.eager_lazy_remote_copy_on_and_off
|
@test_util.eager_lazy_remote_copy_on_and_off
|
||||||
def testMultiDeviceFunctionOnLocalDevice(self):
|
def testMultiDeviceFunctionOnLocalDevice(self):
|
||||||
with ops.device('/job:worker/replica:0/task:1'):
|
with ops.device('/job:worker/replica:0/task:1'):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user