[XLA GPU] Perform layout assignment before GEMM rewriting

GEMM rewriting is sensitive to layout assignment, as we can "fuse" the addition
only if it has the same layout.
Hence, layout assignment has to be performed first.

PiperOrigin-RevId: 255295897
This commit is contained in:
George Karpenkov 2019-06-26 17:11:54 -07:00 committed by TensorFlower Gardener
parent bf7fdcdb47
commit 5b469d2115
6 changed files with 14 additions and 50 deletions

View File

@ -180,14 +180,6 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
return Status::OK(); return Status::OK();
} }
static DotDimensionNumbers GetGemmDotDimensionNumbers(
const HloInstruction* instr) {
CHECK(IsCublasGemm(*instr));
return instr->backend_config<GemmBackendConfig>()
.ConsumeValueOrDie()
.dot_dimension_numbers();
}
Status GpuLayoutAssignment::AddBackendConstraints( Status GpuLayoutAssignment::AddBackendConstraints(
LayoutConstraints* constraints) { LayoutConstraints* constraints) {
// Add convolution constraints in reverse postorder that the earliest // Add convolution constraints in reverse postorder that the earliest
@ -202,14 +194,16 @@ Status GpuLayoutAssignment::AddBackendConstraints(
Cast<HloCustomCallInstruction>(instruction), constraints)); Cast<HloCustomCallInstruction>(instruction), constraints));
} }
CHECK(!IsCublasGemm(*instruction))
<< "Gemm rewriting should run after layout assignment";
// For batched dot we require the default layout. // For batched dot we require the default layout.
// TODO(b/112111608): This is overly conservative, the only real restriction // TODO(b/112111608): This is overly conservative, the only real restriction
// is that batch dimensions must be major. // is that batch dimensions must be major.
if (IsCublasGemm(*instruction) && if (IsMatrixMultiplication(*instruction) &&
GetGemmDotDimensionNumbers(instruction).lhs_batch_dimensions_size() > instruction->dot_dimension_numbers().lhs_batch_dimensions_size() > 0) {
0) {
// Verify that the batch dims come before the row and col dims. // Verify that the batch dims come before the row and col dims.
DotDimensionNumbers dim_nums = GetGemmDotDimensionNumbers(instruction); DotDimensionNumbers dim_nums = instruction->dot_dimension_numbers();
CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
dim_nums.rhs_batch_dimensions_size()); dim_nums.rhs_batch_dimensions_size());
CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2,

View File

@ -351,9 +351,6 @@ TEST_F(LayoutAssignmentTest, DotLayout) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(hlo_text)); ParseHloString(hlo_text));
GemmRewriter gemm_rewriter_pass;
TF_ASSERT_OK_AND_ASSIGN(bool changed, gemm_rewriter_pass.Run(module.get()));
EXPECT_TRUE(changed);
ComputationLayout computation_layout( ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape(), module->entry_computation()->ComputeProgramShape(),
@ -366,8 +363,8 @@ TEST_F(LayoutAssignmentTest, DotLayout) {
Shape expected_shape = Shape expected_shape =
ShapeUtil::MakeShapeWithLayout(F32, {8, 8, 256, 64}, {3, 2, 1, 0}); ShapeUtil::MakeShapeWithLayout(F32, {8, 8, 256, 64}, {3, 2, 1, 0});
EXPECT_THAT(module->entry_computation()->root_instruction(), EXPECT_THAT(module->entry_computation()->root_instruction(),
op::CustomCall(op::ShapeWithLayout(expected_shape), op::Dot(op::ShapeWithLayout(expected_shape),
op::ShapeWithLayout(expected_shape))); op::ShapeWithLayout(expected_shape)));
} }
TEST_F(LayoutAssignmentTest, SortLayout) { TEST_F(LayoutAssignmentTest, SortLayout) {

View File

@ -1788,6 +1788,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
// first. // first.
if (gemm_config.beta() != 0.0) { if (gemm_config.beta() != 0.0) {
const HloInstruction* bias = inst->operand(2); const HloInstruction* bias = inst->operand(2);
CHECK_EQ(bias->shape(), inst->shape());
if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) { if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) {
std::vector<std::unique_ptr<Thunk>> thunks; std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(

View File

@ -265,12 +265,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
} }
{
HloPassPipeline pipeline("gemm_canonicalization");
pipeline.AddPass<GemmRewriter>();
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
{ {
// Convert convolutions into CustomCalls to cudnn, then canonicalize them // Convert convolutions into CustomCalls to cudnn, then canonicalize them
// (CudnnConvPaddingLegalization). Also expand cuSolver calls. // (CudnnConvPaddingLegalization). Also expand cuSolver calls.
@ -323,6 +317,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
options.set_is_layout_sensitive(true); options.set_is_layout_sensitive(true);
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options); pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
// Rewrite GEMMs into custom calls.
pipeline.AddPass<GemmRewriter>();
// Choose the fastest algorithm for each conv. // Choose the fastest algorithm for each conv.
// //
// We pick the algorithm before fusion so we can generate better HLO. After // We pick the algorithm before fusion so we can generate better HLO. After

View File

@ -307,31 +307,6 @@ ENTRY AddDotsFunc {
)"); )");
} }
TEST_F(GemmRewriteTest, BiasDifferentLayoutNoRewrite) {
const char* hlo_text = R"(
HloModule BiasDifferentLayoutNoRewrite
ENTRY AddDotsFunc {
x = f32[2,2]{1,0} parameter(0)
y = f32[2,2]{1,0} parameter(1)
bias = f32[2,2]{0,1} parameter(2)
dot = f32[2,2] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
ROOT out = f32[2,2] add(dot, bias)
}
)";
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
MatchOptimizedHlo(hlo_text,
R"(
; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[2,2], bias: f32[2,2]) -> f32[2,2] {
; CHECK-NEXT: %x = f32[2,2]{1,0} parameter(0)
; CHECK-NEXT: %y = f32[2,2]{1,0} parameter(1)
; CHECK-NEXT: %custom-call = f32[2,2]{1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{selected_algorithm:{{[0-9]+}},alpha_real:1,dot_dimension_numbers:{lhs_contracting_dimensions:[1],rhs_contracting_dimensions:[0],lhs_batch_dimensions:[],rhs_batch_dimensions:[]},batch_size:1}"
)");
}
TEST_F(GemmRewriteTest, SharedBufferAssignment) { TEST_F(GemmRewriteTest, SharedBufferAssignment) {
const char* hlo_text = R"( const char* hlo_text = R"(
HloModule SharedBufferAssignment HloModule SharedBufferAssignment

View File

@ -1,6 +1,6 @@
# Tests of TensorFlow kernels written using the Python API. # Tests of TensorFlow kernels written using the Python API.
load("//tensorflow:tensorflow.bzl", "tf_py_test", "sycl_py_test", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "sycl_py_test", "tf_custom_op_library", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test")
package( package(
@ -3525,7 +3525,7 @@ cuda_py_test(
], ],
# TODO(b/127344411): This test passes because XLA does not actually cluster # TODO(b/127344411): This test passes because XLA does not actually cluster
# the svd op. # the svd op.
xla_enable_strict_auto_jit = False, xla_enable_strict_auto_jit = True,
) )
cuda_py_test( cuda_py_test(