From 4bff6c2441ba7aff209d4234166cfb288f46ef87 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Wed, 5 Jun 2019 16:48:16 -0700 Subject: [PATCH 1/5] Add an auto-clustering heuristic for RNN performance. - The heuristic is to avoid creating dependency between while loop cond and body computations. This unnecessary, artificial dependency can greatly hurt the while loop performance. --- .../compiler/jit/mark_for_compilation_pass.cc | 68 +++++++++++++++++++ .../jit/mark_for_compilation_pass_test.cc | 29 ++++++++ 2 files changed, 97 insertions(+) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 22b47ae5fdb..70a41fd66ea 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1010,6 +1010,45 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() { return Status::OK(); } +StatusOr IsIdentityDrivingConstsInLoop(Node* node) { + if (!node->IsIdentity()) { + return false; + } + + // Check if one of the ancestors is a Switch node. + Node* switch_node = nullptr; + for (const Edge* e : node->in_edges()) { + if (e->src()->IsSwitch()) { + switch_node = e->src(); + break; + } + } + if (switch_node == nullptr) { + return false; + } + + // Check if the Switch is driven by LoopCond. + const Node* maybe_loopcond; + TF_RETURN_IF_ERROR(switch_node->input_node(1, &maybe_loopcond)); + if (!maybe_loopcond->IsLoopCond()) { + return false; + } + + // Check if the Identity is driving any const nodes through a control edge. + bool driving_any_consts = false; + for (const Edge* e : node->out_edges()) { + if (e->dst()->IsConstant() && e->IsControlEdge()) { + driving_any_consts = true; + break; + } + } + if (!driving_any_consts) { + return false; + } + + return true; +} + Status MarkForCompilationPassImpl::FindCompilationCandidates() { OptimizerOptions opts; std::unique_ptr pflr( @@ -1131,6 +1170,35 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { } } + // This is a heuristic to avoid creating dependency between while loop + // condition and body computations. Dependency between them can be created + // if a special Identity node in the following pattern is clustered in. + // That is, an Identity node in the loop cond computation is used to drive + // const nodes consumed by the loop body. If this Identity node goes into + // the same cluster with nodes from the loop body, extra dependency is + // created between the loop cond and body computations and it hinders the + // progression of the loop cond computation at runtime with significant + // overhead. Specifically, we look for the below pattern and do not cluster + // in this Identity to avoid the described issue. Since Identity has low + // execution cost in native TF, the fact that this heuristic gives up these + // special Identity nodes as candidates should not harm any performance. If + // other considerations emerge in the future, we can revisit the heuristic + // and only disallow these Identities to go into the cluster with nodes from + // the loop body but still consider them candidates. + // + // LoopCond -> + // Merge -> Switch -> Identity -> i++ -> ... -> NextIteration + // ..> Const -> LoopBody + // (control edge) + TF_ASSIGN_OR_RETURN(bool is_identity_driving_consts_in_loop, + IsIdentityDrivingConstsInLoop(node)); + if (is_identity_driving_consts_in_loop) { + VLOG(2) << "Rejecting " << node->name() + << ": including it can create dependencies between while loop " + "condition and body computations with runtime overhead."; + continue; + } + compilation_candidates_.insert(node); --(*debug_options_.fuel); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 46ab220d1bd..5afea26d449 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -1677,5 +1677,34 @@ TEST(XlaCompilationTest, IterationIncrementAndGroupDeps) { EXPECT_EQ(clusters["some_ctrl_input"], clusters["matmul_0"]); } +TEST(XlaCompilationTest, DontClusterTheSpecialIdentityDrivingConstsInLoop) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output cond = ops::Placeholder(root.WithOpName("cond"), DT_BOOL); + Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT); + Output loop_cond = ops::LoopCond(root.WithOpName("loop_cond"), cond); + ops::Switch switch_node(root.WithOpName("switch"), value, loop_cond); + + // The special identity driving consts. Expect that is is not in any + // clusters. + Output identity = + ops::Identity(root.WithOpName("identity"), switch_node.output_true); + Output const_node = ops::Const(root.WithOpName("const"), 1.0f); + root.graph()->AddControlEdge(identity.node(), const_node.node()); + Output tanh0 = ops::Tanh(root.WithOpName("tanh0"), const_node); + Output tanh1 = ops::Tanh(root.WithOpName("tanh1"), tanh0); + Output add = ops::Add(root.WithOpName("add"), const_node, tanh1); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_EXPECT_OK(root.ToGraph(graph.get())); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation( + &graph, + MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis())); + auto clusters = GetClusters(*graph); + + EXPECT_EQ(clusters["identity"], ""); +} + } // namespace } // namespace tensorflow From 0e227ada63410fee179de624adc2d996ea1fd1df Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Wed, 5 Jun 2019 19:25:44 -0700 Subject: [PATCH 2/5] Update results of auto-clustering integration tests. --- ...pens2s_gnmt_mixed_precision.golden_summary | 178 +++++++++--------- 1 file changed, 84 insertions(+), 94 deletions(-) diff --git a/tensorflow/compiler/jit/tests/opens2s_gnmt_mixed_precision.golden_summary b/tensorflow/compiler/jit/tests/opens2s_gnmt_mixed_precision.golden_summary index 108ba4a28a5..aa2754cf4d1 100644 --- a/tensorflow/compiler/jit/tests/opens2s_gnmt_mixed_precision.golden_summary +++ b/tensorflow/compiler/jit/tests/opens2s_gnmt_mixed_precision.golden_summary @@ -1,8 +1,8 @@ -Clustered nodes: 1988 -Unclustered nodes: 3960 +Clustered nodes: 1962 +Unclustered nodes: 3974 Number of clusters: 29 -unclustered size 3960 +unclustered size 3974 Add 17 AddN 1 ApplyAdam 38 @@ -14,7 +14,7 @@ unclustered size 3960 Cast 8 ConcatOffset 10 ConcatV2 2 - Const 704 + Const 708 ControlTrigger 5 DynamicStitch 1 Enter 874 @@ -24,7 +24,7 @@ unclustered size 3960 FloorDiv 1 FloorMod 1 GreaterEqual 7 - Identity 105 + Identity 113 IsVariableInitialized 1 IteratorGetNext 1 IteratorV2 1 @@ -42,7 +42,7 @@ unclustered size 3960 RefSwitch 166 Reshape 2 ScatterAdd 4 - Shape 4 + Shape 6 ShapeN 10 Size 2 Snapshot 1 @@ -169,15 +169,13 @@ cluster 7 size 11 Mul 2 Pow 1 Sub 1 -cluster 10 size 14 - Add 2 +cluster 10 size 8 + Add 1 All 2 - Const 4 + Const 2 GreaterEqual 1 - Identity 1 Less 1 LogicalOr 1 - Shape 2 cluster 11 size 226 Add 24 BatchMatMulV2 1 @@ -226,13 +224,12 @@ cluster 12 size 430 TanhGrad 17 Tile 2 ZerosLike 1 -cluster 13 size 25 - Add 3 +cluster 13 size 20 + Add 2 BiasAdd 1 ConcatV2 1 - Const 3 + Const 1 GreaterEqual 1 - Identity 2 MatMul 1 Mul 3 Select 3 @@ -256,13 +253,12 @@ cluster 14 size 52 Slice 2 Sum 9 TanhGrad 2 -cluster 15 size 25 - Add 3 +cluster 15 size 20 + Add 2 BiasAdd 1 ConcatV2 1 - Const 3 + Const 1 GreaterEqual 1 - Identity 2 MatMul 1 Mul 3 Select 3 @@ -290,14 +286,13 @@ cluster 17 size 52 Slice 2 Sum 9 TanhGrad 2 -cluster 19 size 30 - Add 3 +cluster 19 size 25 + Add 2 BiasAdd 1 Cast 1 ConcatV2 1 - Const 5 + Const 3 GreaterEqual 2 - Identity 2 MatMul 1 Mul 5 Select 2 @@ -305,77 +300,7 @@ cluster 19 size 30 Snapshot 1 Split 1 Tanh 2 -cluster 20 size 23 - Add 3 - BiasAdd 1 - Cast 1 - ConcatV2 1 - GreaterEqual 1 - Identity 1 - MatMul 1 - Mul 5 - Select 2 - Sigmoid 3 - Snapshot 1 - Split 1 - Tanh 2 -cluster 21 size 23 - Add 3 - BiasAdd 1 - Cast 1 - ConcatV2 1 - GreaterEqual 1 - Identity 1 - MatMul 1 - Mul 5 - Select 2 - Sigmoid 3 - Snapshot 1 - Split 1 - Tanh 2 -cluster 22 size 23 - Add 3 - BiasAdd 1 - Cast 1 - ConcatV2 1 - GreaterEqual 1 - Identity 1 - MatMul 1 - Mul 5 - Select 2 - Sigmoid 3 - Snapshot 1 - Split 1 - Tanh 2 -cluster 23 size 23 - Add 3 - BiasAdd 1 - Cast 1 - ConcatV2 1 - GreaterEqual 1 - Identity 1 - MatMul 1 - Mul 5 - Select 2 - Sigmoid 3 - Snapshot 1 - Split 1 - Tanh 2 -cluster 24 size 24 - Add 3 - BiasAdd 1 - Cast 1 - ConcatV2 1 - GreaterEqual 1 - Identity 1 - MatMul 1 - Mul 5 - Select 3 - Sigmoid 3 - Snapshot 1 - Split 1 - Tanh 2 -cluster 25 size 363 +cluster 20 size 363 Add 12 AddN 28 BiasAddGrad 6 @@ -391,6 +316,71 @@ cluster 25 size 363 Slice 12 Sum 76 TanhGrad 12 +cluster 21 size 22 + Add 3 + BiasAdd 1 + Cast 1 + ConcatV2 1 + GreaterEqual 1 + MatMul 1 + Mul 5 + Select 2 + Sigmoid 3 + Snapshot 1 + Split 1 + Tanh 2 +cluster 22 size 22 + Add 3 + BiasAdd 1 + Cast 1 + ConcatV2 1 + GreaterEqual 1 + MatMul 1 + Mul 5 + Select 2 + Sigmoid 3 + Snapshot 1 + Split 1 + Tanh 2 +cluster 23 size 22 + Add 3 + BiasAdd 1 + Cast 1 + ConcatV2 1 + GreaterEqual 1 + MatMul 1 + Mul 5 + Select 2 + Sigmoid 3 + Snapshot 1 + Split 1 + Tanh 2 +cluster 24 size 22 + Add 3 + BiasAdd 1 + Cast 1 + ConcatV2 1 + GreaterEqual 1 + MatMul 1 + Mul 5 + Select 2 + Sigmoid 3 + Snapshot 1 + Split 1 + Tanh 2 +cluster 25 size 23 + Add 3 + BiasAdd 1 + Cast 1 + ConcatV2 1 + GreaterEqual 1 + MatMul 1 + Mul 5 + Select 3 + Sigmoid 3 + Snapshot 1 + Split 1 + Tanh 2 cluster 26 size 9 AddN 1 MatMul 2 From 466f9adf03ef9f740e9799f3c3a40a1b6e163db7 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Thu, 6 Jun 2019 14:48:07 -0700 Subject: [PATCH 3/5] Some code polishing. --- .../compiler/jit/mark_for_compilation_pass.cc | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 70a41fd66ea..05307253c6e 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1015,10 +1015,10 @@ StatusOr IsIdentityDrivingConstsInLoop(Node* node) { return false; } - // Check if one of the ancestors is a Switch node. + // Check if the Identity is driven by a Switch on its true path. Node* switch_node = nullptr; for (const Edge* e : node->in_edges()) { - if (e->src()->IsSwitch()) { + if (e->src()->IsSwitch() && e->src_output() == 1) { switch_node = e->src(); break; } @@ -1035,13 +1035,10 @@ StatusOr IsIdentityDrivingConstsInLoop(Node* node) { } // Check if the Identity is driving any const nodes through a control edge. - bool driving_any_consts = false; - for (const Edge* e : node->out_edges()) { - if (e->dst()->IsConstant() && e->IsControlEdge()) { - driving_any_consts = true; - break; - } - } + bool driving_any_consts = + absl::c_any_of(node->out_edges(), [](const Edge* e) { + return e->dst()->IsConstant() && e->IsControlEdge(); + }); if (!driving_any_consts) { return false; } From d6008ae64743e77877b3602d7b7b1840f1f4d9d8 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Thu, 6 Jun 2019 15:15:45 -0700 Subject: [PATCH 4/5] Minor style polishing. --- .../compiler/jit/mark_for_compilation_pass.cc | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 05307253c6e..07a2b4f036b 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1016,16 +1016,13 @@ StatusOr IsIdentityDrivingConstsInLoop(Node* node) { } // Check if the Identity is driven by a Switch on its true path. - Node* switch_node = nullptr; - for (const Edge* e : node->in_edges()) { - if (e->src()->IsSwitch() && e->src_output() == 1) { - switch_node = e->src(); - break; - } - } - if (switch_node == nullptr) { + auto it = absl::c_find_if(node->in_edges(), [](const Edge* e) { + return e->src()->IsSwitch() && e->src_output() == 1; + }); + if (it == node->in_edges().end()) { return false; } + const Node* switch_node = (*it)->src(); // Check if the Switch is driven by LoopCond. const Node* maybe_loopcond; From 49f8021d459c63de3b7e38f5b6220faa77d1be70 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Fri, 7 Jun 2019 14:11:24 -0700 Subject: [PATCH 5/5] Some more minor code/comment polishing. --- tensorflow/compiler/jit/mark_for_compilation_pass.cc | 6 +++--- tensorflow/compiler/jit/mark_for_compilation_pass_test.cc | 7 +++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 07a2b4f036b..95befb73414 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1025,9 +1025,9 @@ StatusOr IsIdentityDrivingConstsInLoop(Node* node) { const Node* switch_node = (*it)->src(); // Check if the Switch is driven by LoopCond. - const Node* maybe_loopcond; - TF_RETURN_IF_ERROR(switch_node->input_node(1, &maybe_loopcond)); - if (!maybe_loopcond->IsLoopCond()) { + const Node* maybe_loop_cond; + TF_RETURN_IF_ERROR(switch_node->input_node(1, &maybe_loop_cond)); + if (!maybe_loop_cond->IsLoopCond()) { return false; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 5afea26d449..bebff82abda 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -1677,6 +1677,11 @@ TEST(XlaCompilationTest, IterationIncrementAndGroupDeps) { EXPECT_EQ(clusters["some_ctrl_input"], clusters["matmul_0"]); } +// Test a pattern where a special Identity node is driving consts in a loop. +// Expect that the Identity node will not go into any clusters. Note that we +// create an incomplete graph here (e.g., lacking Enter/Exit/NextIteration, +// etc.) just enough to test the pattern, as a complete graph may be too +// cumbersome and unnecessary. TEST(XlaCompilationTest, DontClusterTheSpecialIdentityDrivingConstsInLoop) { Scope root = Scope::NewRootScope().ExitOnError(); @@ -1685,8 +1690,6 @@ TEST(XlaCompilationTest, DontClusterTheSpecialIdentityDrivingConstsInLoop) { Output loop_cond = ops::LoopCond(root.WithOpName("loop_cond"), cond); ops::Switch switch_node(root.WithOpName("switch"), value, loop_cond); - // The special identity driving consts. Expect that is is not in any - // clusters. Output identity = ops::Identity(root.WithOpName("identity"), switch_node.output_true); Output const_node = ops::Const(root.WithOpName("const"), 1.0f);