Explain that optimizing partially used tuples within a single computation falls out of the existing optimizations. There is still the option to optimize partially used tuples across computations. Will look into that in a separate CL. PiperOrigin-RevId: 315319714 Change-Id: Ifcc41929cb8213cab661ccefea00138e099d551e
119 lines
4.3 KiB
C++
119 lines
4.3 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
|
|
|
|
#include <queue>
|
|
|
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
|
#include "tensorflow/compiler/xla/status_macros.h"
|
|
#include "tensorflow/compiler/xla/types.h"
|
|
#include "tensorflow/compiler/xla/util.h"
|
|
#include "tensorflow/core/lib/core/errors.h"
|
|
#include "tensorflow/core/lib/core/status.h"
|
|
#include "tensorflow/core/platform/logging.h"
|
|
#include "tensorflow/core/platform/types.h"
|
|
|
|
namespace xla {
|
|
|
|
TupleSimplifier::TupleSimplifier(bool exclude_entry_computation)
|
|
: exclude_entry_computation_(exclude_entry_computation) {}
|
|
|
|
StatusOr<bool> TupleSimplifier::RemoveWholeTuple(HloInstruction* tuple) {
|
|
bool changed = false;
|
|
HloInstruction* top_tuple = nullptr;
|
|
bool can_simplify = true;
|
|
for (int64 operand_number = 0; operand_number < tuple->operand_count();
|
|
++operand_number) {
|
|
HloInstruction* operand = tuple->mutable_operand(operand_number);
|
|
if (operand->opcode() != HloOpcode::kGetTupleElement ||
|
|
operand->tuple_index() != operand_number) {
|
|
can_simplify = false;
|
|
break;
|
|
}
|
|
if (top_tuple == nullptr) {
|
|
top_tuple = operand->mutable_operand(0);
|
|
if (!ShapeUtil::Compatible(top_tuple->shape(), tuple->shape())) {
|
|
can_simplify = false;
|
|
break;
|
|
}
|
|
} else if (top_tuple != operand->operand(0)) {
|
|
can_simplify = false;
|
|
break;
|
|
}
|
|
}
|
|
if (can_simplify && top_tuple != nullptr) {
|
|
changed = true;
|
|
TF_RETURN_IF_ERROR(tuple->parent()->ReplaceInstruction(tuple, top_tuple));
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
|
|
// Initially add all GTE and Tuple instructions to the worklist.
|
|
bool changed = false;
|
|
for (auto* computation : module->computations()) {
|
|
if (exclude_entry_computation_ &&
|
|
computation == module->entry_computation()) {
|
|
continue;
|
|
}
|
|
for (auto* instruction : computation->MakeInstructionPostOrder()) {
|
|
if (instruction->opcode() == HloOpcode::kTuple) {
|
|
TF_ASSIGN_OR_RETURN(changed, RemoveWholeTuple(instruction));
|
|
} else {
|
|
auto ancestor = instruction->LatestNonGteAncestorAndIndex();
|
|
if (ancestor.first == instruction) {
|
|
continue;
|
|
}
|
|
// If possible replace a chain of GTE with the operation which produces
|
|
// the element. For example, replace uses of GTE with below with just
|
|
// 'Op' (assuming 'Op' is at the index of the GTE instruction):
|
|
//
|
|
// ... Op ...
|
|
// \ | /
|
|
// Tuple
|
|
// |
|
|
// GTE
|
|
// ...
|
|
// |
|
|
// GTE
|
|
// |
|
|
// GTE
|
|
//
|
|
// Note that this deletes the Tuple instruction altogether. In addition,
|
|
// if only a subset of tuple's elements are used, this transform
|
|
// optimizes them one at a time, and after the last use is optimized,
|
|
// the Tuple will also be deleted.
|
|
if (ShapeUtil::Compatible(ancestor.first->shape(),
|
|
instruction->shape())) {
|
|
changed = true;
|
|
TF_RETURN_IF_ERROR(
|
|
computation->ReplaceInstruction(instruction, ancestor.first));
|
|
} else if (ancestor.first->opcode() == HloOpcode::kTuple) {
|
|
changed = true;
|
|
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
|
|
instruction,
|
|
ancestor.first->mutable_operand(ancestor.second[0])));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
} // namespace xla
|