From 88fb36a9c383d0c39245fd6d6dd338ffe560b2c2 Mon Sep 17 00:00:00 2001
From: Benjamin Kramer <kramerb@google.com>
Date: Thu, 23 Jan 2020 09:16:11 -0800
Subject: [PATCH] [XLA:CPU] Restrict CallInliner to functions with a single
 call site

This is a rather crude heuristic, but enough to recover performance without
causing excessive inlining when calling functions many times.

Also remove an outdated comment and a TF_RET_CHECK that's doesn't hold
when not inlining everything.

PiperOrigin-RevId: 291180120
Change-Id: I50434076891b69f92ded0cdbd40039a4f5858541
---
 .../compiler/xla/service/call_inliner.cc      |  8 +++--
 .../compiler/xla/service/call_inliner.h       |  7 ++++
 .../compiler/xla/service/call_inliner_test.cc | 35 +++++++++++++++++++
 .../compiler/xla/service/cpu/cpu_compiler.cc  |  5 ++-
 4 files changed, 49 insertions(+), 6 deletions(-)

diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc
index 4f2436de4fa..68c2745a206 100644
--- a/tensorflow/compiler/xla/service/call_inliner.cc
+++ b/tensorflow/compiler/xla/service/call_inliner.cc
@@ -40,9 +40,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
 
   // Resolves the operands to the HLO instruction in the inlined (caller) graph,
   // and clones the HLO instruction into that graph with the new operands.
-  // If the instruction is a call, it is added to the work queue.
   Status DefaultAction(HloInstruction* hlo) override {
-    TF_RET_CHECK(hlo->opcode() != HloOpcode::kCall);
     std::vector<HloInstruction*> new_operands;
     for (HloInstruction* operand : hlo->operands()) {
       TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand));
@@ -146,7 +144,11 @@ StatusOr<bool> CallInliner::Run(HloModule* module) {
         VLOG(1) << "Visiting node: " << node.ToString();
         for (HloInstruction* instruction :
              node.computation()->MakeInstructionPostOrder()) {
-          if (instruction->opcode() == HloOpcode::kCall) {
+          if (instruction->opcode() == HloOpcode::kCall &&
+              (!single_call_site_ ||
+               call_graph->GetNode(instruction->to_apply())
+                       .caller_callsites()
+                       .size() == 1)) {
             TF_RETURN_IF_ERROR(Inline(instruction).status());
             did_mutate = true;
           }
diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h
index 08c4aff4f7f..22b0fdda86d 100644
--- a/tensorflow/compiler/xla/service/call_inliner.h
+++ b/tensorflow/compiler/xla/service/call_inliner.h
@@ -34,10 +34,17 @@ class CallInliner : public HloModulePass {
   // instructions to their inlined versions.
   static StatusOr<InlinedInstructionMap> Inline(HloInstruction* call);
 
+  // If single_call_site is true, only functions with a single call site will be
+  // inlined.
+  explicit CallInliner(bool single_call_site = false)
+      : single_call_site_(single_call_site) {}
   ~CallInliner() override = default;
   absl::string_view name() const override { return "CallInliner"; }
 
   StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+  bool single_call_site_;
 };
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc
index 02f43ba70c7..a1fa59313e0 100644
--- a/tensorflow/compiler/xla/service/call_inliner_test.cc
+++ b/tensorflow/compiler/xla/service/call_inliner_test.cc
@@ -207,5 +207,40 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
   ASSERT_TRUE(mutated);
 }
 
+TEST_F(CallInlinerTest, InlineSingleUseCalleesOnly) {
+  constexpr absl::string_view hlo_string = R"(
+  HloModule inline_module
+
+  a {
+    ROOT tuple = () tuple()
+  }
+
+  b {
+    ROOT tuple.1 = () tuple()
+  }
+
+  ENTRY inline {
+    a = () call(), to_apply=a
+    b = () call(), to_apply=a
+    c = () call(), to_apply=b
+    ROOT tuple = ((), (), ()) tuple(a, b, c)
+  })";
+
+  auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
+  CallInliner call_inliner(/*single_call_site=*/true);
+  TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+  ASSERT_TRUE(mutated);
+
+  ASSERT_EQ(module->entry_computation()->instruction_count(), 4);
+  auto inst = module->entry_computation()->instructions().begin();
+  EXPECT_THAT(*inst, op::Call());
+  ++inst;
+  EXPECT_THAT(*inst, op::Call());
+  ++inst;
+  EXPECT_THAT(*inst, op::Tuple());
+  ++inst;
+  EXPECT_THAT(*inst, op::Tuple());
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index a04a39b4461..50d9c99fa4c 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -254,9 +254,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
   pipeline.AddPass<CholeskyExpander>();
   pipeline.AddPass<TriangularSolveExpander>();
 
-  // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner
-  // pass.
-  pipeline.AddPass<CallInliner>();
+  // Inline computations with a single call site.
+  pipeline.AddPass<CallInliner>(/*single_call_site=*/true);
   pipeline.AddPass<BatchDotSimplification>();
   pipeline.AddPass<DotDecomposer>();
   // After canonicalization, there may be more batch dots that can be