[XLA] Fix a subtle issue in copy_insertion due the interaction between copy

overriding logic and RecordIndicesToColocatingBuffers:

- When building instructions ShapeTree to be copy overriden, it is possible
that we create a single kCopy for two identical instructions. An example can
be:

    %tuple.19 = tuple(%constant.4, %constant.1793, %constant.1793)

where it is used in a while.init operand, and constant.1793 is read-only within
the loop and also used by another while loop. The copy overriding pass will then
create the following (logical, not finalized) tuple:

    %tuple.19 = tuple(%constant.4, %copy.5, %copy.5)

- In the subsequent pass RecordAmbiguousOrNonDistinctIndices, to add copies to
ensure point_to set is distinct, the duplicate %copy.5 are ignored because they
are not yet finalized, and these indices (1 and 2 in the example) are still
marked as to-be copied.

Therefore distinctiveness is lost.

This fix applies to the override building stage, to explicitly avoid creating
shared copies for non-distinct buffers.

PiperOrigin-RevId: 157872231
This commit is contained in:
Kay Zhu 2017-06-02 14:02:25 -07:00 committed by TensorFlower Gardener
parent f4b8d21b8e
commit 366990d92d
2 changed files with 167 additions and 30 deletions

View File

@ -229,25 +229,26 @@ Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices(
// Mapping from LogicalBuffer to index (used to detect non-distinct indices).
FlatMap<const LogicalBuffer*, std::vector<ShapeIndex>>
buffer_to_source_indices;
TF_RETURN_IF_ERROR(points_to.ForEachElement([this, &buffer_to_source_indices](
const ShapeIndex& index, bool /*is_leaf*/,
const std::vector<const LogicalBuffer*>& buffers) {
if (buffers.size() > 1) {
// Record ambiguous points-to set at 'index'.
if (!indices_to_copy_.element(index)) {
VLOG(2) << "Adding copy of buffer for instruction: "
<< instruction_->name()
<< " at index: " << tensorflow::str_util::Join(index, ",")
<< " with ambiguous points-to set.";
RecordIndex(index);
}
}
// For each 'buffer': record a mapping from 'buffer' to 'index'.
for (const LogicalBuffer* buffer : buffers) {
buffer_to_source_indices[buffer].push_back(index);
}
return Status::OK();
}));
TF_RETURN_IF_ERROR(points_to.ForEachElement(
[this, &buffer_to_source_indices](
const ShapeIndex& index, bool /*is_leaf*/,
const std::vector<const LogicalBuffer*>& buffers) {
if (buffers.size() > 1) {
// Record ambiguous points-to set at 'index'.
if (!indices_to_copy_.element(index)) {
VLOG(2) << "Adding copy of buffer for instruction: "
<< instruction_->name()
<< " at index: " << tensorflow::str_util::Join(index, ",")
<< " with ambiguous points-to set.";
RecordIndex(index);
}
}
// For each 'buffer': record a mapping from 'buffer' to 'index'.
for (const LogicalBuffer* buffer : buffers) {
buffer_to_source_indices[buffer].push_back(index);
}
return Status::OK();
}));
// Record all non-distinct indices detected in 'buffer_to_source_indices'.
for (const auto& buff_to_src : buffer_to_source_indices) {
@ -449,11 +450,15 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
FlatMap<const HloInstruction*, HloInstruction*>* shared_copies) {
const HloInstruction* init_hlo = while_hlo->operand(0);
const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo);
// Mapping from LogicalBuffer to index (used to detect non-distinct indices).
FlatSet<const LogicalBuffer*> buffer_set;
ShapeTree<HloInstruction*> copy_overrides(init_hlo->shape());
TF_RETURN_IF_ERROR(points_to.ForEachElement(
[init_hlo, read_only_indices, shared_copies, &copy_overrides](
const ShapeIndex& index, bool /*is_leaf*/,
const std::vector<const LogicalBuffer*>& buffers) {
[init_hlo, read_only_indices, shared_copies, &buffer_set,
&copy_overrides](const ShapeIndex& index, bool /*is_leaf*/,
const std::vector<const LogicalBuffer*>& buffers) {
// Look for read-only entry parameters.
if (!read_only_indices->element(index)) {
return Status::OK();
@ -468,6 +473,7 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
if (!is_entry_parameter && !is_constant) {
continue;
}
// We have found an entry parameter or constant that is read-only in
// the while body. These buffers are managed by the caller, and cannot
// be aliased with non-parameter buffers. Revert this read-only index,
@ -476,16 +482,17 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
// Optimization to allow multiple while loops that share the same
// read-only entry parameters (or constants) to share a single copy.
// Only unambiguous array-shaped buffers are allowed, to reduce code
// complexity. The shape of the entry parameter must be identical to
// the shape of the init_hlo at this index, to ensure there were no
// intervening bitcast or GTE instructions, which are also hard to
// handle.
// Only unambiguous and distinct array-shaped buffers are allowed, to
// reduce code complexity. The shape of the entry parameter must be
// identical to the shape of the init_hlo at this index, to ensure
// there were no intervening bitcast or GTE instructions, which are
// also hard to handle.
const Shape& pointee_shape = pointee->shape();
const Shape& init_shape =
ShapeUtil::GetSubshape(init_hlo->shape(), index);
if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) &&
ShapeUtil::Equal(pointee_shape, init_shape)) {
ShapeUtil::Equal(pointee_shape, init_shape) &&
buffer_set.count(buffer) < 1) {
HloInstruction** copy = &(*shared_copies)[pointee];
if (*copy == nullptr) {
*copy =
@ -496,6 +503,9 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
*copy_overrides.mutable_element(index) = *copy;
}
// Tracks whether this current buffer is distinct.
buffer_set.insert(buffer);
// We've already reverted the read-only index and handled the
// single-copy optimization above, so there's nothing more to do.
break;

View File

@ -44,13 +44,20 @@ class CopyInsertionTest : public HloTestBase {
EXPECT_IS_OK(copy_insertion.Run(module).status());
// Verify the points to set of the root of the computation after copy
// insertion contains no constants or parameters.
// insertion contains no constants or parameters, and is distinct and
// non-ambiguous.
auto points_to_analysis =
TuplePointsToAnalysis::Run(module).ConsumeValueOrDie();
const auto& points_to = points_to_analysis->GetPointsToSet(
module->entry_computation()->root_instruction());
EXPECT_TRUE(points_to.IsDistinct());
EXPECT_TRUE(!points_to.IsAmbiguous());
tensorflow::gtl::FlatSet<const LogicalBuffer*> maybe_live_out_buffers =
points_to_analysis
->GetPointsToSet(module->entry_computation()->root_instruction())
.CreateFlattenedSet();
for (const LogicalBuffer* buffer : maybe_live_out_buffers) {
EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant);
EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter);
@ -390,6 +397,47 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
return builder.Build();
}
// Builds a While body computation with two output tuple elements dependent on
// both input tuple elements.
//
// EX: Body({in0, in1, in2})
// out0 = Add(in0, 1)
// out1 = in1
// out2 = in2
// Tuple(out0, out1, out2)
std::unique_ptr<HloComputation> BuildDependentBodyComputation2() {
auto builder = HloComputation::Builder(TestName() + ".Body");
const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
{induction_variable_shape_, data_shape_, data_shape_});
auto loop_state = builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
// Update the induction variable GTE(0).
auto induction_variable =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, 0));
auto inc = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
// add0 = Add(in0, 1)
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
// data1 = GTE(1).
HloInstruction* data1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
// data2 = GTE(2).
HloInstruction* data2 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2));
// Create output Tuple.
builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2}));
return builder.Build();
}
// Builds a While body computation with read-only tuple element 0.
// EX:
// Body({in0, in1})
@ -408,6 +456,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
// Update data GTE(1).
auto data = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
// Use 'induction_variable' in computation with no path to output tuple.
auto update = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
@ -431,6 +480,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
// Create param instruction to access loop state.
const Shape& loop_state_shape =
nested ? nested_loop_state_shape_ : loop_state_shape_;
auto loop_state = builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
// Update the induction variable GTE(0).
@ -972,7 +1022,8 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
op::Copy(old_init->operand(1)->operand(0)))));
}
// Tests while init instruction buffer which interfers with while result buffer.
// Tests while init instruction buffer which interferes with while result
// buffer.
//
// init_data = Broadcast(...)
// add_unrelated = Add(init_data) // takes a reference to cause interference
@ -989,5 +1040,81 @@ TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
op::Copy(old_init->operand(1))));
}
// Tests while init instruction buffer which has a non-distinct points-to set:
//
// init = Tuple(Parameter(S32, {}), Parameter(F32, {8},
// Parameter(F32, {8})))
//
// where the second and third parameters are identical *and* the tuple shared
// by another while instruction..
//
// Verifies that the resulting point-to set is distinct in the resulting Tuple
// (non-identical Copys). In other words, verifies that copy sharing does not
// insert identical copies to the resulting tuple.
TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
auto condition1 = module_.AddEmbeddedComputation(BuildConditionComputation());
auto condition2 = module_.AddEmbeddedComputation(BuildConditionComputation());
// Loop body that outputs tuple comprises two elements dependent on the init
// tuple.
auto body1 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2());
auto body2 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2());
auto builder = HloComputation::Builder(TestName() + ".While");
auto iter_param = builder.AddInstruction(
HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
auto data_param = builder.AddInstruction(
HloInstruction::CreateParameter(1, data_shape_, "data"));
// Loop init tuple contains two identical parameter buffers.
auto loop_init = builder.AddInstruction(
HloInstruction::CreateTuple({iter_param, data_param, data_param}));
const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
{induction_variable_shape_, data_shape_, data_shape_});
// Two while loops shares the same loop init tuple.
auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape, condition1, body1, loop_init));
auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape, condition2, body2, loop_init));
module_.AddEntryComputation(builder.Build());
auto points_to_analysis =
TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie();
// Asserts that the init tuples before copy insertion is non-distinct.
ASSERT_FALSE(
points_to_analysis->GetPointsToSet(while_hlo1->operand(0)).IsDistinct());
ASSERT_FALSE(
points_to_analysis->GetPointsToSet(while_hlo2->operand(0)).IsDistinct());
auto old_init1 = while_hlo1->operand(0);
auto old_init2 = while_hlo2->operand(0);
InsertCopies(&module_);
EXPECT_THAT(while_hlo1->operand(0),
op::Tuple(op::Copy(old_init1->operand(0)),
op::Copy(old_init1->operand(1)),
op::Copy(old_init1->operand(2))));
EXPECT_THAT(while_hlo2->operand(0),
op::Tuple(op::Copy(old_init2->operand(0)),
op::Copy(old_init2->operand(1)),
op::Copy(old_init2->operand(2))));
// Verifies the init tuples after copy insertion is distinct.
points_to_analysis = TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie();
const auto& points_to1 =
points_to_analysis->GetPointsToSet(while_hlo1->operand(0));
EXPECT_TRUE(points_to1.IsDistinct());
const auto& points_to2 =
points_to_analysis->GetPointsToSet(while_hlo2->operand(0));
EXPECT_TRUE(points_to2.IsDistinct());
}
} // namespace
} // namespace xla