Merge pull request #29470 from trentlo:loop-cond-dependency

PiperOrigin-RevId: 252866329
This commit is contained in:
TensorFlower Gardener 2019-06-12 12:41:47 -07:00
commit e7badb16d0
3 changed files with 178 additions and 94 deletions

View File

@ -1014,6 +1014,39 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
return Status::OK();
}
StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
if (!node->IsIdentity()) {
return false;
}
// Check if the Identity is driven by a Switch on its true path.
auto it = absl::c_find_if(node->in_edges(), [](const Edge* e) {
return e->src()->IsSwitch() && e->src_output() == 1;
});
if (it == node->in_edges().end()) {
return false;
}
const Node* switch_node = (*it)->src();
// Check if the Switch is driven by LoopCond.
const Node* maybe_loop_cond;
TF_RETURN_IF_ERROR(switch_node->input_node(1, &maybe_loop_cond));
if (!maybe_loop_cond->IsLoopCond()) {
return false;
}
// Check if the Identity is driving any const nodes through a control edge.
bool driving_any_consts =
absl::c_any_of(node->out_edges(), [](const Edge* e) {
return e->dst()->IsConstant() && e->IsControlEdge();
});
if (!driving_any_consts) {
return false;
}
return true;
}
Status MarkForCompilationPassImpl::FindCompilationCandidates() {
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
@ -1135,6 +1168,35 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
}
}
// This is a heuristic to avoid creating dependency between while loop
// condition and body computations. Dependency between them can be created
// if a special Identity node in the following pattern is clustered in.
// That is, an Identity node in the loop cond computation is used to drive
// const nodes consumed by the loop body. If this Identity node goes into
// the same cluster with nodes from the loop body, extra dependency is
// created between the loop cond and body computations and it hinders the
// progression of the loop cond computation at runtime with significant
// overhead. Specifically, we look for the below pattern and do not cluster
// in this Identity to avoid the described issue. Since Identity has low
// execution cost in native TF, the fact that this heuristic gives up these
// special Identity nodes as candidates should not harm any performance. If
// other considerations emerge in the future, we can revisit the heuristic
// and only disallow these Identities to go into the cluster with nodes from
// the loop body but still consider them candidates.
//
// LoopCond ->
// Merge -> Switch -> Identity -> i++ -> ... -> NextIteration
// ..> Const -> LoopBody
// (control edge)
TF_ASSIGN_OR_RETURN(bool is_identity_driving_consts_in_loop,
IsIdentityDrivingConstsInLoop(node));
if (is_identity_driving_consts_in_loop) {
VLOG(2) << "Rejecting " << node->name()
<< ": including it can create dependencies between while loop "
"condition and body computations with runtime overhead.";
continue;
}
compilation_candidates_.insert(node);
--(*debug_options_.fuel);
}

View File

@ -1677,5 +1677,37 @@ TEST(XlaCompilationTest, IterationIncrementAndGroupDeps) {
EXPECT_EQ(clusters["some_ctrl_input"], clusters["matmul_0"]);
}
// Test a pattern where a special Identity node is driving consts in a loop.
// Expect that the Identity node will not go into any clusters. Note that we
// create an incomplete graph here (e.g., lacking Enter/Exit/NextIteration,
// etc.) just enough to test the pattern, as a complete graph may be too
// cumbersome and unnecessary.
TEST(XlaCompilationTest, DontClusterTheSpecialIdentityDrivingConstsInLoop) {
Scope root = Scope::NewRootScope().ExitOnError();
Output cond = ops::Placeholder(root.WithOpName("cond"), DT_BOOL);
Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
Output loop_cond = ops::LoopCond(root.WithOpName("loop_cond"), cond);
ops::Switch switch_node(root.WithOpName("switch"), value, loop_cond);
Output identity =
ops::Identity(root.WithOpName("identity"), switch_node.output_true);
Output const_node = ops::Const(root.WithOpName("const"), 1.0f);
root.graph()->AddControlEdge(identity.node(), const_node.node());
Output tanh0 = ops::Tanh(root.WithOpName("tanh0"), const_node);
Output tanh1 = ops::Tanh(root.WithOpName("tanh1"), tanh0);
Output add = ops::Add(root.WithOpName("add"), const_node, tanh1);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_EXPECT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
&graph,
MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
auto clusters = GetClusters(*graph);
EXPECT_EQ(clusters["identity"], "");
}
} // namespace
} // namespace tensorflow

View File

@ -1,8 +1,8 @@
Clustered nodes: 1988
Unclustered nodes: 3960
Clustered nodes: 1962
Unclustered nodes: 3974
Number of clusters: 29
unclustered size 3960
unclustered size 3974
Add 17
AddN 1
ApplyAdam 38
@ -14,7 +14,7 @@ unclustered size 3960
Cast 8
ConcatOffset 10
ConcatV2 2
Const 704
Const 708
ControlTrigger 5
DynamicStitch 1
Enter 874
@ -24,7 +24,7 @@ unclustered size 3960
FloorDiv 1
FloorMod 1
GreaterEqual 7
Identity 105
Identity 113
IsVariableInitialized 1
IteratorGetNext 1
IteratorV2 1
@ -42,7 +42,7 @@ unclustered size 3960
RefSwitch 166
Reshape 2
ScatterAdd 4
Shape 4
Shape 6
ShapeN 10
Size 2
Snapshot 1
@ -169,15 +169,13 @@ cluster 7 size 11
Mul 2
Pow 1
Sub 1
cluster 10 size 14
Add 2
cluster 10 size 8
Add 1
All 2
Const 4
Const 2
GreaterEqual 1
Identity 1
Less 1
LogicalOr 1
Shape 2
cluster 11 size 226
Add 24
BatchMatMulV2 1
@ -226,13 +224,12 @@ cluster 12 size 430
TanhGrad 17
Tile 2
ZerosLike 1
cluster 13 size 25
Add 3
cluster 13 size 20
Add 2
BiasAdd 1
ConcatV2 1
Const 3
Const 1
GreaterEqual 1
Identity 2
MatMul 1
Mul 3
Select 3
@ -256,13 +253,12 @@ cluster 14 size 52
Slice 2
Sum 9
TanhGrad 2
cluster 15 size 25
Add 3
cluster 15 size 20
Add 2
BiasAdd 1
ConcatV2 1
Const 3
Const 1
GreaterEqual 1
Identity 2
MatMul 1
Mul 3
Select 3
@ -290,14 +286,13 @@ cluster 17 size 52
Slice 2
Sum 9
TanhGrad 2
cluster 19 size 30
Add 3
cluster 19 size 25
Add 2
BiasAdd 1
Cast 1
ConcatV2 1
Const 5
Const 3
GreaterEqual 2
Identity 2
MatMul 1
Mul 5
Select 2
@ -305,77 +300,7 @@ cluster 19 size 30
Snapshot 1
Split 1
Tanh 2
cluster 20 size 23
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
Identity 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 21 size 23
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
Identity 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 22 size 23
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
Identity 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 23 size 23
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
Identity 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 24 size 24
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
Identity 1
MatMul 1
Mul 5
Select 3
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 25 size 363
cluster 20 size 363
Add 12
AddN 28
BiasAddGrad 6
@ -391,6 +316,71 @@ cluster 25 size 363
Slice 12
Sum 76
TanhGrad 12
cluster 21 size 22
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 22 size 22
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 23 size 22
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 24 size 22
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 2
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 25 size 23
Add 3
BiasAdd 1
Cast 1
ConcatV2 1
GreaterEqual 1
MatMul 1
Mul 5
Select 3
Sigmoid 3
Snapshot 1
Split 1
Tanh 2
cluster 26 size 9
AddN 1
MatMul 2