[XLA][MLIR] Reduce code bloat for LHLO->STD and HLO->STD patterns.

PiperOrigin-RevId: 298840878
Change-Id: I781008f01b5c8e478d75ba282db9aa78da546ea1
This commit is contained in:
Alexander Belyaev 2020-03-04 06:55:33 -08:00 committed by TensorFlower Gardener
parent 8a53e358fc
commit 4aaabc836e
6 changed files with 316 additions and 325 deletions

View File

@ -133,16 +133,25 @@ cc_library(
cc_library(
name = "map_xla_to_scalar_op",
srcs = [],
hdrs = ["transforms/map_xla_to_scalar_op.h"],
deps = [
":hlo",
":lhlo",
":map_hlo_to_lhlo_op",
"@llvm-project//llvm:support",
"@llvm-project//mlir:StandardOps",
],
)
cc_library(
name = "map_hlo_to_lhlo_op",
hdrs = ["transforms/map_hlo_to_lhlo_op.h"],
deps = [
":hlo",
":lhlo",
],
)
cc_library(
name = "hlo_shape_derivation",
srcs = [],
@ -234,6 +243,7 @@ cc_library(
":hlo",
":hlo_shape_derivation",
":lhlo",
":map_hlo_to_lhlo_op",
"@com_google_absl//absl/memory",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h"
#include "tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
@ -117,7 +118,7 @@ Value InsertAllocAndDealloc(Location loc, Value result,
return alloc;
}
template <typename HloOpTy, typename LhloOpTy>
template <typename HloOpTy>
class HloToLhloOpConverter : public ConversionPattern {
public:
explicit HloToLhloOpConverter(MLIRContext* context)
@ -147,14 +148,14 @@ class HloToLhloOpConverter : public ConversionPattern {
op->getLoc(), result.value(), shape_value, &rewriter));
}
}
rewriter.create<LhloOpTy>(op->getLoc(), llvm::None, buffer_args,
op->getAttrs());
rewriter.create<xla_hlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
buffer_args, op->getAttrs());
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
return matchSuccess();
}
};
struct HloToLHloDynamicBroadcastInDimOpConverter
struct HloToLhloDynamicBroadcastInDimOpConverter
: public OpConversionPattern<xla_hlo::DynamicBroadcastInDimOp> {
public:
using OpConversionPattern::OpConversionPattern;
@ -178,7 +179,7 @@ struct HloToLHloDynamicBroadcastInDimOpConverter
}
};
struct HloToLHloReduceOpConverter
struct HloToLhloReduceOpConverter
: public OpConversionPattern<xla_hlo::ReduceOp> {
public:
using OpConversionPattern::OpConversionPattern;
@ -438,36 +439,35 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<
HloToLHloDynamicBroadcastInDimOpConverter,
HloToLhloDynamicBroadcastInDimOpConverter,
HloToLhloFuncOpConverter,
HloToLhloOpConverter<xla_hlo::AbsOp, xla_lhlo::AbsOp>,
HloToLhloOpConverter<xla_hlo::AddOp, xla_lhlo::AddOp>,
HloToLhloOpConverter<xla_hlo::AndOp, xla_lhlo::AndOp>,
HloToLhloOpConverter<xla_hlo::BroadcastInDimOp,
xla_lhlo::BroadcastInDimOp>,
HloToLhloOpConverter<xla_hlo::CeilOp, xla_lhlo::CeilOp>,
HloToLhloOpConverter<xla_hlo::CompareOp, xla_lhlo::CompareOp>,
HloToLhloOpConverter<xla_hlo::ConstOp, xla_lhlo::ConstOp>,
HloToLhloOpConverter<xla_hlo::ConvertOp, xla_lhlo::ConvertOp>,
HloToLhloOpConverter<xla_hlo::CopyOp, xla_lhlo::CopyOp>,
HloToLhloOpConverter<xla_hlo::CosOp, xla_lhlo::CosOp>,
HloToLhloOpConverter<xla_hlo::DivOp, xla_lhlo::DivOp>,
HloToLhloOpConverter<xla_hlo::ExpOp, xla_lhlo::ExpOp>,
HloToLhloOpConverter<xla_hlo::IotaOp, xla_lhlo::IotaOp>,
HloToLhloOpConverter<xla_hlo::LogOp, xla_lhlo::LogOp>,
HloToLhloOpConverter<xla_hlo::MaxOp, xla_lhlo::MaxOp>,
HloToLhloOpConverter<xla_hlo::MinOp, xla_lhlo::MinOp>,
HloToLhloOpConverter<xla_hlo::MulOp, xla_lhlo::MulOp>,
HloToLhloOpConverter<xla_hlo::NegOp, xla_lhlo::NegOp>,
HloToLhloOpConverter<xla_hlo::RemOp, xla_lhlo::RemOp>,
HloToLhloOpConverter<xla_hlo::SelectOp, xla_lhlo::SelectOp>,
HloToLhloOpConverter<xla_hlo::SignOp, xla_lhlo::SignOp>,
HloToLhloOpConverter<xla_hlo::SubOp, xla_lhlo::SubOp>,
HloToLhloOpConverter<xla_hlo::TanhOp, xla_lhlo::TanhOp>,
HloToLHloReduceOpConverter,
StdToLhloReturnOpConverter,
HloToLhloOpConverter<xla_hlo::AbsOp>,
HloToLhloOpConverter<xla_hlo::AddOp>,
HloToLhloOpConverter<xla_hlo::AndOp>,
HloToLhloOpConverter<xla_hlo::BroadcastInDimOp>,
HloToLhloOpConverter<xla_hlo::CeilOp>,
HloToLhloOpConverter<xla_hlo::CompareOp>,
HloToLhloOpConverter<xla_hlo::ConstOp>,
HloToLhloOpConverter<xla_hlo::ConvertOp>,
HloToLhloOpConverter<xla_hlo::CopyOp>,
HloToLhloOpConverter<xla_hlo::CosOp>,
HloToLhloOpConverter<xla_hlo::DivOp>,
HloToLhloOpConverter<xla_hlo::ExpOp>,
HloToLhloOpConverter<xla_hlo::IotaOp>,
HloToLhloOpConverter<xla_hlo::LogOp>,
HloToLhloOpConverter<xla_hlo::MaxOp>,
HloToLhloOpConverter<xla_hlo::MinOp>,
HloToLhloOpConverter<xla_hlo::MulOp>,
HloToLhloOpConverter<xla_hlo::NegOp>,
HloToLhloOpConverter<xla_hlo::RemOp>,
HloToLhloOpConverter<xla_hlo::SelectOp>,
HloToLhloOpConverter<xla_hlo::SignOp>,
HloToLhloOpConverter<xla_hlo::SubOp>,
HloToLhloOpConverter<xla_hlo::TanhOp>,
HloToLhloReduceOpConverter,
HloToLhloTensorLoadOpConverter,
HloToLhloTensorStoreOpConverter
HloToLhloTensorStoreOpConverter,
StdToLhloReturnOpConverter
>(context);
// clang-format on
}

View File

@ -31,11 +31,11 @@ namespace mlir {
namespace xla_lhlo {
namespace {
template <typename LhloOp>
struct BinaryOpConverter : public OpRewritePattern<LhloOp> {
using OpRewritePattern<LhloOp>::OpRewritePattern;
template <typename LhloOpTy>
struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
using OpRewritePattern<LhloOpTy>::OpRewritePattern;
PatternMatchResult matchAndRewrite(LhloOp op,
PatternMatchResult matchAndRewrite(LhloOpTy op,
PatternRewriter& rewriter) const override {
const auto& lhs = op.lhs();
const auto& rhs = op.rhs();
@ -56,8 +56,8 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOp> {
}
auto l = rewriter.create<LoadOp>(loc, lhs, induction_vars);
auto r = rewriter.create<LoadOp>(loc, rhs, induction_vars);
Value opResult = MapXlaOpToStdScalarOp<LhloOp>(
llvm::cast<LhloOp>(op), element_type, {l, r}, &rewriter);
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
op, element_type, {l, r}, &rewriter);
if (opResult == nullptr) {
return this->matchFailure();
}

View File

@ -0,0 +1,70 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
#include <type_traits>
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
namespace mlir {
namespace xla_hlo {
template <typename HloOpTy>
struct HloToLhloOpImpl {
using Type = std::false_type;
};
template <typename HloOpTy>
using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
#define MAP_HLO_TO_LHLO(OpName) \
template <> \
struct HloToLhloOpImpl<xla_hlo::OpName> { \
using Type = xla_lhlo::OpName; \
}
MAP_HLO_TO_LHLO(AbsOp);
MAP_HLO_TO_LHLO(AddOp);
MAP_HLO_TO_LHLO(AndOp);
MAP_HLO_TO_LHLO(BroadcastInDimOp);
MAP_HLO_TO_LHLO(CeilOp);
MAP_HLO_TO_LHLO(ConstOp);
MAP_HLO_TO_LHLO(CompareOp);
MAP_HLO_TO_LHLO(ConvertOp);
MAP_HLO_TO_LHLO(CopyOp);
MAP_HLO_TO_LHLO(CosOp);
MAP_HLO_TO_LHLO(DivOp);
MAP_HLO_TO_LHLO(ExpOp);
MAP_HLO_TO_LHLO(IotaOp);
MAP_HLO_TO_LHLO(LogOp);
MAP_HLO_TO_LHLO(MaxOp);
MAP_HLO_TO_LHLO(MinOp);
MAP_HLO_TO_LHLO(MulOp);
MAP_HLO_TO_LHLO(NegOp);
MAP_HLO_TO_LHLO(ReduceOp);
MAP_HLO_TO_LHLO(RemOp);
MAP_HLO_TO_LHLO(SelectOp);
MAP_HLO_TO_LHLO(SignOp);
MAP_HLO_TO_LHLO(SubOp);
MAP_HLO_TO_LHLO(TanhOp);
#undef MAP_HLO_TO_LHLO
} // namespace xla_hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_

View File

@ -21,81 +21,63 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h"
namespace mlir {
namespace xla_lhlo {
namespace impl {
template <typename LHLO_BinaryOp>
struct ScalarOp;
// A struct to map LhloBinaryOpTy type to the corresponding floating-point and
// integer scalar operation types.
template <typename LhloBinaryOpTy>
struct LhloToScalarOp;
template <>
struct ScalarOp<xla_lhlo::AddOp> {
struct LhloToScalarOp<xla_lhlo::AddOp> {
using FOp = ::mlir::AddFOp;
using IOp = ::mlir::AddIOp;
};
template <>
struct ScalarOp<xla_hlo::AddOp> {
using FOp = ::mlir::AddFOp;
using IOp = ::mlir::AddIOp;
};
template <>
struct ScalarOp<xla_lhlo::CompareOp> {
struct LhloToScalarOp<xla_lhlo::CompareOp> {
using FOp = ::mlir::CmpFOp;
using IOp = ::mlir::CmpIOp;
};
template <>
struct ScalarOp<xla_hlo::CompareOp> {
using FOp = ::mlir::CmpFOp;
using IOp = ::mlir::CmpIOp;
};
template <>
struct ScalarOp<xla_lhlo::DivOp> {
struct LhloToScalarOp<xla_lhlo::DivOp> {
using FOp = ::mlir::DivFOp;
using IOp = ::mlir::SignedDivIOp;
};
template <>
struct ScalarOp<xla_hlo::DivOp> {
using FOp = ::mlir::DivFOp;
using IOp = ::mlir::SignedDivIOp;
};
template <>
struct ScalarOp<xla_lhlo::MulOp> {
struct LhloToScalarOp<xla_lhlo::MulOp> {
using FOp = ::mlir::MulFOp;
using IOp = ::mlir::MulIOp;
};
template <>
struct ScalarOp<xla_hlo::MulOp> {
using FOp = ::mlir::MulFOp;
using IOp = ::mlir::MulIOp;
};
template <>
struct ScalarOp<xla_lhlo::RemOp> {
struct LhloToScalarOp<xla_lhlo::RemOp> {
using FOp = ::mlir::RemFOp;
using IOp = ::mlir::SignedRemIOp;
};
template <>
struct ScalarOp<xla_hlo::RemOp> {
using FOp = ::mlir::RemFOp;
using IOp = ::mlir::SignedRemIOp;
};
template <>
struct ScalarOp<xla_lhlo::SubOp> {
using FOp = ::mlir::SubFOp;
using IOp = ::mlir::SubIOp;
};
template <>
struct ScalarOp<xla_hlo::SubOp> {
struct LhloToScalarOp<xla_lhlo::SubOp> {
using FOp = ::mlir::SubFOp;
using IOp = ::mlir::SubIOp;
};
template <typename XLA_BinaryOp>
using ScalarFOp = typename ScalarOp<XLA_BinaryOp>::FOp;
template <typename XLA_BinaryOp>
using ScalarIOp = typename ScalarOp<XLA_BinaryOp>::IOp;
template <typename LhloBinaryOpTy>
struct ScalarOp {
using FOp = typename LhloToScalarOp<LhloBinaryOpTy>::FOp;
using IOp = typename LhloToScalarOp<LhloBinaryOpTy>::IOp;
};
// Alias for the map from LHLO binary op type to STD floating-point op type.
template <typename LhloOp>
using ScalarFOp = typename ScalarOp<LhloOp>::FOp;
// Alias for the map from LHLO binary op type to STD integer op type.
template <typename LhloOp>
using ScalarIOp = typename ScalarOp<LhloOp>::IOp;
template <typename... Args>
struct MapXlaOpToStdScalarOpImpl {
struct MapLhloOpToStdScalarOpImpl {
Value operator()(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return nullptr;
@ -103,7 +85,7 @@ struct MapXlaOpToStdScalarOpImpl {
};
template <typename StdScalarOp>
struct MapXlaOpToStdScalarOpImpl<StdScalarOp> {
struct MapLhloOpToStdScalarOpImpl<StdScalarOp> {
Value operator()(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
@ -111,7 +93,7 @@ struct MapXlaOpToStdScalarOpImpl<StdScalarOp> {
};
template <typename SupportedType, typename StdScalarOp, typename... Args>
struct MapXlaOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
Value operator()(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
Type element_type = args.front().getType();
@ -119,52 +101,34 @@ struct MapXlaOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
return b->template create<StdScalarOp>(loc, result_types, args,
mlir::None);
}
return MapXlaOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b);
return MapLhloOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b);
}
};
template <typename XlaOp>
inline Value MapXlaOpToStdScalarOp(XlaOp xla_op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<IntegerType, ScalarIOp<XlaOp>, FloatType,
ScalarFOp<XlaOp>>{}(xla_op.getLoc(),
result_types, args, b);
}
// TODO(ravishankarm): Find a way to reduce code-bloat in HLO and LHLO
// specialization.
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::AbsOp>(xla_lhlo::AbsOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::AbsOp>(xla_hlo::AbsOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
xla_op.getLoc(), result_types, args, b);
// Inserts the computation that corresponds to the body of the loop for lowered
// LHLO unary/binary op. Returns the value for the result.
template <typename LhloOpTy>
inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType,
ScalarFOp<LhloOpTy>>{}(loc, result_types,
args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::AndOp>(xla_lhlo::AndOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
xla_op.getLoc(), result_types, args, b);
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::AndOp>(xla_hlo::AndOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
xla_op.getLoc(), result_types, args, b);
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
loc, result_types, args, b);
}
template <typename PredicateType>
@ -200,7 +164,8 @@ inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
}
template <typename XLACompareOpTy>
inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op,
inline Value MapXlaCompareOpToStdScalarOp(Location loc,
StringRef comparison_direction,
ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
const auto& lhs = args[0];
@ -208,101 +173,60 @@ inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op,
Type element_type = lhs.getType();
if (element_type.isSignlessInteger()) {
Optional<CmpIPredicate> predicate =
getCmpPredicate<CmpIPredicate>(xla_op.comparison_direction());
getCmpPredicate<CmpIPredicate>(comparison_direction);
assert(predicate.hasValue() && "expected valid comparison direction");
return b->create<ScalarIOp<XLACompareOpTy>>(xla_op.getLoc(),
predicate.getValue(), lhs, rhs);
return b->create<ScalarIOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
rhs);
}
if (element_type.isa<FloatType>()) {
Optional<CmpFPredicate> predicate =
getCmpPredicate<CmpFPredicate>(xla_op.comparison_direction());
getCmpPredicate<CmpFPredicate>(comparison_direction);
assert(predicate.hasValue() && "expected valid comparison direction");
return b->create<ScalarFOp<XLACompareOpTy>>(xla_op.getLoc(),
predicate.getValue(), lhs, rhs);
return b->create<ScalarFOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
rhs);
}
return nullptr;
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CompareOp>(
xla_lhlo::CompareOp xla_op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(xla_op, result_types,
args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::CompareOp>(
xla_hlo::CompareOp xla_op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return MapXlaCompareOpToStdScalarOp<xla_hlo::CompareOp>(xla_op, result_types,
args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CopyOp>(
xla_lhlo::CopyOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CopyOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return args.front();
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::CopyOp>(xla_hlo::CopyOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return args.front();
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::ExpOp>(xla_lhlo::ExpOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::ExpOp>(xla_hlo::ExpOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CeilOp>(
xla_lhlo::CeilOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::CeilOp>(xla_hlo::CeilOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
xla_op.getLoc(), result_types, args, b);
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
xla_lhlo::ConvertOp xla_op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CeilOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
Type sourceType = args.front().getType();
Type targetType = result_types.front();
if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
return b->create<mlir::SIToFPOp>(xla_op.getLoc(), result_types, args,
mlir::None);
return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None);
} else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) {
FloatType src = sourceType.cast<FloatType>();
FloatType res = targetType.cast<FloatType>();
if (src.getWidth() > res.getWidth()) {
return b->create<mlir::FPTruncOp>(xla_op.getLoc(), result_types, args,
mlir::None);
return b->create<mlir::FPTruncOp>(loc, result_types, args, mlir::None);
} else if (src.getWidth() < res.getWidth()) {
return b->create<mlir::FPExtOp>(xla_op.getLoc(), result_types, args,
mlir::None);
return b->create<mlir::FPExtOp>(loc, result_types, args, mlir::None);
}
// No conversion is needed for the same width floats
return args.front();
@ -311,10 +235,9 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
IntegerType src = sourceType.cast<IntegerType>();
IntegerType res = targetType.cast<IntegerType>();
if (src.getWidth() > res.getWidth()) {
return b->create<mlir::TruncateIOp>(xla_op.getLoc(), result_types, args,
mlir::None);
return b->create<mlir::TruncateIOp>(loc, result_types, args, mlir::None);
} else if (src.getWidth() < res.getWidth()) {
return b->create<mlir::ZeroExtendIOp>(xla_op.getLoc(), result_types, args,
return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
mlir::None);
}
// No conversion is needed for the same width integers
@ -322,35 +245,25 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
}
// TODO(dfki-ehna): Add other primitive type conversions
// if (mlir::FpToSiOp::areCastCompatible(sourceType, targetType)) {
// return b.create<mlir::FpToSiOp>(xla_op.getLoc(), result_types,
// return b.create<mlir::FpToSiOp>(loc, result_types,
// args,mlir::None);
// }
return nullptr;
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CosOp>(xla_lhlo::CosOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::CosOp>(xla_hlo::CosOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
xla_op.getLoc(), result_types, args, b);
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CosOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
loc, result_types, args, b);
}
/// Implements the conversion of XLA op to scalar op (to use within region of a
/// linalg.generic op) for compare-select style operations like min/max.
template <typename... Args>
struct MapXlaCompareSelectOpToStdScalarOp {
Value operator()(Location loc, StringRef comparison_direction,
struct XlaCompareSelectOpToStdScalarOp {
static Value map(Location loc, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return nullptr;
@ -361,9 +274,9 @@ struct MapXlaCompareSelectOpToStdScalarOp {
/// dialect with a given predicate based on the element type of the operand.
template <typename SupportedType, typename StdCompareOp, typename Predicate,
typename... Args>
struct MapXlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp,
Predicate, Args...> {
Value operator()(Location loc, StringRef comparison_direction,
struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
Args...> {
static Value map(Location loc, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
@ -374,132 +287,130 @@ struct MapXlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp,
args[0], args[1]);
return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
}
return MapXlaCompareSelectOpToStdScalarOp<Args...>{}(
return XlaCompareSelectOpToStdScalarOp<Args...>::map(
loc, comparison_direction, result_types, args, b);
}
};
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::LogOp>(xla_lhlo::LogOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::LogOp>(xla_hlo::LogOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::MaxOp>(xla_lhlo::MaxOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "GT",
result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::MaxOp>(xla_hlo::MaxOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_hlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_hlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "GT",
result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::MinOp>(xla_lhlo::MinOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "LT",
result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::MinOp>(xla_hlo::MinOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_hlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_hlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "LT",
result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::NegOp>(xla_lhlo::NegOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::NegOp>(xla_hlo::NegOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::SelectOp>(
xla_lhlo::SelectOp xla_op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(),
result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::SelectOp>(
xla_hlo::SelectOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<xla_lhlo::LogOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(),
result_types, args, b);
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::SignOp>(
xla_lhlo::SignOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return XlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
result_types, args,
b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return XlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
result_types, args,
b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) {
FloatType float_type = element_type.cast<FloatType>();
APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0);
Value one = b->create<mlir::ConstantFloatOp>(xla_op.getLoc(), const_value,
float_type);
return b->create<::mlir::CopySignOp>(xla_op.getLoc(), result_types, one,
args[0]);
Value one = b->create<mlir::ConstantFloatOp>(loc, const_value, float_type);
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
}
return nullptr;
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::TanhOp>(
xla_lhlo::TanhOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
xla_op.getLoc(), result_types, args, b);
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::TanhOp>(xla_hlo::TanhOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
xla_op.getLoc(), result_types, args, b);
} // namespace impl
struct XlaOpToStdScalarOp {
// Implementation for LHLO ops except xla_lhlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
std::is_same<typename xla_hlo::HloToLhloOp<LhloOpTy>,
std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
args, b);
}
// Implementation for HLO ops except xla_hlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = xla_hlo::HloToLhloOp<XlaOpTy>,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
!std::is_same<LhloOpTy, std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, int i = 0) {
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
args, b);
}
// Implementation for xla_lhlo::CompareOp.
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
LhloOpTy, xla_lhlo::CompareOp>::value>>
static Value map(xla_lhlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b);
}
// Implementation for xla_hlo::CompareOp.
template <typename HloOpTy, typename = std::enable_if_t<std::is_same<
HloOpTy, xla_hlo::CompareOp>::value>>
static Value map(xla_hlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b);
}
};
template <typename XlaOpTy>
inline Value MapXlaOpToStdScalarOp(XlaOpTy xla_op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return XlaOpToStdScalarOp::map<XlaOpTy>(xla_op, result_types, args, b);
}
} // namespace xla_lhlo

View File

@ -149,8 +149,8 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
rewriter.setInsertionPointToEnd(block);
// TODO(ravishankarm) : For now use the method in xla_lhlo namespace. That
// method needs to be moved out of there.
Value opResult = xla_lhlo::MapXlaOpToStdScalarOp<OpTy>(
llvm::cast<OpTy>(op), bodyResultTypes, bodyArgs, &rewriter);
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<OpTy>(
op, bodyResultTypes, bodyArgs, &rewriter);
if (!opResult) {
return ConversionPattern::matchFailure();
}
@ -180,9 +180,9 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
// TODO(ravishankarm) : Move this method out of xla_lhlo namespace.
Value opResult = xla_lhlo::MapXlaOpToStdScalarOp<LhloOp>(
llvm::cast<LhloOp>(lhlo_op), argType.getElementType(),
llvm::ArrayRef<Value>{lhs, rhs}, &rewriter);
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOp>(
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
&rewriter);
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
rewriter.eraseOp(lhlo_op);
return ConversionPattern::matchSuccess();