From 50429bdff14d62c46c8b417c7fefb73fd648fc86 Mon Sep 17 00:00:00 2001
From: Trent Lo <trentl@nvidia.com>
Date: Fri, 16 Aug 2019 13:37:19 -0700
Subject: [PATCH] Add a new unittest in mark_for_compilation_pass_test.

The new test tests that ClusterScopingPass works and MarkForCompilationPass
accordinly preserves the required clustering scopes.
---
 .../jit/mark_for_compilation_pass_test.cc     | 86 +++++++++++++++++++
 .../mark_for_compilation_pass_test_helper.cc  | 12 ++-
 .../mark_for_compilation_pass_test_helper.h   | 12 ++-
 3 files changed, 107 insertions(+), 3 deletions(-)

diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index e056ecd8272..577ddbfca00 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -1718,5 +1718,91 @@ TEST(XlaCompilationTest, UnsupportedEnterExitPattern) {
   EXPECT_EQ(0, clusters.size());
 }
 
+namespace {
+Node* MakeStageNode(GraphDefBuilder& builder, string name,
+                    std::initializer_list<DataType> dtypes,
+                    gtl::ArraySlice<ops::NodeOut> values) {
+  auto opts = builder.opts()
+                  .WithName(std::move(name))
+                  .WithAttr("dtypes", std::move(dtypes));
+  if (opts.HaveError()) {
+    return nullptr;
+  }
+
+  NodeBuilder node_builder(name, "Stage", opts.op_registry());
+  node_builder.Input(values);
+  return opts.FinalizeBuilder(&node_builder);
+}
+}  // namespace
+
+TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) {
+  auto build_staged_graph = [](std::unique_ptr<Graph>* graph) -> Status {
+    // Construct a graph as below with two pipeline stages and test that nodes
+    // in different stages will not be merged if ClusterScopingPass is on.
+    //
+    //       b
+    //       |
+    //       v
+    // a -> add0 -> relu0 -> stage
+    //
+    //             b
+    //             |
+    //             v
+    // unstage -> add1 -> relu1
+    GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+    Node* a = ops::SourceOp("Const", builder.opts()
+                                         .WithName("a")
+                                         .WithAttr("dtype", DT_FLOAT)
+                                         .WithAttr("value", Tensor()));
+    Node* b = ops::SourceOp("Const", builder.opts()
+                                         .WithName("b")
+                                         .WithAttr("dtype", DT_FLOAT)
+                                         .WithAttr("value", Tensor()));
+    Node* unstage = ops::SourceOp(
+        "Unstage",
+        builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT}));
+
+    Node* add0 = ops::BinaryOp("Add", a, b, builder.opts().WithName("add0"));
+    Node* add1 =
+        ops::BinaryOp("Add", unstage, b, builder.opts().WithName("add1"));
+    Node* relu0 = ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0"));
+    ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1"));
+    MakeStageNode(builder, "stage", {DT_FLOAT}, {relu0});
+
+    return GraphDefBuilderToGraph(builder, graph->get());
+  };
+
+  // All nodes go into the same cluster if ClusterScopingPass is off.
+  {
+    std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+    TF_ASSERT_OK(build_staged_graph(&graph));
+
+    TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
+        &graph,
+        MarkForCompilationPassTestHelper::Options().WithNoClusterScoping()));
+
+    std::unordered_map<string, string> clusters = GetClusters(*graph);
+    EXPECT_EQ(clusters["add0"], clusters["add1"]);
+    EXPECT_EQ(clusters["add0"], clusters["relu1"]);
+    EXPECT_EQ(clusters["relu0"], clusters["add1"]);
+    EXPECT_EQ(clusters["relu0"], clusters["relu1"]);
+  }
+
+  // By default, ClusterScopingPass is on and different pipeline stages should
+  // not be merged.
+  {
+    std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+    TF_ASSERT_OK(build_staged_graph(&graph));
+
+    TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+    std::unordered_map<string, string> clusters = GetClusters(*graph);
+    EXPECT_NE(clusters["add0"], clusters["add1"]);
+    EXPECT_NE(clusters["add0"], clusters["relu1"]);
+    EXPECT_NE(clusters["relu0"], clusters["add1"]);
+    EXPECT_NE(clusters["relu0"], clusters["relu1"]);
+  }
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
index fa5abdfe508..44bd7b47d54 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
@@ -14,6 +14,8 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+
+#include "tensorflow/compiler/jit/cluster_scoping_pass.h"
 #include "tensorflow/core/common_runtime/device_factory.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/public/session_options.h"
@@ -48,8 +50,14 @@ namespace tensorflow {
   opt_options.graph = graph;
   opt_options.session_options = &session_options;
   opt_options.flib_def = flib_def;
-  MarkForCompilationPass pass;
-  return pass.RunForTest(
+
+  if (options.enable_cluster_scoping) {
+    ClusterScopingPass cluster_scoping_pass;
+    TF_RETURN_IF_ERROR(cluster_scoping_pass.Run(opt_options));
+  }
+
+  MarkForCompilationPass mark_for_compilation_pass;
+  return mark_for_compilation_pass.RunForTest(
       opt_options,
       /*disable_deadness_analysis=*/options.disable_deadness_analysis);
 }
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h
index b81fca43c80..f482a80f5b5 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h
@@ -24,8 +24,12 @@ class MarkForCompilationPassTestHelper {
   struct Options {
     bool enable_global_jit;
     bool disable_deadness_analysis;
+    bool enable_cluster_scoping;
 
-    Options() : enable_global_jit(true), disable_deadness_analysis(true) {}
+    Options()
+        : enable_global_jit(true),
+          disable_deadness_analysis(true),
+          enable_cluster_scoping(true) {}
 
     Options WithNoGlobalJit() {
       Options copy = *this;
@@ -38,6 +42,12 @@ class MarkForCompilationPassTestHelper {
       copy.disable_deadness_analysis = false;
       return copy;
     }
+
+    Options WithNoClusterScoping() {
+      Options copy = *this;
+      copy.enable_cluster_scoping = false;
+      return copy;
+    }
   };
 
   // Runs the MarkForCompilation pass on `graph` after assigning all nodes in