[XLA] Sink side effect check in while loop simplifier only when it's needed.
Side effect check is expensive (Should probably fix that later). PiperOrigin-RevId: 306965351 Change-Id: I2a8ed8e7f90dbddd25a8ad066d2b88e6ad0302f8
This commit is contained in:
parent
b43b1dc688
commit
12b485762a
@ -496,14 +496,43 @@ static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
|
||||
// Transform while loops with static trip count of 1 into a call op, then
|
||||
// inline the call.
|
||||
if (trip_count && *trip_count == 1) {
|
||||
auto computation = while_op->parent();
|
||||
auto call_op = computation->AddInstruction(HloInstruction::CreateCall(
|
||||
while_op->shape(), while_op->operands(), while_op->while_body()));
|
||||
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op));
|
||||
TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
|
||||
CallInliner::Inline(call_op));
|
||||
(void)inlined_instructions_map;
|
||||
return true;
|
||||
// Do not simplify the loop away when there is a side-effectful op,
|
||||
// otherwise the infeed op may not inherit the data dependency from
|
||||
// the while loop.
|
||||
//
|
||||
// Example: while_body (param_a) {
|
||||
// param_a = parameter(0)
|
||||
// infeed2 = infeed()
|
||||
// }
|
||||
//
|
||||
// infeed1 = ...
|
||||
// while = while(infeed1), body=while_body // infeed2 has implicit
|
||||
// dependency on infeed1.
|
||||
//
|
||||
// After simplification:
|
||||
//
|
||||
// infeed1 = ...
|
||||
// infeed2 = infeed() // no dependency between infeed1 and infeed2. infeed1
|
||||
// // can be scheduled after infeed2.
|
||||
//
|
||||
bool has_side_effects = absl::c_any_of(
|
||||
while_op->called_computations(), [](const HloComputation* computation) {
|
||||
return computation->HasSideEffect();
|
||||
});
|
||||
if (!has_side_effects) {
|
||||
auto computation = while_op->parent();
|
||||
auto call_op = computation->AddInstruction(HloInstruction::CreateCall(
|
||||
while_op->shape(), while_op->operands(), while_op->while_body()));
|
||||
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op));
|
||||
TF_ASSIGN_OR_RETURN(auto inlined_instructions_map,
|
||||
CallInliner::Inline(call_op));
|
||||
(void)inlined_instructions_map;
|
||||
return true;
|
||||
} else {
|
||||
VLOG(2) << "Not attempting to simplify while loop because it contains a "
|
||||
"side-effecting node: "
|
||||
<< while_op->ToShortString();
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -1014,35 +1043,6 @@ StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Do not simplify the loop away when there is a side-effectful op,
|
||||
// otherwise the infeed op may not inherit the data dependency from
|
||||
// the while loop.
|
||||
//
|
||||
// Example: while_body (param_a) {
|
||||
// param_a = parameter(0)
|
||||
// infeed2 = infeed()
|
||||
// }
|
||||
//
|
||||
// infeed1 = ...
|
||||
// while = while(infeed1), body=while_body // infeed2 has implicit
|
||||
// dependency on infeed1.
|
||||
//
|
||||
// After simplification:
|
||||
//
|
||||
// infeed1 = ...
|
||||
// infeed2 = infeed() // no dependency between infeed1 and infeed2. infeed1
|
||||
// // can be scheduled after infeed2.
|
||||
//
|
||||
bool has_side_effects = absl::c_any_of(
|
||||
while_op->called_computations(), [](const HloComputation* computation) {
|
||||
return computation->HasSideEffect();
|
||||
});
|
||||
if (has_side_effects) {
|
||||
VLOG(2) << "Not attempting to simplify while loop because it contains a "
|
||||
"side-effecting node: "
|
||||
<< while_op->ToShortString();
|
||||
continue;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op));
|
||||
changed |= result;
|
||||
|
||||
|
||||
@ -444,6 +444,47 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) {
|
||||
op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1)));
|
||||
}
|
||||
|
||||
// Check that we can remove unused loop operands even if the loop contains a
|
||||
// side-effecting instruction.
|
||||
TEST_F(WhileLoopSimplifierTest,
|
||||
RemoveUnusedLoopOperandsDespiteSideEffectingOps) {
|
||||
const string hlo_string = R"(
|
||||
HloModule RemoveUnusedOperands
|
||||
body {
|
||||
loop_var = (s32[]) parameter(0)
|
||||
gte0 = s32[] get-tuple-element(loop_var), index=0
|
||||
token0 = token[] after-all()
|
||||
unused = ((s32[], pred[]), token[]) infeed(token0)
|
||||
ROOT tuple = (s32[]) tuple(gte0)
|
||||
}
|
||||
cond {
|
||||
loop_var = (s32[]) parameter(0)
|
||||
ROOT constant = pred[] constant(true)
|
||||
}
|
||||
ENTRY RemoveUnusedOperands {
|
||||
x = s32[] parameter(0)
|
||||
tuple.1 = (s32[]) tuple(s32[] x)
|
||||
ROOT while = (s32[]) while((s32[]) tuple.1),
|
||||
condition=cond, body=body
|
||||
}
|
||||
)";
|
||||
|
||||
auto m = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
EXPECT_TRUE(WhileLoopSimplifier().Run(m.get()).ValueOrDie());
|
||||
|
||||
// The original while instruction is still left in the module as a dead
|
||||
// instruction, find a while instruction with a different name as the new
|
||||
// while instruction.
|
||||
const auto& instrs = m->entry_computation()->instructions();
|
||||
HloInstruction* new_while_op =
|
||||
*absl::c_find_if(instrs, [&](const HloInstruction* instr) {
|
||||
return (instr->opcode() == HloOpcode::kWhile &&
|
||||
instr->name() != "while");
|
||||
});
|
||||
EXPECT_TRUE(ShapeUtil::IsEmptyTuple(new_while_op->shape()))
|
||||
<< new_while_op->shape().ToString();
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopSimplifierTest, LoopWithNonTupleBodyShapeNotSimplified) {
|
||||
const string hlo_string = R"(
|
||||
HloModule BodyHasNonTupleRoot
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user