diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index be917d6763b..be5a1ca21ed 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1841,6 +1841,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index f6dac508e5f..f1936035fed 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/call_graph.h" @@ -55,15 +56,24 @@ StatusOr TryRemoveConditional(HloInstruction* conditional) { } // We can always inline a 1-branch conditional due to default branch fallback. - int branch_index = 0; - if (conditional->branch_count() > 1) { - if (conditional->operand(0)->opcode() != HloOpcode::kConstant) { - VLOG(2) << "Not attempting to remove conditional as its branch_index is " - "not a compile-time constant: " - << conditional->ToShortString(); - return false; - } + auto computation = conditional->parent(); + auto create_call = [&](int64 branch) { + auto call = computation->AddInstruction(HloInstruction::CreateCall( + conditional->shape(), {conditional->mutable_operand(1 + branch)}, + conditional->branch_computation(branch))); + conditional->SetupDerivedInstruction(call); + return call; + }; + if (conditional->branch_count() == 1) { + HloInstruction* call_op = create_call(0); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op)); + TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status()); + return true; + } + + if (conditional->operand(0)->opcode() == HloOpcode::kConstant) { + int branch_index = 0; if (conditional->operand(0)->shape().element_type() == PRED) { branch_index = conditional->operand(0)->literal().Get({}) ? 0 : 1; } else { @@ -72,16 +82,83 @@ StatusOr TryRemoveConditional(HloInstruction* conditional) { branch_index = conditional->branch_count() - 1; } } - } - auto computation = conditional->parent(); - HloInstruction* call_op; - call_op = computation->AddInstruction(HloInstruction::CreateCall( - conditional->shape(), {conditional->mutable_operand(branch_index + 1)}, - conditional->branch_computation(branch_index))); - conditional->SetupDerivedInstruction(call_op); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op)); - TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status()); + HloInstruction* call_op = create_call(branch_index); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op)); + TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status()); + return true; + } + + auto instruction_is_expensive = [](const HloInstruction* hlo) { + switch (hlo->opcode()) { + case HloOpcode::kBroadcast: + case HloOpcode::kConcatenate: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kGetTupleElement: + case HloOpcode::kReduce: + case HloOpcode::kReshape: + case HloOpcode::kPad: + case HloOpcode::kParameter: + case HloOpcode::kSlice: + case HloOpcode::kTuple: + return false; + default: + return !hlo->IsElementwise(); + } + }; + + if (conditional->branch_count() != 2 || + conditional->operand(0)->shape().element_type() != PRED || + absl::c_any_of(conditional->branch_computation(0)->instructions(), + instruction_is_expensive) || + absl::c_any_of(conditional->branch_computation(1)->instructions(), + instruction_is_expensive)) { + VLOG(2) + << "Not attempting to remove conditional as its branch_index is not a " + "compile-time constant or contains expensive instructions: " + << conditional->ToShortString(); + return false; + } + + HloInstruction* true_call_op = create_call(0); + HloInstruction* false_call_op = create_call(1); + auto condition_broadcast = [&](const Shape& shape) { + if (ShapeUtil::IsScalar(shape)) { + return conditional->mutable_operand(0); + } + return computation->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType(shape, PRED), + conditional->mutable_operand(0), {})); + }; + + auto gte = [&](HloInstruction* hlo, int64 i) { + return computation->AddInstruction(HloInstruction::CreateGetTupleElement( + hlo->shape().tuple_shapes(i), hlo, i)); + }; + std::function select = + [&](HloInstruction* t, HloInstruction* f) { + if (f->shape().IsArray()) { + return computation->AddInstruction(HloInstruction::CreateTernary( + f->shape(), HloOpcode::kSelect, condition_broadcast(f->shape()), + t, f)); + } + std::vector selects; + const int64 tuple_element_count = + ShapeUtil::TupleElementCount(f->shape()); + selects.reserve(tuple_element_count); + for (int64 i = 0; i < tuple_element_count; ++i) { + selects.push_back(select(gte(t, i), gte(f, i))); + } + return computation->AddInstruction( + HloInstruction::CreateTuple(selects)); + }; + + TF_RETURN_IF_ERROR(computation->ReplaceInstruction( + conditional, select(true_call_op, false_call_op))); + + TF_RETURN_IF_ERROR(CallInliner::Inline(false_call_op).status()); + TF_RETURN_IF_ERROR(CallInliner::Inline(true_call_op).status()); return true; } StatusOr TryRemoveUnusedConditionalOperands( diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index a584aba816f..2a560fe9dd2 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -41,10 +41,11 @@ namespace op = xla::testing::opcode_matchers; class ConditionalSimplifierTest : public HloTestBase { public: // Makes a computation that contains a conditional with constant predicate. - HloComputation* MakeConditional(HloModule* module); + HloComputation* MakeConditional(HloModule* module, bool is_constant = true); }; -HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { +HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module, + bool is_constant) { HloComputation::Builder builder(TestName()); // true_computation returns param+1. @@ -83,7 +84,10 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) { } auto false_instrn = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + is_constant + ? HloInstruction::CreateConstant(LiteralUtil::CreateR0(false)) + : HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(PRED, {}), + "cond")); auto false_param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {}), "false_param")); auto one = builder.AddInstruction( @@ -104,6 +108,16 @@ TEST_F(ConditionalSimplifierTest, ConditionalGetsInlined) { op::Add(op::Parameter(), op::Constant())); } +TEST_F(ConditionalSimplifierTest, BranchGetsInlined) { + auto m = CreateNewVerifiedModule(); + HloComputation* computation = MakeConditional(m.get(), /*is_constant=*/false); + ASSERT_TRUE(ConditionalSimplifier().Run(m.get()).ValueOrDie()); + EXPECT_THAT( + computation->root_instruction(), + op::Select(op::Parameter(1), op::Add(op::Constant(), op::Constant()), + op::Add(op::Parameter(0), op::Constant()))); +} + TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) { auto m = CreateNewVerifiedModule(); HloComputation* computation = MakeConditional(m.get());