Do not DCE while bodies which have IO operations.

PiperOrigin-RevId: 212750173
This commit is contained in:
A. Unique TensorFlower 2018-09-12 20:32:37 -07:00 committed by TensorFlower Gardener
parent f03e8e0b9b
commit f4d8442e13
2 changed files with 37 additions and 19 deletions

View File

@ -36,23 +36,6 @@ namespace xla {
namespace {
bool HasSendRecv(HloComputation* computation) {
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kSend ||
instruction->opcode() == HloOpcode::kSendDone ||
instruction->opcode() == HloOpcode::kRecv ||
instruction->opcode() == HloOpcode::kRecvDone) {
return true;
}
for (auto* sub_computation : instruction->called_computations()) {
if (HasSendRecv(sub_computation)) {
return true;
}
}
}
return false;
}
StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
bool changed = false;
for (auto* computation : module->computations()) {
@ -68,9 +51,10 @@ StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
if (!ShapeUtil::IsTuple(xla_while->shape()) ||
while_body_root->opcode() != HloOpcode::kTuple ||
HasSendRecv(while_body_comp)) {
while_body_comp->HasSideEffect() ||
xla_while->while_condition()->HasSideEffect()) {
// Only run DCE on tuple-shaped while loops where body root is Tuple,
// with no send/recv instructions.
// with no I/O instructions.
VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString();
continue;
}

View File

@ -367,5 +367,39 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
"while.2", 1));
}
// Tests that a while whose body has outfeed operations is not DCE-ed.
TEST_F(HloModuleDceTest, WhileWithOutfeed) {
auto module = ParseHloString(R"(
HloModule OutfeedLoop
WhileBody {
loop_var.1 = (s32[]) parameter(0)
token = token[] after-all()
constant.2 = s32[] constant(2)
outfeed_tuple = (s32[]) outfeed(constant.2, token)
get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
constant.1 = s32[] constant(1)
add = s32[] add(get-tuple-element.1, constant.1)
ROOT tuple = (s32[]) tuple(add)
}
WhileCondition {
loop_var.2 = (s32[]) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
constant.2 = s32[] constant(10)
ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
}
ENTRY SimpleLoop {
constant.3 = s32[] constant(0)
tuple.1 = (s32[]) tuple(constant.3)
ROOT while = (s32[]) while(tuple.1), condition=WhileCondition,
body=WhileBody
})")
.ValueOrDie();
HloModuleDCE dce;
EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
"while", 0));
}
} // namespace
} // namespace xla