diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index ec63ae1ac61..728adb67a0d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1873,8 +1873,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index b2ba2617902..855424067d2 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -156,29 +158,187 @@ Status DecomposeBatchDot(HloInstruction* dot) { return computation->ReplaceInstruction(dot, new_dot); } +// Convert a dot into a canonical form where non-contracting and contracting +// dimensions are reshaped together and batch dimensions are the most major +// dimensions. The requires transposing and reshapes the lhs and rhs and +// reshaping the output batch to the original shape. +Status CanonicalizeDot(HloInstruction* original_dot) { + auto computation = original_dot->parent(); + const auto& original_dnums = original_dot->dot_dimension_numbers(); + const int64 num_batch_dims = original_dnums.lhs_batch_dimensions_size(); + const int64 num_contracting_dims = + original_dnums.lhs_contracting_dimensions_size(); + + const auto& lhs_shape = original_dot->operand(0)->shape(); + const int64 lhs_rank = lhs_shape.rank(); + const int64 num_lhs_non_contracting_dims = + lhs_rank - num_batch_dims - num_contracting_dims; + + std::vector lhs_non_contracting_dims; + lhs_non_contracting_dims.reserve(num_lhs_non_contracting_dims); + int64 lhs_contracting_size = 1; + int64 lhs_non_contracting_size = 1; + std::vector batch_dim_sizes; + batch_dim_sizes.reserve(num_batch_dims); + for (int64 i = 0; i < lhs_rank; ++i) { + if (absl::c_linear_search(original_dnums.lhs_contracting_dimensions(), i)) { + lhs_contracting_size *= lhs_shape.dimensions(i); + } else if (absl::c_linear_search(original_dnums.lhs_batch_dimensions(), + i)) { + batch_dim_sizes.push_back(lhs_shape.dimensions(i)); + } else { + lhs_non_contracting_dims.push_back(i); + lhs_non_contracting_size *= lhs_shape.dimensions(i); + } + } + // The canonical form of the lhs is + // [BatchDims, NonContractingDims, ContractingsDims] + std::vector lhs_transpose; + lhs_transpose.reserve(lhs_rank); + lhs_transpose.insert(lhs_transpose.end(), + original_dnums.lhs_batch_dimensions().begin(), + original_dnums.lhs_batch_dimensions().end()); + lhs_transpose.insert(lhs_transpose.end(), lhs_non_contracting_dims.begin(), + lhs_non_contracting_dims.end()); + lhs_transpose.insert(lhs_transpose.end(), + original_dnums.lhs_contracting_dimensions().begin(), + original_dnums.lhs_contracting_dimensions().end()); + HloInstruction* transposed_lhs = + computation->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions(InversePermutation(lhs_transpose), + lhs_shape), + original_dot->mutable_operand(0), lhs_transpose)); + std::vector lhs_reshape_dims = batch_dim_sizes; + lhs_reshape_dims.push_back(lhs_non_contracting_size); + lhs_reshape_dims.push_back(lhs_contracting_size); + // Reshape the contracting and non-contracting dimensions together. + HloInstruction* reshaped_lhs = + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(lhs_shape.element_type(), lhs_reshape_dims), + transposed_lhs)); + + const auto& rhs_shape = original_dot->operand(1)->shape(); + const int64 rhs_rank = rhs_shape.rank(); + const int64 num_rhs_non_contracting_dims = + rhs_rank - num_batch_dims - num_contracting_dims; + std::vector rhs_non_contracting_dims; + rhs_non_contracting_dims.reserve(num_rhs_non_contracting_dims); + int64 rhs_non_contracting_size = 1; + int64 rhs_contracting_size = 1; + for (int64 i = 0; i < rhs_rank; ++i) { + if (absl::c_linear_search(original_dnums.rhs_contracting_dimensions(), i)) { + rhs_contracting_size *= rhs_shape.dimensions(i); + } else if (!absl::c_linear_search(original_dnums.rhs_batch_dimensions(), + i)) { + rhs_non_contracting_dims.push_back(i); + rhs_non_contracting_size *= rhs_shape.dimensions(i); + } + } + + // The canonical form of the rhs is + // [BatchDims, ContractingsDims, NonContractingDims] + std::vector rhs_transpose; + rhs_transpose.reserve(rhs_rank); + rhs_transpose.insert(rhs_transpose.end(), + original_dnums.rhs_batch_dimensions().begin(), + original_dnums.rhs_batch_dimensions().end()); + rhs_transpose.insert(rhs_transpose.end(), + original_dnums.rhs_contracting_dimensions().begin(), + original_dnums.rhs_contracting_dimensions().end()); + rhs_transpose.insert(rhs_transpose.end(), rhs_non_contracting_dims.begin(), + rhs_non_contracting_dims.end()); + HloInstruction* transposed_rhs = + computation->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions(InversePermutation(rhs_transpose), + rhs_shape), + original_dot->mutable_operand(1), rhs_transpose)); + + std::vector rhs_reshape_dims = batch_dim_sizes; + rhs_reshape_dims.push_back(rhs_contracting_size); + rhs_reshape_dims.push_back(rhs_non_contracting_size); + // Reshape the contracting and non-contracting dimensions together. + HloInstruction* reshaped_rhs = + computation->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(rhs_shape.element_type(), rhs_reshape_dims), + transposed_rhs)); + + std::vector dot_dims = batch_dim_sizes; + dot_dims.push_back(lhs_non_contracting_size); + dot_dims.push_back(rhs_non_contracting_size); + + DotDimensionNumbers dot_dnums; + for (int64 i = 0; i < num_batch_dims; ++i) { + dot_dnums.add_lhs_batch_dimensions(i); + dot_dnums.add_rhs_batch_dimensions(i); + } + dot_dnums.add_lhs_contracting_dimensions(num_batch_dims + 1); + dot_dnums.add_rhs_contracting_dimensions(num_batch_dims); + + HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(original_dot->shape().element_type(), dot_dims), + reshaped_lhs, reshaped_rhs, dot_dnums, original_dot->precision_config())); + + return computation->ReplaceInstruction( + original_dot, computation->AddInstruction(HloInstruction::CreateReshape( + original_dot->shape(), dot))); +} + } // namespace StatusOr DotDecomposer::Run(HloModule* module) { XLA_VLOG_LINES(2, "DotDecomposer ENTRY\n" + module->ToString()); - // Gather all batch Dot operations. - std::vector batch_dots; + // Gather all Non-canonical Dot operations. + std::vector non_canonical_dots; for (auto* computation : module->MakeNonfusionComputations()) { for (auto* instruction : computation->instructions()) { if (instruction->opcode() != HloOpcode::kDot) { continue; } const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); - if (dnums.lhs_batch_dimensions_size() > 0 && decompose_batch_dot_) { - batch_dots.push_back(instruction); + // A dot it not canonical if there are more than one contracting + // dimension. + if (dnums.lhs_contracting_dimensions_size() > 1) { + non_canonical_dots.push_back(instruction); + continue; + } + if (dnums.lhs_batch_dimensions().empty()) { + continue; + } + std::vector canonical_batch_dims( + dnums.lhs_batch_dimensions_size()); + absl::c_iota(canonical_batch_dims, 0); + if (!absl::c_equal(dnums.lhs_batch_dimensions(), canonical_batch_dims) || + !absl::c_equal(dnums.rhs_batch_dimensions(), canonical_batch_dims)) { + non_canonical_dots.push_back(instruction); } } } - // Decompose each batch Dot in 'batch_dots'. bool changed = false; - for (auto* dot : batch_dots) { - TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); + for (auto* dot : non_canonical_dots) { + TF_RETURN_IF_ERROR(CanonicalizeDot(dot)); changed = true; } + + if (decompose_batch_dot_) { + std::vector batch_dots; + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kDot) { + continue; + } + const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); + if (!dnums.lhs_batch_dimensions().empty()) { + batch_dots.push_back(instruction); + } + } + } + // Decompose each batch Dot in 'batch_dots'. + + for (auto* dot : batch_dots) { + TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); + changed = true; + } + } XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString()); return changed; } diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 5e81281bae5..30a8bda606d 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -699,6 +699,7 @@ cc_library( "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", + "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 9be76095083..d1522280baf 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/convolution_group_converter.h" +#include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" @@ -165,6 +166,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // We need a cost model for GPUs. Currently, do nothing. return false; }; + pipeline.AddPass(false); pipeline.AddPass( cost_model, /*convert_batch_groups_only=*/true); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 106b3fd6c9e..0fd0fc108a6 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -711,6 +711,7 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -768,6 +769,7 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index c5d8b663f4a..2e02968ac5f 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -1147,5 +1148,38 @@ XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) { ComputeAndCompareR2(&builder, expected, {}, error_spec_); } + +class DotOperationTextTest : public HloTestBase {}; + +XLA_TEST_F(DotOperationTextTest, DotReorderedDotDims) { + absl::string_view hlo_string = + R"( +HloModule ComplexDotMultipleNonContracting + +ENTRY %test { + %lhs = f32[7,17,10,13]{3,2,1,0} parameter(0) + %rhs = f32[7,9,10,13,6]{4,3,2,1,0} parameter(1) + ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={2,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={3}, rhs_contracting_dims={3} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); +} + +XLA_TEST_F(DotOperationTextTest, DotReorderedDotDimsAndMultipleContracting) { + absl::string_view hlo_string = + R"( +HloModule ComplexDotMultipleNonContracting + +ENTRY %test { + %lhs = f32[7,5,17,10,13]{4,3,2,1,0} parameter(0) + %rhs = f32[7,9,10,13,6,5]{5,4,3,2,1,0} parameter(1) + ROOT %dot = f32[10,7,17,9,6]{4,3,2,1,0} dot(%lhs, %rhs), lhs_batch_dims={3,0}, rhs_batch_dims={2,0}, lhs_contracting_dims={1,4}, rhs_contracting_dims={5,3} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-3, 1e-3})); +} + } // namespace } // namespace xla