[XLA] Switch CallInliner to use CallGraph::VisitNodes.

PiperOrigin-RevId: 168078645
This commit is contained in:
Chris Leary 2017-09-08 19:58:41 -07:00 committed by TensorFlower Gardener
parent aba3466f17
commit 405def792e
4 changed files with 33 additions and 84 deletions

View File

@ -288,6 +288,7 @@ cc_library(
srcs = ["call_inliner.cc"],
hdrs = ["call_inliner.h"],
deps = [
":call_graph",
":hlo_pass",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <deque>
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
@ -29,14 +30,18 @@ namespace {
// computation have been added to the work_queue.
class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
public:
SubcomputationInsertionVisitor(HloInstruction* call,
std::deque<HloInstruction*>* work_queue)
: call_(call), outer_(call->parent()), work_queue_(work_queue) {}
// call is the call operation -- it will be replaced with the body of the
// called computation.
explicit SubcomputationInsertionVisitor(HloInstruction* call)
: call_(call), outer_(call->parent()) {
CHECK_EQ(HloOpcode::kCall, call_->opcode());
}
// Resolves the operands to the HLO instruction in the inlined (caller) graph,
// and clones the HLO instruction into that graph with the new operands.
// If the instruction is a call, it is added to the work queue.
Status DefaultAction(HloInstruction* hlo) override {
TF_RET_CHECK(hlo->opcode() != HloOpcode::kCall);
std::vector<HloInstruction*> new_operands;
for (HloInstruction* operand : hlo->operands()) {
TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand));
@ -56,12 +61,6 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
new_control_predecessor->AddControlDependencyTo(new_hlo_pointer));
}
if (new_hlo_pointer->opcode() == HloOpcode::kCall) {
VLOG(1) << "Adding new call HLO to work queue.";
// Call instructions we observe in the subcomputation are added to the
// inliner work queue.
work_queue_->push_back(new_hlo_pointer);
}
return Status::OK();
}
@ -121,71 +120,27 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
} // namespace
StatusOr<bool> CallInliner::Run(HloModule* module) {
std::deque<HloInstruction*> work_queue;
tensorflow::gtl::FlatSet<HloComputation*> seen;
auto scan_computation = [&work_queue,
&seen](HloComputation* computation) -> Status {
if (!seen.insert(computation).second) {
return Status::OK(); // Already seen.
}
return computation->Accept([&](HloInstruction* hlo) {
if (!hlo->called_computations().empty()) {
work_queue.push_back(hlo);
}
return Status::OK();
});
};
// Seed the work queue with call instructions from the main computation.
TF_RETURN_IF_ERROR(scan_computation(module->entry_computation()));
VLOG(1) << "Work queue seeded with " << work_queue.size() << " entries.";
bool mutated = false;
while (!work_queue.empty()) {
HloInstruction* caller = work_queue.front();
work_queue.pop_front();
switch (caller->opcode()) {
case HloOpcode::kCall:
mutated = true;
TF_RETURN_IF_ERROR(ReplaceWithInlinedBody(caller, &work_queue));
break;
case HloOpcode::kWhile:
TF_RETURN_IF_ERROR(scan_computation(caller->while_condition()));
TF_RETURN_IF_ERROR(scan_computation(caller->while_body()));
break;
case HloOpcode::kSelectAndScatter:
TF_RETURN_IF_ERROR(scan_computation(caller->select()));
TF_RETURN_IF_ERROR(scan_computation(caller->scatter()));
break;
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
TF_RETURN_IF_ERROR(scan_computation(caller->to_apply()));
break;
case HloOpcode::kFusion:
// Fusion nodes don't represent true calls, but instead delimit a
// boundary for the backend-specific fusion capabilities.
break;
default:
return Unimplemented("Unknown higher-order HLO opcode: %s",
caller->ToString().c_str());
}
}
return mutated;
}
Status CallInliner::ReplaceWithInlinedBody(
HloInstruction* call, std::deque<HloInstruction*>* work_queue) {
TF_RET_CHECK(call->opcode() == HloOpcode::kCall);
TF_RET_CHECK(call->called_computations().size() == 1);
HloComputation* called = call->called_computations()[0];
VLOG(1) << "Replacing call " << call->ToString() << " with inlined body of "
<< called->name();
SubcomputationInsertionVisitor visitor(call, work_queue);
return called->Accept(&visitor);
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
// Because call graph nodes are visited in post-order (callees before callers)
// we'll always inline kCalls into their callers in the appropriate order.
bool did_mutate = false;
TF_RETURN_IF_ERROR(
call_graph->VisitNodes([&](const CallGraphNode& node) -> Status {
for (const CallSite& callsite : node.caller_callsites()) {
VLOG(1) << "Visiting callsite: " << callsite.ToString();
if (callsite.instruction()->opcode() == HloOpcode::kCall) {
did_mutate = true;
const auto& callees = callsite.called_computations();
TF_RET_CHECK(callees.size() == 1);
HloComputation* callee = callees[0];
// We visit the callee, cloning its body into its caller.
SubcomputationInsertionVisitor visitor(callsite.instruction());
TF_RETURN_IF_ERROR(callee->Accept(&visitor));
}
}
return Status::OK();
}));
return did_mutate;
}
} // namespace xla

View File

@ -31,16 +31,6 @@ class CallInliner : public HloPassInterface {
tensorflow::StringPiece name() const override { return "CallInliner"; }
StatusOr<bool> Run(HloModule* module) override;
private:
// Replaces the given call operation -- which must be an operation inside the
// entry computation with opcode kCall -- with the called computation's body,
// such that the called computation is inline in the entry computation.
//
// On successful inlining, the inlined computation may have itself contained
// calls; if so, they are added to the work_queue.
Status ReplaceWithInlinedBody(HloInstruction* call,
std::deque<HloInstruction*>* work_queue);
};
} // namespace xla

View File

@ -44,6 +44,8 @@ namespace {
using CallInlinerTest = HloTestBase;
TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
// "inner" computation just has a control dependency from the "zero" value to
// the "one" value.
HloComputation::Builder inner(TestName() + ".inner");
HloInstruction* zero = inner.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(24.0f)));
@ -54,6 +56,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
HloComputation* inner_computation =
module->AddEmbeddedComputation(inner.Build());
// "outer" computation just calls the "inner" computation.
HloComputation::Builder outer(TestName() + ".outer");
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
outer.AddInstruction(