Function/control flow inlining: only use the function's job/replica/task placement if the body node doesn't have its own

At least for function placement this matches the behavior of un-inlined PartitionedCallOp (default_device to placer only affects nodes without a requested placement).

For functional cond (and presumably while) the un-inlined behavior appears to be single-device-only, so this change causes behavior to diverge there (it already diverged somewhat in allowing multi-device body graphs at all).

Fixes .

PiperOrigin-RevId: 347720744
Change-Id: Id32e4aacd2d82811e31dd9efd66c84cc7219a1dc
This commit is contained in:
Allen Lavoie 2020-12-15 16:54:01 -08:00 committed by TensorFlower Gardener
parent 66f670ded6
commit 1ce28a3f8a
3 changed files with 137 additions and 79 deletions
tensorflow
core/common_runtime
python/kernel_tests

View File

@ -1289,7 +1289,7 @@ TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsAndPlaceInlinedNodes) {
auto g = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(construct_graph(&g));
const string merged_device = "/job:call/replica:0/task:1/device:CPU:*";
const string merged_device = "/job:body/replica:0/task:1/device:CPU:*";
ExpandInlineFunctions(flr0_, g.get(), opts);
GraphDef expected = expected_graph({/*a*/ arg_device, //

View File

@ -231,17 +231,19 @@ class MultiDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
if (!DeviceNameUtils::ParseFullName(ndef.device(), &ndef_parsed_device))
return ndef.device();
if (caller_parsed_device_.has_job) {
// Nodes with explicit device placements in the function body have those
// respected, but otherwise the function's placement provides a default.
if (caller_parsed_device_.has_job && !ndef_parsed_device.has_job) {
ndef_parsed_device.has_job = caller_parsed_device_.has_job;
ndef_parsed_device.job = caller_parsed_device_.job;
}
if (caller_parsed_device_.has_replica) {
if (caller_parsed_device_.has_replica && !ndef_parsed_device.has_replica) {
ndef_parsed_device.has_replica = caller_parsed_device_.has_replica;
ndef_parsed_device.replica = caller_parsed_device_.replica;
}
if (caller_parsed_device_.has_task) {
if (caller_parsed_device_.has_task && !ndef_parsed_device.has_task) {
ndef_parsed_device.has_task = caller_parsed_device_.has_task;
ndef_parsed_device.task = caller_parsed_device_.task;
}

View File

@ -19,11 +19,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.eager import remote
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -288,23 +291,24 @@ class CondV2Test(test.TestCase):
@test_util.run_v1_only("b/120545219")
def testDefunInCond(self):
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
with ops.Graph().as_default():
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
def true_fn():
def true_fn():
@function.defun
def fn():
return x * y * 2.0
@function.defun
def fn():
return x * y * 2.0
return fn()
return fn()
def false_fn():
return 2.0
def false_fn():
return 2.0
self._testCond(true_fn, false_fn, [x])
self._testCond(true_fn, false_fn, [x, y])
self._testCond(true_fn, false_fn, [y])
self._testCond(true_fn, false_fn, [x])
self._testCond(true_fn, false_fn, [x, y])
self._testCond(true_fn, false_fn, [y])
@test_util.run_deprecated_v1
def testNestedDefunInCond(self):
@ -942,24 +946,23 @@ class CondV2Test(test.TestCase):
self.assertAllEqual(self.evaluate(fn_output), [2.0, 4.0])
def testGradientTapeOfCondWithResourceVariableInFunction(self):
with context.eager_mode():
v = variables.Variable(2.)
v = variables.Variable(2.)
@def_function.function
def fn_with_cond():
with backprop.GradientTape() as tape:
pred = constant_op.constant(True, dtype=dtypes.bool)
@def_function.function
def fn_with_cond():
with backprop.GradientTape() as tape:
pred = constant_op.constant(True, dtype=dtypes.bool)
def true_fn():
return math_ops.pow(v, 3)
def true_fn():
return math_ops.pow(v, 3)
def false_fn():
return v
def false_fn():
return v
cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
return tape.gradient(cond, v)
cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
return tape.gradient(cond, v)
self.assertAllEqual(fn_with_cond(), 12.0)
self.assertAllEqual(fn_with_cond(), 12.0)
def _CheckIteratedCosGradients(self, func):
@ -1458,9 +1461,10 @@ class CondV2ContainerTest(test.TestCase):
self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))
class CondV2ColocationGroupAndDeviceTest(test.TestCase):
class CondV2ColocationGroupAndDeviceTest(test.TestCase, parameterized.TestCase):
def setUp(self):
context._reset_context()
super(CondV2ColocationGroupAndDeviceTest, self).setUp()
cpus = context.context().list_physical_devices("CPU")
context.context().set_logical_device_configuration(
@ -1468,6 +1472,8 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration()
])
workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0)
remote.connect_to_remote_host(workers[0].target)
def testColocateWithBeforeCond(self):
with ops.Graph().as_default() as g:
@ -1544,64 +1550,113 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
self.assertTrue(len(run_metadata.partition_graphs) >= 2)
def testDeviceBeforeCond(self):
with context.eager_mode():
def fn():
cpu_zero_op = test_ops.device_placement_op()
self.assertEqual("/device:CPU:0", cpu_zero_op.device)
with ops.device("CPU:1"):
cpu_one_op = test_ops.device_placement_op()
self.assertEqual("/device:CPU:1", cpu_one_op.device)
return cpu_zero_op, cpu_one_op
@def_function.function
def _cond_wrapper():
with ops.device("/device:CPU:0"):
return cond_v2.cond_v2(constant_op.constant(True), fn, fn)
def fn():
cpu_zero_op = test_ops.device_placement_op()
self.assertEqual("/job:localhost/device:CPU:0", cpu_zero_op.device)
with ops.device("CPU:1"):
cpu_one_op = test_ops.device_placement_op()
self.assertEqual("/job:localhost/device:CPU:1", cpu_one_op.device)
return cpu_zero_op, cpu_one_op
zero_expected, one_expected = self.evaluate(_cond_wrapper())
self.assertIn(compat.as_bytes("CPU:0"), zero_expected)
self.assertIn(compat.as_bytes("CPU:1"), one_expected)
@def_function.function
def _cond_wrapper():
with ops.device("/job:localhost/device:CPU:0"):
return cond_v2.cond_v2(constant_op.constant(True), fn, fn)
def fn2():
self.assertEqual("/device:GPU:0", constant_op.constant(3.0).op.device)
return test_ops.device_placement_op()
zero_expected, one_expected = self.evaluate(_cond_wrapper())
self.assertIn(compat.as_bytes("CPU:0"), zero_expected)
self.assertIn(compat.as_bytes("CPU:1"), one_expected)
self.assertIn(compat.as_bytes("job:localhost"), zero_expected)
self.assertIn(compat.as_bytes("job:localhost"), one_expected)
@def_function.function
def _cond_wrapper2():
with ops.device("/device:GPU:0"):
return cond_v2.cond_v2(constant_op.constant(True), fn2, fn2)
def fn2():
self.assertEqual("/job:localhost/device:GPU:0",
constant_op.constant(3.0).op.device)
return test_ops.device_placement_op()
if test_util.is_gpu_available():
self.assertIn(compat.as_bytes("GPU:0"),
self.evaluate(_cond_wrapper2()))
else:
self.skipTest("Test requires a GPU to check GPU device placement.")
@def_function.function
def _cond_wrapper2():
with ops.device("/job:localhost/device:GPU:0"):
return cond_v2.cond_v2(constant_op.constant(True), fn2, fn2)
if test_util.is_gpu_available():
self.assertIn(compat.as_bytes("GPU:0"), self.evaluate(_cond_wrapper2()))
self.assertIn(
compat.as_bytes("job:localhost"), self.evaluate(_cond_wrapper2()))
else:
self.skipTest("Test requires a GPU to check GPU device placement.")
@parameterized.named_parameters([
dict(
testcase_name="Function",
functional_op_to_test=lambda fn: def_function.function(fn)()),
dict(
testcase_name="Cond",
functional_op_to_test=
lambda fn: cond_v2.cond_v2(constant_op.constant(True), fn, fn))
])
def testDeviceBeforeRemote(self, functional_op_to_test):
context.context().log_device_placement = True
def _fn():
local_op = test_ops.device_placement_op()
with ops.device("/job:worker/CPU:0"):
worker_op = test_ops.device_placement_op()
return local_op, worker_op
@def_function.function
def _wrapper():
with ops.device("/job:localhost"):
return functional_op_to_test(_fn)
local_expected, worker_expected = self.evaluate(_wrapper())
self.assertIn(compat.as_bytes("job:localhost"), local_expected)
self.assertIn(compat.as_bytes("job:worker"), worker_expected)
del _fn, _wrapper
# There's nothing special about localhost; if we swap roles (functional op
# on worker, op on localhost) the inner placement still wins.
def _fn2():
local_op = test_ops.device_placement_op()
with ops.device("/job:localhost/CPU:0"):
worker_op = test_ops.device_placement_op()
return local_op, worker_op
@def_function.function
def _wrapper2():
with ops.device("/job:worker"):
return functional_op_to_test(_fn2)
worker_expected, local_expected = self.evaluate(_wrapper2())
self.assertIn(compat.as_bytes("job:worker"), worker_expected)
self.assertIn(compat.as_bytes("job:localhost"), local_expected)
def testColocationBeforeCond(self):
with context.eager_mode():
def _fn():
result = test_ops.device_placement_op()
self.assertIn("colocation_test_op",
result.op.colocation_groups()[0].decode())
return result
def _fn():
result = test_ops.device_placement_op()
self.assertIn("colocation_test_op",
result.op.colocation_groups()[0].decode())
return result
@def_function.function(autograph=False)
def _cond_wrapper():
with ops.device("/device:CPU:0"):
op_on_cpu_0 = test_ops.device_placement_op(name="colocation_test_op")
with ops.device("/device:CPU:1"):
op_on_cpu_1 = test_ops.device_placement_op(
name="colocation_test_op_1")
condition = constant_op.constant(True)
with ops.colocate_with(op_on_cpu_0.op):
zero_expected = cond_v2.cond_v2(condition, _fn, _fn)
with ops.colocate_with(op_on_cpu_1.op):
one_expected = cond_v2.cond_v2(condition, _fn, _fn)
return zero_expected, one_expected
zero_expected, one_expected = self.evaluate(_cond_wrapper())
self.assertIn(compat.as_bytes("CPU:0"), zero_expected)
self.assertIn(compat.as_bytes("CPU:1"), one_expected)
@def_function.function(autograph=False)
def _cond_wrapper():
with ops.device("/device:CPU:0"):
op_on_cpu_0 = test_ops.device_placement_op(name="colocation_test_op")
with ops.device("/device:CPU:1"):
op_on_cpu_1 = test_ops.device_placement_op(name="colocation_test_op_1")
condition = constant_op.constant(True)
with ops.colocate_with(op_on_cpu_0.op):
zero_expected = cond_v2.cond_v2(condition, _fn, _fn)
with ops.colocate_with(op_on_cpu_1.op):
one_expected = cond_v2.cond_v2(condition, _fn, _fn)
return zero_expected, one_expected
zero_expected, one_expected = self.evaluate(_cond_wrapper())
self.assertIn(compat.as_bytes("CPU:0"), zero_expected)
self.assertIn(compat.as_bytes("CPU:1"), one_expected)
def testDeviceInAndOutOfCond(self):
with ops.Graph().as_default() as g:
@ -1702,4 +1757,5 @@ def _has_node_with_op(run_metadata, op_type):
if __name__ == "__main__":
ops.enable_eager_execution()
test.main()