[XLA:GPU] layout_assignment: Force the default layout for FFT
cuFFT wants a dim0 major layout, otherwise XLA:GPU hits a check fail later. PiperOrigin-RevId: 240186519
This commit is contained in:
		
							parent
							
								
									9bb938c6f9
								
							
						
					
					
						commit
						f9cda530e4
					
				| @ -214,6 +214,16 @@ Status GpuLayoutAssignment::AddBackendConstraints( | ||||
|           constraints->SetOperandLayout(op1_shape, instruction, 1)); | ||||
|       TF_RETURN_IF_ERROR( | ||||
|           constraints->SetInstructionLayout(output_shape, instruction)); | ||||
|     } else if (instruction->opcode() == HloOpcode::kFft) { | ||||
|       // cuFFT requires a dim0 major layout.
 | ||||
|       Shape op0_shape = instruction->operand(0)->shape(); | ||||
|       LayoutUtil::SetToDefaultLayout(&op0_shape); | ||||
|       Shape output_shape = instruction->shape(); | ||||
|       LayoutUtil::SetToDefaultLayout(&output_shape); | ||||
|       TF_RETURN_IF_ERROR( | ||||
|           constraints->SetOperandLayout(op0_shape, instruction, 0)); | ||||
|       TF_RETURN_IF_ERROR( | ||||
|           constraints->SetInstructionLayout(output_shape, instruction)); | ||||
|     } else if (instruction->opcode() == HloOpcode::kSort && | ||||
|                instruction->operand(0)->shape().rank() > 1) { | ||||
|       // Make sure that all the operands and the output(s) have the same layout.
 | ||||
|  | ||||
| @ -402,6 +402,35 @@ TEST_F(LayoutAssignmentTest, SortLayout) { | ||||
|                        op::ShapeWithLayout(expected_shape))); | ||||
| } | ||||
| 
 | ||||
| TEST_F(LayoutAssignmentTest, FftLayout) { | ||||
|   const char* hlo_text = R"( | ||||
|   HloModule Fft_module | ||||
| 
 | ||||
|   ENTRY Fft { | ||||
|     input = c64[8,32]{0,1} parameter(0) | ||||
|     fft = c64[8,32] fft(input), fft_type=FFT, fft_length={32} | ||||
|     ROOT transpose = c64[32,8] transpose(fft), dimensions={1,0} | ||||
|   })"; | ||||
| 
 | ||||
|   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, | ||||
|                           ParseHloString(hlo_text)); | ||||
| 
 | ||||
|   ComputationLayout computation_layout( | ||||
|       module->entry_computation()->ComputeProgramShape(), | ||||
|       /*ignore_layouts=*/false); | ||||
|   GpuLayoutAssignment layout_assignment( | ||||
|       &computation_layout, LayoutAssignment::InstructionCanChangeLayout, | ||||
|       backend().default_stream_executor()); | ||||
|   EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); | ||||
| 
 | ||||
|   Shape expected_shape = ShapeUtil::MakeShapeWithLayout(C64, {8, 32}, {1, 0}); | ||||
|   EXPECT_THAT(module->entry_computation()->root_instruction(), | ||||
|               op::Copy(op::Transpose(op::ShapeWithLayout(expected_shape)))); | ||||
|   EXPECT_THAT( | ||||
|       module->entry_computation()->root_instruction(), | ||||
|       op::Copy(op::Transpose(op::Fft(op::ShapeWithLayout(expected_shape))))); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| }  // namespace gpu
 | ||||
| }  // namespace xla
 | ||||
|  | ||||
| @ -201,6 +201,7 @@ HLO_MATCHER(Domain); | ||||
| HLO_MATCHER(DynamicSlice); | ||||
| HLO_MATCHER(DynamicUpdateSlice); | ||||
| HLO_MATCHER(Exp); | ||||
| HLO_MATCHER(Fft); | ||||
| HLO_MATCHER(Floor); | ||||
| HLO_MATCHER(Fusion); | ||||
| HLO_MATCHER(AfterAll); | ||||
|  | ||||
| @ -150,7 +150,7 @@ cuda_py_test( | ||||
|         "noasan",  # times out, b/63678675 | ||||
|         "optonly",  # times out, b/79171797 | ||||
|     ], | ||||
|     xla_enable_strict_auto_jit = False, | ||||
|     xla_enable_strict_auto_jit = True, | ||||
| ) | ||||
| 
 | ||||
| cuda_py_test( | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user