diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index 353a7f5cebc..40354dec3c6 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -31,7 +31,7 @@ namespace { // Convert a dot into a canonical form where non-contracting and contracting // dimensions are reshaped together and batch dimensions are the most major -// dimensions. The requires transposing and reshapes the lhs and rhs and +// dimensions. This requires transposing and reshapes of the lhs and rhs and // reshaping the output batch to the original shape. Status CanonicalizeDot(HloInstruction* original_dot) { auto computation = original_dot->parent(); @@ -80,7 +80,9 @@ Status CanonicalizeDot(HloInstruction* original_dot) { lhs_shape), original_dot->mutable_operand(0), lhs_transpose)); std::vector lhs_reshape_dims = batch_dim_sizes; - lhs_reshape_dims.push_back(lhs_non_contracting_size); + if (lhs_non_contracting_size > 1) { + lhs_reshape_dims.push_back(lhs_non_contracting_size); + } lhs_reshape_dims.push_back(lhs_contracting_size); // Reshape the contracting and non-contracting dimensions together. HloInstruction* reshaped_lhs = @@ -126,7 +128,9 @@ Status CanonicalizeDot(HloInstruction* original_dot) { std::vector rhs_reshape_dims = batch_dim_sizes; rhs_reshape_dims.push_back(rhs_contracting_size); - rhs_reshape_dims.push_back(rhs_non_contracting_size); + if (rhs_non_contracting_size > 1) { + rhs_reshape_dims.push_back(rhs_non_contracting_size); + } // Reshape the contracting and non-contracting dimensions together. HloInstruction* reshaped_rhs = computation->AddInstruction(HloInstruction::CreateReshape( @@ -134,15 +138,20 @@ Status CanonicalizeDot(HloInstruction* original_dot) { transposed_rhs)); std::vector dot_dims = batch_dim_sizes; - dot_dims.push_back(lhs_non_contracting_size); - dot_dims.push_back(rhs_non_contracting_size); + if (lhs_non_contracting_size > 1) { + dot_dims.push_back(lhs_non_contracting_size); + } + if (rhs_non_contracting_size > 1) { + dot_dims.push_back(rhs_non_contracting_size); + } DotDimensionNumbers dot_dnums; for (int64 i = 0; i < num_batch_dims; ++i) { dot_dnums.add_lhs_batch_dimensions(i); dot_dnums.add_rhs_batch_dimensions(i); } - dot_dnums.add_lhs_contracting_dimensions(num_batch_dims + 1); + dot_dnums.add_lhs_contracting_dimensions( + num_batch_dims + (lhs_non_contracting_size > 1 ? 1 : 0)); dot_dnums.add_rhs_contracting_dimensions(num_batch_dims); HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot( @@ -174,9 +183,9 @@ StatusOr DotDecomposer::Run(HloModule* module) { } // A dot is not canonical if it has more than one non-contracting // dimension. - if (dnums.lhs_batch_dimensions_size() + 2 != + if (dnums.lhs_batch_dimensions_size() + 2 < instruction->operand(0)->shape().rank() || - dnums.rhs_batch_dimensions_size() + 2 != + dnums.rhs_batch_dimensions_size() + 2 < instruction->operand(1)->shape().rank()) { non_canonical_dots.push_back(instruction); continue; diff --git a/tensorflow/compiler/xla/service/dot_decomposer_test.cc b/tensorflow/compiler/xla/service/dot_decomposer_test.cc index 67fff50eaf6..c4152393933 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer_test.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer_test.cc @@ -50,5 +50,75 @@ TEST_F(DotDecomposerTest, CanonicalizeMultipleNonContractingDims) { op::Shape("f32[4032,512]")))); } +TEST_F(DotDecomposerTest, DontCanonicalizeIfNoNoncontractingDims) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = f32[64,4]{1,0} parameter(0) + p1 = f32[64,4]{1,0} parameter(1) + ROOT dot = f32[64]{0} dot(p0, p1), lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool canonicalized, + DotDecomposer().Run(module.get())); + EXPECT_FALSE(canonicalized); +} + +TEST_F(DotDecomposerTest, DontAddLhsNonContractingDimIfOne) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = f32[64,4]{1,0} parameter(0) + p1 = f32[64,4,2,1]{3,2,1,0} parameter(1) + ROOT dot = f32[64,2,1]{2,1,0} dot(p0, p1), lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool canonicalized, + DotDecomposer().Run(module.get())); + EXPECT_TRUE(canonicalized); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Reshape(AllOf(op::Dot(op::Reshape(), op::Reshape(), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/1), + op::Shape("f32[64,2]")))); +} + +TEST_F(DotDecomposerTest, DontAddRhsNonContractingDimIfOne) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + p0 = f32[64,4,2,1]{3,2,1,0} parameter(0) + p1 = f32[64,4]{1,0} parameter(1) + ROOT dot = f32[64,2,1]{2,1,0} dot(p0, p1), lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool canonicalized, + DotDecomposer().Run(module.get())); + EXPECT_TRUE(canonicalized); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Reshape(AllOf(op::Dot(op::Reshape(), op::Reshape(), + /*lhs_contracting_dim=*/2, + /*rhs_contracting_dim=*/1), + op::Shape("f32[64,2]")))); +} + } // namespace } // namespace xla