From 88fb36a9c383d0c39245fd6d6dd338ffe560b2c2 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer <kramerb@google.com> Date: Thu, 23 Jan 2020 09:16:11 -0800 Subject: [PATCH] [XLA:CPU] Restrict CallInliner to functions with a single call site This is a rather crude heuristic, but enough to recover performance without causing excessive inlining when calling functions many times. Also remove an outdated comment and a TF_RET_CHECK that's doesn't hold when not inlining everything. PiperOrigin-RevId: 291180120 Change-Id: I50434076891b69f92ded0cdbd40039a4f5858541 --- .../compiler/xla/service/call_inliner.cc | 8 +++-- .../compiler/xla/service/call_inliner.h | 7 ++++ .../compiler/xla/service/call_inliner_test.cc | 35 +++++++++++++++++++ .../compiler/xla/service/cpu/cpu_compiler.cc | 5 ++- 4 files changed, 49 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index 4f2436de4fa..68c2745a206 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -40,9 +40,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { // Resolves the operands to the HLO instruction in the inlined (caller) graph, // and clones the HLO instruction into that graph with the new operands. - // If the instruction is a call, it is added to the work queue. Status DefaultAction(HloInstruction* hlo) override { - TF_RET_CHECK(hlo->opcode() != HloOpcode::kCall); std::vector<HloInstruction*> new_operands; for (HloInstruction* operand : hlo->operands()) { TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand)); @@ -146,7 +144,11 @@ StatusOr<bool> CallInliner::Run(HloModule* module) { VLOG(1) << "Visiting node: " << node.ToString(); for (HloInstruction* instruction : node.computation()->MakeInstructionPostOrder()) { - if (instruction->opcode() == HloOpcode::kCall) { + if (instruction->opcode() == HloOpcode::kCall && + (!single_call_site_ || + call_graph->GetNode(instruction->to_apply()) + .caller_callsites() + .size() == 1)) { TF_RETURN_IF_ERROR(Inline(instruction).status()); did_mutate = true; } diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index 08c4aff4f7f..22b0fdda86d 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -34,10 +34,17 @@ class CallInliner : public HloModulePass { // instructions to their inlined versions. static StatusOr<InlinedInstructionMap> Inline(HloInstruction* call); + // If single_call_site is true, only functions with a single call site will be + // inlined. + explicit CallInliner(bool single_call_site = false) + : single_call_site_(single_call_site) {} ~CallInliner() override = default; absl::string_view name() const override { return "CallInliner"; } StatusOr<bool> Run(HloModule* module) override; + + private: + bool single_call_site_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 02f43ba70c7..a1fa59313e0 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -207,5 +207,40 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { ASSERT_TRUE(mutated); } +TEST_F(CallInlinerTest, InlineSingleUseCalleesOnly) { + constexpr absl::string_view hlo_string = R"( + HloModule inline_module + + a { + ROOT tuple = () tuple() + } + + b { + ROOT tuple.1 = () tuple() + } + + ENTRY inline { + a = () call(), to_apply=a + b = () call(), to_apply=a + c = () call(), to_apply=b + ROOT tuple = ((), (), ()) tuple(a, b, c) + })"; + + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + CallInliner call_inliner(/*single_call_site=*/true); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + ASSERT_TRUE(mutated); + + ASSERT_EQ(module->entry_computation()->instruction_count(), 4); + auto inst = module->entry_computation()->instructions().begin(); + EXPECT_THAT(*inst, op::Call()); + ++inst; + EXPECT_THAT(*inst, op::Call()); + ++inst; + EXPECT_THAT(*inst, op::Tuple()); + ++inst; + EXPECT_THAT(*inst, op::Tuple()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index a04a39b4461..50d9c99fa4c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -254,9 +254,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass<CholeskyExpander>(); pipeline.AddPass<TriangularSolveExpander>(); - // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner - // pass. - pipeline.AddPass<CallInliner>(); + // Inline computations with a single call site. + pipeline.AddPass<CallInliner>(/*single_call_site=*/true); pipeline.AddPass<BatchDotSimplification>(); pipeline.AddPass<DotDecomposer>(); // After canonicalization, there may be more batch dots that can be