Fix bug in computing post order when ROOT instruction has trace user.

A 'ROOT' instruction in a computation may still have users that are Trace instructions. In such a case, we would never execute 'ComputeInstructionPoseOrder(...)' if we only check that the instruction has no user. This change fixes this to compute post order if the instruction has only trace users.

PiperOrigin-RevId: 289270887
Change-Id: I2e3d6dadb85e580adfc40932e9f9ea4650e0c6c6
This commit is contained in:
Prakalp Srivastava 2020-01-11 14:28:17 -08:00 committed by TensorFlower Gardener
parent e3cf342cc0
commit 6aaa640cec

View File

@ -466,6 +466,12 @@ HloComputation::ComputeChannelDependencies() const {
return channel_dependency_group;
}
static inline bool HasOnlyTraceUsers(const HloInstruction* instruction) {
return absl::c_all_of(instruction->users(), [](HloInstruction* user) {
return user->opcode() == HloOpcode::kTrace;
});
}
std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
auto channel_dependency_group = ComputeChannelDependencies();
std::vector<HloInstruction*> post_order;
@ -479,7 +485,7 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
// instructions to the post order at the end (necessarily they have no
// users).
trace_instructions.push_back(instruction.get());
} else if (instruction->users().empty()) {
} else if (HasOnlyTraceUsers(instruction.get())) {
ComputeInstructionPostOrder(channel_dependency_group, &post_order,
instruction.get(), &visited);
}