[XLA] Sink side effect check in while loop simplifier only when it's needed.

Side effect check is expensive (Should probably fix that later).

PiperOrigin-RevId: 306965351
Change-Id: I2a8ed8e7f90dbddd25a8ad066d2b88e6ad0302f8
This commit is contained in:
Yunxing Dai 2020-04-16 18:53:24 -07:00 committed by TensorFlower Gardener
parent b43b1dc688
commit 12b485762a
2 changed files with 78 additions and 37 deletions

View File

@ -496,14 +496,43 @@ static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
// Transform while loops with static trip count of 1 into a call op, then
// inline the call.
if (trip_count && *trip_count == 1) {
auto computation = while_op->parent();
auto call_op = computation->AddInstruction(HloInstruction::CreateCall(
while_op->shape(), while_op->operands(), while_op->while_body()));
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op));
TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
CallInliner::Inline(call_op));
(void)inlined_instructions_map;
return true;
// Do not simplify the loop away when there is a side-effectful op,
// otherwise the infeed op may not inherit the data dependency from
// the while loop.
//
// Example: while_body (param_a) {
// param_a = parameter(0)
// infeed2 = infeed()
// }
//
// infeed1 = ...
// while = while(infeed1), body=while_body // infeed2 has implicit
// dependency on infeed1.
//
// After simplification:
//
// infeed1 = ...
// infeed2 = infeed() // no dependency between infeed1 and infeed2. infeed1
// // can be scheduled after infeed2.
//
bool has_side_effects = absl::c_any_of(
while_op->called_computations(), [](const HloComputation* computation) {
return computation->HasSideEffect();
});
if (!has_side_effects) {
auto computation = while_op->parent();
auto call_op = computation->AddInstruction(HloInstruction::CreateCall(
while_op->shape(), while_op->operands(), while_op->while_body()));
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op));
TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
CallInliner::Inline(call_op));
(void)inlined_instructions_map;
return true;
} else {
VLOG(2) << "Not attempting to simplify while loop because it contains a "
"side-effecting node: "
<< while_op->ToShortString();
}
}
return false;
}
@ -1014,35 +1043,6 @@ StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
continue;
}
// Do not simplify the loop away when there is a side-effectful op,
// otherwise the infeed op may not inherit the data dependency from
// the while loop.
//
// Example: while_body (param_a) {
// param_a = parameter(0)
// infeed2 = infeed()
// }
//
// infeed1 = ...
// while = while(infeed1), body=while_body // infeed2 has implicit
// dependency on infeed1.
//
// After simplification:
//
// infeed1 = ...
// infeed2 = infeed() // no dependency between infeed1 and infeed2. infeed1
// // can be scheduled after infeed2.
//
bool has_side_effects = absl::c_any_of(
while_op->called_computations(), [](const HloComputation* computation) {
return computation->HasSideEffect();
});
if (has_side_effects) {
VLOG(2) << "Not attempting to simplify while loop because it contains a "
"side-effecting node: "
<< while_op->ToShortString();
continue;
}
TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op));
changed |= result;

View File

@ -444,6 +444,47 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) {
op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1)));
}
// Check that we can remove unused loop operands even if the loop contains a
// side-effecting instruction.
TEST_F(WhileLoopSimplifierTest,
RemoveUnusedLoopOperandsDespiteSideEffectingOps) {
const string hlo_string = R"(
HloModule RemoveUnusedOperands
body {
loop_var = (s32[]) parameter(0)
gte0 = s32[] get-tuple-element(loop_var), index=0
token0 = token[] after-all()
unused = ((s32[], pred[]), token[]) infeed(token0)
ROOT tuple = (s32[]) tuple(gte0)
}
cond {
loop_var = (s32[]) parameter(0)
ROOT constant = pred[] constant(true)
}
ENTRY RemoveUnusedOperands {
x = s32[] parameter(0)
tuple.1 = (s32[]) tuple(s32[] x)
ROOT while = (s32[]) while((s32[]) tuple.1),
condition=cond, body=body
}
)";
auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
// The original while instruction is still left in the module as a dead
// instruction, find a while instruction with a different name as the new
// while instruction.
const auto& instrs = m->entry_computation()->instructions();
HloInstruction* new_while_op =
*absl::c_find_if(instrs, [&](const HloInstruction* instr) {
return (instr->opcode() == HloOpcode::kWhile &&
instr->name() != "while");
});
EXPECT_TRUE(ShapeUtil::IsEmptyTuple(new_while_op->shape()))
<< new_while_op->shape().ToString();
}
TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) {
const string hlo_string = R"(
HloModule BodyHasNonTupleRoot