[XLA] Handle higher-order HLOs (e.g. While) in CallInliner and test.
PiperOrigin-RevId: 168029345
This commit is contained in:
parent
8988ae365f
commit
f83f6b9ef1
@ -20,30 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<bool> CallInliner::Run(HloModule* module) {
|
||||
std::deque<HloInstruction*> work_queue;
|
||||
|
||||
// Seed the work queue with call instructions from the main computation.
|
||||
TF_RETURN_IF_ERROR(
|
||||
module->entry_computation()->Accept([&](HloInstruction* hlo) {
|
||||
if (hlo->opcode() == HloOpcode::kCall) {
|
||||
work_queue.push_back(hlo);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
VLOG(1) << "Work queue seeded with " << work_queue.size() << " entries.";
|
||||
|
||||
bool mutated = false;
|
||||
while (!work_queue.empty()) {
|
||||
mutated = true;
|
||||
HloInstruction* call = work_queue.front();
|
||||
work_queue.pop_front();
|
||||
TF_RETURN_IF_ERROR(ReplaceWithInlinedBody(call, &work_queue));
|
||||
}
|
||||
return mutated;
|
||||
}
|
||||
namespace {
|
||||
|
||||
// Traverses the callee computation, inlining cloned nodes into the caller
|
||||
// computation and connecting them to producers/consumers appropriately.
|
||||
@ -141,6 +118,64 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
|
||||
std::deque<HloInstruction*>* work_queue_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> CallInliner::Run(HloModule* module) {
|
||||
std::deque<HloInstruction*> work_queue;
|
||||
tensorflow::gtl::FlatSet<HloComputation*> seen;
|
||||
|
||||
auto scan_computation = [&work_queue,
|
||||
&seen](HloComputation* computation) -> Status {
|
||||
if (!seen.insert(computation).second) {
|
||||
return Status::OK(); // Already seen.
|
||||
}
|
||||
return computation->Accept([&](HloInstruction* hlo) {
|
||||
if (!hlo->called_computations().empty()) {
|
||||
work_queue.push_back(hlo);
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
};
|
||||
|
||||
// Seed the work queue with call instructions from the main computation.
|
||||
TF_RETURN_IF_ERROR(scan_computation(module->entry_computation()));
|
||||
|
||||
VLOG(1) << "Work queue seeded with " << work_queue.size() << " entries.";
|
||||
|
||||
bool mutated = false;
|
||||
while (!work_queue.empty()) {
|
||||
HloInstruction* caller = work_queue.front();
|
||||
work_queue.pop_front();
|
||||
switch (caller->opcode()) {
|
||||
case HloOpcode::kCall:
|
||||
mutated = true;
|
||||
TF_RETURN_IF_ERROR(ReplaceWithInlinedBody(caller, &work_queue));
|
||||
break;
|
||||
case HloOpcode::kWhile:
|
||||
TF_RETURN_IF_ERROR(scan_computation(caller->while_condition()));
|
||||
TF_RETURN_IF_ERROR(scan_computation(caller->while_body()));
|
||||
break;
|
||||
case HloOpcode::kSelectAndScatter:
|
||||
TF_RETURN_IF_ERROR(scan_computation(caller->select()));
|
||||
TF_RETURN_IF_ERROR(scan_computation(caller->scatter()));
|
||||
break;
|
||||
case HloOpcode::kMap:
|
||||
case HloOpcode::kReduceWindow:
|
||||
case HloOpcode::kReduce:
|
||||
TF_RETURN_IF_ERROR(scan_computation(caller->to_apply()));
|
||||
break;
|
||||
case HloOpcode::kFusion:
|
||||
// Fusion nodes don't represent true calls, but instead delimit a
|
||||
// boundary for the backend-specific fusion capabilities.
|
||||
break;
|
||||
default:
|
||||
return Unimplemented("Unknown higher-order HLO opcode: %s",
|
||||
caller->ToString().c_str());
|
||||
}
|
||||
}
|
||||
return mutated;
|
||||
}
|
||||
|
||||
Status CallInliner::ReplaceWithInlinedBody(
|
||||
HloInstruction* call, std::deque<HloInstruction*>* work_queue) {
|
||||
TF_RET_CHECK(call->opcode() == HloOpcode::kCall);
|
||||
|
@ -73,5 +73,44 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
|
||||
EXPECT_EQ(prior->literal().GetFirstElement<float>(), 24);
|
||||
}
|
||||
|
||||
// Tests for referential transparency (a function that calls a function that
|
||||
// returns false should be identical to just returning false).
|
||||
TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
|
||||
const Shape pred = ShapeUtil::MakeShape(PRED, {});
|
||||
auto module = CreateNewModule();
|
||||
|
||||
// Create a lambda that calls a function that returns the false predicate.
|
||||
// Note we also use this lambda twice by reference, just to make the test a
|
||||
// little trickier.
|
||||
HloComputation::Builder just_false(TestName() + ".false");
|
||||
just_false.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
HloComputation* false_computation =
|
||||
module->AddEmbeddedComputation(just_false.Build());
|
||||
|
||||
HloComputation::Builder call_false_builder(TestName() + ".call_false");
|
||||
call_false_builder.AddInstruction(
|
||||
HloInstruction::CreateCall(pred, {}, false_computation));
|
||||
HloComputation* call_false =
|
||||
module->AddEmbeddedComputation(call_false_builder.Build());
|
||||
|
||||
HloComputation::Builder outer(TestName() + ".outer");
|
||||
HloInstruction* init_value = outer.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||
outer.AddInstruction(
|
||||
HloInstruction::CreateWhile(pred, call_false, call_false, init_value));
|
||||
|
||||
auto computation = module->AddEntryComputation(outer.Build());
|
||||
|
||||
CallInliner call_inliner;
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
|
||||
ASSERT_TRUE(mutated);
|
||||
EXPECT_THAT(
|
||||
computation->root_instruction()->while_condition()->root_instruction(),
|
||||
op::Constant());
|
||||
EXPECT_THAT(computation->root_instruction()->while_body()->root_instruction(),
|
||||
op::Constant());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user