Internal change.

Change: 151587999
This commit is contained in:
A. Unique TensorFlower 2017-03-29 09:00:26 -08:00 committed by TensorFlower Gardener
parent 19526c1318
commit 4718ac6b15
6 changed files with 112 additions and 11 deletions

View File

@ -212,6 +212,13 @@ StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice(
return GetUniqueSlice(instruction, /*index=*/{});
}
bool BufferAssignment::SharesSliceAtIndex(
const HloInstruction* hlo_a, const ShapeIndex& shape_index_a,
const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const {
return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() ==
GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie();
}
StatusOr<BufferAllocation::Slice>
BufferAssignment::GetUniqueTopLevelOutputSlice() const {
return GetUniqueTopLevelSlice(

View File

@ -294,6 +294,15 @@ class BufferAssignment {
return GetPointsToSet(instruction).element(index);
}
// Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}'
// share the same BufferAllocation::Slice.
// Returns false otherwise.
// REQUIRES: BufferAssignment assigned allocations to both instructions.
bool SharesSliceAtIndex(const HloInstruction* hlo_a,
const ShapeIndex& shape_index_a,
const HloInstruction* hlo_b,
const ShapeIndex& shape_index_b) const;
// Returns the underlying points-to analysis used for this assignment.
const TuplePointsToAnalysis& points_to_analysis() const {
return liveness_->points_to_analysis();

View File

@ -121,6 +121,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
// *) Is element-wise.
// *) Is a loop fusion instruction (with DynamicUpdateSlice fused root) where
// the singleton use of 'a' at 'a.index' is the fused root at operand 0.
// *) Use of 'operand' is DynamicUpdateSlice at operand index 0.
for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) {
if (b.instruction()->IsUserOf(alias.instruction()) &&
!CanShareOperandBufferWithUser(alias.instruction(), alias.index(),

View File

@ -612,6 +612,93 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) {
EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true));
}
class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
protected:
// Builds and runs a computation (see test case computation graphs below).
// Runs BufferLiveness on this computation.
// Returns whether buffer interference is detected between tuple-shaped
// parameter and root instructions at tuple element 1.
bool Run(const bool tuple_element1_has_two_uses) {
auto builder = HloComputation::Builder(TestName());
// Create param0 Tuple.
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
Shape update_shape = ShapeUtil::MakeShape(F32, {3});
auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0"));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
if (tuple_element1_has_two_uses) {
// Add 'gte0' and 'gte1' to create another user of 'gte1'.
gte0 = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape, HloOpcode::kAdd, gte0, gte1));
}
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
auto starts = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
// Create output tuple.
auto tuple_root = builder.AddInstruction(
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
// Build module and get reference to entry computation.
auto module = MakeUnique<HloModule>(TestName());
module->AddEntryComputation(builder.Build());
// Run BufferLiveness on 'module'.
auto liveness =
BufferLiveness::Run(module.get(),
MakeUnique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Return whether or not buffers interfernce is detected between
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
}
};
// Tests that live ranges of buffers Param0[1] and Tuple[1] do not overlap in
// the following computation (because DynamicUpdateSlice (at operand 0) is the
// unique user):
//
// Parameter0
// | |
// GTE(0) GTE(1) Const Const
// | \ | /
// | DynamicUpdateSlice
// \ /
// Tuple
//
TEST_F(DynamicUpdateSliceLivenessTest, NoInterference) {
EXPECT_FALSE(Run(/*tuple_element1_has_two_uses=*/false));
}
// Tests that live ranges of buffers Param0[1] and Tuple[1] do overlap because
// GTE(1) has two users:
// 1) DynamicUpdateSlice at operand 0.
// 2) Add at operand 1.
//
// Parameter0
// | |
// GTE(0) GTE(1)
// | / |
// | / |
// Add | Const Const
// | | | |
// | DynamicUpdateSlice
// \ /
// Tuple
//
TEST_F(DynamicUpdateSliceLivenessTest, WithInterference) {
EXPECT_TRUE(Run(/*tuple_element1_has_two_uses=*/true));
}
} // namespace
} // namespace xla

View File

@ -283,14 +283,7 @@ bool CanUpdateDynamicSliceInPlace(const BufferAssignment& assignment,
return false;
}
auto* operand = fusion->operand(fusion_operand->parameter_number());
BufferAllocation::Slice operand_slice =
assignment.GetUniqueSlice(operand, index).ConsumeValueOrDie();
BufferAllocation::Slice fusion_slice =
assignment.GetUniqueTopLevelSlice(fusion).ConsumeValueOrDie();
return operand_slice == fusion_slice;
return assignment.SharesSliceAtIndex(fusion, {}, operand, index);
}
} // namespace
@ -387,9 +380,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
TF_RETURN_IF_ERROR(root->Accept(&fused_emitter));
// Recursively lookup 'fusion_operand' for DynamicUpdateSlice operand 0.
ShapeIndex index_unused;
auto* fusion_operand =
LatestNonGteAncestorAndIndex(root->operand(0), &index_unused);
auto* fusion_operand = LatestNonGteAncestor(root->operand(0));
CHECK_EQ(HloOpcode::kParameter, fusion_operand->opcode());
// Operand(0) the input array which shares an allocation with the output.

View File

@ -106,6 +106,7 @@ std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
// *) Is a loop fusion instruction where the only use of 'operand' at 'index'
// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
// at operand 0.
// *) Use of 'operand' is DynamicUpdateSlice at operand index 0.
bool CanShareOperandBufferWithUser(
HloInstruction* operand, const ShapeIndex& operand_index,
HloInstruction* user, const ShapeIndex& user_index,
@ -143,6 +144,11 @@ bool CanShareOperandBufferWithUser(
break;
}
return false;
} else if (user->opcode() == HloOpcode::kDynamicUpdateSlice) {
// We eliminated other users in BufferLiveness::live_range_strictly_before,
// so here we just need to check that the use is at operand index 0.
std::vector<int64> operand_indices = user->OperandIndices(operand);
return operand_indices.size() == 1 && operand_indices[0] == 0;
}
// Check if 'user' is element-wise.
return user->IsElementwise();