[XLA:HLO] Small refactoring and more comments in tuple_simplifier.
Explain that optimizing partially used tuples within a single computation falls out of the existing optimizations. There is still the option to optimize partially used tuples across computations. Will look into that in a separate CL. PiperOrigin-RevId: 315319714 Change-Id: Ifcc41929cb8213cab661ccefea00138e099d551e
This commit is contained in:
parent
b0b763203e
commit
81ceabffc6
@ -33,6 +33,36 @@ namespace xla {
|
||||
TupleSimplifier::TupleSimplifier(bool exclude_entry_computation)
|
||||
: exclude_entry_computation_(exclude_entry_computation) {}
|
||||
|
||||
StatusOr<bool> TupleSimplifier::RemoveWholeTuple(HloInstruction* tuple) {
|
||||
bool changed = false;
|
||||
HloInstruction* top_tuple = nullptr;
|
||||
bool can_simplify = true;
|
||||
for (int64 operand_number = 0; operand_number < tuple->operand_count();
|
||||
++operand_number) {
|
||||
HloInstruction* operand = tuple->mutable_operand(operand_number);
|
||||
if (operand->opcode() != HloOpcode::kGetTupleElement ||
|
||||
operand->tuple_index() != operand_number) {
|
||||
can_simplify = false;
|
||||
break;
|
||||
}
|
||||
if (top_tuple == nullptr) {
|
||||
top_tuple = operand->mutable_operand(0);
|
||||
if (!ShapeUtil::Compatible(top_tuple->shape(), tuple->shape())) {
|
||||
can_simplify = false;
|
||||
break;
|
||||
}
|
||||
} else if (top_tuple != operand->operand(0)) {
|
||||
can_simplify = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (can_simplify && top_tuple != nullptr) {
|
||||
changed = true;
|
||||
TF_RETURN_IF_ERROR(tuple->parent()->ReplaceInstruction(tuple, top_tuple));
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
|
||||
// Initially add all GTE and Tuple instructions to the worklist.
|
||||
bool changed = false;
|
||||
@ -43,46 +73,7 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
|
||||
}
|
||||
for (auto* instruction : computation->MakeInstructionPostOrder()) {
|
||||
if (instruction->opcode() == HloOpcode::kTuple) {
|
||||
// Collapse the following structure into just 'Tuple-shaped Op':
|
||||
//
|
||||
// Tuple-shaped Op
|
||||
// |
|
||||
// +-----+-----+
|
||||
// | | |
|
||||
// GTE GTE GTE
|
||||
// | | |
|
||||
// +-----+-----+
|
||||
// |
|
||||
// Tuple
|
||||
//
|
||||
HloInstruction* top_tuple = nullptr;
|
||||
bool can_simplify = true;
|
||||
for (int64 operand_number = 0;
|
||||
operand_number < instruction->operand_count(); ++operand_number) {
|
||||
HloInstruction* operand =
|
||||
instruction->mutable_operand(operand_number);
|
||||
if (operand->opcode() != HloOpcode::kGetTupleElement ||
|
||||
operand->tuple_index() != operand_number) {
|
||||
can_simplify = false;
|
||||
break;
|
||||
}
|
||||
if (top_tuple == nullptr) {
|
||||
top_tuple = operand->mutable_operand(0);
|
||||
if (!ShapeUtil::Compatible(top_tuple->shape(),
|
||||
instruction->shape())) {
|
||||
can_simplify = false;
|
||||
break;
|
||||
}
|
||||
} else if (top_tuple != operand->operand(0)) {
|
||||
can_simplify = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (can_simplify && top_tuple != nullptr) {
|
||||
changed = true;
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation->ReplaceInstruction(instruction, top_tuple));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(changed, RemoveWholeTuple(instruction));
|
||||
} else {
|
||||
auto ancestor = instruction->LatestNonGteAncestorAndIndex();
|
||||
if (ancestor.first == instruction) {
|
||||
@ -102,6 +93,11 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
|
||||
// GTE
|
||||
// |
|
||||
// GTE
|
||||
//
|
||||
// Note that this deletes the Tuple instruction altogether. In addition,
|
||||
// if only a subset of tuple's elements are used, this transform
|
||||
// optimizes them one at a time, and after the last use is optimized,
|
||||
// the Tuple will also be deleted.
|
||||
if (ShapeUtil::Compatible(ancestor.first->shape(),
|
||||
instruction->shape())) {
|
||||
changed = true;
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
@ -41,6 +42,20 @@ class TupleSimplifier : public HloModulePass {
|
||||
// apart from the module's entry computation. This is used by Graphcore's
|
||||
// backend.
|
||||
bool exclude_entry_computation_;
|
||||
|
||||
// Collapse the following structure into just 'Tuple-shaped Op':
|
||||
//
|
||||
// Tuple-shaped Op
|
||||
// |
|
||||
// +-----+-----+
|
||||
// | | |
|
||||
// GTE GTE GTE
|
||||
// | | |
|
||||
// +-----+-----+
|
||||
// |
|
||||
// Tuple
|
||||
//
|
||||
StatusOr<bool> RemoveWholeTuple(HloInstruction* tuple);
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user