diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 610c611eee4..d954f3f545d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -288,6 +288,7 @@ cc_library( srcs = ["call_inliner.cc"], hdrs = ["call_inliner.h"], deps = [ + ":call_graph", ":hlo_pass", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index b604237d72a..65472d9ac92 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -29,14 +30,18 @@ namespace { // computation have been added to the work_queue. class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { public: - SubcomputationInsertionVisitor(HloInstruction* call, - std::deque* work_queue) - : call_(call), outer_(call->parent()), work_queue_(work_queue) {} + // call is the call operation -- it will be replaced with the body of the + // called computation. + explicit SubcomputationInsertionVisitor(HloInstruction* call) + : call_(call), outer_(call->parent()) { + CHECK_EQ(HloOpcode::kCall, call_->opcode()); + } // Resolves the operands to the HLO instruction in the inlined (caller) graph, // and clones the HLO instruction into that graph with the new operands. // If the instruction is a call, it is added to the work queue. Status DefaultAction(HloInstruction* hlo) override { + TF_RET_CHECK(hlo->opcode() != HloOpcode::kCall); std::vector new_operands; for (HloInstruction* operand : hlo->operands()) { TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand)); @@ -56,12 +61,6 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { new_control_predecessor->AddControlDependencyTo(new_hlo_pointer)); } - if (new_hlo_pointer->opcode() == HloOpcode::kCall) { - VLOG(1) << "Adding new call HLO to work queue."; - // Call instructions we observe in the subcomputation are added to the - // inliner work queue. - work_queue_->push_back(new_hlo_pointer); - } return Status::OK(); } @@ -121,71 +120,27 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { } // namespace StatusOr CallInliner::Run(HloModule* module) { - std::deque work_queue; - tensorflow::gtl::FlatSet seen; - - auto scan_computation = [&work_queue, - &seen](HloComputation* computation) -> Status { - if (!seen.insert(computation).second) { - return Status::OK(); // Already seen. - } - return computation->Accept([&](HloInstruction* hlo) { - if (!hlo->called_computations().empty()) { - work_queue.push_back(hlo); - } - return Status::OK(); - }); - }; - - // Seed the work queue with call instructions from the main computation. - TF_RETURN_IF_ERROR(scan_computation(module->entry_computation())); - - VLOG(1) << "Work queue seeded with " << work_queue.size() << " entries."; - - bool mutated = false; - while (!work_queue.empty()) { - HloInstruction* caller = work_queue.front(); - work_queue.pop_front(); - switch (caller->opcode()) { - case HloOpcode::kCall: - mutated = true; - TF_RETURN_IF_ERROR(ReplaceWithInlinedBody(caller, &work_queue)); - break; - case HloOpcode::kWhile: - TF_RETURN_IF_ERROR(scan_computation(caller->while_condition())); - TF_RETURN_IF_ERROR(scan_computation(caller->while_body())); - break; - case HloOpcode::kSelectAndScatter: - TF_RETURN_IF_ERROR(scan_computation(caller->select())); - TF_RETURN_IF_ERROR(scan_computation(caller->scatter())); - break; - case HloOpcode::kMap: - case HloOpcode::kReduceWindow: - case HloOpcode::kReduce: - TF_RETURN_IF_ERROR(scan_computation(caller->to_apply())); - break; - case HloOpcode::kFusion: - // Fusion nodes don't represent true calls, but instead delimit a - // boundary for the backend-specific fusion capabilities. - break; - default: - return Unimplemented("Unknown higher-order HLO opcode: %s", - caller->ToString().c_str()); - } - } - return mutated; -} - -Status CallInliner::ReplaceWithInlinedBody( - HloInstruction* call, std::deque* work_queue) { - TF_RET_CHECK(call->opcode() == HloOpcode::kCall); - TF_RET_CHECK(call->called_computations().size() == 1); - HloComputation* called = call->called_computations()[0]; - VLOG(1) << "Replacing call " << call->ToString() << " with inlined body of " - << called->name(); - - SubcomputationInsertionVisitor visitor(call, work_queue); - return called->Accept(&visitor); + std::unique_ptr call_graph = CallGraph::Build(module); + // Because call graph nodes are visited in post-order (callees before callers) + // we'll always inline kCalls into their callers in the appropriate order. + bool did_mutate = false; + TF_RETURN_IF_ERROR( + call_graph->VisitNodes([&](const CallGraphNode& node) -> Status { + for (const CallSite& callsite : node.caller_callsites()) { + VLOG(1) << "Visiting callsite: " << callsite.ToString(); + if (callsite.instruction()->opcode() == HloOpcode::kCall) { + did_mutate = true; + const auto& callees = callsite.called_computations(); + TF_RET_CHECK(callees.size() == 1); + HloComputation* callee = callees[0]; + // We visit the callee, cloning its body into its caller. + SubcomputationInsertionVisitor visitor(callsite.instruction()); + TF_RETURN_IF_ERROR(callee->Accept(&visitor)); + } + } + return Status::OK(); + })); + return did_mutate; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index 8647edffa7f..8660200bc40 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -31,16 +31,6 @@ class CallInliner : public HloPassInterface { tensorflow::StringPiece name() const override { return "CallInliner"; } StatusOr Run(HloModule* module) override; - - private: - // Replaces the given call operation -- which must be an operation inside the - // entry computation with opcode kCall -- with the called computation's body, - // such that the called computation is inline in the entry computation. - // - // On successful inlining, the inlined computation may have itself contained - // calls; if so, they are added to the work_queue. - Status ReplaceWithInlinedBody(HloInstruction* call, - std::deque* work_queue); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 67633a95748..f3e7407c544 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -44,6 +44,8 @@ namespace { using CallInlinerTest = HloTestBase; TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { + // "inner" computation just has a control dependency from the "zero" value to + // the "one" value. HloComputation::Builder inner(TestName() + ".inner"); HloInstruction* zero = inner.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(24.0f))); @@ -54,6 +56,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { HloComputation* inner_computation = module->AddEmbeddedComputation(inner.Build()); + // "outer" computation just calls the "inner" computation. HloComputation::Builder outer(TestName() + ".outer"); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); outer.AddInstruction(