From f9cda530e4a34197f5d7ed0aa19f2d2f6477806a Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 25 Mar 2019 11:46:10 -0700 Subject: [PATCH] [XLA:GPU] layout_assignment: Force the default layout for FFT cuFFT wants a dim0 major layout, otherwise XLA:GPU hits a check fail later. PiperOrigin-RevId: 240186519 --- .../xla/service/gpu/gpu_layout_assignment.cc | 10 +++++++ .../service/gpu/gpu_layout_assignment_test.cc | 29 +++++++++++++++++++ .../compiler/xla/service/hlo_matchers.h | 1 + tensorflow/python/kernel_tests/linalg/BUILD | 2 +- 4 files changed, 41 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index a6d80f0b6dd..09ea2652341 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -214,6 +214,16 @@ Status GpuLayoutAssignment::AddBackendConstraints( constraints->SetOperandLayout(op1_shape, instruction, 1)); TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(output_shape, instruction)); + } else if (instruction->opcode() == HloOpcode::kFft) { + // cuFFT requires a dim0 major layout. + Shape op0_shape = instruction->operand(0)->shape(); + LayoutUtil::SetToDefaultLayout(&op0_shape); + Shape output_shape = instruction->shape(); + LayoutUtil::SetToDefaultLayout(&output_shape); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op0_shape, instruction, 0)); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(output_shape, instruction)); } else if (instruction->opcode() == HloOpcode::kSort && instruction->operand(0)->shape().rank() > 1) { // Make sure that all the operands and the output(s) have the same layout. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 3630c3e38c5..d9453aff69c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -402,6 +402,35 @@ TEST_F(LayoutAssignmentTest, SortLayout) { op::ShapeWithLayout(expected_shape))); } +TEST_F(LayoutAssignmentTest, FftLayout) { + const char* hlo_text = R"( + HloModule Fft_module + + ENTRY Fft { + input = c64[8,32]{0,1} parameter(0) + fft = c64[8,32] fft(input), fft_type=FFT, fft_length={32} + ROOT transpose = c64[32,8] transpose(fft), dimensions={1,0} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_text)); + + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); + GpuLayoutAssignment layout_assignment( + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + Shape expected_shape = ShapeUtil::MakeShapeWithLayout(C64, {8, 32}, {1, 0}); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::Transpose(op::ShapeWithLayout(expected_shape)))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Copy(op::Transpose(op::Fft(op::ShapeWithLayout(expected_shape))))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 756f4d2c6bc..0adfef7d9e7 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -201,6 +201,7 @@ HLO_MATCHER(Domain); HLO_MATCHER(DynamicSlice); HLO_MATCHER(DynamicUpdateSlice); HLO_MATCHER(Exp); +HLO_MATCHER(Fft); HLO_MATCHER(Floor); HLO_MATCHER(Fusion); HLO_MATCHER(AfterAll); diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 53815858e4c..c1aa99cc375 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -150,7 +150,7 @@ cuda_py_test( "noasan", # times out, b/63678675 "optonly", # times out, b/79171797 ], - xla_enable_strict_auto_jit = False, + xla_enable_strict_auto_jit = True, ) cuda_py_test(