From 4aaabc836ea153118356377de809e52aff4de58a Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Wed, 4 Mar 2020 06:55:33 -0800 Subject: [PATCH] [XLA][MLIR] Reduce code bloat for LHLO->STD and HLO->STD patterns. PiperOrigin-RevId: 298840878 Change-Id: I781008f01b5c8e478d75ba282db9aa78da546ea1 --- tensorflow/compiler/mlir/xla/BUILD | 12 +- .../xla/transforms/hlo_legalize_to_lhlo.cc | 66 +-- .../xla/transforms/lhlo_legalize_to_affine.cc | 12 +- .../mlir/xla/transforms/map_hlo_to_lhlo_op.h | 70 +++ .../xla/transforms/map_xla_to_scalar_op.h | 471 +++++++----------- .../xla/transforms/xla_legalize_to_linalg.cc | 10 +- 6 files changed, 316 insertions(+), 325 deletions(-) create mode 100644 tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 3ae9c5549b3..830fb0789ba 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -133,16 +133,25 @@ cc_library( cc_library( name = "map_xla_to_scalar_op", - srcs = [], hdrs = ["transforms/map_xla_to_scalar_op.h"], deps = [ ":hlo", ":lhlo", + ":map_hlo_to_lhlo_op", "@llvm-project//llvm:support", "@llvm-project//mlir:StandardOps", ], ) +cc_library( + name = "map_hlo_to_lhlo_op", + hdrs = ["transforms/map_hlo_to_lhlo_op.h"], + deps = [ + ":hlo", + ":lhlo", + ], +) + cc_library( name = "hlo_shape_derivation", srcs = [], @@ -234,6 +243,7 @@ cc_library( ":hlo", ":hlo_shape_derivation", ":lhlo", + ":map_hlo_to_lhlo_op", "@com_google_absl//absl/memory", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index 068b5765886..0e8b342fb72 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h" +#include "tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" @@ -117,7 +118,7 @@ Value InsertAllocAndDealloc(Location loc, Value result, return alloc; } -template +template class HloToLhloOpConverter : public ConversionPattern { public: explicit HloToLhloOpConverter(MLIRContext* context) @@ -147,14 +148,14 @@ class HloToLhloOpConverter : public ConversionPattern { op->getLoc(), result.value(), shape_value, &rewriter)); } } - rewriter.create(op->getLoc(), llvm::None, buffer_args, - op->getAttrs()); + rewriter.create>(op->getLoc(), llvm::None, + buffer_args, op->getAttrs()); rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); return matchSuccess(); } }; -struct HloToLHloDynamicBroadcastInDimOpConverter +struct HloToLhloDynamicBroadcastInDimOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -178,7 +179,7 @@ struct HloToLHloDynamicBroadcastInDimOpConverter } }; -struct HloToLHloReduceOpConverter +struct HloToLhloReduceOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -438,36 +439,35 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off patterns->insert< - HloToLHloDynamicBroadcastInDimOpConverter, + HloToLhloDynamicBroadcastInDimOpConverter, HloToLhloFuncOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLhloOpConverter, - HloToLHloReduceOpConverter, - StdToLhloReturnOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloReduceOpConverter, HloToLhloTensorLoadOpConverter, - HloToLhloTensorStoreOpConverter + HloToLhloTensorStoreOpConverter, + StdToLhloReturnOpConverter >(context); // clang-format on } diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc index 2c550465302..32053950fed 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc @@ -31,11 +31,11 @@ namespace mlir { namespace xla_lhlo { namespace { -template -struct BinaryOpConverter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +template +struct BinaryOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(LhloOp op, + PatternMatchResult matchAndRewrite(LhloOpTy op, PatternRewriter& rewriter) const override { const auto& lhs = op.lhs(); const auto& rhs = op.rhs(); @@ -56,8 +56,8 @@ struct BinaryOpConverter : public OpRewritePattern { } auto l = rewriter.create(loc, lhs, induction_vars); auto r = rewriter.create(loc, rhs, induction_vars); - Value opResult = MapXlaOpToStdScalarOp( - llvm::cast(op), element_type, {l, r}, &rewriter); + Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( + op, element_type, {l, r}, &rewriter); if (opResult == nullptr) { return this->matchFailure(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h new file mode 100644 index 00000000000..9852c4a60dc --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_ + +#include + +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" + +namespace mlir { +namespace xla_hlo { + +template +struct HloToLhloOpImpl { + using Type = std::false_type; +}; +template +using HloToLhloOp = typename HloToLhloOpImpl::Type; + +#define MAP_HLO_TO_LHLO(OpName) \ + template <> \ + struct HloToLhloOpImpl { \ + using Type = xla_lhlo::OpName; \ + } + +MAP_HLO_TO_LHLO(AbsOp); +MAP_HLO_TO_LHLO(AddOp); +MAP_HLO_TO_LHLO(AndOp); +MAP_HLO_TO_LHLO(BroadcastInDimOp); +MAP_HLO_TO_LHLO(CeilOp); +MAP_HLO_TO_LHLO(ConstOp); +MAP_HLO_TO_LHLO(CompareOp); +MAP_HLO_TO_LHLO(ConvertOp); +MAP_HLO_TO_LHLO(CopyOp); +MAP_HLO_TO_LHLO(CosOp); +MAP_HLO_TO_LHLO(DivOp); +MAP_HLO_TO_LHLO(ExpOp); +MAP_HLO_TO_LHLO(IotaOp); +MAP_HLO_TO_LHLO(LogOp); +MAP_HLO_TO_LHLO(MaxOp); +MAP_HLO_TO_LHLO(MinOp); +MAP_HLO_TO_LHLO(MulOp); +MAP_HLO_TO_LHLO(NegOp); +MAP_HLO_TO_LHLO(ReduceOp); +MAP_HLO_TO_LHLO(RemOp); +MAP_HLO_TO_LHLO(SelectOp); +MAP_HLO_TO_LHLO(SignOp); +MAP_HLO_TO_LHLO(SubOp); +MAP_HLO_TO_LHLO(TanhOp); + +#undef MAP_HLO_TO_LHLO + +} // namespace xla_hlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_ diff --git a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h index aa3208f3e74..32489dea16d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h @@ -21,81 +21,63 @@ limitations under the License. #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h" namespace mlir { namespace xla_lhlo { +namespace impl { -template -struct ScalarOp; +// A struct to map LhloBinaryOpTy type to the corresponding floating-point and +// integer scalar operation types. +template +struct LhloToScalarOp; template <> -struct ScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::AddFOp; using IOp = ::mlir::AddIOp; }; template <> -struct ScalarOp { - using FOp = ::mlir::AddFOp; - using IOp = ::mlir::AddIOp; -}; -template <> -struct ScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::CmpFOp; using IOp = ::mlir::CmpIOp; }; template <> -struct ScalarOp { - using FOp = ::mlir::CmpFOp; - using IOp = ::mlir::CmpIOp; -}; -template <> -struct ScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::DivFOp; using IOp = ::mlir::SignedDivIOp; }; template <> -struct ScalarOp { - using FOp = ::mlir::DivFOp; - using IOp = ::mlir::SignedDivIOp; -}; -template <> -struct ScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::MulFOp; using IOp = ::mlir::MulIOp; }; template <> -struct ScalarOp { - using FOp = ::mlir::MulFOp; - using IOp = ::mlir::MulIOp; -}; -template <> -struct ScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::RemFOp; using IOp = ::mlir::SignedRemIOp; }; template <> -struct ScalarOp { - using FOp = ::mlir::RemFOp; - using IOp = ::mlir::SignedRemIOp; -}; -template <> -struct ScalarOp { - using FOp = ::mlir::SubFOp; - using IOp = ::mlir::SubIOp; -}; -template <> -struct ScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::SubFOp; using IOp = ::mlir::SubIOp; }; -template -using ScalarFOp = typename ScalarOp::FOp; -template -using ScalarIOp = typename ScalarOp::IOp; +template +struct ScalarOp { + using FOp = typename LhloToScalarOp::FOp; + using IOp = typename LhloToScalarOp::IOp; +}; + +// Alias for the map from LHLO binary op type to STD floating-point op type. +template +using ScalarFOp = typename ScalarOp::FOp; +// Alias for the map from LHLO binary op type to STD integer op type. +template +using ScalarIOp = typename ScalarOp::IOp; template -struct MapXlaOpToStdScalarOpImpl { +struct MapLhloOpToStdScalarOpImpl { Value operator()(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { return nullptr; @@ -103,7 +85,7 @@ struct MapXlaOpToStdScalarOpImpl { }; template -struct MapXlaOpToStdScalarOpImpl { +struct MapLhloOpToStdScalarOpImpl { Value operator()(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { return b->template create(loc, result_types, args, mlir::None); @@ -111,7 +93,7 @@ struct MapXlaOpToStdScalarOpImpl { }; template -struct MapXlaOpToStdScalarOpImpl { +struct MapLhloOpToStdScalarOpImpl { Value operator()(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { Type element_type = args.front().getType(); @@ -119,52 +101,34 @@ struct MapXlaOpToStdScalarOpImpl { return b->template create(loc, result_types, args, mlir::None); } - return MapXlaOpToStdScalarOpImpl{}(loc, result_types, args, b); + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); } }; -template -inline Value MapXlaOpToStdScalarOp(XlaOp xla_op, ArrayRef result_types, - ArrayRef args, OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl, FloatType, - ScalarFOp>{}(xla_op.getLoc(), - result_types, args, b); -} - -// TODO(ravishankarm): Find a way to reduce code-bloat in HLO and LHLO -// specialization. -template <> -inline Value MapXlaOpToStdScalarOp(xla_lhlo::AbsOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::AbsOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); +// Inserts the computation that corresponds to the body of the loop for lowered +// LHLO unary/binary op. Returns the value for the result. +template +inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl, FloatType, + ScalarFOp>{}(loc, result_types, + args, b); } template <> -inline Value MapXlaOpToStdScalarOp(xla_lhlo::AndOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); } + template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::AndOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); } template @@ -200,7 +164,8 @@ inline Optional getCmpPredicate( } template -inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op, +inline Value MapXlaCompareOpToStdScalarOp(Location loc, + StringRef comparison_direction, ArrayRef result_types, ArrayRef args, OpBuilder* b) { const auto& lhs = args[0]; @@ -208,101 +173,60 @@ inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op, Type element_type = lhs.getType(); if (element_type.isSignlessInteger()) { Optional predicate = - getCmpPredicate(xla_op.comparison_direction()); + getCmpPredicate(comparison_direction); assert(predicate.hasValue() && "expected valid comparison direction"); - return b->create>(xla_op.getLoc(), - predicate.getValue(), lhs, rhs); + return b->create>(loc, predicate.getValue(), lhs, + rhs); } if (element_type.isa()) { Optional predicate = - getCmpPredicate(xla_op.comparison_direction()); + getCmpPredicate(comparison_direction); assert(predicate.hasValue() && "expected valid comparison direction"); - return b->create>(xla_op.getLoc(), - predicate.getValue(), lhs, rhs); + return b->create>(loc, predicate.getValue(), lhs, + rhs); } return nullptr; } -template <> -inline Value MapXlaOpToStdScalarOp( - xla_lhlo::CompareOp xla_op, ArrayRef result_types, - ArrayRef args, OpBuilder* b) { - return MapXlaCompareOpToStdScalarOp(xla_op, result_types, - args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp( - xla_hlo::CompareOp xla_op, ArrayRef result_types, - ArrayRef args, OpBuilder* b) { - return MapXlaCompareOpToStdScalarOp(xla_op, result_types, - args, b); -} template <> -inline Value MapXlaOpToStdScalarOp( - xla_lhlo::CopyOp xla_op, ArrayRef result_types, ArrayRef args, +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { return args.front(); } -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::CopyOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return args.front(); -} template <> -inline Value MapXlaOpToStdScalarOp(xla_lhlo::ExpOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::ExpOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); -} - -template <> -inline Value MapXlaOpToStdScalarOp( - xla_lhlo::CeilOp xla_op, ArrayRef result_types, ArrayRef args, +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::CeilOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); } template <> -inline Value MapXlaOpToStdScalarOp( - xla_lhlo::ConvertOp xla_op, ArrayRef result_types, - ArrayRef args, OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { Type sourceType = args.front().getType(); Type targetType = result_types.front(); if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) { - return b->create(xla_op.getLoc(), result_types, args, - mlir::None); + return b->create(loc, result_types, args, mlir::None); } else if (sourceType.isa() && targetType.isa()) { FloatType src = sourceType.cast(); FloatType res = targetType.cast(); if (src.getWidth() > res.getWidth()) { - return b->create(xla_op.getLoc(), result_types, args, - mlir::None); + return b->create(loc, result_types, args, mlir::None); } else if (src.getWidth() < res.getWidth()) { - return b->create(xla_op.getLoc(), result_types, args, - mlir::None); + return b->create(loc, result_types, args, mlir::None); } // No conversion is needed for the same width floats return args.front(); @@ -311,10 +235,9 @@ inline Value MapXlaOpToStdScalarOp( IntegerType src = sourceType.cast(); IntegerType res = targetType.cast(); if (src.getWidth() > res.getWidth()) { - return b->create(xla_op.getLoc(), result_types, args, - mlir::None); + return b->create(loc, result_types, args, mlir::None); } else if (src.getWidth() < res.getWidth()) { - return b->create(xla_op.getLoc(), result_types, args, + return b->create(loc, result_types, args, mlir::None); } // No conversion is needed for the same width integers @@ -322,35 +245,25 @@ inline Value MapXlaOpToStdScalarOp( } // TODO(dfki-ehna): Add other primitive type conversions // if (mlir::FpToSiOp::areCastCompatible(sourceType, targetType)) { - // return b.create(xla_op.getLoc(), result_types, + // return b.create(loc, result_types, // args,mlir::None); // } - return nullptr; } template <> -inline Value MapXlaOpToStdScalarOp(xla_lhlo::CosOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::CosOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); } /// Implements the conversion of XLA op to scalar op (to use within region of a /// linalg.generic op) for compare-select style operations like min/max. template -struct MapXlaCompareSelectOpToStdScalarOp { - Value operator()(Location loc, StringRef comparison_direction, +struct XlaCompareSelectOpToStdScalarOp { + static Value map(Location loc, StringRef comparison_direction, ArrayRef result_types, ArrayRef args, OpBuilder* b) { return nullptr; @@ -361,9 +274,9 @@ struct MapXlaCompareSelectOpToStdScalarOp { /// dialect with a given predicate based on the element type of the operand. template -struct MapXlaCompareSelectOpToStdScalarOp { - Value operator()(Location loc, StringRef comparison_direction, +struct XlaCompareSelectOpToStdScalarOp { + static Value map(Location loc, StringRef comparison_direction, ArrayRef result_types, ArrayRef args, OpBuilder* b) { Type element_type = args.front().getType(); @@ -374,132 +287,130 @@ struct MapXlaCompareSelectOpToStdScalarOpcreate<::mlir::SelectOp>(loc, cmp, args[0], args[1]); } - return MapXlaCompareSelectOpToStdScalarOp{}( + return XlaCompareSelectOpToStdScalarOp::map( loc, comparison_direction, result_types, args, b); } }; template <> -inline Value MapXlaOpToStdScalarOp(xla_lhlo::LogOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::LogOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); -} - -template <> -inline Value MapXlaOpToStdScalarOp(xla_lhlo::MaxOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaCompareSelectOpToStdScalarOp< - IntegerType, ScalarIOp, CmpIPredicate, FloatType, - ScalarFOp, CmpFPredicate>{}(xla_op.getLoc(), "GT", - result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::MaxOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaCompareSelectOpToStdScalarOp< - IntegerType, ScalarIOp, CmpIPredicate, FloatType, - ScalarFOp, CmpFPredicate>{}(xla_op.getLoc(), "GT", - result_types, args, b); -} - -template <> -inline Value MapXlaOpToStdScalarOp(xla_lhlo::MinOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaCompareSelectOpToStdScalarOp< - IntegerType, ScalarIOp, CmpIPredicate, FloatType, - ScalarFOp, CmpFPredicate>{}(xla_op.getLoc(), "LT", - result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::MinOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaCompareSelectOpToStdScalarOp< - IntegerType, ScalarIOp, CmpIPredicate, FloatType, - ScalarFOp, CmpFPredicate>{}(xla_op.getLoc(), "LT", - result_types, args, b); -} - -template <> -inline Value MapXlaOpToStdScalarOp(xla_lhlo::NegOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::NegOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); -} - -template <> -inline Value MapXlaOpToStdScalarOp( - xla_lhlo::SelectOp xla_op, ArrayRef result_types, - ArrayRef args, OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(), - result_types, args, b); -} -template <> -inline Value MapXlaOpToStdScalarOp( - xla_hlo::SelectOp xla_op, ArrayRef result_types, ArrayRef args, +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(), - result_types, args, b); + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); } template <> -inline Value MapXlaOpToStdScalarOp( - xla_lhlo::SignOp xla_op, ArrayRef result_types, ArrayRef args, +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return XlaCompareSelectOpToStdScalarOp< + IntegerType, ScalarIOp, CmpIPredicate, FloatType, + ScalarFOp, CmpFPredicate>::map(loc, "GT", + result_types, args, + b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return XlaCompareSelectOpToStdScalarOp< + IntegerType, ScalarIOp, CmpIPredicate, FloatType, + ScalarFOp, CmpFPredicate>::map(loc, "LT", + result_types, args, + b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args, + b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { Type element_type = args.front().getType(); if (element_type.isa()) { FloatType float_type = element_type.cast(); APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0); - Value one = b->create(xla_op.getLoc(), const_value, - float_type); - return b->create<::mlir::CopySignOp>(xla_op.getLoc(), result_types, one, - args[0]); + Value one = b->create(loc, const_value, float_type); + return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]); } return nullptr; } template <> -inline Value MapXlaOpToStdScalarOp( - xla_lhlo::TanhOp xla_op, ArrayRef result_types, ArrayRef args, +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); } -template <> -inline Value MapXlaOpToStdScalarOp(xla_hlo::TanhOp xla_op, - ArrayRef result_types, - ArrayRef args, - OpBuilder* b) { - return MapXlaOpToStdScalarOpImpl{}( - xla_op.getLoc(), result_types, args, b); + +} // namespace impl + +struct XlaOpToStdScalarOp { + // Implementation for LHLO ops except xla_lhlo::CompareOp. + template ::value && + std::is_same, + std::false_type>::value>> + static Value map(XlaOpTy op, ArrayRef result_types, + ArrayRef args, OpBuilder* b, unsigned i = 0) { + return impl::MapLhloOpToStdScalarOp(op.getLoc(), result_types, + args, b); + } + + // Implementation for HLO ops except xla_hlo::CompareOp. + template , + typename = std::enable_if_t< + !std::is_same::value && + !std::is_same::value>> + static Value map(XlaOpTy op, ArrayRef result_types, + ArrayRef args, OpBuilder* b, int i = 0) { + return impl::MapLhloOpToStdScalarOp(op.getLoc(), result_types, + args, b); + } + + // Implementation for xla_lhlo::CompareOp. + template ::value>> + static Value map(xla_lhlo::CompareOp op, ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + auto comparison_direction = op.comparison_direction(); + return impl::MapXlaCompareOpToStdScalarOp( + op.getLoc(), comparison_direction, result_types, args, b); + } + + // Implementation for xla_hlo::CompareOp. + template ::value>> + static Value map(xla_hlo::CompareOp op, ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + auto comparison_direction = op.comparison_direction(); + return impl::MapXlaCompareOpToStdScalarOp( + op.getLoc(), comparison_direction, result_types, args, b); + } +}; + +template +inline Value MapXlaOpToStdScalarOp(XlaOpTy xla_op, ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + return XlaOpToStdScalarOp::map(xla_op, result_types, args, b); } } // namespace xla_lhlo diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index 3548c2b6c62..5623dadcabc 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -149,8 +149,8 @@ class PointwiseToLinalgConverter : public OpConversionPattern { rewriter.setInsertionPointToEnd(block); // TODO(ravishankarm) : For now use the method in xla_lhlo namespace. That // method needs to be moved out of there. - Value opResult = xla_lhlo::MapXlaOpToStdScalarOp( - llvm::cast(op), bodyResultTypes, bodyArgs, &rewriter); + Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( + op, bodyResultTypes, bodyArgs, &rewriter); if (!opResult) { return ConversionPattern::matchFailure(); } @@ -180,9 +180,9 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { auto lhs = rewriter.create(loc, lhlo_op.lhs()); auto rhs = rewriter.create(loc, lhlo_op.rhs()); // TODO(ravishankarm) : Move this method out of xla_lhlo namespace. - Value opResult = xla_lhlo::MapXlaOpToStdScalarOp( - llvm::cast(lhlo_op), argType.getElementType(), - llvm::ArrayRef{lhs, rhs}, &rewriter); + Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( + lhlo_op, argType.getElementType(), llvm::ArrayRef{lhs, rhs}, + &rewriter); rewriter.create(loc, opResult, lhlo_op.out()); rewriter.eraseOp(lhlo_op); return ConversionPattern::matchSuccess();