418 lines
16 KiB
C++
418 lines
16 KiB
C++
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
|
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "absl/strings/str_cat.h"
|
|
#include "absl/strings/str_format.h"
|
|
#include "absl/strings/str_join.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
|
#include "tensorflow/compiler/xla/shape_util.h"
|
|
#include "tensorflow/compiler/xla/status_macros.h"
|
|
#include "tensorflow/compiler/xla/statusor.h"
|
|
#include "tensorflow/compiler/xla/types.h"
|
|
#include "tensorflow/compiler/xla/util.h"
|
|
#include "tensorflow/core/lib/core/errors.h"
|
|
#include "tensorflow/core/platform/logging.h"
|
|
|
|
namespace xla {
|
|
|
|
bool HloOrdering::ExecutesBefore(const HloInstruction* a,
|
|
const HloInstruction* b) const {
|
|
// 'a' and 'b' may be in different computations. In this case, find the
|
|
// callgraph ancestor instructions which call (potentially transitively) the
|
|
// computations containing 'a' and 'b' and use these ancestor instructions to
|
|
// compare order.
|
|
const HloInstruction* a_ancestor;
|
|
const HloInstruction* b_ancestor;
|
|
std::tie(a_ancestor, b_ancestor) =
|
|
call_graph_->NearestAncestorsInSameComputation(
|
|
const_cast<HloInstruction*>(a), const_cast<HloInstruction*>(b));
|
|
|
|
if (a_ancestor == nullptr) {
|
|
// Ancestors in a common computation could not be found so consider the
|
|
// instructions 'a' and 'b' to be unordered.
|
|
return false;
|
|
}
|
|
// a_ancestor and b_ancestor must be either both null or both non-null.
|
|
CHECK_NE(b_ancestor, nullptr);
|
|
CHECK_EQ(a_ancestor->parent(), b_ancestor->parent());
|
|
|
|
// If the common ancestor is a while instruction there is an additional
|
|
// ordering criteria which may apply. The condition computation is considered
|
|
// to execute before the body computation so if 'a' is in the condition and
|
|
// 'b' is in the body, then 'a' executes before 'b'.
|
|
if (a_ancestor == b_ancestor && a_ancestor->opcode() == HloOpcode::kWhile) {
|
|
const HloComputation* body = a_ancestor->while_body();
|
|
const HloComputation* condition = a_ancestor->while_condition();
|
|
if (call_graph_->InstructionIsNestedIn(a, condition) &&
|
|
call_graph_->InstructionIsNestedIn(b, body)) {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
// If the common ancestor is a conditional instruction, even though the branch
|
|
// computations are not really ordered per-se, we define the 0th branch
|
|
// computation to be ordered before the 1st one, before the 2nd and so forth.
|
|
// This ensures that buffers can still be shared among branch computations
|
|
// as they will forcibly have disjoint liveness.
|
|
if (a_ancestor == b_ancestor &&
|
|
(a_ancestor->opcode() == HloOpcode::kConditional)) {
|
|
int a_branch = -1;
|
|
int b_branch = -1;
|
|
for (int j = 0; j < a_ancestor->branch_count(); ++j) {
|
|
if (call_graph_->InstructionIsNestedIn(
|
|
a, a_ancestor->branch_computation(j))) {
|
|
a_branch = j;
|
|
}
|
|
if (call_graph_->InstructionIsNestedIn(
|
|
b, a_ancestor->branch_computation(j))) {
|
|
b_branch = j;
|
|
}
|
|
}
|
|
if (a_branch != -1 && a_branch < b_branch) {
|
|
return true;
|
|
}
|
|
// If 'b' is the conditional ancestor, and 'a' is within a branch
|
|
// computation, 'a' executes before 'b'.
|
|
if (b == a_ancestor && a_branch != -1) {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor);
|
|
}
|
|
|
|
bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
|
|
// Entry parameter should always be defined before other instructions.
|
|
const HloModule* module = b.defining_instruction()->parent()->parent();
|
|
if (b.defining_instruction()->parent() == module->entry_computation() &&
|
|
b.defining_instruction()->opcode() == HloOpcode::kParameter) {
|
|
return false;
|
|
}
|
|
|
|
if (a.defining_instruction()->parent() == module->entry_computation() &&
|
|
a.defining_instruction()->opcode() == HloOpcode::kParameter) {
|
|
return true;
|
|
}
|
|
|
|
// Phi values require special handling. Because XLA does not have a phi
|
|
// instruction, the definition instruction of the phis values are
|
|
// placeholders: either the subcomputation parameter (body or condition) or
|
|
// the while instruction. However, the program point where these values are
|
|
// logically defined does not necessarily coincide exactly with program point
|
|
// of these place-holder instructions. So we explicitly define the following
|
|
// order for phi values:
|
|
//
|
|
// body/condition parameter phi:
|
|
// Defined before all values defined in its computation excepting other
|
|
// phis.
|
|
//
|
|
// while phi:
|
|
// defined after all values defined in the condition or body.
|
|
//
|
|
auto is_body_or_condition_phi = [](const HloValue& v) {
|
|
return v.is_phi() &&
|
|
v.defining_instruction()->opcode() == HloOpcode::kParameter;
|
|
};
|
|
if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
|
|
call_graph_->InstructionIsNestedIn(b.defining_instruction(),
|
|
a.defining_instruction()->parent())) {
|
|
return true;
|
|
}
|
|
if (is_body_or_condition_phi(b) &&
|
|
call_graph_->InstructionIsNestedIn(a.defining_instruction(),
|
|
b.defining_instruction()->parent())) {
|
|
return false;
|
|
}
|
|
|
|
// If 'b' is a while phi and 'a' is in the body or condition, then 'a'
|
|
// executes before 'b'.
|
|
if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
|
|
(call_graph_->InstructionIsNestedIn(
|
|
a.defining_instruction(), b.defining_instruction()->while_body()) ||
|
|
call_graph_->InstructionIsNestedIn(
|
|
a.defining_instruction(),
|
|
b.defining_instruction()->while_condition()))) {
|
|
return true;
|
|
}
|
|
// If 'b' is a conditional phi and 'a' is in some branch computation, then 'a'
|
|
// executes before 'b'.
|
|
if (b.is_phi() &&
|
|
b.defining_instruction()->opcode() == HloOpcode::kConditional) {
|
|
for (int j = 0; j < b.defining_instruction()->branch_count(); ++j) {
|
|
if (call_graph_->InstructionIsNestedIn(
|
|
a.defining_instruction(),
|
|
b.defining_instruction()->branch_computation(j))) {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
return ExecutesBefore(a.defining_instruction(), b.defining_instruction());
|
|
}
|
|
|
|
/* static */
|
|
bool HloOrdering::UseIsBeforeValueDefinition(
|
|
const HloUse& use, const HloValue& value,
|
|
const HloDataflowAnalysis& dataflow) const {
|
|
VLOG(4) << "UseIsBeforeValueDefinition(use=" << use
|
|
<< ", value=" << value.ToShortString() << ")";
|
|
if (ExecutesBefore(use.instruction, value.defining_instruction())) {
|
|
VLOG(4) << " use instruction executes before value-defining instruction";
|
|
return true;
|
|
}
|
|
|
|
// If the use is at the instruction where the value is defined, then the use
|
|
// is before the def if the instruction allows buffer sharing (in place
|
|
// computation).
|
|
if (use.instruction == value.defining_instruction() &&
|
|
dataflow.CanShareOperandBufferWithUser(
|
|
use.instruction->mutable_operand(use.operand_number),
|
|
use.operand_index, value.defining_instruction(),
|
|
value.defining_index())) {
|
|
VLOG(4) << " use is value def, and instruction can share use buffer";
|
|
return true;
|
|
}
|
|
|
|
// The use at a while is an input to a phi, and logically occurs before values
|
|
// are defined in the body. Note that the use is *not* before the value if the
|
|
// value is defined in the condition and is not the condition parameter, since
|
|
// the input of a while's life range is only ended at the start the body.
|
|
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
|
const HloInstruction* xla_while = use.instruction;
|
|
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
|
|
xla_while->while_body())) {
|
|
VLOG(4) << " use is while " << use.instruction->name()
|
|
<< " and def is in body";
|
|
return true;
|
|
}
|
|
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
|
|
xla_while->while_condition())) {
|
|
if (value.defining_instruction() !=
|
|
xla_while->while_condition()->parameter_instruction(0)) {
|
|
VLOG(4) << " use is while " << use.instruction->name()
|
|
<< " and def is in condition and is not the parameter";
|
|
return false;
|
|
} else {
|
|
VLOG(4) << " use is while " << use.instruction->name()
|
|
<< " and def is in condition and is the parameter";
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Similarly if the value is defined at a while, it logically occurs after any
|
|
// uses in the body or condition computations.
|
|
if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
|
|
CHECK(value.is_phi());
|
|
const HloInstruction* xla_while = value.defining_instruction();
|
|
if (call_graph_->InstructionIsNestedIn(use.instruction,
|
|
xla_while->while_body()) ||
|
|
call_graph_->InstructionIsNestedIn(use.instruction,
|
|
xla_while->while_condition())) {
|
|
VLOG(4) << " value is while " << value.defining_instruction()->name()
|
|
<< " and use is in condition or body";
|
|
return true;
|
|
}
|
|
}
|
|
|
|
// The use at a call occurs before values that are defined in the called
|
|
// computation.
|
|
if (use.instruction->opcode() == HloOpcode::kCall) {
|
|
const HloInstruction* call = use.instruction;
|
|
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
|
|
call->to_apply())) {
|
|
VLOG(4) << " use is call " << use.instruction->name()
|
|
<< " and def is in called computation";
|
|
return true;
|
|
}
|
|
}
|
|
|
|
if (use.instruction->opcode() == HloOpcode::kConditional) {
|
|
const HloInstruction* conditional = use.instruction;
|
|
for (int j = 0; j < conditional->branch_count(); ++j) {
|
|
if (call_graph_->InstructionIsNestedIn(
|
|
value.defining_instruction(),
|
|
conditional->branch_computation(j))) {
|
|
VLOG(4) << " use is conditional " << use.instruction->name()
|
|
<< " and def is in " << j << "th branch computation";
|
|
return true;
|
|
}
|
|
}
|
|
if (value.defining_instruction() == use.instruction) {
|
|
VLOG(4) << " use is conditional " << use << " and def is "
|
|
<< value.ToShortString();
|
|
return true;
|
|
}
|
|
}
|
|
|
|
VLOG(4) << " use is not before value";
|
|
return false;
|
|
}
|
|
|
|
bool HloOrdering::LiveRangeStrictlyBefore(
|
|
const HloValue& a, const HloValue& b,
|
|
const HloDataflowAnalysis& dataflow) const {
|
|
VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
|
|
<< ", b = " << b.ToShortString() << ")";
|
|
if (!IsDefinedBefore(a, b)) {
|
|
VLOG(4) << a << " not defined before " << b;
|
|
return false;
|
|
}
|
|
|
|
if (a.live_out_of_module()) {
|
|
VLOG(4) << a << " is live out of module and not defined before " << b;
|
|
return false;
|
|
}
|
|
|
|
// If the root instruction aliases the buffer 'a', the live range of 'a' is
|
|
// until the end of the computation and can never be strictly before another
|
|
// buffer nested in the same computation. This is needed to prevent the root
|
|
// instruction's buffers from being reused by later instructions even when
|
|
// the root is not the last instruction in the schedule.
|
|
for (const HloPosition& pos : a.positions()) {
|
|
if (pos.instruction->parent()->root_instruction() == pos.instruction &&
|
|
call_graph().InstructionIsNestedIn(b.instruction(),
|
|
pos.instruction->parent())) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// All uses of 'a' must be before 'b' is defined.
|
|
for (const HloUse& use : a.uses()) {
|
|
if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
|
|
use.instruction)) {
|
|
continue;
|
|
}
|
|
if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
|
|
VLOG(4) << "use of " << a << " (" << use << ") not before " << b
|
|
<< " is defined";
|
|
return false;
|
|
}
|
|
}
|
|
|
|
if (a.instruction()->parent() == b.instruction()->parent()) {
|
|
for (const HloPosition& position : a.positions()) {
|
|
if (position.instruction ==
|
|
a.instruction()->parent()->root_instruction()) {
|
|
VLOG(4) << a << " is live out of computation and defined before " << b
|
|
<< " which is in same computation";
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
|
|
const HloDataflowAnalysis& dataflow) const {
|
|
// Buffers without disjoint liveness may interfere.
|
|
return !LiveRangeStrictlyBefore(a, b, dataflow) &&
|
|
!LiveRangeStrictlyBefore(b, a, dataflow);
|
|
}
|
|
|
|
PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
|
|
: HloOrdering(module) {}
|
|
|
|
bool PredecessorHloOrdering::ExecutesBeforeInSameComputation(
|
|
const HloInstruction* a, const HloInstruction* b) const {
|
|
CHECK_EQ(a->parent(), b->parent());
|
|
|
|
// 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'.
|
|
return a != b && predecessors_.at(a->parent())->IsReachable(a, b);
|
|
}
|
|
|
|
string PredecessorHloOrdering::ToStringHelper(const string& name) const {
|
|
std::vector<string> pieces;
|
|
pieces.push_back(name);
|
|
for (auto* computation : module_->MakeNonfusionComputations()) {
|
|
pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
|
|
const auto all = computation->MakeInstructionPostOrder();
|
|
for (auto instruction : all) {
|
|
pieces.push_back(
|
|
absl::StrFormat(" %s predecessors:", instruction->name()));
|
|
for (auto predecessor : all) {
|
|
if (predecessors_.at(computation)
|
|
->IsReachable(predecessor, instruction)) {
|
|
pieces.push_back(absl::StrFormat(" %s", predecessor->name()));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return absl::StrJoin(pieces, "\n");
|
|
}
|
|
|
|
DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
|
|
: PredecessorHloOrdering(module) {
|
|
// Compute predecessor relationships between all instructions to determine
|
|
// ordering based on dependencies. ExecutesBefore will return true iff there
|
|
// exists a path in the HLO computation graph from 'a' to 'b'.
|
|
for (auto* computation : module->MakeNonfusionComputations()) {
|
|
predecessors_.emplace(computation, HloReachabilityMap::Build(computation));
|
|
}
|
|
}
|
|
|
|
string DependencyHloOrdering::ToString() const {
|
|
return ToStringHelper("DependencyHloOrdering");
|
|
}
|
|
|
|
SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule)
|
|
: HloOrdering(schedule.module()), schedule_(schedule) {
|
|
Initialize();
|
|
}
|
|
|
|
SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule)
|
|
: HloOrdering(schedule.module()), schedule_(std::move(schedule)) {
|
|
Initialize();
|
|
}
|
|
|
|
void SequentialHloOrdering::Initialize() {
|
|
// Create a map from instruction to its order position.
|
|
TF_DCHECK_OK(schedule_.Verify());
|
|
for (const auto& computation_sequence : schedule_.sequences()) {
|
|
const auto& order = computation_sequence.second.instructions();
|
|
for (int i = 0; i < order.size(); ++i) {
|
|
InsertOrDie(&order_position_, order[i], i);
|
|
}
|
|
}
|
|
}
|
|
|
|
bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
|
|
const HloInstruction* a, const HloInstruction* b) const {
|
|
CHECK_EQ(a->parent(), b->parent());
|
|
// If either instruction is not in the order, then 'a' and 'b' are unordered.
|
|
if (!order_position_.contains(a) || !order_position_.contains(b)) {
|
|
return false;
|
|
}
|
|
return order_position_.at(a) < order_position_.at(b);
|
|
}
|
|
|
|
const HloInstructionSequence* SequentialHloOrdering::SequentialOrder(
|
|
const HloComputation& computation) const {
|
|
return schedule_.is_computation_scheduled(&computation)
|
|
? &schedule_.sequence(&computation)
|
|
: nullptr;
|
|
}
|
|
|
|
string SequentialHloOrdering::ToString() const {
|
|
return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString());
|
|
}
|
|
|
|
} // namespace xla
|