Conditional code motion -- don't visit other conditionals as boundary.
If a conditional is an input to another conditional, don't attempt to move the second conditional into the first one. This helps avoid quadratic when multiple conditionals are chained together. PiperOrigin-RevId: 331263856 Change-Id: Ib3ef639db44d31ebfd7dc69ede0a3b3ecf8d9817
This commit is contained in:
parent
6c1840fa5b
commit
f1f8573343
@ -911,14 +911,15 @@ class GroupConnectedBoundaries {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::vector<Boundary> BoundariesToMoveInOrOut(const Boundary& b) {
|
std::vector<Boundary> BoundariesToMoveInOrOut(HloInstruction* conditional,
|
||||||
|
const Boundary& b) {
|
||||||
// At the beginning of optimization, a conditional itself is added to a
|
// At the beginning of optimization, a conditional itself is added to a
|
||||||
// worklist. Here the conditional is expanded into two sets of boundaries:
|
// worklist. Here the conditional is expanded into two sets of boundaries:
|
||||||
// the first set contains the boundary that is inside branches and
|
// the first set contains the boundary that is inside branches and
|
||||||
// contains the root of all branches; the second set of boundaries
|
// contains the root of all branches; the second set of boundaries
|
||||||
// contains all the users of the conditional.
|
// contains all the users of the conditional.
|
||||||
HloInstruction* inst = b.operands()[0];
|
HloInstruction* inst = b.operands()[0];
|
||||||
if (inst->opcode() == HloOpcode::kConditional) {
|
if (inst == conditional) {
|
||||||
int branch_count = inst->branch_count();
|
int branch_count = inst->branch_count();
|
||||||
// Add conditional roots as a new boundary to visit.
|
// Add conditional roots as a new boundary to visit.
|
||||||
Boundary boundary_in(Boundary::Position::kInsideBranch);
|
Boundary boundary_in(Boundary::Position::kInsideBranch);
|
||||||
@ -949,7 +950,8 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
|
|||||||
HloInstruction* conditional, const Boundary& cur_boundary,
|
HloInstruction* conditional, const Boundary& cur_boundary,
|
||||||
std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries) {
|
std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries) {
|
||||||
GroupConnectedBoundaries connect(conditional, is_layout_sensitive_);
|
GroupConnectedBoundaries connect(conditional, is_layout_sensitive_);
|
||||||
auto move_in_or_out = connect.BoundariesToMoveInOrOut(cur_boundary);
|
auto move_in_or_out =
|
||||||
|
connect.BoundariesToMoveInOrOut(conditional, cur_boundary);
|
||||||
if (!move_in_or_out.empty()) {
|
if (!move_in_or_out.empty()) {
|
||||||
auto benefit = connect.BenefitForMovingBoundaries(move_in_or_out);
|
auto benefit = connect.BenefitForMovingBoundaries(move_in_or_out);
|
||||||
VLOG(2) << "benefit of moving in or out "
|
VLOG(2) << "benefit of moving in or out "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user