Conditional code motion -- don't visit other conditionals as boundary.

If a conditional is an input to another conditional, don't attempt to move the second conditional into the first one. This helps avoid quadratic when multiple conditionals are chained together.

PiperOrigin-RevId: 331263856
Change-Id: Ib3ef639db44d31ebfd7dc69ede0a3b3ecf8d9817
This commit is contained in:
Yunxing Dai 2020-09-11 18:47:21 -07:00 committed by TensorFlower Gardener
parent 6c1840fa5b
commit f1f8573343

View File

@ -911,14 +911,15 @@ class GroupConnectedBoundaries {
}
}
}
std::vector<Boundary> BoundariesToMoveInOrOut(const Boundary& b) {
std::vector<Boundary> BoundariesToMoveInOrOut(HloInstruction* conditional,
const Boundary& b) {
// At the beginning of optimization, a conditional itself is added to a
// worklist. Here the conditional is expanded into two sets of boundaries:
// the first set contains the boundary that is inside branches and
// contains the root of all branches; the second set of boundaries
// contains all the users of the conditional.
HloInstruction* inst = b.operands()[0];
if (inst->opcode() == HloOpcode::kConditional) {
if (inst == conditional) {
int branch_count = inst->branch_count();
// Add conditional roots as a new boundary to visit.
Boundary boundary_in(Boundary::Position::kInsideBranch);
@ -949,7 +950,8 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion(
HloInstruction* conditional, const Boundary& cur_boundary,
std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries) {
GroupConnectedBoundaries connect(conditional, is_layout_sensitive_);
auto move_in_or_out = connect.BoundariesToMoveInOrOut(cur_boundary);
auto move_in_or_out =
connect.BoundariesToMoveInOrOut(conditional, cur_boundary);
if (!move_in_or_out.empty()) {
auto benefit = connect.BenefitForMovingBoundaries(move_in_or_out);
VLOG(2) << "benefit of moving in or out "