Merge pull request #29470 from trentlo:loop-cond-dependency
PiperOrigin-RevId: 252866329
This commit is contained in:
commit
e7badb16d0
@ -1014,6 +1014,39 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
|
||||
if (!node->IsIdentity()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if the Identity is driven by a Switch on its true path.
|
||||
auto it = absl::c_find_if(node->in_edges(), [](const Edge* e) {
|
||||
return e->src()->IsSwitch() && e->src_output() == 1;
|
||||
});
|
||||
if (it == node->in_edges().end()) {
|
||||
return false;
|
||||
}
|
||||
const Node* switch_node = (*it)->src();
|
||||
|
||||
// Check if the Switch is driven by LoopCond.
|
||||
const Node* maybe_loop_cond;
|
||||
TF_RETURN_IF_ERROR(switch_node->input_node(1, &maybe_loop_cond));
|
||||
if (!maybe_loop_cond->IsLoopCond()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if the Identity is driving any const nodes through a control edge.
|
||||
bool driving_any_consts =
|
||||
absl::c_any_of(node->out_edges(), [](const Edge* e) {
|
||||
return e->dst()->IsConstant() && e->IsControlEdge();
|
||||
});
|
||||
if (!driving_any_consts) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
OptimizerOptions opts;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
@ -1135,6 +1168,35 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
}
|
||||
}
|
||||
|
||||
// This is a heuristic to avoid creating dependency between while loop
|
||||
// condition and body computations. Dependency between them can be created
|
||||
// if a special Identity node in the following pattern is clustered in.
|
||||
// That is, an Identity node in the loop cond computation is used to drive
|
||||
// const nodes consumed by the loop body. If this Identity node goes into
|
||||
// the same cluster with nodes from the loop body, extra dependency is
|
||||
// created between the loop cond and body computations and it hinders the
|
||||
// progression of the loop cond computation at runtime with significant
|
||||
// overhead. Specifically, we look for the below pattern and do not cluster
|
||||
// in this Identity to avoid the described issue. Since Identity has low
|
||||
// execution cost in native TF, the fact that this heuristic gives up these
|
||||
// special Identity nodes as candidates should not harm any performance. If
|
||||
// other considerations emerge in the future, we can revisit the heuristic
|
||||
// and only disallow these Identities to go into the cluster with nodes from
|
||||
// the loop body but still consider them candidates.
|
||||
//
|
||||
// LoopCond ->
|
||||
// Merge -> Switch -> Identity -> i++ -> ... -> NextIteration
|
||||
// ..> Const -> LoopBody
|
||||
// (control edge)
|
||||
TF_ASSIGN_OR_RETURN(bool is_identity_driving_consts_in_loop,
|
||||
IsIdentityDrivingConstsInLoop(node));
|
||||
if (is_identity_driving_consts_in_loop) {
|
||||
VLOG(2) << "Rejecting " << node->name()
|
||||
<< ": including it can create dependencies between while loop "
|
||||
"condition and body computations with runtime overhead.";
|
||||
continue;
|
||||
}
|
||||
|
||||
compilation_candidates_.insert(node);
|
||||
--(*debug_options_.fuel);
|
||||
}
|
||||
|
@ -1677,5 +1677,37 @@ TEST(XlaCompilationTest, IterationIncrementAndGroupDeps) {
|
||||
EXPECT_EQ(clusters["some_ctrl_input"], clusters["matmul_0"]);
|
||||
}
|
||||
|
||||
// Test a pattern where a special Identity node is driving consts in a loop.
|
||||
// Expect that the Identity node will not go into any clusters. Note that we
|
||||
// create an incomplete graph here (e.g., lacking Enter/Exit/NextIteration,
|
||||
// etc.) just enough to test the pattern, as a complete graph may be too
|
||||
// cumbersome and unnecessary.
|
||||
TEST(XlaCompilationTest, DontClusterTheSpecialIdentityDrivingConstsInLoop) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
Output cond = ops::Placeholder(root.WithOpName("cond"), DT_BOOL);
|
||||
Output value = ops::Placeholder(root.WithOpName("value"), DT_FLOAT);
|
||||
Output loop_cond = ops::LoopCond(root.WithOpName("loop_cond"), cond);
|
||||
ops::Switch switch_node(root.WithOpName("switch"), value, loop_cond);
|
||||
|
||||
Output identity =
|
||||
ops::Identity(root.WithOpName("identity"), switch_node.output_true);
|
||||
Output const_node = ops::Const(root.WithOpName("const"), 1.0f);
|
||||
root.graph()->AddControlEdge(identity.node(), const_node.node());
|
||||
Output tanh0 = ops::Tanh(root.WithOpName("tanh0"), const_node);
|
||||
Output tanh1 = ops::Tanh(root.WithOpName("tanh1"), tanh0);
|
||||
Output add = ops::Add(root.WithOpName("add"), const_node, tanh1);
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_EXPECT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(
|
||||
&graph,
|
||||
MarkForCompilationPassTestHelper::Options().WithDeadnessAnalysis()));
|
||||
auto clusters = GetClusters(*graph);
|
||||
|
||||
EXPECT_EQ(clusters["identity"], "");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -1,8 +1,8 @@
|
||||
Clustered nodes: 1988
|
||||
Unclustered nodes: 3960
|
||||
Clustered nodes: 1962
|
||||
Unclustered nodes: 3974
|
||||
Number of clusters: 29
|
||||
|
||||
unclustered size 3960
|
||||
unclustered size 3974
|
||||
Add 17
|
||||
AddN 1
|
||||
ApplyAdam 38
|
||||
@ -14,7 +14,7 @@ unclustered size 3960
|
||||
Cast 8
|
||||
ConcatOffset 10
|
||||
ConcatV2 2
|
||||
Const 704
|
||||
Const 708
|
||||
ControlTrigger 5
|
||||
DynamicStitch 1
|
||||
Enter 874
|
||||
@ -24,7 +24,7 @@ unclustered size 3960
|
||||
FloorDiv 1
|
||||
FloorMod 1
|
||||
GreaterEqual 7
|
||||
Identity 105
|
||||
Identity 113
|
||||
IsVariableInitialized 1
|
||||
IteratorGetNext 1
|
||||
IteratorV2 1
|
||||
@ -42,7 +42,7 @@ unclustered size 3960
|
||||
RefSwitch 166
|
||||
Reshape 2
|
||||
ScatterAdd 4
|
||||
Shape 4
|
||||
Shape 6
|
||||
ShapeN 10
|
||||
Size 2
|
||||
Snapshot 1
|
||||
@ -169,15 +169,13 @@ cluster 7 size 11
|
||||
Mul 2
|
||||
Pow 1
|
||||
Sub 1
|
||||
cluster 10 size 14
|
||||
Add 2
|
||||
cluster 10 size 8
|
||||
Add 1
|
||||
All 2
|
||||
Const 4
|
||||
Const 2
|
||||
GreaterEqual 1
|
||||
Identity 1
|
||||
Less 1
|
||||
LogicalOr 1
|
||||
Shape 2
|
||||
cluster 11 size 226
|
||||
Add 24
|
||||
BatchMatMulV2 1
|
||||
@ -226,13 +224,12 @@ cluster 12 size 430
|
||||
TanhGrad 17
|
||||
Tile 2
|
||||
ZerosLike 1
|
||||
cluster 13 size 25
|
||||
Add 3
|
||||
cluster 13 size 20
|
||||
Add 2
|
||||
BiasAdd 1
|
||||
ConcatV2 1
|
||||
Const 3
|
||||
Const 1
|
||||
GreaterEqual 1
|
||||
Identity 2
|
||||
MatMul 1
|
||||
Mul 3
|
||||
Select 3
|
||||
@ -256,13 +253,12 @@ cluster 14 size 52
|
||||
Slice 2
|
||||
Sum 9
|
||||
TanhGrad 2
|
||||
cluster 15 size 25
|
||||
Add 3
|
||||
cluster 15 size 20
|
||||
Add 2
|
||||
BiasAdd 1
|
||||
ConcatV2 1
|
||||
Const 3
|
||||
Const 1
|
||||
GreaterEqual 1
|
||||
Identity 2
|
||||
MatMul 1
|
||||
Mul 3
|
||||
Select 3
|
||||
@ -290,14 +286,13 @@ cluster 17 size 52
|
||||
Slice 2
|
||||
Sum 9
|
||||
TanhGrad 2
|
||||
cluster 19 size 30
|
||||
Add 3
|
||||
cluster 19 size 25
|
||||
Add 2
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
Const 5
|
||||
Const 3
|
||||
GreaterEqual 2
|
||||
Identity 2
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
@ -305,77 +300,7 @@ cluster 19 size 30
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 20 size 23
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
Identity 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 21 size 23
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
Identity 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 22 size 23
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
Identity 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 23 size 23
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
Identity 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 24 size 24
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
Identity 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 3
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 25 size 363
|
||||
cluster 20 size 363
|
||||
Add 12
|
||||
AddN 28
|
||||
BiasAddGrad 6
|
||||
@ -391,6 +316,71 @@ cluster 25 size 363
|
||||
Slice 12
|
||||
Sum 76
|
||||
TanhGrad 12
|
||||
cluster 21 size 22
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 22 size 22
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 23 size 22
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 24 size 22
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 2
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 25 size 23
|
||||
Add 3
|
||||
BiasAdd 1
|
||||
Cast 1
|
||||
ConcatV2 1
|
||||
GreaterEqual 1
|
||||
MatMul 1
|
||||
Mul 5
|
||||
Select 3
|
||||
Sigmoid 3
|
||||
Snapshot 1
|
||||
Split 1
|
||||
Tanh 2
|
||||
cluster 26 size 9
|
||||
AddN 1
|
||||
MatMul 2
|
||||
|
Loading…
Reference in New Issue
Block a user