[XLA] Conditional simplifier: replace root with empty tuple if no users.
PiperOrigin-RevId: 261237215
This commit is contained in:
parent
f6182b8594
commit
856804b30b
@ -253,6 +253,31 @@ StatusOr<bool> TryRemoveUnusedConditionalOperands(
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Replaces the roots of all branches with an empty tuple if the conditional op
|
||||
// has no users. Returns if anything is changed.
|
||||
bool ReplaceRootWithEmptyTupleIfNoUsers(HloInstruction* conditional_op) {
|
||||
const Shape empty_tuple = ShapeUtil::MakeTupleShape({});
|
||||
if (conditional_op->user_count() == 0 &&
|
||||
conditional_op != conditional_op->parent()->root_instruction() &&
|
||||
!ShapeUtil::Compatible(empty_tuple, conditional_op->shape())) {
|
||||
for (int64 branch_id = 0; branch_id < conditional_op->branch_count();
|
||||
++branch_id) {
|
||||
auto branch_computation =
|
||||
conditional_op->GetModule()->AddEmbeddedComputation(
|
||||
conditional_op->branch_computation(branch_id)->Clone());
|
||||
conditional_op->set_branch_computation(branch_id, branch_computation);
|
||||
auto new_empty_root =
|
||||
branch_computation->AddInstruction(HloInstruction::CreateTuple({}));
|
||||
branch_computation->set_root_instruction(new_empty_root,
|
||||
/*accept_different_shape=*/true);
|
||||
}
|
||||
*conditional_op->mutable_shape() = empty_tuple;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> ConditionalSimplifier::Run(HloModule* module) {
|
||||
@ -274,6 +299,7 @@ StatusOr<bool> ConditionalSimplifier::Run(HloModule* module) {
|
||||
|
||||
std::map<HloComputation*, std::set<int64>> changed_computations;
|
||||
for (HloInstruction* conditional_op : conditional_ops) {
|
||||
changed |= ReplaceRootWithEmptyTupleIfNoUsers(conditional_op);
|
||||
TF_ASSIGN_OR_RETURN(bool result, TryRemoveConditional(conditional_op));
|
||||
if (!result) {
|
||||
TF_ASSIGN_OR_RETURN(result, TryRemoveUnusedConditionalOperands(
|
||||
|
@ -285,6 +285,49 @@ TEST_F(ConditionalSimplifierTest,
|
||||
EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(ConditionalSimplifierTest, RemoveDeadRoots) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule RemoveDeadRoots
|
||||
on_false {
|
||||
t = (f32[20,40], f32[40,40]) parameter(0)
|
||||
lhs = f32[20,40] get-tuple-element(t), index=0
|
||||
rhs = f32[40,40] get-tuple-element(t), index=1
|
||||
dot = f32[20,40] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
after-all = token[] after-all()
|
||||
outfeed = token[] outfeed(dot, after-all)
|
||||
ROOT result = (f32[20,40]) tuple(dot)
|
||||
}
|
||||
|
||||
on_true {
|
||||
t = (f32[20,40], f32[40,40]) parameter(0)
|
||||
lhs = f32[20,40] get-tuple-element(t), index=0
|
||||
add = f32[20,40] add(lhs, lhs)
|
||||
ROOT result = (f32[20,40]) tuple(add)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
c0_0 = f32[20,40] parameter(0)
|
||||
c0_1 = f32[40,40] parameter(1)
|
||||
p = pred[] parameter(2)
|
||||
t = (f32[20,40], f32[40,40]) tuple(c0_0, c0_1)
|
||||
conditional = (f32[20, 40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true
|
||||
ROOT result = () tuple()
|
||||
}
|
||||
)";
|
||||
auto status = ParseAndReturnUnverifiedModule(hlo_string);
|
||||
TF_ASSERT_OK(status.status());
|
||||
HloVerifier v(false, false);
|
||||
TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
|
||||
EXPECT_TRUE(
|
||||
ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie());
|
||||
TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status());
|
||||
HloInstruction* conditional =
|
||||
FindInstruction(status.ValueOrDie().get(), "conditional");
|
||||
// The conditional root should be replaced with an empty tuple.
|
||||
EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user