[XLA:GPU] Migrate GEMM Thunk emission to MLIR.

- Map Custom call for GEMM in XLA HLO to Gemm/Gemm bias operations in LHLO GPU
  dialect.
- Make 'algorithm' an optional attribute to better match with XLA HLO backend config.
- Replace 'alpha' with 'alpha_real' and 'alpha_complex' to support complex GEMM correctly.
- Generate GemmThunk off of LHLO GPU Gemm operations.

PiperOrigin-RevId: 345250840
Change-Id: Ia1ffffd8aa09dbc49e8cbdf7402975700d60fda7
This commit is contained in:
Rahul Joshi 2020-12-02 09:42:26 -08:00 committed by TensorFlower Gardener
parent de8ce88945
commit bea9ecb9aa
8 changed files with 193 additions and 7 deletions

View File

@ -179,9 +179,10 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemRead]>:$output,
DotDimensionNumbers:$dot_dimension_numbers,
F64Attr:$alpha,
F64Attr:$alpha_real,
F64Attr:$alpha_imag,
I64Attr:$batch_size,
I64Attr:$algorithm);
OptionalAttr<I64Attr>:$algorithm);
}
// output = alpha(lhs * rhs) + beta * bias
@ -192,10 +193,11 @@ def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> {
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
Arg<LHLO_Buffer, "", [MemRead]>:$output,
DotDimensionNumbers:$dot_dimension_numbers,
F64Attr:$alpha,
F64Attr:$alpha_real,
F64Attr:$alpha_imag,
F64Attr:$beta,
I64Attr:$batch_size,
I64Attr:$algorithm);
OptionalAttr<I64Attr>:$algorithm);
}
def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {

View File

@ -65,7 +65,8 @@ func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
alpha = 0.5,
alpha_real = 0.5,
alpha_imag = 0.0,
batch_size = 1,
algorithm = 0}
: (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> ()
@ -81,7 +82,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
alpha = 0.5,
alpha_real = 0.5,
alpha_imag = 0.0,
beta = 1.0,
batch_size = 1,
algorithm = 0}

View File

@ -149,6 +149,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/gpu:backend_configs_cc",
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
"@llvm-project//llvm:Support",

View File

@ -108,7 +108,6 @@ ENTRY main {
// -----
HloModule Cholesky
// CHECK-LABEL: func @main
@ -121,3 +120,52 @@ ENTRY main {
operand_layout_constraints={f32[3,3]},
backend_config="{\"lower\":true}"
}
// -----
HloModule Gemm
// CHECK-LABEL: func @main
// CHECK: "lmhlo_gpu.gemm"
// CHECK-SAME: algorithm = 7 : i64
// CHECK-SAME: alpha_imag = 0.000000e+00 : f64
// CHECK-SAME: alpha_real = 1.000000e+00 : f64
// CHECK-SAME: batch_size = 1 : i64
// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>
// CHECK-SAME: lhs_contracting_dimensions = dense<1> : tensor<1xi64>
// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>
// CHECK-SAME: rhs_contracting_dimensions = dense<0> : tensor<1xi64>
// CHECK: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
ENTRY main {
%A = f32[2,2]{1,0} parameter(0)
%B = f32[2,2]{1,0} parameter(1)
ROOT %sgemm = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %A, f32[2,2]{1,0} %B),
custom_call_target="__cublas$gemm",
backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"selected_algorithm\":\"7\"}"
}
// -----
HloModule GemmBias
// CHECK-LABEL: func @main
// CHECK: "lmhlo_gpu.gemm_bias"
// CHECK-SAME: algorithm = 0 : i64
// CHECK-SAME: alpha_imag = 0.000000e+00 : f64
// CHECK-SAME: alpha_real = 1.000000e+00 : f64
// CHECK-SAME: batch_size = 1 : i64
// CHECK-SAME: beta = 1.000000e+00 : f64
// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>
// CHECK-SAME: lhs_contracting_dimensions = dense<1> : tensor<1xi64>
// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>
// CHECK-SAME: rhs_contracting_dimensions = dense<0> : tensor<1xi64>
// CHECK: (memref<1x1xf32>, memref<1x4xf32>, memref<1x4xf32>, memref<1x4xf32>)
ENTRY main {
%A = f32[1,1]{1,0} parameter(0)
%B = f32[1,4]{1,0} parameter(1)
%C = f32[1,4]{1,0} parameter(2)
ROOT %sgemm_add = f32[1,4]{1,0} custom-call(f32[1,1]{0,1} %A, f32[1,4]{1,0} %B, f32[1,4]{1,0} %C),
custom_call_target="__cublas$gemm",
backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"selected_algorithm\":\"0\"}"
}

View File

@ -48,6 +48,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@ -502,6 +503,10 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp(
return EmitCholesky(custom_call_instr);
}
if (xla::gpu::IsCublasGemm(*instr)) {
return EmitGemm(custom_call_instr);
}
size_t num_arguments, num_results;
TF_ASSIGN_OR_RETURN(auto custom_call,
CreateOpWithoutAttrs<lmhlo::CustomCallOp>(
@ -527,6 +532,48 @@ StatusOr<lmhlo_gpu::CholeskyOp> LhloDialectEmitter::EmitCholesky(
return cholesky_op;
}
StatusOr<Operation*> LhloDialectEmitter::EmitGemm(
HloCustomCallInstruction* custom_call) {
TF_ASSIGN_OR_RETURN(
auto const config,
custom_call->backend_config<xla::gpu::GemmBackendConfig>());
auto set_common_attributes = [&](auto op) -> Operation* {
auto hlo_dims = config.dot_dimension_numbers();
auto mlir_dims = mhlo::DotDimensionNumbers::get(
GetI64DenseElementsAttr(hlo_dims.lhs_batch_dimensions()),
GetI64DenseElementsAttr(hlo_dims.rhs_batch_dimensions()),
GetI64DenseElementsAttr(hlo_dims.lhs_contracting_dimensions()),
GetI64DenseElementsAttr(hlo_dims.rhs_contracting_dimensions()),
builder_.getContext());
op.dot_dimension_numbersAttr(mlir_dims);
op.alpha_realAttr(builder_.getF64FloatAttr(config.alpha_real()));
op.alpha_imagAttr(builder_.getF64FloatAttr(config.alpha_imag()));
op.batch_sizeAttr(builder_.getI64IntegerAttr(config.batch_size()));
if (config.algorithm_case() ==
xla::gpu::GemmBackendConfig::kSelectedAlgorithm) {
op.algorithmAttr(builder_.getI64IntegerAttr(config.selected_algorithm()));
}
return op.getOperation();
};
if (custom_call->operand_count() == 2) {
TF_ASSIGN_OR_RETURN(auto gemm,
CreateOpWithoutAttrs<lmhlo_gpu::GEMMOp>(custom_call));
return set_common_attributes(gemm);
}
if (custom_call->operand_count() == 3) {
TF_ASSIGN_OR_RETURN(
auto gemm_bias,
CreateOpWithoutAttrs<lmhlo_gpu::GEMM_BiasOp>(custom_call));
gemm_bias.betaAttr(builder_.getF64FloatAttr(config.beta()));
return set_common_attributes(gemm_bias);
}
return xla::InvalidArgument("GEMM custom call should have 2 or 3 operands");
}
// Convert an XLA HLO constant to a global_memref + get_global_memref pair.
StatusOr<mlir::GetGlobalMemrefOp> LhloDialectEmitter::EmitConstant(
const HloInstruction* instr) {

View File

@ -61,6 +61,8 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
::xla::StatusOr<Operation*> EmitCustomCallOp(::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo_gpu::CholeskyOp> EmitCholesky(
::xla::HloCustomCallInstruction* custom_call);
::xla::StatusOr<Operation*> EmitGemm(
::xla::HloCustomCallInstruction* custom_call);
::xla::StatusOr<lmhlo::ReduceOp> EmitReduceOp(::xla::HloInstruction* instr);
::xla::StatusOr<GetGlobalMemrefOp> EmitConstant(

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
@ -60,6 +61,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
@ -952,6 +954,11 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
return ThunkEmitter(this).HandleCustomCall(custom_call);
}
if (mlir::isa<mlir::lmhlo_gpu::GEMMOp, mlir::lmhlo_gpu::GEMM_BiasOp>(
input.op)) {
return EmitGemmThunkFromMlir(input);
}
#if GOOGLE_CUDA
if (mlir::isa<mlir::lmhlo_gpu::CholeskyOp>(input.op)) {
return EmitCholeskyThunkFromMlir(input);
@ -962,6 +969,82 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
custom_call->custom_call_target());
}
Status IrEmitterUnnested::EmitGemmThunkFromMlir(MlirEmitterInput input) {
auto build_gemm_config = [](auto op) {
GpuGemmConfig config;
GemmBackendConfig& backend = config.backend_config;
config.output_shape = TypeToShape(op.output().getType());
config.lhs_shape = TypeToShape(op.lhs().getType());
config.rhs_shape = TypeToShape(op.rhs().getType());
backend.Clear();
if (op.algorithm()) {
backend.set_selected_algorithm(*op.algorithm());
}
backend.set_alpha_real(op.alpha_real().convertToDouble());
backend.set_alpha_imag(op.alpha_imag().convertToDouble());
backend.set_batch_size(op.batch_size());
auto& dims = *backend.mutable_dot_dimension_numbers();
auto mlir_dims = op.dot_dimension_numbers();
auto fill_dims = [](mlir::DenseElementsAttr mlir_dim, auto* config_attrs) {
for (llvm::APInt e : mlir_dim.getIntValues())
config_attrs->Add(e.getSExtValue());
};
fill_dims(mlir_dims.lhs_batching_dimensions(),
dims.mutable_lhs_batch_dimensions());
fill_dims(mlir_dims.rhs_batching_dimensions(),
dims.mutable_rhs_batch_dimensions());
fill_dims(mlir_dims.lhs_contracting_dimensions(),
dims.mutable_lhs_contracting_dimensions());
fill_dims(mlir_dims.rhs_contracting_dimensions(),
dims.mutable_rhs_contracting_dimensions());
return config;
};
GpuGemmConfig config;
BufferAllocation::Slice lhs, rhs, bias, output;
if (auto gemm = mlir::dyn_cast<mlir::lmhlo_gpu::GEMMOp>(input.op)) {
config = build_gemm_config(gemm);
TF_ASSIGN_OR_RETURN(lhs, GetAllocationSliceForMlir(gemm.lhs()));
TF_ASSIGN_OR_RETURN(rhs, GetAllocationSliceForMlir(gemm.rhs()));
TF_ASSIGN_OR_RETURN(output, GetAllocationSliceForMlir(gemm.output()));
} else if (auto gemm_bias =
mlir::dyn_cast<mlir::lmhlo_gpu::GEMM_BiasOp>(input.op)) {
config = build_gemm_config(gemm_bias);
config.backend_config.set_beta(gemm_bias.beta().convertToDouble());
TF_ASSIGN_OR_RETURN(lhs, GetAllocationSliceForMlir(gemm_bias.lhs()));
TF_ASSIGN_OR_RETURN(rhs, GetAllocationSliceForMlir(gemm_bias.rhs()));
TF_ASSIGN_OR_RETURN(bias, GetAllocationSliceForMlir(gemm_bias.bias()));
TF_ASSIGN_OR_RETURN(output, GetAllocationSliceForMlir(gemm_bias.output()));
// The bias is passed inside the output buffer. If those buffers are shared
// we can just use it, otherwise copy the bias values into the output buffer
// first.
if (bias != output) {
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo(),
/*source_buffer=*/bias,
/*destination_buffer=*/output,
/*mem_size=*/ShapeUtil::ByteSizeOf(config.output_shape)));
thunks.push_back(absl::make_unique<GemmThunk>(
input.thunk_info, std::move(config), lhs, rhs, output,
/*implements_whole_instruction=*/false));
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
input.thunk_info, std::move(thunks)));
return Status::OK();
}
}
AddThunkToThunkSequence(absl::make_unique<GemmThunk>(
input.thunk_info, std::move(config), lhs, rhs, output,
/*implements_whole_instruction=*/true));
return Status::OK();
}
#if GOOGLE_CUDA
Status IrEmitterUnnested::EmitCholeskyThunkFromMlir(MlirEmitterInput input) {
auto cholesky_op = ::mlir::cast<mlir::lmhlo_gpu::CholeskyOp>(input.op);

View File

@ -168,6 +168,7 @@ class IrEmitterUnnested : public IrEmitter,
Status HandleConditional(HloInstruction* conditional) override;
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
Status EmitGemmThunkFromMlir(MlirEmitterInput input);
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
Status EmitCholeskyThunkFromMlir(MlirEmitterInput input);
#endif // (defined(GOOGLE_CUDA) && GOOGLE_CUDA)