From ba941c212da2f9bc52171172f61cbaf2a5d8c750 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Tue, 12 Feb 2019 10:54:26 -0800 Subject: [PATCH] [XLA] Simplify batch dots that have no contracting dimensions into multiplies. PiperOrigin-RevId: 233637534 --- .../xla/service/algebraic_simplifier.cc | 65 ++++++++++++------- .../xla/service/algebraic_simplifier_test.cc | 16 +++-- .../compiler/xla/service/shape_inference.cc | 6 +- .../xla/service/shape_inference_test.cc | 10 +++ .../compiler/xla/tests/dot_operation_test.cc | 2 + tensorflow/compiler/xla/util.cc | 11 ++-- 6 files changed, 70 insertions(+), 40 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 51d924f34d1..c5deb74e96a 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1605,29 +1605,50 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return Status::OK(); } - // If a dot only contains batch dimensions, then tranpose the rhs and lhs - // acording to the batch dimension numbers and do a simple multiply. - if (lhs->shape().rank() == - dot->dot_dimension_numbers().lhs_batch_dimensions_size() && - rhs->shape().rank() == - dot->dot_dimension_numbers().rhs_batch_dimensions_size()) { - HloInstruction* new_rhs = rhs; - HloInstruction* new_lhs = lhs; - if (lhs->shape().rank() > 1) { - TF_ASSIGN_OR_RETURN( - new_rhs, - MakeTransposeHlo( - rhs, AsInt64Slice( - dot->dot_dimension_numbers().rhs_batch_dimensions()))); - TF_ASSIGN_OR_RETURN( - new_lhs, - MakeTransposeHlo( - lhs, AsInt64Slice( - dot->dot_dimension_numbers().lhs_batch_dimensions()))); + // If there are no contracting dimensions, a dot can be rewritten as + // mul(broadcast(transpose(x)),broadcast(transpose(y))) + if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) { + std::vector lhs_transpose( + dot->dot_dimension_numbers().lhs_batch_dimensions().begin(), + dot->dot_dimension_numbers().lhs_batch_dimensions().end()); + for (int64 i = 0; i < lhs->shape().rank(); ++i) { + if (!absl::c_linear_search( + dot->dot_dimension_numbers().lhs_batch_dimensions(), i)) { + lhs_transpose.push_back(i); + } } - TF_ASSIGN_OR_RETURN(auto new_dot, - MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs)); - return ReplaceInstruction(dot, new_dot); + TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs, + MakeTransposeHlo(lhs, lhs_transpose)); + if (dot->shape().rank() != lhs->shape().rank()) { + std::vector lhs_broadcast_dims(lhs->shape().rank()); + absl::c_iota(lhs_broadcast_dims, 0); + new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + dot->shape(), new_lhs, lhs_broadcast_dims)); + } + std::vector rhs_transpose( + dot->dot_dimension_numbers().rhs_batch_dimensions().begin(), + dot->dot_dimension_numbers().rhs_batch_dimensions().end()); + for (int64 i = 0; i < rhs->shape().rank(); ++i) { + if (!absl::c_linear_search( + dot->dot_dimension_numbers().rhs_batch_dimensions(), i)) { + rhs_transpose.push_back(i); + } + } + TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs, + MakeTransposeHlo(rhs, rhs_transpose)); + if (dot->shape().rank() != rhs->shape().rank()) { + std::vector rhs_broadcast_dims( + dot->dot_dimension_numbers().lhs_batch_dimensions_size()); + absl::c_iota(rhs_broadcast_dims, 0); + for (int64 i = lhs->shape().rank(); i < dot->shape().rank(); ++i) { + rhs_broadcast_dims.push_back(i); + } + new_rhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + dot->shape(), new_rhs, rhs_broadcast_dims)); + } + return ReplaceWithNewInstruction( + dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply, + new_lhs, new_rhs)); } if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 || diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 2ace9330e79..feb6a0fb795 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -4222,8 +4222,10 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { std::tie(m, k, n, element_type) = GetParam(); Shape dot_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, n}); - Shape lhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, m, k}); - Shape rhs_shape = ShapeUtil::MakeShape(element_type, {1, 3, 5, k, n}); + Shape lhs_shape = k > 0 ? ShapeUtil::MakeShape(element_type, {1, 3, 5, m, k}) + : ShapeUtil::MakeShape(element_type, {1, 3, 5, m}); + Shape rhs_shape = k > 0 ? ShapeUtil::MakeShape(element_type, {1, 3, 5, k, n}) + : ShapeUtil::MakeShape(element_type, {1, 3, 5, n}); HloComputation::Builder builder(TestName()); auto lhs = builder.AddInstruction( @@ -4237,14 +4239,16 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { dot_dnums.add_rhs_batch_dimensions(0); dot_dnums.add_rhs_batch_dimensions(1); dot_dnums.add_rhs_batch_dimensions(2); - dot_dnums.add_lhs_contracting_dimensions(4); - dot_dnums.add_rhs_contracting_dimensions(3); + if (k > 0) { + dot_dnums.add_lhs_contracting_dimensions(4); + dot_dnums.add_rhs_contracting_dimensions(3); + } builder.AddInstruction(HloInstruction::CreateDot( dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(default_options_); TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); - const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; + const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1 || k == -1; const bool computation_should_be_modified = dot_should_be_transformed; EXPECT_EQ(changed, computation_should_be_modified); bool has_no_dot = true; @@ -4259,7 +4263,7 @@ TEST_P(BatchDotStrengthReductionTest, BatchDotStrengthReduction) { INSTANTIATE_TEST_SUITE_P( BatchDotStrengthReductionTestInstantiation, BatchDotStrengthReductionTest, - ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), + ::testing::Combine(::testing::Values(1, 2), ::testing::Values(-1, 1, 2), ::testing::Values(1, 2), ::testing::Values(F32, BF16))); class DotStrengthReductionTest diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 061dc46f7db..a570ee346d2 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2700,11 +2700,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& operand, absl::Span dimensions) { TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); - std::vector indices(operand.rank()); - std::iota(indices.begin(), indices.end(), 0); - if (dimensions.size() != operand.rank() || - !std::is_permutation(dimensions.begin(), dimensions.end(), - indices.begin())) { + if (!IsPermutation(dimensions, operand.rank())) { return InvalidArgument( "Transpose dimensions [%s] are not a permutation of the operand " "dimensions (operand shape is %s).", diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 77fcf4d27ce..f400ef51f07 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1572,6 +1572,16 @@ TEST_F(ShapeInferenceTest, Transpose) { ShapeUtil::MakeShape(F32, {3, 4, 5, 2}))); } +TEST_F(ShapeInferenceTest, Rank1Transpose) { + Shape a_shape = ShapeUtil::MakeShape(F32, {5}); + auto inferred_shape_and_status = + ShapeInference::InferTransposeShape(a_shape, {0}); + EXPECT_IS_OK(inferred_shape_and_status); + Shape inferred_shape = inferred_shape_and_status.ValueOrDie(); + EXPECT_TRUE( + ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5}))); +} + TEST_F(ShapeInferenceTest, Conditional) { auto inferred_status0 = ShapeInference::InferConditionalShape( pred_, vector_32_, vector_64_, diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 2e35805c6a4..11343ddfd0b 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -1189,6 +1189,8 @@ std::vector GetEinsumTestCases() { p{v{5, 6}, v{6, 7}, "ab,cd->dcba"}, p{v{6}, v{6, 7}, "b,bc->c"}, p{v{77}, v{77}, "a,a->a"}, + p{v{77}, v{77, 55}, "a,ab->ba"}, + p{v{2, 3, 77}, v{77, 2, 3, 55}, "ija,aijb->baij"}, p{v{55}, v{}, "a,->a"}, p{v{11, 111}, v{11}, "ab,a->ab"}, p{v{16, 34}, v{16, 34}, "ab,ab->ab"}, diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 34b73b5206f..bb8bbf57c42 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -80,13 +81,9 @@ bool IsPermutation(absl::Span permutation, int64 rank) { if (rank != permutation.size()) { return false; } - std::vector output(permutation.size(), -1); - for (auto index : permutation) { - CHECK_GE(index, 0); - CHECK_LT(index, rank); - output[index] = 0; - } - return !absl::c_linear_search(output, -1); + absl::InlinedVector trivial_permutation(rank); + absl::c_iota(trivial_permutation, 0); + return absl::c_is_permutation(permutation, trivial_permutation); } std::vector InversePermutation(