diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index 817b59f7627..b604237d72a 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -20,30 +20,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" namespace xla { - -StatusOr CallInliner::Run(HloModule* module) { - std::deque work_queue; - - // Seed the work queue with call instructions from the main computation. - TF_RETURN_IF_ERROR( - module->entry_computation()->Accept([&](HloInstruction* hlo) { - if (hlo->opcode() == HloOpcode::kCall) { - work_queue.push_back(hlo); - } - return Status::OK(); - })); - - VLOG(1) << "Work queue seeded with " << work_queue.size() << " entries."; - - bool mutated = false; - while (!work_queue.empty()) { - mutated = true; - HloInstruction* call = work_queue.front(); - work_queue.pop_front(); - TF_RETURN_IF_ERROR(ReplaceWithInlinedBody(call, &work_queue)); - } - return mutated; -} +namespace { // Traverses the callee computation, inlining cloned nodes into the caller // computation and connecting them to producers/consumers appropriately. @@ -141,6 +118,64 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { std::deque* work_queue_; }; +} // namespace + +StatusOr CallInliner::Run(HloModule* module) { + std::deque work_queue; + tensorflow::gtl::FlatSet seen; + + auto scan_computation = [&work_queue, + &seen](HloComputation* computation) -> Status { + if (!seen.insert(computation).second) { + return Status::OK(); // Already seen. + } + return computation->Accept([&](HloInstruction* hlo) { + if (!hlo->called_computations().empty()) { + work_queue.push_back(hlo); + } + return Status::OK(); + }); + }; + + // Seed the work queue with call instructions from the main computation. + TF_RETURN_IF_ERROR(scan_computation(module->entry_computation())); + + VLOG(1) << "Work queue seeded with " << work_queue.size() << " entries."; + + bool mutated = false; + while (!work_queue.empty()) { + HloInstruction* caller = work_queue.front(); + work_queue.pop_front(); + switch (caller->opcode()) { + case HloOpcode::kCall: + mutated = true; + TF_RETURN_IF_ERROR(ReplaceWithInlinedBody(caller, &work_queue)); + break; + case HloOpcode::kWhile: + TF_RETURN_IF_ERROR(scan_computation(caller->while_condition())); + TF_RETURN_IF_ERROR(scan_computation(caller->while_body())); + break; + case HloOpcode::kSelectAndScatter: + TF_RETURN_IF_ERROR(scan_computation(caller->select())); + TF_RETURN_IF_ERROR(scan_computation(caller->scatter())); + break; + case HloOpcode::kMap: + case HloOpcode::kReduceWindow: + case HloOpcode::kReduce: + TF_RETURN_IF_ERROR(scan_computation(caller->to_apply())); + break; + case HloOpcode::kFusion: + // Fusion nodes don't represent true calls, but instead delimit a + // boundary for the backend-specific fusion capabilities. + break; + default: + return Unimplemented("Unknown higher-order HLO opcode: %s", + caller->ToString().c_str()); + } + } + return mutated; +} + Status CallInliner::ReplaceWithInlinedBody( HloInstruction* call, std::deque* work_queue) { TF_RET_CHECK(call->opcode() == HloOpcode::kCall); diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 77528d0b75f..67633a95748 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -73,5 +73,44 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { EXPECT_EQ(prior->literal().GetFirstElement(), 24); } +// Tests for referential transparency (a function that calls a function that +// returns false should be identical to just returning false). +TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { + const Shape pred = ShapeUtil::MakeShape(PRED, {}); + auto module = CreateNewModule(); + + // Create a lambda that calls a function that returns the false predicate. + // Note we also use this lambda twice by reference, just to make the test a + // little trickier. + HloComputation::Builder just_false(TestName() + ".false"); + just_false.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloComputation* false_computation = + module->AddEmbeddedComputation(just_false.Build()); + + HloComputation::Builder call_false_builder(TestName() + ".call_false"); + call_false_builder.AddInstruction( + HloInstruction::CreateCall(pred, {}, false_computation)); + HloComputation* call_false = + module->AddEmbeddedComputation(call_false_builder.Build()); + + HloComputation::Builder outer(TestName() + ".outer"); + HloInstruction* init_value = outer.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + outer.AddInstruction( + HloInstruction::CreateWhile(pred, call_false, call_false, init_value)); + + auto computation = module->AddEntryComputation(outer.Build()); + + CallInliner call_inliner; + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + ASSERT_TRUE(mutated); + EXPECT_THAT( + computation->root_instruction()->while_condition()->root_instruction(), + op::Constant()); + EXPECT_THAT(computation->root_instruction()->while_body()->root_instruction(), + op::Constant()); +} + } // namespace } // namespace xla