[XLA GPU] Perform layout assignment before GEMM rewriting
GEMM rewriting is sensitive to layout assignment, as we can "fuse" the addition only if it has the same layout. Hence, layout assignment has to be performed first. PiperOrigin-RevId: 255295897
This commit is contained in:
parent
bf7fdcdb47
commit
5b469d2115
@ -180,14 +180,6 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static DotDimensionNumbers GetGemmDotDimensionNumbers(
|
||||
const HloInstruction* instr) {
|
||||
CHECK(IsCublasGemm(*instr));
|
||||
return instr->backend_config<GemmBackendConfig>()
|
||||
.ConsumeValueOrDie()
|
||||
.dot_dimension_numbers();
|
||||
}
|
||||
|
||||
Status GpuLayoutAssignment::AddBackendConstraints(
|
||||
LayoutConstraints* constraints) {
|
||||
// Add convolution constraints in reverse postorder that the earliest
|
||||
@ -202,14 +194,16 @@ Status GpuLayoutAssignment::AddBackendConstraints(
|
||||
Cast<HloCustomCallInstruction>(instruction), constraints));
|
||||
}
|
||||
|
||||
CHECK(!IsCublasGemm(*instruction))
|
||||
<< "Gemm rewriting should run after layout assignment";
|
||||
|
||||
// For batched dot we require the default layout.
|
||||
// TODO(b/112111608): This is overly conservative, the only real restriction
|
||||
// is that batch dimensions must be major.
|
||||
if (IsCublasGemm(*instruction) &&
|
||||
GetGemmDotDimensionNumbers(instruction).lhs_batch_dimensions_size() >
|
||||
0) {
|
||||
if (IsMatrixMultiplication(*instruction) &&
|
||||
instruction->dot_dimension_numbers().lhs_batch_dimensions_size() > 0) {
|
||||
// Verify that the batch dims come before the row and col dims.
|
||||
DotDimensionNumbers dim_nums = GetGemmDotDimensionNumbers(instruction);
|
||||
DotDimensionNumbers dim_nums = instruction->dot_dimension_numbers();
|
||||
CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
|
||||
dim_nums.rhs_batch_dimensions_size());
|
||||
CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2,
|
||||
|
@ -351,9 +351,6 @@ TEST_F(LayoutAssignmentTest, DotLayout) {
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseHloString(hlo_text));
|
||||
GemmRewriter gemm_rewriter_pass;
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, gemm_rewriter_pass.Run(module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
|
||||
ComputationLayout computation_layout(
|
||||
module->entry_computation()->ComputeProgramShape(),
|
||||
@ -366,7 +363,7 @@ TEST_F(LayoutAssignmentTest, DotLayout) {
|
||||
Shape expected_shape =
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {8, 8, 256, 64}, {3, 2, 1, 0});
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::CustomCall(op::ShapeWithLayout(expected_shape),
|
||||
op::Dot(op::ShapeWithLayout(expected_shape),
|
||||
op::ShapeWithLayout(expected_shape)));
|
||||
}
|
||||
|
||||
|
@ -1788,6 +1788,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
|
||||
// first.
|
||||
if (gemm_config.beta() != 0.0) {
|
||||
const HloInstruction* bias = inst->operand(2);
|
||||
CHECK_EQ(bias->shape(), inst->shape());
|
||||
if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) {
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
|
@ -265,12 +265,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
|
||||
}
|
||||
|
||||
{
|
||||
HloPassPipeline pipeline("gemm_canonicalization");
|
||||
pipeline.AddPass<GemmRewriter>();
|
||||
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
|
||||
}
|
||||
|
||||
{
|
||||
// Convert convolutions into CustomCalls to cudnn, then canonicalize them
|
||||
// (CudnnConvPaddingLegalization). Also expand cuSolver calls.
|
||||
@ -323,6 +317,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
options.set_is_layout_sensitive(true);
|
||||
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
|
||||
|
||||
// Rewrite GEMMs into custom calls.
|
||||
pipeline.AddPass<GemmRewriter>();
|
||||
|
||||
// Choose the fastest algorithm for each conv.
|
||||
//
|
||||
// We pick the algorithm before fusion so we can generate better HLO. After
|
||||
|
@ -307,31 +307,6 @@ ENTRY AddDotsFunc {
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(GemmRewriteTest, BiasDifferentLayoutNoRewrite) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule BiasDifferentLayoutNoRewrite
|
||||
|
||||
ENTRY AddDotsFunc {
|
||||
x = f32[2,2]{1,0} parameter(0)
|
||||
y = f32[2,2]{1,0} parameter(1)
|
||||
bias = f32[2,2]{0,1} parameter(2)
|
||||
dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
ROOT out = f32[2,2] add(dot, bias)
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
||||
MatchOptimizedHlo(hlo_text,
|
||||
R"(
|
||||
|
||||
; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] {
|
||||
; CHECK-NEXT: %x = f32[2,2]{1,0} parameter(0)
|
||||
; CHECK-NEXT: %y = f32[2,2]{1,0} parameter(1)
|
||||
; CHECK-NEXT: %custom-call = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{selected_algorithm:{{[0-9]+}},alpha_real:1,dot_dimension_numbers:{lhs_contracting_dimensions:[1],rhs_contracting_dimensions:[0],lhs_batch_dimensions:[],rhs_batch_dimensions:[]},batch_size:1}"
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(GemmRewriteTest, SharedBufferAssignment) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule SharedBufferAssignment
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Tests of TensorFlow kernels written using the Python API.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_test", "sycl_py_test", "tf_custom_op_library")
|
||||
load("//tensorflow:tensorflow.bzl", "sycl_py_test", "tf_custom_op_library", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
package(
|
||||
@ -3525,7 +3525,7 @@ cuda_py_test(
|
||||
],
|
||||
# TODO(b/127344411): This test passes because XLA does not actually cluster
|
||||
# the svd op.
|
||||
xla_enable_strict_auto_jit = False,
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
|
Loading…
Reference in New Issue
Block a user