From b8c9a576c90a0aba678760b4c0d85ea0a044b007 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Wed, 24 Feb 2021 15:10:37 -0800 Subject: [PATCH] [XLA/GPU] Re-enable h-loop-fusion to share operands with users. Will add a check in next commit to avoid invoking ElementsIn() with a tuple shape operands/users. --- .../xla/service/copy_insertion_test.cc | 50 +++++ .../xla/service/gpu/horizontal_loop_fusion.cc | 29 ++- .../gpu/horizontal_loop_fusion_test.cc | 58 +++++- .../xla/service/hlo_dataflow_analysis.cc | 181 ++++++++++++++++++ .../xla/service/hlo_dataflow_analysis_test.cc | 145 ++++++++++++++ 5 files changed, 446 insertions(+), 17 deletions(-) diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 04da56f57df..39d8a9e6002 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -2877,6 +2877,56 @@ ENTRY main { EXPECT_EQ(CountCopies(*module), 1); } +TEST_F(CopyInsertionTest, HorizontalLoopFusionNoCopy) { + const string& hlo_string = R"( + HloModule test + + fused_computation { + p0 = f32[10,20] parameter(0) + p1 = f32[10,20] parameter(1) + p2 = f32[10,10] parameter(2) + p3 = f32[10,10] parameter(3) + add0 = f32[10, 20] add(p0, p1) + sub0 = f32[10, 10] subtract(p2, p3) + reshape0 = f32[200] reshape(add0) + reshape1 = f32[100] reshape(sub0) + concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0} + slice0 = f32[200] slice(concat0), slice={[0:200]} + slice1 = f32[100] slice(concat0), slice={[200:300]} + ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1) + } + + ENTRY test { + p0 = f32[10,20] parameter(0) + p1 = f32[10,20] parameter(1) + p2 = f32[10,10] parameter(2) + p3 = f32[10,10] parameter(3) + fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation + gte0 = f32[200] get-tuple-element(fusion), index=0 + gte1 = f32[100] get-tuple-element(fusion), index=1 + bitcast0 = f32[10,20] bitcast(gte0) + bitcast1 = f32[10,10] bitcast(gte1) + ROOT tuple = (f32[10,20], f32[10,10]) tuple(bitcast0, bitcast1) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, + /*param_number=*/0, + /*param_index=*/{})); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, + /*param_number=*/3, + /*param_index=*/{})); + + InsertCopies(module.get()); + + // There should be no copies inserted. + EXPECT_EQ(CountCopies(*module), 0); +} + TEST_F(CopyInsertionTest, NestedWhileAndConditional3) { const string& hlo_string = R"( HloModule TestModule diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc index 9d1e0533a91..95f8015df22 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc @@ -174,14 +174,6 @@ bool IsProfitableFusionCandidate(const HloInstruction& instr) { return false; } - // We can emit DUS in-place, horizontally fusing it makes the emitter no - // longer recognize that it can be done in-place. This creates much slower - // code. This restriction could be lifted if buffer assignment would recognize - // that the DUS can be done in-place even inside of a horizontal fusion. - if (root->opcode() == HloOpcode::kDynamicUpdateSlice) { - return false; - } - return true; } @@ -203,6 +195,19 @@ bool HasOnlyRowMajorLayout(const HloInstruction& fusion_instr) { return true; } +// Returns whether any operand of `instr` is a parameter instruction that +// is shared with `fusion_instrs`. +bool AnyOpndIsParamSharedAmongFusions( + const HloInstruction* instr, + const absl::flat_hash_set& fusion_instrs) { + return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) { + return opnd->opcode() == HloOpcode::kParameter && + absl::c_any_of(opnd->users(), [&](const HloInstruction* user) { + return user != instr && fusion_instrs.contains(user); + }); + }); +} + void HorizontalLoopFusionImpl::FusionCandidates::Initialize( HloInstruction* consumer) { // First, find out all fusion instructions. We will filter out @@ -230,6 +235,14 @@ void HorizontalLoopFusionImpl::FusionCandidates::Initialize( } else if (!HasOnlyRowMajorLayout(*instr)) { VLOG(2) << "Reject non-row-major fusion instr " << instr->ToString(); continue; + } else if (AnyOpndIsParamSharedAmongFusions(instr, fusion_instrs)) { + // Don't fuse fusions whose operands are parameter instructions that are + // shared among fusions because we cannot i/o alias the produced + // horizontal fusion due to the concat insertion. + VLOG(2) << "Reject the fusion instr because it shares parameter with" + << " other fusion candidates, instr: ", + instr->ToString(); + continue; } else { VLOG(2) << "Find a fusion candidate " << instr->ToString(); fusion_instrs_.push_back(instr); diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc index 8091330cd47..d956438cb5a 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc @@ -364,33 +364,33 @@ TEST_F(HorizontalLoopFusionTest, RMSPropLike) { EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5})); } -TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) { +TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) { auto module = ParseAndReturnVerifiedModule(R"( HloModule NegativeTestForDynamicUpdateSlice fusion.1 { p.0 = f16[5,9,10]{2,1,0} parameter(0) - p.1 = s32[1]{0} parameter(1) + p.1 = s32[] parameter(1) p.2 = f16[1,9,10]{2,1,0} parameter(2) c.0 = s32[] constant(0) - pad = s32[3]{0} pad(p.1, c.0), padding=0_2 - ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + ROOT %dynamic-update-slice = + f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0) } fusion.2 { p.0 = f16[5,9,10]{2,1,0} parameter(0) - p.1 = s32[1]{0} parameter(1) + p.1 = s32[] parameter(1) p.2 = f16[1,9,10]{2,1,0} parameter(2) c.0 = s32[] constant(0) - pad = s32[3]{0} pad(p.1, c.0), padding=0_2 - ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + ROOT %dynamic-update-slice = + f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0) } ENTRY entry { p.00 = f16[5,9,10]{2,1,0} parameter(0) p.01 = f16[5,9,10]{2,1,0} parameter(1) - p.10 = s32[1]{0} parameter(2) - p.11 = s32[1]{0} parameter(3) + p.10 = s32[] parameter(2) + p.11 = s32[] parameter(3) p.20 = f16[1,9,10]{2,1,0} parameter(4) p.21 = f16[1,9,10]{2,1,0} parameter(5) @@ -400,6 +400,46 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) { })") .ValueOrDie(); + EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie()); + + VLOG(2) << "Dump after horizontal fusion:"; + VLOG(2) << module->ToString(); + + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0})); +} + +TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule BasicTest + + fused_computation.1 { + arg.1 = f16[123]{0} parameter(0) + arg.2 = f16[123]{0} parameter(1) + ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2) + } + + fused_computation.2 { + arg.1 = f16[123]{0} parameter(0) + arg.2 = f16[123]{0} parameter(1) + ROOT add.1 = f16[123]{0} add(arg.1, arg.2) + } + + ENTRY entry_computation { + arg.1 = f16[123]{0} parameter(0) + // arg.2 is shared by fusion.1 and fusion.2 + arg.2 = f16[123]{0} parameter(1) + arg.3 = f16[123]{0} parameter(2) + fusion.1 = f16[123]{0} + fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1 + fusion.2 = f16[123]{0} + fusion(arg.3, arg.2), kind=kLoop, calls=fused_computation.2 + ROOT tuple.1 = (f16[123]{0}, f16[123]{0}) + tuple(fusion.1, fusion.2) + } +)") + .ValueOrDie(); + EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index bc1063f9d48..31132fa5d9e 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -120,6 +120,175 @@ bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( return true; } +namespace { +bool Is1dSliceWithoutStrides(const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kSlice && + 1 == instr->slice_starts().size() && + 1 == instr->slice_limits().size() && + 1 == instr->slice_strides().size() && + 1 == instr->slice_strides().at(0); +} + +bool IsSliceInputFusion(const HloInstruction& unnested_hlo) { + if (!unnested_hlo.IsInputFusion()) { + return false; + } + const HloInstruction* root = unnested_hlo.fused_expression_root(); + if (root->opcode() != HloOpcode::kTuple) { + return false; + } + return absl::c_all_of(root->operands(), [](const HloInstruction* instr) { + return Is1dSliceWithoutStrides(instr); + }); +} + +struct ConcatUsageInfo { + // Pointer to a previously seen concat. nullptr if no previously seen concat. + const HloInstruction* prev_concat; + // The opnd id of the seen concat. + int64 concat_opnd_idx; + // The slice that recovers the opnd in the concat outputs. + const HloInstruction* slice_to_recover_opnd; +}; + +// Returns an optional concat usage info to denote whether the concat is used in +// an elementwise manner. A concat followed by slices is considered effectively +// elementwise if the slices combinedly is a reverse function of the concat. +absl::optional ConcatIsEffectivelyElementwise( + const HloInstruction& concat, const HloInstruction& operand, + const ConcatUsageInfo& info) { + // First, check if this concat is in the below pattern. Also, we check + // that the slices combinedly are in effect a reverse function of the concat. + // + // Concat + // | | + // v v + // Slice Slice + // + std::vector users = concat.users(); + if (!absl::c_all_of(users, Is1dSliceWithoutStrides)) { + // Limit our supported cases to 1 dimensional slices. + return absl::optional(); + } + // Verify that each operand to the concat is reversed by a slice. + if (users.size() != concat.operand_count() || + concat.operand_count() != concat.unique_operands().size()) { + return absl::optional(); + } + absl::c_sort(users, [](const HloInstruction* a, const HloInstruction* b) { + return a->slice_starts().at(0) < b->slice_starts().at(0); + }); + int64 prev_limit = 0; + for (int64 i = 0; i < users.size(); ++i) { + const HloInstruction* u = users[i]; + int64 slice_size = u->slice_limits().at(0) - u->slice_starts().at(0); + if (u->slice_starts().at(0) != prev_limit || + slice_size != ShapeUtil::ElementsIn(concat.operand(i)->shape())) { + return absl::optional(); + } + prev_limit = u->slice_limits().at(0); + } + + // If we have seen other concats, make sure they are identical. Multiple + // concats exist because horizontal fusion inserts one concat for each output + // of the fusion candidates. Check that all concats and operand ids are the + // same to know that the "transitive use closure" will be computed in the same + // iteration space. + int64 operand_idx = concat.operand_index(&operand); + if (info.prev_concat != nullptr) { + bool is_concat_identical = info.prev_concat->Identical( + concat, + /*eq_operands=*/[](const HloInstruction*, const HloInstruction*) { + // Operands don't need to be the same. + return true; + }); + if (!is_concat_identical || info.concat_opnd_idx != operand_idx) { + return absl::optional(); + } + } + + const HloInstruction* slice_to_recover_opnd = users.at(operand_idx); + return absl::optional( + ConcatUsageInfo{&concat, operand_idx, slice_to_recover_opnd}); +} + +// Returns whether we can prove the transitive uses of `param` are in effect +// elementwise. In other words, we prove that the "transitive use closure" will +// all be computed in the same iteration space without any reorder of elements. +// In addition, we check that the "transitive use closure" includes the output +// in the `root_tuple`. +// Theoretically, We can prove more patterns but our primary use case is +// SliceInputFusion. +bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* param, + const HloInstruction* root_tuple, + const ShapeIndex& out_shape_idx) { + CHECK_EQ(root_tuple->opcode(), HloOpcode::kTuple); + CHECK_EQ(out_shape_idx.size(), 1); + absl::flat_hash_set visited; + absl::InlinedVector stack; + stack.push_back(param); + ConcatUsageInfo concat_usage_info{nullptr, 0, nullptr}; + bool is_output_reachable = false; + while (!stack.empty()) { + const HloInstruction* current = stack.back(); + stack.pop_back(); + visited.insert(current); + for (const HloInstruction* user : current->users()) { + VLOG(3) << "Visiting: " << user->ToString(); + switch (user->opcode()) { + case HloOpcode::kTuple: + if (user == root_tuple && + current == root_tuple->operand(out_shape_idx.back())) { + // We need to know if the output is reachable by the `param` to make + // sure that they will be computed in the same iteration space. + is_output_reachable = true; + } + break; + case HloOpcode::kReshape: + if (!ShapeUtil::ReshapeIsBitcast(current->shape(), user->shape())) { + return false; + } + break; + case HloOpcode::kConcatenate: { + absl::optional optional_concat_info = + ConcatIsEffectivelyElementwise(*user, *current, + concat_usage_info); + if (!optional_concat_info) { + return false; + } + concat_usage_info = *optional_concat_info; + // Early continue as we only want to traverse through the slice that + // recovers the operand. It is guaranteed that the operand to the + // concat and the slice have the same iteration space. Insert the + // slice instead of the concat. + CHECK(!visited.contains(concat_usage_info.slice_to_recover_opnd)); + stack.push_back(concat_usage_info.slice_to_recover_opnd); + continue; + } + default: + for (const int64 use_index : user->OperandIndices(current)) { + if (!user->IsElementwiseOnOperand(use_index)) { + // Found a user that is non-elementwise on the current + // instruction. + return false; + } + } + if (!LayoutUtil::Equal(current->shape().layout(), + user->shape().layout())) { + // Make sure the layout is not changed by the elementwise op. + return false; + } + break; + } // end of switch + if (!visited.contains(user)) { + stack.push_back(user); + } + } + } + return is_output_reachable; +} +} // namespace + bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, const ShapeIndex& index) const { const HloValueSet& value_set = GetValueSet(instruction, index); @@ -1266,10 +1435,22 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( if (operand->opcode() == HloOpcode::kConstant) { return false; } + const Shape& operand_subshape = ShapeUtil::GetSubshape(operand->shape(), operand_index); const Shape& user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index); + if (IsSliceInputFusion(*user)) { + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + // We don't require the same dimensions but only the same number of elements + // and type (to make sure the same buffer size). + return ShapeUtil::ElementsIn(operand_subshape) == + ShapeUtil::ElementsIn(user_subshape) && + ShapeUtil::SameElementType(operand_subshape, user_subshape) && + AreTransitiveUsesEffectivelyElementwise( + fusion_param, user->fused_expression_root(), user_index); + } // Check that operand and user emit the same shape and layout. if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 1fa6fe95c40..9981db81b16 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2795,5 +2795,150 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { dataflow_analysis_->CanShareOperandBufferWithUser(reverse, {}, call, {})); } +TEST_F(CanShareOperandBufferWithUserTest, ConcatSliceWithElementwise) { + const char* kModule = R"( + HloModule test + + fused_computation { + p0 = f32[10,20] parameter(0) + p1 = f32[10,20] parameter(1) + p2 = f32[10,10] parameter(2) + p3 = f32[10,10] parameter(3) + add0 = f32[10, 20] add(p0, p1) + sub0 = f32[10, 10] subtract(p2, p3) + reshape0 = f32[200] reshape(add0) + reshape1 = f32[100] reshape(sub0) + concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0} + slice0 = f32[200] slice(concat0), slice={[0:200]} + slice1 = f32[100] slice(concat0), slice={[200:300]} + ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1) + } + + ENTRY test { + p0 = f32[10,20] parameter(0) + p1 = f32[10,20] parameter(1) + p2 = f32[10,10] parameter(2) + p3 = f32[10,10] parameter(3) + ROOT fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule)); + auto* fusion = module_->entry_computation()->root_instruction(); + auto* param0 = module_->entry_computation()->parameter_instruction(0); + auto* param1 = module_->entry_computation()->parameter_instruction(1); + auto* param2 = module_->entry_computation()->parameter_instruction(2); + auto* param3 = module_->entry_computation()->parameter_instruction(3); + + RunAnalysis(); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {0})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {0})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param2, {}, + fusion, {1})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param3, {}, + fusion, {1})); + // Tensors of different sizes cannot share buffer. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {1})); +} + +TEST_F(CanShareOperandBufferWithUserTest, ConcatSliceNegativeTest) { + const char* kModule = R"( + HloModule test + + fused_computation { + // p0 has multiple transitive uses fed to concat. So, p0 cannot share + // buffer with outputs because the aliased output could be written before + // all the uses of p0 are finished. + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add0 = f32[100] add(p0, p1) + concat0 = f32[200] concatenate(p0, add0), dimensions={0} + slice0 = f32[100] slice(concat0), slice={[0:100]} + slice1 = f32[100] slice(concat0), slice={[100:200]} + ROOT tuple = (f32[100], f32[100]) tuple(slice0, slice1) + } + + ENTRY test { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + ROOT fusion = (f32[100], f32[100]) fusion(p0, p1), + kind=kInput, calls=fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule)); + auto* fusion = module_->entry_computation()->root_instruction(); + auto* param0 = module_->entry_computation()->parameter_instruction(0); + auto* param1 = module_->entry_computation()->parameter_instruction(1); + + RunAnalysis(); + // p0 cannot share with either fusion{0} or fusion{1}. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {0})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {1})); + // p1 cannot share with fusion{0} because we're not sure about their + // relationship. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {0})); + // p1 can share with fusion{1} because they will be executed in an + // elementwise manner. + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {1})); +} + +TEST_F(CanShareOperandBufferWithUserTest, MultipleConcatenates) { + const char* kModule = R"( + HloModule test + + fused_computation { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add0 = f32[100] add(p0, p1) + sub0 = f32[100] subtract(p1, p1) + concat0 = f32[200] concatenate(p0, add0), dimensions={0} + slice0 = f32[100] slice(concat0), slice={[0:100]} + slice1 = f32[100] slice(concat0), slice={[100:200]} + concat1 = f32[200] concatenate(p0, sub0), dimensions={0} + slice2 = f32[100] slice(concat1), slice={[0:100]} + slice3 = f32[100] slice(concat1), slice={[100:200]} + ROOT tuple = (f32[100], f32[100], f32[100], f32[100]) + tuple(slice0, slice1, slice2, slice3) + } + + ENTRY test { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + ROOT fusion = (f32[100], f32[100], f32[100], f32[100]) + fusion(p0, p1), kind=kInput, calls=fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule)); + auto* fusion = module_->entry_computation()->root_instruction(); + auto* param0 = module_->entry_computation()->parameter_instruction(0); + auto* param1 = module_->entry_computation()->parameter_instruction(1); + + RunAnalysis(); + // p0 cannot share. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {0})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {1})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {2})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {3})); + // p1 can share with either fusion{1} or fusion{3}. + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {1})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {3})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {0})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {2})); +} + } // namespace } // namespace xla