[XLA] Added HasAllocationAt() helper function.
PiperOrigin-RevId: 163742985
This commit is contained in:
parent
18304683ec
commit
6ba02f0e92
@ -200,10 +200,10 @@ BufferAllocation* BufferAssignment::GetMutableAllocation(
|
||||
return const_cast<BufferAllocation*>(&GetAllocation(index));
|
||||
}
|
||||
|
||||
bool BufferAssignment::HasTopLevelAllocation(
|
||||
const HloInstruction* instruction) const {
|
||||
bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index) const {
|
||||
for (const LogicalBuffer* buffer :
|
||||
GetPointsToSet(instruction).element(/*index=*/{})) {
|
||||
GetPointsToSet(instruction).element(index)) {
|
||||
if (allocation_index_for_buffer_.count(buffer) > 0) {
|
||||
return true;
|
||||
}
|
||||
@ -211,6 +211,11 @@ bool BufferAssignment::HasTopLevelAllocation(
|
||||
return false;
|
||||
}
|
||||
|
||||
bool BufferAssignment::HasTopLevelAllocation(
|
||||
const HloInstruction* instruction) const {
|
||||
return HasAllocationAt(instruction, /*index=*/{});
|
||||
}
|
||||
|
||||
StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice(
|
||||
const HloInstruction* instruction, const ShapeIndex& index) const {
|
||||
VLOG(3) << "Trying to find unique slice for " << instruction->name() << " ["
|
||||
|
@ -281,6 +281,11 @@ class BufferAssignment {
|
||||
std::set<BufferAllocation::Slice> GetAllSlices(
|
||||
const HloInstruction* instruction, const ShapeIndex& index) const;
|
||||
|
||||
// Convenience function which returns whether the buffer of the
|
||||
// instruction at the given index is assigned an allocation.
|
||||
bool HasAllocationAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index) const;
|
||||
|
||||
// Convenience function which returns whether the top-level buffer of the
|
||||
// instruction (index == {}) is assigned an allocation.
|
||||
bool HasTopLevelAllocation(const HloInstruction* instruction) const;
|
||||
|
@ -296,6 +296,34 @@ TEST_F(BufferAssignmentTest, BufferForConst) {
|
||||
GetAssignedOutputAllocation(*buffers, add);
|
||||
}
|
||||
|
||||
TEST_F(BufferAssignmentTest, HasAllocationAt) {
|
||||
// Create a tuple with non-const and const elements and check that
|
||||
// HasAllocationAt works correctly.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, f32vec100_, "param0"));
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int>(1)));
|
||||
auto negate = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
|
||||
auto tuple = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({negate, param0, constant}));
|
||||
auto module = CreateNewModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
auto buffers = RunBufferAssignment(module.get());
|
||||
// Make sure that HasAllocationAt() agrees with what HasTopLevelAllocation()
|
||||
// reports for the instruction directly.
|
||||
EXPECT_EQ(buffers->HasTopLevelAllocation(tuple),
|
||||
buffers->HasAllocationAt(tuple, /*index=*/{}));
|
||||
EXPECT_EQ(buffers->HasTopLevelAllocation(negate),
|
||||
buffers->HasAllocationAt(tuple, /*index=*/{0}));
|
||||
EXPECT_EQ(buffers->HasTopLevelAllocation(param0),
|
||||
buffers->HasAllocationAt(tuple, /*index=*/{1}));
|
||||
EXPECT_EQ(buffers->HasTopLevelAllocation(constant),
|
||||
buffers->HasAllocationAt(tuple, /*index=*/{2}));
|
||||
}
|
||||
|
||||
TEST_F(BufferAssignmentTest, BufferForOutputConst) {
|
||||
// This computation copies a constant to output.
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
|
Loading…
x
Reference in New Issue
Block a user