Condintional Simplifier: Return early before creating any instruction.
The newly created instruction can activate DCE and then trigger HloFixPass, which creates long compilation time. PiperOrigin-RevId: 343602826 Change-Id: I46eec0bccad24205f8fe8ee59fd73f6d2875c9e4
This commit is contained in:
parent
2f4d749c04
commit
d9255973eb
@ -144,6 +144,15 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
|
|||||||
<< conditional->ToShortString();
|
<< conditional->ToShortString();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool branch_empty =
|
||||||
|
ComputationIsEmptyWithArrayRoot(conditional->branch_computation(0)) ||
|
||||||
|
ComputationIsEmptyWithArrayRoot(conditional->branch_computation(1));
|
||||||
|
// Empty branch is faster to execute than select.
|
||||||
|
if (branch_empty) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
HloInstruction* true_call_op = create_call(0);
|
HloInstruction* true_call_op = create_call(0);
|
||||||
HloInstruction* false_call_op = create_call(1);
|
HloInstruction* false_call_op = create_call(1);
|
||||||
auto condition_broadcast = [&](const Shape& shape) {
|
auto condition_broadcast = [&](const Shape& shape) {
|
||||||
@ -160,13 +169,6 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
|
|||||||
hlo->shape().tuple_shapes(i), hlo, i));
|
hlo->shape().tuple_shapes(i), hlo, i));
|
||||||
};
|
};
|
||||||
|
|
||||||
bool branch_empty =
|
|
||||||
ComputationIsEmptyWithArrayRoot(conditional->branch_computation(0)) ||
|
|
||||||
ComputationIsEmptyWithArrayRoot(conditional->branch_computation(1));
|
|
||||||
// Empty branch is faster to execute than select.
|
|
||||||
if (branch_empty) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
std::function<HloInstruction*(HloInstruction*, HloInstruction*)> select =
|
std::function<HloInstruction*(HloInstruction*, HloInstruction*)> select =
|
||||||
[&](HloInstruction* t, HloInstruction* f) {
|
[&](HloInstruction* t, HloInstruction* f) {
|
||||||
if (f->shape().IsToken()) {
|
if (f->shape().IsToken()) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user