[XLA:GPU] layout_assignment: Force the default layout for FFT
cuFFT wants a dim0 major layout, otherwise XLA:GPU hits a check fail later. PiperOrigin-RevId: 240186519
This commit is contained in:
parent
9bb938c6f9
commit
f9cda530e4
@ -214,6 +214,16 @@ Status GpuLayoutAssignment::AddBackendConstraints(
|
||||
constraints->SetOperandLayout(op1_shape, instruction, 1));
|
||||
TF_RETURN_IF_ERROR(
|
||||
constraints->SetInstructionLayout(output_shape, instruction));
|
||||
} else if (instruction->opcode() == HloOpcode::kFft) {
|
||||
// cuFFT requires a dim0 major layout.
|
||||
Shape op0_shape = instruction->operand(0)->shape();
|
||||
LayoutUtil::SetToDefaultLayout(&op0_shape);
|
||||
Shape output_shape = instruction->shape();
|
||||
LayoutUtil::SetToDefaultLayout(&output_shape);
|
||||
TF_RETURN_IF_ERROR(
|
||||
constraints->SetOperandLayout(op0_shape, instruction, 0));
|
||||
TF_RETURN_IF_ERROR(
|
||||
constraints->SetInstructionLayout(output_shape, instruction));
|
||||
} else if (instruction->opcode() == HloOpcode::kSort &&
|
||||
instruction->operand(0)->shape().rank() > 1) {
|
||||
// Make sure that all the operands and the output(s) have the same layout.
|
||||
|
@ -402,6 +402,35 @@ TEST_F(LayoutAssignmentTest, SortLayout) {
|
||||
op::ShapeWithLayout(expected_shape)));
|
||||
}
|
||||
|
||||
TEST_F(LayoutAssignmentTest, FftLayout) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule Fft_module
|
||||
|
||||
ENTRY Fft {
|
||||
input = c64[8,32]{0,1} parameter(0)
|
||||
fft = c64[8,32] fft(input), fft_type=FFT, fft_length={32}
|
||||
ROOT transpose = c64[32,8] transpose(fft), dimensions={1,0}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseHloString(hlo_text));
|
||||
|
||||
ComputationLayout computation_layout(
|
||||
module->entry_computation()->ComputeProgramShape(),
|
||||
/*ignore_layouts=*/false);
|
||||
GpuLayoutAssignment layout_assignment(
|
||||
&computation_layout, LayoutAssignment::InstructionCanChangeLayout,
|
||||
backend().default_stream_executor());
|
||||
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
|
||||
|
||||
Shape expected_shape = ShapeUtil::MakeShapeWithLayout(C64, {8, 32}, {1, 0});
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Copy(op::Transpose(op::ShapeWithLayout(expected_shape))));
|
||||
EXPECT_THAT(
|
||||
module->entry_computation()->root_instruction(),
|
||||
op::Copy(op::Transpose(op::Fft(op::ShapeWithLayout(expected_shape)))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -201,6 +201,7 @@ HLO_MATCHER(Domain);
|
||||
HLO_MATCHER(DynamicSlice);
|
||||
HLO_MATCHER(DynamicUpdateSlice);
|
||||
HLO_MATCHER(Exp);
|
||||
HLO_MATCHER(Fft);
|
||||
HLO_MATCHER(Floor);
|
||||
HLO_MATCHER(Fusion);
|
||||
HLO_MATCHER(AfterAll);
|
||||
|
@ -150,7 +150,7 @@ cuda_py_test(
|
||||
"noasan", # times out, b/63678675
|
||||
"optonly", # times out, b/79171797
|
||||
],
|
||||
xla_enable_strict_auto_jit = False,
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
|
Loading…
Reference in New Issue
Block a user