From aee0999827f45f5865b3e34763ba92a63780e37f Mon Sep 17 00:00:00 2001 From: Rahul Joshi <jurahul@google.com> Date: Mon, 21 Dec 2020 10:18:39 -0800 Subject: [PATCH] [XLA:GPU] Convert batch norm custom calls to LHLO GPU dialect operations - Note that custom call for batch norm inference is not currently generated by XLA:GPU, so that case is not handled here. - Fix issues in ThunkEmission for batch norm forward inference. This path is currently not exercised by XLA because BN inference is expanded earlier, but enabled it locally to test conversion to MLIR. PiperOrigin-RevId: 348486912 Change-Id: If6c2f04c35b5fa2d1e5535ad2e164eeec534239b --- tensorflow/compiler/mlir/xla/BUILD | 1 + .../hlo_text_to_lhlo_no_opt.hlotxt | 65 +++++++++++++++++ .../xla/transforms/mhlo_to_lhlo_with_xla.cc | 71 ++++++++++++++++--- .../xla/transforms/mhlo_to_lhlo_with_xla.h | 22 +++--- .../xla/service/gpu/ir_emitter_unnested.cc | 14 +++- .../compiler/xla/service/gpu/thunk_emitter.cc | 4 +- 6 files changed, 156 insertions(+), 21 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 1b40370f94a..4d016d61114 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -155,6 +155,7 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:optional", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt index 34db86a4464..358454f339b 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt @@ -334,4 +334,69 @@ ENTRY main { %bias = f16[32]{0} parameter(2) %side = f16[32]{0} parameter(3) ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias, f16[32]{0} %side), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":1}" +} + +// ----- + +HloModule BatchNormForwardTraining + +// CHECK: func @main +// CHECK: "lmhlo_gpu.batch_norm_training" +// CHECK-SAME: epsilon = 1.000000e-03 : f32 +// CHECK-SAME: feature_index = 3 : i64 +// CHECK-SAME: (memref<1x1x10x1xf32>, memref<1xf32>, memref<1xf32>, memref<1x1x10x1xf32>, memref<1xf32>, memref<1xf32>) -> () + +ENTRY main { + %input = f32[1,1,10,1]{3,2,1,0} parameter(0) + %scale = f32[1]{0} parameter(1) + %offset = f32[1]{0} parameter(2) + %constant = f32[] constant(0.001) + %constant_1 = s64[] constant(3) + %custom-call = (f32[1,1,10,1]{3,2,1,0}, f32[1]{0}, f32[1]{0}) + custom-call(f32[1,1,10,1]{3,2,1,0} %input, f32[1]{0} %scale, f32[1]{0} %offset, f32[] %constant, s64[] %constant_1), + custom_call_target="__cudnn$batchNormalizationForwardTraining" +} + +// ----- + +HloModule BatchNormBackward + +// CHECK: func @main +// CHECK: "lmhlo_gpu.batch_norm_grad" +// CHECK-SAME: epsilon = 1.000000e-03 : f32 +// CHECK-SAME: feature_index = 2 : i64 +// CHECK-SAME: (memref<2x2x2x1xf16>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2x2x2x1xf16>, memref<2x2x2x1xf16>, memref<2xf32>, memref<2xf32>) +ENTRY main { + %input = f16[2,2,2,1]{3,2,1,0} parameter(0) + %scale = f32[2]{0} parameter(1) + %mean = f32[2]{0} parameter(2) + %stddev = f32[2]{0} parameter(3) + %grad = f16[2,2,2,1]{3,2,1,0} parameter(4) + %constant = f32[] constant(0.001) + %constant_2 = s64[] constant(2) + ROOT %custom-call = (f16[2,2,2,1]{3,2,1,0}, f32[2]{0}, f32[2]{0}) + custom-call(f16[2,2,2,1]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %stddev, f16[2,2,2,1]{3,2,1,0} %grad, f32[] %constant, s64[] %constant_2), + custom_call_target="__cudnn$batchNormalizationBackward" +} + +// ----- + +HloModule BatchNormForwardInference + +// CHECK: func @main +// CHECK: lmhlo_gpu.batch_norm_inference" +// CHECK-SAME: epsilon = 1.000000e-03 : f32 +// CHECK-SAME: feature_index = 0 : i64 +// CHECK-SAME: (memref<2x2x2x2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2x2x2x2xf32>) -> () +ENTRY main { + %input = f32[2,2,2,2]{3,2,1,0} parameter(0) + %offset = f32[2]{0} parameter(1) + %scale = f32[2]{0} parameter(2) + %mean = f32[2]{0} parameter(3) + %variance = f32[2]{0} parameter(4) + %constant = f32[] constant(0.001) + %constant_1 = s64[] constant(0) + ROOT %custom-call = f32[2,2,2,2]{3,2,1,0} + custom-call(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[] %constant, s64[] %constant_1), + custom_call_target="__cudnn$batchNormalizationForwardInference" } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index f4b216577aa..3bc8afc4ffe 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -198,11 +198,16 @@ class XlaHloToLhloPass } // namespace +// Creates MLIR operands corresponding to operands and results of the XLA HLO +// instruction. If `num_operands` is not -1, then only the first `num_operands` +// operands of the HLO instruction will be considered. Status LhloDialectEmitter::CreateOperands( HloInstruction* instr, llvm::SmallVectorImpl<Value>& operands, - size_t& num_arguments, size_t& num_results) { - for (const HloInstruction* operand : instr->operands()) { - TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands)); + size_t& num_arguments, size_t& num_results, + absl::optional<xla::int64> num_operands) { + for (xla::int64 i = 0; i < num_operands.value_or(instr->operand_count()); + i++) { + TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands)); } num_arguments = operands.size(); TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands)); @@ -211,17 +216,17 @@ Status LhloDialectEmitter::CreateOperands( } template <typename OpType> -StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(HloInstruction* instr, - size_t& num_arguments, - size_t& num_results) { +StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs( + HloInstruction* instr, size_t& num_arguments, size_t& num_results, + absl::optional<xla::int64> num_operands) { Location loc = getLocation(instr); std::pair<Identifier, Attribute> attrs[] = { {Identifier::get("name", builder_.getContext()), builder_.getStringAttr(instr->name())}, }; llvm::SmallVector<Value, 4> operands; - TF_RETURN_IF_ERROR( - CreateOperands(instr, operands, num_arguments, num_results)); + TF_RETURN_IF_ERROR(CreateOperands(instr, operands, num_arguments, num_results, + num_operands)); return builder_.create<OpType>(loc, llvm::None, operands, attrs); } @@ -562,6 +567,10 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp( return EmitDnnConvolution(custom_call_instr); } + if (xla::gpu::IsCustomCallToDnnBatchNorm(*instr)) { + return EmitDnnBatchNorm(custom_call_instr); + } + size_t num_arguments, num_results; TF_ASSIGN_OR_RETURN(auto custom_call, CreateOpWithoutAttrs<lmhlo::CustomCallOp>( @@ -773,6 +782,52 @@ StatusOr<Operation*> LhloDialectEmitter::EmitDnnConvolution( } } +StatusOr<Operation*> LhloDialectEmitter::EmitDnnBatchNorm( + HloCustomCallInstruction* custom_call) { + const xla::int64 num_operands = custom_call->operand_count(); + auto set_batchnorm_attributes = [&](auto op) -> StatusOr<Operation*> { + // The last 2 operands of a custom call for batch norm are the epsilon and + // feature_index. + const HloInstruction* epsilon = custom_call->operand(num_operands - 2); + TF_RET_CHECK(epsilon->IsConstant()); + float epsilon_value = epsilon->literal().Get<float>({}); + + const HloInstruction* feature_index = + custom_call->operand(num_operands - 1); + TF_RET_CHECK(feature_index->IsConstant()); + xla::int64 feature_index_value = + feature_index->literal().Get<xla::int64>({}); + + op.epsilonAttr(builder_.getF32FloatAttr(epsilon_value)); + op.feature_indexAttr(builder_.getI64IntegerAttr(feature_index_value)); + return op.getOperation(); + }; + + const std::string& target = custom_call->custom_call_target(); + if (target == xla::gpu::kCudnnBatchNormForwardTrainingCallTarget) { + TF_ASSIGN_OR_RETURN(auto fwd_training, + CreateOpWithoutAttrs<lmhlo_gpu::BatchNormTrainingOp>( + custom_call, num_operands - 2)); + return set_batchnorm_attributes(fwd_training); + } + + if (target == xla::gpu::kCudnnBatchNormBackwardCallTarget) { + TF_ASSIGN_OR_RETURN(auto backward, + CreateOpWithoutAttrs<lmhlo_gpu::BatchNormGradOp>( + custom_call, num_operands - 2)); + return set_batchnorm_attributes(backward); + } + + if (target == xla::gpu::kCudnnBatchNormForwardInferenceCallTarget) { + TF_ASSIGN_OR_RETURN(auto fwd_inference, + CreateOpWithoutAttrs<lmhlo_gpu::BatchNormInferenceOp>( + custom_call, num_operands - 2)); + return set_batchnorm_attributes(fwd_inference); + } + + return xla::Unimplemented("Unsupported batch norm operation"); +} + // Convert an XLA HLO constant to a global_memref + get_global_memref pair. StatusOr<mlir::GetGlobalMemrefOp> LhloDialectEmitter::EmitConstant( const HloInstruction* instr) { diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index 3c42d1012b1..73d15930e52 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ +#include "absl/types/optional.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -65,6 +66,8 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { ::xla::HloCustomCallInstruction* custom_call); ::xla::StatusOr<Operation*> EmitDnnConvolution( ::xla::HloCustomCallInstruction* custom_call); + ::xla::StatusOr<Operation*> EmitDnnBatchNorm( + ::xla::HloCustomCallInstruction* custom_call); ::xla::StatusOr<lmhlo::ReduceOp> EmitReduceOp(::xla::HloInstruction* instr); ::xla::StatusOr<GetGlobalMemrefOp> EmitConstant( @@ -84,20 +87,23 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { ::xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp( ::xla::HloInstruction* instr); - ::xla::Status CreateOperands(::xla::HloInstruction* instr, - SmallVectorImpl<Value>& operands, - size_t& num_arguments, size_t& num_results); + ::xla::Status CreateOperands( + ::xla::HloInstruction* instr, SmallVectorImpl<Value>& operands, + size_t& num_arguments, size_t& num_results, + absl::optional<xla::int64> num_operands = absl::nullopt); template <typename OpType> - ::xla::StatusOr<OpType> CreateOpWithoutAttrs(::xla::HloInstruction* instr) { + ::xla::StatusOr<OpType> CreateOpWithoutAttrs( + ::xla::HloInstruction* instr, + absl::optional<xla::int64> num_operands = absl::nullopt) { size_t unused; - return CreateOpWithoutAttrs<OpType>(instr, unused, unused); + return CreateOpWithoutAttrs<OpType>(instr, unused, unused, num_operands); } template <typename OpType> - ::xla::StatusOr<OpType> CreateOpWithoutAttrs(::xla::HloInstruction* instr, - size_t& num_arguments, - size_t& num_results); + ::xla::StatusOr<OpType> CreateOpWithoutAttrs( + ::xla::HloInstruction* instr, size_t& num_arguments, size_t& num_results, + absl::optional<xla::int64> num_operands = absl::nullopt); template <typename T> DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index ed6c831e1be..ba720872129 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1031,9 +1031,12 @@ Status IrEmitterUnnested::EmitSliceToDynamicFromMlir( } Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { + using mlir::dyn_cast; + using mlir::isa; + TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(custom_call)); - if (auto call = mlir::dyn_cast<mlir::lmhlo::CustomCallOp>(input.op)) { + if (auto call = dyn_cast<mlir::lmhlo::CustomCallOp>(input.op)) { if (call.call_target_name() == "PadToStatic") { return EmitPadToStaticFromMlir(input); } @@ -1043,8 +1046,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { return ThunkEmitter(this).HandleCustomCall(custom_call); } - if (mlir::isa<mlir::lmhlo_gpu::GEMMOp, mlir::lmhlo_gpu::GEMM_BiasOp>( - input.op)) { + if (isa<mlir::lmhlo_gpu::GEMMOp, mlir::lmhlo_gpu::GEMM_BiasOp>(input.op)) { return EmitGemmThunkFromMlir(input); } @@ -1056,6 +1058,12 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { return EmitConvolutionThunkFromMlir(input); } + if (isa<mlir::lmhlo_gpu::BatchNormTrainingOp, + mlir::lmhlo_gpu::BatchNormInferenceOp, + mlir::lmhlo_gpu::BatchNormGradOp>(input.op)) { + return ThunkEmitter(this).HandleCustomCall(custom_call); + } + #if GOOGLE_CUDA if (mlir::isa<mlir::lmhlo_gpu::CholeskyOp>(input.op)) { return EmitCholeskyThunkFromMlir(input); diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc index 4310edff679..568e153c518 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc @@ -219,8 +219,8 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { CHECK(feature_index->IsConstant()); int64 feature_index_value = feature_index->literal().Get<int64>({}); - CHECK_EQ(custom_call->shape().tuple_shapes_size(), 3); - CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape().tuple_shapes(0), + CHECK(custom_call->shape().IsArray()); + CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape(), custom_call->operand(0)->shape())); CheckBatchNormInputOutputPrimitivetypeAreValid(custom_call); CudnnBatchNormConfig config = GetCudnnBatchNormConfig(