[XLA][MLIR] Reduce code bloat for LHLO->STD and HLO->STD patterns.
PiperOrigin-RevId: 298840878 Change-Id: I781008f01b5c8e478d75ba282db9aa78da546ea1
This commit is contained in:
parent
8a53e358fc
commit
4aaabc836e
@ -133,16 +133,25 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "map_xla_to_scalar_op",
|
||||
srcs = [],
|
||||
hdrs = ["transforms/map_xla_to_scalar_op.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":lhlo",
|
||||
":map_hlo_to_lhlo_op",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "map_hlo_to_lhlo_op",
|
||||
hdrs = ["transforms/map_hlo_to_lhlo_op.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":lhlo",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_shape_derivation",
|
||||
srcs = [],
|
||||
@ -234,6 +243,7 @@ cc_library(
|
||||
":hlo",
|
||||
":hlo_shape_derivation",
|
||||
":lhlo",
|
||||
":map_hlo_to_lhlo_op",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
|
||||
|
||||
@ -117,7 +118,7 @@ Value InsertAllocAndDealloc(Location loc, Value result,
|
||||
return alloc;
|
||||
}
|
||||
|
||||
template <typename HloOpTy, typename LhloOpTy>
|
||||
template <typename HloOpTy>
|
||||
class HloToLhloOpConverter : public ConversionPattern {
|
||||
public:
|
||||
explicit HloToLhloOpConverter(MLIRContext* context)
|
||||
@ -147,14 +148,14 @@ class HloToLhloOpConverter : public ConversionPattern {
|
||||
op->getLoc(), result.value(), shape_value, &rewriter));
|
||||
}
|
||||
}
|
||||
rewriter.create<LhloOpTy>(op->getLoc(), llvm::None, buffer_args,
|
||||
op->getAttrs());
|
||||
rewriter.create<xla_hlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||
buffer_args, op->getAttrs());
|
||||
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLHloDynamicBroadcastInDimOpConverter
|
||||
struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
: public OpConversionPattern<xla_hlo::DynamicBroadcastInDimOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@ -178,7 +179,7 @@ struct HloToLHloDynamicBroadcastInDimOpConverter
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLHloReduceOpConverter
|
||||
struct HloToLhloReduceOpConverter
|
||||
: public OpConversionPattern<xla_hlo::ReduceOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@ -438,36 +439,35 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
HloToLHloDynamicBroadcastInDimOpConverter,
|
||||
HloToLhloDynamicBroadcastInDimOpConverter,
|
||||
HloToLhloFuncOpConverter,
|
||||
HloToLhloOpConverter<xla_hlo::AbsOp, xla_lhlo::AbsOp>,
|
||||
HloToLhloOpConverter<xla_hlo::AddOp, xla_lhlo::AddOp>,
|
||||
HloToLhloOpConverter<xla_hlo::AndOp, xla_lhlo::AndOp>,
|
||||
HloToLhloOpConverter<xla_hlo::BroadcastInDimOp,
|
||||
xla_lhlo::BroadcastInDimOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CeilOp, xla_lhlo::CeilOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CompareOp, xla_lhlo::CompareOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConstOp, xla_lhlo::ConstOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConvertOp, xla_lhlo::ConvertOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CopyOp, xla_lhlo::CopyOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CosOp, xla_lhlo::CosOp>,
|
||||
HloToLhloOpConverter<xla_hlo::DivOp, xla_lhlo::DivOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ExpOp, xla_lhlo::ExpOp>,
|
||||
HloToLhloOpConverter<xla_hlo::IotaOp, xla_lhlo::IotaOp>,
|
||||
HloToLhloOpConverter<xla_hlo::LogOp, xla_lhlo::LogOp>,
|
||||
HloToLhloOpConverter<xla_hlo::MaxOp, xla_lhlo::MaxOp>,
|
||||
HloToLhloOpConverter<xla_hlo::MinOp, xla_lhlo::MinOp>,
|
||||
HloToLhloOpConverter<xla_hlo::MulOp, xla_lhlo::MulOp>,
|
||||
HloToLhloOpConverter<xla_hlo::NegOp, xla_lhlo::NegOp>,
|
||||
HloToLhloOpConverter<xla_hlo::RemOp, xla_lhlo::RemOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SelectOp, xla_lhlo::SelectOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SignOp, xla_lhlo::SignOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SubOp, xla_lhlo::SubOp>,
|
||||
HloToLhloOpConverter<xla_hlo::TanhOp, xla_lhlo::TanhOp>,
|
||||
HloToLHloReduceOpConverter,
|
||||
StdToLhloReturnOpConverter,
|
||||
HloToLhloOpConverter<xla_hlo::AbsOp>,
|
||||
HloToLhloOpConverter<xla_hlo::AddOp>,
|
||||
HloToLhloOpConverter<xla_hlo::AndOp>,
|
||||
HloToLhloOpConverter<xla_hlo::BroadcastInDimOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CeilOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CompareOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConstOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConvertOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CopyOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CosOp>,
|
||||
HloToLhloOpConverter<xla_hlo::DivOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ExpOp>,
|
||||
HloToLhloOpConverter<xla_hlo::IotaOp>,
|
||||
HloToLhloOpConverter<xla_hlo::LogOp>,
|
||||
HloToLhloOpConverter<xla_hlo::MaxOp>,
|
||||
HloToLhloOpConverter<xla_hlo::MinOp>,
|
||||
HloToLhloOpConverter<xla_hlo::MulOp>,
|
||||
HloToLhloOpConverter<xla_hlo::NegOp>,
|
||||
HloToLhloOpConverter<xla_hlo::RemOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SelectOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SignOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SubOp>,
|
||||
HloToLhloOpConverter<xla_hlo::TanhOp>,
|
||||
HloToLhloReduceOpConverter,
|
||||
HloToLhloTensorLoadOpConverter,
|
||||
HloToLhloTensorStoreOpConverter
|
||||
HloToLhloTensorStoreOpConverter,
|
||||
StdToLhloReturnOpConverter
|
||||
>(context);
|
||||
// clang-format on
|
||||
}
|
||||
|
@ -31,11 +31,11 @@ namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace {
|
||||
|
||||
template <typename LhloOp>
|
||||
struct BinaryOpConverter : public OpRewritePattern<LhloOp> {
|
||||
using OpRewritePattern<LhloOp>::OpRewritePattern;
|
||||
template <typename LhloOpTy>
|
||||
struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
|
||||
using OpRewritePattern<LhloOpTy>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(LhloOp op,
|
||||
PatternMatchResult matchAndRewrite(LhloOpTy op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
const auto& lhs = op.lhs();
|
||||
const auto& rhs = op.rhs();
|
||||
@ -56,8 +56,8 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOp> {
|
||||
}
|
||||
auto l = rewriter.create<LoadOp>(loc, lhs, induction_vars);
|
||||
auto r = rewriter.create<LoadOp>(loc, rhs, induction_vars);
|
||||
Value opResult = MapXlaOpToStdScalarOp<LhloOp>(
|
||||
llvm::cast<LhloOp>(op), element_type, {l, r}, &rewriter);
|
||||
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
|
||||
op, element_type, {l, r}, &rewriter);
|
||||
if (opResult == nullptr) {
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
70
tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h
Normal file
70
tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h
Normal file
@ -0,0 +1,70 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
|
||||
template <typename HloOpTy>
|
||||
struct HloToLhloOpImpl {
|
||||
using Type = std::false_type;
|
||||
};
|
||||
template <typename HloOpTy>
|
||||
using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
|
||||
|
||||
#define MAP_HLO_TO_LHLO(OpName) \
|
||||
template <> \
|
||||
struct HloToLhloOpImpl<xla_hlo::OpName> { \
|
||||
using Type = xla_lhlo::OpName; \
|
||||
}
|
||||
|
||||
MAP_HLO_TO_LHLO(AbsOp);
|
||||
MAP_HLO_TO_LHLO(AddOp);
|
||||
MAP_HLO_TO_LHLO(AndOp);
|
||||
MAP_HLO_TO_LHLO(BroadcastInDimOp);
|
||||
MAP_HLO_TO_LHLO(CeilOp);
|
||||
MAP_HLO_TO_LHLO(ConstOp);
|
||||
MAP_HLO_TO_LHLO(CompareOp);
|
||||
MAP_HLO_TO_LHLO(ConvertOp);
|
||||
MAP_HLO_TO_LHLO(CopyOp);
|
||||
MAP_HLO_TO_LHLO(CosOp);
|
||||
MAP_HLO_TO_LHLO(DivOp);
|
||||
MAP_HLO_TO_LHLO(ExpOp);
|
||||
MAP_HLO_TO_LHLO(IotaOp);
|
||||
MAP_HLO_TO_LHLO(LogOp);
|
||||
MAP_HLO_TO_LHLO(MaxOp);
|
||||
MAP_HLO_TO_LHLO(MinOp);
|
||||
MAP_HLO_TO_LHLO(MulOp);
|
||||
MAP_HLO_TO_LHLO(NegOp);
|
||||
MAP_HLO_TO_LHLO(ReduceOp);
|
||||
MAP_HLO_TO_LHLO(RemOp);
|
||||
MAP_HLO_TO_LHLO(SelectOp);
|
||||
MAP_HLO_TO_LHLO(SignOp);
|
||||
MAP_HLO_TO_LHLO(SubOp);
|
||||
MAP_HLO_TO_LHLO(TanhOp);
|
||||
|
||||
#undef MAP_HLO_TO_LHLO
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
|
@ -21,81 +21,63 @@ limitations under the License.
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace impl {
|
||||
|
||||
template <typename LHLO_BinaryOp>
|
||||
struct ScalarOp;
|
||||
// A struct to map LhloBinaryOpTy type to the corresponding floating-point and
|
||||
// integer scalar operation types.
|
||||
template <typename LhloBinaryOpTy>
|
||||
struct LhloToScalarOp;
|
||||
|
||||
template <>
|
||||
struct ScalarOp<xla_lhlo::AddOp> {
|
||||
struct LhloToScalarOp<xla_lhlo::AddOp> {
|
||||
using FOp = ::mlir::AddFOp;
|
||||
using IOp = ::mlir::AddIOp;
|
||||
};
|
||||
template <>
|
||||
struct ScalarOp<xla_hlo::AddOp> {
|
||||
using FOp = ::mlir::AddFOp;
|
||||
using IOp = ::mlir::AddIOp;
|
||||
};
|
||||
template <>
|
||||
struct ScalarOp<xla_lhlo::CompareOp> {
|
||||
struct LhloToScalarOp<xla_lhlo::CompareOp> {
|
||||
using FOp = ::mlir::CmpFOp;
|
||||
using IOp = ::mlir::CmpIOp;
|
||||
};
|
||||
template <>
|
||||
struct ScalarOp<xla_hlo::CompareOp> {
|
||||
using FOp = ::mlir::CmpFOp;
|
||||
using IOp = ::mlir::CmpIOp;
|
||||
};
|
||||
template <>
|
||||
struct ScalarOp<xla_lhlo::DivOp> {
|
||||
struct LhloToScalarOp<xla_lhlo::DivOp> {
|
||||
using FOp = ::mlir::DivFOp;
|
||||
using IOp = ::mlir::SignedDivIOp;
|
||||
};
|
||||
template <>
|
||||
struct ScalarOp<xla_hlo::DivOp> {
|
||||
using FOp = ::mlir::DivFOp;
|
||||
using IOp = ::mlir::SignedDivIOp;
|
||||
};
|
||||
template <>
|
||||
struct ScalarOp<xla_lhlo::MulOp> {
|
||||
struct LhloToScalarOp<xla_lhlo::MulOp> {
|
||||
using FOp = ::mlir::MulFOp;
|
||||
using IOp = ::mlir::MulIOp;
|
||||
};
|
||||
template <>
|
||||
struct ScalarOp<xla_hlo::MulOp> {
|
||||
using FOp = ::mlir::MulFOp;
|
||||
using IOp = ::mlir::MulIOp;
|
||||
};
|
||||
template <>
|
||||
struct ScalarOp<xla_lhlo::RemOp> {
|
||||
struct LhloToScalarOp<xla_lhlo::RemOp> {
|
||||
using FOp = ::mlir::RemFOp;
|
||||
using IOp = ::mlir::SignedRemIOp;
|
||||
};
|
||||
template <>
|
||||
struct ScalarOp<xla_hlo::RemOp> {
|
||||
using FOp = ::mlir::RemFOp;
|
||||
using IOp = ::mlir::SignedRemIOp;
|
||||
};
|
||||
template <>
|
||||
struct ScalarOp<xla_lhlo::SubOp> {
|
||||
using FOp = ::mlir::SubFOp;
|
||||
using IOp = ::mlir::SubIOp;
|
||||
};
|
||||
template <>
|
||||
struct ScalarOp<xla_hlo::SubOp> {
|
||||
struct LhloToScalarOp<xla_lhlo::SubOp> {
|
||||
using FOp = ::mlir::SubFOp;
|
||||
using IOp = ::mlir::SubIOp;
|
||||
};
|
||||
|
||||
template <typename XLA_BinaryOp>
|
||||
using ScalarFOp = typename ScalarOp<XLA_BinaryOp>::FOp;
|
||||
template <typename XLA_BinaryOp>
|
||||
using ScalarIOp = typename ScalarOp<XLA_BinaryOp>::IOp;
|
||||
template <typename LhloBinaryOpTy>
|
||||
struct ScalarOp {
|
||||
using FOp = typename LhloToScalarOp<LhloBinaryOpTy>::FOp;
|
||||
using IOp = typename LhloToScalarOp<LhloBinaryOpTy>::IOp;
|
||||
};
|
||||
|
||||
// Alias for the map from LHLO binary op type to STD floating-point op type.
|
||||
template <typename LhloOp>
|
||||
using ScalarFOp = typename ScalarOp<LhloOp>::FOp;
|
||||
// Alias for the map from LHLO binary op type to STD integer op type.
|
||||
template <typename LhloOp>
|
||||
using ScalarIOp = typename ScalarOp<LhloOp>::IOp;
|
||||
|
||||
template <typename... Args>
|
||||
struct MapXlaOpToStdScalarOpImpl {
|
||||
struct MapLhloOpToStdScalarOpImpl {
|
||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return nullptr;
|
||||
@ -103,7 +85,7 @@ struct MapXlaOpToStdScalarOpImpl {
|
||||
};
|
||||
|
||||
template <typename StdScalarOp>
|
||||
struct MapXlaOpToStdScalarOpImpl<StdScalarOp> {
|
||||
struct MapLhloOpToStdScalarOpImpl<StdScalarOp> {
|
||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
|
||||
@ -111,7 +93,7 @@ struct MapXlaOpToStdScalarOpImpl<StdScalarOp> {
|
||||
};
|
||||
|
||||
template <typename SupportedType, typename StdScalarOp, typename... Args>
|
||||
struct MapXlaOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
|
||||
struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
|
||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
@ -119,52 +101,34 @@ struct MapXlaOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
|
||||
return b->template create<StdScalarOp>(loc, result_types, args,
|
||||
mlir::None);
|
||||
}
|
||||
return MapXlaOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b);
|
||||
return MapLhloOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename XlaOp>
|
||||
inline Value MapXlaOpToStdScalarOp(XlaOp xla_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<IntegerType, ScalarIOp<XlaOp>, FloatType,
|
||||
ScalarFOp<XlaOp>>{}(xla_op.getLoc(),
|
||||
result_types, args, b);
|
||||
}
|
||||
|
||||
// TODO(ravishankarm): Find a way to reduce code-bloat in HLO and LHLO
|
||||
// specialization.
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::AbsOp>(xla_lhlo::AbsOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::AbsOp>(xla_hlo::AbsOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
// Inserts the computation that corresponds to the body of the loop for lowered
|
||||
// LHLO unary/binary op. Returns the value for the result.
|
||||
template <typename LhloOpTy>
|
||||
inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType,
|
||||
ScalarFOp<LhloOpTy>>{}(loc, result_types,
|
||||
args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::AndOp>(xla_lhlo::AndOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::AndOp>(xla_hlo::AndOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <typename PredicateType>
|
||||
@ -200,7 +164,8 @@ inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
|
||||
}
|
||||
|
||||
template <typename XLACompareOpTy>
|
||||
inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op,
|
||||
inline Value MapXlaCompareOpToStdScalarOp(Location loc,
|
||||
StringRef comparison_direction,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
const auto& lhs = args[0];
|
||||
@ -208,101 +173,60 @@ inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op,
|
||||
Type element_type = lhs.getType();
|
||||
if (element_type.isSignlessInteger()) {
|
||||
Optional<CmpIPredicate> predicate =
|
||||
getCmpPredicate<CmpIPredicate>(xla_op.comparison_direction());
|
||||
getCmpPredicate<CmpIPredicate>(comparison_direction);
|
||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||
return b->create<ScalarIOp<XLACompareOpTy>>(xla_op.getLoc(),
|
||||
predicate.getValue(), lhs, rhs);
|
||||
return b->create<ScalarIOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
|
||||
rhs);
|
||||
}
|
||||
if (element_type.isa<FloatType>()) {
|
||||
Optional<CmpFPredicate> predicate =
|
||||
getCmpPredicate<CmpFPredicate>(xla_op.comparison_direction());
|
||||
getCmpPredicate<CmpFPredicate>(comparison_direction);
|
||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||
return b->create<ScalarFOp<XLACompareOpTy>>(xla_op.getLoc(),
|
||||
predicate.getValue(), lhs, rhs);
|
||||
return b->create<ScalarFOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
|
||||
rhs);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||
xla_lhlo::CompareOp xla_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(xla_op, result_types,
|
||||
args, b);
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::CompareOp>(
|
||||
xla_hlo::CompareOp xla_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return MapXlaCompareOpToStdScalarOp<xla_hlo::CompareOp>(xla_op, result_types,
|
||||
args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CopyOp>(
|
||||
xla_lhlo::CopyOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CopyOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return args.front();
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::CopyOp>(xla_hlo::CopyOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return args.front();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::ExpOp>(xla_lhlo::ExpOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::ExpOp>(xla_hlo::ExpOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CeilOp>(
|
||||
xla_lhlo::CeilOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::CeilOp>(xla_hlo::CeilOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
|
||||
xla_lhlo::ConvertOp xla_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CeilOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type sourceType = args.front().getType();
|
||||
Type targetType = result_types.front();
|
||||
|
||||
if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
|
||||
return b->create<mlir::SIToFPOp>(xla_op.getLoc(), result_types, args,
|
||||
mlir::None);
|
||||
return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None);
|
||||
} else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) {
|
||||
FloatType src = sourceType.cast<FloatType>();
|
||||
FloatType res = targetType.cast<FloatType>();
|
||||
if (src.getWidth() > res.getWidth()) {
|
||||
return b->create<mlir::FPTruncOp>(xla_op.getLoc(), result_types, args,
|
||||
mlir::None);
|
||||
return b->create<mlir::FPTruncOp>(loc, result_types, args, mlir::None);
|
||||
} else if (src.getWidth() < res.getWidth()) {
|
||||
return b->create<mlir::FPExtOp>(xla_op.getLoc(), result_types, args,
|
||||
mlir::None);
|
||||
return b->create<mlir::FPExtOp>(loc, result_types, args, mlir::None);
|
||||
}
|
||||
// No conversion is needed for the same width floats
|
||||
return args.front();
|
||||
@ -311,10 +235,9 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
|
||||
IntegerType src = sourceType.cast<IntegerType>();
|
||||
IntegerType res = targetType.cast<IntegerType>();
|
||||
if (src.getWidth() > res.getWidth()) {
|
||||
return b->create<mlir::TruncateIOp>(xla_op.getLoc(), result_types, args,
|
||||
mlir::None);
|
||||
return b->create<mlir::TruncateIOp>(loc, result_types, args, mlir::None);
|
||||
} else if (src.getWidth() < res.getWidth()) {
|
||||
return b->create<mlir::ZeroExtendIOp>(xla_op.getLoc(), result_types, args,
|
||||
return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
|
||||
mlir::None);
|
||||
}
|
||||
// No conversion is needed for the same width integers
|
||||
@ -322,35 +245,25 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
|
||||
}
|
||||
// TODO(dfki-ehna): Add other primitive type conversions
|
||||
// if (mlir::FpToSiOp::areCastCompatible(sourceType, targetType)) {
|
||||
// return b.create<mlir::FpToSiOp>(xla_op.getLoc(), result_types,
|
||||
// return b.create<mlir::FpToSiOp>(loc, result_types,
|
||||
// args,mlir::None);
|
||||
// }
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CosOp>(xla_lhlo::CosOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::CosOp>(xla_hlo::CosOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CosOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
/// Implements the conversion of XLA op to scalar op (to use within region of a
|
||||
/// linalg.generic op) for compare-select style operations like min/max.
|
||||
template <typename... Args>
|
||||
struct MapXlaCompareSelectOpToStdScalarOp {
|
||||
Value operator()(Location loc, StringRef comparison_direction,
|
||||
struct XlaCompareSelectOpToStdScalarOp {
|
||||
static Value map(Location loc, StringRef comparison_direction,
|
||||
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return nullptr;
|
||||
@ -361,9 +274,9 @@ struct MapXlaCompareSelectOpToStdScalarOp {
|
||||
/// dialect with a given predicate based on the element type of the operand.
|
||||
template <typename SupportedType, typename StdCompareOp, typename Predicate,
|
||||
typename... Args>
|
||||
struct MapXlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp,
|
||||
Predicate, Args...> {
|
||||
Value operator()(Location loc, StringRef comparison_direction,
|
||||
struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
|
||||
Args...> {
|
||||
static Value map(Location loc, StringRef comparison_direction,
|
||||
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
@ -374,132 +287,130 @@ struct MapXlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp,
|
||||
args[0], args[1]);
|
||||
return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
|
||||
}
|
||||
return MapXlaCompareSelectOpToStdScalarOp<Args...>{}(
|
||||
return XlaCompareSelectOpToStdScalarOp<Args...>::map(
|
||||
loc, comparison_direction, result_types, args, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::LogOp>(xla_lhlo::LogOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::LogOp>(xla_hlo::LogOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::MaxOp>(xla_lhlo::MaxOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaCompareSelectOpToStdScalarOp<
|
||||
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "GT",
|
||||
result_types, args, b);
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::MaxOp>(xla_hlo::MaxOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaCompareSelectOpToStdScalarOp<
|
||||
IntegerType, ScalarIOp<xla_hlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<xla_hlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "GT",
|
||||
result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::MinOp>(xla_lhlo::MinOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaCompareSelectOpToStdScalarOp<
|
||||
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "LT",
|
||||
result_types, args, b);
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::MinOp>(xla_hlo::MinOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaCompareSelectOpToStdScalarOp<
|
||||
IntegerType, ScalarIOp<xla_hlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<xla_hlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "LT",
|
||||
result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::NegOp>(xla_lhlo::NegOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::NegOp>(xla_hlo::NegOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::SelectOp>(
|
||||
xla_lhlo::SelectOp xla_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(),
|
||||
result_types, args, b);
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::SelectOp>(
|
||||
xla_hlo::SelectOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::LogOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(),
|
||||
result_types, args, b);
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::SignOp>(
|
||||
xla_lhlo::SignOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return XlaCompareSelectOpToStdScalarOp<
|
||||
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
|
||||
result_types, args,
|
||||
b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return XlaCompareSelectOpToStdScalarOp<
|
||||
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
|
||||
result_types, args,
|
||||
b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
|
||||
b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
if (element_type.isa<FloatType>()) {
|
||||
FloatType float_type = element_type.cast<FloatType>();
|
||||
APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0);
|
||||
Value one = b->create<mlir::ConstantFloatOp>(xla_op.getLoc(), const_value,
|
||||
float_type);
|
||||
return b->create<::mlir::CopySignOp>(xla_op.getLoc(), result_types, one,
|
||||
args[0]);
|
||||
Value one = b->create<mlir::ConstantFloatOp>(loc, const_value, float_type);
|
||||
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::TanhOp>(
|
||||
xla_lhlo::TanhOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::TanhOp>(xla_hlo::TanhOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
struct XlaOpToStdScalarOp {
|
||||
// Implementation for LHLO ops except xla_lhlo::CompareOp.
|
||||
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
|
||||
typename = std::enable_if_t<
|
||||
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
|
||||
std::is_same<typename xla_hlo::HloToLhloOp<LhloOpTy>,
|
||||
std::false_type>::value>>
|
||||
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
|
||||
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
|
||||
args, b);
|
||||
}
|
||||
|
||||
// Implementation for HLO ops except xla_hlo::CompareOp.
|
||||
template <typename XlaOpTy, typename LhloOpTy = xla_hlo::HloToLhloOp<XlaOpTy>,
|
||||
typename = std::enable_if_t<
|
||||
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
|
||||
!std::is_same<LhloOpTy, std::false_type>::value>>
|
||||
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b, int i = 0) {
|
||||
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
|
||||
args, b);
|
||||
}
|
||||
|
||||
// Implementation for xla_lhlo::CompareOp.
|
||||
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
|
||||
LhloOpTy, xla_lhlo::CompareOp>::value>>
|
||||
static Value map(xla_lhlo::CompareOp op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
auto comparison_direction = op.comparison_direction();
|
||||
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||
op.getLoc(), comparison_direction, result_types, args, b);
|
||||
}
|
||||
|
||||
// Implementation for xla_hlo::CompareOp.
|
||||
template <typename HloOpTy, typename = std::enable_if_t<std::is_same<
|
||||
HloOpTy, xla_hlo::CompareOp>::value>>
|
||||
static Value map(xla_hlo::CompareOp op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
auto comparison_direction = op.comparison_direction();
|
||||
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||
op.getLoc(), comparison_direction, result_types, args, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename XlaOpTy>
|
||||
inline Value MapXlaOpToStdScalarOp(XlaOpTy xla_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return XlaOpToStdScalarOp::map<XlaOpTy>(xla_op, result_types, args, b);
|
||||
}
|
||||
|
||||
} // namespace xla_lhlo
|
||||
|
@ -149,8 +149,8 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||
rewriter.setInsertionPointToEnd(block);
|
||||
// TODO(ravishankarm) : For now use the method in xla_lhlo namespace. That
|
||||
// method needs to be moved out of there.
|
||||
Value opResult = xla_lhlo::MapXlaOpToStdScalarOp<OpTy>(
|
||||
llvm::cast<OpTy>(op), bodyResultTypes, bodyArgs, &rewriter);
|
||||
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<OpTy>(
|
||||
op, bodyResultTypes, bodyArgs, &rewriter);
|
||||
if (!opResult) {
|
||||
return ConversionPattern::matchFailure();
|
||||
}
|
||||
@ -180,9 +180,9 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
||||
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
|
||||
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
|
||||
// TODO(ravishankarm) : Move this method out of xla_lhlo namespace.
|
||||
Value opResult = xla_lhlo::MapXlaOpToStdScalarOp<LhloOp>(
|
||||
llvm::cast<LhloOp>(lhlo_op), argType.getElementType(),
|
||||
llvm::ArrayRef<Value>{lhs, rhs}, &rewriter);
|
||||
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOp>(
|
||||
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
|
||||
&rewriter);
|
||||
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
|
||||
rewriter.eraseOp(lhlo_op);
|
||||
return ConversionPattern::matchSuccess();
|
||||
|
Loading…
x
Reference in New Issue
Block a user