Don't rewrite conditionals with empty branches into select.

Executing an empty branch is faster than materializing the select.

PiperOrigin-RevId: 323705674
Change-Id: Ibab748535d8dd136764f5b3bc30ca2097cf75151
This commit is contained in:
Yunxing Dai 2020-07-28 19:50:57 -07:00 committed by TensorFlower Gardener
parent 316acba0b1
commit 37c793e9b8

View File

@ -41,6 +41,26 @@ limitations under the License.
namespace xla {
namespace {
// A computation with array type that only contains parameters and tuples is
// considered emtpy.
bool ComputationIsEmptyWithArrayRoot(const HloComputation* computation) {
bool empty_operations = absl::c_all_of(
computation->MakeInstructionPostOrder(), [](const HloInstruction* inst) {
return inst->opcode() == HloOpcode::kTuple ||
inst->opcode() == HloOpcode::kGetTupleElement ||
inst->opcode() == HloOpcode::kParameter;
});
bool contains_array = false;
ShapeUtil::ForEachSubshape(computation->root_instruction()->shape(),
[&](const Shape& shape, const ShapeIndex& index) {
if (shape.IsArray()) {
contains_array = true;
}
});
return empty_operations && contains_array;
}
// Tries to replace a conditional with a call operation of the corresponding
// computation. If the given conditional has a constant branch_index, tries to
// replace it with a call to its corresponding branch computation and then
@ -124,7 +144,6 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
<< conditional->ToShortString();
return false;
}
HloInstruction* true_call_op = create_call(0);
HloInstruction* false_call_op = create_call(1);
auto condition_broadcast = [&](const Shape& shape) {
@ -140,6 +159,14 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
return computation->AddInstruction(HloInstruction::CreateGetTupleElement(
hlo->shape().tuple_shapes(i), hlo, i));
};
bool branch_empty =
ComputationIsEmptyWithArrayRoot(conditional->branch_computation(0)) ||
ComputationIsEmptyWithArrayRoot(conditional->branch_computation(1));
// Empty branch is faster to execute than select.
if (branch_empty) {
return false;
}
std::function<HloInstruction*(HloInstruction*, HloInstruction*)> select =
[&](HloInstruction* t, HloInstruction* f) {
if (f->shape().IsToken()) {
@ -559,6 +586,10 @@ StatusOr<bool> ConditionalSimplifier::Run(HloModule* module) {
absl::flat_hash_set<HloInstruction*> removed_conditionals;
for (HloInstruction* conditional_op : conditional_ops) {
if (conditional_op->has_sharding()) {
// The code below doesn't handle sharding properly.
continue;
}
changed |= MergeDuplicateTupleElements(conditional_op);
changed |= RemoveUnusedTupleElements(conditional_op);
changed |= ReplaceRootWithEmptyTupleIfNoUsers(conditional_op);
@ -573,18 +604,27 @@ StatusOr<bool> ConditionalSimplifier::Run(HloModule* module) {
// lets collect them first.
absl::flat_hash_map<HloComputation*, absl::flat_hash_set<HloInstruction*>>
calling_conditionals;
// Keys of calling_conditionals to get a deterministic ordering.
std::vector<HloComputation*> calling_computationals_vector;
for (HloInstruction* conditional : conditional_ops) {
if (removed_conditionals.contains(conditional)) {
continue;
}
for (int64 branch = 0; branch < conditional->branch_count(); ++branch) {
calling_conditionals[conditional->branch_computation(branch)].insert(
conditional);
auto* branch_comp = conditional->branch_computation(branch);
if (!calling_conditionals.contains(branch_comp)) {
calling_computationals_vector.push_back(branch_comp);
}
calling_conditionals[branch_comp].insert(conditional);
}
}
for (const auto& entry : calling_conditionals) {
for (auto* comp : calling_computationals_vector) {
auto entry = calling_conditionals.find(comp);
CHECK(entry != calling_conditionals.end());
TF_ASSIGN_OR_RETURN(bool result, TryRemoveUnusedConditionalOperands(
entry.first, entry.second));
entry->first, entry->second));
changed |= result;
}