Rename HloLocation to HloPosition, to avoid ambiguity with MemoryLocation.

PiperOrigin-RevId: 161716528
This commit is contained in:
A. Unique TensorFlower 2017-07-12 14:18:55 -07:00 committed by TensorFlower Gardener
parent 8e7f573716
commit 6b28eb0843
10 changed files with 162 additions and 162 deletions

View File

@ -94,13 +94,13 @@ void HloAliasAnalysis::CombineBuffers(
unified_buffer.AddValue(dataflow_analysis_->GetValue(value_id));
}
// Iterate through all locations where the buffer-to-eliminate exists and
// Iterate through all positions where the buffer-to-eliminate exists and
// replace it with the unified buffer.
for (const HloLocation& location : buffer.locations()) {
VLOG(4) << "Replacing in " << location;
GetBufferSet(location.instruction, location.index)
for (const HloPosition& position : buffer.positions()) {
VLOG(4) << "Replacing in " << position;
GetBufferSet(position.instruction, position.index)
.RemoveBufferOrDie(buffer_id);
GetBufferSet(location.instruction, location.index)
GetBufferSet(position.instruction, position.index)
.AddBuffer(unified_buffer.id());
}

View File

@ -45,7 +45,7 @@ class HloAliasAnalysis {
InstructionBufferSet& GetInstructionBufferSet(
const HloInstruction* instruction);
// Return the HloBufferSet for the given location.
// Return the HloBufferSet for the given position.
const HloBufferSet& GetBufferSet(const HloInstruction* instruction,
const ShapeIndex& index = {}) const;
HloBufferSet& GetBufferSet(const HloInstruction* instruction,
@ -59,8 +59,8 @@ class HloAliasAnalysis {
return buffers_.at(buffer_id);
}
// Returns the unique buffer at the given location. CHECK fails if the buffer
// set at that location does not contain exactly one buffer.
// Returns the unique buffer at the given position. CHECK fails if the buffer
// set at that position does not contain exactly one buffer.
const HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction,
const ShapeIndex& index = {}) const {
return GetBuffer(GetBufferSet(instruction, index).GetUniqueBufferId());

View File

@ -46,7 +46,7 @@ class HloAliasAnalysisTest : public HloTestBase {
return *analysis_;
}
// Return a vector of the buffers in the buffer set at the current location.
// Return a vector of the buffers in the buffer set at the current position.
std::vector<HloBuffer> GetBuffersAt(const HloInstruction* instruction,
const ShapeIndex& index = {}) const {
std::vector<HloBuffer> buffers;
@ -66,7 +66,7 @@ class HloAliasAnalysisTest : public HloTestBase {
return values;
}
// Return the HloValue defined at the given location.
// Return the HloValue defined at the given position.
const HloValue& GetValueDefinedAt(const HloInstruction* instruction,
const ShapeIndex& index = {}) const {
return analysis_->dataflow_analysis().GetValueDefinedAt(instruction, index);
@ -174,11 +174,11 @@ TEST_F(HloAliasAnalysisTest, TupleAndGtes) {
EXPECT_EQ(analysis.GetUniqueBufferAt(param0),
analysis.GetUniqueBufferAt(gte0));
// Verify the locations of an aliased buffer.
// Verify the positions of an aliased buffer.
EXPECT_THAT(
analysis.GetUniqueBufferAt(param0).locations(),
UnorderedElementsAre(HloLocation{param0, {}}, HloLocation{tuple, {0}},
HloLocation{gte0, {}}));
analysis.GetUniqueBufferAt(param0).positions(),
UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
HloPosition{gte0, {}}));
EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous());
EXPECT_TRUE(analysis.GetInstructionBufferSet(tuple).IsDistinct());
@ -201,9 +201,9 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) {
const HloAliasAnalysis& analysis = RunAnalysis();
EXPECT_THAT(
analysis.GetUniqueBufferAt(param0).locations(),
UnorderedElementsAre(HloLocation{param0, {}}, HloLocation{tuple, {0}},
HloLocation{tuple, {2}}));
analysis.GetUniqueBufferAt(param0).positions(),
UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
HloPosition{tuple, {2}}));
EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous());
EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsDistinct());
@ -236,17 +236,17 @@ TEST_F(HloAliasAnalysisTest, SingleCall) {
const HloAliasAnalysis& analysis = RunAnalysis();
// Verify aliasing of the kCall operands and the subcomputation parameters.
EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).locations(),
UnorderedElementsAre(HloLocation{constant1, {}},
HloLocation{subparam0, {}}));
EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).locations(),
UnorderedElementsAre(HloLocation{constant2, {}},
HloLocation{subparam1, {}}));
EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).positions(),
UnorderedElementsAre(HloPosition{constant1, {}},
HloPosition{subparam0, {}}));
EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).positions(),
UnorderedElementsAre(HloPosition{constant2, {}},
HloPosition{subparam1, {}}));
// The subcomputation root and the kCall itself should alias.
EXPECT_THAT(
analysis.GetUniqueBufferAt(add).locations(),
UnorderedElementsAre(HloLocation{add, {}}, HloLocation{call, {}}));
analysis.GetUniqueBufferAt(add).positions(),
UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call, {}}));
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
}
@ -276,20 +276,20 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) {
const HloAliasAnalysis& analysis = RunAnalysis();
EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).locations(),
UnorderedElementsAre(HloLocation{constant1, {}},
HloLocation{subparam0, {}}));
EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).locations(),
UnorderedElementsAre(HloLocation{constant2, {}},
HloLocation{subparam1, {}}));
EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).positions(),
UnorderedElementsAre(HloPosition{constant1, {}},
HloPosition{subparam0, {}}));
EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).positions(),
UnorderedElementsAre(HloPosition{constant2, {}},
HloPosition{subparam1, {}}));
// The 'add' (root of the subcomputation) aliases the two call instruction,
// and the first parameter of the subcomputation because 'call1' it is passed
// as an argument to the subcomputation in 'call2'.
EXPECT_THAT(
analysis.GetUniqueBufferAt(add).locations(),
UnorderedElementsAre(HloLocation{add, {}}, HloLocation{call1, {}},
HloLocation{subparam0, {}}, HloLocation{call2, {}}));
analysis.GetUniqueBufferAt(add).positions(),
UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call1, {}},
HloPosition{subparam0, {}}, HloPosition{call2, {}}));
EXPECT_THAT(GetBuffersAt(subparam0),
UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
@ -361,24 +361,24 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) {
const HloAliasAnalysis& analysis = RunAnalysis();
// Verify the locations of the aliased while buffers.
EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).locations(),
// Verify the positions of the aliased while buffers.
EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).positions(),
UnorderedElementsAre(
HloLocation{tuple, {}}, HloLocation{xla_while, {}},
HloLocation{body_param, {}}, HloLocation{body_tuple, {}},
HloLocation{cond_param, {}}));
EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).locations(),
HloPosition{tuple, {}}, HloPosition{xla_while, {}},
HloPosition{body_param, {}}, HloPosition{body_tuple, {}},
HloPosition{cond_param, {}}));
EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).positions(),
UnorderedElementsAre(
HloLocation{constant1, {}}, HloLocation{tuple, {0}},
HloLocation{xla_while, {0}}, HloLocation{body_param, {0}},
HloLocation{body_element_0, {}}, HloLocation{body_tuple, {0}},
HloLocation{cond_param, {0}}));
EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).locations(),
HloPosition{constant1, {}}, HloPosition{tuple, {0}},
HloPosition{xla_while, {0}}, HloPosition{body_param, {0}},
HloPosition{body_element_0, {}}, HloPosition{body_tuple, {0}},
HloPosition{cond_param, {0}}));
EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).positions(),
UnorderedElementsAre(
HloLocation{constant2, {}}, HloLocation{tuple, {1}},
HloLocation{xla_while, {1}}, HloLocation{body_param, {1}},
HloLocation{body_element_1, {}}, HloLocation{add, {}},
HloLocation{body_tuple, {1}}, HloLocation{cond_param, {1}}));
HloPosition{constant2, {}}, HloPosition{tuple, {1}},
HloPosition{xla_while, {1}}, HloPosition{body_param, {1}},
HloPosition{body_element_1, {}}, HloPosition{add, {}},
HloPosition{body_tuple, {1}}, HloPosition{cond_param, {1}}));
EXPECT_THAT(
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})),
@ -619,7 +619,7 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) {
const HloAliasAnalysis& analysis = RunAnalysis();
// The swizzling while makes most locations in the module alias leaving only 3
// The swizzling while makes most positions in the module alias leaving only 3
// HloBuffers.
EXPECT_THAT(
analysis.buffers(),

View File

@ -46,11 +46,11 @@ void HloBuffer::AddValue(const HloValue& value) {
value_ids_.push_back(value.id());
// Add all of the locations of the HloValue to this buffer.
for (const HloLocation& location : value.locations()) {
if (std::find(locations_.begin(), locations_.end(), location) ==
locations_.end()) {
locations_.push_back(location);
// Add all of the positions of the HloValue to this buffer.
for (const HloPosition& position : value.positions()) {
if (std::find(positions_.begin(), positions_.end(), position) ==
positions_.end()) {
positions_.push_back(position);
}
}
}
@ -60,7 +60,7 @@ bool HloBuffer::operator==(const HloBuffer& other) const {
if (equal) {
// DCHECK because these comparisons are expensive (linear time).
DCHECK(value_ids() == other.value_ids());
DCHECK(locations() == other.locations());
DCHECK(positions() == other.positions());
}
return equal;
}

View File

@ -33,7 +33,7 @@ namespace xla {
// from. Generally there is a one-to-one correspondence between HloBuffers and
// HloValue where each HloValue in the module is held in a unique HloBuffer. An
// exception is the while instruction which updates the loop state in-place. In
// this case, we have a single HloBuffer for each HloLocation in the loop state,
// this case, we have a single HloBuffer for each HloPosition in the loop state,
// but multiple HloValues. For example:
//
// %init = ...
@ -53,7 +53,7 @@ namespace xla {
// HloValue{%while}, HloValue{%body_param}, HloValue{%body_root}, and
// HloValue{%cond_param}.
//
// HloBuffers may appear at different HloLocations in the module mirroring the
// HloBuffers may appear at different HloPositions in the module mirroring the
// same propery of HloValues. For example:
//
// %sub = Sub(...)
@ -62,11 +62,11 @@ namespace xla {
// %gte = GetTupleElement(%tuple, 0)
//
// In this case, the HloBuffer containing %add appears at the following
// locations: HloLocation{%add, {}}, HloLocation{%tuple, {0}}, and
// HloLocation{%gte, {}}.
// positions: HloPosition{%add, {}}, HloPosition{%tuple, {0}}, and
// HloPosition{%gte, {}}.
//
// Different HloLocations which share the same HloBuffer indicate mandatory
// aliasing in the HLO module. These locations must share the same memory
// Different HloPositions which share the same HloBuffer indicate mandatory
// aliasing in the HLO module. These positions must share the same memory
// allocation for correctness (the backends rely on this property). This differs
// from incidental aliasing introduced by memory reuse in BufferAssignment where
// different instructions may happen to get the same allocation.
@ -80,17 +80,17 @@ class HloBuffer {
Id id() const { return id_; }
// Add a value to the set of values held by this buffer. Also adds the
// HloLocations of the value to the locations vector of the buffer. If the
// HloPositions of the value to the positions vector of the buffer. If the
// buffer already contains this value, then this method is a nop.
void AddValue(const HloValue& value);
// Return the IDs of all values contained in this buffer.
const std::vector<HloValue::Id>& value_ids() const { return value_ids_; }
// Return the locations (output of which instruction and at what index) where
// the buffer is used. This is exactly the union of the locations of the
// Return the positions (output of which instruction and at what index) where
// the buffer is used. This is exactly the union of the positions of the
// HloValues contained by the buffer.
const std::vector<HloLocation>& locations() const { return locations_; }
const std::vector<HloPosition>& positions() const { return positions_; }
string ToString() const;
@ -104,16 +104,16 @@ class HloBuffer {
// The set of values contained in this buffer.
std::vector<HloValue::Id> value_ids_;
// The set of locations where this buffer is used.
std::vector<HloLocation> locations_;
// The set of positions where this buffer is used.
std::vector<HloPosition> positions_;
};
std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer);
// A class representing the set of possible HloBuffers at a particular
// HloLocation (shape index in the output of an instruction) in the XLA
// HloPosition (shape index in the output of an instruction) in the XLA
// graph. In most cases, the buffer set will have a single HloBuffer indicating
// that the HloBuffer which appears at that particular location is known
// that the HloBuffer which appears at that particular position is known
// unambiguously at compile-time. However, tuple-shaped Select instructions can
// introduce ambiguity as the tuple elements of the operands are passed by
// reference into the output of the Select. For example:
@ -123,7 +123,7 @@ std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer);
// %tuple1 = Tuple(%x, %y)
// %select = Select(%pred, %tuple0, %tuple1)
//
// In this case the HloBufferSet at HloLocation{%select, {0}} contains the
// In this case the HloBufferSet at HloPosition{%select, {0}} contains the
// HloBuffer holding %a and the HloBuffer holding %x.
class HloBufferSet {
public:

View File

@ -247,11 +247,11 @@ InstructionValueSet HloDataflowAnalysis::Phi(
return new_value_set;
}
void HloDataflowAnalysis::UpdateLocationsOfValuesAt(
void HloDataflowAnalysis::UpdatePositionsOfValuesAt(
HloInstruction* instruction, const InstructionValueSet& new_value_set,
const InstructionValueSet* prev_value_set) {
if (prev_value_set != nullptr) {
// Remove locations from the old value set.
// Remove positions from the old value set.
prev_value_set->ForEachElement(
[this, instruction](const ShapeIndex& index,
const HloValueSet& value_set) {
@ -260,17 +260,17 @@ void HloDataflowAnalysis::UpdateLocationsOfValuesAt(
if (!ContainsKey(values_, value_id)) {
continue;
}
// Don't remove the defining location of the value.
// Don't remove the defining position of the value.
HloValue& value = GetValue(value_id);
if (instruction == value.defining_instruction()) {
CHECK_EQ(index, value.defining_index());
} else {
value.RemoveLocation(instruction, index);
value.RemovePosition(instruction, index);
}
}
});
}
// Add locations in the new value set.
// Add positions in the new value set.
new_value_set.ForEachElement(
[this, instruction](const ShapeIndex& index,
const HloValueSet& value_set) {
@ -279,7 +279,7 @@ void HloDataflowAnalysis::UpdateLocationsOfValuesAt(
if (instruction == value.defining_instruction()) {
CHECK_EQ(index, value.defining_index());
} else {
value.AddLocation(instruction, index);
value.AddPosition(instruction, index);
}
}
});
@ -466,7 +466,7 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
// Update uses. First clear all of the old uses at the particular
// operands. Then add the new uses. There may be overlap between the old
// uses and new uses.
UpdateLocationsOfValuesAt(instruction, GetInstructionValueSet(instruction),
UpdatePositionsOfValuesAt(instruction, GetInstructionValueSet(instruction),
&old_value);
}
}
@ -600,7 +600,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
define_all_values();
break;
}
UpdateLocationsOfValuesAt(instruction.get(),
UpdatePositionsOfValuesAt(instruction.get(),
GetInstructionValueSet(instruction.get()));
}
}

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Analysis for determining the possible set of values for all locations
// Analysis for determining the possible set of values for all positions
// (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped
// tracking values across computation boundaries.
@ -170,14 +170,14 @@ class HloDataflowAnalysis {
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs,
bool skip_top_level = false);
// Updates the locations of the HloValues in the output of the given
// Updates the positions of the HloValues in the output of the given
// instruction. This should be called after the instruction value set of
// 'instruction' has been changed. 'prev_value_set' must point to the previous
// state of the value set prior to the change. 'prev_value_set' may be null if
// this is the first time locations are being computed. The previous state is
// necessary to efficiently remove locations which have been eliminated due to
// this is the first time positions are being computed. The previous state is
// necessary to efficiently remove positions which have been eliminated due to
// changes in the instructions' InstructionValueSet.
void UpdateLocationsOfValuesAt(
void UpdatePositionsOfValuesAt(
HloInstruction* instruction, const InstructionValueSet& new_value_set,
const InstructionValueSet* prev_value_set = nullptr);

View File

@ -51,7 +51,7 @@ class HloDataflowAnalysisTest : public HloTestBase,
return *analysis_;
}
// Return a vector of the HloValues at the given program location.
// Return a vector of the HloValues at the given program position.
std::vector<HloValue> HloValuesAt(const HloInstruction* instruction,
const ShapeIndex& index = {}) {
CHECK(analysis_ != nullptr);
@ -101,14 +101,14 @@ TEST_P(HloDataflowAnalysisTest, BinaryOperation) {
EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
// Verify the locations of the values. These locations are all trivial because
// Verify the positions of the values. These positions are all trivial because
// there are no instructions which forward values.
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).locations(),
UnorderedElementsAre(HloLocation{constant1, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).locations(),
UnorderedElementsAre(HloLocation{constant2, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(add).locations(),
UnorderedElementsAre(HloLocation{add, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).positions(),
UnorderedElementsAre(HloPosition{constant1, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).positions(),
UnorderedElementsAre(HloPosition{constant2, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(add).positions(),
UnorderedElementsAre(HloPosition{add, {}}));
// Verify the uses of the values.
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
@ -155,17 +155,17 @@ TEST_P(HloDataflowAnalysisTest, TupleAndGtes) {
EXPECT_FALSE(analysis.ValueIsDefinedAt(gte1));
EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
// Verify the locations of the values.
// Verify the positions of the values.
EXPECT_THAT(
analysis.GetValueDefinedAt(param0).locations(),
UnorderedElementsAre(HloLocation{param0, {}}, HloLocation{tuple, {0}},
HloLocation{gte0, {}}));
analysis.GetValueDefinedAt(param0).positions(),
UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
HloPosition{gte0, {}}));
EXPECT_THAT(
analysis.GetValueDefinedAt(param1).locations(),
UnorderedElementsAre(HloLocation{param1, {}}, HloLocation{tuple, {1}},
HloLocation{gte1, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(tuple).locations(),
UnorderedElementsAre(HloLocation{tuple, {}}));
analysis.GetValueDefinedAt(param1).positions(),
UnorderedElementsAre(HloPosition{param1, {}}, HloPosition{tuple, {1}},
HloPosition{gte1, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(tuple).positions(),
UnorderedElementsAre(HloPosition{tuple, {}}));
// Verify uses. Of interest is that a GetTupleElement instruction is only a
// use of the top-level value in the tuple operand.
@ -200,15 +200,15 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) {
EXPECT_EQ(analysis.values().size(), 4);
// Verify locations and uses.
// Verify positions and uses.
EXPECT_THAT(
analysis.GetValueDefinedAt(constant1).locations(),
analysis.GetValueDefinedAt(constant1).positions(),
UnorderedElementsAre(
HloLocation{constant1, {}}, HloLocation{tuple, {0}},
HloLocation{nested_tuple, {0, 0}}, HloLocation{nested_tuple, {1, 0}},
HloLocation{nested_tuple, {2}}, HloLocation{gte_tuple, {0}},
HloLocation{gte_out, {}}));
// Constant values should have no uses though one is live out. The locations
HloPosition{constant1, {}}, HloPosition{tuple, {0}},
HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}},
HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}},
HloPosition{gte_out, {}}));
// Constant values should have no uses though one is live out. The positions
// where they appear as operands are on instructions which do not use the
// values (eg, Tuple).
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty());

View File

@ -38,18 +38,18 @@ namespace xla {
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
const Shape& HloLocation::shape() const {
const Shape& HloPosition::shape() const {
return ShapeUtil::GetSubshape(instruction->shape(), index);
}
string HloLocation::ToString() const {
string HloPosition::ToString() const {
string index_str =
ShapeUtil::IsTuple(instruction->shape()) ? (" " + index.ToString()) : "";
return StrCat(instruction->name(), index_str);
}
std::ostream& operator<<(std::ostream& out, const HloLocation& location) {
out << location.ToString();
std::ostream& operator<<(std::ostream& out, const HloPosition& position) {
out << position.ToString();
return out;
}
@ -69,8 +69,8 @@ std::ostream& operator<<(std::ostream& out, const HloUse& use) {
HloValue::HloValue(HloValue::Id id, HloInstruction* instruction,
const ShapeIndex& index, bool is_phi)
: id_(id), is_phi_(is_phi) {
// The defining location is always the first element in the locations_ vector.
AddLocation(instruction, index);
// The defining position is always the first element in the positions_ vector.
AddPosition(instruction, index);
}
bool HloValue::operator==(const HloValue& other) const {
@ -95,9 +95,9 @@ string HloValue::ToShortString() const {
string HloValue::ToString(int indent) const {
string indentation(indent, ' ');
string out = StrCat(indentation, ToShortString(), ", locations:\n");
for (const HloLocation& location : locations()) {
StrAppend(&out, indentation, " ", location.ToString(), "\n");
string out = StrCat(indentation, ToShortString(), ", positions:\n");
for (const HloPosition& position : positions()) {
StrAppend(&out, indentation, " ", position.ToString(), "\n");
}
StrAppend(&out, indentation, " uses:\n");
for (const HloUse& use : uses()) {
@ -150,22 +150,22 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
} // namespace
void HloValue::AddLocation(HloInstruction* instruction,
void HloValue::AddPosition(HloInstruction* instruction,
const ShapeIndex& index) {
HloLocation new_location{instruction, index};
HloPosition new_position{instruction, index};
// The new location must not already exist in locations_.
for (const HloLocation& location : locations_) {
DCHECK_NE(location, new_location);
// The new position must not already exist in positions_.
for (const HloPosition& position : positions_) {
DCHECK_NE(position, new_position);
}
// The shape of the new location must match existing locations.
if (!locations_.empty()) {
// The shape of the new position must match existing positions.
if (!positions_.empty()) {
CHECK(
ShapeUtil::Compatible(locations_.front().shape(), new_location.shape()))
<< "front: " << locations_.front() << " new: " << new_location;
ShapeUtil::Compatible(positions_.front().shape(), new_position.shape()))
<< "front: " << positions_.front() << " new: " << new_position;
}
locations_.push_back(std::move(new_location));
positions_.push_back(std::move(new_position));
// Update uses.
for (HloInstruction* user : instruction->users()) {
@ -194,23 +194,23 @@ void HloValue::AddLocation(HloInstruction* instruction,
}
}
void HloValue::RemoveLocation(HloInstruction* instruction,
void HloValue::RemovePosition(HloInstruction* instruction,
const ShapeIndex& index) {
// The defining location cannot be removed.
// The defining position cannot be removed.
CHECK(!(instruction == defining_instruction() && index == defining_index()));
int64 size_before = locations_.size();
locations_.erase(
std::remove_if(locations_.begin(), locations_.end(),
[instruction, &index](const HloLocation& location) {
return location.instruction == instruction &&
location.index == index;
int64 size_before = positions_.size();
positions_.erase(
std::remove_if(positions_.begin(), positions_.end(),
[instruction, &index](const HloPosition& position) {
return position.instruction == instruction &&
position.index == index;
}),
locations_.end());
// Only a single location should have been removed.
CHECK_EQ(locations_.size(), size_before - 1);
positions_.end());
// Only a single position should have been removed.
CHECK_EQ(positions_.size(), size_before - 1);
// Update uses which referred to this location.
// Update uses which referred to this position.
uses_.erase(std::remove_if(uses_.begin(), uses_.end(),
[instruction, &index](const HloUse& use) {
return use.instruction->operand(
@ -221,8 +221,8 @@ void HloValue::RemoveLocation(HloInstruction* instruction,
// Returns whether this value is contained in the given instruction's output.
auto is_contained_in = [this](const HloInstruction* instruction) {
for (const HloLocation& location : locations()) {
if (location.instruction == instruction) {
for (const HloPosition& position : positions()) {
if (position.instruction == instruction) {
return true;
}
}
@ -231,7 +231,7 @@ void HloValue::RemoveLocation(HloInstruction* instruction,
const HloModule& module = *instruction->parent()->parent();
if (instruction == module.entry_computation()->root_instruction()) {
// Value has been removed from a location in the entry root instruction.
// Value has been removed from a position in the entry root instruction.
live_out_of_module_ =
is_contained_in(module.entry_computation()->root_instruction());
}

View File

@ -30,24 +30,24 @@ limitations under the License.
namespace xla {
// Abstraction which identifies a specific point in the XLA graph. An
// HloLocation specifies a ShapeIndex within the output of a specific
// HloPosition specifies a ShapeIndex within the output of a specific
// instruction.
struct HloLocation {
struct HloPosition {
HloInstruction* instruction;
ShapeIndex index;
// Returns the shape at this location.
// Returns the shape at this position.
const Shape& shape() const;
string ToString() const;
bool operator==(const HloLocation& other) const {
bool operator==(const HloPosition& other) const {
return instruction == other.instruction && index == other.index;
}
bool operator!=(const HloLocation& other) const { return !(*this == other); }
bool operator!=(const HloPosition& other) const { return !(*this == other); }
};
std::ostream& operator<<(std::ostream& out, const HloLocation& location);
std::ostream& operator<<(std::ostream& out, const HloPosition& position);
// Defines a single use of an HLO value.
struct HloUse {
@ -111,28 +111,28 @@ class HloValue {
// Returns whether this value is a phi value.
bool is_phi() const { return is_phi_; }
// Return the location where this value is defined.
const HloLocation& defining_location() const { return locations_[0]; }
// Return the position where this value is defined.
const HloPosition& defining_position() const { return positions_[0]; }
// Return the instruction which defines this HloValue.
HloInstruction* defining_instruction() const {
return defining_location().instruction;
return defining_position().instruction;
}
// Return the shape index at which this HloValue is defined in the output of
// its defining instruction.
const ShapeIndex& defining_index() const { return defining_location().index; }
const ShapeIndex& defining_index() const { return defining_position().index; }
// Return the shape of this HloValue.
const Shape& shape() const { return defining_location().shape(); }
const Shape& shape() const { return defining_position().shape(); }
// Add or remove a location at which the HloValue appears. The definition
// location can not be removed. The uses of the HloValue are updated.
void AddLocation(HloInstruction* instruction, const ShapeIndex& index);
void RemoveLocation(HloInstruction* instruction, const ShapeIndex& index);
// Add or remove a position at which the HloValue appears. The definition
// position can not be removed. The uses of the HloValue are updated.
void AddPosition(HloInstruction* instruction, const ShapeIndex& index);
void RemovePosition(HloInstruction* instruction, const ShapeIndex& index);
// Return all locations of the HloValue in the module.
const std::vector<HloLocation>& locations() const { return locations_; }
// Return all positions of the HloValue in the module.
const std::vector<HloPosition>& positions() const { return positions_; }
// Return all uses of the HloValue.
const std::vector<HloUse>& uses() const { return uses_; }
@ -158,9 +158,9 @@ class HloValue {
// Whether this instruction is a phi value.
const bool is_phi_;
// The set of locations of this HloValue. The first element is always the
// location of the definition.
std::vector<HloLocation> locations_;
// The set of positions of this HloValue. The first element is always the
// position of the definition.
std::vector<HloPosition> positions_;
// The set of uses of this HloValue.
std::vector<HloUse> uses_;