[XLA GPU] Perform layout assignment before GEMM rewriting

GEMM rewriting is sensitive to layout assignment, as we can "fuse" the addition
only if it has the same layout.
Hence, layout assignment has to be performed first.

PiperOrigin-RevId: 255295897
This commit is contained in:
George Karpenkov 2019-06-26 17:11:54 -07:00 committed by TensorFlower Gardener
parent bf7fdcdb47
commit 5b469d2115
6 changed files with 14 additions and 50 deletions

View File

@ -180,14 +180,6 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
return Status::OK();
}
static DotDimensionNumbers GetGemmDotDimensionNumbers(
const HloInstruction* instr) {
CHECK(IsCublasGemm(*instr));
return instr->backend_config<GemmBackendConfig>()
.ConsumeValueOrDie()
.dot_dimension_numbers();
}
Status GpuLayoutAssignment::AddBackendConstraints(
LayoutConstraints* constraints) {
// Add convolution constraints in reverse postorder that the earliest
@ -202,14 +194,16 @@ Status GpuLayoutAssignment::AddBackendConstraints(
Cast<HloCustomCallInstruction>(instruction), constraints));
}
CHECK(!IsCublasGemm(*instruction))
<< "Gemm rewriting should run after layout assignment";
// For batched dot we require the default layout.
// TODO(b/112111608): This is overly conservative, the only real restriction
// is that batch dimensions must be major.
if (IsCublasGemm(*instruction) &&
GetGemmDotDimensionNumbers(instruction).lhs_batch_dimensions_size() >
0) {
if (IsMatrixMultiplication(*instruction) &&
instruction->dot_dimension_numbers().lhs_batch_dimensions_size() > 0) {
// Verify that the batch dims come before the row and col dims.
DotDimensionNumbers dim_nums = GetGemmDotDimensionNumbers(instruction);
DotDimensionNumbers dim_nums = instruction->dot_dimension_numbers();
CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
dim_nums.rhs_batch_dimensions_size());
CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2,

View File

@ -351,9 +351,6 @@ TEST_F(LayoutAssignmentTest, DotLayout) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(hlo_text));
GemmRewriter gemm_rewriter_pass;
TF_ASSERT_OK_AND_ASSIGN(bool changed, gemm_rewriter_pass.Run(module.get()));
EXPECT_TRUE(changed);
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape(),
@ -366,8 +363,8 @@ TEST_F(LayoutAssignmentTest, DotLayout) {
Shape expected_shape =
ShapeUtil::MakeShapeWithLayout(F32, {8, 8, 256, 64}, {3, 2, 1, 0});
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::CustomCall(op::ShapeWithLayout(expected_shape),
op::ShapeWithLayout(expected_shape)));
op::Dot(op::ShapeWithLayout(expected_shape),
op::ShapeWithLayout(expected_shape)));
}
TEST_F(LayoutAssignmentTest, SortLayout) {

View File

@ -1788,6 +1788,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
// first.
if (gemm_config.beta() != 0.0) {
const HloInstruction* bias = inst->operand(2);
CHECK_EQ(bias->shape(), inst->shape());
if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) {
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(

View File

@ -265,12 +265,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
{
HloPassPipeline pipeline("gemm_canonicalization");
pipeline.AddPass<GemmRewriter>();
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
{
// Convert convolutions into CustomCalls to cudnn, then canonicalize them
// (CudnnConvPaddingLegalization). Also expand cuSolver calls.
@ -323,6 +317,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
options.set_is_layout_sensitive(true);
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
// Rewrite GEMMs into custom calls.
pipeline.AddPass<GemmRewriter>();
// Choose the fastest algorithm for each conv.
//
// We pick the algorithm before fusion so we can generate better HLO. After

View File

@ -307,31 +307,6 @@ ENTRY AddDotsFunc {
)");
}
TEST_F(GemmRewriteTest, BiasDifferentLayoutNoRewrite) {
const char* hlo_text = R"(
HloModule BiasDifferentLayoutNoRewrite
ENTRY AddDotsFunc {
x = f32[2,2]{1,0} parameter(0)
y = f32[2,2]{1,0} parameter(1)
bias = f32[2,2]{0,1} parameter(2)
dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
ROOT out = f32[2,2] add(dot, bias)
}
)";
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
MatchOptimizedHlo(hlo_text,
R"(
; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] {
; CHECK-NEXT: %x = f32[2,2]{1,0} parameter(0)
; CHECK-NEXT: %y = f32[2,2]{1,0} parameter(1)
; CHECK-NEXT: %custom-call = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{selected_algorithm:{{[0-9]+}},alpha_real:1,dot_dimension_numbers:{lhs_contracting_dimensions:[1],rhs_contracting_dimensions:[0],lhs_batch_dimensions:[],rhs_batch_dimensions:[]},batch_size:1}"
)");
}
TEST_F(GemmRewriteTest, SharedBufferAssignment) {
const char* hlo_text = R"(
HloModule SharedBufferAssignment

View File

@ -1,6 +1,6 @@
# Tests of TensorFlow kernels written using the Python API.
load("//tensorflow:tensorflow.bzl", "tf_py_test", "sycl_py_test", "tf_custom_op_library")
load("//tensorflow:tensorflow.bzl", "sycl_py_test", "tf_custom_op_library", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
package(
@ -3525,7 +3525,7 @@ cuda_py_test(
],
# TODO(b/127344411): This test passes because XLA does not actually cluster
# the svd op.
xla_enable_strict_auto_jit = False,
xla_enable_strict_auto_jit = True,
)
cuda_py_test(