[XLA:HLO] Small refactoring and more comments in tuple_simplifier.
Explain that optimizing partially used tuples within a single computation falls out of the existing optimizations. There is still the option to optimize partially used tuples across computations. Will look into that in a separate CL. PiperOrigin-RevId: 315319714 Change-Id: Ifcc41929cb8213cab661ccefea00138e099d551e
This commit is contained in:
parent
b0b763203e
commit
81ceabffc6
@ -33,6 +33,36 @@ namespace xla {
|
|||||||
TupleSimplifier::TupleSimplifier(bool exclude_entry_computation)
|
TupleSimplifier::TupleSimplifier(bool exclude_entry_computation)
|
||||||
: exclude_entry_computation_(exclude_entry_computation) {}
|
: exclude_entry_computation_(exclude_entry_computation) {}
|
||||||
|
|
||||||
|
StatusOr<bool> TupleSimplifier::RemoveWholeTuple(HloInstruction* tuple) {
|
||||||
|
bool changed = false;
|
||||||
|
HloInstruction* top_tuple = nullptr;
|
||||||
|
bool can_simplify = true;
|
||||||
|
for (int64 operand_number = 0; operand_number < tuple->operand_count();
|
||||||
|
++operand_number) {
|
||||||
|
HloInstruction* operand = tuple->mutable_operand(operand_number);
|
||||||
|
if (operand->opcode() != HloOpcode::kGetTupleElement ||
|
||||||
|
operand->tuple_index() != operand_number) {
|
||||||
|
can_simplify = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (top_tuple == nullptr) {
|
||||||
|
top_tuple = operand->mutable_operand(0);
|
||||||
|
if (!ShapeUtil::Compatible(top_tuple->shape(), tuple->shape())) {
|
||||||
|
can_simplify = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else if (top_tuple != operand->operand(0)) {
|
||||||
|
can_simplify = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (can_simplify && top_tuple != nullptr) {
|
||||||
|
changed = true;
|
||||||
|
TF_RETURN_IF_ERROR(tuple->parent()->ReplaceInstruction(tuple, top_tuple));
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
|
StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
|
||||||
// Initially add all GTE and Tuple instructions to the worklist.
|
// Initially add all GTE and Tuple instructions to the worklist.
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
@ -43,46 +73,7 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
|
|||||||
}
|
}
|
||||||
for (auto* instruction : computation->MakeInstructionPostOrder()) {
|
for (auto* instruction : computation->MakeInstructionPostOrder()) {
|
||||||
if (instruction->opcode() == HloOpcode::kTuple) {
|
if (instruction->opcode() == HloOpcode::kTuple) {
|
||||||
// Collapse the following structure into just 'Tuple-shaped Op':
|
TF_ASSIGN_OR_RETURN(changed, RemoveWholeTuple(instruction));
|
||||||
//
|
|
||||||
// Tuple-shaped Op
|
|
||||||
// |
|
|
||||||
// +-----+-----+
|
|
||||||
// | | |
|
|
||||||
// GTE GTE GTE
|
|
||||||
// | | |
|
|
||||||
// +-----+-----+
|
|
||||||
// |
|
|
||||||
// Tuple
|
|
||||||
//
|
|
||||||
HloInstruction* top_tuple = nullptr;
|
|
||||||
bool can_simplify = true;
|
|
||||||
for (int64 operand_number = 0;
|
|
||||||
operand_number < instruction->operand_count(); ++operand_number) {
|
|
||||||
HloInstruction* operand =
|
|
||||||
instruction->mutable_operand(operand_number);
|
|
||||||
if (operand->opcode() != HloOpcode::kGetTupleElement ||
|
|
||||||
operand->tuple_index() != operand_number) {
|
|
||||||
can_simplify = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (top_tuple == nullptr) {
|
|
||||||
top_tuple = operand->mutable_operand(0);
|
|
||||||
if (!ShapeUtil::Compatible(top_tuple->shape(),
|
|
||||||
instruction->shape())) {
|
|
||||||
can_simplify = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} else if (top_tuple != operand->operand(0)) {
|
|
||||||
can_simplify = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (can_simplify && top_tuple != nullptr) {
|
|
||||||
changed = true;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
computation->ReplaceInstruction(instruction, top_tuple));
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
auto ancestor = instruction->LatestNonGteAncestorAndIndex();
|
auto ancestor = instruction->LatestNonGteAncestorAndIndex();
|
||||||
if (ancestor.first == instruction) {
|
if (ancestor.first == instruction) {
|
||||||
@ -102,6 +93,11 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
|
|||||||
// GTE
|
// GTE
|
||||||
// |
|
// |
|
||||||
// GTE
|
// GTE
|
||||||
|
//
|
||||||
|
// Note that this deletes the Tuple instruction altogether. In addition,
|
||||||
|
// if only a subset of tuple's elements are used, this transform
|
||||||
|
// optimizes them one at a time, and after the last use is optimized,
|
||||||
|
// the Tuple will also be deleted.
|
||||||
if (ShapeUtil::Compatible(ancestor.first->shape(),
|
if (ShapeUtil::Compatible(ancestor.first->shape(),
|
||||||
instruction->shape())) {
|
instruction->shape())) {
|
||||||
changed = true;
|
changed = true;
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||||
|
|
||||||
@ -41,6 +42,20 @@ class TupleSimplifier : public HloModulePass {
|
|||||||
// apart from the module's entry computation. This is used by Graphcore's
|
// apart from the module's entry computation. This is used by Graphcore's
|
||||||
// backend.
|
// backend.
|
||||||
bool exclude_entry_computation_;
|
bool exclude_entry_computation_;
|
||||||
|
|
||||||
|
// Collapse the following structure into just 'Tuple-shaped Op':
|
||||||
|
//
|
||||||
|
// Tuple-shaped Op
|
||||||
|
// |
|
||||||
|
// +-----+-----+
|
||||||
|
// | | |
|
||||||
|
// GTE GTE GTE
|
||||||
|
// | | |
|
||||||
|
// +-----+-----+
|
||||||
|
// |
|
||||||
|
// Tuple
|
||||||
|
//
|
||||||
|
StatusOr<bool> RemoveWholeTuple(HloInstruction* tuple);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user