Rename HloLocation to HloPosition, to avoid ambiguity with MemoryLocation.
PiperOrigin-RevId: 161716528
This commit is contained in:
parent
8e7f573716
commit
6b28eb0843
@ -94,13 +94,13 @@ void HloAliasAnalysis::CombineBuffers(
|
||||
unified_buffer.AddValue(dataflow_analysis_->GetValue(value_id));
|
||||
}
|
||||
|
||||
// Iterate through all locations where the buffer-to-eliminate exists and
|
||||
// Iterate through all positions where the buffer-to-eliminate exists and
|
||||
// replace it with the unified buffer.
|
||||
for (const HloLocation& location : buffer.locations()) {
|
||||
VLOG(4) << "Replacing in " << location;
|
||||
GetBufferSet(location.instruction, location.index)
|
||||
for (const HloPosition& position : buffer.positions()) {
|
||||
VLOG(4) << "Replacing in " << position;
|
||||
GetBufferSet(position.instruction, position.index)
|
||||
.RemoveBufferOrDie(buffer_id);
|
||||
GetBufferSet(location.instruction, location.index)
|
||||
GetBufferSet(position.instruction, position.index)
|
||||
.AddBuffer(unified_buffer.id());
|
||||
}
|
||||
|
||||
|
@ -45,7 +45,7 @@ class HloAliasAnalysis {
|
||||
InstructionBufferSet& GetInstructionBufferSet(
|
||||
const HloInstruction* instruction);
|
||||
|
||||
// Return the HloBufferSet for the given location.
|
||||
// Return the HloBufferSet for the given position.
|
||||
const HloBufferSet& GetBufferSet(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) const;
|
||||
HloBufferSet& GetBufferSet(const HloInstruction* instruction,
|
||||
@ -59,8 +59,8 @@ class HloAliasAnalysis {
|
||||
return buffers_.at(buffer_id);
|
||||
}
|
||||
|
||||
// Returns the unique buffer at the given location. CHECK fails if the buffer
|
||||
// set at that location does not contain exactly one buffer.
|
||||
// Returns the unique buffer at the given position. CHECK fails if the buffer
|
||||
// set at that position does not contain exactly one buffer.
|
||||
const HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) const {
|
||||
return GetBuffer(GetBufferSet(instruction, index).GetUniqueBufferId());
|
||||
|
@ -46,7 +46,7 @@ class HloAliasAnalysisTest : public HloTestBase {
|
||||
return *analysis_;
|
||||
}
|
||||
|
||||
// Return a vector of the buffers in the buffer set at the current location.
|
||||
// Return a vector of the buffers in the buffer set at the current position.
|
||||
std::vector<HloBuffer> GetBuffersAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) const {
|
||||
std::vector<HloBuffer> buffers;
|
||||
@ -66,7 +66,7 @@ class HloAliasAnalysisTest : public HloTestBase {
|
||||
return values;
|
||||
}
|
||||
|
||||
// Return the HloValue defined at the given location.
|
||||
// Return the HloValue defined at the given position.
|
||||
const HloValue& GetValueDefinedAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) const {
|
||||
return analysis_->dataflow_analysis().GetValueDefinedAt(instruction, index);
|
||||
@ -174,11 +174,11 @@ TEST_F(HloAliasAnalysisTest, TupleAndGtes) {
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(param0),
|
||||
analysis.GetUniqueBufferAt(gte0));
|
||||
|
||||
// Verify the locations of an aliased buffer.
|
||||
// Verify the positions of an aliased buffer.
|
||||
EXPECT_THAT(
|
||||
analysis.GetUniqueBufferAt(param0).locations(),
|
||||
UnorderedElementsAre(HloLocation{param0, {}}, HloLocation{tuple, {0}},
|
||||
HloLocation{gte0, {}}));
|
||||
analysis.GetUniqueBufferAt(param0).positions(),
|
||||
UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
|
||||
HloPosition{gte0, {}}));
|
||||
|
||||
EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous());
|
||||
EXPECT_TRUE(analysis.GetInstructionBufferSet(tuple).IsDistinct());
|
||||
@ -201,9 +201,9 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) {
|
||||
const HloAliasAnalysis& analysis = RunAnalysis();
|
||||
|
||||
EXPECT_THAT(
|
||||
analysis.GetUniqueBufferAt(param0).locations(),
|
||||
UnorderedElementsAre(HloLocation{param0, {}}, HloLocation{tuple, {0}},
|
||||
HloLocation{tuple, {2}}));
|
||||
analysis.GetUniqueBufferAt(param0).positions(),
|
||||
UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
|
||||
HloPosition{tuple, {2}}));
|
||||
|
||||
EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous());
|
||||
EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsDistinct());
|
||||
@ -236,17 +236,17 @@ TEST_F(HloAliasAnalysisTest, SingleCall) {
|
||||
const HloAliasAnalysis& analysis = RunAnalysis();
|
||||
|
||||
// Verify aliasing of the kCall operands and the subcomputation parameters.
|
||||
EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).locations(),
|
||||
UnorderedElementsAre(HloLocation{constant1, {}},
|
||||
HloLocation{subparam0, {}}));
|
||||
EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).locations(),
|
||||
UnorderedElementsAre(HloLocation{constant2, {}},
|
||||
HloLocation{subparam1, {}}));
|
||||
EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).positions(),
|
||||
UnorderedElementsAre(HloPosition{constant1, {}},
|
||||
HloPosition{subparam0, {}}));
|
||||
EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).positions(),
|
||||
UnorderedElementsAre(HloPosition{constant2, {}},
|
||||
HloPosition{subparam1, {}}));
|
||||
|
||||
// The subcomputation root and the kCall itself should alias.
|
||||
EXPECT_THAT(
|
||||
analysis.GetUniqueBufferAt(add).locations(),
|
||||
UnorderedElementsAre(HloLocation{add, {}}, HloLocation{call, {}}));
|
||||
analysis.GetUniqueBufferAt(add).positions(),
|
||||
UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call, {}}));
|
||||
|
||||
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
|
||||
}
|
||||
@ -276,20 +276,20 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) {
|
||||
|
||||
const HloAliasAnalysis& analysis = RunAnalysis();
|
||||
|
||||
EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).locations(),
|
||||
UnorderedElementsAre(HloLocation{constant1, {}},
|
||||
HloLocation{subparam0, {}}));
|
||||
EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).locations(),
|
||||
UnorderedElementsAre(HloLocation{constant2, {}},
|
||||
HloLocation{subparam1, {}}));
|
||||
EXPECT_THAT(analysis.GetUniqueBufferAt(constant1).positions(),
|
||||
UnorderedElementsAre(HloPosition{constant1, {}},
|
||||
HloPosition{subparam0, {}}));
|
||||
EXPECT_THAT(analysis.GetUniqueBufferAt(constant2).positions(),
|
||||
UnorderedElementsAre(HloPosition{constant2, {}},
|
||||
HloPosition{subparam1, {}}));
|
||||
|
||||
// The 'add' (root of the subcomputation) aliases the two call instruction,
|
||||
// and the first parameter of the subcomputation because 'call1' it is passed
|
||||
// as an argument to the subcomputation in 'call2'.
|
||||
EXPECT_THAT(
|
||||
analysis.GetUniqueBufferAt(add).locations(),
|
||||
UnorderedElementsAre(HloLocation{add, {}}, HloLocation{call1, {}},
|
||||
HloLocation{subparam0, {}}, HloLocation{call2, {}}));
|
||||
analysis.GetUniqueBufferAt(add).positions(),
|
||||
UnorderedElementsAre(HloPosition{add, {}}, HloPosition{call1, {}},
|
||||
HloPosition{subparam0, {}}, HloPosition{call2, {}}));
|
||||
|
||||
EXPECT_THAT(GetBuffersAt(subparam0),
|
||||
UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
|
||||
@ -361,24 +361,24 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) {
|
||||
|
||||
const HloAliasAnalysis& analysis = RunAnalysis();
|
||||
|
||||
// Verify the locations of the aliased while buffers.
|
||||
EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).locations(),
|
||||
// Verify the positions of the aliased while buffers.
|
||||
EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{}).positions(),
|
||||
UnorderedElementsAre(
|
||||
HloLocation{tuple, {}}, HloLocation{xla_while, {}},
|
||||
HloLocation{body_param, {}}, HloLocation{body_tuple, {}},
|
||||
HloLocation{cond_param, {}}));
|
||||
EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).locations(),
|
||||
HloPosition{tuple, {}}, HloPosition{xla_while, {}},
|
||||
HloPosition{body_param, {}}, HloPosition{body_tuple, {}},
|
||||
HloPosition{cond_param, {}}));
|
||||
EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0}).positions(),
|
||||
UnorderedElementsAre(
|
||||
HloLocation{constant1, {}}, HloLocation{tuple, {0}},
|
||||
HloLocation{xla_while, {0}}, HloLocation{body_param, {0}},
|
||||
HloLocation{body_element_0, {}}, HloLocation{body_tuple, {0}},
|
||||
HloLocation{cond_param, {0}}));
|
||||
EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).locations(),
|
||||
HloPosition{constant1, {}}, HloPosition{tuple, {0}},
|
||||
HloPosition{xla_while, {0}}, HloPosition{body_param, {0}},
|
||||
HloPosition{body_element_0, {}}, HloPosition{body_tuple, {0}},
|
||||
HloPosition{cond_param, {0}}));
|
||||
EXPECT_THAT(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).positions(),
|
||||
UnorderedElementsAre(
|
||||
HloLocation{constant2, {}}, HloLocation{tuple, {1}},
|
||||
HloLocation{xla_while, {1}}, HloLocation{body_param, {1}},
|
||||
HloLocation{body_element_1, {}}, HloLocation{add, {}},
|
||||
HloLocation{body_tuple, {1}}, HloLocation{cond_param, {1}}));
|
||||
HloPosition{constant2, {}}, HloPosition{tuple, {1}},
|
||||
HloPosition{xla_while, {1}}, HloPosition{body_param, {1}},
|
||||
HloPosition{body_element_1, {}}, HloPosition{add, {}},
|
||||
HloPosition{body_tuple, {1}}, HloPosition{cond_param, {1}}));
|
||||
|
||||
EXPECT_THAT(
|
||||
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})),
|
||||
@ -619,7 +619,7 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) {
|
||||
|
||||
const HloAliasAnalysis& analysis = RunAnalysis();
|
||||
|
||||
// The swizzling while makes most locations in the module alias leaving only 3
|
||||
// The swizzling while makes most positions in the module alias leaving only 3
|
||||
// HloBuffers.
|
||||
EXPECT_THAT(
|
||||
analysis.buffers(),
|
||||
|
@ -46,11 +46,11 @@ void HloBuffer::AddValue(const HloValue& value) {
|
||||
|
||||
value_ids_.push_back(value.id());
|
||||
|
||||
// Add all of the locations of the HloValue to this buffer.
|
||||
for (const HloLocation& location : value.locations()) {
|
||||
if (std::find(locations_.begin(), locations_.end(), location) ==
|
||||
locations_.end()) {
|
||||
locations_.push_back(location);
|
||||
// Add all of the positions of the HloValue to this buffer.
|
||||
for (const HloPosition& position : value.positions()) {
|
||||
if (std::find(positions_.begin(), positions_.end(), position) ==
|
||||
positions_.end()) {
|
||||
positions_.push_back(position);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -60,7 +60,7 @@ bool HloBuffer::operator==(const HloBuffer& other) const {
|
||||
if (equal) {
|
||||
// DCHECK because these comparisons are expensive (linear time).
|
||||
DCHECK(value_ids() == other.value_ids());
|
||||
DCHECK(locations() == other.locations());
|
||||
DCHECK(positions() == other.positions());
|
||||
}
|
||||
return equal;
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ namespace xla {
|
||||
// from. Generally there is a one-to-one correspondence between HloBuffers and
|
||||
// HloValue where each HloValue in the module is held in a unique HloBuffer. An
|
||||
// exception is the while instruction which updates the loop state in-place. In
|
||||
// this case, we have a single HloBuffer for each HloLocation in the loop state,
|
||||
// this case, we have a single HloBuffer for each HloPosition in the loop state,
|
||||
// but multiple HloValues. For example:
|
||||
//
|
||||
// %init = ...
|
||||
@ -53,7 +53,7 @@ namespace xla {
|
||||
// HloValue{%while}, HloValue{%body_param}, HloValue{%body_root}, and
|
||||
// HloValue{%cond_param}.
|
||||
//
|
||||
// HloBuffers may appear at different HloLocations in the module mirroring the
|
||||
// HloBuffers may appear at different HloPositions in the module mirroring the
|
||||
// same propery of HloValues. For example:
|
||||
//
|
||||
// %sub = Sub(...)
|
||||
@ -62,11 +62,11 @@ namespace xla {
|
||||
// %gte = GetTupleElement(%tuple, 0)
|
||||
//
|
||||
// In this case, the HloBuffer containing %add appears at the following
|
||||
// locations: HloLocation{%add, {}}, HloLocation{%tuple, {0}}, and
|
||||
// HloLocation{%gte, {}}.
|
||||
// positions: HloPosition{%add, {}}, HloPosition{%tuple, {0}}, and
|
||||
// HloPosition{%gte, {}}.
|
||||
//
|
||||
// Different HloLocations which share the same HloBuffer indicate mandatory
|
||||
// aliasing in the HLO module. These locations must share the same memory
|
||||
// Different HloPositions which share the same HloBuffer indicate mandatory
|
||||
// aliasing in the HLO module. These positions must share the same memory
|
||||
// allocation for correctness (the backends rely on this property). This differs
|
||||
// from incidental aliasing introduced by memory reuse in BufferAssignment where
|
||||
// different instructions may happen to get the same allocation.
|
||||
@ -80,17 +80,17 @@ class HloBuffer {
|
||||
Id id() const { return id_; }
|
||||
|
||||
// Add a value to the set of values held by this buffer. Also adds the
|
||||
// HloLocations of the value to the locations vector of the buffer. If the
|
||||
// HloPositions of the value to the positions vector of the buffer. If the
|
||||
// buffer already contains this value, then this method is a nop.
|
||||
void AddValue(const HloValue& value);
|
||||
|
||||
// Return the IDs of all values contained in this buffer.
|
||||
const std::vector<HloValue::Id>& value_ids() const { return value_ids_; }
|
||||
|
||||
// Return the locations (output of which instruction and at what index) where
|
||||
// the buffer is used. This is exactly the union of the locations of the
|
||||
// Return the positions (output of which instruction and at what index) where
|
||||
// the buffer is used. This is exactly the union of the positions of the
|
||||
// HloValues contained by the buffer.
|
||||
const std::vector<HloLocation>& locations() const { return locations_; }
|
||||
const std::vector<HloPosition>& positions() const { return positions_; }
|
||||
|
||||
string ToString() const;
|
||||
|
||||
@ -104,16 +104,16 @@ class HloBuffer {
|
||||
// The set of values contained in this buffer.
|
||||
std::vector<HloValue::Id> value_ids_;
|
||||
|
||||
// The set of locations where this buffer is used.
|
||||
std::vector<HloLocation> locations_;
|
||||
// The set of positions where this buffer is used.
|
||||
std::vector<HloPosition> positions_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer);
|
||||
|
||||
// A class representing the set of possible HloBuffers at a particular
|
||||
// HloLocation (shape index in the output of an instruction) in the XLA
|
||||
// HloPosition (shape index in the output of an instruction) in the XLA
|
||||
// graph. In most cases, the buffer set will have a single HloBuffer indicating
|
||||
// that the HloBuffer which appears at that particular location is known
|
||||
// that the HloBuffer which appears at that particular position is known
|
||||
// unambiguously at compile-time. However, tuple-shaped Select instructions can
|
||||
// introduce ambiguity as the tuple elements of the operands are passed by
|
||||
// reference into the output of the Select. For example:
|
||||
@ -123,7 +123,7 @@ std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer);
|
||||
// %tuple1 = Tuple(%x, %y)
|
||||
// %select = Select(%pred, %tuple0, %tuple1)
|
||||
//
|
||||
// In this case the HloBufferSet at HloLocation{%select, {0}} contains the
|
||||
// In this case the HloBufferSet at HloPosition{%select, {0}} contains the
|
||||
// HloBuffer holding %a and the HloBuffer holding %x.
|
||||
class HloBufferSet {
|
||||
public:
|
||||
|
@ -247,11 +247,11 @@ InstructionValueSet HloDataflowAnalysis::Phi(
|
||||
return new_value_set;
|
||||
}
|
||||
|
||||
void HloDataflowAnalysis::UpdateLocationsOfValuesAt(
|
||||
void HloDataflowAnalysis::UpdatePositionsOfValuesAt(
|
||||
HloInstruction* instruction, const InstructionValueSet& new_value_set,
|
||||
const InstructionValueSet* prev_value_set) {
|
||||
if (prev_value_set != nullptr) {
|
||||
// Remove locations from the old value set.
|
||||
// Remove positions from the old value set.
|
||||
prev_value_set->ForEachElement(
|
||||
[this, instruction](const ShapeIndex& index,
|
||||
const HloValueSet& value_set) {
|
||||
@ -260,17 +260,17 @@ void HloDataflowAnalysis::UpdateLocationsOfValuesAt(
|
||||
if (!ContainsKey(values_, value_id)) {
|
||||
continue;
|
||||
}
|
||||
// Don't remove the defining location of the value.
|
||||
// Don't remove the defining position of the value.
|
||||
HloValue& value = GetValue(value_id);
|
||||
if (instruction == value.defining_instruction()) {
|
||||
CHECK_EQ(index, value.defining_index());
|
||||
} else {
|
||||
value.RemoveLocation(instruction, index);
|
||||
value.RemovePosition(instruction, index);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
// Add locations in the new value set.
|
||||
// Add positions in the new value set.
|
||||
new_value_set.ForEachElement(
|
||||
[this, instruction](const ShapeIndex& index,
|
||||
const HloValueSet& value_set) {
|
||||
@ -279,7 +279,7 @@ void HloDataflowAnalysis::UpdateLocationsOfValuesAt(
|
||||
if (instruction == value.defining_instruction()) {
|
||||
CHECK_EQ(index, value.defining_index());
|
||||
} else {
|
||||
value.AddLocation(instruction, index);
|
||||
value.AddPosition(instruction, index);
|
||||
}
|
||||
}
|
||||
});
|
||||
@ -466,7 +466,7 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
|
||||
// Update uses. First clear all of the old uses at the particular
|
||||
// operands. Then add the new uses. There may be overlap between the old
|
||||
// uses and new uses.
|
||||
UpdateLocationsOfValuesAt(instruction, GetInstructionValueSet(instruction),
|
||||
UpdatePositionsOfValuesAt(instruction, GetInstructionValueSet(instruction),
|
||||
&old_value);
|
||||
}
|
||||
}
|
||||
@ -600,7 +600,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
define_all_values();
|
||||
break;
|
||||
}
|
||||
UpdateLocationsOfValuesAt(instruction.get(),
|
||||
UpdatePositionsOfValuesAt(instruction.get(),
|
||||
GetInstructionValueSet(instruction.get()));
|
||||
}
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Analysis for determining the possible set of values for all locations
|
||||
// Analysis for determining the possible set of values for all positions
|
||||
// (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped
|
||||
// tracking values across computation boundaries.
|
||||
|
||||
@ -170,14 +170,14 @@ class HloDataflowAnalysis {
|
||||
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs,
|
||||
bool skip_top_level = false);
|
||||
|
||||
// Updates the locations of the HloValues in the output of the given
|
||||
// Updates the positions of the HloValues in the output of the given
|
||||
// instruction. This should be called after the instruction value set of
|
||||
// 'instruction' has been changed. 'prev_value_set' must point to the previous
|
||||
// state of the value set prior to the change. 'prev_value_set' may be null if
|
||||
// this is the first time locations are being computed. The previous state is
|
||||
// necessary to efficiently remove locations which have been eliminated due to
|
||||
// this is the first time positions are being computed. The previous state is
|
||||
// necessary to efficiently remove positions which have been eliminated due to
|
||||
// changes in the instructions' InstructionValueSet.
|
||||
void UpdateLocationsOfValuesAt(
|
||||
void UpdatePositionsOfValuesAt(
|
||||
HloInstruction* instruction, const InstructionValueSet& new_value_set,
|
||||
const InstructionValueSet* prev_value_set = nullptr);
|
||||
|
||||
|
@ -51,7 +51,7 @@ class HloDataflowAnalysisTest : public HloTestBase,
|
||||
return *analysis_;
|
||||
}
|
||||
|
||||
// Return a vector of the HloValues at the given program location.
|
||||
// Return a vector of the HloValues at the given program position.
|
||||
std::vector<HloValue> HloValuesAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) {
|
||||
CHECK(analysis_ != nullptr);
|
||||
@ -101,14 +101,14 @@ TEST_P(HloDataflowAnalysisTest, BinaryOperation) {
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
|
||||
|
||||
// Verify the locations of the values. These locations are all trivial because
|
||||
// Verify the positions of the values. These positions are all trivial because
|
||||
// there are no instructions which forward values.
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).locations(),
|
||||
UnorderedElementsAre(HloLocation{constant1, {}}));
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).locations(),
|
||||
UnorderedElementsAre(HloLocation{constant2, {}}));
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(add).locations(),
|
||||
UnorderedElementsAre(HloLocation{add, {}}));
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).positions(),
|
||||
UnorderedElementsAre(HloPosition{constant1, {}}));
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant2).positions(),
|
||||
UnorderedElementsAre(HloPosition{constant2, {}}));
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(add).positions(),
|
||||
UnorderedElementsAre(HloPosition{add, {}}));
|
||||
|
||||
// Verify the uses of the values.
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
|
||||
@ -155,17 +155,17 @@ TEST_P(HloDataflowAnalysisTest, TupleAndGtes) {
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(gte1));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
|
||||
|
||||
// Verify the locations of the values.
|
||||
// Verify the positions of the values.
|
||||
EXPECT_THAT(
|
||||
analysis.GetValueDefinedAt(param0).locations(),
|
||||
UnorderedElementsAre(HloLocation{param0, {}}, HloLocation{tuple, {0}},
|
||||
HloLocation{gte0, {}}));
|
||||
analysis.GetValueDefinedAt(param0).positions(),
|
||||
UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
|
||||
HloPosition{gte0, {}}));
|
||||
EXPECT_THAT(
|
||||
analysis.GetValueDefinedAt(param1).locations(),
|
||||
UnorderedElementsAre(HloLocation{param1, {}}, HloLocation{tuple, {1}},
|
||||
HloLocation{gte1, {}}));
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(tuple).locations(),
|
||||
UnorderedElementsAre(HloLocation{tuple, {}}));
|
||||
analysis.GetValueDefinedAt(param1).positions(),
|
||||
UnorderedElementsAre(HloPosition{param1, {}}, HloPosition{tuple, {1}},
|
||||
HloPosition{gte1, {}}));
|
||||
EXPECT_THAT(analysis.GetValueDefinedAt(tuple).positions(),
|
||||
UnorderedElementsAre(HloPosition{tuple, {}}));
|
||||
|
||||
// Verify uses. Of interest is that a GetTupleElement instruction is only a
|
||||
// use of the top-level value in the tuple operand.
|
||||
@ -200,15 +200,15 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) {
|
||||
|
||||
EXPECT_EQ(analysis.values().size(), 4);
|
||||
|
||||
// Verify locations and uses.
|
||||
// Verify positions and uses.
|
||||
EXPECT_THAT(
|
||||
analysis.GetValueDefinedAt(constant1).locations(),
|
||||
analysis.GetValueDefinedAt(constant1).positions(),
|
||||
UnorderedElementsAre(
|
||||
HloLocation{constant1, {}}, HloLocation{tuple, {0}},
|
||||
HloLocation{nested_tuple, {0, 0}}, HloLocation{nested_tuple, {1, 0}},
|
||||
HloLocation{nested_tuple, {2}}, HloLocation{gte_tuple, {0}},
|
||||
HloLocation{gte_out, {}}));
|
||||
// Constant values should have no uses though one is live out. The locations
|
||||
HloPosition{constant1, {}}, HloPosition{tuple, {0}},
|
||||
HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}},
|
||||
HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}},
|
||||
HloPosition{gte_out, {}}));
|
||||
// Constant values should have no uses though one is live out. The positions
|
||||
// where they appear as operands are on instructions which do not use the
|
||||
// values (eg, Tuple).
|
||||
EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty());
|
||||
|
@ -38,18 +38,18 @@ namespace xla {
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
const Shape& HloLocation::shape() const {
|
||||
const Shape& HloPosition::shape() const {
|
||||
return ShapeUtil::GetSubshape(instruction->shape(), index);
|
||||
}
|
||||
|
||||
string HloLocation::ToString() const {
|
||||
string HloPosition::ToString() const {
|
||||
string index_str =
|
||||
ShapeUtil::IsTuple(instruction->shape()) ? (" " + index.ToString()) : "";
|
||||
return StrCat(instruction->name(), index_str);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloLocation& location) {
|
||||
out << location.ToString();
|
||||
std::ostream& operator<<(std::ostream& out, const HloPosition& position) {
|
||||
out << position.ToString();
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -69,8 +69,8 @@ std::ostream& operator<<(std::ostream& out, const HloUse& use) {
|
||||
HloValue::HloValue(HloValue::Id id, HloInstruction* instruction,
|
||||
const ShapeIndex& index, bool is_phi)
|
||||
: id_(id), is_phi_(is_phi) {
|
||||
// The defining location is always the first element in the locations_ vector.
|
||||
AddLocation(instruction, index);
|
||||
// The defining position is always the first element in the positions_ vector.
|
||||
AddPosition(instruction, index);
|
||||
}
|
||||
|
||||
bool HloValue::operator==(const HloValue& other) const {
|
||||
@ -95,9 +95,9 @@ string HloValue::ToShortString() const {
|
||||
|
||||
string HloValue::ToString(int indent) const {
|
||||
string indentation(indent, ' ');
|
||||
string out = StrCat(indentation, ToShortString(), ", locations:\n");
|
||||
for (const HloLocation& location : locations()) {
|
||||
StrAppend(&out, indentation, " ", location.ToString(), "\n");
|
||||
string out = StrCat(indentation, ToShortString(), ", positions:\n");
|
||||
for (const HloPosition& position : positions()) {
|
||||
StrAppend(&out, indentation, " ", position.ToString(), "\n");
|
||||
}
|
||||
StrAppend(&out, indentation, " uses:\n");
|
||||
for (const HloUse& use : uses()) {
|
||||
@ -150,22 +150,22 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
|
||||
|
||||
} // namespace
|
||||
|
||||
void HloValue::AddLocation(HloInstruction* instruction,
|
||||
void HloValue::AddPosition(HloInstruction* instruction,
|
||||
const ShapeIndex& index) {
|
||||
HloLocation new_location{instruction, index};
|
||||
HloPosition new_position{instruction, index};
|
||||
|
||||
// The new location must not already exist in locations_.
|
||||
for (const HloLocation& location : locations_) {
|
||||
DCHECK_NE(location, new_location);
|
||||
// The new position must not already exist in positions_.
|
||||
for (const HloPosition& position : positions_) {
|
||||
DCHECK_NE(position, new_position);
|
||||
}
|
||||
// The shape of the new location must match existing locations.
|
||||
if (!locations_.empty()) {
|
||||
// The shape of the new position must match existing positions.
|
||||
if (!positions_.empty()) {
|
||||
CHECK(
|
||||
ShapeUtil::Compatible(locations_.front().shape(), new_location.shape()))
|
||||
<< "front: " << locations_.front() << " new: " << new_location;
|
||||
ShapeUtil::Compatible(positions_.front().shape(), new_position.shape()))
|
||||
<< "front: " << positions_.front() << " new: " << new_position;
|
||||
}
|
||||
|
||||
locations_.push_back(std::move(new_location));
|
||||
positions_.push_back(std::move(new_position));
|
||||
|
||||
// Update uses.
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
@ -194,23 +194,23 @@ void HloValue::AddLocation(HloInstruction* instruction,
|
||||
}
|
||||
}
|
||||
|
||||
void HloValue::RemoveLocation(HloInstruction* instruction,
|
||||
void HloValue::RemovePosition(HloInstruction* instruction,
|
||||
const ShapeIndex& index) {
|
||||
// The defining location cannot be removed.
|
||||
// The defining position cannot be removed.
|
||||
CHECK(!(instruction == defining_instruction() && index == defining_index()));
|
||||
|
||||
int64 size_before = locations_.size();
|
||||
locations_.erase(
|
||||
std::remove_if(locations_.begin(), locations_.end(),
|
||||
[instruction, &index](const HloLocation& location) {
|
||||
return location.instruction == instruction &&
|
||||
location.index == index;
|
||||
int64 size_before = positions_.size();
|
||||
positions_.erase(
|
||||
std::remove_if(positions_.begin(), positions_.end(),
|
||||
[instruction, &index](const HloPosition& position) {
|
||||
return position.instruction == instruction &&
|
||||
position.index == index;
|
||||
}),
|
||||
locations_.end());
|
||||
// Only a single location should have been removed.
|
||||
CHECK_EQ(locations_.size(), size_before - 1);
|
||||
positions_.end());
|
||||
// Only a single position should have been removed.
|
||||
CHECK_EQ(positions_.size(), size_before - 1);
|
||||
|
||||
// Update uses which referred to this location.
|
||||
// Update uses which referred to this position.
|
||||
uses_.erase(std::remove_if(uses_.begin(), uses_.end(),
|
||||
[instruction, &index](const HloUse& use) {
|
||||
return use.instruction->operand(
|
||||
@ -221,8 +221,8 @@ void HloValue::RemoveLocation(HloInstruction* instruction,
|
||||
|
||||
// Returns whether this value is contained in the given instruction's output.
|
||||
auto is_contained_in = [this](const HloInstruction* instruction) {
|
||||
for (const HloLocation& location : locations()) {
|
||||
if (location.instruction == instruction) {
|
||||
for (const HloPosition& position : positions()) {
|
||||
if (position.instruction == instruction) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -231,7 +231,7 @@ void HloValue::RemoveLocation(HloInstruction* instruction,
|
||||
|
||||
const HloModule& module = *instruction->parent()->parent();
|
||||
if (instruction == module.entry_computation()->root_instruction()) {
|
||||
// Value has been removed from a location in the entry root instruction.
|
||||
// Value has been removed from a position in the entry root instruction.
|
||||
live_out_of_module_ =
|
||||
is_contained_in(module.entry_computation()->root_instruction());
|
||||
}
|
||||
|
@ -30,24 +30,24 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
// Abstraction which identifies a specific point in the XLA graph. An
|
||||
// HloLocation specifies a ShapeIndex within the output of a specific
|
||||
// HloPosition specifies a ShapeIndex within the output of a specific
|
||||
// instruction.
|
||||
struct HloLocation {
|
||||
struct HloPosition {
|
||||
HloInstruction* instruction;
|
||||
ShapeIndex index;
|
||||
|
||||
// Returns the shape at this location.
|
||||
// Returns the shape at this position.
|
||||
const Shape& shape() const;
|
||||
|
||||
string ToString() const;
|
||||
|
||||
bool operator==(const HloLocation& other) const {
|
||||
bool operator==(const HloPosition& other) const {
|
||||
return instruction == other.instruction && index == other.index;
|
||||
}
|
||||
bool operator!=(const HloLocation& other) const { return !(*this == other); }
|
||||
bool operator!=(const HloPosition& other) const { return !(*this == other); }
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const HloLocation& location);
|
||||
std::ostream& operator<<(std::ostream& out, const HloPosition& position);
|
||||
|
||||
// Defines a single use of an HLO value.
|
||||
struct HloUse {
|
||||
@ -111,28 +111,28 @@ class HloValue {
|
||||
// Returns whether this value is a phi value.
|
||||
bool is_phi() const { return is_phi_; }
|
||||
|
||||
// Return the location where this value is defined.
|
||||
const HloLocation& defining_location() const { return locations_[0]; }
|
||||
// Return the position where this value is defined.
|
||||
const HloPosition& defining_position() const { return positions_[0]; }
|
||||
|
||||
// Return the instruction which defines this HloValue.
|
||||
HloInstruction* defining_instruction() const {
|
||||
return defining_location().instruction;
|
||||
return defining_position().instruction;
|
||||
}
|
||||
|
||||
// Return the shape index at which this HloValue is defined in the output of
|
||||
// its defining instruction.
|
||||
const ShapeIndex& defining_index() const { return defining_location().index; }
|
||||
const ShapeIndex& defining_index() const { return defining_position().index; }
|
||||
|
||||
// Return the shape of this HloValue.
|
||||
const Shape& shape() const { return defining_location().shape(); }
|
||||
const Shape& shape() const { return defining_position().shape(); }
|
||||
|
||||
// Add or remove a location at which the HloValue appears. The definition
|
||||
// location can not be removed. The uses of the HloValue are updated.
|
||||
void AddLocation(HloInstruction* instruction, const ShapeIndex& index);
|
||||
void RemoveLocation(HloInstruction* instruction, const ShapeIndex& index);
|
||||
// Add or remove a position at which the HloValue appears. The definition
|
||||
// position can not be removed. The uses of the HloValue are updated.
|
||||
void AddPosition(HloInstruction* instruction, const ShapeIndex& index);
|
||||
void RemovePosition(HloInstruction* instruction, const ShapeIndex& index);
|
||||
|
||||
// Return all locations of the HloValue in the module.
|
||||
const std::vector<HloLocation>& locations() const { return locations_; }
|
||||
// Return all positions of the HloValue in the module.
|
||||
const std::vector<HloPosition>& positions() const { return positions_; }
|
||||
|
||||
// Return all uses of the HloValue.
|
||||
const std::vector<HloUse>& uses() const { return uses_; }
|
||||
@ -158,9 +158,9 @@ class HloValue {
|
||||
// Whether this instruction is a phi value.
|
||||
const bool is_phi_;
|
||||
|
||||
// The set of locations of this HloValue. The first element is always the
|
||||
// location of the definition.
|
||||
std::vector<HloLocation> locations_;
|
||||
// The set of positions of this HloValue. The first element is always the
|
||||
// position of the definition.
|
||||
std::vector<HloPosition> positions_;
|
||||
|
||||
// The set of uses of this HloValue.
|
||||
std::vector<HloUse> uses_;
|
||||
|
Loading…
Reference in New Issue
Block a user