Internal change.
Change: 151587999
This commit is contained in:
parent
19526c1318
commit
4718ac6b15
tensorflow/compiler/xla/service
@ -212,6 +212,13 @@ StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice(
|
||||
return GetUniqueSlice(instruction, /*index=*/{});
|
||||
}
|
||||
|
||||
bool BufferAssignment::SharesSliceAtIndex(
|
||||
const HloInstruction* hlo_a, const ShapeIndex& shape_index_a,
|
||||
const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const {
|
||||
return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() ==
|
||||
GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
StatusOr<BufferAllocation::Slice>
|
||||
BufferAssignment::GetUniqueTopLevelOutputSlice() const {
|
||||
return GetUniqueTopLevelSlice(
|
||||
|
@ -294,6 +294,15 @@ class BufferAssignment {
|
||||
return GetPointsToSet(instruction).element(index);
|
||||
}
|
||||
|
||||
// Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}'
|
||||
// share the same BufferAllocation::Slice.
|
||||
// Returns false otherwise.
|
||||
// REQUIRES: BufferAssignment assigned allocations to both instructions.
|
||||
bool SharesSliceAtIndex(const HloInstruction* hlo_a,
|
||||
const ShapeIndex& shape_index_a,
|
||||
const HloInstruction* hlo_b,
|
||||
const ShapeIndex& shape_index_b) const;
|
||||
|
||||
// Returns the underlying points-to analysis used for this assignment.
|
||||
const TuplePointsToAnalysis& points_to_analysis() const {
|
||||
return liveness_->points_to_analysis();
|
||||
|
@ -121,6 +121,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
|
||||
// *) Is element-wise.
|
||||
// *) Is a loop fusion instruction (with DynamicUpdateSlice fused root) where
|
||||
// the singleton use of 'a' at 'a.index' is the fused root at operand 0.
|
||||
// *) Use of 'operand' is DynamicUpdateSlice at operand index 0.
|
||||
for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) {
|
||||
if (b.instruction()->IsUserOf(alias.instruction()) &&
|
||||
!CanShareOperandBufferWithUser(alias.instruction(), alias.index(),
|
||||
|
@ -612,6 +612,93 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) {
|
||||
EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true));
|
||||
}
|
||||
|
||||
class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
|
||||
protected:
|
||||
// Builds and runs a computation (see test case computation graphs below).
|
||||
// Runs BufferLiveness on this computation.
|
||||
// Returns whether buffer interference is detected between tuple-shaped
|
||||
// parameter and root instructions at tuple element 1.
|
||||
bool Run(const bool tuple_element1_has_two_uses) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
// Create param0 Tuple.
|
||||
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
|
||||
Shape update_shape = ShapeUtil::MakeShape(F32, {3});
|
||||
auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0"));
|
||||
|
||||
auto gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0));
|
||||
|
||||
auto gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
|
||||
|
||||
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
|
||||
|
||||
if (tuple_element1_has_two_uses) {
|
||||
// Add 'gte0' and 'gte1' to create another user of 'gte1'.
|
||||
gte0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
data_shape, HloOpcode::kAdd, gte0, gte1));
|
||||
}
|
||||
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
|
||||
auto starts = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
|
||||
auto dynamic_update_slice =
|
||||
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
data_shape, gte1, update, starts));
|
||||
// Create output tuple.
|
||||
auto tuple_root = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
|
||||
// Build module and get reference to entry computation.
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
module->AddEntryComputation(builder.Build());
|
||||
// Run BufferLiveness on 'module'.
|
||||
auto liveness =
|
||||
BufferLiveness::Run(module.get(),
|
||||
MakeUnique<DependencyHloOrdering>(module.get()))
|
||||
.ConsumeValueOrDie();
|
||||
// Return whether or not buffers interfernce is detected between
|
||||
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
|
||||
return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
|
||||
}
|
||||
};
|
||||
|
||||
// Tests that live ranges of buffers Param0[1] and Tuple[1] do not overlap in
|
||||
// the following computation (because DynamicUpdateSlice (at operand 0) is the
|
||||
// unique user):
|
||||
//
|
||||
// Parameter0
|
||||
// | |
|
||||
// GTE(0) GTE(1) Const Const
|
||||
// | \ | /
|
||||
// | DynamicUpdateSlice
|
||||
// \ /
|
||||
// Tuple
|
||||
//
|
||||
TEST_F(DynamicUpdateSliceLivenessTest, NoInterference) {
|
||||
EXPECT_FALSE(Run(/*tuple_element1_has_two_uses=*/false));
|
||||
}
|
||||
|
||||
// Tests that live ranges of buffers Param0[1] and Tuple[1] do overlap because
|
||||
// GTE(1) has two users:
|
||||
// 1) DynamicUpdateSlice at operand 0.
|
||||
// 2) Add at operand 1.
|
||||
//
|
||||
// Parameter0
|
||||
// | |
|
||||
// GTE(0) GTE(1)
|
||||
// | / |
|
||||
// | / |
|
||||
// Add | Const Const
|
||||
// | | | |
|
||||
// | DynamicUpdateSlice
|
||||
// \ /
|
||||
// Tuple
|
||||
//
|
||||
TEST_F(DynamicUpdateSliceLivenessTest, WithInterference) {
|
||||
EXPECT_TRUE(Run(/*tuple_element1_has_two_uses=*/true));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla
|
||||
|
@ -283,14 +283,7 @@ bool CanUpdateDynamicSliceInPlace(const BufferAssignment& assignment,
|
||||
return false;
|
||||
}
|
||||
auto* operand = fusion->operand(fusion_operand->parameter_number());
|
||||
|
||||
BufferAllocation::Slice operand_slice =
|
||||
assignment.GetUniqueSlice(operand, index).ConsumeValueOrDie();
|
||||
|
||||
BufferAllocation::Slice fusion_slice =
|
||||
assignment.GetUniqueTopLevelSlice(fusion).ConsumeValueOrDie();
|
||||
|
||||
return operand_slice == fusion_slice;
|
||||
return assignment.SharesSliceAtIndex(fusion, {}, operand, index);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -387,9 +380,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
TF_RETURN_IF_ERROR(root->Accept(&fused_emitter));
|
||||
|
||||
// Recursively lookup 'fusion_operand' for DynamicUpdateSlice operand 0.
|
||||
ShapeIndex index_unused;
|
||||
auto* fusion_operand =
|
||||
LatestNonGteAncestorAndIndex(root->operand(0), &index_unused);
|
||||
auto* fusion_operand = LatestNonGteAncestor(root->operand(0));
|
||||
CHECK_EQ(HloOpcode::kParameter, fusion_operand->opcode());
|
||||
|
||||
// Operand(0) the input array which shares an allocation with the output.
|
||||
|
@ -106,6 +106,7 @@ std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
|
||||
// *) Is a loop fusion instruction where the only use of 'operand' at 'index'
|
||||
// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
|
||||
// at operand 0.
|
||||
// *) Use of 'operand' is DynamicUpdateSlice at operand index 0.
|
||||
bool CanShareOperandBufferWithUser(
|
||||
HloInstruction* operand, const ShapeIndex& operand_index,
|
||||
HloInstruction* user, const ShapeIndex& user_index,
|
||||
@ -143,6 +144,11 @@ bool CanShareOperandBufferWithUser(
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
} else if (user->opcode() == HloOpcode::kDynamicUpdateSlice) {
|
||||
// We eliminated other users in BufferLiveness::live_range_strictly_before,
|
||||
// so here we just need to check that the use is at operand index 0.
|
||||
std::vector<int64> operand_indices = user->OperandIndices(operand);
|
||||
return operand_indices.size() == 1 && operand_indices[0] == 0;
|
||||
}
|
||||
// Check if 'user' is element-wise.
|
||||
return user->IsElementwise();
|
||||
|
Loading…
Reference in New Issue
Block a user