From a873a3e0756656b29a56ab7f525ea57fe6b1da3a Mon Sep 17 00:00:00 2001 From: Zheng Xu Date: Tue, 15 Sep 2020 16:55:18 -0700 Subject: [PATCH] Fix colocation attr on Switch in lower_if_op.cc PiperOrigin-RevId: 331886863 Change-Id: I9cd889b2c380b31ae025f4050eece83f7851a4aa --- tensorflow/core/common_runtime/lower_if_op.cc | 16 +++-- .../kernel_tests/control_flow_ops_py_test.py | 59 ++++--------------- 2 files changed, 17 insertions(+), 58 deletions(-) diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index ec52f20fac9..5cde4f9049c 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -166,15 +166,13 @@ Status CondBuilder::AddInput(Node* src, int src_output) { // the same device as the input node (if set) and sets the colocation _class // attr. It also ignores the existing colocation constraints on the input node // using colocate_with(ignore_existing=True). - TF_RETURN_IF_ERROR( - NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry(), - &debug_info) - .Input(src, src_output) - .Input(pred_) - .Device(src->requested_device()) - .Attr(kColocationAttrName, - {absl::StrCat(kColocationGroupPrefix, src->name())}) - .Finalize(graph_, &input)); + TF_RETURN_IF_ERROR(NodeBuilder(NewName(src->name()), "Switch", + graph_->op_registry(), &debug_info) + .Input(src, src_output) + .Input(pred_) + .Device(src->requested_device()) + .Attr("_class", {src->name()}) + .Finalize(graph_, &input)); then_call_builder_.Input(input, kThenBranch); else_call_builder_.Input(input, kElseBranch); return Status::OK(); diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 7cedeef8916..2e13414f720 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -34,8 +34,6 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python import tf2 from tensorflow.python.client import device_lib from tensorflow.python.client import session -from tensorflow.python.data.experimental.ops import cardinality -from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import function as eager_function @@ -722,9 +720,8 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): # We expect that everything runs on CPU, even if GPU is available. self.assertEqual(len(run_metadata.partition_graphs), 1) - def _count_matching_switch_nodes_on_device(self, run_metadata, device_str, - dtype): - # Returns the number of Switch nodes with type dtype placed on + def _count_matching_switch_nodes_on_device(self, run_metadata, device_str): + # Returns the number of Switch nodes with type float32 placed on # `device_str`. device_graphs = [ g for g in run_metadata.partition_graphs @@ -732,14 +729,14 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): ] self.assertLen(device_graphs, 1) switch_nodes = [ - n for n in device_graphs[0].node - if n.op == "Switch" and n.attr["T"].type == dtype.as_datatype_enum + n for n in device_graphs[0].node if n.op == "Switch" and + n.attr["T"].type == dtypes.float32.as_datatype_enum ] return len(switch_nodes) @test_util.run_gpu_only @test_util.run_deprecated_v1 - def testCondSwitchColocatedWithInputWhenInputExplicitlyPlacedOnCPU(self): + def testCondSwitchColocatedWithInputWhenInputOnCPU(self): x = array_ops.placeholder(dtypes.float32) # `arg` is used in the cond then branch so a Switch node is created for it. @@ -759,46 +756,12 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): options = config_pb2.RunOptions(output_partition_graphs=True) sess.run( r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata) - self.assertLen(run_metadata.partition_graphs, 2) + self.assertEqual(len(run_metadata.partition_graphs), 2) # Check that the Switch for `arg` gets placed on CPU. self.assertEqual( - self._count_matching_switch_nodes_on_device(run_metadata, "CPU", - dtypes.float32), 1) + self._count_matching_switch_nodes_on_device(run_metadata, "CPU"), 1) self.assertEqual( - self._count_matching_switch_nodes_on_device(run_metadata, "GPU", - dtypes.float32), 0) - - @test_util.run_gpu_only - @test_util.run_deprecated_v1 - def testCondSwitchColocatedWithInputWhenInputPlacedOnCPU(self): - x = array_ops.placeholder(dtypes.float32) - - # `arg` is used in the cond then branch so a Switch node is created for it. - # We test that the Switch node gets placed on the same device as `arg`. - # Since arg is a dataset (and only has a CPU kernel), it gets placed on CPU - # by placer. - arg = dataset_ops.Dataset.range(8) - - def true_fn(): - return cardinality.cardinality(arg) - - r = control_flow_ops.cond( - constant_op.constant(True), true_fn, - lambda: constant_op.constant(0, dtypes.int64)) - - with session.Session() as sess: - run_metadata = config_pb2.RunMetadata() - options = config_pb2.RunOptions(output_partition_graphs=True) - sess.run( - r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata) - self.assertLen(run_metadata.partition_graphs, 2) - # Check that the Switch for `arg` gets placed on CPU. - self.assertEqual( - self._count_matching_switch_nodes_on_device(run_metadata, "CPU", - dtypes.variant), 1) - self.assertEqual( - self._count_matching_switch_nodes_on_device(run_metadata, "GPU", - dtypes.variant), 0) + self._count_matching_switch_nodes_on_device(run_metadata, "GPU"), 0) @test_util.run_gpu_only @test_util.run_deprecated_v1 @@ -824,11 +787,9 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): self.assertEqual(len(run_metadata.partition_graphs), 2) # Check that the Switch for `arg` gets placed on GPU. self.assertEqual( - self._count_matching_switch_nodes_on_device(run_metadata, "CPU", - dtypes.float32), 0) + self._count_matching_switch_nodes_on_device(run_metadata, "CPU"), 0) self.assertEqual( - self._count_matching_switch_nodes_on_device(run_metadata, "GPU", - dtypes.float32), 1) + self._count_matching_switch_nodes_on_device(run_metadata, "GPU"), 1) def testCondAccessTrueBranchTensorInFalseBranchRaises(self):