[XLA:GPU] layout_assignment: Force the default layout for FFT

cuFFT wants a dim0 major layout, otherwise XLA:GPU hits a check fail later.

PiperOrigin-RevId: 240186519
This commit is contained in:
Benjamin Kramer 2019-03-25 11:46:10 -07:00 committed by TensorFlower Gardener
parent 9bb938c6f9
commit f9cda530e4
4 changed files with 41 additions and 1 deletions

View File

@ -214,6 +214,16 @@ Status GpuLayoutAssignment::AddBackendConstraints(
constraints->SetOperandLayout(op1_shape, instruction, 1));
TF_RETURN_IF_ERROR(
constraints->SetInstructionLayout(output_shape, instruction));
} else if (instruction->opcode() == HloOpcode::kFft) {
// cuFFT requires a dim0 major layout.
Shape op0_shape = instruction->operand(0)->shape();
LayoutUtil::SetToDefaultLayout(&op0_shape);
Shape output_shape = instruction->shape();
LayoutUtil::SetToDefaultLayout(&output_shape);
TF_RETURN_IF_ERROR(
constraints->SetOperandLayout(op0_shape, instruction, 0));
TF_RETURN_IF_ERROR(
constraints->SetInstructionLayout(output_shape, instruction));
} else if (instruction->opcode() == HloOpcode::kSort &&
instruction->operand(0)->shape().rank() > 1) {
// Make sure that all the operands and the output(s) have the same layout.

View File

@ -402,6 +402,35 @@ TEST_F(LayoutAssignmentTest, SortLayout) {
op::ShapeWithLayout(expected_shape)));
}
TEST_F(LayoutAssignmentTest, FftLayout) {
const char* hlo_text = R"(
HloModule Fft_module
ENTRY Fft {
input = c64[8,32]{0,1} parameter(0)
fft = c64[8,32] fft(input), fft_type=FFT, fft_length={32}
ROOT transpose = c64[32,8] transpose(fft), dimensions={1,0}
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(hlo_text));
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape(),
/*ignore_layouts=*/false);
GpuLayoutAssignment layout_assignment(
&computation_layout, LayoutAssignment::InstructionCanChangeLayout,
backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
Shape expected_shape = ShapeUtil::MakeShapeWithLayout(C64, {8, 32}, {1, 0});
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Copy(op::Transpose(op::ShapeWithLayout(expected_shape))));
EXPECT_THAT(
module->entry_computation()->root_instruction(),
op::Copy(op::Transpose(op::Fft(op::ShapeWithLayout(expected_shape)))));
}
} // namespace
} // namespace gpu
} // namespace xla

View File

@ -201,6 +201,7 @@ HLO_MATCHER(Domain);
HLO_MATCHER(DynamicSlice);
HLO_MATCHER(DynamicUpdateSlice);
HLO_MATCHER(Exp);
HLO_MATCHER(Fft);
HLO_MATCHER(Floor);
HLO_MATCHER(Fusion);
HLO_MATCHER(AfterAll);

View File

@ -150,7 +150,7 @@ cuda_py_test(
"noasan", # times out, b/63678675
"optonly", # times out, b/79171797
],
xla_enable_strict_auto_jit = False,
xla_enable_strict_auto_jit = True,
)
cuda_py_test(