From bea9ecb9aad42dde0816a6e451320fdf5601f0a0 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Wed, 2 Dec 2020 09:42:26 -0800 Subject: [PATCH] [XLA:GPU] Migrate GEMM Thunk emission to MLIR. - Map Custom call for GEMM in XLA HLO to Gemm/Gemm bias operations in LHLO GPU dialect. - Make 'algorithm' an optional attribute to better match with XLA HLO backend config. - Replace 'alpha' with 'alpha_real' and 'alpha_complex' to support complex GEMM correctly. - Generate GemmThunk off of LHLO GPU Gemm operations. PiperOrigin-RevId: 345250840 Change-Id: Ia1ffffd8aa09dbc49e8cbdf7402975700d60fda7 --- .../mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td | 10 ++- .../compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir | 6 +- tensorflow/compiler/mlir/xla/BUILD | 1 + .../hlo_text_to_lhlo_no_opt.hlotxt | 50 ++++++++++- .../xla/transforms/mhlo_to_lhlo_with_xla.cc | 47 +++++++++++ .../xla/transforms/mhlo_to_lhlo_with_xla.h | 2 + .../xla/service/gpu/ir_emitter_unnested.cc | 83 +++++++++++++++++++ .../xla/service/gpu/ir_emitter_unnested.h | 1 + 8 files changed, 193 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td index 4613627c36d..f4f2d859fba 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td @@ -179,9 +179,10 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { Arg:$rhs, Arg:$output, DotDimensionNumbers:$dot_dimension_numbers, - F64Attr:$alpha, + F64Attr:$alpha_real, + F64Attr:$alpha_imag, I64Attr:$batch_size, - I64Attr:$algorithm); + OptionalAttr:$algorithm); } // output = alpha(lhs * rhs) + beta * bias @@ -192,10 +193,11 @@ def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> { Arg:$bias, Arg:$output, DotDimensionNumbers:$dot_dimension_numbers, - F64Attr:$alpha, + F64Attr:$alpha_real, + F64Attr:$alpha_imag, F64Attr:$beta, I64Attr:$batch_size, - I64Attr:$algorithm); + OptionalAttr:$algorithm); } def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir index a939cab6d10..bd5df3875f1 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir @@ -65,7 +65,8 @@ func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32> rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, - alpha = 0.5, + alpha_real = 0.5, + alpha_imag = 0.0, batch_size = 1, algorithm = 0} : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> () @@ -81,7 +82,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, - alpha = 0.5, + alpha_real = 0.5, + alpha_imag = 0.0, beta = 1.0, batch_size = 1, algorithm = 0} diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index ab7e21587f0..2daa8a86d37 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -149,6 +149,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:backend_configs_cc", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "@llvm-project//llvm:Support", diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt index 395d4bb8f9f..ce42ccf9054 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt @@ -108,7 +108,6 @@ ENTRY main { // ----- - HloModule Cholesky // CHECK-LABEL: func @main @@ -121,3 +120,52 @@ ENTRY main { operand_layout_constraints={f32[3,3]}, backend_config="{\"lower\":true}" } + +// ----- + +HloModule Gemm + +// CHECK-LABEL: func @main +// CHECK: "lmhlo_gpu.gemm" +// CHECK-SAME: algorithm = 7 : i64 +// CHECK-SAME: alpha_imag = 0.000000e+00 : f64 +// CHECK-SAME: alpha_real = 1.000000e+00 : f64 +// CHECK-SAME: batch_size = 1 : i64 +// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64> +// CHECK-SAME: lhs_contracting_dimensions = dense<1> : tensor<1xi64> +// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64> +// CHECK-SAME: rhs_contracting_dimensions = dense<0> : tensor<1xi64> +// CHECK: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () +ENTRY main { + %A = f32[2,2]{1,0} parameter(0) + %B = f32[2,2]{1,0} parameter(1) + ROOT %sgemm = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %A, f32[2,2]{1,0} %B), + custom_call_target="__cublas$gemm", + backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"selected_algorithm\":\"7\"}" +} + +// ----- + +HloModule GemmBias + +// CHECK-LABEL: func @main +// CHECK: "lmhlo_gpu.gemm_bias" +// CHECK-SAME: algorithm = 0 : i64 +// CHECK-SAME: alpha_imag = 0.000000e+00 : f64 +// CHECK-SAME: alpha_real = 1.000000e+00 : f64 +// CHECK-SAME: batch_size = 1 : i64 +// CHECK-SAME: beta = 1.000000e+00 : f64 +// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64> +// CHECK-SAME: lhs_contracting_dimensions = dense<1> : tensor<1xi64> +// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64> +// CHECK-SAME: rhs_contracting_dimensions = dense<0> : tensor<1xi64> +// CHECK: (memref<1x1xf32>, memref<1x4xf32>, memref<1x4xf32>, memref<1x4xf32>) +ENTRY main { + %A = f32[1,1]{1,0} parameter(0) + %B = f32[1,4]{1,0} parameter(1) + %C = f32[1,4]{1,0} parameter(2) + ROOT %sgemm_add = f32[1,4]{1,0} custom-call(f32[1,1]{0,1} %A, f32[1,4]{1,0} %B, f32[1,4]{1,0} %C), + custom_call_target="__cublas$gemm", + backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"selected_algorithm\":\"0\"}" +} + diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index fbe602d4a57..bf11d526c75 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -48,6 +48,7 @@ limitations under the License. #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -502,6 +503,10 @@ StatusOr LhloDialectEmitter::EmitCustomCallOp( return EmitCholesky(custom_call_instr); } + if (xla::gpu::IsCublasGemm(*instr)) { + return EmitGemm(custom_call_instr); + } + size_t num_arguments, num_results; TF_ASSIGN_OR_RETURN(auto custom_call, CreateOpWithoutAttrs( @@ -527,6 +532,48 @@ StatusOr LhloDialectEmitter::EmitCholesky( return cholesky_op; } +StatusOr LhloDialectEmitter::EmitGemm( + HloCustomCallInstruction* custom_call) { + TF_ASSIGN_OR_RETURN( + auto const config, + custom_call->backend_config()); + + auto set_common_attributes = [&](auto op) -> Operation* { + auto hlo_dims = config.dot_dimension_numbers(); + auto mlir_dims = mhlo::DotDimensionNumbers::get( + GetI64DenseElementsAttr(hlo_dims.lhs_batch_dimensions()), + GetI64DenseElementsAttr(hlo_dims.rhs_batch_dimensions()), + GetI64DenseElementsAttr(hlo_dims.lhs_contracting_dimensions()), + GetI64DenseElementsAttr(hlo_dims.rhs_contracting_dimensions()), + builder_.getContext()); + op.dot_dimension_numbersAttr(mlir_dims); + op.alpha_realAttr(builder_.getF64FloatAttr(config.alpha_real())); + op.alpha_imagAttr(builder_.getF64FloatAttr(config.alpha_imag())); + op.batch_sizeAttr(builder_.getI64IntegerAttr(config.batch_size())); + if (config.algorithm_case() == + xla::gpu::GemmBackendConfig::kSelectedAlgorithm) { + op.algorithmAttr(builder_.getI64IntegerAttr(config.selected_algorithm())); + } + return op.getOperation(); + }; + + if (custom_call->operand_count() == 2) { + TF_ASSIGN_OR_RETURN(auto gemm, + CreateOpWithoutAttrs(custom_call)); + return set_common_attributes(gemm); + } + + if (custom_call->operand_count() == 3) { + TF_ASSIGN_OR_RETURN( + auto gemm_bias, + CreateOpWithoutAttrs(custom_call)); + gemm_bias.betaAttr(builder_.getF64FloatAttr(config.beta())); + return set_common_attributes(gemm_bias); + } + + return xla::InvalidArgument("GEMM custom call should have 2 or 3 operands"); +} + // Convert an XLA HLO constant to a global_memref + get_global_memref pair. StatusOr LhloDialectEmitter::EmitConstant( const HloInstruction* instr) { diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index 6c66a670fd0..82144510669 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -61,6 +61,8 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { ::xla::StatusOr EmitCustomCallOp(::xla::HloInstruction* instr); ::xla::StatusOr EmitCholesky( ::xla::HloCustomCallInstruction* custom_call); + ::xla::StatusOr EmitGemm( + ::xla::HloCustomCallInstruction* custom_call); ::xla::StatusOr EmitReduceOp(::xla::HloInstruction* instr); ::xla::StatusOr EmitConstant( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 9f901ed835d..907e13d2175 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -60,6 +61,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" @@ -952,6 +954,11 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { return ThunkEmitter(this).HandleCustomCall(custom_call); } + if (mlir::isa( + input.op)) { + return EmitGemmThunkFromMlir(input); + } + #if GOOGLE_CUDA if (mlir::isa(input.op)) { return EmitCholeskyThunkFromMlir(input); @@ -962,6 +969,82 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { custom_call->custom_call_target()); } +Status IrEmitterUnnested::EmitGemmThunkFromMlir(MlirEmitterInput input) { + auto build_gemm_config = [](auto op) { + GpuGemmConfig config; + GemmBackendConfig& backend = config.backend_config; + config.output_shape = TypeToShape(op.output().getType()); + config.lhs_shape = TypeToShape(op.lhs().getType()); + config.rhs_shape = TypeToShape(op.rhs().getType()); + backend.Clear(); + if (op.algorithm()) { + backend.set_selected_algorithm(*op.algorithm()); + } + backend.set_alpha_real(op.alpha_real().convertToDouble()); + backend.set_alpha_imag(op.alpha_imag().convertToDouble()); + backend.set_batch_size(op.batch_size()); + + auto& dims = *backend.mutable_dot_dimension_numbers(); + auto mlir_dims = op.dot_dimension_numbers(); + + auto fill_dims = [](mlir::DenseElementsAttr mlir_dim, auto* config_attrs) { + for (llvm::APInt e : mlir_dim.getIntValues()) + config_attrs->Add(e.getSExtValue()); + }; + fill_dims(mlir_dims.lhs_batching_dimensions(), + dims.mutable_lhs_batch_dimensions()); + fill_dims(mlir_dims.rhs_batching_dimensions(), + dims.mutable_rhs_batch_dimensions()); + fill_dims(mlir_dims.lhs_contracting_dimensions(), + dims.mutable_lhs_contracting_dimensions()); + fill_dims(mlir_dims.rhs_contracting_dimensions(), + dims.mutable_rhs_contracting_dimensions()); + return config; + }; + + GpuGemmConfig config; + BufferAllocation::Slice lhs, rhs, bias, output; + + if (auto gemm = mlir::dyn_cast(input.op)) { + config = build_gemm_config(gemm); + TF_ASSIGN_OR_RETURN(lhs, GetAllocationSliceForMlir(gemm.lhs())); + TF_ASSIGN_OR_RETURN(rhs, GetAllocationSliceForMlir(gemm.rhs())); + TF_ASSIGN_OR_RETURN(output, GetAllocationSliceForMlir(gemm.output())); + } else if (auto gemm_bias = + mlir::dyn_cast(input.op)) { + config = build_gemm_config(gemm_bias); + config.backend_config.set_beta(gemm_bias.beta().convertToDouble()); + TF_ASSIGN_OR_RETURN(lhs, GetAllocationSliceForMlir(gemm_bias.lhs())); + TF_ASSIGN_OR_RETURN(rhs, GetAllocationSliceForMlir(gemm_bias.rhs())); + TF_ASSIGN_OR_RETURN(bias, GetAllocationSliceForMlir(gemm_bias.bias())); + TF_ASSIGN_OR_RETURN(output, GetAllocationSliceForMlir(gemm_bias.output())); + + // The bias is passed inside the output buffer. If those buffers are shared + // we can just use it, otherwise copy the bias values into the output buffer + // first. + if (bias != output) { + std::vector> thunks; + + thunks.push_back(absl::make_unique( + Thunk::ThunkInfo(), + /*source_buffer=*/bias, + /*destination_buffer=*/output, + /*mem_size=*/ShapeUtil::ByteSizeOf(config.output_shape))); + thunks.push_back(absl::make_unique( + input.thunk_info, std::move(config), lhs, rhs, output, + /*implements_whole_instruction=*/false)); + AddThunkToThunkSequence(absl::make_unique( + input.thunk_info, std::move(thunks))); + return Status::OK(); + } + } + + AddThunkToThunkSequence(absl::make_unique( + input.thunk_info, std::move(config), lhs, rhs, output, + /*implements_whole_instruction=*/true)); + return Status::OK(); +} + #if GOOGLE_CUDA Status IrEmitterUnnested::EmitCholeskyThunkFromMlir(MlirEmitterInput input) { auto cholesky_op = ::mlir::cast(input.op); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 5c6a06851d7..1869792548c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -168,6 +168,7 @@ class IrEmitterUnnested : public IrEmitter, Status HandleConditional(HloInstruction* conditional) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleCustomCall(HloInstruction* custom_call) override; + Status EmitGemmThunkFromMlir(MlirEmitterInput input); #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) Status EmitCholeskyThunkFromMlir(MlirEmitterInput input); #endif // (defined(GOOGLE_CUDA) && GOOGLE_CUDA)