From 50429bdff14d62c46c8b417c7fefb73fd648fc86 Mon Sep 17 00:00:00 2001 From: Trent Lo <trentl@nvidia.com> Date: Fri, 16 Aug 2019 13:37:19 -0700 Subject: [PATCH] Add a new unittest in mark_for_compilation_pass_test. The new test tests that ClusterScopingPass works and MarkForCompilationPass accordinly preserves the required clustering scopes. --- .../jit/mark_for_compilation_pass_test.cc | 86 +++++++++++++++++++ .../mark_for_compilation_pass_test_helper.cc | 12 ++- .../mark_for_compilation_pass_test_helper.h | 12 ++- 3 files changed, 107 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index e056ecd8272..577ddbfca00 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -1718,5 +1718,91 @@ TEST(XlaCompilationTest, UnsupportedEnterExitPattern) { EXPECT_EQ(0, clusters.size()); } +namespace { +Node* MakeStageNode(GraphDefBuilder& builder, string name, + std::initializer_list<DataType> dtypes, + gtl::ArraySlice<ops::NodeOut> values) { + auto opts = builder.opts() + .WithName(std::move(name)) + .WithAttr("dtypes", std::move(dtypes)); + if (opts.HaveError()) { + return nullptr; + } + + NodeBuilder node_builder(name, "Stage", opts.op_registry()); + node_builder.Input(values); + return opts.FinalizeBuilder(&node_builder); +} +} // namespace + +TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) { + auto build_staged_graph = [](std::unique_ptr<Graph>* graph) -> Status { + // Construct a graph as below with two pipeline stages and test that nodes + // in different stages will not be merged if ClusterScopingPass is on. + // + // b + // | + // v + // a -> add0 -> relu0 -> stage + // + // b + // | + // v + // unstage -> add1 -> relu1 + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("a") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::SourceOp("Const", builder.opts() + .WithName("b") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* unstage = ops::SourceOp( + "Unstage", + builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT})); + + Node* add0 = ops::BinaryOp("Add", a, b, builder.opts().WithName("add0")); + Node* add1 = + ops::BinaryOp("Add", unstage, b, builder.opts().WithName("add1")); + Node* relu0 = ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0")); + ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1")); + MakeStageNode(builder, "stage", {DT_FLOAT}, {relu0}); + + return GraphDefBuilderToGraph(builder, graph->get()); + }; + + // All nodes go into the same cluster if ClusterScopingPass is off. + { + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(build_staged_graph(&graph)); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation( + &graph, + MarkForCompilationPassTestHelper::Options().WithNoClusterScoping())); + + std::unordered_map<string, string> clusters = GetClusters(*graph); + EXPECT_EQ(clusters["add0"], clusters["add1"]); + EXPECT_EQ(clusters["add0"], clusters["relu1"]); + EXPECT_EQ(clusters["relu0"], clusters["add1"]); + EXPECT_EQ(clusters["relu0"], clusters["relu1"]); + } + + // By default, ClusterScopingPass is on and different pipeline stages should + // not be merged. + { + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(build_staged_graph(&graph)); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map<string, string> clusters = GetClusters(*graph); + EXPECT_NE(clusters["add0"], clusters["add1"]); + EXPECT_NE(clusters["add0"], clusters["relu1"]); + EXPECT_NE(clusters["relu0"], clusters["add1"]); + EXPECT_NE(clusters["relu0"], clusters["relu1"]); + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc index fa5abdfe508..44bd7b47d54 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" + +#include "tensorflow/compiler/jit/cluster_scoping_pass.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/public/session_options.h" @@ -48,8 +50,14 @@ namespace tensorflow { opt_options.graph = graph; opt_options.session_options = &session_options; opt_options.flib_def = flib_def; - MarkForCompilationPass pass; - return pass.RunForTest( + + if (options.enable_cluster_scoping) { + ClusterScopingPass cluster_scoping_pass; + TF_RETURN_IF_ERROR(cluster_scoping_pass.Run(opt_options)); + } + + MarkForCompilationPass mark_for_compilation_pass; + return mark_for_compilation_pass.RunForTest( opt_options, /*disable_deadness_analysis=*/options.disable_deadness_analysis); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h index b81fca43c80..f482a80f5b5 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h @@ -24,8 +24,12 @@ class MarkForCompilationPassTestHelper { struct Options { bool enable_global_jit; bool disable_deadness_analysis; + bool enable_cluster_scoping; - Options() : enable_global_jit(true), disable_deadness_analysis(true) {} + Options() + : enable_global_jit(true), + disable_deadness_analysis(true), + enable_cluster_scoping(true) {} Options WithNoGlobalJit() { Options copy = *this; @@ -38,6 +42,12 @@ class MarkForCompilationPassTestHelper { copy.disable_deadness_analysis = false; return copy; } + + Options WithNoClusterScoping() { + Options copy = *this; + copy.enable_cluster_scoping = false; + return copy; + } }; // Runs the MarkForCompilation pass on `graph` after assigning all nodes in