From 7189185ec06226923a98c87c0cc9db790f75aa8a Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Tue, 28 Jan 2020 16:05:26 -0800 Subject: [PATCH] [XLA] Change CopyStart return type to (target, source, context) This is to fix a memory corruption issue where the source buffer may be reused for another HLO. Having the source in the return type of CopyStart makes it explicit that the source buffer may still be in use until CopyDone. PiperOrigin-RevId: 292034096 Change-Id: Id5e0546100410eb28c50554122c166e081f885af --- .../xla/service/hlo_dataflow_analysis.cc | 26 +++++++++++++++++-- .../xla/service/hlo_dataflow_analysis.h | 1 + .../xla/service/hlo_dataflow_analysis_test.cc | 7 ++--- .../compiler/xla/service/hlo_evaluator.cc | 15 ++++++----- .../xla/service/hlo_evaluator_test.cc | 2 +- .../compiler/xla/service/hlo_matchers_test.cc | 24 ++++++++--------- .../compiler/xla/service/hlo_parser_test.cc | 8 +++--- .../compiler/xla/service/hlo_verifier.cc | 13 ++++++++++ .../compiler/xla/service/hlo_verifier_test.cc | 15 +++++------ .../xla/service/logical_buffer_analysis.cc | 13 ++++++++-- .../xla/service/logical_buffer_analysis.h | 1 + .../xla/service/memory_space_assignment.cc | 2 +- .../xla/service/tuple_points_to_analysis.cc | 24 +++++++++++++++++ .../xla/service/tuple_points_to_analysis.h | 1 + .../service/tuple_points_to_analysis_test.cc | 5 ++-- 15 files changed, 115 insertions(+), 42 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 11d3c5fdbd0..36da176b62f 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -380,6 +380,19 @@ bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { return changed; } +bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) { + CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart); + bool changed = false; + // CopyStart forwards the operand value to element {1} of its output. + const HloValueSet& operand_value_set = GetValueSet(copy_start->operand(0)); + HloValueSet& value_set = GetValueSet(copy_start, {1}); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + return changed; +} + bool HloDataflowAnalysis::UpdateCopyDoneValueSet(HloInstruction* copy_done) { CHECK_EQ(copy_done->opcode(), HloOpcode::kCopyDone); bool changed = false; @@ -682,6 +695,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( return UpdateSendValueSet(instruction); case HloOpcode::kRecvDone: return UpdateRecvDoneValueSet(instruction); + case HloOpcode::kCopyStart: + return UpdateCopyStartValueSet(instruction); case HloOpcode::kCopyDone: return UpdateCopyDoneValueSet(instruction); case HloOpcode::kConditional: @@ -863,9 +878,16 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { // values flow from their operands. define_value_at(/*index=*/{}); break; + case HloOpcode::kCopyStart: + // CopyStart produces a tuple of {destination buffer, aliased operand, + // U32 context}. + define_value_at(/*index=*/{}); + define_value_at(/*index=*/{0}); + define_value_at(/*index=*/{2}); + break; case HloOpcode::kCopyDone: - // CopyDone produces an element. Its output aliases its input tuple - // element {0}; element one is a context. + // CopyDone consumes a tuple produced by CopyStart and produces an + // element. Its output aliases its input tuple element {0}. break; case HloOpcode::kRecvDone: // RecvDone produces a two-element tuple. Element zero aliases its diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 670d1e4c086..294ffea6792 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -189,6 +189,7 @@ class HloDataflowAnalysis { bool UpdateDomainValueSet(HloInstruction* domain); bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateParameterValueSet(HloInstruction* parameter); + bool UpdateCopyStartValueSet(HloInstruction* copy_start); bool UpdateCopyDoneValueSet(HloInstruction* copy_done); bool UpdateRecvDoneValueSet(HloInstruction* recv_done); bool UpdateTupleSelectValueSet(HloInstruction* select); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 330779b5ebd..074d14fd810 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1177,8 +1177,8 @@ TEST_P(HloDataflowAnalysisTest, CopyStartAndCopyDone) { auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto copy_start = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeTupleShape( - {constant->shape(), ShapeUtil::MakeShape(U32, {})}), + ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(), + ShapeUtil::MakeShape(U32, {})}), HloOpcode::kCopyStart, constant)); auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopyDone, copy_start)); @@ -1192,7 +1192,8 @@ TEST_P(HloDataflowAnalysisTest, CopyStartAndCopyDone) { EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{})); EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{0})); - EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{1})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{1})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{2})); EXPECT_FALSE(analysis.ValueIsDefinedAt(copy_done, /*index=*/{})); EXPECT_THAT( HloValuesAt(copy_done, /*index=*/{}), diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 7159e5bfdf6..106ebb7be0e 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1872,14 +1872,15 @@ Status HloEvaluator::HandleCopyStart(HloInstruction* copy_start) { "user."); } - // The token in index {1} is undefined, but since we can't represent undefined - // values using a Literal, we just use 0. This should be safe though since we - // ensure that the only user of a kCopyStart is a kCopyDone which "eats" the - // token. Also note that MakeTuple copies its arguments, so this is - // memory-safe. - const Literal token_literal = LiteralUtil::CreateR0(0); + // The context in index {2} is undefined, but since we can't represent + // undefined values using a Literal, we just use 0. This should be safe though + // since we ensure that the only user of a kCopyStart is a kCopyDone which + // consumes the context. Also note that MakeTuple copies its arguments, so + // this is memory-safe. + const Literal context_literal = LiteralUtil::CreateR0(0); evaluated_[copy_start] = LiteralUtil::MakeTuple( - {&GetEvaluatedLiteralFor(copy_start->operand(0)), &token_literal}); + {&GetEvaluatedLiteralFor(copy_start->operand(0)), + &GetEvaluatedLiteralFor(copy_start->operand(0)), &context_literal}); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 89ea74e766c..17f43f8449d 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -4431,7 +4431,7 @@ TEST_F(HloEvaluatorTest, CopyStartCopyDone) { HloModule test ENTRY CopyStartCopyDone { init = f32[] constant(42.0) - copy-start = (f32[]{:S(1)}, u32[]) copy-start(init) + copy-start = (f32[]{:S(1)}, f32[], u32[]) copy-start(init) ROOT copy-done = f32[] copy-done(copy-start) } )"; diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 9c63638d492..cb5cbd05d65 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -278,7 +278,7 @@ TEST_F(HloMatchersTest, AsyncCopyMatcher) { auto p0 = HloInstruction::CreateParameter(0, shape_memspace1, "p0"); auto copy_start = HloInstruction::CreateUnary( ShapeUtil::MakeTupleShape( - {shape_memspace2, ShapeUtil::MakeShape(U32, {})}), + {shape_memspace2, shape_memspace1, ShapeUtil::MakeShape(U32, {})}), HloOpcode::kCopyStart, p0.get()); auto copy_done = HloInstruction::CreateUnary( shape_memspace2, HloOpcode::kCopyDone, copy_start.get()); @@ -286,18 +286,18 @@ TEST_F(HloMatchersTest, AsyncCopyMatcher) { EXPECT_THAT(copy_done.get(), op::AsyncCopy(2, 1, op::Parameter(0))); EXPECT_THAT(Explain(copy_start.get(), op::AsyncCopy(2, 1, op::Parameter(0))), - Eq("(%copy-start = (f32[16]{0:S(2)}, u32[]) " + Eq("(%copy-start = (f32[16]{0:S(2)}, f32[16]{0:S(1)}, u32[]) " "copy-start(f32[16]{0:S(1)} %p0))")); - EXPECT_THAT( - Explain(copy_done.get(), op::AsyncCopy(3, 1, op::Parameter(0))), - "(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) " - "%copy-start)) " - "copies to memory space 2, expected 3"); - EXPECT_THAT( - Explain(copy_done.get(), op::AsyncCopy(2, 3, op::Parameter(0))), - "(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, u32[]) " - "%copy-start)) " - "is in the memory space 1, expected 3"); + EXPECT_THAT(Explain(copy_done.get(), op::AsyncCopy(3, 1, op::Parameter(0))), + "(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, " + "f32[16]{0:S(1)}, u32[]) " + "%copy-start)) " + "copies to memory space 2, expected 3"); + EXPECT_THAT(Explain(copy_done.get(), op::AsyncCopy(2, 3, op::Parameter(0))), + "(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, " + "f32[16]{0:S(1)}, u32[]) " + "%copy-start)) " + "is in the memory space 1, expected 3"); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index e3431a4731f..7f626718389 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -317,11 +317,11 @@ R"(HloModule CopyStartAndCopyDone_module ENTRY %CopyStartAndCopyDone (v1: f32[], v2: f32[2,3]) -> (f32[], f32[2,3]) { %v1 = f32[] parameter(0) - %copy-start.1 = (f32[], u32[]) copy-start(f32[] %v1) - %copy-done.1 = f32[] copy-done((f32[], u32[]) %copy-start.1) + %copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1) + %copy-done.1 = f32[] copy-done((f32[], f32[], u32[]) %copy-start.1) %v2 = f32[2,3]{1,0:S(1)} parameter(1) - %copy-start.2 = (f32[2,3]{1,0:S(2)}, u32[]) copy-start(f32[2,3]{1,0:S(1)} %v2) - %copy-done.2 = f32[2,3]{1,0:S(2)} copy-done((f32[2,3]{1,0:S(2)}, u32[]) %copy-start.2) + %copy-start.2 = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(f32[2,3]{1,0:S(1)} %v2) + %copy-done.2 = f32[2,3]{1,0:S(2)} copy-done((f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) %copy-start.2) ROOT %tuple = (f32[], f32[2,3]{1,0:S(2)}) tuple(f32[] %copy-done.1, f32[2,3]{1,0:S(2)} %copy-done.2) } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index b4d1996373a..195fcacf342 100755 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -817,11 +817,24 @@ Status ShapeVerifier::HandlePad(HloInstruction* pad) { Status ShapeVerifier::HandleCopyStart(HloInstruction* copy_start) { return CheckShape(copy_start, ShapeUtil::MakeTupleShape({copy_start->operand(0)->shape(), + copy_start->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})}), /*only_compare_minor_to_major_in_layout=*/true); } Status ShapeVerifier::HandleCopyDone(HloInstruction* copy_done) { + const Shape& operand_shape = copy_done->operand(0)->shape(); + const Shape& dest_shape = ShapeUtil::GetTupleElementShape(operand_shape, 0); + const Shape& src_shape = ShapeUtil::GetTupleElementShape(operand_shape, 1); + if (!ShapesSame(dest_shape, src_shape, + /*minor_to_major_only=*/false, + /*ignore_memory_space=*/true)) { + return InternalError( + "Source and destination buffers in CopyDone arguments need to be the " + "same shape found %s and %s\n%s", + StringifyShape(dest_shape), StringifyShape(src_shape), + copy_done->ToString()); + } return CheckShape(copy_done, ShapeUtil::GetTupleElementShape( copy_done->operand(0)->shape(), 0)); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index c174af6dec0..c7290adab23 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -622,7 +622,7 @@ TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDone) { ENTRY CopyStartAndCopyDone { p0 = f32[2,3]{1,0:S(1)} parameter(0) - copy-start = (f32[2,3]{1,0:S(2)}, u32[]) copy-start(p0) + copy-start = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0) ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start) } )"; @@ -639,7 +639,7 @@ TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDoneWrongLayout) { ENTRY CopyStartAndCopyDone { p0 = f32[2,3]{1,0:S(1)} parameter(0) - copy-start = (f32[2,3]{0,1:S(2)}, u32[]) copy-start(p0) + copy-start = (f32[2,3]{0,1:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0) ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start) } )"; @@ -667,10 +667,9 @@ TEST_F(HloVerifierTest, CopyStartAndCopyDoneWrongType) { auto status = verifier().Run(module.get()).status(); ASSERT_FALSE(status.ok()); - EXPECT_THAT( - status.error_message(), - HasSubstr( - "Expected instruction to have shape equal to (f32[2,3], u32[])")); + EXPECT_THAT(status.error_message(), + HasSubstr("Expected instruction to have shape equal to " + "(f32[2,3], f32[2,3], u32[])")); } TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) { @@ -679,7 +678,7 @@ TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) { ENTRY CopyStartAndCopyDone { p0 = f32[2,3] parameter(0) - copy-start = (f32[2,3], u32[]) copy-start(p0) + copy-start = (f32[2,3], f32[2,3], u32[]) copy-start(p0) copy-done.1 = f32[2,3] copy-done(copy-start) copy-done.2 = f32[2,3] copy-done(copy-start) ROOT tuple = (f32[2,3], f32[2,3]) tuple(copy-done.1, copy-done.2) @@ -702,7 +701,7 @@ TEST_F(HloVerifierTest, CopyDoneNoCopyStart) { ENTRY CopyStartAndCopyDone { p0 = f32[2,3] parameter(0) p1 = u32[] parameter(1) - tuple = (f32[2,3], u32[]) tuple(p0, p1) + tuple = (f32[2,3], f32[2,3], u32[]) tuple(p0, p0, p1) ROOT copy-done = f32[2,3] copy-done(tuple) } )"; diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index 4ba660467ac..0a05ff5ca51 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -159,9 +159,18 @@ Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) { return Status::OK(); } +Status LogicalBufferAnalysis::HandleCopyStart(HloInstruction* copy_start) { + // CopyStart defines the tuple, target buffer at index {0}, and context at + // index {2}. + NewLogicalBuffer(copy_start, /*index=*/{}); + NewLogicalBuffer(copy_start, /*index=*/{0}); + NewLogicalBuffer(copy_start, /*index=*/{2}); + return Status::OK(); +} + Status LogicalBufferAnalysis::HandleCopyDone(HloInstruction* copy_done) { - // The top-level buffer (index={}) for kCopy is newly created, but all other - // buffers (in the case of a tuple shape) come from the operand. + // The output of CopyDone aliases with operand {0}. CopyDone doesn't create + // any buffers. return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index 5f774bb25a6..8ea4bcd6f87 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -62,6 +62,7 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { Status HandleBitcast(HloInstruction* bitcast) override; Status HandleDomain(HloInstruction* domain) override; Status HandleCopy(HloInstruction* copy) override; + Status HandleCopyStart(HloInstruction* copy_start) override; Status HandleCopyDone(HloInstruction* copy_done) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 21b222609c6..9f05f5419ea 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -1314,7 +1314,7 @@ Status MemorySpaceAssignment::CopyAllocation::Process( } } copy_start_ = computation->AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}), + ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}), HloOpcode::kCopyStart, producing_instruction)); copy_done_ = computation->AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_)); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 9ff819437b3..639a55e3356 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -315,6 +315,30 @@ Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleCopyStart(HloInstruction* copy_start) { + // CopyStart forwards its aliased operand to {1}. + PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_start); + const PointsToSet& operand_points_to_set = + GetPointsToSet(copy_start->operand(0)); + + points_to_set.ForEachMutableElement( + [&](const ShapeIndex& target_index, PointsToSet::BufferList* buffers) { + if (target_index == ShapeIndex({1})) { + *buffers = operand_points_to_set.element(/*index=*/{}); + } else { + buffers->push_back( + &logical_buffer_analysis_->GetBuffer(copy_start, target_index)); + } + }); + + for (HloInstruction* tuple : + operand_points_to_set.tuple_sources(/*index=*/{})) { + points_to_set.add_tuple_source(/*index=*/{1}, tuple); + } + + return Status::OK(); +} + Status TuplePointsToAnalysis::HandleCopyDone(HloInstruction* copy_done) { // CopyDone forwards its aliased operand. PointsToSet& points_to_set = CreateEmptyPointsToSet(copy_done); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index c223378b332..4ef0e16a4c5 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -250,6 +250,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleBitcast(HloInstruction* bitcast) override; Status HandleDomain(HloInstruction* domain) override; Status HandleCopy(HloInstruction* copy) override; + Status HandleCopyStart(HloInstruction* copy_start) override; Status HandleCopyDone(HloInstruction* copy_done) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index a0161419cec..c66f9d96a50 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -334,8 +334,8 @@ TEST_F(TuplePointsToAnalysisTest, CopyStartAndCopyDone) { auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto copy_start = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeTupleShape( - {constant->shape(), ShapeUtil::MakeShape(U32, {})}), + ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(), + ShapeUtil::MakeShape(U32, {})}), HloOpcode::kCopyStart, constant)); auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopyDone, copy_start)); @@ -351,6 +351,7 @@ TEST_F(TuplePointsToAnalysisTest, CopyStartAndCopyDone) { points_to_analysis_->GetPointsToSet(copy_start).element({}), {copy_start}); ExpectHasBufferAliases(copy_start, {0}, {{copy_start, {0}}, {copy_done, {}}}); + ExpectHasBufferAliases(constant, {}, {{constant, {}}, {copy_start, {1}}}); } TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {