Update check for multiple devices found to check whether there is a single device colocated with the default_device.
If multiple devices are found in the cluster, but there is a single device colocated with default_device, the single colocated device would be used by placer. This change relaxes the current check to account for parent device scope. For example, @tf.function() def replica_fn() with tf.device("gpu:0") with tf.device("worker1"): replica_fn() In this case there may be multiple gpu:0 in the cluster, but gpu:0 on worker 1 should be picked. In this case, we should not receive multiple devices found error. PiperOrigin-RevId: 357996270 Change-Id: I6a19deb5f26d7741ba7f7adff7f14550679b3d6d
This commit is contained in:
parent
acb69957a7
commit
79cd2fd4ac
@ -516,6 +516,22 @@ Status ProcessFunctionLibraryRuntime::PinArgsAndRets(
|
|||||||
if (on_same_task) {
|
if (on_same_task) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
// Compare with default_device if it has a narrower scope matching
|
||||||
|
// requested device.
|
||||||
|
int colocated_on_default_device = 0;
|
||||||
|
for (int i = 0; i < matching_devices.size(); ++i) {
|
||||||
|
if (DeviceNameUtils::IsSameAddressSpace(
|
||||||
|
default_device->parsed_name(),
|
||||||
|
matching_devices.at(i)->parsed_name())) {
|
||||||
|
colocated_on_default_device++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Continue to raise error if multiple colocated devices are
|
||||||
|
// found.
|
||||||
|
if (colocated_on_default_device == 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Convert a vector of devices to a string.
|
// Convert a vector of devices to a string.
|
||||||
// Using absl::StrJoin did not work in Android builds.
|
// Using absl::StrJoin did not work in Android builds.
|
||||||
string devices = "[";
|
string devices = "[";
|
||||||
|
@ -1046,6 +1046,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:functional_ops",
|
"//tensorflow/python:functional_ops",
|
||||||
"//tensorflow/python:resource_variable_ops",
|
"//tensorflow/python:resource_variable_ops",
|
||||||
"//tensorflow/python:tensor_spec",
|
"//tensorflow/python:tensor_spec",
|
||||||
|
"//tensorflow/python:test_ops",
|
||||||
"@absl_py//absl/testing:parameterized",
|
"@absl_py//absl/testing:parameterized",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
|
@ -1795,6 +1795,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
has_device.f()
|
has_device.f()
|
||||||
self.assertIn('CPU', has_device.v.device)
|
self.assertIn('CPU', has_device.v.device)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testMultipleDeviceCheck(self):
|
||||||
|
|
||||||
|
def f():
|
||||||
|
with ops.device('cpu'):
|
||||||
|
return test_ops.device_placement_op()
|
||||||
|
|
||||||
|
func = function.defun(f)
|
||||||
|
with ops.device('cpu:0'):
|
||||||
|
output = self.evaluate(func())
|
||||||
|
self.assertIn(compat.as_bytes('CPU:0'), output)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testDeviceAnnotationsRespected(self):
|
def testDeviceAnnotationsRespected(self):
|
||||||
|
|
||||||
|
@ -38,6 +38,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
|
from tensorflow.python.framework import test_ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -45,9 +46,11 @@ from tensorflow.python.ops import data_flow_ops
|
|||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
|
from tensorflow.python.ops import string_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
from tensorflow.python.training.server_lib import ClusterSpec
|
from tensorflow.python.training.server_lib import ClusterSpec
|
||||||
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
|
|
||||||
class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
||||||
@ -112,20 +115,6 @@ class SingleWorkerTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual(rets[1].backing_device,
|
self.assertEqual(rets[1].backing_device,
|
||||||
'/job:worker/replica:0/task:0/device:CPU:0')
|
'/job:worker/replica:0/task:0/device:CPU:0')
|
||||||
|
|
||||||
def testMultiDeviceFunctionAmbiguousDevice(self):
|
|
||||||
|
|
||||||
@def_function.function
|
|
||||||
def ambiguous_device(i):
|
|
||||||
with ops.device('cpu:0'):
|
|
||||||
return i + constant_op.constant([2])
|
|
||||||
|
|
||||||
with self.assertRaises(errors.InvalidArgumentError) as cm:
|
|
||||||
with ops.device('/job:worker/replica:0/task:0/cpu:0'):
|
|
||||||
ambiguous_device(constant_op.constant([2])).numpy()
|
|
||||||
|
|
||||||
self.assertIn('the output node must match exactly one device',
|
|
||||||
cm.exception.message)
|
|
||||||
|
|
||||||
def testStreaming(self):
|
def testStreaming(self):
|
||||||
"""A mini stress test for streaming - issuing many RPCs back to back."""
|
"""A mini stress test for streaming - issuing many RPCs back to back."""
|
||||||
with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
|
with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
|
||||||
@ -318,6 +307,21 @@ class MultiWorkersTest(test.TestCase, parameterized.TestCase):
|
|||||||
with ops.device('/job:worker/replica:0/task:1'):
|
with ops.device('/job:worker/replica:0/task:1'):
|
||||||
self.assertAllEqual(local_func(x), [2, 1])
|
self.assertAllEqual(local_func(x), [2, 1])
|
||||||
|
|
||||||
|
def testMultiDeviceFunctionAmbiguousDevice(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def ambiguous_device(i):
|
||||||
|
with ops.device('/job:worker'):
|
||||||
|
# Multiple worker tasks, thus ambiguous device found error will be
|
||||||
|
# raised.
|
||||||
|
return i + constant_op.constant([2])
|
||||||
|
|
||||||
|
with self.assertRaises(errors.InvalidArgumentError) as cm:
|
||||||
|
ambiguous_device(constant_op.constant([2])).numpy()
|
||||||
|
|
||||||
|
self.assertIn('the output node must match exactly one device',
|
||||||
|
cm.exception.message)
|
||||||
|
|
||||||
# Note that the following tests for remote function cancellation only works
|
# Note that the following tests for remote function cancellation only works
|
||||||
# when non-streaming RPC. We need to disable streaming explicitly and restore
|
# when non-streaming RPC. We need to disable streaming explicitly and restore
|
||||||
# this config to its initial value at the end of each test case.
|
# this config to its initial value at the end of each test case.
|
||||||
@ -579,6 +583,32 @@ class MultiJobsTest(test.TestCase, parameterized.TestCase):
|
|||||||
# Reset the context to avoid polluting other test cases.
|
# Reset the context to avoid polluting other test cases.
|
||||||
context._reset_context()
|
context._reset_context()
|
||||||
|
|
||||||
|
def testMultipleDeviceFoundCheck(self):
|
||||||
|
remote.connect_to_cluster(self._cluster)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def func():
|
||||||
|
with ops.device('cpu:0'):
|
||||||
|
# Multiple CPU:0 devices match would be found, but the CPU:0 from the
|
||||||
|
# parent device scope should be picked.
|
||||||
|
x = test_ops.device_placement_op()
|
||||||
|
y = string_ops.string_upper(x)
|
||||||
|
packed_var_0 = array_ops.stack([x, y], 0)
|
||||||
|
return packed_var_0
|
||||||
|
|
||||||
|
with ops.device('/job:my_worker/task:1'):
|
||||||
|
output = self.evaluate(func())
|
||||||
|
self.assertEqual(
|
||||||
|
compat.as_bytes('/job:my_worker/replica:0/task:1/device:CPU:0'),
|
||||||
|
output[0])
|
||||||
|
self.assertIn(compat.as_bytes('/JOB:MY_WORKER'), output[1])
|
||||||
|
with ops.device('/job:my_ps/task:1'):
|
||||||
|
output = self.evaluate(func())
|
||||||
|
self.assertEqual(
|
||||||
|
compat.as_bytes('/job:my_ps/replica:0/task:1/device:CPU:0'),
|
||||||
|
output[0])
|
||||||
|
self.assertIn(compat.as_bytes('/JOB:MY_PS'), output[1])
|
||||||
|
|
||||||
def testSimpleParameterServer(self):
|
def testSimpleParameterServer(self):
|
||||||
remote.connect_to_cluster(self._cluster)
|
remote.connect_to_cluster(self._cluster)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user