[XLA] Change CopyStart return type to (target, source, context)

This is to fix a memory corruption issue where the source buffer may be reused
for another HLO. Having the source in the return type of CopyStart makes it
explicit that the source buffer may still be in use until CopyDone.

PiperOrigin-RevId: 292034096
Change-Id: Id5e0546100410eb28c50554122c166e081f885af
This commit is contained in:
Berkin Ilbeyi 2020-01-28 16:05:26 -08:00 committed by TensorFlower Gardener
parent 77ae99f06d
commit 7189185ec0
15 changed files with 115 additions and 42 deletions

View File

@ -380,6 +380,19 @@ bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
return changed;
}
bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) {
CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart);
bool changed = false;
// CopyStart forwards the operand value to element {1} of its output.
const HloValueSet& operand_value_set = GetValueSet(copy_start->operand(0));
HloValueSet& value_set = GetValueSet(copy_start, {1});
if (value_set != operand_value_set) {
value_set = operand_value_set;
changed = true;
}
return changed;
}
bool HloDataflowAnalysis::UpdateCopyDoneValueSet(HloInstruction* copy_done) {
CHECK_EQ(copy_done->opcode(), HloOpcode::kCopyDone);
bool changed = false;
@ -682,6 +695,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
return UpdateSendValueSet(instruction);
case HloOpcode::kRecvDone:
return UpdateRecvDoneValueSet(instruction);
case HloOpcode::kCopyStart:
return UpdateCopyStartValueSet(instruction);
case HloOpcode::kCopyDone:
return UpdateCopyDoneValueSet(instruction);
case HloOpcode::kConditional:
@ -863,9 +878,16 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
// values flow from their operands.
define_value_at(/*index=*/{});
break;
case HloOpcode::kCopyStart:
// CopyStart produces a tuple of {destination buffer, aliased operand,
// U32 context}.
define_value_at(/*index=*/{});
define_value_at(/*index=*/{0});
define_value_at(/*index=*/{2});
break;
case HloOpcode::kCopyDone:
// CopyDone produces an element. Its output aliases its input tuple
// element {0}; element one is a context.
// CopyDone consumes a tuple produced by CopyStart and produces an
// element. Its output aliases its input tuple element {0}.
break;
case HloOpcode::kRecvDone:
// RecvDone produces a two-element tuple. Element zero aliases its

View File

@ -189,6 +189,7 @@ class HloDataflowAnalysis {
bool UpdateDomainValueSet(HloInstruction* domain);
bool UpdateGetTupleElementValueSet(HloInstruction* gte);
bool UpdateParameterValueSet(HloInstruction* parameter);
bool UpdateCopyStartValueSet(HloInstruction* copy_start);
bool UpdateCopyDoneValueSet(HloInstruction* copy_done);
bool UpdateRecvDoneValueSet(HloInstruction* recv_done);
bool UpdateTupleSelectValueSet(HloInstruction* select);

View File

@ -1177,8 +1177,8 @@ TEST_P(HloDataflowAnalysisTest, CopyStartAndCopyDone) {
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto copy_start = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeTupleShape(
{constant->shape(), ShapeUtil::MakeShape(U32, {})}),
ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(),
ShapeUtil::MakeShape(U32, {})}),
HloOpcode::kCopyStart, constant));
auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kCopyDone, copy_start));
@ -1192,7 +1192,8 @@ TEST_P(HloDataflowAnalysisTest, CopyStartAndCopyDone) {
EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{0}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{1}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{1}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{2}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(copy_done, /*index=*/{}));
EXPECT_THAT(
HloValuesAt(copy_done, /*index=*/{}),

View File

@ -1872,14 +1872,15 @@ Status HloEvaluator::HandleCopyStart(HloInstruction* copy_start) {
"user.");
}
// The token in index {1} is undefined, but since we can't represent undefined
// values using a Literal, we just use 0. This should be safe though since we
// ensure that the only user of a kCopyStart is a kCopyDone which "eats" the
// token. Also note that MakeTuple copies its arguments, so this is
// memory-safe.
const Literal token_literal = LiteralUtil::CreateR0<uint32>(0);
// The context in index {2} is undefined, but since we can't represent
// undefined values using a Literal, we just use 0. This should be safe though
// since we ensure that the only user of a kCopyStart is a kCopyDone which
// consumes the context. Also note that MakeTuple copies its arguments, so
// this is memory-safe.
const Literal context_literal = LiteralUtil::CreateR0<uint32>(0);
evaluated_[copy_start] = LiteralUtil::MakeTuple(
{&GetEvaluatedLiteralFor(copy_start->operand(0)), &token_literal});
{&GetEvaluatedLiteralFor(copy_start->operand(0)),
&GetEvaluatedLiteralFor(copy_start->operand(0)), &context_literal});
return Status::OK();
}

View File

@ -4431,7 +4431,7 @@ TEST_F(HloEvaluatorTest, CopyStartCopyDone) {
HloModule test
ENTRY CopyStartCopyDone {
init = f32[] constant(42.0)
copy-start = (f32[]{:S(1)}, u32[]) copy-start(init)
copy-start = (f32[]{:S(1)}, f32[], u32[]) copy-start(init)
ROOT copy-done = f32[] copy-done(copy-start)
}
)";

View File

@ -278,7 +278,7 @@ TEST_F(HloMatchersTest, AsyncCopyMatcher) {
auto p0 = HloInstruction::CreateParameter(0, shape_memspace1, "p0");
auto copy_start = HloInstruction::CreateUnary(
ShapeUtil::MakeTupleShape(
{shape_memspace2, ShapeUtil::MakeShape(U32, {})}),
{shape_memspace2, shape_memspace1, ShapeUtil::MakeShape(U32, {})}),
HloOpcode::kCopyStart, p0.get());
auto copy_done = HloInstruction::CreateUnary(
shape_memspace2, HloOpcode::kCopyDone, copy_start.get());
@ -286,18 +286,18 @@ TEST_F(HloMatchersTest, AsyncCopyMatcher) {
EXPECT_THAT(copy_done.get(), op::AsyncCopy(2, 1, op::Parameter(0)));
EXPECT_THAT(Explain(copy_start.get(), op::AsyncCopy(2, 1, op::Parameter(0))),
Eq("(%copy-start = (f32[16]{0:S(2)}, u32[]) "
Eq("(%copy-start = (f32[16]{0:S(2)}, f32[16]{0:S(1)}, u32[]) "
"copy-start(f32[16]{0:S(1)} %p0))"));
EXPECT_THAT(
Explain(copy_done.get(), op::AsyncCopy(3, 1, op::Parameter(0))),
"(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) "
"%copy-start)) "
"copies to memory space 2, expected 3");
EXPECT_THAT(
Explain(copy_done.get(), op::AsyncCopy(2, 3, op::Parameter(0))),
"(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) "
"%copy-start)) "
"is in the memory space 1, expected 3");
EXPECT_THAT(Explain(copy_done.get(), op::AsyncCopy(3, 1, op::Parameter(0))),
"(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, "
"f32[16]{0:S(1)}, u32[]) "
"%copy-start)) "
"copies to memory space 2, expected 3");
EXPECT_THAT(Explain(copy_done.get(), op::AsyncCopy(2, 3, op::Parameter(0))),
"(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, "
"f32[16]{0:S(1)}, u32[]) "
"%copy-start)) "
"is in the memory space 1, expected 3");
}
} // namespace

View File

@ -317,11 +317,11 @@ R"(HloModule CopyStartAndCopyDone_module
ENTRY %CopyStartAndCopyDone (v1: f32[], v2: f32[2,3]) -> (f32[], f32[2,3]) {
%v1 = f32[] parameter(0)
%copy-start.1 = (f32[], u32[]) copy-start(f32[] %v1)
%copy-done.1 = f32[] copy-done((f32[], u32[]) %copy-start.1)
%copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1)
%copy-done.1 = f32[] copy-done((f32[], f32[], u32[]) %copy-start.1)
%v2 = f32[2,3]{1,0:S(1)} parameter(1)
%copy-start.2 = (f32[2,3]{1,0:S(2)}, u32[]) copy-start(f32[2,3]{1,0:S(1)} %v2)
%copy-done.2 = f32[2,3]{1,0:S(2)} copy-done((f32[2,3]{1,0:S(2)}, u32[]) %copy-start.2)
%copy-start.2 = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(f32[2,3]{1,0:S(1)} %v2)
%copy-done.2 = f32[2,3]{1,0:S(2)} copy-done((f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) %copy-start.2)
ROOT %tuple = (f32[], f32[2,3]{1,0:S(2)}) tuple(f32[] %copy-done.1, f32[2,3]{1,0:S(2)} %copy-done.2)
}

View File

@ -817,11 +817,24 @@ Status ShapeVerifier::HandlePad(HloInstruction* pad) {
Status ShapeVerifier::HandleCopyStart(HloInstruction* copy_start) {
return CheckShape(copy_start,
ShapeUtil::MakeTupleShape({copy_start->operand(0)->shape(),
copy_start->operand(0)->shape(),
ShapeUtil::MakeShape(U32, {})}),
/*only_compare_minor_to_major_in_layout=*/true);
}
Status ShapeVerifier::HandleCopyDone(HloInstruction* copy_done) {
const Shape& operand_shape = copy_done->operand(0)->shape();
const Shape& dest_shape = ShapeUtil::GetTupleElementShape(operand_shape, 0);
const Shape& src_shape = ShapeUtil::GetTupleElementShape(operand_shape, 1);
if (!ShapesSame(dest_shape, src_shape,
/*minor_to_major_only=*/false,
/*ignore_memory_space=*/true)) {
return InternalError(
"Source and destination buffers in CopyDone arguments need to be the "
"same shape found %s and %s\n%s",
StringifyShape(dest_shape), StringifyShape(src_shape),
copy_done->ToString());
}
return CheckShape(copy_done, ShapeUtil::GetTupleElementShape(
copy_done->operand(0)->shape(), 0));
}

View File

@ -622,7 +622,7 @@ TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDone) {
ENTRY CopyStartAndCopyDone {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
copy-start = (f32[2,3]{1,0:S(2)}, u32[]) copy-start(p0)
copy-start = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
}
)";
@ -639,7 +639,7 @@ TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDoneWrongLayout) {
ENTRY CopyStartAndCopyDone {
p0 = f32[2,3]{1,0:S(1)} parameter(0)
copy-start = (f32[2,3]{0,1:S(2)}, u32[]) copy-start(p0)
copy-start = (f32[2,3]{0,1:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
}
)";
@ -667,10 +667,9 @@ TEST_F(HloVerifierTest, CopyStartAndCopyDoneWrongType) {
auto status = verifier().Run(module.get()).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.error_message(),
HasSubstr(
"Expected instruction to have shape equal to (f32[2,3], u32[])"));
EXPECT_THAT(status.error_message(),
HasSubstr("Expected instruction to have shape equal to "
"(f32[2,3], f32[2,3], u32[])"));
}
TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) {
@ -679,7 +678,7 @@ TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) {
ENTRY CopyStartAndCopyDone {
p0 = f32[2,3] parameter(0)
copy-start = (f32[2,3], u32[]) copy-start(p0)
copy-start = (f32[2,3], f32[2,3], u32[]) copy-start(p0)
copy-done.1 = f32[2,3] copy-done(copy-start)
copy-done.2 = f32[2,3] copy-done(copy-start)
ROOT tuple = (f32[2,3], f32[2,3]) tuple(copy-done.1, copy-done.2)
@ -702,7 +701,7 @@ TEST_F(HloVerifierTest, CopyDoneNoCopyStart) {
ENTRY CopyStartAndCopyDone {
p0 = f32[2,3] parameter(0)
p1 = u32[] parameter(1)
tuple = (f32[2,3], u32[]) tuple(p0, p1)
tuple = (f32[2,3], f32[2,3], u32[]) tuple(p0, p0, p1)
ROOT copy-done = f32[2,3] copy-done(tuple)
}
)";

View File

@ -159,9 +159,18 @@ Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) {
return Status::OK();
}
Status LogicalBufferAnalysis::HandleCopyStart(HloInstruction* copy_start) {
// CopyStart defines the tuple, target buffer at index {0}, and context at
// index {2}.
NewLogicalBuffer(copy_start, /*index=*/{});
NewLogicalBuffer(copy_start, /*index=*/{0});
NewLogicalBuffer(copy_start, /*index=*/{2});
return Status::OK();
}
Status LogicalBufferAnalysis::HandleCopyDone(HloInstruction* copy_done) {
// The top-level buffer (index={}) for kCopy is newly created, but all other
// buffers (in the case of a tuple shape) come from the operand.
// The output of CopyDone aliases with operand {0}. CopyDone doesn't create
// any buffers.
return Status::OK();
}

View File

@ -62,6 +62,7 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault {
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleDomain(HloInstruction* domain) override;
Status HandleCopy(HloInstruction* copy) override;
Status HandleCopyStart(HloInstruction* copy_start) override;
Status HandleCopyDone(HloInstruction* copy_done) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;

View File

@ -1314,7 +1314,7 @@ Status MemorySpaceAssignment::CopyAllocation::Process(
}
}
copy_start_ = computation->AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}),
ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}),
HloOpcode::kCopyStart, producing_instruction));
copy_done_ = computation->AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_));

View File

@ -315,6 +315,30 @@ Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
return Status::OK();
}
Status TuplePointsToAnalysis::HandleCopyStart(HloInstruction* copy_start) {
// CopyStart forwards its aliased operand to {1}.
PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_start);
const PointsToSet& operand_points_to_set =
GetPointsToSet(copy_start->operand(0));
points_to_set.ForEachMutableElement(
[&](const ShapeIndex& target_index, PointsToSet::BufferList* buffers) {
if (target_index == ShapeIndex({1})) {
*buffers = operand_points_to_set.element(/*index=*/{});
} else {
buffers->push_back(
&logical_buffer_analysis_->GetBuffer(copy_start, target_index));
}
});
for (HloInstruction* tuple :
operand_points_to_set.tuple_sources(/*index=*/{})) {
points_to_set.add_tuple_source(/*index=*/{1}, tuple);
}
return Status::OK();
}
Status TuplePointsToAnalysis::HandleCopyDone(HloInstruction* copy_done) {
// CopyDone forwards its aliased operand.
PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_done);

View File

@ -250,6 +250,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleDomain(HloInstruction* domain) override;
Status HandleCopy(HloInstruction* copy) override;
Status HandleCopyStart(HloInstruction* copy_start) override;
Status HandleCopyDone(HloInstruction* copy_done) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;

View File

@ -334,8 +334,8 @@ TEST_F(TuplePointsToAnalysisTest, CopyStartAndCopyDone) {
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto copy_start = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeTupleShape(
{constant->shape(), ShapeUtil::MakeShape(U32, {})}),
ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(),
ShapeUtil::MakeShape(U32, {})}),
HloOpcode::kCopyStart, constant));
auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kCopyDone, copy_start));
@ -351,6 +351,7 @@ TEST_F(TuplePointsToAnalysisTest, CopyStartAndCopyDone) {
points_to_analysis_->GetPointsToSet(copy_start).element({}),
{copy_start});
ExpectHasBufferAliases(copy_start, {0}, {{copy_start, {0}}, {copy_done, {}}});
ExpectHasBufferAliases(constant, {}, {{constant, {}}, {copy_start, {1}}});
}
TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {