diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index ca1b2018d6c..2879acecbce 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -180,14 +180,6 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( return Status::OK(); } -static DotDimensionNumbers GetGemmDotDimensionNumbers( - const HloInstruction* instr) { - CHECK(IsCublasGemm(*instr)); - return instr->backend_config() - .ConsumeValueOrDie() - .dot_dimension_numbers(); -} - Status GpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { // Add convolution constraints in reverse postorder that the earliest @@ -202,14 +194,16 @@ Status GpuLayoutAssignment::AddBackendConstraints( Cast(instruction), constraints)); } + CHECK(!IsCublasGemm(*instruction)) + << "Gemm rewriting should run after layout assignment"; + // For batched dot we require the default layout. // TODO(b/112111608): This is overly conservative, the only real restriction // is that batch dimensions must be major. - if (IsCublasGemm(*instruction) && - GetGemmDotDimensionNumbers(instruction).lhs_batch_dimensions_size() > - 0) { + if (IsMatrixMultiplication(*instruction) && + instruction->dot_dimension_numbers().lhs_batch_dimensions_size() > 0) { // Verify that the batch dims come before the row and col dims. - DotDimensionNumbers dim_nums = GetGemmDotDimensionNumbers(instruction); + DotDimensionNumbers dim_nums = instruction->dot_dimension_numbers(); CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), dim_nums.rhs_batch_dimensions_size()); CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index e08c652e93b..d0d1abf5946 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -351,9 +351,6 @@ TEST_F(LayoutAssignmentTest, DotLayout) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseHloString(hlo_text)); - GemmRewriter gemm_rewriter_pass; - TF_ASSERT_OK_AND_ASSIGN(bool changed, gemm_rewriter_pass.Run(module.get())); - EXPECT_TRUE(changed); ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape(), @@ -366,8 +363,8 @@ TEST_F(LayoutAssignmentTest, DotLayout) { Shape expected_shape = ShapeUtil::MakeShapeWithLayout(F32, {8, 8, 256, 64}, {3, 2, 1, 0}); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::CustomCall(op::ShapeWithLayout(expected_shape), - op::ShapeWithLayout(expected_shape))); + op::Dot(op::ShapeWithLayout(expected_shape), + op::ShapeWithLayout(expected_shape))); } TEST_F(LayoutAssignmentTest, SortLayout) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index ef2bb4ecdfe..168156edf8e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1788,6 +1788,7 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( // first. if (gemm_config.beta() != 0.0) { const HloInstruction* bias = inst->operand(2); + CHECK_EQ(bias->shape(), inst->shape()); if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) { std::vector> thunks; thunks.push_back(absl::make_unique( diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index cb029b65935..06269bd4e48 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -265,12 +265,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } - { - HloPassPipeline pipeline("gemm_canonicalization"); - pipeline.AddPass(); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - } - { // Convert convolutions into CustomCalls to cudnn, then canonicalize them // (CudnnConvPaddingLegalization). Also expand cuSolver calls. @@ -323,6 +317,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, options.set_is_layout_sensitive(true); pipeline.AddPass>(options); + // Rewrite GEMMs into custom calls. + pipeline.AddPass(); + // Choose the fastest algorithm for each conv. // // We pick the algorithm before fusion so we can generate better HLO. After diff --git a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc index 5fbeb08e133..9e643bd4b99 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gemm_rewrite_test.cc @@ -307,31 +307,6 @@ ENTRY AddDotsFunc { )"); } -TEST_F(GemmRewriteTest, BiasDifferentLayoutNoRewrite) { - const char* hlo_text = R"( -HloModule BiasDifferentLayoutNoRewrite - -ENTRY AddDotsFunc { - x = f32[2,2]{1,0} parameter(0) - y = f32[2,2]{1,0} parameter(1) - bias = f32[2,2]{0,1} parameter(2) - dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT out = f32[2,2] add(dot, bias) -} - -)"; - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); - MatchOptimizedHlo(hlo_text, - R"( - -; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] { -; CHECK-NEXT: %x = f32[2,2]{1,0} parameter(0) -; CHECK-NEXT: %y = f32[2,2]{1,0} parameter(1) -; CHECK-NEXT: %custom-call = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{selected_algorithm:{{[0-9]+}},alpha_real:1,dot_dimension_numbers:{lhs_contracting_dimensions:[1],rhs_contracting_dimensions:[0],lhs_batch_dimensions:[],rhs_batch_dimensions:[]},batch_size:1}" - )"); -} - TEST_F(GemmRewriteTest, SharedBufferAssignment) { const char* hlo_text = R"( HloModule SharedBufferAssignment diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index cc1fa819dac..b1bc231f111 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1,6 +1,6 @@ # Tests of TensorFlow kernels written using the Python API. -load("//tensorflow:tensorflow.bzl", "tf_py_test", "sycl_py_test", "tf_custom_op_library") +load("//tensorflow:tensorflow.bzl", "sycl_py_test", "tf_custom_op_library", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") package( @@ -3525,7 +3525,7 @@ cuda_py_test( ], # TODO(b/127344411): This test passes because XLA does not actually cluster # the svd op. - xla_enable_strict_auto_jit = False, + xla_enable_strict_auto_jit = True, ) cuda_py_test(