[XLA:GPU] Migrate GEMM Thunk emission to MLIR.
- Map Custom call for GEMM in XLA HLO to Gemm/Gemm bias operations in LHLO GPU dialect. - Make 'algorithm' an optional attribute to better match with XLA HLO backend config. - Replace 'alpha' with 'alpha_real' and 'alpha_complex' to support complex GEMM correctly. - Generate GemmThunk off of LHLO GPU Gemm operations. PiperOrigin-RevId: 345250840 Change-Id: Ia1ffffd8aa09dbc49e8cbdf7402975700d60fda7
This commit is contained in:
parent
de8ce88945
commit
bea9ecb9aa
@ -179,9 +179,10 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
||||
DotDimensionNumbers:$dot_dimension_numbers,
|
||||
F64Attr:$alpha,
|
||||
F64Attr:$alpha_real,
|
||||
F64Attr:$alpha_imag,
|
||||
I64Attr:$batch_size,
|
||||
I64Attr:$algorithm);
|
||||
OptionalAttr<I64Attr>:$algorithm);
|
||||
}
|
||||
|
||||
// output = alpha(lhs * rhs) + beta * bias
|
||||
@ -192,10 +193,11 @@ def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> {
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
||||
DotDimensionNumbers:$dot_dimension_numbers,
|
||||
F64Attr:$alpha,
|
||||
F64Attr:$alpha_real,
|
||||
F64Attr:$alpha_imag,
|
||||
F64Attr:$beta,
|
||||
I64Attr:$batch_size,
|
||||
I64Attr:$algorithm);
|
||||
OptionalAttr<I64Attr>:$algorithm);
|
||||
}
|
||||
|
||||
def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
|
||||
|
@ -65,7 +65,8 @@ func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>
|
||||
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
|
||||
alpha = 0.5,
|
||||
alpha_real = 0.5,
|
||||
alpha_imag = 0.0,
|
||||
batch_size = 1,
|
||||
algorithm = 0}
|
||||
: (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> ()
|
||||
@ -81,7 +82,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
|
||||
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
|
||||
alpha = 0.5,
|
||||
alpha_real = 0.5,
|
||||
alpha_imag = 0.0,
|
||||
beta = 1.0,
|
||||
batch_size = 1,
|
||||
algorithm = 0}
|
||||
|
@ -149,6 +149,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service/gpu:backend_configs_cc",
|
||||
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
|
||||
"@llvm-project//llvm:Support",
|
||||
|
@ -108,7 +108,6 @@ ENTRY main {
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
HloModule Cholesky
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
@ -121,3 +120,52 @@ ENTRY main {
|
||||
operand_layout_constraints={f32[3,3]},
|
||||
backend_config="{\"lower\":true}"
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
HloModule Gemm
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK: "lmhlo_gpu.gemm"
|
||||
// CHECK-SAME: algorithm = 7 : i64
|
||||
// CHECK-SAME: alpha_imag = 0.000000e+00 : f64
|
||||
// CHECK-SAME: alpha_real = 1.000000e+00 : f64
|
||||
// CHECK-SAME: batch_size = 1 : i64
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>
|
||||
// CHECK-SAME: rhs_contracting_dimensions = dense<0> : tensor<1xi64>
|
||||
// CHECK: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
ENTRY main {
|
||||
%A = f32[2,2]{1,0} parameter(0)
|
||||
%B = f32[2,2]{1,0} parameter(1)
|
||||
ROOT %sgemm = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %A, f32[2,2]{1,0} %B),
|
||||
custom_call_target="__cublas$gemm",
|
||||
backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"selected_algorithm\":\"7\"}"
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
HloModule GemmBias
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK: "lmhlo_gpu.gemm_bias"
|
||||
// CHECK-SAME: algorithm = 0 : i64
|
||||
// CHECK-SAME: alpha_imag = 0.000000e+00 : f64
|
||||
// CHECK-SAME: alpha_real = 1.000000e+00 : f64
|
||||
// CHECK-SAME: batch_size = 1 : i64
|
||||
// CHECK-SAME: beta = 1.000000e+00 : f64
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>
|
||||
// CHECK-SAME: rhs_contracting_dimensions = dense<0> : tensor<1xi64>
|
||||
// CHECK: (memref<1x1xf32>, memref<1x4xf32>, memref<1x4xf32>, memref<1x4xf32>)
|
||||
ENTRY main {
|
||||
%A = f32[1,1]{1,0} parameter(0)
|
||||
%B = f32[1,4]{1,0} parameter(1)
|
||||
%C = f32[1,4]{1,0} parameter(2)
|
||||
ROOT %sgemm_add = f32[1,4]{1,0} custom-call(f32[1,1]{0,1} %A, f32[1,4]{1,0} %B, f32[1,4]{1,0} %C),
|
||||
custom_call_target="__cublas$gemm",
|
||||
backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"selected_algorithm\":\"0\"}"
|
||||
}
|
||||
|
||||
|
@ -48,6 +48,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
@ -502,6 +503,10 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp(
|
||||
return EmitCholesky(custom_call_instr);
|
||||
}
|
||||
|
||||
if (xla::gpu::IsCublasGemm(*instr)) {
|
||||
return EmitGemm(custom_call_instr);
|
||||
}
|
||||
|
||||
size_t num_arguments, num_results;
|
||||
TF_ASSIGN_OR_RETURN(auto custom_call,
|
||||
CreateOpWithoutAttrs<lmhlo::CustomCallOp>(
|
||||
@ -527,6 +532,48 @@ StatusOr<lmhlo_gpu::CholeskyOp> LhloDialectEmitter::EmitCholesky(
|
||||
return cholesky_op;
|
||||
}
|
||||
|
||||
StatusOr<Operation*> LhloDialectEmitter::EmitGemm(
|
||||
HloCustomCallInstruction* custom_call) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto const config,
|
||||
custom_call->backend_config<xla::gpu::GemmBackendConfig>());
|
||||
|
||||
auto set_common_attributes = [&](auto op) -> Operation* {
|
||||
auto hlo_dims = config.dot_dimension_numbers();
|
||||
auto mlir_dims = mhlo::DotDimensionNumbers::get(
|
||||
GetI64DenseElementsAttr(hlo_dims.lhs_batch_dimensions()),
|
||||
GetI64DenseElementsAttr(hlo_dims.rhs_batch_dimensions()),
|
||||
GetI64DenseElementsAttr(hlo_dims.lhs_contracting_dimensions()),
|
||||
GetI64DenseElementsAttr(hlo_dims.rhs_contracting_dimensions()),
|
||||
builder_.getContext());
|
||||
op.dot_dimension_numbersAttr(mlir_dims);
|
||||
op.alpha_realAttr(builder_.getF64FloatAttr(config.alpha_real()));
|
||||
op.alpha_imagAttr(builder_.getF64FloatAttr(config.alpha_imag()));
|
||||
op.batch_sizeAttr(builder_.getI64IntegerAttr(config.batch_size()));
|
||||
if (config.algorithm_case() ==
|
||||
xla::gpu::GemmBackendConfig::kSelectedAlgorithm) {
|
||||
op.algorithmAttr(builder_.getI64IntegerAttr(config.selected_algorithm()));
|
||||
}
|
||||
return op.getOperation();
|
||||
};
|
||||
|
||||
if (custom_call->operand_count() == 2) {
|
||||
TF_ASSIGN_OR_RETURN(auto gemm,
|
||||
CreateOpWithoutAttrs<lmhlo_gpu::GEMMOp>(custom_call));
|
||||
return set_common_attributes(gemm);
|
||||
}
|
||||
|
||||
if (custom_call->operand_count() == 3) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto gemm_bias,
|
||||
CreateOpWithoutAttrs<lmhlo_gpu::GEMM_BiasOp>(custom_call));
|
||||
gemm_bias.betaAttr(builder_.getF64FloatAttr(config.beta()));
|
||||
return set_common_attributes(gemm_bias);
|
||||
}
|
||||
|
||||
return xla::InvalidArgument("GEMM custom call should have 2 or 3 operands");
|
||||
}
|
||||
|
||||
// Convert an XLA HLO constant to a global_memref + get_global_memref pair.
|
||||
StatusOr<mlir::GetGlobalMemrefOp> LhloDialectEmitter::EmitConstant(
|
||||
const HloInstruction* instr) {
|
||||
|
@ -61,6 +61,8 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
::xla::StatusOr<Operation*> EmitCustomCallOp(::xla::HloInstruction* instr);
|
||||
::xla::StatusOr<lmhlo_gpu::CholeskyOp> EmitCholesky(
|
||||
::xla::HloCustomCallInstruction* custom_call);
|
||||
::xla::StatusOr<Operation*> EmitGemm(
|
||||
::xla::HloCustomCallInstruction* custom_call);
|
||||
|
||||
::xla::StatusOr<lmhlo::ReduceOp> EmitReduceOp(::xla::HloInstruction* instr);
|
||||
::xla::StatusOr<GetGlobalMemrefOp> EmitConstant(
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
@ -60,6 +61,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
|
||||
@ -952,6 +954,11 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
return ThunkEmitter(this).HandleCustomCall(custom_call);
|
||||
}
|
||||
|
||||
if (mlir::isa<mlir::lmhlo_gpu::GEMMOp, mlir::lmhlo_gpu::GEMM_BiasOp>(
|
||||
input.op)) {
|
||||
return EmitGemmThunkFromMlir(input);
|
||||
}
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
if (mlir::isa<mlir::lmhlo_gpu::CholeskyOp>(input.op)) {
|
||||
return EmitCholeskyThunkFromMlir(input);
|
||||
@ -962,6 +969,82 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
custom_call->custom_call_target());
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitGemmThunkFromMlir(MlirEmitterInput input) {
|
||||
auto build_gemm_config = [](auto op) {
|
||||
GpuGemmConfig config;
|
||||
GemmBackendConfig& backend = config.backend_config;
|
||||
config.output_shape = TypeToShape(op.output().getType());
|
||||
config.lhs_shape = TypeToShape(op.lhs().getType());
|
||||
config.rhs_shape = TypeToShape(op.rhs().getType());
|
||||
backend.Clear();
|
||||
if (op.algorithm()) {
|
||||
backend.set_selected_algorithm(*op.algorithm());
|
||||
}
|
||||
backend.set_alpha_real(op.alpha_real().convertToDouble());
|
||||
backend.set_alpha_imag(op.alpha_imag().convertToDouble());
|
||||
backend.set_batch_size(op.batch_size());
|
||||
|
||||
auto& dims = *backend.mutable_dot_dimension_numbers();
|
||||
auto mlir_dims = op.dot_dimension_numbers();
|
||||
|
||||
auto fill_dims = [](mlir::DenseElementsAttr mlir_dim, auto* config_attrs) {
|
||||
for (llvm::APInt e : mlir_dim.getIntValues())
|
||||
config_attrs->Add(e.getSExtValue());
|
||||
};
|
||||
fill_dims(mlir_dims.lhs_batching_dimensions(),
|
||||
dims.mutable_lhs_batch_dimensions());
|
||||
fill_dims(mlir_dims.rhs_batching_dimensions(),
|
||||
dims.mutable_rhs_batch_dimensions());
|
||||
fill_dims(mlir_dims.lhs_contracting_dimensions(),
|
||||
dims.mutable_lhs_contracting_dimensions());
|
||||
fill_dims(mlir_dims.rhs_contracting_dimensions(),
|
||||
dims.mutable_rhs_contracting_dimensions());
|
||||
return config;
|
||||
};
|
||||
|
||||
GpuGemmConfig config;
|
||||
BufferAllocation::Slice lhs, rhs, bias, output;
|
||||
|
||||
if (auto gemm = mlir::dyn_cast<mlir::lmhlo_gpu::GEMMOp>(input.op)) {
|
||||
config = build_gemm_config(gemm);
|
||||
TF_ASSIGN_OR_RETURN(lhs, GetAllocationSliceForMlir(gemm.lhs()));
|
||||
TF_ASSIGN_OR_RETURN(rhs, GetAllocationSliceForMlir(gemm.rhs()));
|
||||
TF_ASSIGN_OR_RETURN(output, GetAllocationSliceForMlir(gemm.output()));
|
||||
} else if (auto gemm_bias =
|
||||
mlir::dyn_cast<mlir::lmhlo_gpu::GEMM_BiasOp>(input.op)) {
|
||||
config = build_gemm_config(gemm_bias);
|
||||
config.backend_config.set_beta(gemm_bias.beta().convertToDouble());
|
||||
TF_ASSIGN_OR_RETURN(lhs, GetAllocationSliceForMlir(gemm_bias.lhs()));
|
||||
TF_ASSIGN_OR_RETURN(rhs, GetAllocationSliceForMlir(gemm_bias.rhs()));
|
||||
TF_ASSIGN_OR_RETURN(bias, GetAllocationSliceForMlir(gemm_bias.bias()));
|
||||
TF_ASSIGN_OR_RETURN(output, GetAllocationSliceForMlir(gemm_bias.output()));
|
||||
|
||||
// The bias is passed inside the output buffer. If those buffers are shared
|
||||
// we can just use it, otherwise copy the bias values into the output buffer
|
||||
// first.
|
||||
if (bias != output) {
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
Thunk::ThunkInfo(),
|
||||
/*source_buffer=*/bias,
|
||||
/*destination_buffer=*/output,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(config.output_shape)));
|
||||
thunks.push_back(absl::make_unique<GemmThunk>(
|
||||
input.thunk_info, std::move(config), lhs, rhs, output,
|
||||
/*implements_whole_instruction=*/false));
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
input.thunk_info, std::move(thunks)));
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
AddThunkToThunkSequence(absl::make_unique<GemmThunk>(
|
||||
input.thunk_info, std::move(config), lhs, rhs, output,
|
||||
/*implements_whole_instruction=*/true));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
Status IrEmitterUnnested::EmitCholeskyThunkFromMlir(MlirEmitterInput input) {
|
||||
auto cholesky_op = ::mlir::cast<mlir::lmhlo_gpu::CholeskyOp>(input.op);
|
||||
|
@ -168,6 +168,7 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
Status HandleConditional(HloInstruction* conditional) override;
|
||||
Status HandleConvolution(HloInstruction* convolution) override;
|
||||
Status HandleCustomCall(HloInstruction* custom_call) override;
|
||||
Status EmitGemmThunkFromMlir(MlirEmitterInput input);
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
|
||||
Status EmitCholeskyThunkFromMlir(MlirEmitterInput input);
|
||||
#endif // (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
|
||||
|
Loading…
Reference in New Issue
Block a user