Rename HloLocation to HloPosition, to avoid ambiguity with MemoryLocation.

PiperOrigin-RevId: 161716528
This commit is contained in:
A. Unique TensorFlower 2017-07-12 14:18:55 -07:00 committed by TensorFlower Gardener
parent 8e7f573716
commit 6b28eb0843
10 changed files with 162 additions and 162 deletions

View File

@ -94,13 +94,13 @@ void HloAliasAnalysis::CombineBuffers(
unified_buffer.AddValue(dataflow_analysis_->GetValue(value_id)); unified_buffer.AddValue(dataflow_analysis_->GetValue(value_id));
} }
// Iterate through all locations where the buffer-to-eliminate exists and // Iterate through all positions where the buffer-to-eliminate exists and
// replace it with the unified buffer. // replace it with the unified buffer.
for (const HloLocation& location : buffer.locations()) { for (const HloPosition& position : buffer.positions()) {
VLOG(4) << "Replacing in " << location; VLOG(4) << "Replacing in " << position;
GetBufferSet(location.instruction, location.index) GetBufferSet(position.instruction, position.index)
.RemoveBufferOrDie(buffer_id); .RemoveBufferOrDie(buffer_id);
GetBufferSet(location.instruction, location.index) GetBufferSet(position.instruction, position.index)
.AddBuffer(unified_buffer.id()); .AddBuffer(unified_buffer.id());
} }

View File

@ -45,7 +45,7 @@ class HloAliasAnalysis {
InstructionBufferSet& GetInstructionBufferSet( InstructionBufferSet& GetInstructionBufferSet(
const HloInstruction* instruction); const HloInstruction* instruction);
// Return the HloBufferSet for the given location. // Return the HloBufferSet for the given position.
const HloBufferSet& GetBufferSet(const HloInstruction* instruction, const HloBufferSet& GetBufferSet(const HloInstruction* instruction,
const ShapeIndex& index = {}) const; const ShapeIndex& index = {}) const;
HloBufferSet& GetBufferSet(const HloInstruction* instruction, HloBufferSet& GetBufferSet(const HloInstruction* instruction,
@ -59,8 +59,8 @@ class HloAliasAnalysis {
return buffers_.at(buffer_id); return buffers_.at(buffer_id);
} }
// Returns the unique buffer at the given location. CHECK fails if the buffer // Returns the unique buffer at the given position. CHECK fails if the buffer
// set at that location does not contain exactly one buffer. // set at that position does not contain exactly one buffer.
const HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction, const HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction,
const ShapeIndex& index = {}) const { const ShapeIndex& index = {}) const {
return GetBuffer(GetBufferSet(instruction, index).GetUniqueBufferId()); return GetBuffer(GetBufferSet(instruction, index).GetUniqueBufferId());

View File

@ -46,7 +46,7 @@ class HloAliasAnalysisTest : public HloTestBase {
return *analysis_; return *analysis_;
} }
// Return a vector of the buffers in the buffer set at the current location. // Return a vector of the buffers in the buffer set at the current position.
std::vector<HloBuffer> GetBuffersAt(const HloInstruction* instruction, std::vector<HloBuffer> GetBuffersAt(const HloInstruction* instruction,
const ShapeIndex& index = {}) const { const ShapeIndex& index = {}) const {
std::vector<HloBuffer> buffers; std::vector<HloBuffer> buffers;
@ -66,7 +66,7 @@ class HloAliasAnalysisTest : public HloTestBase {
return values; return values;
} }
// Return the HloValue defined at the given location. // Return the HloValue defined at the given position.
const HloValue& GetValueDefinedAt(const HloInstruction* instruction, const HloValue& GetValueDefinedAt(const HloInstruction* instruction,
const ShapeIndex& index = {}) const { const ShapeIndex& index = {}) const {
return analysis_->dataflow_analysis().GetValueDefinedAt(instruction, index); return analysis_->dataflow_analysis().GetValueDefinedAt(instruction, index);
@ -174,11 +174,11 @@ TEST_F(HloAliasAnalysisTest, TupleAndGtes) {
EXPECT_EQ(analysis.GetUniqueBufferAt(param0), EXPECT_EQ(analysis.GetUniqueBufferAt(param0),
analysis.GetUniqueBufferAt(gte0)); analysis.GetUniqueBufferAt(gte0));
// Verify the locations of an aliased buffer. // Verify the positions of an aliased buffer.
EXPECT_THAT( EXPECT_THAT(
analysis.GetUniqueBufferAt(param0).locations(), analysis.GetUniqueBufferAt(param0).positions(),
UnorderedElementsAre(HloLocation{param0, {}}, HloLocation{tuple, {0}}, UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
HloLocation{gte0, {}})); HloPosition{gte0, {}}));
EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous()); EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous());
EXPECT_TRUE(analysis.GetInstructionBufferSet(tuple).IsDistinct()); EXPECT_TRUE(analysis.GetInstructionBufferSet(tuple).IsDistinct());
@ -201,9 +201,9 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) {
const HloAliasAnalysis& analysis = RunAnalysis(); const HloAliasAnalysis& analysis = RunAnalysis();
EXPECT_THAT( EXPECT_THAT(
analysis.GetUniqueBufferAt(param0).locations(), analysis.GetUniqueBufferAt(param0).positions(),
UnorderedElementsAre(HloLocation{param0, {}}, HloLocation{tuple, {0}}, UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
HloLocation{tuple, {2}})); HloPosition{tuple, {2}}));
EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous()); EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous());
EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsDistinct()); EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsDistinct());
@ -236,17 +236,17 @@ TEST_F(HloAliasAnalysisTest, SingleCall) {
const HloAliasAnalysis& analysis = RunAnalysis(); const HloAliasAnalysis& analysis = RunAnalysis();
// Verify aliasing of the kCall operands and the subcomputation parameters. // Verify aliasing of the kCall operands and the subcomputation parameters.
EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).locations(), EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).positions(),
UnorderedElementsAre(HloLocation{constant1, {}}, UnorderedElementsAre(HloPosition{constant1, {}},
HloLocation{subparam0, {}})); HloPosition{subparam0, {}}));
EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).locations(), EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).positions(),
UnorderedElementsAre(HloLocation{constant2, {}}, UnorderedElementsAre(HloPosition{constant2, {}},
HloLocation{subparam1, {}})); HloPosition{subparam1, {}}));
// The subcomputation root and the kCall itself should alias. // The subcomputation root and the kCall itself should alias.
EXPECT_THAT( EXPECT_THAT(
analysis.GetUniqueBufferAt(add).locations(), analysis.GetUniqueBufferAt(add).positions(),
UnorderedElementsAre(HloLocation{add, {}}, HloLocation{call, {}})); UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call, {}}));
EXPECT_FALSE(AnyValuesInSameBufferInterfere()); EXPECT_FALSE(AnyValuesInSameBufferInterfere());
} }
@ -276,20 +276,20 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) {
const HloAliasAnalysis& analysis = RunAnalysis(); const HloAliasAnalysis& analysis = RunAnalysis();
EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).locations(), EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).positions(),
UnorderedElementsAre(HloLocation{constant1, {}}, UnorderedElementsAre(HloPosition{constant1, {}},
HloLocation{subparam0, {}})); HloPosition{subparam0, {}}));
EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).locations(), EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).positions(),
UnorderedElementsAre(HloLocation{constant2, {}}, UnorderedElementsAre(HloPosition{constant2, {}},
HloLocation{subparam1, {}})); HloPosition{subparam1, {}}));
// The 'add' (root of the subcomputation) aliases the two call instruction, // The 'add' (root of the subcomputation) aliases the two call instruction,
// and the first parameter of the subcomputation because 'call1' it is passed // and the first parameter of the subcomputation because 'call1' it is passed
// as an argument to the subcomputation in 'call2'. // as an argument to the subcomputation in 'call2'.
EXPECT_THAT( EXPECT_THAT(
analysis.GetUniqueBufferAt(add).locations(), analysis.GetUniqueBufferAt(add).positions(),
UnorderedElementsAre(HloLocation{add, {}}, HloLocation{call1, {}}, UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call1, {}},
HloLocation{subparam0, {}}, HloLocation{call2, {}})); HloPosition{subparam0, {}}, HloPosition{call2, {}}));
EXPECT_THAT(GetBuffersAt(subparam0), EXPECT_THAT(GetBuffersAt(subparam0),
UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1), UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
@ -361,24 +361,24 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) {
const HloAliasAnalysis& analysis = RunAnalysis(); const HloAliasAnalysis& analysis = RunAnalysis();
// Verify the locations of the aliased while buffers. // Verify the positions of the aliased while buffers.
EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).locations(), EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).positions(),
UnorderedElementsAre( UnorderedElementsAre(
HloLocation{tuple, {}}, HloLocation{xla_while, {}}, HloPosition{tuple, {}}, HloPosition{xla_while, {}},
HloLocation{body_param, {}}, HloLocation{body_tuple, {}}, HloPosition{body_param, {}}, HloPosition{body_tuple, {}},
HloLocation{cond_param, {}})); HloPosition{cond_param, {}}));
EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).locations(), EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).positions(),
UnorderedElementsAre( UnorderedElementsAre(
HloLocation{constant1, {}}, HloLocation{tuple, {0}}, HloPosition{constant1, {}}, HloPosition{tuple, {0}},
HloLocation{xla_while, {0}}, HloLocation{body_param, {0}}, HloPosition{xla_while, {0}}, HloPosition{body_param, {0}},
HloLocation{body_element_0, {}}, HloLocation{body_tuple, {0}}, HloPosition{body_element_0, {}}, HloPosition{body_tuple, {0}},
HloLocation{cond_param, {0}})); HloPosition{cond_param, {0}}));
EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).locations(), EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).positions(),
UnorderedElementsAre( UnorderedElementsAre(
HloLocation{constant2, {}}, HloLocation{tuple, {1}}, HloPosition{constant2, {}}, HloPosition{tuple, {1}},
HloLocation{xla_while, {1}}, HloLocation{body_param, {1}}, HloPosition{xla_while, {1}}, HloPosition{body_param, {1}},
HloLocation{body_element_1, {}}, HloLocation{add, {}}, HloPosition{body_element_1, {}}, HloPosition{add, {}},
HloLocation{body_tuple, {1}}, HloLocation{cond_param, {1}})); HloPosition{body_tuple, {1}}, HloPosition{cond_param, {1}}));
EXPECT_THAT( EXPECT_THAT(
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})), GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})),
@ -619,7 +619,7 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) {
const HloAliasAnalysis& analysis = RunAnalysis(); const HloAliasAnalysis& analysis = RunAnalysis();
// The swizzling while makes most locations in the module alias leaving only 3 // The swizzling while makes most positions in the module alias leaving only 3
// HloBuffers. // HloBuffers.
EXPECT_THAT( EXPECT_THAT(
analysis.buffers(), analysis.buffers(),

View File

@ -46,11 +46,11 @@ void HloBuffer::AddValue(const HloValue& value) {
value_ids_.push_back(value.id()); value_ids_.push_back(value.id());
// Add all of the locations of the HloValue to this buffer. // Add all of the positions of the HloValue to this buffer.
for (const HloLocation& location : value.locations()) { for (const HloPosition& position : value.positions()) {
if (std::find(locations_.begin(), locations_.end(), location) == if (std::find(positions_.begin(), positions_.end(), position) ==
locations_.end()) { positions_.end()) {
locations_.push_back(location); positions_.push_back(position);
} }
} }
} }
@ -60,7 +60,7 @@ bool HloBuffer::operator==(const HloBuffer& other) const {
if (equal) { if (equal) {
// DCHECK because these comparisons are expensive (linear time). // DCHECK because these comparisons are expensive (linear time).
DCHECK(value_ids() == other.value_ids()); DCHECK(value_ids() == other.value_ids());
DCHECK(locations() == other.locations()); DCHECK(positions() == other.positions());
} }
return equal; return equal;
} }

View File

@ -33,7 +33,7 @@ namespace xla {
// from. Generally there is a one-to-one correspondence between HloBuffers and // from. Generally there is a one-to-one correspondence between HloBuffers and
// HloValue where each HloValue in the module is held in a unique HloBuffer. An // HloValue where each HloValue in the module is held in a unique HloBuffer. An
// exception is the while instruction which updates the loop state in-place. In // exception is the while instruction which updates the loop state in-place. In
// this case, we have a single HloBuffer for each HloLocation in the loop state, // this case, we have a single HloBuffer for each HloPosition in the loop state,
// but multiple HloValues. For example: // but multiple HloValues. For example:
// //
// %init = ... // %init = ...
@ -53,7 +53,7 @@ namespace xla {
// HloValue{%while}, HloValue{%body_param}, HloValue{%body_root}, and // HloValue{%while}, HloValue{%body_param}, HloValue{%body_root}, and
// HloValue{%cond_param}. // HloValue{%cond_param}.
// //
// HloBuffers may appear at different HloLocations in the module mirroring the // HloBuffers may appear at different HloPositions in the module mirroring the
// same propery of HloValues. For example: // same propery of HloValues. For example:
// //
// %sub = Sub(...) // %sub = Sub(...)
@ -62,11 +62,11 @@ namespace xla {
// %gte = GetTupleElement(%tuple, 0) // %gte = GetTupleElement(%tuple, 0)
// //
// In this case, the HloBuffer containing %add appears at the following // In this case, the HloBuffer containing %add appears at the following
// locations: HloLocation{%add, {}}, HloLocation{%tuple, {0}}, and // positions: HloPosition{%add, {}}, HloPosition{%tuple, {0}}, and
// HloLocation{%gte, {}}. // HloPosition{%gte, {}}.
// //
// Different HloLocations which share the same HloBuffer indicate mandatory // Different HloPositions which share the same HloBuffer indicate mandatory
// aliasing in the HLO module. These locations must share the same memory // aliasing in the HLO module. These positions must share the same memory
// allocation for correctness (the backends rely on this property). This differs // allocation for correctness (the backends rely on this property). This differs
// from incidental aliasing introduced by memory reuse in BufferAssignment where // from incidental aliasing introduced by memory reuse in BufferAssignment where
// different instructions may happen to get the same allocation. // different instructions may happen to get the same allocation.
@ -80,17 +80,17 @@ class HloBuffer {
Id id() const { return id_; } Id id() const { return id_; }
// Add a value to the set of values held by this buffer. Also adds the // Add a value to the set of values held by this buffer. Also adds the
// HloLocations of the value to the locations vector of the buffer. If the // HloPositions of the value to the positions vector of the buffer. If the
// buffer already contains this value, then this method is a nop. // buffer already contains this value, then this method is a nop.
void AddValue(const HloValue& value); void AddValue(const HloValue& value);
// Return the IDs of all values contained in this buffer. // Return the IDs of all values contained in this buffer.
const std::vector<HloValue::Id>& value_ids() const { return value_ids_; } const std::vector<HloValue::Id>& value_ids() const { return value_ids_; }
// Return the locations (output of which instruction and at what index) where // Return the positions (output of which instruction and at what index) where
// the buffer is used. This is exactly the union of the locations of the // the buffer is used. This is exactly the union of the positions of the
// HloValues contained by the buffer. // HloValues contained by the buffer.
const std::vector<HloLocation>& locations() const { return locations_; } const std::vector<HloPosition>& positions() const { return positions_; }
string ToString() const; string ToString() const;
@ -104,16 +104,16 @@ class HloBuffer {
// The set of values contained in this buffer. // The set of values contained in this buffer.
std::vector<HloValue::Id> value_ids_; std::vector<HloValue::Id> value_ids_;
// The set of locations where this buffer is used. // The set of positions where this buffer is used.
std::vector<HloLocation> locations_; std::vector<HloPosition> positions_;
}; };
std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer); std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer);
// A class representing the set of possible HloBuffers at a particular // A class representing the set of possible HloBuffers at a particular
// HloLocation (shape index in the output of an instruction) in the XLA // HloPosition (shape index in the output of an instruction) in the XLA
// graph. In most cases, the buffer set will have a single HloBuffer indicating // graph. In most cases, the buffer set will have a single HloBuffer indicating
// that the HloBuffer which appears at that particular location is known // that the HloBuffer which appears at that particular position is known
// unambiguously at compile-time. However, tuple-shaped Select instructions can // unambiguously at compile-time. However, tuple-shaped Select instructions can
// introduce ambiguity as the tuple elements of the operands are passed by // introduce ambiguity as the tuple elements of the operands are passed by
// reference into the output of the Select. For example: // reference into the output of the Select. For example:
@ -123,7 +123,7 @@ std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer);
// %tuple1 = Tuple(%x, %y) // %tuple1 = Tuple(%x, %y)
// %select = Select(%pred, %tuple0, %tuple1) // %select = Select(%pred, %tuple0, %tuple1)
// //
// In this case the HloBufferSet at HloLocation{%select, {0}} contains the // In this case the HloBufferSet at HloPosition{%select, {0}} contains the
// HloBuffer holding %a and the HloBuffer holding %x. // HloBuffer holding %a and the HloBuffer holding %x.
class HloBufferSet { class HloBufferSet {
public: public:

View File

@ -247,11 +247,11 @@ InstructionValueSet HloDataflowAnalysis::Phi(
return new_value_set; return new_value_set;
} }
void HloDataflowAnalysis::UpdateLocationsOfValuesAt( void HloDataflowAnalysis::UpdatePositionsOfValuesAt(
HloInstruction* instruction, const InstructionValueSet& new_value_set, HloInstruction* instruction, const InstructionValueSet& new_value_set,
const InstructionValueSet* prev_value_set) { const InstructionValueSet* prev_value_set) {
if (prev_value_set != nullptr) { if (prev_value_set != nullptr) {
// Remove locations from the old value set. // Remove positions from the old value set.
prev_value_set->ForEachElement( prev_value_set->ForEachElement(
[this, instruction](const ShapeIndex& index, [this, instruction](const ShapeIndex& index,
const HloValueSet& value_set) { const HloValueSet& value_set) {
@ -260,17 +260,17 @@ void HloDataflowAnalysis::UpdateLocationsOfValuesAt(
if (!ContainsKey(values_, value_id)) { if (!ContainsKey(values_, value_id)) {
continue; continue;
} }
// Don't remove the defining location of the value. // Don't remove the defining position of the value.
HloValue& value = GetValue(value_id); HloValue& value = GetValue(value_id);
if (instruction == value.defining_instruction()) { if (instruction == value.defining_instruction()) {
CHECK_EQ(index, value.defining_index()); CHECK_EQ(index, value.defining_index());
} else { } else {
value.RemoveLocation(instruction, index); value.RemovePosition(instruction, index);
} }
} }
}); });
} }
// Add locations in the new value set. // Add positions in the new value set.
new_value_set.ForEachElement( new_value_set.ForEachElement(
[this, instruction](const ShapeIndex& index, [this, instruction](const ShapeIndex& index,
const HloValueSet& value_set) { const HloValueSet& value_set) {
@ -279,7 +279,7 @@ void HloDataflowAnalysis::UpdateLocationsOfValuesAt(
if (instruction == value.defining_instruction()) { if (instruction == value.defining_instruction()) {
CHECK_EQ(index, value.defining_index()); CHECK_EQ(index, value.defining_index());
} else { } else {
value.AddLocation(instruction, index); value.AddPosition(instruction, index);
} }
} }
}); });
@ -466,7 +466,7 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
// Update uses. First clear all of the old uses at the particular // Update uses. First clear all of the old uses at the particular
// operands. Then add the new uses. There may be overlap between the old // operands. Then add the new uses. There may be overlap between the old
// uses and new uses. // uses and new uses.
UpdateLocationsOfValuesAt(instruction, GetInstructionValueSet(instruction), UpdatePositionsOfValuesAt(instruction, GetInstructionValueSet(instruction),
&old_value); &old_value);
} }
} }
@ -600,7 +600,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
define_all_values(); define_all_values();
break; break;
} }
UpdateLocationsOfValuesAt(instruction.get(), UpdatePositionsOfValuesAt(instruction.get(),
GetInstructionValueSet(instruction.get())); GetInstructionValueSet(instruction.get()));
} }
} }

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// Analysis for determining the possible set of values for all locations // Analysis for determining the possible set of values for all positions
// (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped // (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped
// tracking values across computation boundaries. // tracking values across computation boundaries.
@ -170,14 +170,14 @@ class HloDataflowAnalysis {
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs, tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs,
bool skip_top_level = false); bool skip_top_level = false);
// Updates the locations of the HloValues in the output of the given // Updates the positions of the HloValues in the output of the given
// instruction. This should be called after the instruction value set of // instruction. This should be called after the instruction value set of
// 'instruction' has been changed. 'prev_value_set' must point to the previous // 'instruction' has been changed. 'prev_value_set' must point to the previous
// state of the value set prior to the change. 'prev_value_set' may be null if // state of the value set prior to the change. 'prev_value_set' may be null if
// this is the first time locations are being computed. The previous state is // this is the first time positions are being computed. The previous state is
// necessary to efficiently remove locations which have been eliminated due to // necessary to efficiently remove positions which have been eliminated due to
// changes in the instructions' InstructionValueSet. // changes in the instructions' InstructionValueSet.
void UpdateLocationsOfValuesAt( void UpdatePositionsOfValuesAt(
HloInstruction* instruction, const InstructionValueSet& new_value_set, HloInstruction* instruction, const InstructionValueSet& new_value_set,
const InstructionValueSet* prev_value_set = nullptr); const InstructionValueSet* prev_value_set = nullptr);

View File

@ -51,7 +51,7 @@ class HloDataflowAnalysisTest : public HloTestBase,
return *analysis_; return *analysis_;
} }
// Return a vector of the HloValues at the given program location. // Return a vector of the HloValues at the given program position.
std::vector<HloValue> HloValuesAt(const HloInstruction* instruction, std::vector<HloValue> HloValuesAt(const HloInstruction* instruction,
const ShapeIndex& index = {}) { const ShapeIndex& index = {}) {
CHECK(analysis_ != nullptr); CHECK(analysis_ != nullptr);
@ -101,14 +101,14 @@ TEST_P(HloDataflowAnalysisTest, BinaryOperation) {
EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
// Verify the locations of the values. These locations are all trivial because // Verify the positions of the values. These positions are all trivial because
// there are no instructions which forward values. // there are no instructions which forward values.
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).locations(), EXPECT_THAT(analysis.GetValueDefinedAt(constant1).positions(),
UnorderedElementsAre(HloLocation{constant1, {}})); UnorderedElementsAre(HloPosition{constant1, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).locations(), EXPECT_THAT(analysis.GetValueDefinedAt(constant2).positions(),
UnorderedElementsAre(HloLocation{constant2, {}})); UnorderedElementsAre(HloPosition{constant2, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(add).locations(), EXPECT_THAT(analysis.GetValueDefinedAt(add).positions(),
UnorderedElementsAre(HloLocation{add, {}})); UnorderedElementsAre(HloPosition{add, {}}));
// Verify the uses of the values. // Verify the uses of the values.
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
@ -155,17 +155,17 @@ TEST_P(HloDataflowAnalysisTest, TupleAndGtes) {
EXPECT_FALSE(analysis.ValueIsDefinedAt(gte1)); EXPECT_FALSE(analysis.ValueIsDefinedAt(gte1));
EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
// Verify the locations of the values. // Verify the positions of the values.
EXPECT_THAT( EXPECT_THAT(
analysis.GetValueDefinedAt(param0).locations(), analysis.GetValueDefinedAt(param0).positions(),
UnorderedElementsAre(HloLocation{param0, {}}, HloLocation{tuple, {0}}, UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
HloLocation{gte0, {}})); HloPosition{gte0, {}}));
EXPECT_THAT( EXPECT_THAT(
analysis.GetValueDefinedAt(param1).locations(), analysis.GetValueDefinedAt(param1).positions(),
UnorderedElementsAre(HloLocation{param1, {}}, HloLocation{tuple, {1}}, UnorderedElementsAre(HloPosition{param1, {}}, HloPosition{tuple, {1}},
HloLocation{gte1, {}})); HloPosition{gte1, {}}));
EXPECT_THAT(analysis.GetValueDefinedAt(tuple).locations(), EXPECT_THAT(analysis.GetValueDefinedAt(tuple).positions(),
UnorderedElementsAre(HloLocation{tuple, {}})); UnorderedElementsAre(HloPosition{tuple, {}}));
// Verify uses. Of interest is that a GetTupleElement instruction is only a // Verify uses. Of interest is that a GetTupleElement instruction is only a
// use of the top-level value in the tuple operand. // use of the top-level value in the tuple operand.
@ -200,15 +200,15 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) {
EXPECT_EQ(analysis.values().size(), 4); EXPECT_EQ(analysis.values().size(), 4);
// Verify locations and uses. // Verify positions and uses.
EXPECT_THAT( EXPECT_THAT(
analysis.GetValueDefinedAt(constant1).locations(), analysis.GetValueDefinedAt(constant1).positions(),
UnorderedElementsAre( UnorderedElementsAre(
HloLocation{constant1, {}}, HloLocation{tuple, {0}}, HloPosition{constant1, {}}, HloPosition{tuple, {0}},
HloLocation{nested_tuple, {0, 0}}, HloLocation{nested_tuple, {1, 0}}, HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}},
HloLocation{nested_tuple, {2}}, HloLocation{gte_tuple, {0}}, HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}},
HloLocation{gte_out, {}})); HloPosition{gte_out, {}}));
// Constant values should have no uses though one is live out. The locations // Constant values should have no uses though one is live out. The positions
// where they appear as operands are on instructions which do not use the // where they appear as operands are on instructions which do not use the
// values (eg, Tuple). // values (eg, Tuple).
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty()); EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty());

View File

@ -38,18 +38,18 @@ namespace xla {
using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat; using ::tensorflow::strings::StrCat;
const Shape& HloLocation::shape() const { const Shape& HloPosition::shape() const {
return ShapeUtil::GetSubshape(instruction->shape(), index); return ShapeUtil::GetSubshape(instruction->shape(), index);
} }
string HloLocation::ToString() const { string HloPosition::ToString() const {
string index_str = string index_str =
ShapeUtil::IsTuple(instruction->shape()) ? (" " + index.ToString()) : ""; ShapeUtil::IsTuple(instruction->shape()) ? (" " + index.ToString()) : "";
return StrCat(instruction->name(), index_str); return StrCat(instruction->name(), index_str);
} }
std::ostream& operator<<(std::ostream& out, const HloLocation& location) { std::ostream& operator<<(std::ostream& out, const HloPosition& position) {
out << location.ToString(); out << position.ToString();
return out; return out;
} }
@ -69,8 +69,8 @@ std::ostream& operator<<(std::ostream& out, const HloUse& use) {
HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, HloValue::HloValue(HloValue::Id id, HloInstruction* instruction,
const ShapeIndex& index, bool is_phi) const ShapeIndex& index, bool is_phi)
: id_(id), is_phi_(is_phi) { : id_(id), is_phi_(is_phi) {
// The defining location is always the first element in the locations_ vector. // The defining position is always the first element in the positions_ vector.
AddLocation(instruction, index); AddPosition(instruction, index);
} }
bool HloValue::operator==(const HloValue& other) const { bool HloValue::operator==(const HloValue& other) const {
@ -95,9 +95,9 @@ string HloValue::ToShortString() const {
string HloValue::ToString(int indent) const { string HloValue::ToString(int indent) const {
string indentation(indent, ' '); string indentation(indent, ' ');
string out = StrCat(indentation, ToShortString(), ", locations:\n"); string out = StrCat(indentation, ToShortString(), ", positions:\n");
for (const HloLocation& location : locations()) { for (const HloPosition& position : positions()) {
StrAppend(&out, indentation, " ", location.ToString(), "\n"); StrAppend(&out, indentation, " ", position.ToString(), "\n");
} }
StrAppend(&out, indentation, " uses:\n"); StrAppend(&out, indentation, " uses:\n");
for (const HloUse& use : uses()) { for (const HloUse& use : uses()) {
@ -150,22 +150,22 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
} // namespace } // namespace
void HloValue::AddLocation(HloInstruction* instruction, void HloValue::AddPosition(HloInstruction* instruction,
const ShapeIndex& index) { const ShapeIndex& index) {
HloLocation new_location{instruction, index}; HloPosition new_position{instruction, index};
// The new location must not already exist in locations_. // The new position must not already exist in positions_.
for (const HloLocation& location : locations_) { for (const HloPosition& position : positions_) {
DCHECK_NE(location, new_location); DCHECK_NE(position, new_position);
} }
// The shape of the new location must match existing locations. // The shape of the new position must match existing positions.
if (!locations_.empty()) { if (!positions_.empty()) {
CHECK( CHECK(
ShapeUtil::Compatible(locations_.front().shape(), new_location.shape())) ShapeUtil::Compatible(positions_.front().shape(), new_position.shape()))
<< "front: " << locations_.front() << " new: " << new_location; << "front: " << positions_.front() << " new: " << new_position;
} }
locations_.push_back(std::move(new_location)); positions_.push_back(std::move(new_position));
// Update uses. // Update uses.
for (HloInstruction* user : instruction->users()) { for (HloInstruction* user : instruction->users()) {
@ -194,23 +194,23 @@ void HloValue::AddLocation(HloInstruction* instruction,
} }
} }
void HloValue::RemoveLocation(HloInstruction* instruction, void HloValue::RemovePosition(HloInstruction* instruction,
const ShapeIndex& index) { const ShapeIndex& index) {
// The defining location cannot be removed. // The defining position cannot be removed.
CHECK(!(instruction == defining_instruction() && index == defining_index())); CHECK(!(instruction == defining_instruction() && index == defining_index()));
int64 size_before = locations_.size(); int64 size_before = positions_.size();
locations_.erase( positions_.erase(
std::remove_if(locations_.begin(), locations_.end(), std::remove_if(positions_.begin(), positions_.end(),
[instruction, &index](const HloLocation& location) { [instruction, &index](const HloPosition& position) {
return location.instruction == instruction && return position.instruction == instruction &&
location.index == index; position.index == index;
}), }),
locations_.end()); positions_.end());
// Only a single location should have been removed. // Only a single position should have been removed.
CHECK_EQ(locations_.size(), size_before - 1); CHECK_EQ(positions_.size(), size_before - 1);
// Update uses which referred to this location. // Update uses which referred to this position.
uses_.erase(std::remove_if(uses_.begin(), uses_.end(), uses_.erase(std::remove_if(uses_.begin(), uses_.end(),
[instruction, &index](const HloUse& use) { [instruction, &index](const HloUse& use) {
return use.instruction->operand( return use.instruction->operand(
@ -221,8 +221,8 @@ void HloValue::RemoveLocation(HloInstruction* instruction,
// Returns whether this value is contained in the given instruction's output. // Returns whether this value is contained in the given instruction's output.
auto is_contained_in = [this](const HloInstruction* instruction) { auto is_contained_in = [this](const HloInstruction* instruction) {
for (const HloLocation& location : locations()) { for (const HloPosition& position : positions()) {
if (location.instruction == instruction) { if (position.instruction == instruction) {
return true; return true;
} }
} }
@ -231,7 +231,7 @@ void HloValue::RemoveLocation(HloInstruction* instruction,
const HloModule& module = *instruction->parent()->parent(); const HloModule& module = *instruction->parent()->parent();
if (instruction == module.entry_computation()->root_instruction()) { if (instruction == module.entry_computation()->root_instruction()) {
// Value has been removed from a location in the entry root instruction. // Value has been removed from a position in the entry root instruction.
live_out_of_module_ = live_out_of_module_ =
is_contained_in(module.entry_computation()->root_instruction()); is_contained_in(module.entry_computation()->root_instruction());
} }

View File

@ -30,24 +30,24 @@ limitations under the License.
namespace xla { namespace xla {
// Abstraction which identifies a specific point in the XLA graph. An // Abstraction which identifies a specific point in the XLA graph. An
// HloLocation specifies a ShapeIndex within the output of a specific // HloPosition specifies a ShapeIndex within the output of a specific
// instruction. // instruction.
struct HloLocation { struct HloPosition {
HloInstruction* instruction; HloInstruction* instruction;
ShapeIndex index; ShapeIndex index;
// Returns the shape at this location. // Returns the shape at this position.
const Shape& shape() const; const Shape& shape() const;
string ToString() const; string ToString() const;
bool operator==(const HloLocation& other) const { bool operator==(const HloPosition& other) const {
return instruction == other.instruction && index == other.index; return instruction == other.instruction && index == other.index;
} }
bool operator!=(const HloLocation& other) const { return !(*this == other); } bool operator!=(const HloPosition& other) const { return !(*this == other); }
}; };
std::ostream& operator<<(std::ostream& out, const HloLocation& location); std::ostream& operator<<(std::ostream& out, const HloPosition& position);
// Defines a single use of an HLO value. // Defines a single use of an HLO value.
struct HloUse { struct HloUse {
@ -111,28 +111,28 @@ class HloValue {
// Returns whether this value is a phi value. // Returns whether this value is a phi value.
bool is_phi() const { return is_phi_; } bool is_phi() const { return is_phi_; }
// Return the location where this value is defined. // Return the position where this value is defined.
const HloLocation& defining_location() const { return locations_[0]; } const HloPosition& defining_position() const { return positions_[0]; }
// Return the instruction which defines this HloValue. // Return the instruction which defines this HloValue.
HloInstruction* defining_instruction() const { HloInstruction* defining_instruction() const {
return defining_location().instruction; return defining_position().instruction;
} }
// Return the shape index at which this HloValue is defined in the output of // Return the shape index at which this HloValue is defined in the output of
// its defining instruction. // its defining instruction.
const ShapeIndex& defining_index() const { return defining_location().index; } const ShapeIndex& defining_index() const { return defining_position().index; }
// Return the shape of this HloValue. // Return the shape of this HloValue.
const Shape& shape() const { return defining_location().shape(); } const Shape& shape() const { return defining_position().shape(); }
// Add or remove a location at which the HloValue appears. The definition // Add or remove a position at which the HloValue appears. The definition
// location can not be removed. The uses of the HloValue are updated. // position can not be removed. The uses of the HloValue are updated.
void AddLocation(HloInstruction* instruction, const ShapeIndex& index); void AddPosition(HloInstruction* instruction, const ShapeIndex& index);
void RemoveLocation(HloInstruction* instruction, const ShapeIndex& index); void RemovePosition(HloInstruction* instruction, const ShapeIndex& index);
// Return all locations of the HloValue in the module. // Return all positions of the HloValue in the module.
const std::vector<HloLocation>& locations() const { return locations_; } const std::vector<HloPosition>& positions() const { return positions_; }
// Return all uses of the HloValue. // Return all uses of the HloValue.
const std::vector<HloUse>& uses() const { return uses_; } const std::vector<HloUse>& uses() const { return uses_; }
@ -158,9 +158,9 @@ class HloValue {
// Whether this instruction is a phi value. // Whether this instruction is a phi value.
const bool is_phi_; const bool is_phi_;
// The set of locations of this HloValue. The first element is always the // The set of positions of this HloValue. The first element is always the
// location of the definition. // position of the definition.
std::vector<HloLocation> locations_; std::vector<HloPosition> positions_;
// The set of uses of this HloValue. // The set of uses of this HloValue.
std::vector<HloUse> uses_; std::vector<HloUse> uses_;