[XLA:GPU] layout_assignment: Force the default layout for FFT
cuFFT wants a dim0 major layout, otherwise XLA:GPU hits a check fail later. PiperOrigin-RevId: 240186519
This commit is contained in:
parent
9bb938c6f9
commit
f9cda530e4
@ -214,6 +214,16 @@ Status GpuLayoutAssignment::AddBackendConstraints(
|
|||||||
constraints->SetOperandLayout(op1_shape, instruction, 1));
|
constraints->SetOperandLayout(op1_shape, instruction, 1));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
constraints->SetInstructionLayout(output_shape, instruction));
|
constraints->SetInstructionLayout(output_shape, instruction));
|
||||||
|
} else if (instruction->opcode() == HloOpcode::kFft) {
|
||||||
|
// cuFFT requires a dim0 major layout.
|
||||||
|
Shape op0_shape = instruction->operand(0)->shape();
|
||||||
|
LayoutUtil::SetToDefaultLayout(&op0_shape);
|
||||||
|
Shape output_shape = instruction->shape();
|
||||||
|
LayoutUtil::SetToDefaultLayout(&output_shape);
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
constraints->SetOperandLayout(op0_shape, instruction, 0));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
constraints->SetInstructionLayout(output_shape, instruction));
|
||||||
} else if (instruction->opcode() == HloOpcode::kSort &&
|
} else if (instruction->opcode() == HloOpcode::kSort &&
|
||||||
instruction->operand(0)->shape().rank() > 1) {
|
instruction->operand(0)->shape().rank() > 1) {
|
||||||
// Make sure that all the operands and the output(s) have the same layout.
|
// Make sure that all the operands and the output(s) have the same layout.
|
||||||
|
@ -402,6 +402,35 @@ TEST_F(LayoutAssignmentTest, SortLayout) {
|
|||||||
op::ShapeWithLayout(expected_shape)));
|
op::ShapeWithLayout(expected_shape)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(LayoutAssignmentTest, FftLayout) {
|
||||||
|
const char* hlo_text = R"(
|
||||||
|
HloModule Fft_module
|
||||||
|
|
||||||
|
ENTRY Fft {
|
||||||
|
input = c64[8,32]{0,1} parameter(0)
|
||||||
|
fft = c64[8,32] fft(input), fft_type=FFT, fft_length={32}
|
||||||
|
ROOT transpose = c64[32,8] transpose(fft), dimensions={1,0}
|
||||||
|
})";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseHloString(hlo_text));
|
||||||
|
|
||||||
|
ComputationLayout computation_layout(
|
||||||
|
module->entry_computation()->ComputeProgramShape(),
|
||||||
|
/*ignore_layouts=*/false);
|
||||||
|
GpuLayoutAssignment layout_assignment(
|
||||||
|
&computation_layout, LayoutAssignment::InstructionCanChangeLayout,
|
||||||
|
backend().default_stream_executor());
|
||||||
|
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
|
||||||
|
|
||||||
|
Shape expected_shape = ShapeUtil::MakeShapeWithLayout(C64, {8, 32}, {1, 0});
|
||||||
|
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||||
|
op::Copy(op::Transpose(op::ShapeWithLayout(expected_shape))));
|
||||||
|
EXPECT_THAT(
|
||||||
|
module->entry_computation()->root_instruction(),
|
||||||
|
op::Copy(op::Transpose(op::Fft(op::ShapeWithLayout(expected_shape)))));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -201,6 +201,7 @@ HLO_MATCHER(Domain);
|
|||||||
HLO_MATCHER(DynamicSlice);
|
HLO_MATCHER(DynamicSlice);
|
||||||
HLO_MATCHER(DynamicUpdateSlice);
|
HLO_MATCHER(DynamicUpdateSlice);
|
||||||
HLO_MATCHER(Exp);
|
HLO_MATCHER(Exp);
|
||||||
|
HLO_MATCHER(Fft);
|
||||||
HLO_MATCHER(Floor);
|
HLO_MATCHER(Floor);
|
||||||
HLO_MATCHER(Fusion);
|
HLO_MATCHER(Fusion);
|
||||||
HLO_MATCHER(AfterAll);
|
HLO_MATCHER(AfterAll);
|
||||||
|
@ -150,7 +150,7 @@ cuda_py_test(
|
|||||||
"noasan", # times out, b/63678675
|
"noasan", # times out, b/63678675
|
||||||
"optonly", # times out, b/79171797
|
"optonly", # times out, b/79171797
|
||||||
],
|
],
|
||||||
xla_enable_strict_auto_jit = False,
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
|
Loading…
Reference in New Issue
Block a user