418 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			418 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 | 
						|
 | 
						|
Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
you may not use this file except in compliance with the License.
 | 
						|
You may obtain a copy of the License at
 | 
						|
 | 
						|
    http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 | 
						|
Unless required by applicable law or agreed to in writing, software
 | 
						|
distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
See the License for the specific language governing permissions and
 | 
						|
limitations under the License.
 | 
						|
==============================================================================*/
 | 
						|
 | 
						|
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
 | 
						|
 | 
						|
#include <utility>
 | 
						|
#include <vector>
 | 
						|
 | 
						|
#include "absl/strings/str_cat.h"
 | 
						|
#include "absl/strings/str_format.h"
 | 
						|
#include "absl/strings/str_join.h"
 | 
						|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
 | 
						|
#include "tensorflow/compiler/xla/shape_util.h"
 | 
						|
#include "tensorflow/compiler/xla/status_macros.h"
 | 
						|
#include "tensorflow/compiler/xla/statusor.h"
 | 
						|
#include "tensorflow/compiler/xla/types.h"
 | 
						|
#include "tensorflow/compiler/xla/util.h"
 | 
						|
#include "tensorflow/core/lib/core/errors.h"
 | 
						|
#include "tensorflow/core/platform/logging.h"
 | 
						|
 | 
						|
namespace xla {
 | 
						|
 | 
						|
bool HloOrdering::ExecutesBefore(const HloInstruction* a,
 | 
						|
                                 const HloInstruction* b) const {
 | 
						|
  // 'a' and 'b' may be in different computations. In this case, find the
 | 
						|
  // callgraph ancestor instructions which call (potentially transitively) the
 | 
						|
  // computations containing 'a' and 'b' and use these ancestor instructions to
 | 
						|
  // compare order.
 | 
						|
  const HloInstruction* a_ancestor;
 | 
						|
  const HloInstruction* b_ancestor;
 | 
						|
  std::tie(a_ancestor, b_ancestor) =
 | 
						|
      call_graph_->NearestAncestorsInSameComputation(
 | 
						|
          const_cast<HloInstruction*>(a), const_cast<HloInstruction*>(b));
 | 
						|
 | 
						|
  if (a_ancestor == nullptr) {
 | 
						|
    // Ancestors in a common computation could not be found so consider the
 | 
						|
    // instructions 'a' and 'b' to be unordered.
 | 
						|
    return false;
 | 
						|
  }
 | 
						|
  // a_ancestor and b_ancestor must be either both null or both non-null.
 | 
						|
  CHECK_NE(b_ancestor, nullptr);
 | 
						|
  CHECK_EQ(a_ancestor->parent(), b_ancestor->parent());
 | 
						|
 | 
						|
  // If the common ancestor is a while instruction there is an additional
 | 
						|
  // ordering criteria which may apply. The condition computation is considered
 | 
						|
  // to execute before the body computation so if 'a' is in the condition and
 | 
						|
  // 'b' is in the body, then 'a' executes before 'b'.
 | 
						|
  if (a_ancestor == b_ancestor && a_ancestor->opcode() == HloOpcode::kWhile) {
 | 
						|
    const HloComputation* body = a_ancestor->while_body();
 | 
						|
    const HloComputation* condition = a_ancestor->while_condition();
 | 
						|
    if (call_graph_->InstructionIsNestedIn(a, condition) &&
 | 
						|
        call_graph_->InstructionIsNestedIn(b, body)) {
 | 
						|
      return true;
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  // If the common ancestor is a conditional instruction, even though the branch
 | 
						|
  // computations are not really ordered per-se, we define the 0th branch
 | 
						|
  // computation to be ordered before the 1st one, before the 2nd and so forth.
 | 
						|
  // This ensures that buffers can still be shared among branch computations
 | 
						|
  // as they will forcibly have disjoint liveness.
 | 
						|
  if (a_ancestor == b_ancestor &&
 | 
						|
      (a_ancestor->opcode() == HloOpcode::kConditional)) {
 | 
						|
    int a_branch = -1;
 | 
						|
    int b_branch = -1;
 | 
						|
    for (int j = 0; j < a_ancestor->branch_count(); ++j) {
 | 
						|
      if (call_graph_->InstructionIsNestedIn(
 | 
						|
              a, a_ancestor->branch_computation(j))) {
 | 
						|
        a_branch = j;
 | 
						|
      }
 | 
						|
      if (call_graph_->InstructionIsNestedIn(
 | 
						|
              b, a_ancestor->branch_computation(j))) {
 | 
						|
        b_branch = j;
 | 
						|
      }
 | 
						|
    }
 | 
						|
    if (a_branch != -1 && a_branch < b_branch) {
 | 
						|
      return true;
 | 
						|
    }
 | 
						|
    // If 'b' is the conditional ancestor, and 'a' is within a branch
 | 
						|
    // computation, 'a' executes before 'b'.
 | 
						|
    if (b == a_ancestor && a_branch != -1) {
 | 
						|
      return true;
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor);
 | 
						|
}
 | 
						|
 | 
						|
bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
 | 
						|
  // Entry parameter should always be defined before other instructions.
 | 
						|
  const HloModule* module = b.defining_instruction()->parent()->parent();
 | 
						|
  if (b.defining_instruction()->parent() == module->entry_computation() &&
 | 
						|
      b.defining_instruction()->opcode() == HloOpcode::kParameter) {
 | 
						|
    return false;
 | 
						|
  }
 | 
						|
 | 
						|
  if (a.defining_instruction()->parent() == module->entry_computation() &&
 | 
						|
      a.defining_instruction()->opcode() == HloOpcode::kParameter) {
 | 
						|
    return true;
 | 
						|
  }
 | 
						|
 | 
						|
  // Phi values require special handling. Because XLA does not have a phi
 | 
						|
  // instruction, the definition instruction of the phis values are
 | 
						|
  // placeholders: either the subcomputation parameter (body or condition) or
 | 
						|
  // the while instruction. However, the program point where these values are
 | 
						|
  // logically defined does not necessarily coincide exactly with program point
 | 
						|
  // of these place-holder instructions. So we explicitly define the following
 | 
						|
  // order for phi values:
 | 
						|
  //
 | 
						|
  //   body/condition parameter phi:
 | 
						|
  //     Defined before all values defined in its computation excepting other
 | 
						|
  //     phis.
 | 
						|
  //
 | 
						|
  //   while phi:
 | 
						|
  //     defined after all values defined in the condition or body.
 | 
						|
  //
 | 
						|
  auto is_body_or_condition_phi = [](const HloValue& v) {
 | 
						|
    return v.is_phi() &&
 | 
						|
           v.defining_instruction()->opcode() == HloOpcode::kParameter;
 | 
						|
  };
 | 
						|
  if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
 | 
						|
      call_graph_->InstructionIsNestedIn(b.defining_instruction(),
 | 
						|
                                         a.defining_instruction()->parent())) {
 | 
						|
    return true;
 | 
						|
  }
 | 
						|
  if (is_body_or_condition_phi(b) &&
 | 
						|
      call_graph_->InstructionIsNestedIn(a.defining_instruction(),
 | 
						|
                                         b.defining_instruction()->parent())) {
 | 
						|
    return false;
 | 
						|
  }
 | 
						|
 | 
						|
  // If 'b' is a while phi and 'a' is in the body or condition, then 'a'
 | 
						|
  // executes before 'b'.
 | 
						|
  if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
 | 
						|
      (call_graph_->InstructionIsNestedIn(
 | 
						|
           a.defining_instruction(), b.defining_instruction()->while_body()) ||
 | 
						|
       call_graph_->InstructionIsNestedIn(
 | 
						|
           a.defining_instruction(),
 | 
						|
           b.defining_instruction()->while_condition()))) {
 | 
						|
    return true;
 | 
						|
  }
 | 
						|
  // If 'b' is a conditional phi and 'a' is in some branch computation, then 'a'
 | 
						|
  // executes before 'b'.
 | 
						|
  if (b.is_phi() &&
 | 
						|
      b.defining_instruction()->opcode() == HloOpcode::kConditional) {
 | 
						|
    for (int j = 0; j < b.defining_instruction()->branch_count(); ++j) {
 | 
						|
      if (call_graph_->InstructionIsNestedIn(
 | 
						|
              a.defining_instruction(),
 | 
						|
              b.defining_instruction()->branch_computation(j))) {
 | 
						|
        return true;
 | 
						|
      }
 | 
						|
    }
 | 
						|
  }
 | 
						|
  return ExecutesBefore(a.defining_instruction(), b.defining_instruction());
 | 
						|
}
 | 
						|
 | 
						|
/* static */
 | 
						|
bool HloOrdering::UseIsBeforeValueDefinition(
 | 
						|
    const HloUse& use, const HloValue& value,
 | 
						|
    const HloDataflowAnalysis& dataflow) const {
 | 
						|
  VLOG(4) << "UseIsBeforeValueDefinition(use=" << use
 | 
						|
          << ", value=" << value.ToShortString() << ")";
 | 
						|
  if (ExecutesBefore(use.instruction, value.defining_instruction())) {
 | 
						|
    VLOG(4) << "  use instruction executes before value-defining instruction";
 | 
						|
    return true;
 | 
						|
  }
 | 
						|
 | 
						|
  // If the use is at the instruction where the value is defined, then the use
 | 
						|
  // is before the def if the instruction allows buffer sharing (in place
 | 
						|
  // computation).
 | 
						|
  if (use.instruction == value.defining_instruction() &&
 | 
						|
      dataflow.CanShareOperandBufferWithUser(
 | 
						|
          use.instruction->mutable_operand(use.operand_number),
 | 
						|
          use.operand_index, value.defining_instruction(),
 | 
						|
          value.defining_index())) {
 | 
						|
    VLOG(4) << "  use is value def, and instruction can share use buffer";
 | 
						|
    return true;
 | 
						|
  }
 | 
						|
 | 
						|
  // The use at a while is an input to a phi, and logically occurs before values
 | 
						|
  // are defined in the body. Note that the use is *not* before the value if the
 | 
						|
  // value is defined in the condition and is not the condition parameter, since
 | 
						|
  // the input of a while's life range is only ended at the start the body.
 | 
						|
  if (use.instruction->opcode() == HloOpcode::kWhile) {
 | 
						|
    const HloInstruction* xla_while = use.instruction;
 | 
						|
    if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
 | 
						|
                                           xla_while->while_body())) {
 | 
						|
      VLOG(4) << "  use is while " << use.instruction->name()
 | 
						|
              << " and def is in body";
 | 
						|
      return true;
 | 
						|
    }
 | 
						|
    if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
 | 
						|
                                           xla_while->while_condition())) {
 | 
						|
      if (value.defining_instruction() !=
 | 
						|
          xla_while->while_condition()->parameter_instruction(0)) {
 | 
						|
        VLOG(4) << "  use is while " << use.instruction->name()
 | 
						|
                << " and def is in condition and is not the parameter";
 | 
						|
        return false;
 | 
						|
      } else {
 | 
						|
        VLOG(4) << "  use is while " << use.instruction->name()
 | 
						|
                << " and def is in condition and is the parameter";
 | 
						|
        return true;
 | 
						|
      }
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  // Similarly if the value is defined at a while, it logically occurs after any
 | 
						|
  // uses in the body or condition computations.
 | 
						|
  if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
 | 
						|
    CHECK(value.is_phi());
 | 
						|
    const HloInstruction* xla_while = value.defining_instruction();
 | 
						|
    if (call_graph_->InstructionIsNestedIn(use.instruction,
 | 
						|
                                           xla_while->while_body()) ||
 | 
						|
        call_graph_->InstructionIsNestedIn(use.instruction,
 | 
						|
                                           xla_while->while_condition())) {
 | 
						|
      VLOG(4) << "  value is while " << value.defining_instruction()->name()
 | 
						|
              << " and use is in condition or body";
 | 
						|
      return true;
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  // The use at a call occurs before values that are defined in the called
 | 
						|
  // computation.
 | 
						|
  if (use.instruction->opcode() == HloOpcode::kCall) {
 | 
						|
    const HloInstruction* call = use.instruction;
 | 
						|
    if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
 | 
						|
                                           call->to_apply())) {
 | 
						|
      VLOG(4) << "  use is call " << use.instruction->name()
 | 
						|
              << " and def is in called computation";
 | 
						|
      return true;
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  if (use.instruction->opcode() == HloOpcode::kConditional) {
 | 
						|
    const HloInstruction* conditional = use.instruction;
 | 
						|
    for (int j = 0; j < conditional->branch_count(); ++j) {
 | 
						|
      if (call_graph_->InstructionIsNestedIn(
 | 
						|
              value.defining_instruction(),
 | 
						|
              conditional->branch_computation(j))) {
 | 
						|
        VLOG(4) << "  use is conditional " << use.instruction->name()
 | 
						|
                << " and def is in " << j << "th branch computation";
 | 
						|
        return true;
 | 
						|
      }
 | 
						|
    }
 | 
						|
    if (value.defining_instruction() == use.instruction) {
 | 
						|
      VLOG(4) << "  use is conditional " << use << " and def is "
 | 
						|
              << value.ToShortString();
 | 
						|
      return true;
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  VLOG(4) << "  use is not before value";
 | 
						|
  return false;
 | 
						|
}
 | 
						|
 | 
						|
bool HloOrdering::LiveRangeStrictlyBefore(
 | 
						|
    const HloValue& a, const HloValue& b,
 | 
						|
    const HloDataflowAnalysis& dataflow) const {
 | 
						|
  VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
 | 
						|
          << ", b = " << b.ToShortString() << ")";
 | 
						|
  if (!IsDefinedBefore(a, b)) {
 | 
						|
    VLOG(4) << a << " not defined before " << b;
 | 
						|
    return false;
 | 
						|
  }
 | 
						|
 | 
						|
  if (a.live_out_of_module()) {
 | 
						|
    VLOG(4) << a << " is live out of module and not defined before " << b;
 | 
						|
    return false;
 | 
						|
  }
 | 
						|
 | 
						|
  // If the root instruction aliases the buffer 'a', the live range of 'a' is
 | 
						|
  // until the end of the computation and can never be strictly before another
 | 
						|
  // buffer nested in the same computation. This is needed to prevent the root
 | 
						|
  // instruction's buffers from being reused by later instructions even when
 | 
						|
  // the root is not the last instruction in the schedule.
 | 
						|
  for (const HloPosition& pos : a.positions()) {
 | 
						|
    if (pos.instruction->parent()->root_instruction() == pos.instruction &&
 | 
						|
        call_graph().InstructionIsNestedIn(b.instruction(),
 | 
						|
                                           pos.instruction->parent())) {
 | 
						|
      return false;
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  // All uses of 'a' must be before 'b' is defined.
 | 
						|
  for (const HloUse& use : a.uses()) {
 | 
						|
    if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
 | 
						|
                                         use.instruction)) {
 | 
						|
      continue;
 | 
						|
    }
 | 
						|
    if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
 | 
						|
      VLOG(4) << "use of " << a << " (" << use << ") not before " << b
 | 
						|
              << " is defined";
 | 
						|
      return false;
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  if (a.instruction()->parent() == b.instruction()->parent()) {
 | 
						|
    for (const HloPosition& position : a.positions()) {
 | 
						|
      if (position.instruction ==
 | 
						|
          a.instruction()->parent()->root_instruction()) {
 | 
						|
        VLOG(4) << a << " is live out of computation and defined before " << b
 | 
						|
                << " which is in same computation";
 | 
						|
        return false;
 | 
						|
      }
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  return true;
 | 
						|
}
 | 
						|
 | 
						|
bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
 | 
						|
                               const HloDataflowAnalysis& dataflow) const {
 | 
						|
  // Buffers without disjoint liveness may interfere.
 | 
						|
  return !LiveRangeStrictlyBefore(a, b, dataflow) &&
 | 
						|
         !LiveRangeStrictlyBefore(b, a, dataflow);
 | 
						|
}
 | 
						|
 | 
						|
PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
 | 
						|
    : HloOrdering(module) {}
 | 
						|
 | 
						|
bool PredecessorHloOrdering::ExecutesBeforeInSameComputation(
 | 
						|
    const HloInstruction* a, const HloInstruction* b) const {
 | 
						|
  CHECK_EQ(a->parent(), b->parent());
 | 
						|
 | 
						|
  // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'.
 | 
						|
  return a != b && predecessors_.at(a->parent())->IsReachable(a, b);
 | 
						|
}
 | 
						|
 | 
						|
string PredecessorHloOrdering::ToStringHelper(const string& name) const {
 | 
						|
  std::vector<string> pieces;
 | 
						|
  pieces.push_back(name);
 | 
						|
  for (auto* computation : module_->MakeNonfusionComputations()) {
 | 
						|
    pieces.push_back(absl::StrFormat("computation %s:", computation->name()));
 | 
						|
    const auto all = computation->MakeInstructionPostOrder();
 | 
						|
    for (auto instruction : all) {
 | 
						|
      pieces.push_back(
 | 
						|
          absl::StrFormat("  %s predecessors:", instruction->name()));
 | 
						|
      for (auto predecessor : all) {
 | 
						|
        if (predecessors_.at(computation)
 | 
						|
                ->IsReachable(predecessor, instruction)) {
 | 
						|
          pieces.push_back(absl::StrFormat("    %s", predecessor->name()));
 | 
						|
        }
 | 
						|
      }
 | 
						|
    }
 | 
						|
  }
 | 
						|
  return absl::StrJoin(pieces, "\n");
 | 
						|
}
 | 
						|
 | 
						|
DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
 | 
						|
    : PredecessorHloOrdering(module) {
 | 
						|
  // Compute predecessor relationships between all instructions to determine
 | 
						|
  // ordering based on dependencies. ExecutesBefore will return true iff there
 | 
						|
  // exists a path in the HLO computation graph from 'a' to 'b'.
 | 
						|
  for (auto* computation : module->MakeNonfusionComputations()) {
 | 
						|
    predecessors_.emplace(computation, HloReachabilityMap::Build(computation));
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
string DependencyHloOrdering::ToString() const {
 | 
						|
  return ToStringHelper("DependencyHloOrdering");
 | 
						|
}
 | 
						|
 | 
						|
SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule)
 | 
						|
    : HloOrdering(schedule.module()), schedule_(schedule) {
 | 
						|
  Initialize();
 | 
						|
}
 | 
						|
 | 
						|
SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule)
 | 
						|
    : HloOrdering(schedule.module()), schedule_(std::move(schedule)) {
 | 
						|
  Initialize();
 | 
						|
}
 | 
						|
 | 
						|
void SequentialHloOrdering::Initialize() {
 | 
						|
  // Create a map from instruction to its order position.
 | 
						|
  TF_DCHECK_OK(schedule_.Verify());
 | 
						|
  for (const auto& computation_sequence : schedule_.sequences()) {
 | 
						|
    const auto& order = computation_sequence.second.instructions();
 | 
						|
    for (int i = 0; i < order.size(); ++i) {
 | 
						|
      InsertOrDie(&order_position_, order[i], i);
 | 
						|
    }
 | 
						|
  }
 | 
						|
}
 | 
						|
 | 
						|
bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
 | 
						|
    const HloInstruction* a, const HloInstruction* b) const {
 | 
						|
  CHECK_EQ(a->parent(), b->parent());
 | 
						|
  // If either instruction is not in the order, then 'a' and 'b' are unordered.
 | 
						|
  if (!order_position_.contains(a) || !order_position_.contains(b)) {
 | 
						|
    return false;
 | 
						|
  }
 | 
						|
  return order_position_.at(a) < order_position_.at(b);
 | 
						|
}
 | 
						|
 | 
						|
const HloInstructionSequence* SequentialHloOrdering::SequentialOrder(
 | 
						|
    const HloComputation& computation) const {
 | 
						|
  return schedule_.is_computation_scheduled(&computation)
 | 
						|
             ? &schedule_.sequence(&computation)
 | 
						|
             : nullptr;
 | 
						|
}
 | 
						|
 | 
						|
string SequentialHloOrdering::ToString() const {
 | 
						|
  return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString());
 | 
						|
}
 | 
						|
 | 
						|
}  // namespace xla
 |