Don't rewrite conditionals with empty branches into select.
Executing an empty branch is faster than materializing the select. PiperOrigin-RevId: 323705674 Change-Id: Ibab748535d8dd136764f5b3bc30ca2097cf75151
This commit is contained in:
parent
316acba0b1
commit
37c793e9b8
@ -41,6 +41,26 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
// A computation with array type that only contains parameters and tuples is
|
||||
// considered emtpy.
|
||||
bool ComputationIsEmptyWithArrayRoot(const HloComputation* computation) {
|
||||
bool empty_operations = absl::c_all_of(
|
||||
computation->MakeInstructionPostOrder(), [](const HloInstruction* inst) {
|
||||
return inst->opcode() == HloOpcode::kTuple ||
|
||||
inst->opcode() == HloOpcode::kGetTupleElement ||
|
||||
inst->opcode() == HloOpcode::kParameter;
|
||||
});
|
||||
bool contains_array = false;
|
||||
ShapeUtil::ForEachSubshape(computation->root_instruction()->shape(),
|
||||
[&](const Shape& shape, const ShapeIndex& index) {
|
||||
if (shape.IsArray()) {
|
||||
contains_array = true;
|
||||
}
|
||||
});
|
||||
return empty_operations && contains_array;
|
||||
}
|
||||
|
||||
// Tries to replace a conditional with a call operation of the corresponding
|
||||
// computation. If the given conditional has a constant branch_index, tries to
|
||||
// replace it with a call to its corresponding branch computation and then
|
||||
@ -124,7 +144,6 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
|
||||
<< conditional->ToShortString();
|
||||
return false;
|
||||
}
|
||||
|
||||
HloInstruction* true_call_op = create_call(0);
|
||||
HloInstruction* false_call_op = create_call(1);
|
||||
auto condition_broadcast = [&](const Shape& shape) {
|
||||
@ -140,6 +159,14 @@ StatusOr<bool> TryRemoveConditional(HloInstruction* conditional) {
|
||||
return computation->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
hlo->shape().tuple_shapes(i), hlo, i));
|
||||
};
|
||||
|
||||
bool branch_empty =
|
||||
ComputationIsEmptyWithArrayRoot(conditional->branch_computation(0)) ||
|
||||
ComputationIsEmptyWithArrayRoot(conditional->branch_computation(1));
|
||||
// Empty branch is faster to execute than select.
|
||||
if (branch_empty) {
|
||||
return false;
|
||||
}
|
||||
std::function<HloInstruction*(HloInstruction*, HloInstruction*)> select =
|
||||
[&](HloInstruction* t, HloInstruction* f) {
|
||||
if (f->shape().IsToken()) {
|
||||
@ -559,6 +586,10 @@ StatusOr<bool> ConditionalSimplifier::Run(HloModule* module) {
|
||||
|
||||
absl::flat_hash_set<HloInstruction*> removed_conditionals;
|
||||
for (HloInstruction* conditional_op : conditional_ops) {
|
||||
if (conditional_op->has_sharding()) {
|
||||
// The code below doesn't handle sharding properly.
|
||||
continue;
|
||||
}
|
||||
changed |= MergeDuplicateTupleElements(conditional_op);
|
||||
changed |= RemoveUnusedTupleElements(conditional_op);
|
||||
changed |= ReplaceRootWithEmptyTupleIfNoUsers(conditional_op);
|
||||
@ -573,18 +604,27 @@ StatusOr<bool> ConditionalSimplifier::Run(HloModule* module) {
|
||||
// lets collect them first.
|
||||
absl::flat_hash_map<HloComputation*, absl::flat_hash_set<HloInstruction*>>
|
||||
calling_conditionals;
|
||||
// Keys of calling_conditionals to get a deterministic ordering.
|
||||
std::vector<HloComputation*> calling_computationals_vector;
|
||||
for (HloInstruction* conditional : conditional_ops) {
|
||||
if (removed_conditionals.contains(conditional)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int64 branch = 0; branch < conditional->branch_count(); ++branch) {
|
||||
calling_conditionals[conditional->branch_computation(branch)].insert(
|
||||
conditional);
|
||||
auto* branch_comp = conditional->branch_computation(branch);
|
||||
if (!calling_conditionals.contains(branch_comp)) {
|
||||
calling_computationals_vector.push_back(branch_comp);
|
||||
}
|
||||
calling_conditionals[branch_comp].insert(conditional);
|
||||
}
|
||||
}
|
||||
for (const auto& entry : calling_conditionals) {
|
||||
|
||||
for (auto* comp : calling_computationals_vector) {
|
||||
auto entry = calling_conditionals.find(comp);
|
||||
CHECK(entry != calling_conditionals.end());
|
||||
TF_ASSIGN_OR_RETURN(bool result, TryRemoveUnusedConditionalOperands(
|
||||
entry.first, entry.second));
|
||||
entry->first, entry->second));
|
||||
changed |= result;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user