[XLA] Conditional simplifier: replace root with empty tuple if no users.

PiperOrigin-RevId: 261237215
This commit is contained in:
Yuanzhong Xu 2019-08-01 18:05:54 -07:00 committed by TensorFlower Gardener
parent f6182b8594
commit 856804b30b
2 changed files with 69 additions and 0 deletions

View File

@ -253,6 +253,31 @@ StatusOr<bool> TryRemoveUnusedConditionalOperands(
} }
return true; return true;
} }
// Replaces the roots of all branches with an empty tuple if the conditional op
// has no users. Returns if anything is changed.
bool ReplaceRootWithEmptyTupleIfNoUsers(HloInstruction* conditional_op) {
const Shape empty_tuple = ShapeUtil::MakeTupleShape({});
if (conditional_op->user_count() == 0 &&
conditional_op != conditional_op->parent()->root_instruction() &&
!ShapeUtil::Compatible(empty_tuple, conditional_op->shape())) {
for (int64 branch_id = 0; branch_id < conditional_op->branch_count();
++branch_id) {
auto branch_computation =
conditional_op->GetModule()->AddEmbeddedComputation(
conditional_op->branch_computation(branch_id)->Clone());
conditional_op->set_branch_computation(branch_id, branch_computation);
auto new_empty_root =
branch_computation->AddInstruction(HloInstruction::CreateTuple({}));
branch_computation->set_root_instruction(new_empty_root,
/*accept_different_shape=*/true);
}
*conditional_op->mutable_shape() = empty_tuple;
return true;
}
return false;
}
} // namespace } // namespace
StatusOr<bool> ConditionalSimplifier::Run(HloModule* module) { StatusOr<bool> ConditionalSimplifier::Run(HloModule* module) {
@ -274,6 +299,7 @@ StatusOr<bool> ConditionalSimplifier::Run(HloModule* module) {
std::map<HloComputation*, std::set<int64>> changed_computations; std::map<HloComputation*, std::set<int64>> changed_computations;
for (HloInstruction* conditional_op : conditional_ops) { for (HloInstruction* conditional_op : conditional_ops) {
changed |= ReplaceRootWithEmptyTupleIfNoUsers(conditional_op);
TF_ASSIGN_OR_RETURN(bool result, TryRemoveConditional(conditional_op)); TF_ASSIGN_OR_RETURN(bool result, TryRemoveConditional(conditional_op));
if (!result) { if (!result) {
TF_ASSIGN_OR_RETURN(result, TryRemoveUnusedConditionalOperands( TF_ASSIGN_OR_RETURN(result, TryRemoveUnusedConditionalOperands(

View File

@ -285,6 +285,49 @@ TEST_F(ConditionalSimplifierTest,
EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie()); EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie());
} }
TEST_F(ConditionalSimplifierTest, RemoveDeadRoots) {
absl::string_view hlo_string =
R"(
HloModule RemoveDeadRoots
on_false {
t = (f32[20,40], f32[40,40]) parameter(0)
lhs = f32[20,40] get-tuple-element(t), index=0
rhs = f32[40,40] get-tuple-element(t), index=1
dot = f32[20,40] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
after-all = token[] after-all()
outfeed = token[] outfeed(dot, after-all)
ROOT result = (f32[20,40]) tuple(dot)
}
on_true {
t = (f32[20,40], f32[40,40]) parameter(0)
lhs = f32[20,40] get-tuple-element(t), index=0
add = f32[20,40] add(lhs, lhs)
ROOT result = (f32[20,40]) tuple(add)
}
ENTRY main {
c0_0 = f32[20,40] parameter(0)
c0_1 = f32[40,40] parameter(1)
p = pred[] parameter(2)
t = (f32[20,40], f32[40,40]) tuple(c0_0, c0_1)
conditional = (f32[20, 40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true
ROOT result = () tuple()
}
)";
auto status = ParseAndReturnUnverifiedModule(hlo_string);
TF_ASSERT_OK(status.status());
HloVerifier v(false, false);
TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
EXPECT_TRUE(
ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie());
TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
HloInstruction* conditional =
FindInstruction(status.ValueOrDie().get(), "conditional");
// The conditional root should be replaced with an empty tuple.
EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 0);
}
} // namespace } // namespace
} // namespace xla } // namespace xla