[XLA] Switch CallInliner to use CallGraph::VisitNodes.
PiperOrigin-RevId: 168078645
This commit is contained in:
parent
aba3466f17
commit
405def792e
@ -288,6 +288,7 @@ cc_library(
|
||||
srcs = ["call_inliner.cc"],
|
||||
hdrs = ["call_inliner.h"],
|
||||
deps = [
|
||||
":call_graph",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <deque>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace xla {
|
||||
@ -29,14 +30,18 @@ namespace {
|
||||
// computation have been added to the work_queue.
|
||||
class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
SubcomputationInsertionVisitor(HloInstruction* call,
|
||||
std::deque<HloInstruction*>* work_queue)
|
||||
: call_(call), outer_(call->parent()), work_queue_(work_queue) {}
|
||||
// call is the call operation -- it will be replaced with the body of the
|
||||
// called computation.
|
||||
explicit SubcomputationInsertionVisitor(HloInstruction* call)
|
||||
: call_(call), outer_(call->parent()) {
|
||||
CHECK_EQ(HloOpcode::kCall, call_->opcode());
|
||||
}
|
||||
|
||||
// Resolves the operands to the HLO instruction in the inlined (caller) graph,
|
||||
// and clones the HLO instruction into that graph with the new operands.
|
||||
// If the instruction is a call, it is added to the work queue.
|
||||
Status DefaultAction(HloInstruction* hlo) override {
|
||||
TF_RET_CHECK(hlo->opcode() != HloOpcode::kCall);
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
for (HloInstruction* operand : hlo->operands()) {
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand));
|
||||
@ -56,12 +61,6 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
|
||||
new_control_predecessor->AddControlDependencyTo(new_hlo_pointer));
|
||||
}
|
||||
|
||||
if (new_hlo_pointer->opcode() == HloOpcode::kCall) {
|
||||
VLOG(1) << "Adding new call HLO to work queue.";
|
||||
// Call instructions we observe in the subcomputation are added to the
|
||||
// inliner work queue.
|
||||
work_queue_->push_back(new_hlo_pointer);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -121,71 +120,27 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> CallInliner::Run(HloModule* module) {
|
||||
std::deque<HloInstruction*> work_queue;
|
||||
tensorflow::gtl::FlatSet<HloComputation*> seen;
|
||||
|
||||
auto scan_computation = [&work_queue,
|
||||
&seen](HloComputation* computation) -> Status {
|
||||
if (!seen.insert(computation).second) {
|
||||
return Status::OK(); // Already seen.
|
||||
}
|
||||
return computation->Accept([&](HloInstruction* hlo) {
|
||||
if (!hlo->called_computations().empty()) {
|
||||
work_queue.push_back(hlo);
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
};
|
||||
|
||||
// Seed the work queue with call instructions from the main computation.
|
||||
TF_RETURN_IF_ERROR(scan_computation(module->entry_computation()));
|
||||
|
||||
VLOG(1) << "Work queue seeded with " << work_queue.size() << " entries.";
|
||||
|
||||
bool mutated = false;
|
||||
while (!work_queue.empty()) {
|
||||
HloInstruction* caller = work_queue.front();
|
||||
work_queue.pop_front();
|
||||
switch (caller->opcode()) {
|
||||
case HloOpcode::kCall:
|
||||
mutated = true;
|
||||
TF_RETURN_IF_ERROR(ReplaceWithInlinedBody(caller, &work_queue));
|
||||
break;
|
||||
case HloOpcode::kWhile:
|
||||
TF_RETURN_IF_ERROR(scan_computation(caller->while_condition()));
|
||||
TF_RETURN_IF_ERROR(scan_computation(caller->while_body()));
|
||||
break;
|
||||
case HloOpcode::kSelectAndScatter:
|
||||
TF_RETURN_IF_ERROR(scan_computation(caller->select()));
|
||||
TF_RETURN_IF_ERROR(scan_computation(caller->scatter()));
|
||||
break;
|
||||
case HloOpcode::kMap:
|
||||
case HloOpcode::kReduceWindow:
|
||||
case HloOpcode::kReduce:
|
||||
TF_RETURN_IF_ERROR(scan_computation(caller->to_apply()));
|
||||
break;
|
||||
case HloOpcode::kFusion:
|
||||
// Fusion nodes don't represent true calls, but instead delimit a
|
||||
// boundary for the backend-specific fusion capabilities.
|
||||
break;
|
||||
default:
|
||||
return Unimplemented("Unknown higher-order HLO opcode: %s",
|
||||
caller->ToString().c_str());
|
||||
}
|
||||
}
|
||||
return mutated;
|
||||
}
|
||||
|
||||
Status CallInliner::ReplaceWithInlinedBody(
|
||||
HloInstruction* call, std::deque<HloInstruction*>* work_queue) {
|
||||
TF_RET_CHECK(call->opcode() == HloOpcode::kCall);
|
||||
TF_RET_CHECK(call->called_computations().size() == 1);
|
||||
HloComputation* called = call->called_computations()[0];
|
||||
VLOG(1) << "Replacing call " << call->ToString() << " with inlined body of "
|
||||
<< called->name();
|
||||
|
||||
SubcomputationInsertionVisitor visitor(call, work_queue);
|
||||
return called->Accept(&visitor);
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
// Because call graph nodes are visited in post-order (callees before callers)
|
||||
// we'll always inline kCalls into their callers in the appropriate order.
|
||||
bool did_mutate = false;
|
||||
TF_RETURN_IF_ERROR(
|
||||
call_graph->VisitNodes([&](const CallGraphNode& node) -> Status {
|
||||
for (const CallSite& callsite : node.caller_callsites()) {
|
||||
VLOG(1) << "Visiting callsite: " << callsite.ToString();
|
||||
if (callsite.instruction()->opcode() == HloOpcode::kCall) {
|
||||
did_mutate = true;
|
||||
const auto& callees = callsite.called_computations();
|
||||
TF_RET_CHECK(callees.size() == 1);
|
||||
HloComputation* callee = callees[0];
|
||||
// We visit the callee, cloning its body into its caller.
|
||||
SubcomputationInsertionVisitor visitor(callsite.instruction());
|
||||
TF_RETURN_IF_ERROR(callee->Accept(&visitor));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
return did_mutate;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -31,16 +31,6 @@ class CallInliner : public HloPassInterface {
|
||||
tensorflow::StringPiece name() const override { return "CallInliner"; }
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
// Replaces the given call operation -- which must be an operation inside the
|
||||
// entry computation with opcode kCall -- with the called computation's body,
|
||||
// such that the called computation is inline in the entry computation.
|
||||
//
|
||||
// On successful inlining, the inlined computation may have itself contained
|
||||
// calls; if so, they are added to the work_queue.
|
||||
Status ReplaceWithInlinedBody(HloInstruction* call,
|
||||
std::deque<HloInstruction*>* work_queue);
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -44,6 +44,8 @@ namespace {
|
||||
using CallInlinerTest = HloTestBase;
|
||||
|
||||
TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
|
||||
// "inner" computation just has a control dependency from the "zero" value to
|
||||
// the "one" value.
|
||||
HloComputation::Builder inner(TestName() + ".inner");
|
||||
HloInstruction* zero = inner.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(24.0f)));
|
||||
@ -54,6 +56,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
|
||||
HloComputation* inner_computation =
|
||||
module->AddEmbeddedComputation(inner.Build());
|
||||
|
||||
// "outer" computation just calls the "inner" computation.
|
||||
HloComputation::Builder outer(TestName() + ".outer");
|
||||
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
||||
outer.AddInstruction(
|
||||
|
Loading…
Reference in New Issue
Block a user