Do not DCE while bodies which have IO operations.
PiperOrigin-RevId: 212750173
This commit is contained in:
parent
f03e8e0b9b
commit
f4d8442e13
@ -36,23 +36,6 @@ namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
bool HasSendRecv(HloComputation* computation) {
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kSend ||
|
||||
instruction->opcode() == HloOpcode::kSendDone ||
|
||||
instruction->opcode() == HloOpcode::kRecv ||
|
||||
instruction->opcode() == HloOpcode::kRecvDone) {
|
||||
return true;
|
||||
}
|
||||
for (auto* sub_computation : instruction->called_computations()) {
|
||||
if (HasSendRecv(sub_computation)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
|
||||
bool changed = false;
|
||||
for (auto* computation : module->computations()) {
|
||||
@ -68,9 +51,10 @@ StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
|
||||
|
||||
if (!ShapeUtil::IsTuple(xla_while->shape()) ||
|
||||
while_body_root->opcode() != HloOpcode::kTuple ||
|
||||
HasSendRecv(while_body_comp)) {
|
||||
while_body_comp->HasSideEffect() ||
|
||||
xla_while->while_condition()->HasSideEffect()) {
|
||||
// Only run DCE on tuple-shaped while loops where body root is Tuple,
|
||||
// with no send/recv instructions.
|
||||
// with no I/O instructions.
|
||||
VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString();
|
||||
continue;
|
||||
}
|
||||
|
@ -367,5 +367,39 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
|
||||
"while.2", 1));
|
||||
}
|
||||
|
||||
// Tests that a while whose body has outfeed operations is not DCE-ed.
|
||||
TEST_F(HloModuleDceTest, WhileWithOutfeed) {
|
||||
auto module = ParseHloString(R"(
|
||||
HloModule OutfeedLoop
|
||||
WhileBody {
|
||||
loop_var.1 = (s32[]) parameter(0)
|
||||
token = token[] after-all()
|
||||
constant.2 = s32[] constant(2)
|
||||
outfeed_tuple = (s32[]) outfeed(constant.2, token)
|
||||
get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
|
||||
constant.1 = s32[] constant(1)
|
||||
add = s32[] add(get-tuple-element.1, constant.1)
|
||||
ROOT tuple = (s32[]) tuple(add)
|
||||
}
|
||||
WhileCondition {
|
||||
loop_var.2 = (s32[]) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||
constant.2 = s32[] constant(10)
|
||||
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
|
||||
}
|
||||
ENTRY SimpleLoop {
|
||||
constant.3 = s32[] constant(0)
|
||||
tuple.1 = (s32[]) tuple(constant.3)
|
||||
ROOT while = (s32[]) while(tuple.1), condition=WhileCondition,
|
||||
body=WhileBody
|
||||
})")
|
||||
.ValueOrDie();
|
||||
|
||||
HloModuleDCE dce;
|
||||
EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
|
||||
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
|
||||
"while", 0));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user