From 856804b30b00348a7abb01fc4fd13a3106049a2d Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Thu, 1 Aug 2019 18:05:54 -0700 Subject: [PATCH] [XLA] Conditional simplifier: replace root with empty tuple if no users. PiperOrigin-RevId: 261237215 --- .../xla/service/conditional_simplifier.cc | 26 +++++++++++ .../service/conditional_simplifier_test.cc | 43 +++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index f1936035fed..985603b08e4 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -253,6 +253,31 @@ StatusOr TryRemoveUnusedConditionalOperands( } return true; } + +// Replaces the roots of all branches with an empty tuple if the conditional op +// has no users. Returns if anything is changed. +bool ReplaceRootWithEmptyTupleIfNoUsers(HloInstruction* conditional_op) { + const Shape empty_tuple = ShapeUtil::MakeTupleShape({}); + if (conditional_op->user_count() == 0 && + conditional_op != conditional_op->parent()->root_instruction() && + !ShapeUtil::Compatible(empty_tuple, conditional_op->shape())) { + for (int64 branch_id = 0; branch_id < conditional_op->branch_count(); + ++branch_id) { + auto branch_computation = + conditional_op->GetModule()->AddEmbeddedComputation( + conditional_op->branch_computation(branch_id)->Clone()); + conditional_op->set_branch_computation(branch_id, branch_computation); + auto new_empty_root = + branch_computation->AddInstruction(HloInstruction::CreateTuple({})); + branch_computation->set_root_instruction(new_empty_root, + /*accept_different_shape=*/true); + } + *conditional_op->mutable_shape() = empty_tuple; + return true; + } + return false; +} + } // namespace StatusOr ConditionalSimplifier::Run(HloModule* module) { @@ -274,6 +299,7 @@ StatusOr ConditionalSimplifier::Run(HloModule* module) { std::map> changed_computations; for (HloInstruction* conditional_op : conditional_ops) { + changed |= ReplaceRootWithEmptyTupleIfNoUsers(conditional_op); TF_ASSIGN_OR_RETURN(bool result, TryRemoveConditional(conditional_op)); if (!result) { TF_ASSIGN_OR_RETURN(result, TryRemoveUnusedConditionalOperands( diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index 58659156a75..d409e22463e 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -285,6 +285,49 @@ TEST_F(ConditionalSimplifierTest, EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie()); } +TEST_F(ConditionalSimplifierTest, RemoveDeadRoots) { + absl::string_view hlo_string = + R"( +HloModule RemoveDeadRoots +on_false { + t = (f32[20,40], f32[40,40]) parameter(0) + lhs = f32[20,40] get-tuple-element(t), index=0 + rhs = f32[40,40] get-tuple-element(t), index=1 + dot = f32[20,40] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + after-all = token[] after-all() + outfeed = token[] outfeed(dot, after-all) + ROOT result = (f32[20,40]) tuple(dot) +} + +on_true { + t = (f32[20,40], f32[40,40]) parameter(0) + lhs = f32[20,40] get-tuple-element(t), index=0 + add = f32[20,40] add(lhs, lhs) + ROOT result = (f32[20,40]) tuple(add) +} + +ENTRY main { + c0_0 = f32[20,40] parameter(0) + c0_1 = f32[40,40] parameter(1) + p = pred[] parameter(2) + t = (f32[20,40], f32[40,40]) tuple(c0_0, c0_1) + conditional = (f32[20, 40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true + ROOT result = () tuple() +} +)"; + auto status = ParseAndReturnUnverifiedModule(hlo_string); + TF_ASSERT_OK(status.status()); + HloVerifier v(false, false); + TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); + EXPECT_TRUE( + ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie()); + TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); + HloInstruction* conditional = + FindInstruction(status.ValueOrDie().get(), "conditional"); + // The conditional root should be replaced with an empty tuple. + EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 0); +} + } // namespace } // namespace xla