[XLA] Fix a subtle issue in copy_insertion due the interaction between copy
overriding logic and RecordIndicesToColocatingBuffers: - When building instructions ShapeTree to be copy overriden, it is possible that we create a single kCopy for two identical instructions. An example can be: %tuple.19 = tuple(%constant.4, %constant.1793, %constant.1793) where it is used in a while.init operand, and constant.1793 is read-only within the loop and also used by another while loop. The copy overriding pass will then create the following (logical, not finalized) tuple: %tuple.19 = tuple(%constant.4, %copy.5, %copy.5) - In the subsequent pass RecordAmbiguousOrNonDistinctIndices, to add copies to ensure point_to set is distinct, the duplicate %copy.5 are ignored because they are not yet finalized, and these indices (1 and 2 in the example) are still marked as to-be copied. Therefore distinctiveness is lost. This fix applies to the override building stage, to explicitly avoid creating shared copies for non-distinct buffers. PiperOrigin-RevId: 157872231
This commit is contained in:
parent
f4b8d21b8e
commit
366990d92d
@ -229,25 +229,26 @@ Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices(
|
||||
// Mapping from LogicalBuffer to index (used to detect non-distinct indices).
|
||||
FlatMap<const LogicalBuffer*, std::vector<ShapeIndex>>
|
||||
buffer_to_source_indices;
|
||||
TF_RETURN_IF_ERROR(points_to.ForEachElement([this, &buffer_to_source_indices](
|
||||
const ShapeIndex& index, bool /*is_leaf*/,
|
||||
const std::vector<const LogicalBuffer*>& buffers) {
|
||||
if (buffers.size() > 1) {
|
||||
// Record ambiguous points-to set at 'index'.
|
||||
if (!indices_to_copy_.element(index)) {
|
||||
VLOG(2) << "Adding copy of buffer for instruction: "
|
||||
<< instruction_->name()
|
||||
<< " at index: " << tensorflow::str_util::Join(index, ",")
|
||||
<< " with ambiguous points-to set.";
|
||||
RecordIndex(index);
|
||||
}
|
||||
}
|
||||
// For each 'buffer': record a mapping from 'buffer' to 'index'.
|
||||
for (const LogicalBuffer* buffer : buffers) {
|
||||
buffer_to_source_indices[buffer].push_back(index);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
TF_RETURN_IF_ERROR(points_to.ForEachElement(
|
||||
[this, &buffer_to_source_indices](
|
||||
const ShapeIndex& index, bool /*is_leaf*/,
|
||||
const std::vector<const LogicalBuffer*>& buffers) {
|
||||
if (buffers.size() > 1) {
|
||||
// Record ambiguous points-to set at 'index'.
|
||||
if (!indices_to_copy_.element(index)) {
|
||||
VLOG(2) << "Adding copy of buffer for instruction: "
|
||||
<< instruction_->name()
|
||||
<< " at index: " << tensorflow::str_util::Join(index, ",")
|
||||
<< " with ambiguous points-to set.";
|
||||
RecordIndex(index);
|
||||
}
|
||||
}
|
||||
// For each 'buffer': record a mapping from 'buffer' to 'index'.
|
||||
for (const LogicalBuffer* buffer : buffers) {
|
||||
buffer_to_source_indices[buffer].push_back(index);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
// Record all non-distinct indices detected in 'buffer_to_source_indices'.
|
||||
for (const auto& buff_to_src : buffer_to_source_indices) {
|
||||
@ -449,11 +450,15 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
|
||||
FlatMap<const HloInstruction*, HloInstruction*>* shared_copies) {
|
||||
const HloInstruction* init_hlo = while_hlo->operand(0);
|
||||
const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo);
|
||||
|
||||
// Mapping from LogicalBuffer to index (used to detect non-distinct indices).
|
||||
FlatSet<const LogicalBuffer*> buffer_set;
|
||||
|
||||
ShapeTree<HloInstruction*> copy_overrides(init_hlo->shape());
|
||||
TF_RETURN_IF_ERROR(points_to.ForEachElement(
|
||||
[init_hlo, read_only_indices, shared_copies, ©_overrides](
|
||||
const ShapeIndex& index, bool /*is_leaf*/,
|
||||
const std::vector<const LogicalBuffer*>& buffers) {
|
||||
[init_hlo, read_only_indices, shared_copies, &buffer_set,
|
||||
©_overrides](const ShapeIndex& index, bool /*is_leaf*/,
|
||||
const std::vector<const LogicalBuffer*>& buffers) {
|
||||
// Look for read-only entry parameters.
|
||||
if (!read_only_indices->element(index)) {
|
||||
return Status::OK();
|
||||
@ -468,6 +473,7 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
|
||||
if (!is_entry_parameter && !is_constant) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// We have found an entry parameter or constant that is read-only in
|
||||
// the while body. These buffers are managed by the caller, and cannot
|
||||
// be aliased with non-parameter buffers. Revert this read-only index,
|
||||
@ -476,16 +482,17 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
|
||||
|
||||
// Optimization to allow multiple while loops that share the same
|
||||
// read-only entry parameters (or constants) to share a single copy.
|
||||
// Only unambiguous array-shaped buffers are allowed, to reduce code
|
||||
// complexity. The shape of the entry parameter must be identical to
|
||||
// the shape of the init_hlo at this index, to ensure there were no
|
||||
// intervening bitcast or GTE instructions, which are also hard to
|
||||
// handle.
|
||||
// Only unambiguous and distinct array-shaped buffers are allowed, to
|
||||
// reduce code complexity. The shape of the entry parameter must be
|
||||
// identical to the shape of the init_hlo at this index, to ensure
|
||||
// there were no intervening bitcast or GTE instructions, which are
|
||||
// also hard to handle.
|
||||
const Shape& pointee_shape = pointee->shape();
|
||||
const Shape& init_shape =
|
||||
ShapeUtil::GetSubshape(init_hlo->shape(), index);
|
||||
if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) &&
|
||||
ShapeUtil::Equal(pointee_shape, init_shape)) {
|
||||
ShapeUtil::Equal(pointee_shape, init_shape) &&
|
||||
buffer_set.count(buffer) < 1) {
|
||||
HloInstruction** copy = &(*shared_copies)[pointee];
|
||||
if (*copy == nullptr) {
|
||||
*copy =
|
||||
@ -496,6 +503,9 @@ RevertReadOnlyIndicesForEntryParamsAndConstants(
|
||||
*copy_overrides.mutable_element(index) = *copy;
|
||||
}
|
||||
|
||||
// Tracks whether this current buffer is distinct.
|
||||
buffer_set.insert(buffer);
|
||||
|
||||
// We've already reverted the read-only index and handled the
|
||||
// single-copy optimization above, so there's nothing more to do.
|
||||
break;
|
||||
|
@ -44,13 +44,20 @@ class CopyInsertionTest : public HloTestBase {
|
||||
EXPECT_IS_OK(copy_insertion.Run(module).status());
|
||||
|
||||
// Verify the points to set of the root of the computation after copy
|
||||
// insertion contains no constants or parameters.
|
||||
// insertion contains no constants or parameters, and is distinct and
|
||||
// non-ambiguous.
|
||||
auto points_to_analysis =
|
||||
TuplePointsToAnalysis::Run(module).ConsumeValueOrDie();
|
||||
const auto& points_to = points_to_analysis->GetPointsToSet(
|
||||
module->entry_computation()->root_instruction());
|
||||
EXPECT_TRUE(points_to.IsDistinct());
|
||||
EXPECT_TRUE(!points_to.IsAmbiguous());
|
||||
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*> maybe_live_out_buffers =
|
||||
points_to_analysis
|
||||
->GetPointsToSet(module->entry_computation()->root_instruction())
|
||||
.CreateFlattenedSet();
|
||||
|
||||
for (const LogicalBuffer* buffer : maybe_live_out_buffers) {
|
||||
EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant);
|
||||
EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter);
|
||||
@ -390,6 +397,47 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
// Builds a While body computation with two output tuple elements dependent on
|
||||
// both input tuple elements.
|
||||
//
|
||||
// EX: Body({in0, in1, in2})
|
||||
// out0 = Add(in0, 1)
|
||||
// out1 = in1
|
||||
// out2 = in2
|
||||
// Tuple(out0, out1, out2)
|
||||
std::unique_ptr<HloComputation> BuildDependentBodyComputation2() {
|
||||
auto builder = HloComputation::Builder(TestName() + ".Body");
|
||||
|
||||
const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
|
||||
{induction_variable_shape_, data_shape_, data_shape_});
|
||||
|
||||
auto loop_state = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
|
||||
|
||||
// Update the induction variable GTE(0).
|
||||
auto induction_variable =
|
||||
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
induction_variable_shape_, loop_state, 0));
|
||||
auto inc = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
|
||||
|
||||
// add0 = Add(in0, 1)
|
||||
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
|
||||
// data1 = GTE(1).
|
||||
HloInstruction* data1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
|
||||
|
||||
// data2 = GTE(2).
|
||||
HloInstruction* data2 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 2));
|
||||
|
||||
// Create output Tuple.
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({add0, data1, data2}));
|
||||
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
// Builds a While body computation with read-only tuple element 0.
|
||||
// EX:
|
||||
// Body({in0, in1})
|
||||
@ -408,6 +456,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
// Update data GTE(1).
|
||||
auto data = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
|
||||
|
||||
// Use 'induction_variable' in computation with no path to output tuple.
|
||||
auto update = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
|
||||
@ -431,6 +480,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
|
||||
// Create param instruction to access loop state.
|
||||
const Shape& loop_state_shape =
|
||||
nested ? nested_loop_state_shape_ : loop_state_shape_;
|
||||
|
||||
auto loop_state = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
|
||||
// Update the induction variable GTE(0).
|
||||
@ -972,7 +1022,8 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
|
||||
op::Copy(old_init->operand(1)->operand(0)))));
|
||||
}
|
||||
|
||||
// Tests while init instruction buffer which interfers with while result buffer.
|
||||
// Tests while init instruction buffer which interferes with while result
|
||||
// buffer.
|
||||
//
|
||||
// init_data = Broadcast(...)
|
||||
// add_unrelated = Add(init_data) // takes a reference to cause interference
|
||||
@ -989,5 +1040,81 @@ TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
|
||||
op::Copy(old_init->operand(1))));
|
||||
}
|
||||
|
||||
// Tests while init instruction buffer which has a non-distinct points-to set:
|
||||
//
|
||||
// init = Tuple(Parameter(S32, {}), Parameter(F32, {8},
|
||||
// Parameter(F32, {8})))
|
||||
//
|
||||
// where the second and third parameters are identical *and* the tuple shared
|
||||
// by another while instruction..
|
||||
//
|
||||
// Verifies that the resulting point-to set is distinct in the resulting Tuple
|
||||
// (non-identical Copys). In other words, verifies that copy sharing does not
|
||||
// insert identical copies to the resulting tuple.
|
||||
TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
|
||||
auto condition1 = module_.AddEmbeddedComputation(BuildConditionComputation());
|
||||
auto condition2 = module_.AddEmbeddedComputation(BuildConditionComputation());
|
||||
// Loop body that outputs tuple comprises two elements dependent on the init
|
||||
// tuple.
|
||||
auto body1 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2());
|
||||
auto body2 = module_.AddEmbeddedComputation(BuildDependentBodyComputation2());
|
||||
|
||||
auto builder = HloComputation::Builder(TestName() + ".While");
|
||||
|
||||
auto iter_param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, induction_variable_shape_, "iter"));
|
||||
auto data_param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, data_shape_, "data"));
|
||||
|
||||
// Loop init tuple contains two identical parameter buffers.
|
||||
auto loop_init = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({iter_param, data_param, data_param}));
|
||||
|
||||
const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
|
||||
{induction_variable_shape_, data_shape_, data_shape_});
|
||||
|
||||
// Two while loops shares the same loop init tuple.
|
||||
auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
loop_state_shape, condition1, body1, loop_init));
|
||||
auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
|
||||
loop_state_shape, condition2, body2, loop_init));
|
||||
|
||||
module_.AddEntryComputation(builder.Build());
|
||||
|
||||
auto points_to_analysis =
|
||||
TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie();
|
||||
|
||||
// Asserts that the init tuples before copy insertion is non-distinct.
|
||||
ASSERT_FALSE(
|
||||
points_to_analysis->GetPointsToSet(while_hlo1->operand(0)).IsDistinct());
|
||||
ASSERT_FALSE(
|
||||
points_to_analysis->GetPointsToSet(while_hlo2->operand(0)).IsDistinct());
|
||||
|
||||
auto old_init1 = while_hlo1->operand(0);
|
||||
auto old_init2 = while_hlo2->operand(0);
|
||||
|
||||
InsertCopies(&module_);
|
||||
|
||||
EXPECT_THAT(while_hlo1->operand(0),
|
||||
op::Tuple(op::Copy(old_init1->operand(0)),
|
||||
op::Copy(old_init1->operand(1)),
|
||||
op::Copy(old_init1->operand(2))));
|
||||
|
||||
EXPECT_THAT(while_hlo2->operand(0),
|
||||
op::Tuple(op::Copy(old_init2->operand(0)),
|
||||
op::Copy(old_init2->operand(1)),
|
||||
op::Copy(old_init2->operand(2))));
|
||||
|
||||
// Verifies the init tuples after copy insertion is distinct.
|
||||
points_to_analysis = TuplePointsToAnalysis::Run(&module_).ConsumeValueOrDie();
|
||||
const auto& points_to1 =
|
||||
points_to_analysis->GetPointsToSet(while_hlo1->operand(0));
|
||||
EXPECT_TRUE(points_to1.IsDistinct());
|
||||
|
||||
const auto& points_to2 =
|
||||
points_to_analysis->GetPointsToSet(while_hlo2->operand(0));
|
||||
EXPECT_TRUE(points_to2.IsDistinct());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user