[XLA] Add HloLocation to dataflow analysis.
Add an HloLocation abstraction to dataflow analysis which indicates where (in the output of what instruction and at which index) an HloValue may appear. Previously only uses were stored with an HLO value where a use is an edge in the HLO graph (instruction, operand number and ShapeIndex). Also, change the handling of tuple-shaped kSelect instructions when ssa_form is true. Previously a phi value would be created. With this change the the value set instead contains the union of it's inputs identical to the ssa_form=false case. PiperOrigin-RevId: 158276598
This commit is contained in:
parent
b9d5e14419
commit
9f17c26caa
@ -43,6 +43,17 @@ namespace xla {
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
string HloLocation::ToString() const {
|
||||
string index_str =
|
||||
ShapeUtil::IsTuple(instruction->shape()) ? (" " + index.ToString()) : "";
|
||||
return StrCat(instruction->FullyQualifiedName(), index_str);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloLocation& location) {
|
||||
out << location.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
string HloUse::ToString() const {
|
||||
string index_str =
|
||||
ShapeUtil::IsTuple(instruction->operand(operand_number)->shape())
|
||||
@ -57,6 +68,13 @@ std::ostream& operator<<(std::ostream& out, const HloUse& use) {
|
||||
return out;
|
||||
}
|
||||
|
||||
HloValue::HloValue(HloValue::Id id, HloInstruction* instruction,
|
||||
const ShapeIndex& index, bool is_phi)
|
||||
: id_(id), is_phi_(is_phi) {
|
||||
// The defining location is always the first element in the locations_ vector.
|
||||
AddLocation(instruction, index);
|
||||
}
|
||||
|
||||
bool HloValue::operator==(const HloValue& other) const {
|
||||
bool equal = instruction() == other.instruction() && index() == other.index();
|
||||
// If the values are equal they most both be phi (or non phi).
|
||||
@ -70,34 +88,95 @@ bool HloValue::operator!=(const HloValue& other) const {
|
||||
|
||||
string HloValue::ToShortString() const {
|
||||
string index_str =
|
||||
ShapeUtil::IsTuple(instruction_->shape()) ? index_.ToString() : "";
|
||||
return StrCat(is_phi_ ? "PHI " : "", instruction_->FullyQualifiedName(),
|
||||
ShapeUtil::IsTuple(instruction()->shape()) ? index().ToString() : "";
|
||||
return StrCat(is_phi_ ? "PHI " : "", instruction()->FullyQualifiedName(),
|
||||
index_str);
|
||||
}
|
||||
|
||||
string HloValue::ToString(int indent) const {
|
||||
string indentation(indent, ' ');
|
||||
string out = StrCat(indentation, ToShortString(), ", uses:\n");
|
||||
string out = StrCat(indentation, ToShortString(), ", locations:\n");
|
||||
for (const HloLocation& location : locations()) {
|
||||
StrAppend(&out, indentation, " ", location.ToString(), "\n");
|
||||
}
|
||||
StrAppend(&out, indentation, " uses:\n");
|
||||
for (const HloUse& use : uses()) {
|
||||
StrAppend(&out, indentation, " ", use.ToString(), "\n");
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void HloValue::AddUse(HloInstruction* instruction, int64 operand_number,
|
||||
const ShapeIndex& operand_index) {
|
||||
HloUse use = {instruction, operand_number, operand_index};
|
||||
CHECK(std::find(uses_.begin(), uses_.end(), use) == uses_.end());
|
||||
uses_.push_back(std::move(use));
|
||||
void HloValue::AddLocation(HloInstruction* instruction,
|
||||
const ShapeIndex& index) {
|
||||
// The given location should not already exist in locations_.
|
||||
for (const HloLocation& location : locations_) {
|
||||
DCHECK(!(location.instruction == instruction && location.index == index));
|
||||
}
|
||||
|
||||
locations_.push_back(HloLocation{instruction, index});
|
||||
|
||||
// Update uses.
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
for (int64 operand_number : user->OperandIndices(instruction)) {
|
||||
if (!DoesNotUseOperandBuffer(instruction, index, user)) {
|
||||
for (const HloUse& use : uses_) {
|
||||
// Verify that this use does not already exist.
|
||||
DCHECK(!(use.instruction == user &&
|
||||
use.operand_number == operand_number &&
|
||||
use.operand_index == index));
|
||||
}
|
||||
|
||||
uses_.push_back(HloUse{user, operand_number, index});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update liveout status of this HloValue.
|
||||
const HloModule& module = *instruction->parent()->parent();
|
||||
if (instruction == module.entry_computation()->root_instruction()) {
|
||||
live_out_of_module_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void HloValue::RemoveUse(HloInstruction* instruction, int64 operand_number,
|
||||
const ShapeIndex& operand_index) {
|
||||
HloUse use = {instruction, operand_number, operand_index};
|
||||
auto it = std::find(uses_.begin(), uses_.end(), use);
|
||||
CHECK(it != uses_.end());
|
||||
uses_.erase(it);
|
||||
DCHECK(std::find(uses_.begin(), uses_.end(), use) == uses_.end());
|
||||
void HloValue::RemoveLocation(HloInstruction* instruction,
|
||||
const ShapeIndex& index) {
|
||||
// The defining location cannot be removed.
|
||||
CHECK(!(instruction == this->instruction() && index == this->index()));
|
||||
|
||||
int64 size_before = locations_.size();
|
||||
locations_.erase(
|
||||
std::remove_if(locations_.begin(), locations_.end(),
|
||||
[instruction, &index](const HloLocation& location) {
|
||||
return location.instruction == instruction &&
|
||||
location.index == index;
|
||||
}),
|
||||
locations_.end());
|
||||
// Only a single location should have been removed.
|
||||
CHECK_EQ(locations_.size(), size_before - 1);
|
||||
|
||||
// Update uses which referred to this location.
|
||||
uses_.erase(std::remove_if(uses_.begin(), uses_.end(),
|
||||
[instruction, &index](const HloUse& use) {
|
||||
return use.instruction->operand(
|
||||
use.operand_number) == instruction &&
|
||||
use.operand_index == index;
|
||||
}),
|
||||
uses_.end());
|
||||
|
||||
const HloModule& module = *instruction->parent()->parent();
|
||||
if (instruction == module.entry_computation()->root_instruction()) {
|
||||
// Value has been removed from a location in the entry root instruction.
|
||||
// Check if the value is still live out of the module by walking all
|
||||
// remaining locations.
|
||||
live_out_of_module_ = false;
|
||||
for (const HloLocation& location : locations()) {
|
||||
if (location.instruction ==
|
||||
module.entry_computation()->root_instruction()) {
|
||||
live_out_of_module_ = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloValue& value) {
|
||||
@ -159,11 +238,11 @@ std::ostream& operator<<(std::ostream& out,
|
||||
|
||||
string InstructionValueSet::ToString() const {
|
||||
string out =
|
||||
StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")");
|
||||
ForEachElement(
|
||||
[this, &out](const ShapeIndex& index, const HloValueSet& value_set) {
|
||||
StrAppend(&out, index.ToString(), " : ", value_set.ToString(), "\n");
|
||||
});
|
||||
StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n");
|
||||
ForEachElement([this, &out](const ShapeIndex& index,
|
||||
const HloValueSet& value_set) {
|
||||
StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n");
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -315,30 +394,19 @@ InstructionValueSet HloDataflowAnalysis::Phi(
|
||||
*value_set = GetInstructionValueSet(instruction).element(index);
|
||||
return;
|
||||
}
|
||||
// Return the unique value at the current index in the given
|
||||
// InstructionValueSet. Returns null if the value has not yet been
|
||||
// determined.
|
||||
auto unique_value_or_null = [this,
|
||||
&index](const InstructionValueSet& ivset) {
|
||||
const HloValueSet& vset = ivset.element(index);
|
||||
CHECK_LE(vset.value_ids().size(), 1);
|
||||
return vset.value_ids().empty() ? nullptr
|
||||
: &GetValue(vset.GetUniqueValueId());
|
||||
};
|
||||
|
||||
// Save the old value at this index.
|
||||
const HloValue* old_value =
|
||||
unique_value_or_null(GetInstructionValueSet(instruction));
|
||||
bool old_value_is_phi = old_value != nullptr && old_value->is_phi() &&
|
||||
ValueIsDefinedAt(instruction, index);
|
||||
// Identify the existing phi value at this index if it exists.
|
||||
const HloValue* existing_phi_value = nullptr;
|
||||
if (ValueIsDefinedAt(instruction, index) &&
|
||||
GetUniqueValueAt(instruction, index).is_phi()) {
|
||||
existing_phi_value = &GetUniqueValueAt(instruction, index);
|
||||
}
|
||||
|
||||
// Construct a vector of unique value IDs of the inputs.
|
||||
std::vector<HloValue::Id> input_value_ids;
|
||||
for (const InstructionValueSet* input : inputs) {
|
||||
// All values must be unique.
|
||||
const HloValue* input_value = unique_value_or_null(*input);
|
||||
if (input_value != nullptr) {
|
||||
input_value_ids.push_back(input_value->id());
|
||||
for (HloValue::Id value_id : input->element(index).value_ids()) {
|
||||
input_value_ids.push_back(value_id);
|
||||
}
|
||||
}
|
||||
input_value_ids.erase(
|
||||
@ -348,9 +416,9 @@ InstructionValueSet HloDataflowAnalysis::Phi(
|
||||
// Remove the existing phi value (if it exists). The phi can be its own
|
||||
// input, for example, in while body parameters where the body passes
|
||||
// through the parameter value.
|
||||
if (old_value_is_phi) {
|
||||
if (existing_phi_value != nullptr) {
|
||||
auto it = std::find(input_value_ids.begin(), input_value_ids.end(),
|
||||
old_value->id());
|
||||
existing_phi_value->id());
|
||||
if (it != input_value_ids.end()) {
|
||||
input_value_ids.erase(it);
|
||||
}
|
||||
@ -360,20 +428,20 @@ InstructionValueSet HloDataflowAnalysis::Phi(
|
||||
if (input_value_ids.size() == 1) {
|
||||
*value_set = HloValueSet({input_value_ids[0]});
|
||||
}
|
||||
if (old_value_is_phi) {
|
||||
if (existing_phi_value) {
|
||||
// The merge point does not have multiple distinct inputs (which are
|
||||
// not the phi value itself). Therefore there is no need to insert a
|
||||
// phi value because there is a single reaching definition (or no
|
||||
// reaching definition).
|
||||
DeleteHloValue(old_value->id());
|
||||
DeleteHloValue(existing_phi_value->id());
|
||||
}
|
||||
} else if (input_value_ids.size() > 1) {
|
||||
// Multiple distinct values reach this point. A phi value is
|
||||
// necessary.
|
||||
if (old_value_is_phi) {
|
||||
if (existing_phi_value) {
|
||||
// A phi value already exists so reuse it in the new
|
||||
// InstructionValueSet.
|
||||
*value_set = HloValueSet({old_value->id()});
|
||||
*value_set = HloValueSet({existing_phi_value->id()});
|
||||
} else {
|
||||
// Create a new phi value.
|
||||
*value_set =
|
||||
@ -384,61 +452,40 @@ InstructionValueSet HloDataflowAnalysis::Phi(
|
||||
return new_value_set;
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::UpdateUsesOfValuesAt(
|
||||
void HloDataflowAnalysis::UpdateLocationsOfValuesAt(
|
||||
HloInstruction* instruction, const InstructionValueSet& new_value_set,
|
||||
const InstructionValueSet* prev_value_set) {
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
for (int64 operand_number : user->OperandIndices(instruction)) {
|
||||
if (prev_value_set != nullptr) {
|
||||
// Remove uses from the old value set.
|
||||
prev_value_set->ForEachElement(
|
||||
[this, instruction, user, operand_number](
|
||||
const ShapeIndex& index, const HloValueSet& value_set) {
|
||||
for (HloValue::Id value_id : value_set.value_ids()) {
|
||||
// HloValues in the previous value set may have been deleted.
|
||||
if (!ContainsKey(values_, value_id)) {
|
||||
continue;
|
||||
}
|
||||
if (!DoesNotUseOperandBuffer(instruction, index, user)) {
|
||||
GetValue(value_id).RemoveUse(user, operand_number, index);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
// Add uses in the new value set.
|
||||
new_value_set.ForEachElement(
|
||||
[this, instruction, user, operand_number](
|
||||
const ShapeIndex& index, const HloValueSet& value_set) {
|
||||
for (HloValue::Id value_id : value_set.value_ids()) {
|
||||
if (!DoesNotUseOperandBuffer(instruction, index, user)) {
|
||||
GetValue(value_id).AddUse(user, operand_number, index);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::UpdateLiveOutValues(
|
||||
const InstructionValueSet& new_root_value_set,
|
||||
const InstructionValueSet* prev_root_value_set) {
|
||||
if (prev_root_value_set != nullptr) {
|
||||
// Clear the old live out set.
|
||||
prev_root_value_set->ForEachElement(
|
||||
[this](const ShapeIndex& index, const HloValueSet& value_set) {
|
||||
if (prev_value_set != nullptr) {
|
||||
// Remove locations from the old value set.
|
||||
prev_value_set->ForEachElement(
|
||||
[this, instruction](const ShapeIndex& index,
|
||||
const HloValueSet& value_set) {
|
||||
for (HloValue::Id value_id : value_set.value_ids()) {
|
||||
// HloValues in the previous value set may have been deleted.
|
||||
if (!ContainsKey(values_, value_id)) {
|
||||
continue;
|
||||
}
|
||||
GetValue(value_id).set_live_out_of_module(false);
|
||||
// Don't remove the defining location of the value.
|
||||
HloValue& value = GetValue(value_id);
|
||||
if (instruction == value.instruction()) {
|
||||
CHECK_EQ(index, value.index());
|
||||
} else {
|
||||
value.RemoveLocation(instruction, index);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
new_root_value_set.ForEachElement(
|
||||
[this](const ShapeIndex& index, const HloValueSet& value_set) {
|
||||
// Add locations in the new value set.
|
||||
new_value_set.ForEachElement(
|
||||
[this, instruction](const ShapeIndex& index,
|
||||
const HloValueSet& value_set) {
|
||||
for (HloValue::Id value_id : value_set.value_ids()) {
|
||||
GetValue(value_id).set_live_out_of_module(true);
|
||||
HloValue& value = GetValue(value_id);
|
||||
if (instruction == value.instruction()) {
|
||||
CHECK_EQ(index, value.index());
|
||||
} else {
|
||||
value.AddLocation(instruction, index);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -483,9 +530,11 @@ InstructionValueSet HloDataflowAnalysis::RecomputeSelectValueSet(
|
||||
std::vector<const InstructionValueSet*> inputs = {
|
||||
&GetInstructionValueSet(select->operand(1)),
|
||||
&GetInstructionValueSet(select->operand(2))};
|
||||
InstructionValueSet new_value_set =
|
||||
ssa_form_ ? Phi(select, inputs, /*skip_top_level=*/true)
|
||||
: InstructionValueSet::Union(inputs);
|
||||
// A phi value is not defined at a kSelect instruction because kSelect does
|
||||
// not create a new value. Rather it forwards a value from its operands. This
|
||||
// contrasts with kWhile instruction (which does define a phi value) which has
|
||||
// in-place update semantics.
|
||||
InstructionValueSet new_value_set = InstructionValueSet::Union(inputs);
|
||||
*new_value_set.mutable_element(/*index=*/{}) =
|
||||
GetInstructionValueSet(select).element(/*index=*/{});
|
||||
return new_value_set;
|
||||
@ -577,9 +626,14 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
|
||||
|
||||
if (GetInstructionValueSet(instruction) == old_value) {
|
||||
// No change to the instruction's value set.
|
||||
VLOG(4) << "No change.";
|
||||
continue;
|
||||
}
|
||||
|
||||
VLOG(4) << "New value set for " << instruction->name() << ": "
|
||||
<< GetInstructionValueSet(instruction);
|
||||
VLOG(4) << "Previously: " << old_value;
|
||||
|
||||
// Instruction value was updated. Add users to work list.
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
worklist.push(user);
|
||||
@ -617,13 +671,8 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
|
||||
// Update uses. First clear all of the old uses at the particular
|
||||
// operands. Then add the new uses. There may be overlap between the old
|
||||
// uses and new uses.
|
||||
UpdateUsesOfValuesAt(instruction, GetInstructionValueSet(instruction),
|
||||
&old_value);
|
||||
|
||||
// Reset module live-out values.
|
||||
if (instruction == module_->entry_computation()->root_instruction()) {
|
||||
UpdateLiveOutValues(GetInstructionValueSet(instruction), &old_value);
|
||||
}
|
||||
UpdateLocationsOfValuesAt(instruction, GetInstructionValueSet(instruction),
|
||||
&old_value);
|
||||
}
|
||||
}
|
||||
|
||||
@ -745,12 +794,10 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
define_all_values();
|
||||
break;
|
||||
}
|
||||
UpdateUsesOfValuesAt(instruction.get(),
|
||||
GetInstructionValueSet(instruction.get()));
|
||||
UpdateLocationsOfValuesAt(instruction.get(),
|
||||
GetInstructionValueSet(instruction.get()));
|
||||
}
|
||||
}
|
||||
UpdateLiveOutValues(
|
||||
GetInstructionValueSet(module_->entry_computation()->root_instruction()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -43,6 +43,23 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Abstraction which identifies a specific point in the XLA graph. An
|
||||
// HloLocation specifies a ShapeIndex within the output of a specific
|
||||
// instruction.
|
||||
struct HloLocation {
|
||||
HloInstruction* instruction;
|
||||
ShapeIndex index;
|
||||
|
||||
string ToString() const;
|
||||
|
||||
bool operator==(const HloLocation& other) const {
|
||||
return instruction == other.instruction && index == other.index;
|
||||
}
|
||||
bool operator!=(const HloLocation& other) const { return !(*this == other); }
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloLocation& location);
|
||||
|
||||
// Defines a single use of an HLO value.
|
||||
struct HloUse {
|
||||
// Instruction at which the value is used.
|
||||
@ -93,12 +110,10 @@ class HloValue {
|
||||
|
||||
// Construct an HloValue defined by 'instruction' at shape index 'index'. If
|
||||
// is_phi is true, then this value is a phi value, for example, at the
|
||||
// parameter of a while body computation or in a select instruction. Phi
|
||||
// values are only used in the SSA dataflow analysis
|
||||
// (HloDataflowAnalysis::ssa_form_ is true).
|
||||
// parameter of a while body computation. Phi values are only used in the SSA
|
||||
// dataflow analysis (HloDataflowAnalysis::ssa_form_ is true).
|
||||
HloValue(HloValue::Id id, HloInstruction* instruction,
|
||||
const ShapeIndex& index, bool is_phi = false)
|
||||
: id_(id), instruction_(instruction), index_(index), is_phi_(is_phi) {}
|
||||
const ShapeIndex& index, bool is_phi = false);
|
||||
|
||||
// Return a unique identifier for this HloValue. This value is used for stable
|
||||
// sorting and iteration
|
||||
@ -107,26 +122,31 @@ class HloValue {
|
||||
// Returns whether this value is a phi value.
|
||||
bool is_phi() const { return is_phi_; }
|
||||
|
||||
// Return the location where this value is defined.
|
||||
const HloLocation& DefinitionLocation() const { return locations_[0]; }
|
||||
|
||||
// Return the instruction which defines this HloValue.
|
||||
HloInstruction* instruction() const { return instruction_; }
|
||||
HloInstruction* instruction() const {
|
||||
return DefinitionLocation().instruction;
|
||||
}
|
||||
|
||||
// Return the shape index at which this HloValue is defined in the output of
|
||||
// instruction().
|
||||
const ShapeIndex& index() const { return index_; }
|
||||
const ShapeIndex& index() const { return DefinitionLocation().index; }
|
||||
|
||||
// Add or remove a use of the HloValue at a particular operand of an
|
||||
// instruction.
|
||||
void AddUse(HloInstruction* instruction, int64 operand_number,
|
||||
const ShapeIndex& operand_index);
|
||||
void RemoveUse(HloInstruction* instruction, int64 operand_number,
|
||||
const ShapeIndex& operand_index);
|
||||
// Add or remove a location at which the HloValue appears. The definition
|
||||
// location can not be removed. The uses of the HloValue are updated.
|
||||
void AddLocation(HloInstruction* instruction, const ShapeIndex& index);
|
||||
void RemoveLocation(HloInstruction* instruction, const ShapeIndex& index);
|
||||
|
||||
// Return all locations of the HloValue in the module.
|
||||
const std::vector<HloLocation>& locations() const { return locations_; }
|
||||
|
||||
// Return all uses of the HloValue.
|
||||
const std::vector<HloUse>& uses() const { return uses_; }
|
||||
|
||||
// Set/get whether this HloValue is live out of the module.
|
||||
bool live_out_of_module() const { return live_out_of_module_; }
|
||||
void set_live_out_of_module(bool value) { live_out_of_module_ = value; }
|
||||
|
||||
bool operator==(const HloValue& other) const;
|
||||
bool operator!=(const HloValue& other) const;
|
||||
@ -140,15 +160,13 @@ class HloValue {
|
||||
// Unique identifier for this HloValue. Used for stable sorting and iteration.
|
||||
const Id id_;
|
||||
|
||||
// The instruction defining this value.
|
||||
HloInstruction* const instruction_;
|
||||
|
||||
// Shape index at which this value is defined.
|
||||
const ShapeIndex index_;
|
||||
|
||||
// Whether this instruction is a phi value.
|
||||
const bool is_phi_;
|
||||
|
||||
// The set of locations of this HloValue. The first element is always the
|
||||
// location of the definition.
|
||||
std::vector<HloLocation> locations_;
|
||||
|
||||
// The set of uses of this HloValue.
|
||||
std::vector<HloUse> uses_;
|
||||
|
||||
@ -233,15 +251,14 @@ class HloDataflowAnalysis {
|
||||
//
|
||||
// ssa_form : If true then new values are defined at merge points in the XLA
|
||||
// graph. Abusing nomenclature somewhat, we call these "phi values".
|
||||
// Merge points exist at Select instructions, While instructions (formed
|
||||
// by the init value and loop backedge), and subcomputations which are
|
||||
// called via kCall from more than one callsite. The SSA form is minimal
|
||||
// in that a new phi value is defined only if the merge point is reachable
|
||||
// by multiple different values. The SSA form is also in loop-closed form
|
||||
// in that no values defined inside of a loop (while body) is used outside
|
||||
// of the loop. In SSA form every location in the HLO graph (instruction
|
||||
// and ShapeIndex) has a single unique value (a unique reaching
|
||||
// definition).
|
||||
// Merge points exist at While instructions (formed by the init value and
|
||||
// loop backedge), and subcomputations which are called via kCall from
|
||||
// more than one callsite. The SSA form is minimal in that a new phi value
|
||||
// is defined only if the merge point is reachable by multiple different
|
||||
// values. The SSA form is also in loop-closed form in that no values
|
||||
// defined inside of a loop (while body) is used outside of the loop. In
|
||||
// SSA form every location in the HLO graph (instruction and ShapeIndex)
|
||||
// has a single unique value (a unique reaching definition).
|
||||
//
|
||||
// If ssa_form is false, then merge points do not define new
|
||||
// values. Rather, the HloValueSet for the merge point contains the union
|
||||
@ -351,27 +368,17 @@ class HloDataflowAnalysis {
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs,
|
||||
bool skip_top_level = false);
|
||||
|
||||
// Updates the HloUses of the HloValues contained in the output of the given
|
||||
// instruction at all of the users of 'instruction'. This should be called
|
||||
// after the instruction value set of 'instruction' has been
|
||||
// changed. 'prev_value_set' must point to the previous state of the value set
|
||||
// prior to the change. 'prev_value_set' may be null if this is the first time
|
||||
// uses are being computed. The previous state is necessary to efficiently
|
||||
// remove uses which have been eliminated due to changes in the instructions'
|
||||
// InstructionValueSet.
|
||||
void UpdateUsesOfValuesAt(
|
||||
// Updates the locations of the HloValues in the output of the given
|
||||
// instruction. This should be called after the instruction value set of
|
||||
// 'instruction' has been changed. 'prev_value_set' must point to the previous
|
||||
// state of the value set prior to the change. 'prev_value_set' may be null if
|
||||
// this is the first time locations are being computed. The previous state is
|
||||
// necessary to efficiently remove locations which have been eliminated due to
|
||||
// changes in the instructions' InstructionValueSet.
|
||||
void UpdateLocationsOfValuesAt(
|
||||
HloInstruction* instruction, const InstructionValueSet& new_value_set,
|
||||
const InstructionValueSet* prev_value_set = nullptr);
|
||||
|
||||
// Updates the values live out of the module. This should be called after
|
||||
// the instruction value set of the root instruction of the entry computation
|
||||
// has been changed. 'prev_root_value_set' should point to the previous
|
||||
// InstructionValueSet of the entry root instruction. 'prev_root_set' can be
|
||||
// nullptr if this is the first time live-out values are being computed.
|
||||
void UpdateLiveOutValues(
|
||||
const InstructionValueSet& new_root_value_set,
|
||||
const InstructionValueSet* prev_root_value_set = nullptr);
|
||||
|
||||
HloModule* const module_;
|
||||
const bool ssa_form_;
|
||||
const bool bitcast_defines_value_;
|
||||
|
@ -89,6 +89,15 @@ TEST_P(HloDataflowAnalysisTest, BinaryOperation) {
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
|
||||
|
||||
// Verify the locations of the values. These locations are all trivial because
|
||||
// there are no instructions which forward values.
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).locations(),
|
||||
UnorderedElementsAre(HloLocation{constant1, {}}));
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).locations(),
|
||||
UnorderedElementsAre(HloLocation{constant2, {}}));
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(add).locations(),
|
||||
UnorderedElementsAre(HloLocation{add, {}}));
|
||||
|
||||
// Verify the uses of the values.
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
|
||||
UnorderedElementsAre(HloUse{add, 0, {}}));
|
||||
@ -134,6 +143,18 @@ TEST_P(HloDataflowAnalysisTest, TupleAndGtes) {
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(gte1));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
|
||||
|
||||
// Verify the locations of the values.
|
||||
EXPECT_THAT(
|
||||
analysis.GetValueDefinedAt(param0).locations(),
|
||||
UnorderedElementsAre(HloLocation{param0, {}}, HloLocation{tuple, {0}},
|
||||
HloLocation{gte0, {}}));
|
||||
EXPECT_THAT(
|
||||
analysis.GetValueDefinedAt(param1).locations(),
|
||||
UnorderedElementsAre(HloLocation{param1, {}}, HloLocation{tuple, {1}},
|
||||
HloLocation{gte1, {}}));
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(tuple).locations(),
|
||||
UnorderedElementsAre(HloLocation{tuple, {}}));
|
||||
|
||||
// Verify uses. Of interest is that a GetTupleElement instruction is only a
|
||||
// use of the top-level value in the tuple operand.
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(param0).uses(),
|
||||
@ -173,6 +194,14 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) {
|
||||
|
||||
EXPECT_EQ(analysis.values().size(), 4);
|
||||
|
||||
// Verify locations and uses.
|
||||
EXPECT_THAT(
|
||||
analysis.GetValueDefinedAt(constant1).locations(),
|
||||
UnorderedElementsAre(
|
||||
HloLocation{constant1, {}}, HloLocation{tuple, {0}},
|
||||
HloLocation{nested_tuple, {0, 0}}, HloLocation{nested_tuple, {1, 0}},
|
||||
HloLocation{nested_tuple, {2}}, HloLocation{gte_tuple, {0}},
|
||||
HloLocation{gte_out, {}}));
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
|
||||
UnorderedElementsAre(
|
||||
HloUse{tuple, 0, {}}, HloUse{nested_tuple, 0, {0}},
|
||||
@ -860,30 +889,6 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) {
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(select34));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(select1234));
|
||||
|
||||
if (ssa_form) {
|
||||
// All non-top-level elements should be phi instructions except for
|
||||
// %select11 which selects between the same values.
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(select11, /*index=*/{0}));
|
||||
EXPECT_THAT(HloValuesAt(select11, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(constant1)));
|
||||
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(select12, /*index=*/{0}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(select12, /*index=*/{0}).is_phi());
|
||||
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(select34, /*index=*/{0}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(select34, /*index=*/{0}).is_phi());
|
||||
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(select1234, /*index=*/{0}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(select1234, /*index=*/{0}).is_phi());
|
||||
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
|
||||
UnorderedElementsAre(
|
||||
HloUse{tuple1, 0, {}}, HloUse{select11, 1, {0}},
|
||||
HloUse{select11, 2, {0}}, HloUse{select12, 1, {0}}));
|
||||
EXPECT_THAT(
|
||||
analysis.GetValueDefinedAt(constant2).uses(),
|
||||
UnorderedElementsAre(HloUse{tuple2, 0, {}}, HloUse{select12, 2, {0}}));
|
||||
} else {
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(select11, /*index=*/{0}));
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(select12, /*index=*/{0}));
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(select34, /*index=*/{0}));
|
||||
@ -912,7 +917,6 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) {
|
||||
analysis.GetValueDefinedAt(constant2).uses(),
|
||||
UnorderedElementsAre(HloUse{tuple2, 0, {}}, HloUse{select12, 2, {0}},
|
||||
HloUse{select1234, 1, {0}}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
|
||||
@ -948,14 +952,6 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
|
||||
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(select));
|
||||
|
||||
if (ssa_form) {
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(select, /*index=*/{0}));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(select, /*index=*/{1}));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(select, /*index=*/{1, 0}));
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(select, /*index=*/{1, 1}));
|
||||
EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 1}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(constant3)));
|
||||
} else {
|
||||
EXPECT_THAT(HloValuesAt(select, /*index=*/{0}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(constant1),
|
||||
analysis.GetValueDefinedAt(constant4)));
|
||||
@ -967,7 +963,6 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
|
||||
analysis.GetValueDefinedAt(constant5)));
|
||||
EXPECT_THAT(HloValuesAt(select, /*index=*/{1, 1}),
|
||||
UnorderedElementsAre(analysis.GetValueDefinedAt(constant3)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) {
|
||||
@ -1045,19 +1040,16 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) {
|
||||
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
|
||||
|
||||
if (ssa_form) {
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
|
||||
EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}),
|
||||
UnorderedElementsAre(
|
||||
analysis.GetValueDefinedAt(select, /*index=*/{0})));
|
||||
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
|
||||
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(select, /*index=*/{0}));
|
||||
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
|
||||
EXPECT_FALSE(analysis.GetValueDefinedAt(constant3).live_out_of_module());
|
||||
EXPECT_TRUE(
|
||||
analysis.GetValueDefinedAt(select, /*index=*/{0}).live_out_of_module());
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
|
||||
.live_out_of_module());
|
||||
} else {
|
||||
|
Loading…
x
Reference in New Issue
Block a user