From 4edc3da19631fa3cb3525a2e4362262685d87e20 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Fri, 22 Jan 2021 12:02:06 -0800 Subject: [PATCH 01/10] Make H-loop-fusion share operands with users. --- .../xla/service/copy_insertion_test.cc | 52 +++++++ .../xla/service/hlo_dataflow_analysis.cc | 140 ++++++++++++++++++ .../xla/service/hlo_dataflow_analysis_test.cc | 48 ++++++ 3 files changed, 240 insertions(+) diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 23d3be6e17d..55ea07a1d2c 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -2782,5 +2782,57 @@ ENTRY main { EXPECT_EQ(CountCopies(*module), 1); } +TEST_F(CopyInsertionTest, HorizontalLoopFusionNoCopy) { + const string& hlo_string = R"( + HloModule test + + fused_computation { + p0 = f32[10,20] parameter(0) + p1 = f32[10,20] parameter(1) + p2 = f32[10,10] parameter(2) + p3 = f32[10,10] parameter(3) + add0 = f32[10, 20] add(p0, p1) + sub0 = f32[10, 10] subtract(p2, p3) + reshape0 = f32[200] reshape(add0) + reshape1 = f32[100] reshape(sub0) + concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0} + slice0 = f32[200] slice(concat0), slice={[0:200]} + slice1 = f32[100] slice(concat0), slice={[200:300]} + ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1) + } + + ENTRY test { + p0 = f32[10,20] parameter(0) + p1 = f32[10,20] parameter(1) + p2 = f32[10,10] parameter(2) + p3 = f32[10,10] parameter(3) + fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation + gte0 = f32[200] get-tuple-element(fusion), index=0 + gte1 = f32[100] get-tuple-element(fusion), index=1 + bitcast0 = f32[10,20] bitcast(gte0) + bitcast1 = f32[10,10] bitcast(gte1) + ROOT tuple = (f32[10,20], f32[10,10]) tuple(bitcast0, bitcast1) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + // Set up the aliasing manually which normally would be set by + // alias_passthrough_params pass. + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{0}, + /*param_number=*/0, + /*param_index=*/{})); + ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( + /*output_index=*/{1}, + /*param_number=*/3, + /*param_index=*/{})); + + InsertCopies(module.get()); + + // There should be no copies inserted. + EXPECT_EQ(CountCopies(*module), 0); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index bc1063f9d48..c2add4ae830 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -120,6 +120,135 @@ bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( return true; } +namespace { +bool Is1dSliceWithoutStrides(const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kSlice && + 1 == instr->slice_starts().size() && + 1 == instr->slice_limits().size() && + 1 == instr->slice_strides().size() && + 1 == instr->slice_strides().at(0); +} + +bool IsSliceInputFusion(const HloInstruction& unnested_hlo) { + if (!unnested_hlo.IsInputFusion()) { + return false; + } + + const HloInstruction* root = unnested_hlo.fused_expression_root(); + if (Is1dSliceWithoutStrides(root)) { + return true; + } + + if (root->opcode() != HloOpcode::kTuple) { + return false; + } + + return absl::c_all_of(root->operands(), [](const HloInstruction* instr) { + return Is1dSliceWithoutStrides(instr); + }); +} + +bool ConcatHasNoEffect(const HloInstruction* concat) { + // Check if this concat is in the below pattern. In addition, we check + // that the slices combiningly are in effect a reverse function of the + // concat. + // + // Concat + // | | + // v v + // Slice Slice + // + std::vector users = concat->users(); + bool all_1d_slices = absl::c_all_of(users, [](const HloInstruction* i) { + return Is1dSliceWithoutStrides(i); + }); + if (!all_1d_slices) { + // Limit our supported cases to 1 dimensional slices. + return false; + } + absl::c_sort(users, [](const HloInstruction* a, const HloInstruction* b) { + return a->slice_starts().at(0) < b->slice_starts().at(0); + }); + + // Verify that each operand to the concat is reversed by a slice. + if (users.size() != concat->operand_count()) { + return false; + } + int64 prev_limit = 0; + for (size_t i = 0; i < users.size(); ++i) { + const HloInstruction* u = users[i]; + int64 slice_size = u->slice_limits().at(0) - u->slice_starts().at(0); + if (u->slice_starts().at(0) != prev_limit || + slice_size != ShapeUtil::ElementsIn(concat->operand(i)->shape())) { + return false; + } + prev_limit = u->slice_limits().at(0); + } + + return true; +} + +// Returns whether we can prove the transitive uses are in effect elementwise +// operations. A concat followed by slices is considered (effectively) +// elementwise if the slices combiningly is a reverse function of the concat. +// We can prove more patterns but we currently do just enough for +// SliceInputFusion. +bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* instr) { + absl::flat_hash_set visited; + absl::InlinedVector stack; + stack.push_back(instr); + while (!stack.empty()) { + const HloInstruction* current = stack.back(); + stack.pop_back(); + visited.insert(current); + for (const HloInstruction* user : current->users()) { + VLOG(3) << "Visiting: " << user->ToString(); + switch (user->opcode()) { + case HloOpcode::kTuple: + // We say reshape is fine because it does not reorder elements. + case HloOpcode::kReshape: + break; + case HloOpcode::kConcatenate: + if (!ConcatHasNoEffect(user)) { + return false; + } + break; + case HloOpcode::kSlice: + if (user->operand(0)->opcode() != HloOpcode::kConcatenate) { + return false; + } + break; + default: + if (user->IsElementwise()) { + for (const int64 use_index : user->OperandIndices(current)) { + if (!user->IsElementwiseOnOperand(use_index)) { + // Found a user that is non-elementwise on the current + // instruction. + return false; + } + } + } else { + VLOG(3) << "Cannot prove that the op is effectively elementwise: " + << user->ToString(); + return false; + } + break; + } // end of switch + if (user->opcode() != HloOpcode::kTuple && + !LayoutUtil::IsMonotonicWithDim0Major(user->shape().layout())) { + // Simply check that all the layout is row-major to make sure there + // is no layout change. + return false; + } + if (!visited.contains(user)) { + stack.push_back(user); + } + } + } + return true; +} +} // namespace + bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, const ShapeIndex& index) const { const HloValueSet& value_set = GetValueSet(instruction, index); @@ -1266,10 +1395,21 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( if (operand->opcode() == HloOpcode::kConstant) { return false; } + const Shape& operand_subshape = ShapeUtil::GetSubshape(operand->shape(), operand_index); const Shape& user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index); + if (IsSliceInputFusion(*user)) { + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + // We don't require the same dimensions but only the same number of elements + // and type (to make sure the same buffer size). + return ShapeUtil::ElementsIn(operand_subshape) == + ShapeUtil::ElementsIn(user_subshape) && + ShapeUtil::SameElementType(operand_subshape, user_subshape) && + AreTransitiveUsesEffectivelyElementwise(fusion_param); + } // Check that operand and user emit the same shape and layout. if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 1fa6fe95c40..b83f24df894 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2795,5 +2795,53 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { dataflow_analysis_->CanShareOperandBufferWithUser(reverse, {}, call, {})); } +TEST_F(CanShareOperandBufferWithUserTest, ConcatSliceWithElementwise) { + const char* kModule = R"( + HloModule test + + fused_computation { + p0 = f32[10,20] parameter(0) + p1 = f32[10,20] parameter(1) + p2 = f32[10,10] parameter(2) + p3 = f32[10,10] parameter(3) + add0 = f32[10, 20] add(p0, p1) + sub0 = f32[10, 10] subtract(p2, p3) + reshape0 = f32[200] reshape(add0) + reshape1 = f32[100] reshape(sub0) + concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0} + slice0 = f32[200] slice(concat0), slice={[0:200]} + slice1 = f32[100] slice(concat0), slice={[200:300]} + ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1) + } + + ENTRY test { + p0 = f32[10,20] parameter(0) + p1 = f32[10,20] parameter(1) + p2 = f32[10,10] parameter(2) + p3 = f32[10,10] parameter(3) + ROOT fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule)); + auto* fusion = module_->entry_computation()->root_instruction(); + auto* param0 = module_->entry_computation()->parameter_instruction(0); + auto* param1 = module_->entry_computation()->parameter_instruction(1); + auto* param2 = module_->entry_computation()->parameter_instruction(2); + auto* param3 = module_->entry_computation()->parameter_instruction(3); + + RunAnalysis(); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {0})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {0})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param2, {}, + fusion, {1})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param3, {}, + fusion, {1})); + // Tensors of different sizes cannot share buffer. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {1})); +} + } // namespace } // namespace xla From 42b6a802a6d1d7743a7965eec6f2b42bfc7f6902 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Mon, 25 Jan 2021 11:08:55 -0800 Subject: [PATCH 02/10] Revision to address review comments. --- .../xla/service/hlo_dataflow_analysis.cc | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index c2add4ae830..ef0e0a4ef7a 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -175,7 +175,7 @@ bool ConcatHasNoEffect(const HloInstruction* concat) { return false; } int64 prev_limit = 0; - for (size_t i = 0; i < users.size(); ++i) { + for (int64 i = 0; i < users.size(); ++i) { const HloInstruction* u = users[i]; int64 slice_size = u->slice_limits().at(0) - u->slice_starts().at(0); if (u->slice_starts().at(0) != prev_limit || @@ -205,8 +205,11 @@ bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* instr) { VLOG(3) << "Visiting: " << user->ToString(); switch (user->opcode()) { case HloOpcode::kTuple: - // We say reshape is fine because it does not reorder elements. + break; case HloOpcode::kReshape: + if (!ShapeUtil::ReshapeIsBitcast(current->shape(), user->shape())) { + return false; + } break; case HloOpcode::kConcatenate: if (!ConcatHasNoEffect(user)) { @@ -215,6 +218,8 @@ bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* instr) { break; case HloOpcode::kSlice: if (user->operand(0)->opcode() != HloOpcode::kConcatenate) { + // Check that we have seen and verified a preceding concat of this + // Slice. return false; } break; @@ -227,6 +232,11 @@ bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* instr) { return false; } } + if (!LayoutUtil::Equal(current->shape().layout(), + user->shape().layout())) { + // Make sure the layout is not changed by the elementwise op. + return false; + } } else { VLOG(3) << "Cannot prove that the op is effectively elementwise: " << user->ToString(); @@ -234,12 +244,6 @@ bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* instr) { } break; } // end of switch - if (user->opcode() != HloOpcode::kTuple && - !LayoutUtil::IsMonotonicWithDim0Major(user->shape().layout())) { - // Simply check that all the layout is row-major to make sure there - // is no layout change. - return false; - } if (!visited.contains(user)) { stack.push_back(user); } From 45194938fc1a93d45609eeb5ffe19ae6f8ffa3e2 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Mon, 25 Jan 2021 13:48:26 -0800 Subject: [PATCH 03/10] [NFC] Polish coding style. --- tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index ef0e0a4ef7a..1a3ce3d0950 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -159,10 +159,9 @@ bool ConcatHasNoEffect(const HloInstruction* concat) { // Slice Slice // std::vector users = concat->users(); - bool all_1d_slices = absl::c_all_of(users, [](const HloInstruction* i) { - return Is1dSliceWithoutStrides(i); - }); - if (!all_1d_slices) { + if (!absl::c_all_of(users, [](const HloInstruction* i) { + return Is1dSliceWithoutStrides(i); + })) { // Limit our supported cases to 1 dimensional slices. return false; } From dde133c3f3edc717b0fa2b08f798e0d6c473d8bc Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Tue, 26 Jan 2021 15:15:07 -0800 Subject: [PATCH 04/10] Fixing a bug that concat can create extra constraints to buffer sharing. - Also, add a negative test. --- .../xla/service/hlo_dataflow_analysis.cc | 99 +++++++++++++------ .../xla/service/hlo_dataflow_analysis_test.cc | 45 +++++++++ 2 files changed, 112 insertions(+), 32 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 1a3ce3d0950..f08f22d3acf 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -133,32 +133,33 @@ bool IsSliceInputFusion(const HloInstruction& unnested_hlo) { if (!unnested_hlo.IsInputFusion()) { return false; } - const HloInstruction* root = unnested_hlo.fused_expression_root(); - if (Is1dSliceWithoutStrides(root)) { - return true; - } - if (root->opcode() != HloOpcode::kTuple) { return false; } - return absl::c_all_of(root->operands(), [](const HloInstruction* instr) { return Is1dSliceWithoutStrides(instr); }); } -bool ConcatHasNoEffect(const HloInstruction* concat) { - // Check if this concat is in the below pattern. In addition, we check - // that the slices combiningly are in effect a reverse function of the - // concat. +// Returns whether the concat is used in an elementwise manner. +// A concat followed by slices is considered (effectively) elementwise if the +// slices combinedly is a reverse function of the concat. +// prev_concat_opnd_idx: previously seen concat and its operand index. +// prev_concat_opnd_idx.first is nullptr if no previously seen concat. +bool ConcatIsEffectivelyElementwise( + const HloInstruction& concat, const HloInstruction& operand, + std::pair* prev_concat_opnd_idx, + const HloInstruction** slice_to_recover_opnd) { + // First, check if this concat is in the below pattern. Also, we check + // that the slices combinedly are in effect a reverse function of the concat. // // Concat // | | // v v // Slice Slice // - std::vector users = concat->users(); + std::vector users = concat.users(); if (!absl::c_all_of(users, [](const HloInstruction* i) { return Is1dSliceWithoutStrides(i); })) { @@ -168,9 +169,9 @@ bool ConcatHasNoEffect(const HloInstruction* concat) { absl::c_sort(users, [](const HloInstruction* a, const HloInstruction* b) { return a->slice_starts().at(0) < b->slice_starts().at(0); }); - // Verify that each operand to the concat is reversed by a slice. - if (users.size() != concat->operand_count()) { + if (users.size() != concat.operand_count() || + concat.operand_count() != concat.unique_operands().size()) { return false; } int64 prev_limit = 0; @@ -178,24 +179,48 @@ bool ConcatHasNoEffect(const HloInstruction* concat) { const HloInstruction* u = users[i]; int64 slice_size = u->slice_limits().at(0) - u->slice_starts().at(0); if (u->slice_starts().at(0) != prev_limit || - slice_size != ShapeUtil::ElementsIn(concat->operand(i)->shape())) { + slice_size != ShapeUtil::ElementsIn(concat.operand(i)->shape())) { return false; } prev_limit = u->slice_limits().at(0); } + // If we have seen other concats, make sure they are identical. + int64 operand_idx = concat.operand_index(&operand); + *slice_to_recover_opnd = users.at(operand_idx); + if (prev_concat_opnd_idx->first == nullptr) { + prev_concat_opnd_idx->first = &concat; + prev_concat_opnd_idx->second = operand_idx; + } else { + bool is_concat_identical = prev_concat_opnd_idx->first->Identical( + concat, + /*eq_operands=*/[](const HloInstruction*, const HloInstruction*) { + // Operands don't need to be the same. + return true; + }); + if (!is_concat_identical || prev_concat_opnd_idx->second != operand_idx) { + return false; + } + } + return true; } -// Returns whether we can prove the transitive uses are in effect elementwise -// operations. A concat followed by slices is considered (effectively) -// elementwise if the slices combiningly is a reverse function of the concat. -// We can prove more patterns but we currently do just enough for +// Returns whether we can prove the transitive uses of `param` are in effect +// elementwise. In other words, we prove that the transitive use closure will +// all be computed in the same iteration space without any reorder of elements. +// Theoretically, We can prove more patterns but our primary use case is // SliceInputFusion. -bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* instr) { +bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* param, + const HloInstruction* root_tuple, + const ShapeIndex& out_shape_idx) { + CHECK(root_tuple->opcode() == HloOpcode::kTuple); + CHECK(out_shape_idx.size() == 1); absl::flat_hash_set visited; absl::InlinedVector stack; - stack.push_back(instr); + stack.push_back(param); + std::pair prev_concat_opnd_idx(nullptr, 0); + bool is_output_reachable = false; while (!stack.empty()) { const HloInstruction* current = stack.back(); stack.pop_back(); @@ -204,24 +229,33 @@ bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* instr) { VLOG(3) << "Visiting: " << user->ToString(); switch (user->opcode()) { case HloOpcode::kTuple: + if (user == root_tuple && + current == root_tuple->operand(out_shape_idx.back())) { + // We need to know if the output is reachable by the `param` to make + // sure that they will be computed in the same iteration space. + is_output_reachable = true; + } break; case HloOpcode::kReshape: if (!ShapeUtil::ReshapeIsBitcast(current->shape(), user->shape())) { return false; } break; - case HloOpcode::kConcatenate: - if (!ConcatHasNoEffect(user)) { + case HloOpcode::kConcatenate: { + const HloInstruction* slice_to_recover_opnd = nullptr; + if (!ConcatIsEffectivelyElementwise(*user, *current, + &prev_concat_opnd_idx, + &slice_to_recover_opnd)) { return false; } - break; - case HloOpcode::kSlice: - if (user->operand(0)->opcode() != HloOpcode::kConcatenate) { - // Check that we have seen and verified a preceding concat of this - // Slice. - return false; - } - break; + // Early continue as we only want to traverse through the slice that + // recovers the operand. It is guaranteed that the operand to the + // concat and the slice have the same iteration space. Insert the + // slice instead of the concat. + CHECK(!visited.contains(slice_to_recover_opnd)); + stack.push_back(slice_to_recover_opnd); + continue; + } default: if (user->IsElementwise()) { for (const int64 use_index : user->OperandIndices(current)) { @@ -248,7 +282,7 @@ bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* instr) { } } } - return true; + return is_output_reachable; } } // namespace @@ -1411,7 +1445,8 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( return ShapeUtil::ElementsIn(operand_subshape) == ShapeUtil::ElementsIn(user_subshape) && ShapeUtil::SameElementType(operand_subshape, user_subshape) && - AreTransitiveUsesEffectivelyElementwise(fusion_param); + AreTransitiveUsesEffectivelyElementwise( + fusion_param, user->fused_expression_root(), user_index); } // Check that operand and user emit the same shape and layout. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index b83f24df894..0a17edf8c29 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2843,5 +2843,50 @@ TEST_F(CanShareOperandBufferWithUserTest, ConcatSliceWithElementwise) { fusion, {1})); } +TEST_F(CanShareOperandBufferWithUserTest, ConcatSliceNegativeTest) { + const char* kModule = R"( + HloModule test + + fused_computation { + // p0 has multiple transitive uses fed to concat. So, p0 cannot share + // buffer with outputs because the aliased output could be written before + // all the uses of p0 are finished. + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add0 = f32[100] add(p0, p1) + concat0 = f32[200] concatenate(p0, add0), dimensions={0} + slice0 = f32[100] slice(concat0), slice={[0:100]} + slice1 = f32[100] slice(concat0), slice={[100:200]} + ROOT tuple = (f32[100], f32[100]) tuple(slice0, slice1) + } + + ENTRY test { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + ROOT fusion = (f32[100], f32[100]) fusion(p0, p1), + kind=kInput, calls=fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule)); + auto* fusion = module_->entry_computation()->root_instruction(); + auto* param0 = module_->entry_computation()->parameter_instruction(0); + auto* param1 = module_->entry_computation()->parameter_instruction(1); + + RunAnalysis(); + // p0 cannot share with either fusion{0} or fusion{1}. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {0})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {1})); + // p1 cannot share with fusion{0} because we're not sure about their + // relationship. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {0})); + // p1 can share with fusion{1} because they will be executed in an + // elementwise manner. + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {1})); +} + } // namespace } // namespace xla From f90bdfc74bee50a7f4a2607ae634952069835882 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Tue, 26 Jan 2021 15:55:43 -0800 Subject: [PATCH 05/10] Add one more unitest for h-loop-fusion buffer sharing. --- .../xla/service/hlo_dataflow_analysis_test.cc | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 0a17edf8c29..08f34330136 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2888,5 +2888,57 @@ TEST_F(CanShareOperandBufferWithUserTest, ConcatSliceNegativeTest) { fusion, {1})); } +TEST_F(CanShareOperandBufferWithUserTest, MultipleConcatenates) { + const char* kModule = R"( + HloModule test + + fused_computation { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + add0 = f32[100] add(p0, p1) + sub0 = f32[100] subtract(p1, p1) + concat0 = f32[200] concatenate(p0, add0), dimensions={0} + slice0 = f32[100] slice(concat0), slice={[0:100]} + slice1 = f32[100] slice(concat0), slice={[100:200]} + concat1 = f32[200] concatenate(p0, sub0), dimensions={0} + slice2 = f32[100] slice(concat1), slice={[0:100]} + slice3 = f32[100] slice(concat1), slice={[100:200]} + ROOT tuple = (f32[100], f32[100], f32[100], f32[100]) + tuple(slice0, slice1, slice2, slice3) + } + + ENTRY test { + p0 = f32[100] parameter(0) + p1 = f32[100] parameter(1) + ROOT fusion = (f32[100], f32[100], f32[100], f32[100]) + fusion(p0, p1), kind=kInput, calls=fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule)); + auto* fusion = module_->entry_computation()->root_instruction(); + auto* param0 = module_->entry_computation()->parameter_instruction(0); + auto* param1 = module_->entry_computation()->parameter_instruction(1); + + RunAnalysis(); + // p0 cannot share. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {0})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {1})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {2})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {3})); + // p1 can share with either fusion{1} or fusion{3}. + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {1})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {3})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {0})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {2})); +} + } // namespace } // namespace xla From 24c07a32c800c0ab4d2705491a2c48c8c9bdbff8 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Wed, 27 Jan 2021 12:10:40 -0800 Subject: [PATCH 06/10] [XLA/GPU] Do not horizontally fuse loop-fusions if they share parameters. --- .../xla/service/gpu/horizontal_loop_fusion.cc | 25 ++++++++++++++ .../gpu/horizontal_loop_fusion_test.cc | 34 +++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc index 9d1e0533a91..a2a2264cfb7 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc @@ -203,6 +203,24 @@ bool HasOnlyRowMajorLayout(const HloInstruction& fusion_instr) { return true; } +// Returns whether any operand of `instr` is a parameter instruction that +// is shared with `fusion_instrs`. +bool AnyOpndIsParamSharedAmongFusions( + const HloInstruction* instr, + const absl::flat_hash_set& fusion_instrs) { + for (const HloInstruction* opnd : instr->operands()) { + if (opnd->opcode() != HloOpcode::kParameter) { + continue; + } + for (const HloInstruction* user : opnd->users()) { + if (user != instr && fusion_instrs.contains(user)) { + return true; + } + } + } + return false; +} + void HorizontalLoopFusionImpl::FusionCandidates::Initialize( HloInstruction* consumer) { // First, find out all fusion instructions. We will filter out @@ -230,6 +248,13 @@ void HorizontalLoopFusionImpl::FusionCandidates::Initialize( } else if (!HasOnlyRowMajorLayout(*instr)) { VLOG(2) << "Reject non-row-major fusion instr " << instr->ToString(); continue; + } else if (AnyOpndIsParamSharedAmongFusions(instr, fusion_instrs)) { + // Don't fuse fusions whose operands are parameter instructions that are + // shared among fusions because we cannot i/o alias the produced + // horizontal fusion due to the concat insertion. + VLOG(2) << "Reject the fusion instr because it shares parameter with" + << " other fusion candidates, instr: ", instr->ToString(); + continue; } else { VLOG(2) << "Find a fusion candidate " << instr->ToString(); fusion_instrs_.push_back(instr); diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc index 8091330cd47..a124ad765fd 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc @@ -403,6 +403,40 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) { EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); } +TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule BasicTest + + fused_computation.1 { + arg.1 = f16[123]{0} parameter(0) + arg.2 = f16[123]{0} parameter(1) + ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2) + } + + fused_computation.2 { + arg.1 = f16[123]{0} parameter(0) + arg.2 = f16[123]{0} parameter(1) + ROOT add.1 = f16[123]{0} add(arg.1, arg.2) + } + + ENTRY entry_computation { + arg.1 = f16[123]{0} parameter(0) + // arg.2 is shared by fusion.1 and fusion.2 + arg.2 = f16[123]{0} parameter(1) + arg.3 = f16[123]{0} parameter(2) + fusion.1 = f16[123]{0} + fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1 + fusion.2 = f16[123]{0} + fusion(arg.3, arg.2), kind=kLoop, calls=fused_computation.2 + ROOT tuple.1 = (f16[123]{0}, f16[123]{0}) + tuple(fusion.1, fusion.2) + } +)") + .ValueOrDie(); + + EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace gpu } // namespace xla From 3fb0103a881da1fdde23d2f3c2fde3dde2d3a988 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Wed, 10 Feb 2021 11:01:34 -0800 Subject: [PATCH 07/10] [XLA/GPU] Use IsElementwiseOnOperand() to test elementwise relationship. Otherwise, we will exclude DynamicSliceUpdate because it is not a pure elementwise instruction. --- .../xla/service/hlo_dataflow_analysis.cc | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index f08f22d3acf..a2d456a05c3 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -257,22 +257,16 @@ bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* param, continue; } default: - if (user->IsElementwise()) { - for (const int64 use_index : user->OperandIndices(current)) { - if (!user->IsElementwiseOnOperand(use_index)) { - // Found a user that is non-elementwise on the current - // instruction. - return false; - } - } - if (!LayoutUtil::Equal(current->shape().layout(), - user->shape().layout())) { - // Make sure the layout is not changed by the elementwise op. + for (const int64 use_index : user->OperandIndices(current)) { + if (!user->IsElementwiseOnOperand(use_index)) { + // Found a user that is non-elementwise on the current + // instruction. return false; } - } else { - VLOG(3) << "Cannot prove that the op is effectively elementwise: " - << user->ToString(); + } + if (!LayoutUtil::Equal(current->shape().layout(), + user->shape().layout())) { + // Make sure the layout is not changed by the elementwise op. return false; } break; From 7bfeff2f6bb60306b266fd71d4b3a303319ddaf4 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Wed, 10 Feb 2021 14:08:16 -0800 Subject: [PATCH 08/10] [XLA/GPU] Horizontally fuse DynamicUpdateSlice. Previously we encountered an issue that additional copies are inserted due to horizontal fusion. Re-enable the handling of DynamicUpdateSlice since a fix has been committed. With the fix, we observe a 30% speedup due to horizontal loop fusion from the previously problematic case. See the below trace from nvprof. //Horizontal fusion (w/ fix), kernel time ~= 3.33ms (30% speedup vs. w/o horizontal fusion.) 29.51% 1.8112ms 1000 1.8110us 1.7910us 2.3360us fusion_5 24.68% 1.5148ms 1000 1.5140us 1.5030us 1.9840us fusion_3 0.09% 5.5040us 3 1.8340us 1.5360us 2.1120us [CUDA memcpy DtoD] //Horizontal fusion (w/o fix), kernel time ~= 5.74ms 28.91% 2.4695ms 2003 1.2320us 1.2150us 1.9200us [CUDA memcpy DtoD] 21.13% 1.8052ms 1000 1.8050us 1.7910us 2.3360us fusion_5 17.20% 1.4691ms 1000 1.4690us 1.4390us 1.6320us fusion_3 // No horizontal fusion, kernel time = ~4.36ms 21.21% 1.5209ms 1000 1.5200us 1.4720us 10.110us fusion_3 19.81% 1.4204ms 1000 1.4200us 1.3750us 8.5750us fusion_1 19.74% 1.4155ms 1000 1.4150us 1.3750us 1.7600us fusion 0.07% 5.2800us 3 1.7600us 1.5040us 1.8880us [CUDA memcpy DtoD] --- .../xla/service/gpu/horizontal_loop_fusion.cc | 8 ------ .../gpu/horizontal_loop_fusion_test.cc | 26 ++++++++++++------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc index a2a2264cfb7..36bdf1db1eb 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc @@ -174,14 +174,6 @@ bool IsProfitableFusionCandidate(const HloInstruction& instr) { return false; } - // We can emit DUS in-place, horizontally fusing it makes the emitter no - // longer recognize that it can be done in-place. This creates much slower - // code. This restriction could be lifted if buffer assignment would recognize - // that the DUS can be done in-place even inside of a horizontal fusion. - if (root->opcode() == HloOpcode::kDynamicUpdateSlice) { - return false; - } - return true; } diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc index a124ad765fd..d956438cb5a 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc @@ -364,33 +364,33 @@ TEST_F(HorizontalLoopFusionTest, RMSPropLike) { EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5})); } -TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) { +TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) { auto module = ParseAndReturnVerifiedModule(R"( HloModule NegativeTestForDynamicUpdateSlice fusion.1 { p.0 = f16[5,9,10]{2,1,0} parameter(0) - p.1 = s32[1]{0} parameter(1) + p.1 = s32[] parameter(1) p.2 = f16[1,9,10]{2,1,0} parameter(2) c.0 = s32[] constant(0) - pad = s32[3]{0} pad(p.1, c.0), padding=0_2 - ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + ROOT %dynamic-update-slice = + f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0) } fusion.2 { p.0 = f16[5,9,10]{2,1,0} parameter(0) - p.1 = s32[1]{0} parameter(1) + p.1 = s32[] parameter(1) p.2 = f16[1,9,10]{2,1,0} parameter(2) c.0 = s32[] constant(0) - pad = s32[3]{0} pad(p.1, c.0), padding=0_2 - ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad) + ROOT %dynamic-update-slice = + f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0) } ENTRY entry { p.00 = f16[5,9,10]{2,1,0} parameter(0) p.01 = f16[5,9,10]{2,1,0} parameter(1) - p.10 = s32[1]{0} parameter(2) - p.11 = s32[1]{0} parameter(3) + p.10 = s32[] parameter(2) + p.11 = s32[] parameter(3) p.20 = f16[1,9,10]{2,1,0} parameter(4) p.21 = f16[1,9,10]{2,1,0} parameter(5) @@ -400,7 +400,13 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) { })") .ValueOrDie(); - EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie()); + + VLOG(2) << "Dump after horizontal fusion:"; + VLOG(2) << module->ToString(); + + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0})); } TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) { From 23defae7b0e80dd9a821116962ae96250cc4f38f Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Wed, 10 Feb 2021 15:31:50 -0800 Subject: [PATCH 09/10] [XLA/GPU] Comment polishing. --- .../compiler/xla/service/hlo_dataflow_analysis.cc | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index a2d456a05c3..549579dd13d 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -166,14 +166,14 @@ bool ConcatIsEffectivelyElementwise( // Limit our supported cases to 1 dimensional slices. return false; } - absl::c_sort(users, [](const HloInstruction* a, const HloInstruction* b) { - return a->slice_starts().at(0) < b->slice_starts().at(0); - }); // Verify that each operand to the concat is reversed by a slice. if (users.size() != concat.operand_count() || concat.operand_count() != concat.unique_operands().size()) { return false; } + absl::c_sort(users, [](const HloInstruction* a, const HloInstruction* b) { + return a->slice_starts().at(0) < b->slice_starts().at(0); + }); int64 prev_limit = 0; for (int64 i = 0; i < users.size(); ++i) { const HloInstruction* u = users[i]; @@ -185,7 +185,10 @@ bool ConcatIsEffectivelyElementwise( prev_limit = u->slice_limits().at(0); } - // If we have seen other concats, make sure they are identical. + // If we have seen other concats, make sure they are identical. Multiple + // concats exist because horizontal fusion inserts one concat for each output + // of the fusion candidates. Check all concats and operand ids are the same + // to make sure that the compute iteration space is the same. int64 operand_idx = concat.operand_index(&operand); *slice_to_recover_opnd = users.at(operand_idx); if (prev_concat_opnd_idx->first == nullptr) { @@ -207,8 +210,10 @@ bool ConcatIsEffectivelyElementwise( } // Returns whether we can prove the transitive uses of `param` are in effect -// elementwise. In other words, we prove that the transitive use closure will +// elementwise. In other words, we prove that the "transitive use closure" will // all be computed in the same iteration space without any reorder of elements. +// In addition, we check that the "transitive use closure" includes the output +// in the `root_tuple`. // Theoretically, We can prove more patterns but our primary use case is // SliceInputFusion. bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* param, From fa4e5ca6ca53bd8d209e7fddf7a5b54a20a69d06 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Thu, 11 Feb 2021 11:48:17 -0800 Subject: [PATCH 10/10] [XLA/GPU] Polish coding styles. --- .../xla/service/copy_insertion_test.cc | 2 - .../xla/service/gpu/horizontal_loop_fusion.cc | 17 ++--- .../xla/service/hlo_dataflow_analysis.cc | 68 ++++++++++--------- 3 files changed, 42 insertions(+), 45 deletions(-) diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 3c920a0d382..39d8a9e6002 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -2912,8 +2912,6 @@ TEST_F(CopyInsertionTest, HorizontalLoopFusionNoCopy) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_string)); - // Set up the aliasing manually which normally would be set by - // alias_passthrough_params pass. ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias( /*output_index=*/{0}, /*param_number=*/0, diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc index 36bdf1db1eb..da04f4fa1c3 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc @@ -200,17 +200,12 @@ bool HasOnlyRowMajorLayout(const HloInstruction& fusion_instr) { bool AnyOpndIsParamSharedAmongFusions( const HloInstruction* instr, const absl::flat_hash_set& fusion_instrs) { - for (const HloInstruction* opnd : instr->operands()) { - if (opnd->opcode() != HloOpcode::kParameter) { - continue; - } - for (const HloInstruction* user : opnd->users()) { - if (user != instr && fusion_instrs.contains(user)) { - return true; - } - } - } - return false; + return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) { + return opnd->opcode() == HloOpcode::kParameter && + absl::c_any_of(opnd->users(), [&](const HloInstruction* user) { + return user != instr && fusion_instrs.contains(user); + }); + }); } void HorizontalLoopFusionImpl::FusionCandidates::Initialize( diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 549579dd13d..887761254d8 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -142,15 +142,21 @@ bool IsSliceInputFusion(const HloInstruction& unnested_hlo) { }); } -// Returns whether the concat is used in an elementwise manner. -// A concat followed by slices is considered (effectively) elementwise if the -// slices combinedly is a reverse function of the concat. -// prev_concat_opnd_idx: previously seen concat and its operand index. -// prev_concat_opnd_idx.first is nullptr if no previously seen concat. -bool ConcatIsEffectivelyElementwise( +struct ConcatUsageInfo { + // Pointer to a previously seen concat. nullptr if no previously seen concat. + const HloInstruction* prev_concat; + // The opnd id of the seen concat. + int64 concat_opnd_idx; + // The slice that recovers the opnd in the concat outputs. + const HloInstruction* slice_to_recover_opnd; +}; + +// Returns an optional concat usage info to denote whether the concat is used in +// an elementwise manner. A concat followed by slices is considered effectively +// elementwise if the slices combinedly is a reverse function of the concat. +absl::optional ConcatIsEffectivelyElementwise( const HloInstruction& concat, const HloInstruction& operand, - std::pair* prev_concat_opnd_idx, - const HloInstruction** slice_to_recover_opnd) { + const ConcatUsageInfo& info) { // First, check if this concat is in the below pattern. Also, we check // that the slices combinedly are in effect a reverse function of the concat. // @@ -160,16 +166,14 @@ bool ConcatIsEffectivelyElementwise( // Slice Slice // std::vector users = concat.users(); - if (!absl::c_all_of(users, [](const HloInstruction* i) { - return Is1dSliceWithoutStrides(i); - })) { + if (!absl::c_all_of(users, Is1dSliceWithoutStrides)) { // Limit our supported cases to 1 dimensional slices. - return false; + return absl::optional(); } // Verify that each operand to the concat is reversed by a slice. if (users.size() != concat.operand_count() || concat.operand_count() != concat.unique_operands().size()) { - return false; + return absl::optional(); } absl::c_sort(users, [](const HloInstruction* a, const HloInstruction* b) { return a->slice_starts().at(0) < b->slice_starts().at(0); @@ -180,33 +184,32 @@ bool ConcatIsEffectivelyElementwise( int64 slice_size = u->slice_limits().at(0) - u->slice_starts().at(0); if (u->slice_starts().at(0) != prev_limit || slice_size != ShapeUtil::ElementsIn(concat.operand(i)->shape())) { - return false; + return absl::optional(); } prev_limit = u->slice_limits().at(0); } // If we have seen other concats, make sure they are identical. Multiple // concats exist because horizontal fusion inserts one concat for each output - // of the fusion candidates. Check all concats and operand ids are the same - // to make sure that the compute iteration space is the same. + // of the fusion candidates. Check that all concats and operand ids are the + // same to know that the "transitive use closure" will be computed in the same + // iteration space. int64 operand_idx = concat.operand_index(&operand); - *slice_to_recover_opnd = users.at(operand_idx); - if (prev_concat_opnd_idx->first == nullptr) { - prev_concat_opnd_idx->first = &concat; - prev_concat_opnd_idx->second = operand_idx; - } else { - bool is_concat_identical = prev_concat_opnd_idx->first->Identical( + if (info.prev_concat != nullptr) { + bool is_concat_identical = info.prev_concat->Identical( concat, /*eq_operands=*/[](const HloInstruction*, const HloInstruction*) { // Operands don't need to be the same. return true; }); - if (!is_concat_identical || prev_concat_opnd_idx->second != operand_idx) { - return false; + if (!is_concat_identical || info.concat_opnd_idx != operand_idx) { + return absl::optional(); } } - return true; + const HloInstruction* slice_to_recover_opnd = users.at(operand_idx); + return absl::optional( + ConcatUsageInfo{&concat, operand_idx, slice_to_recover_opnd}); } // Returns whether we can prove the transitive uses of `param` are in effect @@ -224,7 +227,7 @@ bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* param, absl::flat_hash_set visited; absl::InlinedVector stack; stack.push_back(param); - std::pair prev_concat_opnd_idx(nullptr, 0); + ConcatUsageInfo concat_usage_info{nullptr, 0, nullptr}; bool is_output_reachable = false; while (!stack.empty()) { const HloInstruction* current = stack.back(); @@ -247,18 +250,19 @@ bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* param, } break; case HloOpcode::kConcatenate: { - const HloInstruction* slice_to_recover_opnd = nullptr; - if (!ConcatIsEffectivelyElementwise(*user, *current, - &prev_concat_opnd_idx, - &slice_to_recover_opnd)) { + absl::optional optional_concat_info = + ConcatIsEffectivelyElementwise(*user, *current, + concat_usage_info); + if (!optional_concat_info) { return false; } + concat_usage_info = *optional_concat_info; // Early continue as we only want to traverse through the slice that // recovers the operand. It is guaranteed that the operand to the // concat and the slice have the same iteration space. Insert the // slice instead of the concat. - CHECK(!visited.contains(slice_to_recover_opnd)); - stack.push_back(slice_to_recover_opnd); + CHECK(!visited.contains(concat_usage_info.slice_to_recover_opnd)); + stack.push_back(concat_usage_info.slice_to_recover_opnd); continue; } default: