From 5b469d2115f9c987eda339b25321e463b816d39d Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Wed, 26 Jun 2019 17:11:54 -0700 Subject: [PATCH] [XLA GPU] Perform layout assignment before GEMM rewriting GEMM rewriting is sensitive to layout assignment, as we can "fuse" the addition only if it has the same layout. Hence, layout assignment has to be performed first. PiperOrigin-RevId: 255295897 --- .../xla/service/gpu/gpu_layout_assignment.cc | 18 +++++-------- .../service/gpu/gpu_layout_assignment_test.cc | 7 ++---- .../xla/service/gpu/ir_emitter_unnested.cc | 1 + .../xla/service/gpu/nvptx_compiler.cc | 9 +++---- .../service/gpu/tests/gemm_rewrite_test.cc | 25 ------------------- tensorflow/python/kernel_tests/BUILD | 4 +-- 6 files changed, 14 insertions(+), 50 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index ca1b2018d6c..2879acecbce 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -180,14 +180,6 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( return Status::OK(); } -static DotDimensionNumbers GetGemmDotDimensionNumbers( - const HloInstruction* instr) { - CHECK(IsCublasGemm(*instr)); - return instr->backend_config() - .ConsumeValueOrDie() - .dot_dimension_numbers(); -} - Status GpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { // Add convolution constraints in reverse postorder that the earliest @@ -202,14 +194,16 @@ Status GpuLayoutAssignment::AddBackendConstraints( Cast(instruction), constraints)); } + CHECK(!IsCublasGemm(*instruction)) + << "Gemm rewriting should run after layout assignment"; + // For batched dot we require the default layout. // TODO(b/112111608): This is overly conservative, the only real restriction // is that batch dimensions must be major. - if (IsCublasGemm(*instruction) && - GetGemmDotDimensionNumbers(instruction).lhs_batch_dimensions_size() > - 0) { + if (IsMatrixMultiplication(*instruction) && + instruction->dot_dimension_numbers().lhs_batch_dimensions_size() > 0) { // Verify that the batch dims come before the row and col dims. - DotDimensionNumbers dim_nums = GetGemmDotDimensionNumbers(instruction); + DotDimensionNumbers dim_nums = instruction->dot_dimension_numbers(); CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), dim_nums.rhs_batch_dimensions_size()); CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index e08c652e93b..d0d1abf5946 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -351,9 +351,6 @@ TEST_F(LayoutAssignmentTest, DotLayout) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseHloString(hlo_text)); - GemmRewriter gemm_rewriter_pass; - TF_ASSERT_OK_AND_ASSIGN(bool changed, gemm_rewriter_pass.Run(module.get())); - EXPECT_TRUE(changed); ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape(), @@ -366,8 +363,8 @@ TEST_F(LayoutAssignmentTest, DotLayout) { Shape expected_shape = ShapeUtil::MakeShapeWithLayout(F32, {8, 8, 256, 64}, {3, 2, 1, 0}); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::CustomCall(op::ShapeWithLayout(expected_shape), - op::ShapeWithLayout(expected_shape))); + op::Dot(op::ShapeWithLayout(expected_shape), + op::ShapeWithLayout(expected_shape))); } TEST_F(LayoutAssignmentTest, SortLayout) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index ef2bb4ecdfe..168156edf8e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1788,6 +1788,7 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( // first. if (gemm_config.beta() != 0.0) { const HloInstruction* bias = inst->operand(2); + CHECK_EQ(bias->shape(), inst->shape()); if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) { std::vector> thunks; thunks.push_back(absl::make_unique( diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index cb029b65935..06269bd4e48 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -265,12 +265,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } - { - HloPassPipeline pipeline("gemm_canonicalization"); - pipeline.AddPass(); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - } - { // Convert convolutions into CustomCalls to cudnn, then canonicalize them // (CudnnConvPaddingLegalization). Also expand cuSolver calls. @@ -323,6 +317,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, options.set_is_layout_sensitive(true); pipeline.AddPass>(options); + // Rewrite GEMMs into custom calls. + pipeline.AddPass(); + // Choose the fastest algorithm for each conv. // // We pick the algorithm before fusion so we can generate better HLO. After diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index 5fbeb08e133..9e643bd4b99 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -307,31 +307,6 @@ ENTRY AddDotsFunc { )"); } -TEST_F(GemmRewriteTest, BiasDifferentLayoutNoRewrite) { - const char* hlo_text = R"( -HloModule BiasDifferentLayoutNoRewrite - -ENTRY AddDotsFunc { - x = f32[2,2]{1,0} parameter(0) - y = f32[2,2]{1,0} parameter(1) - bias = f32[2,2]{0,1} parameter(2) - dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out = f32[2,2] add(dot, bias) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { -; CHECK-NEXT: %x = f32[2,2]{1,0} parameter(0) -; CHECK-NEXT: %y = f32[2,2]{1,0} parameter(1) -; CHECK-NEXT: %custom-call = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{selected_algorithm:{{[0-9]+}},alpha_real:1,dot_dimension_numbers:{lhs_contracting_dimensions:[1],rhs_contracting_dimensions:[0],lhs_batch_dimensions:[],rhs_batch_dimensions:[]},batch_size:1}" - )"); -} - TEST_F(GemmRewriteTest, SharedBufferAssignment) { const char* hlo_text = R"( HloModule SharedBufferAssignment diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index cc1fa819dac..b1bc231f111 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1,6 +1,6 @@ # Tests of TensorFlow kernels written using the Python API. -load("//tensorflow:tensorflow.bzl", "tf_py_test", "sycl_py_test", "tf_custom_op_library") +load("//tensorflow:tensorflow.bzl", "sycl_py_test", "tf_custom_op_library", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") package( @@ -3525,7 +3525,7 @@ cuda_py_test( ], # TODO(b/127344411): This test passes because XLA does not actually cluster # the svd op. - xla_enable_strict_auto_jit = False, + xla_enable_strict_auto_jit = True, ) cuda_py_test(