[XLA] Add module-scoped HLO dataflow analysis.
This is the first step to replacing TuplePointsToAnalysis with a global, module-scoped analysis. This dataflow analysis identifies all values and their defs and uses in the XLA graph. The analysis is currently unused. Follow up CLs will add buffer alias analysis using this dataflow analysis, and incrementally switch the transformation passes (for example, CopyInsertion) to use these new module-scoped analyses. PiperOrigin-RevId: 158067910
This commit is contained in:
parent
93c57c6e4f
commit
54ccc3e5a8
@ -1142,6 +1142,51 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_dataflow_analysis",
|
||||
srcs = [
|
||||
"hlo_dataflow_analysis.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"hlo_dataflow_analysis.h",
|
||||
],
|
||||
deps = [
|
||||
":call_graph",
|
||||
":hlo",
|
||||
":liveness_util",
|
||||
"//tensorflow/compiler/xla:shape_tree",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "hlo_dataflow_analysis_test",
|
||||
srcs = ["hlo_dataflow_analysis_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_dataflow_analysis",
|
||||
":hlo_matchers",
|
||||
":instruction_fusion",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tuple_points_to_analysis",
|
||||
srcs = [
|
||||
|
810
tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
Normal file
810
tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
Normal file
@ -0,0 +1,810 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iosfwd>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/liveness_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
string HloUse::ToString() const {
|
||||
string index_str =
|
||||
ShapeUtil::IsTuple(instruction->operand(operand_number)->shape())
|
||||
? (" " + operand_index.ToString())
|
||||
: "";
|
||||
return StrCat(instruction->FullyQualifiedName(), ", operand ", operand_number,
|
||||
index_str);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloUse& use) {
|
||||
out << use.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
bool HloValue::operator==(const HloValue& other) const {
|
||||
bool equal = instruction() == other.instruction() && index() == other.index();
|
||||
// If the values are equal they most both be phi (or non phi).
|
||||
CHECK(!(equal && is_phi() != other.is_phi()));
|
||||
return equal;
|
||||
}
|
||||
|
||||
bool HloValue::operator!=(const HloValue& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
string HloValue::ToShortString() const {
|
||||
string index_str =
|
||||
ShapeUtil::IsTuple(instruction_->shape()) ? index_.ToString() : "";
|
||||
return StrCat(is_phi_ ? "PHI " : "", instruction_->FullyQualifiedName(),
|
||||
index_str);
|
||||
}
|
||||
|
||||
string HloValue::ToString(int indent) const {
|
||||
string indentation(indent, ' ');
|
||||
string out = StrCat(indentation, ToShortString(), ", uses:\n");
|
||||
for (const HloUse& use : uses()) {
|
||||
StrAppend(&out, indentation, " ", use.ToString(), "\n");
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void HloValue::AddUse(HloInstruction* instruction, int64 operand_number,
|
||||
const ShapeIndex& operand_index) {
|
||||
HloUse use = {instruction, operand_number, operand_index};
|
||||
CHECK(std::find(uses_.begin(), uses_.end(), use) == uses_.end());
|
||||
uses_.push_back(std::move(use));
|
||||
}
|
||||
|
||||
void HloValue::RemoveUse(HloInstruction* instruction, int64 operand_number,
|
||||
const ShapeIndex& operand_index) {
|
||||
HloUse use = {instruction, operand_number, operand_index};
|
||||
auto it = std::find(uses_.begin(), uses_.end(), use);
|
||||
CHECK(it != uses_.end());
|
||||
uses_.erase(it);
|
||||
DCHECK(std::find(uses_.begin(), uses_.end(), use) == uses_.end());
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloValue& value) {
|
||||
out << value.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
void HloValueSet::SortAndUniquifyValues() {
|
||||
std::sort(value_ids_.begin(), value_ids_.end());
|
||||
value_ids_.erase(std::unique(value_ids_.begin(), value_ids_.end()),
|
||||
value_ids_.end());
|
||||
}
|
||||
|
||||
string HloValueSet::ToString() const {
|
||||
return StrCat("HloValueSet: ", tensorflow::str_util::Join(value_ids_, ", "));
|
||||
}
|
||||
|
||||
/*static */
|
||||
HloValueSet HloValueSet::Union(
|
||||
tensorflow::gtl::ArraySlice<const HloValueSet*> inputs) {
|
||||
HloValueSet union_set;
|
||||
for (const HloValueSet* input : inputs) {
|
||||
for (HloValue::Id value_id : input->value_ids()) {
|
||||
union_set.value_ids_.push_back(value_id);
|
||||
}
|
||||
}
|
||||
union_set.SortAndUniquifyValues();
|
||||
return union_set;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloValueSet& value_set) {
|
||||
out << value_set.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
void InstructionValueSet::ForEachValueSet(
|
||||
const InstructionValueSet::VisitorFunction& func) const {
|
||||
ForEachElement([&func](const ShapeIndex& index, bool /*is_leaf*/,
|
||||
const HloValueSet& value_set) {
|
||||
func(index, value_set);
|
||||
return Status::OK();
|
||||
})
|
||||
.IgnoreError();
|
||||
}
|
||||
|
||||
void InstructionValueSet::ForEachMutableValueSet(
|
||||
const InstructionValueSet::MutableVisitorFunction& func) {
|
||||
ForEachMutableElement([&func](const ShapeIndex& index, bool /*is_leaf*/,
|
||||
HloValueSet* value_set) {
|
||||
func(index, value_set);
|
||||
return Status::OK();
|
||||
})
|
||||
.IgnoreError();
|
||||
}
|
||||
|
||||
InstructionValueSet InstructionValueSet::Union(
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
|
||||
CHECK_GT(inputs.size(), 0);
|
||||
for (int i = 1; i < inputs.size(); ++i) {
|
||||
CHECK(ShapeUtil::Compatible(inputs[0]->shape(), inputs[i]->shape()));
|
||||
}
|
||||
InstructionValueSet union_set(inputs[0]->shape());
|
||||
union_set
|
||||
.ForEachMutableElement([&inputs](const ShapeIndex& index,
|
||||
bool /*is leaf*/,
|
||||
HloValueSet* value_set) {
|
||||
std::vector<const HloValueSet*> input_sets;
|
||||
for (const InstructionValueSet* input : inputs) {
|
||||
input_sets.push_back(&input->element(index));
|
||||
}
|
||||
*value_set = HloValueSet::Union(input_sets);
|
||||
return Status::OK();
|
||||
})
|
||||
.IgnoreError();
|
||||
return union_set;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out,
|
||||
const InstructionValueSet& instruction_value_set) {
|
||||
out << instruction_value_set.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
string InstructionValueSet::ToString() const {
|
||||
string out =
|
||||
StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")");
|
||||
ForEachElement([this, &out](const ShapeIndex& index, bool /*is_leaf*/,
|
||||
const HloValueSet& value_set) {
|
||||
StrAppend(&out, index.ToString(), " : ", value_set.ToString(), "\n");
|
||||
return Status::OK();
|
||||
})
|
||||
.IgnoreError();
|
||||
return out;
|
||||
}
|
||||
|
||||
HloDataflowAnalysis::HloDataflowAnalysis(HloModule* module, bool ssa_form,
|
||||
bool bitcast_defines_value)
|
||||
: module_(module),
|
||||
ssa_form_(ssa_form),
|
||||
bitcast_defines_value_(bitcast_defines_value),
|
||||
call_graph_(CallGraph::Build(module)) {}
|
||||
|
||||
bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index) const {
|
||||
const HloValueSet& value_set = GetValueSet(instruction, index);
|
||||
if (value_set.value_ids().size() != 1) {
|
||||
return false;
|
||||
}
|
||||
return GetValue(value_set.GetUniqueValueId()).instruction() == instruction;
|
||||
}
|
||||
|
||||
const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
|
||||
const HloInstruction* instruction, const ShapeIndex& index) const {
|
||||
CHECK(ValueIsDefinedAt(instruction, index));
|
||||
return GetUniqueValueAt(instruction, index);
|
||||
}
|
||||
|
||||
HloValue& HloDataflowAnalysis::GetValueDefinedAt(
|
||||
const HloInstruction* instruction, const ShapeIndex& index) {
|
||||
CHECK(ValueIsDefinedAt(instruction, index));
|
||||
return GetUniqueValueAt(instruction, index);
|
||||
}
|
||||
|
||||
HloValue::Id HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
|
||||
const ShapeIndex& index,
|
||||
bool is_phi) {
|
||||
int64 value_id = next_value_id_++;
|
||||
auto it_added = values_.emplace(
|
||||
std::piecewise_construct, std::forward_as_tuple(value_id),
|
||||
std::forward_as_tuple(value_id, instruction, index, is_phi));
|
||||
CHECK(it_added.second);
|
||||
|
||||
// Clear the vector of values as it is now stale. It will be lazily
|
||||
// reconstructed if needed when HloDataflowAnalysis::values() is called.
|
||||
values_vector_.clear();
|
||||
|
||||
return value_id;
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) {
|
||||
values_.erase(value_id);
|
||||
|
||||
// Clear the vector of values as it is now stale. It will be lazily
|
||||
// reconstructed if needed when HloDataflowAnalysis::values() is called.
|
||||
values_vector_.clear();
|
||||
}
|
||||
|
||||
string HloDataflowAnalysis::ToString() const {
|
||||
string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n");
|
||||
StrAppend(&out, " Instruction value sets:\n");
|
||||
for (const std::unique_ptr<HloComputation>& computation :
|
||||
module_->computations()) {
|
||||
for (const std::unique_ptr<HloInstruction>& instruction :
|
||||
computation->instructions()) {
|
||||
StrAppend(&out, " ", instruction->FullyQualifiedName(), ":\n");
|
||||
if (ShapeUtil::IsTuple(instruction->shape())) {
|
||||
GetInstructionValueSet(instruction.get())
|
||||
.ForEachValueSet([this, &instruction, &out](
|
||||
const ShapeIndex& index,
|
||||
const HloValueSet& value_set) {
|
||||
StrAppend(&out, " tuple index ", index.ToString(), ":\n");
|
||||
for (HloValue::Id value_id : value_set.value_ids()) {
|
||||
StrAppend(
|
||||
&out, " ", GetValue(value_id).ToShortString(),
|
||||
ValueIsDefinedAt(instruction.get(), index) ? " (def)" : "",
|
||||
"\n");
|
||||
}
|
||||
});
|
||||
} else {
|
||||
const HloValueSet& top_level_value_set =
|
||||
GetValueSet(instruction.get(), /*index=*/{});
|
||||
for (HloValue::Id value_id : top_level_value_set.value_ids()) {
|
||||
StrAppend(&out, " ", GetValue(value_id).ToShortString(),
|
||||
ValueIsDefinedAt(instruction.get()) ? " (def)" : "", "\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
StrAppend(&out, " HloValues:\n");
|
||||
for (const auto& pair : values_) {
|
||||
StrAppend(&out, pair.second.ToString(/*indent=*/4));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
|
||||
return values_.at(value_id);
|
||||
}
|
||||
|
||||
HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) {
|
||||
return values_.at(value_id);
|
||||
}
|
||||
|
||||
const HloValueSet& HloDataflowAnalysis::GetValueSet(
|
||||
const HloInstruction* instruction, const ShapeIndex& index) const {
|
||||
return GetInstructionValueSet(instruction).element(index);
|
||||
}
|
||||
|
||||
HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction,
|
||||
const ShapeIndex& index) {
|
||||
return *GetInstructionValueSet(instruction).mutable_element(index);
|
||||
}
|
||||
|
||||
std::vector<const HloValue*>& HloDataflowAnalysis::values() const {
|
||||
if (values_vector_.empty()) {
|
||||
// Lazily construct vector of values.
|
||||
values_vector_.reserve(values_.size());
|
||||
for (auto& pair : values_) {
|
||||
values_vector_.push_back(&pair.second);
|
||||
}
|
||||
std::sort(
|
||||
values_vector_.begin(), values_vector_.end(),
|
||||
[](const HloValue* a, const HloValue* b) { return a->id() < b->id(); });
|
||||
} else {
|
||||
CHECK_EQ(values_vector_.size(), values_.size());
|
||||
for (const HloValue* value : values_vector_) {
|
||||
DCHECK(ContainsKey(values_, value->id()));
|
||||
DCHECK(&GetValue(value->id()) == value);
|
||||
}
|
||||
}
|
||||
return values_vector_;
|
||||
}
|
||||
|
||||
/* static */
|
||||
InstructionValueSet HloDataflowAnalysis::Phi(
|
||||
HloInstruction* instruction,
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs,
|
||||
bool skip_top_level) {
|
||||
CHECK(ssa_form_);
|
||||
|
||||
for (const InstructionValueSet* input : inputs) {
|
||||
CHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
|
||||
}
|
||||
InstructionValueSet new_value_set(instruction->shape());
|
||||
new_value_set.ForEachMutableValueSet(
|
||||
[this, instruction, &inputs, skip_top_level](const ShapeIndex& index,
|
||||
HloValueSet* value_set) {
|
||||
// If we're skipping the top level, just copy over the existing
|
||||
// HloValueSet.
|
||||
if (skip_top_level && index.empty()) {
|
||||
*value_set = GetInstructionValueSet(instruction).element(index);
|
||||
return;
|
||||
}
|
||||
// Return the unique value at the current index in the given
|
||||
// InstructionValueSet. Returns null if the value has not yet been
|
||||
// determined.
|
||||
auto unique_value_or_null = [this,
|
||||
&index](const InstructionValueSet& ivset) {
|
||||
const HloValueSet& vset = ivset.element(index);
|
||||
CHECK_LE(vset.value_ids().size(), 1);
|
||||
return vset.value_ids().empty() ? nullptr
|
||||
: &GetValue(vset.GetUniqueValueId());
|
||||
};
|
||||
|
||||
// Save the old value at this index.
|
||||
const HloValue* old_value =
|
||||
unique_value_or_null(GetInstructionValueSet(instruction));
|
||||
bool old_value_is_phi = old_value != nullptr && old_value->is_phi() &&
|
||||
ValueIsDefinedAt(instruction, index);
|
||||
|
||||
// Construct a vector of unique value IDs of the inputs.
|
||||
std::vector<HloValue::Id> input_value_ids;
|
||||
for (const InstructionValueSet* input : inputs) {
|
||||
// All values must be unique.
|
||||
const HloValue* input_value = unique_value_or_null(*input);
|
||||
if (input_value != nullptr) {
|
||||
input_value_ids.push_back(input_value->id());
|
||||
}
|
||||
}
|
||||
input_value_ids.erase(
|
||||
std::unique(input_value_ids.begin(), input_value_ids.end()),
|
||||
input_value_ids.end());
|
||||
|
||||
// Remove the existing phi value (if it exists). The phi can be its own
|
||||
// input, for example, in while body parameters where the body passes
|
||||
// through the parameter value.
|
||||
if (old_value_is_phi) {
|
||||
auto it = std::find(input_value_ids.begin(), input_value_ids.end(),
|
||||
old_value->id());
|
||||
if (it != input_value_ids.end()) {
|
||||
input_value_ids.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
if (input_value_ids.size() <= 1) {
|
||||
if (input_value_ids.size() == 1) {
|
||||
*value_set = HloValueSet({input_value_ids[0]});
|
||||
}
|
||||
if (old_value_is_phi) {
|
||||
// The merge point does not have multiple distinct inputs (which are
|
||||
// not the phi value itself). Therefore there is no need to insert a
|
||||
// phi value because there is a single reaching definition (or no
|
||||
// reaching definition).
|
||||
DeleteHloValue(old_value->id());
|
||||
}
|
||||
} else if (input_value_ids.size() > 1) {
|
||||
// Multiple distinct values reach this point. A phi value is
|
||||
// necessary.
|
||||
if (old_value_is_phi) {
|
||||
// A phi value already exists so reuse it in the new
|
||||
// InstructionValueSet.
|
||||
*value_set = HloValueSet({old_value->id()});
|
||||
} else {
|
||||
// Create a new phi value.
|
||||
*value_set =
|
||||
HloValueSet({NewHloValue(instruction, index, /*is_phi=*/true)});
|
||||
}
|
||||
}
|
||||
});
|
||||
return new_value_set;
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::UpdateUsesOfValuesAt(
|
||||
HloInstruction* instruction, const InstructionValueSet& new_value_set,
|
||||
const InstructionValueSet* prev_value_set) {
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
for (int64 operand_number : user->OperandIndices(instruction)) {
|
||||
if (prev_value_set != nullptr) {
|
||||
// Remove uses from the old value set.
|
||||
prev_value_set->ForEachValueSet(
|
||||
[this, instruction, user, operand_number](
|
||||
const ShapeIndex& index, const HloValueSet& value_set) {
|
||||
for (HloValue::Id value_id : value_set.value_ids()) {
|
||||
// HloValues in the previous value set may have been deleted.
|
||||
if (!ContainsKey(values_, value_id)) {
|
||||
continue;
|
||||
}
|
||||
if (!DoesNotUseOperandBuffer(instruction, index, user)) {
|
||||
GetValue(value_id).RemoveUse(user, operand_number, index);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
// Add uses in the new value set.
|
||||
new_value_set.ForEachValueSet(
|
||||
[this, instruction, user, operand_number](
|
||||
const ShapeIndex& index, const HloValueSet& value_set) {
|
||||
for (HloValue::Id value_id : value_set.value_ids()) {
|
||||
if (!DoesNotUseOperandBuffer(instruction, index, user)) {
|
||||
GetValue(value_id).AddUse(user, operand_number, index);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::UpdateLiveOutValues(
|
||||
const InstructionValueSet& new_root_value_set,
|
||||
const InstructionValueSet* prev_root_value_set) {
|
||||
if (prev_root_value_set != nullptr) {
|
||||
// Clear the old live out set.
|
||||
prev_root_value_set->ForEachValueSet(
|
||||
[this](const ShapeIndex& index, const HloValueSet& value_set) {
|
||||
for (HloValue::Id value_id : value_set.value_ids()) {
|
||||
// HloValues in the previous value set may have been deleted.
|
||||
if (!ContainsKey(values_, value_id)) {
|
||||
continue;
|
||||
}
|
||||
GetValue(value_id).set_live_out_of_module(false);
|
||||
}
|
||||
});
|
||||
}
|
||||
new_root_value_set.ForEachValueSet(
|
||||
[this](const ShapeIndex& index, const HloValueSet& value_set) {
|
||||
for (HloValue::Id value_id : value_set.value_ids()) {
|
||||
GetValue(value_id).set_live_out_of_module(true);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
InstructionValueSet HloDataflowAnalysis::RecomputeBitcastValueSet(
|
||||
HloInstruction* bitcast) {
|
||||
CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
|
||||
if (bitcast_defines_value_) {
|
||||
return GetInstructionValueSet(bitcast);
|
||||
} else {
|
||||
return GetInstructionValueSet(bitcast->operand(0));
|
||||
}
|
||||
}
|
||||
|
||||
InstructionValueSet HloDataflowAnalysis::RecomputeCopyValueSet(
|
||||
HloInstruction* copy) {
|
||||
CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
|
||||
InstructionValueSet new_value_set = GetInstructionValueSet(copy);
|
||||
if (ShapeUtil::IsTuple(copy->shape())) {
|
||||
for (int i = 0; i < ShapeUtil::TupleElementCount(copy->shape()); ++i) {
|
||||
new_value_set.CopySubtreeFrom(GetInstructionValueSet(copy->operand(0)),
|
||||
/*source_base_index=*/{i},
|
||||
/*target_base_index=*/{i});
|
||||
}
|
||||
}
|
||||
return new_value_set;
|
||||
}
|
||||
|
||||
InstructionValueSet HloDataflowAnalysis::RecomputeGetTupleElementValueSet(
|
||||
HloInstruction* gte) {
|
||||
CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
|
||||
InstructionValueSet new_value_set(gte->shape());
|
||||
new_value_set.CopySubtreeFrom(GetInstructionValueSet(gte->operand(0)),
|
||||
/*source_base_index=*/{gte->tuple_index()},
|
||||
/*target_base_index=*/{});
|
||||
return new_value_set;
|
||||
}
|
||||
|
||||
InstructionValueSet HloDataflowAnalysis::RecomputeSelectValueSet(
|
||||
HloInstruction* select) {
|
||||
CHECK_EQ(select->opcode(), HloOpcode::kSelect);
|
||||
std::vector<const InstructionValueSet*> inputs = {
|
||||
&GetInstructionValueSet(select->operand(1)),
|
||||
&GetInstructionValueSet(select->operand(2))};
|
||||
InstructionValueSet new_value_set =
|
||||
ssa_form_ ? Phi(select, inputs, /*skip_top_level=*/true)
|
||||
: InstructionValueSet::Union(inputs);
|
||||
*new_value_set.mutable_element(/*index=*/{}) =
|
||||
GetInstructionValueSet(select).element(/*index=*/{});
|
||||
return new_value_set;
|
||||
}
|
||||
|
||||
InstructionValueSet HloDataflowAnalysis::RecomputeTupleValueSet(
|
||||
HloInstruction* tuple) {
|
||||
CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
|
||||
InstructionValueSet new_value_set(tuple->shape());
|
||||
*new_value_set.mutable_element(/*index=*/{}) =
|
||||
GetInstructionValueSet(tuple).element(/*index=*/{});
|
||||
for (int64 i = 0; i < tuple->operands().size(); ++i) {
|
||||
new_value_set.CopySubtreeFrom(GetInstructionValueSet(tuple->operand(i)),
|
||||
/*source_base_index=*/{},
|
||||
/*target_base_index=*/{i});
|
||||
}
|
||||
return new_value_set;
|
||||
}
|
||||
|
||||
InstructionValueSet HloDataflowAnalysis::RecomputeWhileValueSet(
|
||||
HloInstruction* xla_while) {
|
||||
CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile);
|
||||
std::vector<const InstructionValueSet*> inputs = {
|
||||
&GetInstructionValueSet(xla_while->while_body()->root_instruction()),
|
||||
&GetInstructionValueSet(xla_while->operand(0))};
|
||||
if (ssa_form_) {
|
||||
return Phi(xla_while, inputs);
|
||||
} else {
|
||||
return InstructionValueSet::Union(inputs);
|
||||
}
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::UpdateInstructionValueSet(
|
||||
HloInstruction* instruction) {
|
||||
// Recompute from operands.
|
||||
InstructionValueSet& value_set = GetInstructionValueSet(instruction);
|
||||
switch (instruction->opcode()) {
|
||||
case HloOpcode::kBitcast:
|
||||
value_set = RecomputeBitcastValueSet(instruction);
|
||||
break;
|
||||
case HloOpcode::kCopy:
|
||||
value_set = RecomputeCopyValueSet(instruction);
|
||||
break;
|
||||
case HloOpcode::kGetTupleElement:
|
||||
value_set = RecomputeGetTupleElementValueSet(instruction);
|
||||
break;
|
||||
case HloOpcode::kSelect:
|
||||
value_set = RecomputeSelectValueSet(instruction);
|
||||
break;
|
||||
case HloOpcode::kTuple:
|
||||
value_set = RecomputeTupleValueSet(instruction);
|
||||
break;
|
||||
case HloOpcode::kParameter:
|
||||
value_set = RecomputeParameterValueSet(instruction);
|
||||
break;
|
||||
case HloOpcode::kCall:
|
||||
// The output of a kCall instruction is exactly the output of the root of
|
||||
// the subcomputation.
|
||||
value_set =
|
||||
GetInstructionValueSet(instruction->to_apply()->root_instruction());
|
||||
break;
|
||||
case HloOpcode::kWhile:
|
||||
value_set = RecomputeWhileValueSet(instruction);
|
||||
break;
|
||||
default:
|
||||
// Instruction does not forward HloValues (it defines all values in its
|
||||
// output). No update is necessary.
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
|
||||
std::queue<HloInstruction*> worklist;
|
||||
for (HloInstruction* instruction : instructions) {
|
||||
worklist.push(instruction);
|
||||
}
|
||||
|
||||
while (!worklist.empty()) {
|
||||
HloInstruction* instruction = worklist.front();
|
||||
worklist.pop();
|
||||
|
||||
VLOG(3) << "Worklist top: " << instruction->name();
|
||||
VLOG(3) << ToString();
|
||||
|
||||
// Save old value for recomputing uses and live out.
|
||||
InstructionValueSet old_value = GetInstructionValueSet(instruction);
|
||||
UpdateInstructionValueSet(instruction);
|
||||
|
||||
if (GetInstructionValueSet(instruction) == old_value) {
|
||||
// No change to the instruction's value set.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Instruction value was updated. Add users to work list.
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
worklist.push(user);
|
||||
|
||||
// If user calls a computation, then the respective parameter(s) of the
|
||||
// computation need to be updated.
|
||||
for (HloComputation* called_computation : user->called_computations()) {
|
||||
for (int64 operand_number : user->OperandIndices(instruction)) {
|
||||
worklist.push(
|
||||
called_computation->parameter_instruction(operand_number));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If instruction is a root instruction, then propagate out to any calling
|
||||
// instruction and across any while backedge.
|
||||
if (instruction == instruction->parent()->root_instruction()) {
|
||||
const CallGraphNode& call_graph_node =
|
||||
call_graph_->GetNode(instruction->parent());
|
||||
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
|
||||
if (callsite.instruction()->opcode() == HloOpcode::kCall) {
|
||||
worklist.push(callsite.instruction());
|
||||
} else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
|
||||
// Add the while itself, and the body and condition parameters.
|
||||
worklist.push(callsite.instruction());
|
||||
worklist.push(
|
||||
callsite.instruction()->while_body()->parameter_instruction(0));
|
||||
worklist.push(
|
||||
callsite.instruction()->while_condition()->parameter_instruction(
|
||||
0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update uses. First clear all of the old uses at the particular
|
||||
// operands. Then add the new uses. There may be overlap between the old
|
||||
// uses and new uses.
|
||||
UpdateUsesOfValuesAt(instruction, GetInstructionValueSet(instruction),
|
||||
&old_value);
|
||||
|
||||
// Reset module live-out values.
|
||||
if (instruction == module_->entry_computation()->root_instruction()) {
|
||||
UpdateLiveOutValues(GetInstructionValueSet(instruction), &old_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
InstructionValueSet HloDataflowAnalysis::RecomputeParameterValueSet(
|
||||
HloInstruction* parameter) {
|
||||
CHECK_EQ(parameter->opcode(), HloOpcode::kParameter);
|
||||
const CallGraphNode& call_graph_node =
|
||||
call_graph_->GetNode(parameter->parent());
|
||||
|
||||
// Subcomputations called in a parallel context (eg, map) do not have dataflow
|
||||
// from the caller operands.
|
||||
if (call_graph_node.context() == CallContext::kParallel ||
|
||||
call_graph_node.caller_callsites().empty()) {
|
||||
return GetInstructionValueSet(parameter);
|
||||
}
|
||||
CHECK_EQ(call_graph_node.context(), CallContext::kSequential);
|
||||
|
||||
std::vector<const InstructionValueSet*> inputs;
|
||||
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
|
||||
inputs.push_back(&GetInstructionValueSet(
|
||||
callsite.instruction()->operand(parameter->parameter_number())));
|
||||
if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
|
||||
// In a while instruction, the backedge is also a dataflow input to the
|
||||
// parameter instruction. This code covers the case where the parameter is
|
||||
// in the while body or the parameter is in the while condition.
|
||||
inputs.push_back(&GetInstructionValueSet(
|
||||
callsite.instruction()->while_body()->root_instruction()));
|
||||
}
|
||||
}
|
||||
|
||||
if (ssa_form_) {
|
||||
return Phi(parameter, inputs);
|
||||
} else {
|
||||
return InstructionValueSet::Union(inputs);
|
||||
}
|
||||
}
|
||||
|
||||
const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
|
||||
const HloInstruction* instruction) const {
|
||||
return value_sets_.at(instruction);
|
||||
}
|
||||
|
||||
InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
|
||||
const HloInstruction* instruction) {
|
||||
return value_sets_.at(instruction);
|
||||
}
|
||||
|
||||
Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
for (const std::unique_ptr<HloComputation>& computation :
|
||||
module_->computations()) {
|
||||
const CallGraphNode& call_graph_node =
|
||||
call_graph_->GetNode(computation.get());
|
||||
for (const std::unique_ptr<HloInstruction>& instruction :
|
||||
computation->instructions()) {
|
||||
// Create an empty shape tree.
|
||||
value_sets_.emplace(std::piecewise_construct,
|
||||
std::forward_as_tuple(instruction.get()),
|
||||
std::forward_as_tuple(instruction->shape()));
|
||||
|
||||
// Lambda to set the value set to define all values in the output of the
|
||||
// instruction.
|
||||
auto define_all_values = [this, &instruction]() {
|
||||
GetInstructionValueSet(instruction.get())
|
||||
.ForEachMutableValueSet([this, &instruction](
|
||||
const ShapeIndex& index,
|
||||
HloValueSet* value_set) {
|
||||
*value_set = HloValueSet({NewHloValue(instruction.get(), index)});
|
||||
});
|
||||
};
|
||||
|
||||
// Lambda to set the value set to define only the top-level buffer in the
|
||||
// output of the instruction. Any other values flow from the operands of
|
||||
// the instruction (or from cross-computation dataflow).
|
||||
auto define_top_level_only = [this, &instruction]() {
|
||||
GetValueSet(instruction.get(), /*index=*/{}) =
|
||||
HloValueSet({NewHloValue(instruction.get(), /*index=*/{})});
|
||||
};
|
||||
|
||||
switch (instruction->opcode()) {
|
||||
case HloOpcode::kBitcast:
|
||||
if (bitcast_defines_value_) {
|
||||
define_all_values();
|
||||
}
|
||||
break;
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kWhile:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
// These instructions define no values. The values in their output
|
||||
// flow from their operands or from cross computation dataflow.
|
||||
break;
|
||||
case HloOpcode::kParameter:
|
||||
if (call_graph_node.caller_callsites().empty() ||
|
||||
call_graph_node.context() == CallContext::kParallel) {
|
||||
// Parameters of computations called in a parallel context (eg, map
|
||||
// and reduce) as well as parameters of dead computations define all
|
||||
// values in their output. Otherwise the values of the parameter
|
||||
// come from the caller (eg, operands to the kCall instruction).
|
||||
define_all_values();
|
||||
} else if (call_graph_node.context() == CallContext::kBoth) {
|
||||
// We do not support a subcomputation that is called from both a
|
||||
// parallel and sequential context. In this case, the parameter
|
||||
// would both define a value and propagate a value from its
|
||||
// caller. This limitation is not really a problem because the call
|
||||
// graph is typically flattened.
|
||||
return Unimplemented(
|
||||
"Computation %s is called in both a parallel (eg, kMap) and "
|
||||
"sequential (eg, kCall) context",
|
||||
computation->name().c_str());
|
||||
}
|
||||
break;
|
||||
case HloOpcode::kCopy:
|
||||
case HloOpcode::kSelect:
|
||||
case HloOpcode::kTuple:
|
||||
// These instructions only define their top-level values. Any other
|
||||
// values flow from their operands.
|
||||
define_top_level_only();
|
||||
break;
|
||||
default:
|
||||
define_all_values();
|
||||
break;
|
||||
}
|
||||
UpdateUsesOfValuesAt(instruction.get(),
|
||||
GetInstructionValueSet(instruction.get()));
|
||||
}
|
||||
}
|
||||
UpdateLiveOutValues(
|
||||
GetInstructionValueSet(module_->entry_computation()->root_instruction()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
|
||||
HloModule* module, bool ssa_form, bool bitcast_defines_value) {
|
||||
VLOG(1) << "HloDataflowAnalysis::Run on module " << module->name();
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
|
||||
auto dataflow_analysis = WrapUnique(
|
||||
new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value));
|
||||
|
||||
TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
|
||||
|
||||
// Construct list of all instructions to initialize the worklist to propagate
|
||||
// the data flow. For efficiency sort the instruction in post order so
|
||||
// producers appear before consumers.
|
||||
std::vector<HloInstruction*> all_instructions;
|
||||
for (const HloComputation* computation : module->MakeComputationPostOrder()) {
|
||||
for (HloInstruction* instruction :
|
||||
computation->MakeInstructionPostOrder()) {
|
||||
all_instructions.push_back(instruction);
|
||||
}
|
||||
}
|
||||
dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions);
|
||||
|
||||
VLOG(1) << dataflow_analysis->ToString();
|
||||
return std::move(dataflow_analysis);
|
||||
}
|
||||
|
||||
} // namespace xla
|
408
tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
Normal file
408
tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
Normal file
@ -0,0 +1,408 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Analysis for determining the possible set of values for all locations
|
||||
// (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped
|
||||
// tracking values across computation boundaries.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Defines a single use of an HLO value.
|
||||
struct HloUse {
|
||||
// Instruction at which the value is used.
|
||||
HloInstruction* instruction;
|
||||
|
||||
// The operand number in which the value is appears.
|
||||
int64 operand_number;
|
||||
|
||||
// The shape index within the operand in which the value appears.
|
||||
ShapeIndex operand_index;
|
||||
|
||||
string ToString() const;
|
||||
|
||||
bool operator==(const HloUse& other) const {
|
||||
return instruction == other.instruction &&
|
||||
operand_number == other.operand_number &&
|
||||
operand_index == other.operand_index;
|
||||
}
|
||||
|
||||
bool operator!=(const HloUse& other) const { return !(*this == other); }
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloUse& use);
|
||||
|
||||
// Class describing a value used by the dataflow analysis. XLA arrays are
|
||||
// trivially a single HloValue. Tuples are made up of more than one HloValue: an
|
||||
// HloValue for the pointer vector, and an HloValue for each child element.
|
||||
//
|
||||
// Every HloValue is defined by a particular instruction and most instructions
|
||||
// define only a single HloValue. Instructions which define a single HloValue
|
||||
// include array-shaped instructions such as Add but also includes Tuple-shaped
|
||||
// instructions such as Tuple. The Tuple instruction defines a single HloValue
|
||||
// which is a vector of pointers to the values containing the Tuple
|
||||
// instruction's operands. Though the result of the Tuple instruction includes
|
||||
// multiple values only the top-level HloValue (the vector of pointers) is
|
||||
// defined by the Tuple instruction. The values containing the tuple elements
|
||||
// are defined by earlier instructions, usually the operands of the Tuple
|
||||
// instruction.
|
||||
//
|
||||
// Instructions which construct both the tuple *and* the tuple elements define
|
||||
// more than one HloValue. This includes (at least) tuple-shaped Constant,
|
||||
// Parameter, Infeed and While instructions. These tuple-shaped instructions do
|
||||
// not assemble a tuple from existing HloValues like the Tuple instruction does,
|
||||
// but rather define all the HloValues in the tuple.
|
||||
class HloValue {
|
||||
public:
|
||||
using Id = int64;
|
||||
|
||||
// Construct an HloValue defined by 'instruction' at shape index 'index'. If
|
||||
// is_phi is true, then this value is a phi value, for example, at the
|
||||
// parameter of a while body computation or in a select instruction. Phi
|
||||
// values are only used in the SSA dataflow analysis
|
||||
// (HloDataflowAnalysis::ssa_form_ is true).
|
||||
HloValue(HloValue::Id id, HloInstruction* instruction,
|
||||
const ShapeIndex& index, bool is_phi = false)
|
||||
: id_(id), instruction_(instruction), index_(index), is_phi_(is_phi) {}
|
||||
|
||||
// Return a unique identifier for this HloValue. This value is used for stable
|
||||
// sorting and iteration
|
||||
Id id() const { return id_; }
|
||||
|
||||
// Returns whether this value is a phi value.
|
||||
bool is_phi() const { return is_phi_; }
|
||||
|
||||
// Return the instruction which defines this HloValue.
|
||||
HloInstruction* instruction() const { return instruction_; }
|
||||
|
||||
// Return the shape index at which this HloValue is defined in the output of
|
||||
// instruction().
|
||||
const ShapeIndex& index() const { return index_; }
|
||||
|
||||
// Add or remove a use of the HloValue at a particular operand of an
|
||||
// instruction.
|
||||
void AddUse(HloInstruction* instruction, int64 operand_number,
|
||||
const ShapeIndex& operand_index);
|
||||
void RemoveUse(HloInstruction* instruction, int64 operand_number,
|
||||
const ShapeIndex& operand_index);
|
||||
|
||||
// Return all uses of the HloValue.
|
||||
const std::vector<HloUse>& uses() const { return uses_; }
|
||||
|
||||
// Set/get whether this HloValue is live out of the module.
|
||||
bool live_out_of_module() const { return live_out_of_module_; }
|
||||
void set_live_out_of_module(bool value) { live_out_of_module_ = value; }
|
||||
|
||||
bool operator==(const HloValue& other) const;
|
||||
bool operator!=(const HloValue& other) const;
|
||||
|
||||
// Return a single-line string representation of the value.
|
||||
string ToShortString() const;
|
||||
|
||||
string ToString(int indent = 0) const;
|
||||
|
||||
private:
|
||||
// Unique identifier for this HloValue. Used for stable sorting and iteration.
|
||||
const Id id_;
|
||||
|
||||
// The instruction defining this value.
|
||||
HloInstruction* const instruction_;
|
||||
|
||||
// Shape index at which this value is defined.
|
||||
const ShapeIndex index_;
|
||||
|
||||
// Whether this instruction is a phi value.
|
||||
const bool is_phi_;
|
||||
|
||||
// The set of uses of this HloValue.
|
||||
std::vector<HloUse> uses_;
|
||||
|
||||
// Whether this value is live out of the HLO module.
|
||||
bool live_out_of_module_ = false;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value);
|
||||
|
||||
// A class representing the possible set of HloValues at a particular point
|
||||
// (shape index in the output of an instruction) in the XLA graph. This set
|
||||
// contains the set of reaching HloValue definitions. For a simple array-shaped
|
||||
// instruction like Add, the HloValueSet of the top-level of the instruction's
|
||||
// output trivially contains only the HloValue defined by the instruction. For
|
||||
// instructions which have non-trivial dataflow such as Tuple or Select, the
|
||||
// HloValueSets of the instruction's output contains one or more HloValues
|
||||
// defined by the instruction's operands or defined further up in the XLA graph.
|
||||
class HloValueSet {
|
||||
public:
|
||||
HloValueSet() = default;
|
||||
|
||||
explicit HloValueSet(tensorflow::gtl::ArraySlice<HloValue::Id> value_ids)
|
||||
: value_ids_(value_ids.begin(), value_ids.end()) {
|
||||
SortAndUniquifyValues();
|
||||
}
|
||||
|
||||
// Return the union of the given HloValueSets.
|
||||
static HloValueSet Union(
|
||||
tensorflow::gtl::ArraySlice<const HloValueSet*> inputs);
|
||||
|
||||
// Return the vector of the IDs of all HloValues in the set. Values in the
|
||||
// vector are unique and sorted.
|
||||
const std::vector<HloValue::Id>& value_ids() const { return value_ids_; }
|
||||
|
||||
// Return the unique HLO value in the set. CHECKs if the set does not contain
|
||||
// exactly one value.
|
||||
HloValue::Id GetUniqueValueId() const {
|
||||
CHECK_EQ(value_ids().size(), 1);
|
||||
return value_ids()[0];
|
||||
}
|
||||
|
||||
bool operator==(const HloValueSet& other) const {
|
||||
return value_ids() == other.value_ids();
|
||||
}
|
||||
bool operator!=(const HloValueSet& other) const { return !(*this == other); }
|
||||
|
||||
string ToString() const;
|
||||
|
||||
private:
|
||||
// Sorts value_ and removes duplicates. This should be called after adding any
|
||||
// elements to values_.
|
||||
void SortAndUniquifyValues();
|
||||
|
||||
// HloValues sorted by HloValue::Id.
|
||||
std::vector<HloValue::Id> value_ids_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value);
|
||||
|
||||
// A class collecting the HloValues which might be contained in the output of
|
||||
// an HLO instruction. For array-shaped instructions, an InstructionValueSet
|
||||
// trivially holds a single HloValueSet. Tuple-shaped InstructionValueSets
|
||||
// hold multiple HloValueSets.
|
||||
class InstructionValueSet : public ShapeTree<HloValueSet> {
|
||||
public:
|
||||
InstructionValueSet(const Shape& shape) : ShapeTree<HloValueSet>(shape) {}
|
||||
|
||||
// Return the union of the given InstructionValueSets.
|
||||
static InstructionValueSet Union(
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
|
||||
|
||||
// Traverse the shape and call the given function for each HloValueSet
|
||||
// contained in the InstructionValueSet at the corresponding ShapeIndex.
|
||||
using VisitorFunction = std::function<void(const ShapeIndex& /*index*/,
|
||||
const HloValueSet& /*value_set*/)>;
|
||||
void ForEachValueSet(const VisitorFunction& func) const;
|
||||
|
||||
// Mutable form of ForEachValueSet.
|
||||
using MutableVisitorFunction = std::function<void(
|
||||
const ShapeIndex& /*index*/, HloValueSet* /*value_set*/)>;
|
||||
void ForEachMutableValueSet(const MutableVisitorFunction& func);
|
||||
|
||||
string ToString() const;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out,
|
||||
const InstructionValueSet& instruction_value_set);
|
||||
|
||||
// Analysis which identifies all HLO values and their uses in an HLO module.
|
||||
class HloDataflowAnalysis {
|
||||
public:
|
||||
// Run dataflow analysis on the given module. Parameters:
|
||||
//
|
||||
// ssa_form : If true then new values are defined at merge points in the XLA
|
||||
// graph. Abusing nomenclature somewhat, we call these "phi values".
|
||||
// Merge points exist at Select instructions, While instructions (formed
|
||||
// by the init value and loop backedge), and subcomputations which are
|
||||
// called via kCall from more than one callsite. The SSA form is minimal
|
||||
// in that a new phi value is defined only if the merge point is reachable
|
||||
// by multiple different values. The SSA form is also in loop-closed form
|
||||
// in that no values defined inside of a loop (while body) is used outside
|
||||
// of the loop. In SSA form every location in the HLO graph (instruction
|
||||
// and ShapeIndex) has a single unique value (a unique reaching
|
||||
// definition).
|
||||
//
|
||||
// If ssa_form is false, then merge points do not define new
|
||||
// values. Rather, the HloValueSet for the merge point contains the union
|
||||
// of the merged HloValues. Therefore a location in the HLO graph
|
||||
// (instruction and ShapeIndex) may have more than one value (multiple
|
||||
// reaching definitions).
|
||||
//
|
||||
// bitcast_defines_value : If true then the Bitcast HLO instruction defines
|
||||
// a new HLO value in the analysis. If false then Bitcast forwards the
|
||||
// value of its operand.
|
||||
static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run(
|
||||
HloModule* module, bool ssa_form = false,
|
||||
bool bitcast_defines_value = false);
|
||||
|
||||
// Returns true if 'instruction' defines an HLO value at the given shape index
|
||||
// of its output.
|
||||
bool ValueIsDefinedAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) const;
|
||||
|
||||
// Return the HloValue defined by 'instruction' at the given shape index of
|
||||
// its output.
|
||||
//
|
||||
// Precondition: ValueIsDefinedAt is true for this instruction and index.
|
||||
const HloValue& GetValueDefinedAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) const;
|
||||
HloValue& GetValueDefinedAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {});
|
||||
|
||||
// Return the InstructionValueSet for the given instruction.
|
||||
const InstructionValueSet& GetInstructionValueSet(
|
||||
const HloInstruction* instruction) const;
|
||||
InstructionValueSet& GetInstructionValueSet(
|
||||
const HloInstruction* instruction);
|
||||
|
||||
// Return the HloValueSet for the given instruction at the given index.
|
||||
const HloValueSet& GetValueSet(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) const;
|
||||
HloValueSet& GetValueSet(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {});
|
||||
|
||||
// Return the unique value in the HloValueSet at the given instruction and
|
||||
// shape index. CHECKs if the value set does not contain a exactly one value.
|
||||
const HloValue& GetUniqueValueAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) const {
|
||||
return GetValue(GetValueSet(instruction, index).GetUniqueValueId());
|
||||
}
|
||||
HloValue& GetUniqueValueAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) {
|
||||
return GetValue(GetValueSet(instruction, index).GetUniqueValueId());
|
||||
}
|
||||
|
||||
// Return the HloValue with the given Id.
|
||||
const HloValue& GetValue(HloValue::Id value_id) const;
|
||||
HloValue& GetValue(HloValue::Id value_id);
|
||||
|
||||
// Return the total number of HloValues.
|
||||
int64 value_count() const { return values_.size(); }
|
||||
|
||||
// Return a vector of all HloValues stabily sorted by HloValue::Id. This
|
||||
// vector is lazily computed. Mutating operations on HloDataflowAnalysis may
|
||||
// invalidate the underlying vector requiring recomputation.
|
||||
std::vector<const HloValue*>& values() const;
|
||||
|
||||
string ToString() const;
|
||||
|
||||
protected:
|
||||
HloDataflowAnalysis(HloModule* module, bool ssa_form,
|
||||
bool bitcast_defines_value = false);
|
||||
|
||||
// Creates a new HloValue defined at the given instruction and shape index and
|
||||
// return its ID.
|
||||
HloValue::Id NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
|
||||
bool is_phi = false);
|
||||
|
||||
// Delete the HloValue with the given ID.
|
||||
void DeleteHloValue(HloValue::Id value_id);
|
||||
|
||||
// Constructs and initializes the InstructionValueSets of all instructions to
|
||||
// contain exactly the HloValues defined by each instruction. These values can
|
||||
// then propagated throughout the HLO graph by calling
|
||||
// UpdateInstructionsAndPropagate.
|
||||
Status InitializeInstructionValueSets();
|
||||
|
||||
// Updates the value set of the given instruction based on the values flowing
|
||||
// into the instruction (operands and cross-computation dataflow).
|
||||
void UpdateInstructionValueSet(HloInstruction* instruction);
|
||||
|
||||
// Recomputes and returns the value set for the given parameter instruction.
|
||||
InstructionValueSet RecomputeBitcastValueSet(HloInstruction* bitcast);
|
||||
InstructionValueSet RecomputeCopyValueSet(HloInstruction* copy);
|
||||
InstructionValueSet RecomputeGetTupleElementValueSet(HloInstruction* gte);
|
||||
InstructionValueSet RecomputeParameterValueSet(HloInstruction* parameter);
|
||||
InstructionValueSet RecomputeSelectValueSet(HloInstruction* select);
|
||||
InstructionValueSet RecomputeTupleValueSet(HloInstruction* tuple);
|
||||
InstructionValueSet RecomputeWhileValueSet(HloInstruction* xla_while);
|
||||
|
||||
// Update the value sets of the given instructions and propagate the
|
||||
// changes to fixed point.
|
||||
void UpdateInstructionsAndPropagate(
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
|
||||
|
||||
// Return the result of the SSA Phi function applied to the given inputs at
|
||||
// the given instruction. If skip_top_level is true, then the top level of the
|
||||
// value set of 'instruction' is not modified.
|
||||
InstructionValueSet Phi(
|
||||
HloInstruction* instruction,
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs,
|
||||
bool skip_top_level = false);
|
||||
|
||||
// Updates the HloUses of the HloValues contained in the output of the given
|
||||
// instruction at all of the users of 'instruction'. This should be called
|
||||
// after the instruction value set of 'instruction' has been
|
||||
// changed. 'prev_value_set' must point to the previous state of the value set
|
||||
// prior to the change. 'prev_value_set' may be null if this is the first time
|
||||
// uses are being computed. The previous state is necessary to efficiently
|
||||
// remove uses which have been eliminated due to changes in the instructions'
|
||||
// InstructionValueSet.
|
||||
void UpdateUsesOfValuesAt(
|
||||
HloInstruction* instruction, const InstructionValueSet& new_value_set,
|
||||
const InstructionValueSet* prev_value_set = nullptr);
|
||||
|
||||
// Updates the values live out of the module. This should be called after
|
||||
// the instruction value set of the root instruction of the entry computation
|
||||
// has been changed. 'prev_root_value_set' should point to the previous
|
||||
// InstructionValueSet of the entry root instruction. 'prev_root_set' can be
|
||||
// nullptr if this is the first time live-out values are being computed.
|
||||
void UpdateLiveOutValues(
|
||||
const InstructionValueSet& new_root_value_set,
|
||||
const InstructionValueSet* prev_root_value_set = nullptr);
|
||||
|
||||
HloModule* const module_;
|
||||
const bool ssa_form_;
|
||||
const bool bitcast_defines_value_;
|
||||
|
||||
std::unique_ptr<CallGraph> call_graph_;
|
||||
|
||||
// The map of all HloValues in the module.
|
||||
std::unordered_map<HloValue::Id, HloValue> values_;
|
||||
|
||||
// A map from instruction to InstructionValueSet.
|
||||
std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_;
|
||||
|
||||
// A lazily constructed vector containing all HloValues sorted by
|
||||
// HloValue::Id.
|
||||
mutable std::vector<const HloValue*> values_vector_;
|
||||
|
||||
// The Id to use for the next HloValue.
|
||||
HloValue::Id next_value_id_ = 0;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_
|
1152
tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
Normal file
1152
tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
Normal file
File diff suppressed because it is too large
Load Diff
@ -28,6 +28,17 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
bool DoesNotUseOperandBuffer(const HloInstruction* operand,
|
||||
const ShapeIndex& index,
|
||||
const HloInstruction* user) {
|
||||
CHECK(user->IsUserOf(operand))
|
||||
<< "user: " << user->ToString() << " operand: " << operand->ToString();
|
||||
|
||||
// GetTupleElement instructions only access the top-level buffer of their
|
||||
// operand.
|
||||
return (user->opcode() == HloOpcode::kGetTupleElement && !index.empty());
|
||||
}
|
||||
|
||||
bool DoesNotUseOperandBuffer(const HloInstruction* operand,
|
||||
const ShapeIndex& index,
|
||||
const HloInstruction* user,
|
||||
|
@ -18,9 +18,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -37,6 +34,12 @@ bool DoesNotUseOperandBuffer(const HloInstruction* operand,
|
||||
const HloInstruction* user,
|
||||
const TuplePointsToAnalysis& points_to_analysis);
|
||||
|
||||
// Overload which does not require points-to analysis. The result is more
|
||||
// conservative (returns false more often).
|
||||
bool DoesNotUseOperandBuffer(const HloInstruction* operand,
|
||||
const ShapeIndex& index,
|
||||
const HloInstruction* user);
|
||||
|
||||
// Returns true if 'user' (at 'user_index') can share a buffer with its operand
|
||||
// 'operand' (at 'operand_index').
|
||||
// Returns false otherwise.
|
||||
|
@ -163,7 +163,8 @@ std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
|
||||
return TransferFromDevice(result_shape, device_base);
|
||||
}
|
||||
|
||||
string HloTestBase::TestName() const {
|
||||
/* static */
|
||||
string HloTestBase::TestName() {
|
||||
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
||||
}
|
||||
|
||||
|
@ -94,7 +94,7 @@ class HloTestBase : public ::testing::Test {
|
||||
->Clear();
|
||||
}
|
||||
|
||||
string TestName() const;
|
||||
static string TestName();
|
||||
|
||||
std::unique_ptr<Backend> backend_;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user