[XLA:GPU] Convert batch norm custom calls to LHLO GPU dialect operations
- Note that custom call for batch norm inference is not currently generated by XLA:GPU, so that case is not handled here. - Fix issues in ThunkEmission for batch norm forward inference. This path is currently not exercised by XLA because BN inference is expanded earlier, but enabled it locally to test conversion to MLIR. PiperOrigin-RevId: 348486912 Change-Id: If6c2f04c35b5fa2d1e5535ad2e164eeec534239b
This commit is contained in:
parent
f6972931f2
commit
aee0999827
@ -155,6 +155,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
|
@ -334,4 +334,69 @@ ENTRY main {
|
||||
%bias = f16[32]{0} parameter(2)
|
||||
%side = f16[32]{0} parameter(3)
|
||||
ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias, f16[32]{0} %side), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":1}"
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
HloModule BatchNormForwardTraining
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK: "lmhlo_gpu.batch_norm_training"
|
||||
// CHECK-SAME: epsilon = 1.000000e-03 : f32
|
||||
// CHECK-SAME: feature_index = 3 : i64
|
||||
// CHECK-SAME: (memref<1x1x10x1xf32>, memref<1xf32>, memref<1xf32>, memref<1x1x10x1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
|
||||
ENTRY main {
|
||||
%input = f32[1,1,10,1]{3,2,1,0} parameter(0)
|
||||
%scale = f32[1]{0} parameter(1)
|
||||
%offset = f32[1]{0} parameter(2)
|
||||
%constant = f32[] constant(0.001)
|
||||
%constant_1 = s64[] constant(3)
|
||||
%custom-call = (f32[1,1,10,1]{3,2,1,0}, f32[1]{0}, f32[1]{0})
|
||||
custom-call(f32[1,1,10,1]{3,2,1,0} %input, f32[1]{0} %scale, f32[1]{0} %offset, f32[] %constant, s64[] %constant_1),
|
||||
custom_call_target="__cudnn$batchNormalizationForwardTraining"
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
HloModule BatchNormBackward
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK: "lmhlo_gpu.batch_norm_grad"
|
||||
// CHECK-SAME: epsilon = 1.000000e-03 : f32
|
||||
// CHECK-SAME: feature_index = 2 : i64
|
||||
// CHECK-SAME: (memref<2x2x2x1xf16>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2x2x2x1xf16>, memref<2x2x2x1xf16>, memref<2xf32>, memref<2xf32>)
|
||||
ENTRY main {
|
||||
%input = f16[2,2,2,1]{3,2,1,0} parameter(0)
|
||||
%scale = f32[2]{0} parameter(1)
|
||||
%mean = f32[2]{0} parameter(2)
|
||||
%stddev = f32[2]{0} parameter(3)
|
||||
%grad = f16[2,2,2,1]{3,2,1,0} parameter(4)
|
||||
%constant = f32[] constant(0.001)
|
||||
%constant_2 = s64[] constant(2)
|
||||
ROOT %custom-call = (f16[2,2,2,1]{3,2,1,0}, f32[2]{0}, f32[2]{0})
|
||||
custom-call(f16[2,2,2,1]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %stddev, f16[2,2,2,1]{3,2,1,0} %grad, f32[] %constant, s64[] %constant_2),
|
||||
custom_call_target="__cudnn$batchNormalizationBackward"
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
HloModule BatchNormForwardInference
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK: lmhlo_gpu.batch_norm_inference"
|
||||
// CHECK-SAME: epsilon = 1.000000e-03 : f32
|
||||
// CHECK-SAME: feature_index = 0 : i64
|
||||
// CHECK-SAME: (memref<2x2x2x2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2xf32>, memref<2x2x2x2xf32>) -> ()
|
||||
ENTRY main {
|
||||
%input = f32[2,2,2,2]{3,2,1,0} parameter(0)
|
||||
%offset = f32[2]{0} parameter(1)
|
||||
%scale = f32[2]{0} parameter(2)
|
||||
%mean = f32[2]{0} parameter(3)
|
||||
%variance = f32[2]{0} parameter(4)
|
||||
%constant = f32[] constant(0.001)
|
||||
%constant_1 = s64[] constant(0)
|
||||
ROOT %custom-call = f32[2,2,2,2]{3,2,1,0}
|
||||
custom-call(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[] %constant, s64[] %constant_1),
|
||||
custom_call_target="__cudnn$batchNormalizationForwardInference"
|
||||
}
|
@ -198,11 +198,16 @@ class XlaHloToLhloPass
|
||||
|
||||
} // namespace
|
||||
|
||||
// Creates MLIR operands corresponding to operands and results of the XLA HLO
|
||||
// instruction. If `num_operands` is not -1, then only the first `num_operands`
|
||||
// operands of the HLO instruction will be considered.
|
||||
Status LhloDialectEmitter::CreateOperands(
|
||||
HloInstruction* instr, llvm::SmallVectorImpl<Value>& operands,
|
||||
size_t& num_arguments, size_t& num_results) {
|
||||
for (const HloInstruction* operand : instr->operands()) {
|
||||
TF_RETURN_IF_ERROR(GetOrCreateView(operand, &operands));
|
||||
size_t& num_arguments, size_t& num_results,
|
||||
absl::optional<xla::int64> num_operands) {
|
||||
for (xla::int64 i = 0; i < num_operands.value_or(instr->operand_count());
|
||||
i++) {
|
||||
TF_RETURN_IF_ERROR(GetOrCreateView(instr->operand(i), &operands));
|
||||
}
|
||||
num_arguments = operands.size();
|
||||
TF_RETURN_IF_ERROR(GetOrCreateView(instr, &operands));
|
||||
@ -211,17 +216,17 @@ Status LhloDialectEmitter::CreateOperands(
|
||||
}
|
||||
|
||||
template <typename OpType>
|
||||
StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(HloInstruction* instr,
|
||||
size_t& num_arguments,
|
||||
size_t& num_results) {
|
||||
StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
|
||||
HloInstruction* instr, size_t& num_arguments, size_t& num_results,
|
||||
absl::optional<xla::int64> num_operands) {
|
||||
Location loc = getLocation(instr);
|
||||
std::pair<Identifier, Attribute> attrs[] = {
|
||||
{Identifier::get("name", builder_.getContext()),
|
||||
builder_.getStringAttr(instr->name())},
|
||||
};
|
||||
llvm::SmallVector<Value, 4> operands;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateOperands(instr, operands, num_arguments, num_results));
|
||||
TF_RETURN_IF_ERROR(CreateOperands(instr, operands, num_arguments, num_results,
|
||||
num_operands));
|
||||
return builder_.create<OpType>(loc, llvm::None, operands, attrs);
|
||||
}
|
||||
|
||||
@ -562,6 +567,10 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp(
|
||||
return EmitDnnConvolution(custom_call_instr);
|
||||
}
|
||||
|
||||
if (xla::gpu::IsCustomCallToDnnBatchNorm(*instr)) {
|
||||
return EmitDnnBatchNorm(custom_call_instr);
|
||||
}
|
||||
|
||||
size_t num_arguments, num_results;
|
||||
TF_ASSIGN_OR_RETURN(auto custom_call,
|
||||
CreateOpWithoutAttrs<lmhlo::CustomCallOp>(
|
||||
@ -773,6 +782,52 @@ StatusOr<Operation*> LhloDialectEmitter::EmitDnnConvolution(
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<Operation*> LhloDialectEmitter::EmitDnnBatchNorm(
|
||||
HloCustomCallInstruction* custom_call) {
|
||||
const xla::int64 num_operands = custom_call->operand_count();
|
||||
auto set_batchnorm_attributes = [&](auto op) -> StatusOr<Operation*> {
|
||||
// The last 2 operands of a custom call for batch norm are the epsilon and
|
||||
// feature_index.
|
||||
const HloInstruction* epsilon = custom_call->operand(num_operands - 2);
|
||||
TF_RET_CHECK(epsilon->IsConstant());
|
||||
float epsilon_value = epsilon->literal().Get<float>({});
|
||||
|
||||
const HloInstruction* feature_index =
|
||||
custom_call->operand(num_operands - 1);
|
||||
TF_RET_CHECK(feature_index->IsConstant());
|
||||
xla::int64 feature_index_value =
|
||||
feature_index->literal().Get<xla::int64>({});
|
||||
|
||||
op.epsilonAttr(builder_.getF32FloatAttr(epsilon_value));
|
||||
op.feature_indexAttr(builder_.getI64IntegerAttr(feature_index_value));
|
||||
return op.getOperation();
|
||||
};
|
||||
|
||||
const std::string& target = custom_call->custom_call_target();
|
||||
if (target == xla::gpu::kCudnnBatchNormForwardTrainingCallTarget) {
|
||||
TF_ASSIGN_OR_RETURN(auto fwd_training,
|
||||
CreateOpWithoutAttrs<lmhlo_gpu::BatchNormTrainingOp>(
|
||||
custom_call, num_operands - 2));
|
||||
return set_batchnorm_attributes(fwd_training);
|
||||
}
|
||||
|
||||
if (target == xla::gpu::kCudnnBatchNormBackwardCallTarget) {
|
||||
TF_ASSIGN_OR_RETURN(auto backward,
|
||||
CreateOpWithoutAttrs<lmhlo_gpu::BatchNormGradOp>(
|
||||
custom_call, num_operands - 2));
|
||||
return set_batchnorm_attributes(backward);
|
||||
}
|
||||
|
||||
if (target == xla::gpu::kCudnnBatchNormForwardInferenceCallTarget) {
|
||||
TF_ASSIGN_OR_RETURN(auto fwd_inference,
|
||||
CreateOpWithoutAttrs<lmhlo_gpu::BatchNormInferenceOp>(
|
||||
custom_call, num_operands - 2));
|
||||
return set_batchnorm_attributes(fwd_inference);
|
||||
}
|
||||
|
||||
return xla::Unimplemented("Unsupported batch norm operation");
|
||||
}
|
||||
|
||||
// Convert an XLA HLO constant to a global_memref + get_global_memref pair.
|
||||
StatusOr<mlir::GetGlobalMemrefOp> LhloDialectEmitter::EmitConstant(
|
||||
const HloInstruction* instr) {
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
@ -65,6 +66,8 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
::xla::HloCustomCallInstruction* custom_call);
|
||||
::xla::StatusOr<Operation*> EmitDnnConvolution(
|
||||
::xla::HloCustomCallInstruction* custom_call);
|
||||
::xla::StatusOr<Operation*> EmitDnnBatchNorm(
|
||||
::xla::HloCustomCallInstruction* custom_call);
|
||||
|
||||
::xla::StatusOr<lmhlo::ReduceOp> EmitReduceOp(::xla::HloInstruction* instr);
|
||||
::xla::StatusOr<GetGlobalMemrefOp> EmitConstant(
|
||||
@ -84,20 +87,23 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
::xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
|
||||
::xla::HloInstruction* instr);
|
||||
|
||||
::xla::Status CreateOperands(::xla::HloInstruction* instr,
|
||||
SmallVectorImpl<Value>& operands,
|
||||
size_t& num_arguments, size_t& num_results);
|
||||
::xla::Status CreateOperands(
|
||||
::xla::HloInstruction* instr, SmallVectorImpl<Value>& operands,
|
||||
size_t& num_arguments, size_t& num_results,
|
||||
absl::optional<xla::int64> num_operands = absl::nullopt);
|
||||
|
||||
template <typename OpType>
|
||||
::xla::StatusOr<OpType> CreateOpWithoutAttrs(::xla::HloInstruction* instr) {
|
||||
::xla::StatusOr<OpType> CreateOpWithoutAttrs(
|
||||
::xla::HloInstruction* instr,
|
||||
absl::optional<xla::int64> num_operands = absl::nullopt) {
|
||||
size_t unused;
|
||||
return CreateOpWithoutAttrs<OpType>(instr, unused, unused);
|
||||
return CreateOpWithoutAttrs<OpType>(instr, unused, unused, num_operands);
|
||||
}
|
||||
|
||||
template <typename OpType>
|
||||
::xla::StatusOr<OpType> CreateOpWithoutAttrs(::xla::HloInstruction* instr,
|
||||
size_t& num_arguments,
|
||||
size_t& num_results);
|
||||
::xla::StatusOr<OpType> CreateOpWithoutAttrs(
|
||||
::xla::HloInstruction* instr, size_t& num_arguments, size_t& num_results,
|
||||
absl::optional<xla::int64> num_operands = absl::nullopt);
|
||||
|
||||
template <typename T>
|
||||
DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) {
|
||||
|
@ -1031,9 +1031,12 @@ Status IrEmitterUnnested::EmitSliceToDynamicFromMlir(
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
using mlir::dyn_cast;
|
||||
using mlir::isa;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(custom_call));
|
||||
|
||||
if (auto call = mlir::dyn_cast<mlir::lmhlo::CustomCallOp>(input.op)) {
|
||||
if (auto call = dyn_cast<mlir::lmhlo::CustomCallOp>(input.op)) {
|
||||
if (call.call_target_name() == "PadToStatic") {
|
||||
return EmitPadToStaticFromMlir(input);
|
||||
}
|
||||
@ -1043,8 +1046,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
return ThunkEmitter(this).HandleCustomCall(custom_call);
|
||||
}
|
||||
|
||||
if (mlir::isa<mlir::lmhlo_gpu::GEMMOp, mlir::lmhlo_gpu::GEMM_BiasOp>(
|
||||
input.op)) {
|
||||
if (isa<mlir::lmhlo_gpu::GEMMOp, mlir::lmhlo_gpu::GEMM_BiasOp>(input.op)) {
|
||||
return EmitGemmThunkFromMlir(input);
|
||||
}
|
||||
|
||||
@ -1056,6 +1058,12 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
return EmitConvolutionThunkFromMlir(input);
|
||||
}
|
||||
|
||||
if (isa<mlir::lmhlo_gpu::BatchNormTrainingOp,
|
||||
mlir::lmhlo_gpu::BatchNormInferenceOp,
|
||||
mlir::lmhlo_gpu::BatchNormGradOp>(input.op)) {
|
||||
return ThunkEmitter(this).HandleCustomCall(custom_call);
|
||||
}
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
if (mlir::isa<mlir::lmhlo_gpu::CholeskyOp>(input.op)) {
|
||||
return EmitCholeskyThunkFromMlir(input);
|
||||
|
@ -219,8 +219,8 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
CHECK(feature_index->IsConstant());
|
||||
int64 feature_index_value = feature_index->literal().Get<int64>({});
|
||||
|
||||
CHECK_EQ(custom_call->shape().tuple_shapes_size(), 3);
|
||||
CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape().tuple_shapes(0),
|
||||
CHECK(custom_call->shape().IsArray());
|
||||
CHECK(LayoutUtil::LayoutsInShapesEqual(custom_call->shape(),
|
||||
custom_call->operand(0)->shape()));
|
||||
CheckBatchNormInputOutputPrimitivetypeAreValid(custom_call);
|
||||
CudnnBatchNormConfig config = GetCudnnBatchNormConfig(
|
||||
|
Loading…
Reference in New Issue
Block a user