From 86ea22210462f42c0c85920f5962f603b81a5e55 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 26 May 2020 05:46:06 -0700 Subject: [PATCH] Improve DotDecomposer to not add unnecessary non-contracting dimensions. These would be removed by AlgebraicSimplifier, then DotDecomposer would add them again, which makes the HloPassFix iterate until it hits the maximum number of iterations. Also consider operands of dots without non-contracting dimension to be canonical. PiperOrigin-RevId: 313174496 Change-Id: I8e8ac404198a9df01378820ad16834c9893336a5 --- .../compiler/xla/service/dot_decomposer.cc | 25 ++++--- .../xla/service/dot_decomposer_test.cc | 70 +++++++++++++++++++ 2 files changed, 87 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index 353a7f5cebc..40354dec3c6 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -31,7 +31,7 @@ namespace { // Convert a dot into a canonical form where non-contracting and contracting // dimensions are reshaped together and batch dimensions are the most major -// dimensions. The requires transposing and reshapes the lhs and rhs and +// dimensions. This requires transposing and reshapes of the lhs and rhs and // reshaping the output batch to the original shape. Status CanonicalizeDot(HloInstruction* original_dot) { auto computation = original_dot->parent(); @@ -80,7 +80,9 @@ Status CanonicalizeDot(HloInstruction* original_dot) { lhs_shape), original_dot->mutable_operand(0), lhs_transpose)); std::vector lhs_reshape_dims = batch_dim_sizes; - lhs_reshape_dims.push_back(lhs_non_contracting_size); + if (lhs_non_contracting_size > 1) { + lhs_reshape_dims.push_back(lhs_non_contracting_size); + } lhs_reshape_dims.push_back(lhs_contracting_size); // Reshape the contracting and non-contracting dimensions together. HloInstruction* reshaped_lhs = @@ -126,7 +128,9 @@ Status CanonicalizeDot(HloInstruction* original_dot) { std::vector rhs_reshape_dims = batch_dim_sizes; rhs_reshape_dims.push_back(rhs_contracting_size); - rhs_reshape_dims.push_back(rhs_non_contracting_size); + if (rhs_non_contracting_size > 1) { + rhs_reshape_dims.push_back(rhs_non_contracting_size); + } // Reshape the contracting and non-contracting dimensions together. HloInstruction* reshaped_rhs = computation->AddInstruction(HloInstruction::CreateReshape( @@ -134,15 +138,20 @@ Status CanonicalizeDot(HloInstruction* original_dot) { transposed_rhs)); std::vector dot_dims = batch_dim_sizes; - dot_dims.push_back(lhs_non_contracting_size); - dot_dims.push_back(rhs_non_contracting_size); + if (lhs_non_contracting_size > 1) { + dot_dims.push_back(lhs_non_contracting_size); + } + if (rhs_non_contracting_size > 1) { + dot_dims.push_back(rhs_non_contracting_size); + } DotDimensionNumbers dot_dnums; for (int64 i = 0; i < num_batch_dims; ++i) { dot_dnums.add_lhs_batch_dimensions(i); dot_dnums.add_rhs_batch_dimensions(i); } - dot_dnums.add_lhs_contracting_dimensions(num_batch_dims + 1); + dot_dnums.add_lhs_contracting_dimensions( + num_batch_dims + (lhs_non_contracting_size > 1 ? 1 : 0)); dot_dnums.add_rhs_contracting_dimensions(num_batch_dims); HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot( @@ -174,9 +183,9 @@ StatusOr DotDecomposer::Run(HloModule* module) { } // A dot is not canonical if it has more than one non-contracting // dimension. - if (dnums.lhs_batch_dimensions_size() + 2 != + if (dnums.lhs_batch_dimensions_size() + 2 < instruction->operand(0)->shape().rank() || - dnums.rhs_batch_dimensions_size() + 2 != + dnums.rhs_batch_dimensions_size() + 2 < instruction->operand(1)->shape().rank()) { non_canonical_dots.push_back(instruction); continue; diff --git a/tensorflow/compiler/xla/service/dot_decomposer_test.cc b/tensorflow/compiler/xla/service/dot_decomposer_test.cc index 67fff50eaf6..c4152393933 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer_test.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer_test.cc @@ -50,5 +50,75 @@ TEST_F(DotDecomposerTest, CanonicalizeMultipleNonContractingDims) { op::Shape("f32[4032,512]")))); } +TEST_F(DotDecomposerTest, DontCanonicalizeIfNoNoncontractingDims) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = f32[64,4]{1,0} parameter(0) + p1 = f32[64,4]{1,0} parameter(1) + ROOT dot = f32[64]{0} dot(p0, p1), lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool canonicalized, + DotDecomposer().Run(module.get())); + EXPECT_FALSE(canonicalized); +} + +TEST_F(DotDecomposerTest, DontAddLhsNonContractingDimIfOne) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = f32[64,4]{1,0} parameter(0) + p1 = f32[64,4,2,1]{3,2,1,0} parameter(1) + ROOT dot = f32[64,2,1]{2,1,0} dot(p0, p1), lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool canonicalized, + DotDecomposer().Run(module.get())); + EXPECT_TRUE(canonicalized); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Reshape(AllOf(op::Dot(op::Reshape(), op::Reshape(), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/1), + op::Shape("f32[64,2]")))); +} + +TEST_F(DotDecomposerTest, DontAddRhsNonContractingDimIfOne) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = f32[64,4,2,1]{3,2,1,0} parameter(0) + p1 = f32[64,4]{1,0} parameter(1) + ROOT dot = f32[64,2,1]{2,1,0} dot(p0, p1), lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool canonicalized, + DotDecomposer().Run(module.get())); + EXPECT_TRUE(canonicalized); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Reshape(AllOf(op::Dot(op::Reshape(), op::Reshape(), + /*lhs_contracting_dim=*/2, + /*rhs_contracting_dim=*/1), + op::Shape("f32[64,2]")))); +} + } // namespace } // namespace xla