[XLA:HLO] Small refactoring and more comments in tuple_simplifier.

Explain that optimizing partially used tuples within a single computation falls out of the existing optimizations. There is still the option to optimize partially used tuples across computations. Will look into that in a separate CL.

PiperOrigin-RevId: 315319714
Change-Id: Ifcc41929cb8213cab661ccefea00138e099d551e
This commit is contained in:
Dimitris Vardoulakis 2020-06-08 11:44:27 -07:00 committed by TensorFlower Gardener
parent b0b763203e
commit 81ceabffc6
2 changed files with 51 additions and 40 deletions

View File

@ -33,6 +33,36 @@ namespace xla {
TupleSimplifier::TupleSimplifier(bool exclude_entry_computation)
: exclude_entry_computation_(exclude_entry_computation) {}
StatusOr<bool> TupleSimplifier::RemoveWholeTuple(HloInstruction* tuple) {
bool changed = false;
HloInstruction* top_tuple = nullptr;
bool can_simplify = true;
for (int64 operand_number = 0; operand_number < tuple->operand_count();
++operand_number) {
HloInstruction* operand = tuple->mutable_operand(operand_number);
if (operand->opcode() != HloOpcode::kGetTupleElement ||
operand->tuple_index() != operand_number) {
can_simplify = false;
break;
}
if (top_tuple == nullptr) {
top_tuple = operand->mutable_operand(0);
if (!ShapeUtil::Compatible(top_tuple->shape(), tuple->shape())) {
can_simplify = false;
break;
}
} else if (top_tuple != operand->operand(0)) {
can_simplify = false;
break;
}
}
if (can_simplify && top_tuple != nullptr) {
changed = true;
TF_RETURN_IF_ERROR(tuple->parent()->ReplaceInstruction(tuple, top_tuple));
}
return changed;
}
StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
// Initially add all GTE and Tuple instructions to the worklist.
bool changed = false;
@ -43,46 +73,7 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
}
for (auto* instruction : computation->MakeInstructionPostOrder()) {
if (instruction->opcode() == HloOpcode::kTuple) {
// Collapse the following structure into just 'Tuple-shaped Op':
//
// Tuple-shaped Op
// |
// +-----+-----+
// | | |
// GTE GTE GTE
// | | |
// +-----+-----+
// |
// Tuple
//
HloInstruction* top_tuple = nullptr;
bool can_simplify = true;
for (int64 operand_number = 0;
operand_number < instruction->operand_count(); ++operand_number) {
HloInstruction* operand =
instruction->mutable_operand(operand_number);
if (operand->opcode() != HloOpcode::kGetTupleElement ||
operand->tuple_index() != operand_number) {
can_simplify = false;
break;
}
if (top_tuple == nullptr) {
top_tuple = operand->mutable_operand(0);
if (!ShapeUtil::Compatible(top_tuple->shape(),
instruction->shape())) {
can_simplify = false;
break;
}
} else if (top_tuple != operand->operand(0)) {
can_simplify = false;
break;
}
}
if (can_simplify && top_tuple != nullptr) {
changed = true;
TF_RETURN_IF_ERROR(
computation->ReplaceInstruction(instruction, top_tuple));
}
TF_ASSIGN_OR_RETURN(changed, RemoveWholeTuple(instruction));
} else {
auto ancestor = instruction->LatestNonGteAncestorAndIndex();
if (ancestor.first == instruction) {
@ -102,6 +93,11 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
// GTE
// |
// GTE
//
// Note that this deletes the Tuple instruction altogether. In addition,
// if only a subset of tuple's elements are used, this transform
// optimizes them one at a time, and after the last use is optimized,
// the Tuple will also be deleted.
if (ShapeUtil::Compatible(ancestor.first->shape(),
instruction->shape())) {
changed = true;

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
@ -41,6 +42,20 @@ class TupleSimplifier : public HloModulePass {
// apart from the module's entry computation. This is used by Graphcore's
// backend.
bool exclude_entry_computation_;
// Collapse the following structure into just 'Tuple-shaped Op':
//
// Tuple-shaped Op
// |
// +-----+-----+
// | | |
// GTE GTE GTE
// | | |
// +-----+-----+
// |
// Tuple
//
StatusOr<bool> RemoveWholeTuple(HloInstruction* tuple);
};
} // namespace xla