From 673b993983f37f332ff70cdb642305f69089337d Mon Sep 17 00:00:00 2001
From: Rohan Jain <rohanj@google.com>
Date: Wed, 21 Oct 2020 05:44:08 -0700
Subject: [PATCH] Ensuring that the Switch op used as a pivot is always placed
 on the CPU. For this we set a private attribute _PivotSwitch while creating
 this op and then make sure that the device overwriting logic in
 GraphPartition isn't executed for this op.

Note: Had to fix up control_flow_ops_py_test so that we don't expect a GPU graph when we don't get one. The reason is that now since we already know the switch_pred is going to be placed on CPU, the placer ensures that its input is placed on the CPU as well and we end up saving a copy. This means there is no GPU graph when we partition.
PiperOrigin-RevId: 338246477
Change-Id: I5641c9ae1b2d593a2996947bafe92b22cb63371d
---
 tensorflow/core/common_runtime/BUILD          |   3 +-
 tensorflow/core/common_runtime/lower_if_op.cc |  11 +-
 .../core/common_runtime/lower_if_op_test.cc   | 109 ++++++++++++++++++
 tensorflow/core/graph/graph_partition.cc      |   7 ++
 .../kernel_tests/control_flow_ops_py_test.py  |   3 +-
 5 files changed, 130 insertions(+), 3 deletions(-)

diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD
index cf053b0af51..fcbf0c52905 100644
--- a/tensorflow/core/common_runtime/BUILD
+++ b/tensorflow/core/common_runtime/BUILD
@@ -2522,10 +2522,11 @@ tf_cc_test(
     ],
 )
 
-tf_cc_test(
+tf_cc_test_gpu(
     name = "lower_if_op_test",
     size = "small",
     srcs = ["lower_if_op_test.cc"],
+    tags = tf_cuda_tests_tags(),
     deps = [
         ":core_cpu",
         ":core_cpu_internal",
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index ff010ad8a63..2a0e5d35de5 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -148,13 +148,22 @@ Status CondBuilder::SetColocationAndFinalize(NodeBuilder node_builder,
 Status CondBuilder::CreatePivotNodes() {
   // Construct the basic cond body (consisting of feeding in the predicate to
   // create pivot nodes).
+
+  // This is a special pivot switch node for lowering. We mark this with a
+  // special _PivotSwitch attr on it as later on in the graph partitioner we
+  // do some special placement for Switch nodes and its necessary to distinguish
+  // between a "normal" Switch node and one of these pivot switches. We would
+  // like to place this node on the CPU always as the pred_ will be on the CPU
+  // as well (either a CPU op output or a GPU op with HostMemory annotation).
+  // TODO(b/171321391): Fix this for NUMA cases.
   Node* switch_pred;
   TF_RETURN_IF_ERROR(
       SetColocationAndFinalize(NodeBuilder(NewName("switch_pred"), "Switch",
                                            graph_->op_registry(), &debug_info_)
                                    .Input(NodeOut(pred_))
                                    .Input(NodeOut(pred_))
-                                   .Device(if_op_->requested_device()),
+                                   .Attr("_PivotSwitch", true)
+                                   .Device("/CPU:0"),
                                graph_, &switch_pred));
   control_predecessor_ = switch_pred;
   TF_RETURN_IF_ERROR(
diff --git a/tensorflow/core/common_runtime/lower_if_op_test.cc b/tensorflow/core/common_runtime/lower_if_op_test.cc
index cf7d35409bb..b0304cfe293 100644
--- a/tensorflow/core/common_runtime/lower_if_op_test.cc
+++ b/tensorflow/core/common_runtime/lower_if_op_test.cc
@@ -147,6 +147,115 @@ TEST(LowerIfOpTest, Simple) {
   }
 }
 
+TEST(LowerIfOpTest, GPUPlacement) {
+  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+  // Add test functions for then and else branch.
+  FunctionDefLibrary f_lib_proto;
+  *(f_lib_proto.add_function()) = test::function::XTimesTwo();
+  *(f_lib_proto.add_function()) = test::function::XTimesFour();
+
+  // Construct simple conditional that switches on `pred` and operates only on
+  // single input `A`.
+  Scope root = Scope::NewRootScope().ExitOnError();
+  TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
+  auto a = ops::Placeholder(root.WithOpName("A"), DT_INT32);
+  auto x = ops::Placeholder(root.WithOpName("X"), DT_INT32);
+  auto y = ops::Placeholder(root.WithOpName("Y"), DT_INT32);
+  Node* pred;
+  TF_ASSERT_OK(NodeBuilder("greater", "Greater", &root.graph()->flib_def())
+                   .Input(x.node())
+                   .Input(y.node())
+                   .Device("/GPU:0")
+                   .Finalize(root.graph(), &pred));
+  Node* written_if;
+  std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
+  TF_ASSERT_OK(
+      NodeBuilder("if", "If", &root.graph()->flib_def())
+          .Input(pred)
+          .Input(inputs)
+          .Attr("then_branch", FuncAttr("XTimesTwo"))
+          .Attr("else_branch", FuncAttr("XTimesFour"))
+          .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
+          .Attr("Tout", {DT_INT32})
+          .Device("/GPU:0")
+          .Finalize(root.graph(), &written_if));
+  TF_ASSERT_OK(root.DoShapeInference(written_if));
+  TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+  // The input graph has no switch or merge nodes.
+  int node_called_if_count = 0;
+  for (const auto* op : graph->op_nodes()) {
+    ASSERT_FALSE(op->IsSwitch());
+    ASSERT_FALSE(op->IsMerge());
+    if (op->name() == "if") {
+      ++node_called_if_count;
+    }
+  }
+  ASSERT_EQ(node_called_if_count, 1);
+
+  TF_ASSERT_OK(Rewrite(&graph));
+
+  // Verify the resultant graph has switch and merge nodes, and a node called
+  // `if` (but not If nodes).
+  int switch_count = 0;
+  int merge_count = 0;
+  node_called_if_count = 0;
+  for (const auto* op : graph->op_nodes()) {
+    if (op->IsSwitch()) {
+      ++switch_count;
+    }
+    if (op->IsMerge()) {
+      ++merge_count;
+    }
+    ASSERT_NE(op->type_string(), "If");
+    if (op->name() == "if") {
+      ++node_called_if_count;
+    }
+  }
+  // One switch for predicate and one for input (A).
+  ASSERT_EQ(switch_count, 2);
+  // One merge for the single output value of then and else, and one more merge
+  // to enforce then and else function call execution (`branch_executed` node).
+  ASSERT_EQ(merge_count, 2);
+  ASSERT_EQ(node_called_if_count, 1);
+
+  // Verify execution.
+  ClientSession session(root, SessionOptionsWithInlining());
+  {
+    RunMetadata metadata;
+    RunOptions options;
+    options.set_output_partition_graphs(true);
+    ClientSession::FeedType feeds;
+    feeds.emplace(Output(x.node()), Input::Initializer(5));
+    feeds.emplace(Output(y.node()), Input::Initializer(10));
+    feeds.emplace(Output(a.node()), Input::Initializer(10));
+    std::vector<Tensor> out_tensors;
+    TF_ASSERT_OK(session.Run(options, feeds, {Output(written_if)}, {},
+                             &out_tensors, &metadata));
+    GraphDef cpu_graph = metadata.partition_graphs(1);
+    int num_cpu_switch = 0;
+    for (const auto& node : cpu_graph.node()) {
+      if (node.op() == "Switch") {
+        ++num_cpu_switch;
+      }
+    }
+    EXPECT_EQ(num_cpu_switch, 2);
+    EXPECT_EQ(out_tensors.size(), 1);
+    EXPECT_EQ(out_tensors[0].scalar<int>()(), 40);
+  }
+  {
+    ClientSession::FeedType feeds;
+    feeds.emplace(Output(x.node()), Input::Initializer(10));
+    feeds.emplace(Output(y.node()), Input::Initializer(5));
+    feeds.emplace(Output(a.node()), Input::Initializer(10));
+    std::vector<Tensor> out_tensors;
+    TF_ASSERT_OK(session.Run(feeds, {Output(written_if)}, &out_tensors));
+    EXPECT_EQ(out_tensors.size(), 1);
+    EXPECT_EQ(out_tensors[0].scalar<int>()(), 20);
+  }
+}
+
 TEST(LowerIfOpTest, BranchFunctionsWithoutOutputs) {
   using ::tensorflow::test::function::GDef;
   using ::tensorflow::test::function::NDef;
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index bf57e263441..7680bcacba5 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -371,6 +371,13 @@ NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
 void OptimizeControlFlowColocation(Graph* graph) {
   auto visit = [](Node* node) {
     if (IsSwitch(node)) {
+      // Pivot Switch nodes (which are also of type Switch) are already placed
+      // on the CPU and colocated with its inputs that are also already on the
+      // CPU (or might be placed on GPU but in host memory).
+      if (HasNodeAttr(node->def(), "_PivotSwitch")) {
+        DCHECK(node->requested_device().find("CPU") != string::npos);
+        return;
+      }
       for (const Edge* in_edge : node->in_edges()) {
         if (in_edge->dst_input() == 0) {
           // Colocate with the data input.
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 54bbd2b2e9e..532dac1d85a 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -730,6 +730,8 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
         g for g in run_metadata.partition_graphs
         if device_str in g.node[0].device
     ]
+    if not device_graphs:
+      return 0
     self.assertLen(device_graphs, 1)
     switch_nodes = [
         n for n in device_graphs[0].node
@@ -759,7 +761,6 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
       options = config_pb2.RunOptions(output_partition_graphs=True)
       sess.run(
           r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
-      self.assertLen(run_metadata.partition_graphs, 2)
       # Check that the Switch for `arg` gets placed on CPU.
       self.assertEqual(
           self._count_matching_switch_nodes_on_device(run_metadata, "CPU",