[XLA][MLIR] Reduce code bloat for LHLO->STD and HLO->STD patterns.

PiperOrigin-RevId: 298840878
Change-Id: I781008f01b5c8e478d75ba282db9aa78da546ea1
This commit is contained in:
Alexander Belyaev 2020-03-04 06:55:33 -08:00 committed by TensorFlower Gardener
parent 8a53e358fc
commit 4aaabc836e
6 changed files with 316 additions and 325 deletions

View File

@ -133,16 +133,25 @@ cc_library(
cc_library( cc_library(
name = "map_xla_to_scalar_op", name = "map_xla_to_scalar_op",
srcs = [],
hdrs = ["transforms/map_xla_to_scalar_op.h"], hdrs = ["transforms/map_xla_to_scalar_op.h"],
deps = [ deps = [
":hlo", ":hlo",
":lhlo", ":lhlo",
":map_hlo_to_lhlo_op",
"@llvm-project//llvm:support", "@llvm-project//llvm:support",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
], ],
) )
cc_library(
name = "map_hlo_to_lhlo_op",
hdrs = ["transforms/map_hlo_to_lhlo_op.h"],
deps = [
":hlo",
":lhlo",
],
)
cc_library( cc_library(
name = "hlo_shape_derivation", name = "hlo_shape_derivation",
srcs = [], srcs = [],
@ -234,6 +243,7 @@ cc_library(
":hlo", ":hlo",
":hlo_shape_derivation", ":hlo_shape_derivation",
":lhlo", ":lhlo",
":map_hlo_to_lhlo_op",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h" #include "tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h"
#include "tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" #include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
@ -117,7 +118,7 @@ Value InsertAllocAndDealloc(Location loc, Value result,
return alloc; return alloc;
} }
template <typename HloOpTy, typename LhloOpTy> template <typename HloOpTy>
class HloToLhloOpConverter : public ConversionPattern { class HloToLhloOpConverter : public ConversionPattern {
public: public:
explicit HloToLhloOpConverter(MLIRContext* context) explicit HloToLhloOpConverter(MLIRContext* context)
@ -147,14 +148,14 @@ class HloToLhloOpConverter : public ConversionPattern {
op->getLoc(), result.value(), shape_value, &rewriter)); op->getLoc(), result.value(), shape_value, &rewriter));
} }
} }
rewriter.create<LhloOpTy>(op->getLoc(), llvm::None, buffer_args, rewriter.create<xla_hlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
op->getAttrs()); buffer_args, op->getAttrs());
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size())); rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
return matchSuccess(); return matchSuccess();
} }
}; };
struct HloToLHloDynamicBroadcastInDimOpConverter struct HloToLhloDynamicBroadcastInDimOpConverter
: public OpConversionPattern<xla_hlo::DynamicBroadcastInDimOp> { : public OpConversionPattern<xla_hlo::DynamicBroadcastInDimOp> {
public: public:
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
@ -178,7 +179,7 @@ struct HloToLHloDynamicBroadcastInDimOpConverter
} }
}; };
struct HloToLHloReduceOpConverter struct HloToLhloReduceOpConverter
: public OpConversionPattern<xla_hlo::ReduceOp> { : public OpConversionPattern<xla_hlo::ReduceOp> {
public: public:
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
@ -438,36 +439,35 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
// clang-format off // clang-format off
patterns->insert< patterns->insert<
HloToLHloDynamicBroadcastInDimOpConverter, HloToLhloDynamicBroadcastInDimOpConverter,
HloToLhloFuncOpConverter, HloToLhloFuncOpConverter,
HloToLhloOpConverter<xla_hlo::AbsOp, xla_lhlo::AbsOp>, HloToLhloOpConverter<xla_hlo::AbsOp>,
HloToLhloOpConverter<xla_hlo::AddOp, xla_lhlo::AddOp>, HloToLhloOpConverter<xla_hlo::AddOp>,
HloToLhloOpConverter<xla_hlo::AndOp, xla_lhlo::AndOp>, HloToLhloOpConverter<xla_hlo::AndOp>,
HloToLhloOpConverter<xla_hlo::BroadcastInDimOp, HloToLhloOpConverter<xla_hlo::BroadcastInDimOp>,
xla_lhlo::BroadcastInDimOp>, HloToLhloOpConverter<xla_hlo::CeilOp>,
HloToLhloOpConverter<xla_hlo::CeilOp, xla_lhlo::CeilOp>, HloToLhloOpConverter<xla_hlo::CompareOp>,
HloToLhloOpConverter<xla_hlo::CompareOp, xla_lhlo::CompareOp>, HloToLhloOpConverter<xla_hlo::ConstOp>,
HloToLhloOpConverter<xla_hlo::ConstOp, xla_lhlo::ConstOp>, HloToLhloOpConverter<xla_hlo::ConvertOp>,
HloToLhloOpConverter<xla_hlo::ConvertOp, xla_lhlo::ConvertOp>, HloToLhloOpConverter<xla_hlo::CopyOp>,
HloToLhloOpConverter<xla_hlo::CopyOp, xla_lhlo::CopyOp>, HloToLhloOpConverter<xla_hlo::CosOp>,
HloToLhloOpConverter<xla_hlo::CosOp, xla_lhlo::CosOp>, HloToLhloOpConverter<xla_hlo::DivOp>,
HloToLhloOpConverter<xla_hlo::DivOp, xla_lhlo::DivOp>, HloToLhloOpConverter<xla_hlo::ExpOp>,
HloToLhloOpConverter<xla_hlo::ExpOp, xla_lhlo::ExpOp>, HloToLhloOpConverter<xla_hlo::IotaOp>,
HloToLhloOpConverter<xla_hlo::IotaOp, xla_lhlo::IotaOp>, HloToLhloOpConverter<xla_hlo::LogOp>,
HloToLhloOpConverter<xla_hlo::LogOp, xla_lhlo::LogOp>, HloToLhloOpConverter<xla_hlo::MaxOp>,
HloToLhloOpConverter<xla_hlo::MaxOp, xla_lhlo::MaxOp>, HloToLhloOpConverter<xla_hlo::MinOp>,
HloToLhloOpConverter<xla_hlo::MinOp, xla_lhlo::MinOp>, HloToLhloOpConverter<xla_hlo::MulOp>,
HloToLhloOpConverter<xla_hlo::MulOp, xla_lhlo::MulOp>, HloToLhloOpConverter<xla_hlo::NegOp>,
HloToLhloOpConverter<xla_hlo::NegOp, xla_lhlo::NegOp>, HloToLhloOpConverter<xla_hlo::RemOp>,
HloToLhloOpConverter<xla_hlo::RemOp, xla_lhlo::RemOp>, HloToLhloOpConverter<xla_hlo::SelectOp>,
HloToLhloOpConverter<xla_hlo::SelectOp, xla_lhlo::SelectOp>, HloToLhloOpConverter<xla_hlo::SignOp>,
HloToLhloOpConverter<xla_hlo::SignOp, xla_lhlo::SignOp>, HloToLhloOpConverter<xla_hlo::SubOp>,
HloToLhloOpConverter<xla_hlo::SubOp, xla_lhlo::SubOp>, HloToLhloOpConverter<xla_hlo::TanhOp>,
HloToLhloOpConverter<xla_hlo::TanhOp, xla_lhlo::TanhOp>, HloToLhloReduceOpConverter,
HloToLHloReduceOpConverter,
StdToLhloReturnOpConverter,
HloToLhloTensorLoadOpConverter, HloToLhloTensorLoadOpConverter,
HloToLhloTensorStoreOpConverter HloToLhloTensorStoreOpConverter,
StdToLhloReturnOpConverter
>(context); >(context);
// clang-format on // clang-format on
} }

View File

@ -31,11 +31,11 @@ namespace mlir {
namespace xla_lhlo { namespace xla_lhlo {
namespace { namespace {
template <typename LhloOp> template <typename LhloOpTy>
struct BinaryOpConverter : public OpRewritePattern<LhloOp> { struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
using OpRewritePattern<LhloOp>::OpRewritePattern; using OpRewritePattern<LhloOpTy>::OpRewritePattern;
PatternMatchResult matchAndRewrite(LhloOp op, PatternMatchResult matchAndRewrite(LhloOpTy op,
PatternRewriter& rewriter) const override { PatternRewriter& rewriter) const override {
const auto& lhs = op.lhs(); const auto& lhs = op.lhs();
const auto& rhs = op.rhs(); const auto& rhs = op.rhs();
@ -56,8 +56,8 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOp> {
} }
auto l = rewriter.create<LoadOp>(loc, lhs, induction_vars); auto l = rewriter.create<LoadOp>(loc, lhs, induction_vars);
auto r = rewriter.create<LoadOp>(loc, rhs, induction_vars); auto r = rewriter.create<LoadOp>(loc, rhs, induction_vars);
Value opResult = MapXlaOpToStdScalarOp<LhloOp>( Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
llvm::cast<LhloOp>(op), element_type, {l, r}, &rewriter); op, element_type, {l, r}, &rewriter);
if (opResult == nullptr) { if (opResult == nullptr) {
return this->matchFailure(); return this->matchFailure();
} }

View File

@ -0,0 +1,70 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
#include <type_traits>
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
namespace mlir {
namespace xla_hlo {
template <typename HloOpTy>
struct HloToLhloOpImpl {
using Type = std::false_type;
};
template <typename HloOpTy>
using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
#define MAP_HLO_TO_LHLO(OpName) \
template <> \
struct HloToLhloOpImpl<xla_hlo::OpName> { \
using Type = xla_lhlo::OpName; \
}
MAP_HLO_TO_LHLO(AbsOp);
MAP_HLO_TO_LHLO(AddOp);
MAP_HLO_TO_LHLO(AndOp);
MAP_HLO_TO_LHLO(BroadcastInDimOp);
MAP_HLO_TO_LHLO(CeilOp);
MAP_HLO_TO_LHLO(ConstOp);
MAP_HLO_TO_LHLO(CompareOp);
MAP_HLO_TO_LHLO(ConvertOp);
MAP_HLO_TO_LHLO(CopyOp);
MAP_HLO_TO_LHLO(CosOp);
MAP_HLO_TO_LHLO(DivOp);
MAP_HLO_TO_LHLO(ExpOp);
MAP_HLO_TO_LHLO(IotaOp);
MAP_HLO_TO_LHLO(LogOp);
MAP_HLO_TO_LHLO(MaxOp);
MAP_HLO_TO_LHLO(MinOp);
MAP_HLO_TO_LHLO(MulOp);
MAP_HLO_TO_LHLO(NegOp);
MAP_HLO_TO_LHLO(ReduceOp);
MAP_HLO_TO_LHLO(RemOp);
MAP_HLO_TO_LHLO(SelectOp);
MAP_HLO_TO_LHLO(SignOp);
MAP_HLO_TO_LHLO(SubOp);
MAP_HLO_TO_LHLO(TanhOp);
#undef MAP_HLO_TO_LHLO
} // namespace xla_hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_

View File

@ -21,81 +21,63 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h"
namespace mlir { namespace mlir {
namespace xla_lhlo { namespace xla_lhlo {
namespace impl {
template <typename LHLO_BinaryOp> // A struct to map LhloBinaryOpTy type to the corresponding floating-point and
struct ScalarOp; // integer scalar operation types.
template <typename LhloBinaryOpTy>
struct LhloToScalarOp;
template <> template <>
struct ScalarOp<xla_lhlo::AddOp> { struct LhloToScalarOp<xla_lhlo::AddOp> {
using FOp = ::mlir::AddFOp; using FOp = ::mlir::AddFOp;
using IOp = ::mlir::AddIOp; using IOp = ::mlir::AddIOp;
}; };
template <> template <>
struct ScalarOp<xla_hlo::AddOp> { struct LhloToScalarOp<xla_lhlo::CompareOp> {
using FOp = ::mlir::AddFOp;
using IOp = ::mlir::AddIOp;
};
template <>
struct ScalarOp<xla_lhlo::CompareOp> {
using FOp = ::mlir::CmpFOp; using FOp = ::mlir::CmpFOp;
using IOp = ::mlir::CmpIOp; using IOp = ::mlir::CmpIOp;
}; };
template <> template <>
struct ScalarOp<xla_hlo::CompareOp> { struct LhloToScalarOp<xla_lhlo::DivOp> {
using FOp = ::mlir::CmpFOp;
using IOp = ::mlir::CmpIOp;
};
template <>
struct ScalarOp<xla_lhlo::DivOp> {
using FOp = ::mlir::DivFOp; using FOp = ::mlir::DivFOp;
using IOp = ::mlir::SignedDivIOp; using IOp = ::mlir::SignedDivIOp;
}; };
template <> template <>
struct ScalarOp<xla_hlo::DivOp> { struct LhloToScalarOp<xla_lhlo::MulOp> {
using FOp = ::mlir::DivFOp;
using IOp = ::mlir::SignedDivIOp;
};
template <>
struct ScalarOp<xla_lhlo::MulOp> {
using FOp = ::mlir::MulFOp; using FOp = ::mlir::MulFOp;
using IOp = ::mlir::MulIOp; using IOp = ::mlir::MulIOp;
}; };
template <> template <>
struct ScalarOp<xla_hlo::MulOp> { struct LhloToScalarOp<xla_lhlo::RemOp> {
using FOp = ::mlir::MulFOp;
using IOp = ::mlir::MulIOp;
};
template <>
struct ScalarOp<xla_lhlo::RemOp> {
using FOp = ::mlir::RemFOp; using FOp = ::mlir::RemFOp;
using IOp = ::mlir::SignedRemIOp; using IOp = ::mlir::SignedRemIOp;
}; };
template <> template <>
struct ScalarOp<xla_hlo::RemOp> { struct LhloToScalarOp<xla_lhlo::SubOp> {
using FOp = ::mlir::RemFOp;
using IOp = ::mlir::SignedRemIOp;
};
template <>
struct ScalarOp<xla_lhlo::SubOp> {
using FOp = ::mlir::SubFOp;
using IOp = ::mlir::SubIOp;
};
template <>
struct ScalarOp<xla_hlo::SubOp> {
using FOp = ::mlir::SubFOp; using FOp = ::mlir::SubFOp;
using IOp = ::mlir::SubIOp; using IOp = ::mlir::SubIOp;
}; };
template <typename XLA_BinaryOp> template <typename LhloBinaryOpTy>
using ScalarFOp = typename ScalarOp<XLA_BinaryOp>::FOp; struct ScalarOp {
template <typename XLA_BinaryOp> using FOp = typename LhloToScalarOp<LhloBinaryOpTy>::FOp;
using ScalarIOp = typename ScalarOp<XLA_BinaryOp>::IOp; using IOp = typename LhloToScalarOp<LhloBinaryOpTy>::IOp;
};
// Alias for the map from LHLO binary op type to STD floating-point op type.
template <typename LhloOp>
using ScalarFOp = typename ScalarOp<LhloOp>::FOp;
// Alias for the map from LHLO binary op type to STD integer op type.
template <typename LhloOp>
using ScalarIOp = typename ScalarOp<LhloOp>::IOp;
template <typename... Args> template <typename... Args>
struct MapXlaOpToStdScalarOpImpl { struct MapLhloOpToStdScalarOpImpl {
Value operator()(Location loc, ArrayRef<Type> result_types, Value operator()(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) { ArrayRef<Value> args, OpBuilder* b) {
return nullptr; return nullptr;
@ -103,7 +85,7 @@ struct MapXlaOpToStdScalarOpImpl {
}; };
template <typename StdScalarOp> template <typename StdScalarOp>
struct MapXlaOpToStdScalarOpImpl<StdScalarOp> { struct MapLhloOpToStdScalarOpImpl<StdScalarOp> {
Value operator()(Location loc, ArrayRef<Type> result_types, Value operator()(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) { ArrayRef<Value> args, OpBuilder* b) {
return b->template create<StdScalarOp>(loc, result_types, args, mlir::None); return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
@ -111,7 +93,7 @@ struct MapXlaOpToStdScalarOpImpl<StdScalarOp> {
}; };
template <typename SupportedType, typename StdScalarOp, typename... Args> template <typename SupportedType, typename StdScalarOp, typename... Args>
struct MapXlaOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> { struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
Value operator()(Location loc, ArrayRef<Type> result_types, Value operator()(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) { ArrayRef<Value> args, OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = args.front().getType();
@ -119,52 +101,34 @@ struct MapXlaOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
return b->template create<StdScalarOp>(loc, result_types, args, return b->template create<StdScalarOp>(loc, result_types, args,
mlir::None); mlir::None);
} }
return MapXlaOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b); return MapLhloOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b);
} }
}; };
template <typename XlaOp> // Inserts the computation that corresponds to the body of the loop for lowered
inline Value MapXlaOpToStdScalarOp(XlaOp xla_op, ArrayRef<Type> result_types, // LHLO unary/binary op. Returns the value for the result.
ArrayRef<Value> args, OpBuilder* b) { template <typename LhloOpTy>
return MapXlaOpToStdScalarOpImpl<IntegerType, ScalarIOp<XlaOp>, FloatType, inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
ScalarFOp<XlaOp>>{}(xla_op.getLoc(), ArrayRef<Value> args, OpBuilder* b) {
result_types, args, b); return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType,
} ScalarFOp<LhloOpTy>>{}(loc, result_types,
args, b);
// TODO(ravishankarm): Find a way to reduce code-bloat in HLO and LHLO
// specialization.
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::AbsOp>(xla_lhlo::AbsOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::AbsOp>(xla_hlo::AbsOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
xla_op.getLoc(), result_types, args, b);
} }
template <> template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::AndOp>(xla_lhlo::AndOp xla_op, inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
ArrayRef<Type> result_types, Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
ArrayRef<Value> args, OpBuilder* b) {
OpBuilder* b) { return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
return MapXlaOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}( loc, result_types, args, b);
xla_op.getLoc(), result_types, args, b);
} }
template <> template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::AndOp>(xla_hlo::AndOp xla_op, inline Value MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
ArrayRef<Type> result_types, Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
ArrayRef<Value> args, OpBuilder* b) {
OpBuilder* b) { return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
return MapXlaOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}( loc, result_types, args, b);
xla_op.getLoc(), result_types, args, b);
} }
template <typename PredicateType> template <typename PredicateType>
@ -200,7 +164,8 @@ inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
} }
template <typename XLACompareOpTy> template <typename XLACompareOpTy>
inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op, inline Value MapXlaCompareOpToStdScalarOp(Location loc,
StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) { ArrayRef<Value> args, OpBuilder* b) {
const auto& lhs = args[0]; const auto& lhs = args[0];
@ -208,101 +173,60 @@ inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op,
Type element_type = lhs.getType(); Type element_type = lhs.getType();
if (element_type.isSignlessInteger()) { if (element_type.isSignlessInteger()) {
Optional<CmpIPredicate> predicate = Optional<CmpIPredicate> predicate =
getCmpPredicate<CmpIPredicate>(xla_op.comparison_direction()); getCmpPredicate<CmpIPredicate>(comparison_direction);
assert(predicate.hasValue() && "expected valid comparison direction"); assert(predicate.hasValue() && "expected valid comparison direction");
return b->create<ScalarIOp<XLACompareOpTy>>(xla_op.getLoc(), return b->create<ScalarIOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
predicate.getValue(), lhs, rhs); rhs);
} }
if (element_type.isa<FloatType>()) { if (element_type.isa<FloatType>()) {
Optional<CmpFPredicate> predicate = Optional<CmpFPredicate> predicate =
getCmpPredicate<CmpFPredicate>(xla_op.comparison_direction()); getCmpPredicate<CmpFPredicate>(comparison_direction);
assert(predicate.hasValue() && "expected valid comparison direction"); assert(predicate.hasValue() && "expected valid comparison direction");
return b->create<ScalarFOp<XLACompareOpTy>>(xla_op.getLoc(), return b->create<ScalarFOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
predicate.getValue(), lhs, rhs); rhs);
} }
return nullptr; return nullptr;
} }
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CompareOp>(
xla_lhlo::CompareOp xla_op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(xla_op, result_types,
args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::CompareOp>(
xla_hlo::CompareOp xla_op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return MapXlaCompareOpToStdScalarOp<xla_hlo::CompareOp>(xla_op, result_types,
args, b);
}
template <> template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CopyOp>( inline Value MapLhloOpToStdScalarOp<xla_lhlo::CopyOp>(
xla_lhlo::CopyOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args, Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return args.front(); return args.front();
} }
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::CopyOp>(xla_hlo::CopyOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return args.front();
}
template <> template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::ExpOp>(xla_lhlo::ExpOp xla_op, inline Value MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
ArrayRef<Type> result_types, Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::ExpOp>(xla_hlo::ExpOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CeilOp>(
xla_lhlo::CeilOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}( return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
xla_op.getLoc(), result_types, args, b); loc, result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::CeilOp>(xla_hlo::CeilOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
xla_op.getLoc(), result_types, args, b);
} }
template <> template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>( inline Value MapLhloOpToStdScalarOp<xla_lhlo::CeilOp>(
xla_lhlo::ConvertOp xla_op, ArrayRef<Type> result_types, Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
ArrayRef<Value> args, OpBuilder* b) { OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
Type sourceType = args.front().getType(); Type sourceType = args.front().getType();
Type targetType = result_types.front(); Type targetType = result_types.front();
if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) { if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
return b->create<mlir::SIToFPOp>(xla_op.getLoc(), result_types, args, return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None);
mlir::None);
} else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) { } else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) {
FloatType src = sourceType.cast<FloatType>(); FloatType src = sourceType.cast<FloatType>();
FloatType res = targetType.cast<FloatType>(); FloatType res = targetType.cast<FloatType>();
if (src.getWidth() > res.getWidth()) { if (src.getWidth() > res.getWidth()) {
return b->create<mlir::FPTruncOp>(xla_op.getLoc(), result_types, args, return b->create<mlir::FPTruncOp>(loc, result_types, args, mlir::None);
mlir::None);
} else if (src.getWidth() < res.getWidth()) { } else if (src.getWidth() < res.getWidth()) {
return b->create<mlir::FPExtOp>(xla_op.getLoc(), result_types, args, return b->create<mlir::FPExtOp>(loc, result_types, args, mlir::None);
mlir::None);
} }
// No conversion is needed for the same width floats // No conversion is needed for the same width floats
return args.front(); return args.front();
@ -311,10 +235,9 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
IntegerType src = sourceType.cast<IntegerType>(); IntegerType src = sourceType.cast<IntegerType>();
IntegerType res = targetType.cast<IntegerType>(); IntegerType res = targetType.cast<IntegerType>();
if (src.getWidth() > res.getWidth()) { if (src.getWidth() > res.getWidth()) {
return b->create<mlir::TruncateIOp>(xla_op.getLoc(), result_types, args, return b->create<mlir::TruncateIOp>(loc, result_types, args, mlir::None);
mlir::None);
} else if (src.getWidth() < res.getWidth()) { } else if (src.getWidth() < res.getWidth()) {
return b->create<mlir::ZeroExtendIOp>(xla_op.getLoc(), result_types, args, return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
mlir::None); mlir::None);
} }
// No conversion is needed for the same width integers // No conversion is needed for the same width integers
@ -322,35 +245,25 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
} }
// TODO(dfki-ehna): Add other primitive type conversions // TODO(dfki-ehna): Add other primitive type conversions
// if (mlir::FpToSiOp::areCastCompatible(sourceType, targetType)) { // if (mlir::FpToSiOp::areCastCompatible(sourceType, targetType)) {
// return b.create<mlir::FpToSiOp>(xla_op.getLoc(), result_types, // return b.create<mlir::FpToSiOp>(loc, result_types,
// args,mlir::None); // args,mlir::None);
// } // }
return nullptr; return nullptr;
} }
template <> template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CosOp>(xla_lhlo::CosOp xla_op, inline Value MapLhloOpToStdScalarOp<xla_lhlo::CosOp>(
ArrayRef<Type> result_types, Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
ArrayRef<Value> args, OpBuilder* b) {
OpBuilder* b) { return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}( loc, result_types, args, b);
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::CosOp>(xla_hlo::CosOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
xla_op.getLoc(), result_types, args, b);
} }
/// Implements the conversion of XLA op to scalar op (to use within region of a /// Implements the conversion of XLA op to scalar op (to use within region of a
/// linalg.generic op) for compare-select style operations like min/max. /// linalg.generic op) for compare-select style operations like min/max.
template <typename... Args> template <typename... Args>
struct MapXlaCompareSelectOpToStdScalarOp { struct XlaCompareSelectOpToStdScalarOp {
Value operator()(Location loc, StringRef comparison_direction, static Value map(Location loc, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return nullptr; return nullptr;
@ -361,9 +274,9 @@ struct MapXlaCompareSelectOpToStdScalarOp {
/// dialect with a given predicate based on the element type of the operand. /// dialect with a given predicate based on the element type of the operand.
template <typename SupportedType, typename StdCompareOp, typename Predicate, template <typename SupportedType, typename StdCompareOp, typename Predicate,
typename... Args> typename... Args>
struct MapXlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
Predicate, Args...> { Args...> {
Value operator()(Location loc, StringRef comparison_direction, static Value map(Location loc, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Value> args, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = args.front().getType();
@ -374,132 +287,130 @@ struct MapXlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp,
args[0], args[1]); args[0], args[1]);
return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]); return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
} }
return MapXlaCompareSelectOpToStdScalarOp<Args...>{}( return XlaCompareSelectOpToStdScalarOp<Args...>::map(
loc, comparison_direction, result_types, args, b); loc, comparison_direction, result_types, args, b);
} }
}; };
template <> template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::LogOp>(xla_lhlo::LogOp xla_op, inline Value MapLhloOpToStdScalarOp<xla_lhlo::LogOp>(
ArrayRef<Type> result_types, Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::LogOp>(xla_hlo::LogOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::MaxOp>(xla_lhlo::MaxOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "GT",
result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::MaxOp>(xla_hlo::MaxOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_hlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_hlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "GT",
result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::MinOp>(xla_lhlo::MinOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "LT",
result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::MinOp>(xla_hlo::MinOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_hlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_hlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "LT",
result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::NegOp>(xla_lhlo::NegOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::NegOp>(xla_hlo::NegOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
xla_op.getLoc(), result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::SelectOp>(
xla_lhlo::SelectOp xla_op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(),
result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::SelectOp>(
xla_hlo::SelectOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(), return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
result_types, args, b); loc, result_types, args, b);
} }
template <> template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::SignOp>( inline Value MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
xla_lhlo::SignOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args, Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return XlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
result_types, args,
b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return XlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
result_types, args,
b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) { if (element_type.isa<FloatType>()) {
FloatType float_type = element_type.cast<FloatType>(); FloatType float_type = element_type.cast<FloatType>();
APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0); APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0);
Value one = b->create<mlir::ConstantFloatOp>(xla_op.getLoc(), const_value, Value one = b->create<mlir::ConstantFloatOp>(loc, const_value, float_type);
float_type); return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
return b->create<::mlir::CopySignOp>(xla_op.getLoc(), result_types, one,
args[0]);
} }
return nullptr; return nullptr;
} }
template <> template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::TanhOp>( inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
xla_lhlo::TanhOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args, Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}( return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
xla_op.getLoc(), result_types, args, b); loc, result_types, args, b);
} }
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::TanhOp>(xla_hlo::TanhOp xla_op, } // namespace impl
ArrayRef<Type> result_types,
ArrayRef<Value> args, struct XlaOpToStdScalarOp {
OpBuilder* b) { // Implementation for LHLO ops except xla_lhlo::CompareOp.
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}( template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
xla_op.getLoc(), result_types, args, b); typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
std::is_same<typename xla_hlo::HloToLhloOp<LhloOpTy>,
std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
args, b);
}
// Implementation for HLO ops except xla_hlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = xla_hlo::HloToLhloOp<XlaOpTy>,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
!std::is_same<LhloOpTy, std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, int i = 0) {
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
args, b);
}
// Implementation for xla_lhlo::CompareOp.
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
LhloOpTy, xla_lhlo::CompareOp>::value>>
static Value map(xla_lhlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b);
}
// Implementation for xla_hlo::CompareOp.
template <typename HloOpTy, typename = std::enable_if_t<std::is_same<
HloOpTy, xla_hlo::CompareOp>::value>>
static Value map(xla_hlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b);
}
};
template <typename XlaOpTy>
inline Value MapXlaOpToStdScalarOp(XlaOpTy xla_op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return XlaOpToStdScalarOp::map<XlaOpTy>(xla_op, result_types, args, b);
} }
} // namespace xla_lhlo } // namespace xla_lhlo

View File

@ -149,8 +149,8 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
rewriter.setInsertionPointToEnd(block); rewriter.setInsertionPointToEnd(block);
// TODO(ravishankarm) : For now use the method in xla_lhlo namespace. That // TODO(ravishankarm) : For now use the method in xla_lhlo namespace. That
// method needs to be moved out of there. // method needs to be moved out of there.
Value opResult = xla_lhlo::MapXlaOpToStdScalarOp<OpTy>( Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<OpTy>(
llvm::cast<OpTy>(op), bodyResultTypes, bodyArgs, &rewriter); op, bodyResultTypes, bodyArgs, &rewriter);
if (!opResult) { if (!opResult) {
return ConversionPattern::matchFailure(); return ConversionPattern::matchFailure();
} }
@ -180,9 +180,9 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs()); auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs()); auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
// TODO(ravishankarm) : Move this method out of xla_lhlo namespace. // TODO(ravishankarm) : Move this method out of xla_lhlo namespace.
Value opResult = xla_lhlo::MapXlaOpToStdScalarOp<LhloOp>( Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOp>(
llvm::cast<LhloOp>(lhlo_op), argType.getElementType(), lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
llvm::ArrayRef<Value>{lhs, rhs}, &rewriter); &rewriter);
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out()); rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
rewriter.eraseOp(lhlo_op); rewriter.eraseOp(lhlo_op);
return ConversionPattern::matchSuccess(); return ConversionPattern::matchSuccess();