Function/control flow inlining: only use the function's job/replica/task placement if the body node doesn't have its own
At least for function placement this matches the behavior of un-inlined PartitionedCallOp (default_device to placer only affects nodes without a requested placement). For functional cond (and presumably while) the un-inlined behavior appears to be single-device-only, so this change causes behavior to diverge there (it already diverged somewhat in allowing multi-device body graphs at all). Fixes #44011. PiperOrigin-RevId: 347720744 Change-Id: Id32e4aacd2d82811e31dd9efd66c84cc7219a1dc
This commit is contained in:
parent
66f670ded6
commit
1ce28a3f8a
tensorflow
@ -1289,7 +1289,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsAndPlaceInlinedNodes) {
|
||||
auto g = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
TF_ASSERT_OK(construct_graph(&g));
|
||||
|
||||
const string merged_device = "/job:call/replica:0/task:1/device:CPU:*";
|
||||
const string merged_device = "/job:body/replica:0/task:1/device:CPU:*";
|
||||
|
||||
ExpandInlineFunctions(flr0_, g.get(), opts);
|
||||
GraphDef expected = expected_graph({/*a*/ arg_device, //
|
||||
|
@ -231,17 +231,19 @@ class MultiDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
|
||||
if (!DeviceNameUtils::ParseFullName(ndef.device(), &ndef_parsed_device))
|
||||
return ndef.device();
|
||||
|
||||
if (caller_parsed_device_.has_job) {
|
||||
// Nodes with explicit device placements in the function body have those
|
||||
// respected, but otherwise the function's placement provides a default.
|
||||
if (caller_parsed_device_.has_job && !ndef_parsed_device.has_job) {
|
||||
ndef_parsed_device.has_job = caller_parsed_device_.has_job;
|
||||
ndef_parsed_device.job = caller_parsed_device_.job;
|
||||
}
|
||||
|
||||
if (caller_parsed_device_.has_replica) {
|
||||
if (caller_parsed_device_.has_replica && !ndef_parsed_device.has_replica) {
|
||||
ndef_parsed_device.has_replica = caller_parsed_device_.has_replica;
|
||||
ndef_parsed_device.replica = caller_parsed_device_.replica;
|
||||
}
|
||||
|
||||
if (caller_parsed_device_.has_task) {
|
||||
if (caller_parsed_device_.has_task && !ndef_parsed_device.has_task) {
|
||||
ndef_parsed_device.has_task = caller_parsed_device_.has_task;
|
||||
ndef_parsed_device.task = caller_parsed_device_.task;
|
||||
}
|
||||
|
@ -19,11 +19,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import remote
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -288,23 +291,24 @@ class CondV2Test(test.TestCase):
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testDefunInCond(self):
|
||||
x = constant_op.constant(1.0, name="x")
|
||||
y = constant_op.constant(2.0, name="y")
|
||||
with ops.Graph().as_default():
|
||||
x = constant_op.constant(1.0, name="x")
|
||||
y = constant_op.constant(2.0, name="y")
|
||||
|
||||
def true_fn():
|
||||
def true_fn():
|
||||
|
||||
@function.defun
|
||||
def fn():
|
||||
return x * y * 2.0
|
||||
@function.defun
|
||||
def fn():
|
||||
return x * y * 2.0
|
||||
|
||||
return fn()
|
||||
return fn()
|
||||
|
||||
def false_fn():
|
||||
return 2.0
|
||||
def false_fn():
|
||||
return 2.0
|
||||
|
||||
self._testCond(true_fn, false_fn, [x])
|
||||
self._testCond(true_fn, false_fn, [x, y])
|
||||
self._testCond(true_fn, false_fn, [y])
|
||||
self._testCond(true_fn, false_fn, [x])
|
||||
self._testCond(true_fn, false_fn, [x, y])
|
||||
self._testCond(true_fn, false_fn, [y])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNestedDefunInCond(self):
|
||||
@ -942,24 +946,23 @@ class CondV2Test(test.TestCase):
|
||||
self.assertAllEqual(self.evaluate(fn_output), [2.0, 4.0])
|
||||
|
||||
def testGradientTapeOfCondWithResourceVariableInFunction(self):
|
||||
with context.eager_mode():
|
||||
v = variables.Variable(2.)
|
||||
v = variables.Variable(2.)
|
||||
|
||||
@def_function.function
|
||||
def fn_with_cond():
|
||||
with backprop.GradientTape() as tape:
|
||||
pred = constant_op.constant(True, dtype=dtypes.bool)
|
||||
@def_function.function
|
||||
def fn_with_cond():
|
||||
with backprop.GradientTape() as tape:
|
||||
pred = constant_op.constant(True, dtype=dtypes.bool)
|
||||
|
||||
def true_fn():
|
||||
return math_ops.pow(v, 3)
|
||||
def true_fn():
|
||||
return math_ops.pow(v, 3)
|
||||
|
||||
def false_fn():
|
||||
return v
|
||||
def false_fn():
|
||||
return v
|
||||
|
||||
cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
|
||||
return tape.gradient(cond, v)
|
||||
cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
|
||||
return tape.gradient(cond, v)
|
||||
|
||||
self.assertAllEqual(fn_with_cond(), 12.0)
|
||||
self.assertAllEqual(fn_with_cond(), 12.0)
|
||||
|
||||
def _CheckIteratedCosGradients(self, func):
|
||||
|
||||
@ -1458,9 +1461,10 @@ class CondV2ContainerTest(test.TestCase):
|
||||
self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))
|
||||
|
||||
|
||||
class CondV2ColocationGroupAndDeviceTest(test.TestCase):
|
||||
class CondV2ColocationGroupAndDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
context._reset_context()
|
||||
super(CondV2ColocationGroupAndDeviceTest, self).setUp()
|
||||
cpus = context.context().list_physical_devices("CPU")
|
||||
context.context().set_logical_device_configuration(
|
||||
@ -1468,6 +1472,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
|
||||
context.LogicalDeviceConfiguration(),
|
||||
context.LogicalDeviceConfiguration()
|
||||
])
|
||||
workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0)
|
||||
remote.connect_to_remote_host(workers[0].target)
|
||||
|
||||
def testColocateWithBeforeCond(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
@ -1544,64 +1550,113 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
|
||||
self.assertTrue(len(run_metadata.partition_graphs) >= 2)
|
||||
|
||||
def testDeviceBeforeCond(self):
|
||||
with context.eager_mode():
|
||||
def fn():
|
||||
cpu_zero_op = test_ops.device_placement_op()
|
||||
self.assertEqual("/device:CPU:0", cpu_zero_op.device)
|
||||
with ops.device("CPU:1"):
|
||||
cpu_one_op = test_ops.device_placement_op()
|
||||
self.assertEqual("/device:CPU:1", cpu_one_op.device)
|
||||
return cpu_zero_op, cpu_one_op
|
||||
|
||||
@def_function.function
|
||||
def _cond_wrapper():
|
||||
with ops.device("/device:CPU:0"):
|
||||
return cond_v2.cond_v2(constant_op.constant(True), fn, fn)
|
||||
def fn():
|
||||
cpu_zero_op = test_ops.device_placement_op()
|
||||
self.assertEqual("/job:localhost/device:CPU:0", cpu_zero_op.device)
|
||||
with ops.device("CPU:1"):
|
||||
cpu_one_op = test_ops.device_placement_op()
|
||||
self.assertEqual("/job:localhost/device:CPU:1", cpu_one_op.device)
|
||||
return cpu_zero_op, cpu_one_op
|
||||
|
||||
zero_expected, one_expected = self.evaluate(_cond_wrapper())
|
||||
self.assertIn(compat.as_bytes("CPU:0"), zero_expected)
|
||||
self.assertIn(compat.as_bytes("CPU:1"), one_expected)
|
||||
@def_function.function
|
||||
def _cond_wrapper():
|
||||
with ops.device("/job:localhost/device:CPU:0"):
|
||||
return cond_v2.cond_v2(constant_op.constant(True), fn, fn)
|
||||
|
||||
def fn2():
|
||||
self.assertEqual("/device:GPU:0", constant_op.constant(3.0).op.device)
|
||||
return test_ops.device_placement_op()
|
||||
zero_expected, one_expected = self.evaluate(_cond_wrapper())
|
||||
self.assertIn(compat.as_bytes("CPU:0"), zero_expected)
|
||||
self.assertIn(compat.as_bytes("CPU:1"), one_expected)
|
||||
self.assertIn(compat.as_bytes("job:localhost"), zero_expected)
|
||||
self.assertIn(compat.as_bytes("job:localhost"), one_expected)
|
||||
|
||||
@def_function.function
|
||||
def _cond_wrapper2():
|
||||
with ops.device("/device:GPU:0"):
|
||||
return cond_v2.cond_v2(constant_op.constant(True), fn2, fn2)
|
||||
def fn2():
|
||||
self.assertEqual("/job:localhost/device:GPU:0",
|
||||
constant_op.constant(3.0).op.device)
|
||||
return test_ops.device_placement_op()
|
||||
|
||||
if test_util.is_gpu_available():
|
||||
self.assertIn(compat.as_bytes("GPU:0"),
|
||||
self.evaluate(_cond_wrapper2()))
|
||||
else:
|
||||
self.skipTest("Test requires a GPU to check GPU device placement.")
|
||||
@def_function.function
|
||||
def _cond_wrapper2():
|
||||
with ops.device("/job:localhost/device:GPU:0"):
|
||||
return cond_v2.cond_v2(constant_op.constant(True), fn2, fn2)
|
||||
|
||||
if test_util.is_gpu_available():
|
||||
self.assertIn(compat.as_bytes("GPU:0"), self.evaluate(_cond_wrapper2()))
|
||||
self.assertIn(
|
||||
compat.as_bytes("job:localhost"), self.evaluate(_cond_wrapper2()))
|
||||
else:
|
||||
self.skipTest("Test requires a GPU to check GPU device placement.")
|
||||
|
||||
@parameterized.named_parameters([
|
||||
dict(
|
||||
testcase_name="Function",
|
||||
functional_op_to_test=lambda fn: def_function.function(fn)()),
|
||||
dict(
|
||||
testcase_name="Cond",
|
||||
functional_op_to_test=
|
||||
lambda fn: cond_v2.cond_v2(constant_op.constant(True), fn, fn))
|
||||
])
|
||||
def testDeviceBeforeRemote(self, functional_op_to_test):
|
||||
context.context().log_device_placement = True
|
||||
|
||||
def _fn():
|
||||
local_op = test_ops.device_placement_op()
|
||||
with ops.device("/job:worker/CPU:0"):
|
||||
worker_op = test_ops.device_placement_op()
|
||||
return local_op, worker_op
|
||||
|
||||
@def_function.function
|
||||
def _wrapper():
|
||||
with ops.device("/job:localhost"):
|
||||
return functional_op_to_test(_fn)
|
||||
|
||||
local_expected, worker_expected = self.evaluate(_wrapper())
|
||||
self.assertIn(compat.as_bytes("job:localhost"), local_expected)
|
||||
self.assertIn(compat.as_bytes("job:worker"), worker_expected)
|
||||
|
||||
del _fn, _wrapper
|
||||
|
||||
# There's nothing special about localhost; if we swap roles (functional op
|
||||
# on worker, op on localhost) the inner placement still wins.
|
||||
def _fn2():
|
||||
local_op = test_ops.device_placement_op()
|
||||
with ops.device("/job:localhost/CPU:0"):
|
||||
worker_op = test_ops.device_placement_op()
|
||||
return local_op, worker_op
|
||||
|
||||
@def_function.function
|
||||
def _wrapper2():
|
||||
with ops.device("/job:worker"):
|
||||
return functional_op_to_test(_fn2)
|
||||
|
||||
worker_expected, local_expected = self.evaluate(_wrapper2())
|
||||
self.assertIn(compat.as_bytes("job:worker"), worker_expected)
|
||||
self.assertIn(compat.as_bytes("job:localhost"), local_expected)
|
||||
|
||||
def testColocationBeforeCond(self):
|
||||
with context.eager_mode():
|
||||
|
||||
def _fn():
|
||||
result = test_ops.device_placement_op()
|
||||
self.assertIn("colocation_test_op",
|
||||
result.op.colocation_groups()[0].decode())
|
||||
return result
|
||||
def _fn():
|
||||
result = test_ops.device_placement_op()
|
||||
self.assertIn("colocation_test_op",
|
||||
result.op.colocation_groups()[0].decode())
|
||||
return result
|
||||
|
||||
@def_function.function(autograph=False)
|
||||
def _cond_wrapper():
|
||||
with ops.device("/device:CPU:0"):
|
||||
op_on_cpu_0 = test_ops.device_placement_op(name="colocation_test_op")
|
||||
with ops.device("/device:CPU:1"):
|
||||
op_on_cpu_1 = test_ops.device_placement_op(
|
||||
name="colocation_test_op_1")
|
||||
condition = constant_op.constant(True)
|
||||
with ops.colocate_with(op_on_cpu_0.op):
|
||||
zero_expected = cond_v2.cond_v2(condition, _fn, _fn)
|
||||
with ops.colocate_with(op_on_cpu_1.op):
|
||||
one_expected = cond_v2.cond_v2(condition, _fn, _fn)
|
||||
return zero_expected, one_expected
|
||||
zero_expected, one_expected = self.evaluate(_cond_wrapper())
|
||||
self.assertIn(compat.as_bytes("CPU:0"), zero_expected)
|
||||
self.assertIn(compat.as_bytes("CPU:1"), one_expected)
|
||||
@def_function.function(autograph=False)
|
||||
def _cond_wrapper():
|
||||
with ops.device("/device:CPU:0"):
|
||||
op_on_cpu_0 = test_ops.device_placement_op(name="colocation_test_op")
|
||||
with ops.device("/device:CPU:1"):
|
||||
op_on_cpu_1 = test_ops.device_placement_op(name="colocation_test_op_1")
|
||||
condition = constant_op.constant(True)
|
||||
with ops.colocate_with(op_on_cpu_0.op):
|
||||
zero_expected = cond_v2.cond_v2(condition, _fn, _fn)
|
||||
with ops.colocate_with(op_on_cpu_1.op):
|
||||
one_expected = cond_v2.cond_v2(condition, _fn, _fn)
|
||||
return zero_expected, one_expected
|
||||
|
||||
zero_expected, one_expected = self.evaluate(_cond_wrapper())
|
||||
self.assertIn(compat.as_bytes("CPU:0"), zero_expected)
|
||||
self.assertIn(compat.as_bytes("CPU:1"), one_expected)
|
||||
|
||||
def testDeviceInAndOutOfCond(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
@ -1702,4 +1757,5 @@ def _has_node_with_op(run_metadata, op_type):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user