[XLA:CPU] Restrict CallInliner to functions with a single call site

This is a rather crude heuristic, but enough to recover performance without
causing excessive inlining when calling functions many times.

Also remove an outdated comment and a TF_RET_CHECK that's doesn't hold
when not inlining everything.

PiperOrigin-RevId: 291180120
Change-Id: I50434076891b69f92ded0cdbd40039a4f5858541
This commit is contained in:
Benjamin Kramer 2020-01-23 09:16:11 -08:00 committed by TensorFlower Gardener
parent fae39def8e
commit 88fb36a9c3
4 changed files with 49 additions and 6 deletions

View File

@ -40,9 +40,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
// Resolves the operands to the HLO instruction in the inlined (caller) graph,
// and clones the HLO instruction into that graph with the new operands.
// If the instruction is a call, it is added to the work queue.
Status DefaultAction(HloInstruction* hlo) override {
TF_RET_CHECK(hlo->opcode() != HloOpcode::kCall);
std::vector<HloInstruction*> new_operands;
for (HloInstruction* operand : hlo->operands()) {
TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand));
@ -146,7 +144,11 @@ StatusOr<bool> CallInliner::Run(HloModule* module) {
VLOG(1) << "Visiting node: " << node.ToString();
for (HloInstruction* instruction :
node.computation()->MakeInstructionPostOrder()) {
if (instruction->opcode() == HloOpcode::kCall) {
if (instruction->opcode() == HloOpcode::kCall &&
(!single_call_site_ ||
call_graph->GetNode(instruction->to_apply())
.caller_callsites()
.size() == 1)) {
TF_RETURN_IF_ERROR(Inline(instruction).status());
did_mutate = true;
}

View File

@ -34,10 +34,17 @@ class CallInliner : public HloModulePass {
// instructions to their inlined versions.
static StatusOr<InlinedInstructionMap> Inline(HloInstruction* call);
// If single_call_site is true, only functions with a single call site will be
// inlined.
explicit CallInliner(bool single_call_site = false)
: single_call_site_(single_call_site) {}
~CallInliner() override = default;
absl::string_view name() const override { return "CallInliner"; }
StatusOr<bool> Run(HloModule* module) override;
private:
bool single_call_site_;
};
} // namespace xla

View File

@ -207,5 +207,40 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
ASSERT_TRUE(mutated);
}
TEST_F(CallInlinerTest, InlineSingleUseCalleesOnly) {
constexpr absl::string_view hlo_string = R"(
HloModule inline_module
a {
ROOT tuple = () tuple()
}
b {
ROOT tuple.1 = () tuple()
}
ENTRY inline {
a = () call(), to_apply=a
b = () call(), to_apply=a
c = () call(), to_apply=b
ROOT tuple = ((), (), ()) tuple(a, b, c)
})";
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
CallInliner call_inliner(/*single_call_site=*/true);
TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
ASSERT_TRUE(mutated);
ASSERT_EQ(module->entry_computation()->instruction_count(), 4);
auto inst = module->entry_computation()->instructions().begin();
EXPECT_THAT(*inst, op::Call());
++inst;
EXPECT_THAT(*inst, op::Call());
++inst;
EXPECT_THAT(*inst, op::Tuple());
++inst;
EXPECT_THAT(*inst, op::Tuple());
}
} // namespace
} // namespace xla

View File

@ -254,9 +254,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pipeline.AddPass<CholeskyExpander>();
pipeline.AddPass<TriangularSolveExpander>();
// TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner
// pass.
pipeline.AddPass<CallInliner>();
// Inline computations with a single call site.
pipeline.AddPass<CallInliner>(/*single_call_site=*/true);
pipeline.AddPass<BatchDotSimplification>();
pipeline.AddPass<DotDecomposer>();
// After canonicalization, there may be more batch dots that can be