[XLA] Change CopyStart return type to (target, source, context)
This is to fix a memory corruption issue where the source buffer may be reused for another HLO. Having the source in the return type of CopyStart makes it explicit that the source buffer may still be in use until CopyDone. PiperOrigin-RevId: 292034096 Change-Id: Id5e0546100410eb28c50554122c166e081f885af
This commit is contained in:
parent
77ae99f06d
commit
7189185ec0
@ -380,6 +380,19 @@ bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) {
|
||||
CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart);
|
||||
bool changed = false;
|
||||
// CopyStart forwards the operand value to element {1} of its output.
|
||||
const HloValueSet& operand_value_set = GetValueSet(copy_start->operand(0));
|
||||
HloValueSet& value_set = GetValueSet(copy_start, {1});
|
||||
if (value_set != operand_value_set) {
|
||||
value_set = operand_value_set;
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::UpdateCopyDoneValueSet(HloInstruction* copy_done) {
|
||||
CHECK_EQ(copy_done->opcode(), HloOpcode::kCopyDone);
|
||||
bool changed = false;
|
||||
@ -682,6 +695,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
|
||||
return UpdateSendValueSet(instruction);
|
||||
case HloOpcode::kRecvDone:
|
||||
return UpdateRecvDoneValueSet(instruction);
|
||||
case HloOpcode::kCopyStart:
|
||||
return UpdateCopyStartValueSet(instruction);
|
||||
case HloOpcode::kCopyDone:
|
||||
return UpdateCopyDoneValueSet(instruction);
|
||||
case HloOpcode::kConditional:
|
||||
@ -863,9 +878,16 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
// values flow from their operands.
|
||||
define_value_at(/*index=*/{});
|
||||
break;
|
||||
case HloOpcode::kCopyStart:
|
||||
// CopyStart produces a tuple of {destination buffer, aliased operand,
|
||||
// U32 context}.
|
||||
define_value_at(/*index=*/{});
|
||||
define_value_at(/*index=*/{0});
|
||||
define_value_at(/*index=*/{2});
|
||||
break;
|
||||
case HloOpcode::kCopyDone:
|
||||
// CopyDone produces an element. Its output aliases its input tuple
|
||||
// element {0}; element one is a context.
|
||||
// CopyDone consumes a tuple produced by CopyStart and produces an
|
||||
// element. Its output aliases its input tuple element {0}.
|
||||
break;
|
||||
case HloOpcode::kRecvDone:
|
||||
// RecvDone produces a two-element tuple. Element zero aliases its
|
||||
|
||||
@ -189,6 +189,7 @@ class HloDataflowAnalysis {
|
||||
bool UpdateDomainValueSet(HloInstruction* domain);
|
||||
bool UpdateGetTupleElementValueSet(HloInstruction* gte);
|
||||
bool UpdateParameterValueSet(HloInstruction* parameter);
|
||||
bool UpdateCopyStartValueSet(HloInstruction* copy_start);
|
||||
bool UpdateCopyDoneValueSet(HloInstruction* copy_done);
|
||||
bool UpdateRecvDoneValueSet(HloInstruction* recv_done);
|
||||
bool UpdateTupleSelectValueSet(HloInstruction* select);
|
||||
|
||||
@ -1177,8 +1177,8 @@ TEST_P(HloDataflowAnalysisTest, CopyStartAndCopyDone) {
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto copy_start = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{constant->shape(), ShapeUtil::MakeShape(U32, {})}),
|
||||
ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(),
|
||||
ShapeUtil::MakeShape(U32, {})}),
|
||||
HloOpcode::kCopyStart, constant));
|
||||
auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
constant->shape(), HloOpcode::kCopyDone, copy_start));
|
||||
@ -1192,7 +1192,8 @@ TEST_P(HloDataflowAnalysisTest, CopyStartAndCopyDone) {
|
||||
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{}));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{0}));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{1}));
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{1}));
|
||||
EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{2}));
|
||||
EXPECT_FALSE(analysis.ValueIsDefinedAt(copy_done, /*index=*/{}));
|
||||
EXPECT_THAT(
|
||||
HloValuesAt(copy_done, /*index=*/{}),
|
||||
|
||||
@ -1872,14 +1872,15 @@ Status HloEvaluator::HandleCopyStart(HloInstruction* copy_start) {
|
||||
"user.");
|
||||
}
|
||||
|
||||
// The token in index {1} is undefined, but since we can't represent undefined
|
||||
// values using a Literal, we just use 0. This should be safe though since we
|
||||
// ensure that the only user of a kCopyStart is a kCopyDone which "eats" the
|
||||
// token. Also note that MakeTuple copies its arguments, so this is
|
||||
// memory-safe.
|
||||
const Literal token_literal = LiteralUtil::CreateR0<uint32>(0);
|
||||
// The context in index {2} is undefined, but since we can't represent
|
||||
// undefined values using a Literal, we just use 0. This should be safe though
|
||||
// since we ensure that the only user of a kCopyStart is a kCopyDone which
|
||||
// consumes the context. Also note that MakeTuple copies its arguments, so
|
||||
// this is memory-safe.
|
||||
const Literal context_literal = LiteralUtil::CreateR0<uint32>(0);
|
||||
evaluated_[copy_start] = LiteralUtil::MakeTuple(
|
||||
{&GetEvaluatedLiteralFor(copy_start->operand(0)), &token_literal});
|
||||
{&GetEvaluatedLiteralFor(copy_start->operand(0)),
|
||||
&GetEvaluatedLiteralFor(copy_start->operand(0)), &context_literal});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
@ -4431,7 +4431,7 @@ TEST_F(HloEvaluatorTest, CopyStartCopyDone) {
|
||||
HloModule test
|
||||
ENTRY CopyStartCopyDone {
|
||||
init = f32[] constant(42.0)
|
||||
copy-start = (f32[]{:S(1)}, u32[]) copy-start(init)
|
||||
copy-start = (f32[]{:S(1)}, f32[], u32[]) copy-start(init)
|
||||
ROOT copy-done = f32[] copy-done(copy-start)
|
||||
}
|
||||
)";
|
||||
|
||||
@ -278,7 +278,7 @@ TEST_F(HloMatchersTest, AsyncCopyMatcher) {
|
||||
auto p0 = HloInstruction::CreateParameter(0, shape_memspace1, "p0");
|
||||
auto copy_start = HloInstruction::CreateUnary(
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{shape_memspace2, ShapeUtil::MakeShape(U32, {})}),
|
||||
{shape_memspace2, shape_memspace1, ShapeUtil::MakeShape(U32, {})}),
|
||||
HloOpcode::kCopyStart, p0.get());
|
||||
auto copy_done = HloInstruction::CreateUnary(
|
||||
shape_memspace2, HloOpcode::kCopyDone, copy_start.get());
|
||||
@ -286,18 +286,18 @@ TEST_F(HloMatchersTest, AsyncCopyMatcher) {
|
||||
EXPECT_THAT(copy_done.get(), op::AsyncCopy(2, 1, op::Parameter(0)));
|
||||
|
||||
EXPECT_THAT(Explain(copy_start.get(), op::AsyncCopy(2, 1, op::Parameter(0))),
|
||||
Eq("(%copy-start = (f32[16]{0:S(2)}, u32[]) "
|
||||
Eq("(%copy-start = (f32[16]{0:S(2)}, f32[16]{0:S(1)}, u32[]) "
|
||||
"copy-start(f32[16]{0:S(1)} %p0))"));
|
||||
EXPECT_THAT(
|
||||
Explain(copy_done.get(), op::AsyncCopy(3, 1, op::Parameter(0))),
|
||||
"(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) "
|
||||
"%copy-start)) "
|
||||
"copies to memory space 2, expected 3");
|
||||
EXPECT_THAT(
|
||||
Explain(copy_done.get(), op::AsyncCopy(2, 3, op::Parameter(0))),
|
||||
"(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) "
|
||||
"%copy-start)) "
|
||||
"is in the memory space 1, expected 3");
|
||||
EXPECT_THAT(Explain(copy_done.get(), op::AsyncCopy(3, 1, op::Parameter(0))),
|
||||
"(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, "
|
||||
"f32[16]{0:S(1)}, u32[]) "
|
||||
"%copy-start)) "
|
||||
"copies to memory space 2, expected 3");
|
||||
EXPECT_THAT(Explain(copy_done.get(), op::AsyncCopy(2, 3, op::Parameter(0))),
|
||||
"(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, "
|
||||
"f32[16]{0:S(1)}, u32[]) "
|
||||
"%copy-start)) "
|
||||
"is in the memory space 1, expected 3");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -317,11 +317,11 @@ R"(HloModule CopyStartAndCopyDone_module
|
||||
|
||||
ENTRY %CopyStartAndCopyDone (v1: f32[], v2: f32[2,3]) -> (f32[], f32[2,3]) {
|
||||
%v1 = f32[] parameter(0)
|
||||
%copy-start.1 = (f32[], u32[]) copy-start(f32[] %v1)
|
||||
%copy-done.1 = f32[] copy-done((f32[], u32[]) %copy-start.1)
|
||||
%copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1)
|
||||
%copy-done.1 = f32[] copy-done((f32[], f32[], u32[]) %copy-start.1)
|
||||
%v2 = f32[2,3]{1,0:S(1)} parameter(1)
|
||||
%copy-start.2 = (f32[2,3]{1,0:S(2)}, u32[]) copy-start(f32[2,3]{1,0:S(1)} %v2)
|
||||
%copy-done.2 = f32[2,3]{1,0:S(2)} copy-done((f32[2,3]{1,0:S(2)}, u32[]) %copy-start.2)
|
||||
%copy-start.2 = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(f32[2,3]{1,0:S(1)} %v2)
|
||||
%copy-done.2 = f32[2,3]{1,0:S(2)} copy-done((f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) %copy-start.2)
|
||||
ROOT %tuple = (f32[], f32[2,3]{1,0:S(2)}) tuple(f32[] %copy-done.1, f32[2,3]{1,0:S(2)} %copy-done.2)
|
||||
}
|
||||
|
||||
|
||||
@ -817,11 +817,24 @@ Status ShapeVerifier::HandlePad(HloInstruction* pad) {
|
||||
Status ShapeVerifier::HandleCopyStart(HloInstruction* copy_start) {
|
||||
return CheckShape(copy_start,
|
||||
ShapeUtil::MakeTupleShape({copy_start->operand(0)->shape(),
|
||||
copy_start->operand(0)->shape(),
|
||||
ShapeUtil::MakeShape(U32, {})}),
|
||||
/*only_compare_minor_to_major_in_layout=*/true);
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleCopyDone(HloInstruction* copy_done) {
|
||||
const Shape& operand_shape = copy_done->operand(0)->shape();
|
||||
const Shape& dest_shape = ShapeUtil::GetTupleElementShape(operand_shape, 0);
|
||||
const Shape& src_shape = ShapeUtil::GetTupleElementShape(operand_shape, 1);
|
||||
if (!ShapesSame(dest_shape, src_shape,
|
||||
/*minor_to_major_only=*/false,
|
||||
/*ignore_memory_space=*/true)) {
|
||||
return InternalError(
|
||||
"Source and destination buffers in CopyDone arguments need to be the "
|
||||
"same shape found %s and %s\n%s",
|
||||
StringifyShape(dest_shape), StringifyShape(src_shape),
|
||||
copy_done->ToString());
|
||||
}
|
||||
return CheckShape(copy_done, ShapeUtil::GetTupleElementShape(
|
||||
copy_done->operand(0)->shape(), 0));
|
||||
}
|
||||
|
||||
@ -622,7 +622,7 @@ TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDone) {
|
||||
|
||||
ENTRY CopyStartAndCopyDone {
|
||||
p0 = f32[2,3]{1,0:S(1)} parameter(0)
|
||||
copy-start = (f32[2,3]{1,0:S(2)}, u32[]) copy-start(p0)
|
||||
copy-start = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
|
||||
ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
|
||||
}
|
||||
)";
|
||||
@ -639,7 +639,7 @@ TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDoneWrongLayout) {
|
||||
|
||||
ENTRY CopyStartAndCopyDone {
|
||||
p0 = f32[2,3]{1,0:S(1)} parameter(0)
|
||||
copy-start = (f32[2,3]{0,1:S(2)}, u32[]) copy-start(p0)
|
||||
copy-start = (f32[2,3]{0,1:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
|
||||
ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
|
||||
}
|
||||
)";
|
||||
@ -667,10 +667,9 @@ TEST_F(HloVerifierTest, CopyStartAndCopyDoneWrongType) {
|
||||
|
||||
auto status = verifier().Run(module.get()).status();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.error_message(),
|
||||
HasSubstr(
|
||||
"Expected instruction to have shape equal to (f32[2,3], u32[])"));
|
||||
EXPECT_THAT(status.error_message(),
|
||||
HasSubstr("Expected instruction to have shape equal to "
|
||||
"(f32[2,3], f32[2,3], u32[])"));
|
||||
}
|
||||
|
||||
TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) {
|
||||
@ -679,7 +678,7 @@ TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) {
|
||||
|
||||
ENTRY CopyStartAndCopyDone {
|
||||
p0 = f32[2,3] parameter(0)
|
||||
copy-start = (f32[2,3], u32[]) copy-start(p0)
|
||||
copy-start = (f32[2,3], f32[2,3], u32[]) copy-start(p0)
|
||||
copy-done.1 = f32[2,3] copy-done(copy-start)
|
||||
copy-done.2 = f32[2,3] copy-done(copy-start)
|
||||
ROOT tuple = (f32[2,3], f32[2,3]) tuple(copy-done.1, copy-done.2)
|
||||
@ -702,7 +701,7 @@ TEST_F(HloVerifierTest, CopyDoneNoCopyStart) {
|
||||
ENTRY CopyStartAndCopyDone {
|
||||
p0 = f32[2,3] parameter(0)
|
||||
p1 = u32[] parameter(1)
|
||||
tuple = (f32[2,3], u32[]) tuple(p0, p1)
|
||||
tuple = (f32[2,3], f32[2,3], u32[]) tuple(p0, p0, p1)
|
||||
ROOT copy-done = f32[2,3] copy-done(tuple)
|
||||
}
|
||||
)";
|
||||
|
||||
@ -159,9 +159,18 @@ Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LogicalBufferAnalysis::HandleCopyStart(HloInstruction* copy_start) {
|
||||
// CopyStart defines the tuple, target buffer at index {0}, and context at
|
||||
// index {2}.
|
||||
NewLogicalBuffer(copy_start, /*index=*/{});
|
||||
NewLogicalBuffer(copy_start, /*index=*/{0});
|
||||
NewLogicalBuffer(copy_start, /*index=*/{2});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LogicalBufferAnalysis::HandleCopyDone(HloInstruction* copy_done) {
|
||||
// The top-level buffer (index={}) for kCopy is newly created, but all other
|
||||
// buffers (in the case of a tuple shape) come from the operand.
|
||||
// The output of CopyDone aliases with operand {0}. CopyDone doesn't create
|
||||
// any buffers.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
@ -62,6 +62,7 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault {
|
||||
Status HandleBitcast(HloInstruction* bitcast) override;
|
||||
Status HandleDomain(HloInstruction* domain) override;
|
||||
Status HandleCopy(HloInstruction* copy) override;
|
||||
Status HandleCopyStart(HloInstruction* copy_start) override;
|
||||
Status HandleCopyDone(HloInstruction* copy_done) override;
|
||||
Status HandleRecvDone(HloInstruction* recv_done) override;
|
||||
Status HandleSend(HloInstruction* send) override;
|
||||
|
||||
@ -1314,7 +1314,7 @@ Status MemorySpaceAssignment::CopyAllocation::Process(
|
||||
}
|
||||
}
|
||||
copy_start_ = computation->AddInstruction(HloInstruction::CreateUnary(
|
||||
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}),
|
||||
ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}),
|
||||
HloOpcode::kCopyStart, producing_instruction));
|
||||
copy_done_ = computation->AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_));
|
||||
|
||||
@ -315,6 +315,30 @@ Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TuplePointsToAnalysis::HandleCopyStart(HloInstruction* copy_start) {
|
||||
// CopyStart forwards its aliased operand to {1}.
|
||||
PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_start);
|
||||
const PointsToSet& operand_points_to_set =
|
||||
GetPointsToSet(copy_start->operand(0));
|
||||
|
||||
points_to_set.ForEachMutableElement(
|
||||
[&](const ShapeIndex& target_index, PointsToSet::BufferList* buffers) {
|
||||
if (target_index == ShapeIndex({1})) {
|
||||
*buffers = operand_points_to_set.element(/*index=*/{});
|
||||
} else {
|
||||
buffers->push_back(
|
||||
&logical_buffer_analysis_->GetBuffer(copy_start, target_index));
|
||||
}
|
||||
});
|
||||
|
||||
for (HloInstruction* tuple :
|
||||
operand_points_to_set.tuple_sources(/*index=*/{})) {
|
||||
points_to_set.add_tuple_source(/*index=*/{1}, tuple);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TuplePointsToAnalysis::HandleCopyDone(HloInstruction* copy_done) {
|
||||
// CopyDone forwards its aliased operand.
|
||||
PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_done);
|
||||
|
||||
@ -250,6 +250,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
|
||||
Status HandleBitcast(HloInstruction* bitcast) override;
|
||||
Status HandleDomain(HloInstruction* domain) override;
|
||||
Status HandleCopy(HloInstruction* copy) override;
|
||||
Status HandleCopyStart(HloInstruction* copy_start) override;
|
||||
Status HandleCopyDone(HloInstruction* copy_done) override;
|
||||
Status HandleRecvDone(HloInstruction* recv_done) override;
|
||||
Status HandleSend(HloInstruction* send) override;
|
||||
|
||||
@ -334,8 +334,8 @@ TEST_F(TuplePointsToAnalysisTest, CopyStartAndCopyDone) {
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto copy_start = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{constant->shape(), ShapeUtil::MakeShape(U32, {})}),
|
||||
ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(),
|
||||
ShapeUtil::MakeShape(U32, {})}),
|
||||
HloOpcode::kCopyStart, constant));
|
||||
auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
constant->shape(), HloOpcode::kCopyDone, copy_start));
|
||||
@ -351,6 +351,7 @@ TEST_F(TuplePointsToAnalysisTest, CopyStartAndCopyDone) {
|
||||
points_to_analysis_->GetPointsToSet(copy_start).element({}),
|
||||
{copy_start});
|
||||
ExpectHasBufferAliases(copy_start, {0}, {{copy_start, {0}}, {copy_done, {}}});
|
||||
ExpectHasBufferAliases(constant, {}, {{constant, {}}, {copy_start, {1}}});
|
||||
}
|
||||
|
||||
TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user