[XLA:CPU] Restrict CallInliner to functions with a single call site
This is a rather crude heuristic, but enough to recover performance without causing excessive inlining when calling functions many times. Also remove an outdated comment and a TF_RET_CHECK that's doesn't hold when not inlining everything. PiperOrigin-RevId: 291180120 Change-Id: I50434076891b69f92ded0cdbd40039a4f5858541
This commit is contained in:
parent
fae39def8e
commit
88fb36a9c3
tensorflow/compiler/xla/service
@ -40,9 +40,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
// Resolves the operands to the HLO instruction in the inlined (caller) graph,
|
||||
// and clones the HLO instruction into that graph with the new operands.
|
||||
// If the instruction is a call, it is added to the work queue.
|
||||
Status DefaultAction(HloInstruction* hlo) override {
|
||||
TF_RET_CHECK(hlo->opcode() != HloOpcode::kCall);
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
for (HloInstruction* operand : hlo->operands()) {
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand));
|
||||
@ -146,7 +144,11 @@ StatusOr<bool> CallInliner::Run(HloModule* module) {
|
||||
VLOG(1) << "Visiting node: " << node.ToString();
|
||||
for (HloInstruction* instruction :
|
||||
node.computation()->MakeInstructionPostOrder()) {
|
||||
if (instruction->opcode() == HloOpcode::kCall) {
|
||||
if (instruction->opcode() == HloOpcode::kCall &&
|
||||
(!single_call_site_ ||
|
||||
call_graph->GetNode(instruction->to_apply())
|
||||
.caller_callsites()
|
||||
.size() == 1)) {
|
||||
TF_RETURN_IF_ERROR(Inline(instruction).status());
|
||||
did_mutate = true;
|
||||
}
|
||||
|
@ -34,10 +34,17 @@ class CallInliner : public HloModulePass {
|
||||
// instructions to their inlined versions.
|
||||
static StatusOr<InlinedInstructionMap> Inline(HloInstruction* call);
|
||||
|
||||
// If single_call_site is true, only functions with a single call site will be
|
||||
// inlined.
|
||||
explicit CallInliner(bool single_call_site = false)
|
||||
: single_call_site_(single_call_site) {}
|
||||
~CallInliner() override = default;
|
||||
absl::string_view name() const override { return "CallInliner"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
bool single_call_site_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -207,5 +207,40 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
|
||||
ASSERT_TRUE(mutated);
|
||||
}
|
||||
|
||||
TEST_F(CallInlinerTest, InlineSingleUseCalleesOnly) {
|
||||
constexpr absl::string_view hlo_string = R"(
|
||||
HloModule inline_module
|
||||
|
||||
a {
|
||||
ROOT tuple = () tuple()
|
||||
}
|
||||
|
||||
b {
|
||||
ROOT tuple.1 = () tuple()
|
||||
}
|
||||
|
||||
ENTRY inline {
|
||||
a = () call(), to_apply=a
|
||||
b = () call(), to_apply=a
|
||||
c = () call(), to_apply=b
|
||||
ROOT tuple = ((), (), ()) tuple(a, b, c)
|
||||
})";
|
||||
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
CallInliner call_inliner(/*single_call_site=*/true);
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
|
||||
ASSERT_TRUE(mutated);
|
||||
|
||||
ASSERT_EQ(module->entry_computation()->instruction_count(), 4);
|
||||
auto inst = module->entry_computation()->instructions().begin();
|
||||
EXPECT_THAT(*inst, op::Call());
|
||||
++inst;
|
||||
EXPECT_THAT(*inst, op::Call());
|
||||
++inst;
|
||||
EXPECT_THAT(*inst, op::Tuple());
|
||||
++inst;
|
||||
EXPECT_THAT(*inst, op::Tuple());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -254,9 +254,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
pipeline.AddPass<CholeskyExpander>();
|
||||
pipeline.AddPass<TriangularSolveExpander>();
|
||||
|
||||
// TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner
|
||||
// pass.
|
||||
pipeline.AddPass<CallInliner>();
|
||||
// Inline computations with a single call site.
|
||||
pipeline.AddPass<CallInliner>(/*single_call_site=*/true);
|
||||
pipeline.AddPass<BatchDotSimplification>();
|
||||
pipeline.AddPass<DotDecomposer>();
|
||||
// After canonicalization, there may be more batch dots that can be
|
||||
|
Loading…
Reference in New Issue
Block a user