Fix bug in computing post order when ROOT instruction has trace user.
A 'ROOT' instruction in a computation may still have users that are Trace instructions. In such a case, we would never execute 'ComputeInstructionPoseOrder(...)' if we only check that the instruction has no user. This change fixes this to compute post order if the instruction has only trace users. PiperOrigin-RevId: 289270887 Change-Id: I2e3d6dadb85e580adfc40932e9f9ea4650e0c6c6
This commit is contained in:
parent
e3cf342cc0
commit
6aaa640cec
@ -466,6 +466,12 @@ HloComputation::ComputeChannelDependencies() const {
|
|||||||
return channel_dependency_group;
|
return channel_dependency_group;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline bool HasOnlyTraceUsers(const HloInstruction* instruction) {
|
||||||
|
return absl::c_all_of(instruction->users(), [](HloInstruction* user) {
|
||||||
|
return user->opcode() == HloOpcode::kTrace;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
|
std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
|
||||||
auto channel_dependency_group = ComputeChannelDependencies();
|
auto channel_dependency_group = ComputeChannelDependencies();
|
||||||
std::vector<HloInstruction*> post_order;
|
std::vector<HloInstruction*> post_order;
|
||||||
@ -479,7 +485,7 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
|
|||||||
// instructions to the post order at the end (necessarily they have no
|
// instructions to the post order at the end (necessarily they have no
|
||||||
// users).
|
// users).
|
||||||
trace_instructions.push_back(instruction.get());
|
trace_instructions.push_back(instruction.get());
|
||||||
} else if (instruction->users().empty()) {
|
} else if (HasOnlyTraceUsers(instruction.get())) {
|
||||||
ComputeInstructionPostOrder(channel_dependency_group, &post_order,
|
ComputeInstructionPostOrder(channel_dependency_group, &post_order,
|
||||||
instruction.get(), &visited);
|
instruction.get(), &visited);
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user