[XLA][MLIR] Reduce code bloat for LHLO->STD and HLO->STD patterns.
PiperOrigin-RevId: 298840878 Change-Id: I781008f01b5c8e478d75ba282db9aa78da546ea1
This commit is contained in:
parent
8a53e358fc
commit
4aaabc836e
@ -133,16 +133,25 @@ cc_library(
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "map_xla_to_scalar_op",
|
name = "map_xla_to_scalar_op",
|
||||||
srcs = [],
|
|
||||||
hdrs = ["transforms/map_xla_to_scalar_op.h"],
|
hdrs = ["transforms/map_xla_to_scalar_op.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
":lhlo",
|
":lhlo",
|
||||||
|
":map_hlo_to_lhlo_op",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "map_hlo_to_lhlo_op",
|
||||||
|
hdrs = ["transforms/map_hlo_to_lhlo_op.h"],
|
||||||
|
deps = [
|
||||||
|
":hlo",
|
||||||
|
":lhlo",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "hlo_shape_derivation",
|
name = "hlo_shape_derivation",
|
||||||
srcs = [],
|
srcs = [],
|
||||||
@ -234,6 +243,7 @@ cc_library(
|
|||||||
":hlo",
|
":hlo",
|
||||||
":hlo_shape_derivation",
|
":hlo_shape_derivation",
|
||||||
":lhlo",
|
":lhlo",
|
||||||
|
":map_hlo_to_lhlo_op",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
|
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h"
|
#include "tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h"
|
||||||
|
#include "tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
|
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
|
||||||
|
|
||||||
@ -117,7 +118,7 @@ Value InsertAllocAndDealloc(Location loc, Value result,
|
|||||||
return alloc;
|
return alloc;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename HloOpTy, typename LhloOpTy>
|
template <typename HloOpTy>
|
||||||
class HloToLhloOpConverter : public ConversionPattern {
|
class HloToLhloOpConverter : public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
explicit HloToLhloOpConverter(MLIRContext* context)
|
explicit HloToLhloOpConverter(MLIRContext* context)
|
||||||
@ -147,14 +148,14 @@ class HloToLhloOpConverter : public ConversionPattern {
|
|||||||
op->getLoc(), result.value(), shape_value, &rewriter));
|
op->getLoc(), result.value(), shape_value, &rewriter));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rewriter.create<LhloOpTy>(op->getLoc(), llvm::None, buffer_args,
|
rewriter.create<xla_hlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||||
op->getAttrs());
|
buffer_args, op->getAttrs());
|
||||||
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct HloToLHloDynamicBroadcastInDimOpConverter
|
struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||||
: public OpConversionPattern<xla_hlo::DynamicBroadcastInDimOp> {
|
: public OpConversionPattern<xla_hlo::DynamicBroadcastInDimOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
@ -178,7 +179,7 @@ struct HloToLHloDynamicBroadcastInDimOpConverter
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct HloToLHloReduceOpConverter
|
struct HloToLhloReduceOpConverter
|
||||||
: public OpConversionPattern<xla_hlo::ReduceOp> {
|
: public OpConversionPattern<xla_hlo::ReduceOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
@ -438,36 +439,35 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
|||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
HloToLHloDynamicBroadcastInDimOpConverter,
|
HloToLhloDynamicBroadcastInDimOpConverter,
|
||||||
HloToLhloFuncOpConverter,
|
HloToLhloFuncOpConverter,
|
||||||
HloToLhloOpConverter<xla_hlo::AbsOp, xla_lhlo::AbsOp>,
|
HloToLhloOpConverter<xla_hlo::AbsOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::AddOp, xla_lhlo::AddOp>,
|
HloToLhloOpConverter<xla_hlo::AddOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::AndOp, xla_lhlo::AndOp>,
|
HloToLhloOpConverter<xla_hlo::AndOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::BroadcastInDimOp,
|
HloToLhloOpConverter<xla_hlo::BroadcastInDimOp>,
|
||||||
xla_lhlo::BroadcastInDimOp>,
|
HloToLhloOpConverter<xla_hlo::CeilOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::CeilOp, xla_lhlo::CeilOp>,
|
HloToLhloOpConverter<xla_hlo::CompareOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::CompareOp, xla_lhlo::CompareOp>,
|
HloToLhloOpConverter<xla_hlo::ConstOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::ConstOp, xla_lhlo::ConstOp>,
|
HloToLhloOpConverter<xla_hlo::ConvertOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::ConvertOp, xla_lhlo::ConvertOp>,
|
HloToLhloOpConverter<xla_hlo::CopyOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::CopyOp, xla_lhlo::CopyOp>,
|
HloToLhloOpConverter<xla_hlo::CosOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::CosOp, xla_lhlo::CosOp>,
|
HloToLhloOpConverter<xla_hlo::DivOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::DivOp, xla_lhlo::DivOp>,
|
HloToLhloOpConverter<xla_hlo::ExpOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::ExpOp, xla_lhlo::ExpOp>,
|
HloToLhloOpConverter<xla_hlo::IotaOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::IotaOp, xla_lhlo::IotaOp>,
|
HloToLhloOpConverter<xla_hlo::LogOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::LogOp, xla_lhlo::LogOp>,
|
HloToLhloOpConverter<xla_hlo::MaxOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::MaxOp, xla_lhlo::MaxOp>,
|
HloToLhloOpConverter<xla_hlo::MinOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::MinOp, xla_lhlo::MinOp>,
|
HloToLhloOpConverter<xla_hlo::MulOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::MulOp, xla_lhlo::MulOp>,
|
HloToLhloOpConverter<xla_hlo::NegOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::NegOp, xla_lhlo::NegOp>,
|
HloToLhloOpConverter<xla_hlo::RemOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::RemOp, xla_lhlo::RemOp>,
|
HloToLhloOpConverter<xla_hlo::SelectOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::SelectOp, xla_lhlo::SelectOp>,
|
HloToLhloOpConverter<xla_hlo::SignOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::SignOp, xla_lhlo::SignOp>,
|
HloToLhloOpConverter<xla_hlo::SubOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::SubOp, xla_lhlo::SubOp>,
|
HloToLhloOpConverter<xla_hlo::TanhOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::TanhOp, xla_lhlo::TanhOp>,
|
HloToLhloReduceOpConverter,
|
||||||
HloToLHloReduceOpConverter,
|
|
||||||
StdToLhloReturnOpConverter,
|
|
||||||
HloToLhloTensorLoadOpConverter,
|
HloToLhloTensorLoadOpConverter,
|
||||||
HloToLhloTensorStoreOpConverter
|
HloToLhloTensorStoreOpConverter,
|
||||||
|
StdToLhloReturnOpConverter
|
||||||
>(context);
|
>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
@ -31,11 +31,11 @@ namespace mlir {
|
|||||||
namespace xla_lhlo {
|
namespace xla_lhlo {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename LhloOp>
|
template <typename LhloOpTy>
|
||||||
struct BinaryOpConverter : public OpRewritePattern<LhloOp> {
|
struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
|
||||||
using OpRewritePattern<LhloOp>::OpRewritePattern;
|
using OpRewritePattern<LhloOpTy>::OpRewritePattern;
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(LhloOp op,
|
PatternMatchResult matchAndRewrite(LhloOpTy op,
|
||||||
PatternRewriter& rewriter) const override {
|
PatternRewriter& rewriter) const override {
|
||||||
const auto& lhs = op.lhs();
|
const auto& lhs = op.lhs();
|
||||||
const auto& rhs = op.rhs();
|
const auto& rhs = op.rhs();
|
||||||
@ -56,8 +56,8 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOp> {
|
|||||||
}
|
}
|
||||||
auto l = rewriter.create<LoadOp>(loc, lhs, induction_vars);
|
auto l = rewriter.create<LoadOp>(loc, lhs, induction_vars);
|
||||||
auto r = rewriter.create<LoadOp>(loc, rhs, induction_vars);
|
auto r = rewriter.create<LoadOp>(loc, rhs, induction_vars);
|
||||||
Value opResult = MapXlaOpToStdScalarOp<LhloOp>(
|
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
|
||||||
llvm::cast<LhloOp>(op), element_type, {l, r}, &rewriter);
|
op, element_type, {l, r}, &rewriter);
|
||||||
if (opResult == nullptr) {
|
if (opResult == nullptr) {
|
||||||
return this->matchFailure();
|
return this->matchFailure();
|
||||||
}
|
}
|
||||||
|
70
tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h
Normal file
70
tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_hlo {
|
||||||
|
|
||||||
|
template <typename HloOpTy>
|
||||||
|
struct HloToLhloOpImpl {
|
||||||
|
using Type = std::false_type;
|
||||||
|
};
|
||||||
|
template <typename HloOpTy>
|
||||||
|
using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
|
||||||
|
|
||||||
|
#define MAP_HLO_TO_LHLO(OpName) \
|
||||||
|
template <> \
|
||||||
|
struct HloToLhloOpImpl<xla_hlo::OpName> { \
|
||||||
|
using Type = xla_lhlo::OpName; \
|
||||||
|
}
|
||||||
|
|
||||||
|
MAP_HLO_TO_LHLO(AbsOp);
|
||||||
|
MAP_HLO_TO_LHLO(AddOp);
|
||||||
|
MAP_HLO_TO_LHLO(AndOp);
|
||||||
|
MAP_HLO_TO_LHLO(BroadcastInDimOp);
|
||||||
|
MAP_HLO_TO_LHLO(CeilOp);
|
||||||
|
MAP_HLO_TO_LHLO(ConstOp);
|
||||||
|
MAP_HLO_TO_LHLO(CompareOp);
|
||||||
|
MAP_HLO_TO_LHLO(ConvertOp);
|
||||||
|
MAP_HLO_TO_LHLO(CopyOp);
|
||||||
|
MAP_HLO_TO_LHLO(CosOp);
|
||||||
|
MAP_HLO_TO_LHLO(DivOp);
|
||||||
|
MAP_HLO_TO_LHLO(ExpOp);
|
||||||
|
MAP_HLO_TO_LHLO(IotaOp);
|
||||||
|
MAP_HLO_TO_LHLO(LogOp);
|
||||||
|
MAP_HLO_TO_LHLO(MaxOp);
|
||||||
|
MAP_HLO_TO_LHLO(MinOp);
|
||||||
|
MAP_HLO_TO_LHLO(MulOp);
|
||||||
|
MAP_HLO_TO_LHLO(NegOp);
|
||||||
|
MAP_HLO_TO_LHLO(ReduceOp);
|
||||||
|
MAP_HLO_TO_LHLO(RemOp);
|
||||||
|
MAP_HLO_TO_LHLO(SelectOp);
|
||||||
|
MAP_HLO_TO_LHLO(SignOp);
|
||||||
|
MAP_HLO_TO_LHLO(SubOp);
|
||||||
|
MAP_HLO_TO_LHLO(TanhOp);
|
||||||
|
|
||||||
|
#undef MAP_HLO_TO_LHLO
|
||||||
|
|
||||||
|
} // namespace xla_hlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
|
@ -21,81 +21,63 @@ limitations under the License.
|
|||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
|
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace xla_lhlo {
|
namespace xla_lhlo {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
template <typename LHLO_BinaryOp>
|
// A struct to map LhloBinaryOpTy type to the corresponding floating-point and
|
||||||
struct ScalarOp;
|
// integer scalar operation types.
|
||||||
|
template <typename LhloBinaryOpTy>
|
||||||
|
struct LhloToScalarOp;
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct ScalarOp<xla_lhlo::AddOp> {
|
struct LhloToScalarOp<xla_lhlo::AddOp> {
|
||||||
using FOp = ::mlir::AddFOp;
|
using FOp = ::mlir::AddFOp;
|
||||||
using IOp = ::mlir::AddIOp;
|
using IOp = ::mlir::AddIOp;
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct ScalarOp<xla_hlo::AddOp> {
|
struct LhloToScalarOp<xla_lhlo::CompareOp> {
|
||||||
using FOp = ::mlir::AddFOp;
|
|
||||||
using IOp = ::mlir::AddIOp;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
struct ScalarOp<xla_lhlo::CompareOp> {
|
|
||||||
using FOp = ::mlir::CmpFOp;
|
using FOp = ::mlir::CmpFOp;
|
||||||
using IOp = ::mlir::CmpIOp;
|
using IOp = ::mlir::CmpIOp;
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct ScalarOp<xla_hlo::CompareOp> {
|
struct LhloToScalarOp<xla_lhlo::DivOp> {
|
||||||
using FOp = ::mlir::CmpFOp;
|
|
||||||
using IOp = ::mlir::CmpIOp;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
struct ScalarOp<xla_lhlo::DivOp> {
|
|
||||||
using FOp = ::mlir::DivFOp;
|
using FOp = ::mlir::DivFOp;
|
||||||
using IOp = ::mlir::SignedDivIOp;
|
using IOp = ::mlir::SignedDivIOp;
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct ScalarOp<xla_hlo::DivOp> {
|
struct LhloToScalarOp<xla_lhlo::MulOp> {
|
||||||
using FOp = ::mlir::DivFOp;
|
|
||||||
using IOp = ::mlir::SignedDivIOp;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
struct ScalarOp<xla_lhlo::MulOp> {
|
|
||||||
using FOp = ::mlir::MulFOp;
|
using FOp = ::mlir::MulFOp;
|
||||||
using IOp = ::mlir::MulIOp;
|
using IOp = ::mlir::MulIOp;
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct ScalarOp<xla_hlo::MulOp> {
|
struct LhloToScalarOp<xla_lhlo::RemOp> {
|
||||||
using FOp = ::mlir::MulFOp;
|
|
||||||
using IOp = ::mlir::MulIOp;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
struct ScalarOp<xla_lhlo::RemOp> {
|
|
||||||
using FOp = ::mlir::RemFOp;
|
using FOp = ::mlir::RemFOp;
|
||||||
using IOp = ::mlir::SignedRemIOp;
|
using IOp = ::mlir::SignedRemIOp;
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct ScalarOp<xla_hlo::RemOp> {
|
struct LhloToScalarOp<xla_lhlo::SubOp> {
|
||||||
using FOp = ::mlir::RemFOp;
|
|
||||||
using IOp = ::mlir::SignedRemIOp;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
struct ScalarOp<xla_lhlo::SubOp> {
|
|
||||||
using FOp = ::mlir::SubFOp;
|
|
||||||
using IOp = ::mlir::SubIOp;
|
|
||||||
};
|
|
||||||
template <>
|
|
||||||
struct ScalarOp<xla_hlo::SubOp> {
|
|
||||||
using FOp = ::mlir::SubFOp;
|
using FOp = ::mlir::SubFOp;
|
||||||
using IOp = ::mlir::SubIOp;
|
using IOp = ::mlir::SubIOp;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename XLA_BinaryOp>
|
template <typename LhloBinaryOpTy>
|
||||||
using ScalarFOp = typename ScalarOp<XLA_BinaryOp>::FOp;
|
struct ScalarOp {
|
||||||
template <typename XLA_BinaryOp>
|
using FOp = typename LhloToScalarOp<LhloBinaryOpTy>::FOp;
|
||||||
using ScalarIOp = typename ScalarOp<XLA_BinaryOp>::IOp;
|
using IOp = typename LhloToScalarOp<LhloBinaryOpTy>::IOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Alias for the map from LHLO binary op type to STD floating-point op type.
|
||||||
|
template <typename LhloOp>
|
||||||
|
using ScalarFOp = typename ScalarOp<LhloOp>::FOp;
|
||||||
|
// Alias for the map from LHLO binary op type to STD integer op type.
|
||||||
|
template <typename LhloOp>
|
||||||
|
using ScalarIOp = typename ScalarOp<LhloOp>::IOp;
|
||||||
|
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
struct MapXlaOpToStdScalarOpImpl {
|
struct MapLhloOpToStdScalarOpImpl {
|
||||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -103,7 +85,7 @@ struct MapXlaOpToStdScalarOpImpl {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename StdScalarOp>
|
template <typename StdScalarOp>
|
||||||
struct MapXlaOpToStdScalarOpImpl<StdScalarOp> {
|
struct MapLhloOpToStdScalarOpImpl<StdScalarOp> {
|
||||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
|
return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
|
||||||
@ -111,7 +93,7 @@ struct MapXlaOpToStdScalarOpImpl<StdScalarOp> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename SupportedType, typename StdScalarOp, typename... Args>
|
template <typename SupportedType, typename StdScalarOp, typename... Args>
|
||||||
struct MapXlaOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
|
struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
|
||||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
Type element_type = args.front().getType();
|
Type element_type = args.front().getType();
|
||||||
@ -119,52 +101,34 @@ struct MapXlaOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
|
|||||||
return b->template create<StdScalarOp>(loc, result_types, args,
|
return b->template create<StdScalarOp>(loc, result_types, args,
|
||||||
mlir::None);
|
mlir::None);
|
||||||
}
|
}
|
||||||
return MapXlaOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b);
|
return MapLhloOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename XlaOp>
|
// Inserts the computation that corresponds to the body of the loop for lowered
|
||||||
inline Value MapXlaOpToStdScalarOp(XlaOp xla_op, ArrayRef<Type> result_types,
|
// LHLO unary/binary op. Returns the value for the result.
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
template <typename LhloOpTy>
|
||||||
return MapXlaOpToStdScalarOpImpl<IntegerType, ScalarIOp<XlaOp>, FloatType,
|
inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
|
||||||
ScalarFOp<XlaOp>>{}(xla_op.getLoc(),
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
result_types, args, b);
|
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType,
|
||||||
}
|
ScalarFOp<LhloOpTy>>{}(loc, result_types,
|
||||||
|
args, b);
|
||||||
// TODO(ravishankarm): Find a way to reduce code-bloat in HLO and LHLO
|
|
||||||
// specialization.
|
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::AbsOp>(xla_lhlo::AbsOp xla_op,
|
|
||||||
ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
|
||||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
|
|
||||||
xla_op.getLoc(), result_types, args, b);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::AbsOp>(xla_hlo::AbsOp xla_op,
|
|
||||||
ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
|
||||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
|
|
||||||
xla_op.getLoc(), result_types, args, b);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::AndOp>(xla_lhlo::AndOp xla_op,
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
|
||||||
ArrayRef<Type> result_types,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
ArrayRef<Value> args,
|
OpBuilder* b) {
|
||||||
OpBuilder* b) {
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
|
||||||
return MapXlaOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
|
loc, result_types, args, b);
|
||||||
xla_op.getLoc(), result_types, args, b);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::AndOp>(xla_hlo::AndOp xla_op,
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
|
||||||
ArrayRef<Type> result_types,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
ArrayRef<Value> args,
|
OpBuilder* b) {
|
||||||
OpBuilder* b) {
|
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
|
||||||
return MapXlaOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
|
loc, result_types, args, b);
|
||||||
xla_op.getLoc(), result_types, args, b);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename PredicateType>
|
template <typename PredicateType>
|
||||||
@ -200,7 +164,8 @@ inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename XLACompareOpTy>
|
template <typename XLACompareOpTy>
|
||||||
inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op,
|
inline Value MapXlaCompareOpToStdScalarOp(Location loc,
|
||||||
|
StringRef comparison_direction,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
const auto& lhs = args[0];
|
const auto& lhs = args[0];
|
||||||
@ -208,101 +173,60 @@ inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op,
|
|||||||
Type element_type = lhs.getType();
|
Type element_type = lhs.getType();
|
||||||
if (element_type.isSignlessInteger()) {
|
if (element_type.isSignlessInteger()) {
|
||||||
Optional<CmpIPredicate> predicate =
|
Optional<CmpIPredicate> predicate =
|
||||||
getCmpPredicate<CmpIPredicate>(xla_op.comparison_direction());
|
getCmpPredicate<CmpIPredicate>(comparison_direction);
|
||||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||||
return b->create<ScalarIOp<XLACompareOpTy>>(xla_op.getLoc(),
|
return b->create<ScalarIOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
|
||||||
predicate.getValue(), lhs, rhs);
|
rhs);
|
||||||
}
|
}
|
||||||
if (element_type.isa<FloatType>()) {
|
if (element_type.isa<FloatType>()) {
|
||||||
Optional<CmpFPredicate> predicate =
|
Optional<CmpFPredicate> predicate =
|
||||||
getCmpPredicate<CmpFPredicate>(xla_op.comparison_direction());
|
getCmpPredicate<CmpFPredicate>(comparison_direction);
|
||||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||||
return b->create<ScalarFOp<XLACompareOpTy>>(xla_op.getLoc(),
|
return b->create<ScalarFOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
|
||||||
predicate.getValue(), lhs, rhs);
|
rhs);
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CompareOp>(
|
|
||||||
xla_lhlo::CompareOp xla_op, ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
|
||||||
return MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(xla_op, result_types,
|
|
||||||
args, b);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::CompareOp>(
|
|
||||||
xla_hlo::CompareOp xla_op, ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
|
||||||
return MapXlaCompareOpToStdScalarOp<xla_hlo::CompareOp>(xla_op, result_types,
|
|
||||||
args, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CopyOp>(
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CopyOp>(
|
||||||
xla_lhlo::CopyOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return args.front();
|
return args.front();
|
||||||
}
|
}
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::CopyOp>(xla_hlo::CopyOp xla_op,
|
|
||||||
ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
|
||||||
return args.front();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::ExpOp>(xla_lhlo::ExpOp xla_op,
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
|
||||||
ArrayRef<Type> result_types,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
|
||||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
|
|
||||||
xla_op.getLoc(), result_types, args, b);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::ExpOp>(xla_hlo::ExpOp xla_op,
|
|
||||||
ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
|
||||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
|
|
||||||
xla_op.getLoc(), result_types, args, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CeilOp>(
|
|
||||||
xla_lhlo::CeilOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
|
||||||
xla_op.getLoc(), result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::CeilOp>(xla_hlo::CeilOp xla_op,
|
|
||||||
ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
|
||||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
|
|
||||||
xla_op.getLoc(), result_types, args, b);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CeilOp>(
|
||||||
xla_lhlo::ConvertOp xla_op, ArrayRef<Type> result_types,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
Type sourceType = args.front().getType();
|
Type sourceType = args.front().getType();
|
||||||
Type targetType = result_types.front();
|
Type targetType = result_types.front();
|
||||||
|
|
||||||
if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
|
if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
|
||||||
return b->create<mlir::SIToFPOp>(xla_op.getLoc(), result_types, args,
|
return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None);
|
||||||
mlir::None);
|
|
||||||
} else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) {
|
} else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) {
|
||||||
FloatType src = sourceType.cast<FloatType>();
|
FloatType src = sourceType.cast<FloatType>();
|
||||||
FloatType res = targetType.cast<FloatType>();
|
FloatType res = targetType.cast<FloatType>();
|
||||||
if (src.getWidth() > res.getWidth()) {
|
if (src.getWidth() > res.getWidth()) {
|
||||||
return b->create<mlir::FPTruncOp>(xla_op.getLoc(), result_types, args,
|
return b->create<mlir::FPTruncOp>(loc, result_types, args, mlir::None);
|
||||||
mlir::None);
|
|
||||||
} else if (src.getWidth() < res.getWidth()) {
|
} else if (src.getWidth() < res.getWidth()) {
|
||||||
return b->create<mlir::FPExtOp>(xla_op.getLoc(), result_types, args,
|
return b->create<mlir::FPExtOp>(loc, result_types, args, mlir::None);
|
||||||
mlir::None);
|
|
||||||
}
|
}
|
||||||
// No conversion is needed for the same width floats
|
// No conversion is needed for the same width floats
|
||||||
return args.front();
|
return args.front();
|
||||||
@ -311,10 +235,9 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
|
|||||||
IntegerType src = sourceType.cast<IntegerType>();
|
IntegerType src = sourceType.cast<IntegerType>();
|
||||||
IntegerType res = targetType.cast<IntegerType>();
|
IntegerType res = targetType.cast<IntegerType>();
|
||||||
if (src.getWidth() > res.getWidth()) {
|
if (src.getWidth() > res.getWidth()) {
|
||||||
return b->create<mlir::TruncateIOp>(xla_op.getLoc(), result_types, args,
|
return b->create<mlir::TruncateIOp>(loc, result_types, args, mlir::None);
|
||||||
mlir::None);
|
|
||||||
} else if (src.getWidth() < res.getWidth()) {
|
} else if (src.getWidth() < res.getWidth()) {
|
||||||
return b->create<mlir::ZeroExtendIOp>(xla_op.getLoc(), result_types, args,
|
return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
|
||||||
mlir::None);
|
mlir::None);
|
||||||
}
|
}
|
||||||
// No conversion is needed for the same width integers
|
// No conversion is needed for the same width integers
|
||||||
@ -322,35 +245,25 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
|
|||||||
}
|
}
|
||||||
// TODO(dfki-ehna): Add other primitive type conversions
|
// TODO(dfki-ehna): Add other primitive type conversions
|
||||||
// if (mlir::FpToSiOp::areCastCompatible(sourceType, targetType)) {
|
// if (mlir::FpToSiOp::areCastCompatible(sourceType, targetType)) {
|
||||||
// return b.create<mlir::FpToSiOp>(xla_op.getLoc(), result_types,
|
// return b.create<mlir::FpToSiOp>(loc, result_types,
|
||||||
// args,mlir::None);
|
// args,mlir::None);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CosOp>(xla_lhlo::CosOp xla_op,
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CosOp>(
|
||||||
ArrayRef<Type> result_types,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
ArrayRef<Value> args,
|
OpBuilder* b) {
|
||||||
OpBuilder* b) {
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
|
||||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
|
loc, result_types, args, b);
|
||||||
xla_op.getLoc(), result_types, args, b);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::CosOp>(xla_hlo::CosOp xla_op,
|
|
||||||
ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
|
||||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
|
|
||||||
xla_op.getLoc(), result_types, args, b);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implements the conversion of XLA op to scalar op (to use within region of a
|
/// Implements the conversion of XLA op to scalar op (to use within region of a
|
||||||
/// linalg.generic op) for compare-select style operations like min/max.
|
/// linalg.generic op) for compare-select style operations like min/max.
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
struct MapXlaCompareSelectOpToStdScalarOp {
|
struct XlaCompareSelectOpToStdScalarOp {
|
||||||
Value operator()(Location loc, StringRef comparison_direction,
|
static Value map(Location loc, StringRef comparison_direction,
|
||||||
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -361,9 +274,9 @@ struct MapXlaCompareSelectOpToStdScalarOp {
|
|||||||
/// dialect with a given predicate based on the element type of the operand.
|
/// dialect with a given predicate based on the element type of the operand.
|
||||||
template <typename SupportedType, typename StdCompareOp, typename Predicate,
|
template <typename SupportedType, typename StdCompareOp, typename Predicate,
|
||||||
typename... Args>
|
typename... Args>
|
||||||
struct MapXlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp,
|
struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
|
||||||
Predicate, Args...> {
|
Args...> {
|
||||||
Value operator()(Location loc, StringRef comparison_direction,
|
static Value map(Location loc, StringRef comparison_direction,
|
||||||
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
Type element_type = args.front().getType();
|
Type element_type = args.front().getType();
|
||||||
@ -374,132 +287,130 @@ struct MapXlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp,
|
|||||||
args[0], args[1]);
|
args[0], args[1]);
|
||||||
return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
|
return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
|
||||||
}
|
}
|
||||||
return MapXlaCompareSelectOpToStdScalarOp<Args...>{}(
|
return XlaCompareSelectOpToStdScalarOp<Args...>::map(
|
||||||
loc, comparison_direction, result_types, args, b);
|
loc, comparison_direction, result_types, args, b);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::LogOp>(xla_lhlo::LogOp xla_op,
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::LogOp>(
|
||||||
ArrayRef<Type> result_types,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
|
||||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
|
|
||||||
xla_op.getLoc(), result_types, args, b);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::LogOp>(xla_hlo::LogOp xla_op,
|
|
||||||
ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
|
||||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
|
|
||||||
xla_op.getLoc(), result_types, args, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::MaxOp>(xla_lhlo::MaxOp xla_op,
|
|
||||||
ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
|
||||||
return MapXlaCompareSelectOpToStdScalarOp<
|
|
||||||
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
|
|
||||||
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "GT",
|
|
||||||
result_types, args, b);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::MaxOp>(xla_hlo::MaxOp xla_op,
|
|
||||||
ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
|
||||||
return MapXlaCompareSelectOpToStdScalarOp<
|
|
||||||
IntegerType, ScalarIOp<xla_hlo::CompareOp>, CmpIPredicate, FloatType,
|
|
||||||
ScalarFOp<xla_hlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "GT",
|
|
||||||
result_types, args, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::MinOp>(xla_lhlo::MinOp xla_op,
|
|
||||||
ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
|
||||||
return MapXlaCompareSelectOpToStdScalarOp<
|
|
||||||
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
|
|
||||||
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "LT",
|
|
||||||
result_types, args, b);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::MinOp>(xla_hlo::MinOp xla_op,
|
|
||||||
ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
|
||||||
return MapXlaCompareSelectOpToStdScalarOp<
|
|
||||||
IntegerType, ScalarIOp<xla_hlo::CompareOp>, CmpIPredicate, FloatType,
|
|
||||||
ScalarFOp<xla_hlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "LT",
|
|
||||||
result_types, args, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::NegOp>(xla_lhlo::NegOp xla_op,
|
|
||||||
ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
|
||||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
|
|
||||||
xla_op.getLoc(), result_types, args, b);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::NegOp>(xla_hlo::NegOp xla_op,
|
|
||||||
ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
|
||||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
|
|
||||||
xla_op.getLoc(), result_types, args, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::SelectOp>(
|
|
||||||
xla_lhlo::SelectOp xla_op, ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args, OpBuilder* b) {
|
|
||||||
return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(),
|
|
||||||
result_types, args, b);
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::SelectOp>(
|
|
||||||
xla_hlo::SelectOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(),
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
|
||||||
result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::SignOp>(
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
|
||||||
xla_lhlo::SignOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return XlaCompareSelectOpToStdScalarOp<
|
||||||
|
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||||
|
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
|
||||||
|
result_types, args,
|
||||||
|
b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return XlaCompareSelectOpToStdScalarOp<
|
||||||
|
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||||
|
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
|
||||||
|
result_types, args,
|
||||||
|
b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
|
||||||
|
b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
Type element_type = args.front().getType();
|
Type element_type = args.front().getType();
|
||||||
if (element_type.isa<FloatType>()) {
|
if (element_type.isa<FloatType>()) {
|
||||||
FloatType float_type = element_type.cast<FloatType>();
|
FloatType float_type = element_type.cast<FloatType>();
|
||||||
APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0);
|
APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0);
|
||||||
Value one = b->create<mlir::ConstantFloatOp>(xla_op.getLoc(), const_value,
|
Value one = b->create<mlir::ConstantFloatOp>(loc, const_value, float_type);
|
||||||
float_type);
|
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
|
||||||
return b->create<::mlir::CopySignOp>(xla_op.getLoc(), result_types, one,
|
|
||||||
args[0]);
|
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::TanhOp>(
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
|
||||||
xla_lhlo::TanhOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
|
||||||
xla_op.getLoc(), result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
template <>
|
|
||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::TanhOp>(xla_hlo::TanhOp xla_op,
|
} // namespace impl
|
||||||
ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args,
|
struct XlaOpToStdScalarOp {
|
||||||
OpBuilder* b) {
|
// Implementation for LHLO ops except xla_lhlo::CompareOp.
|
||||||
return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
|
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
|
||||||
xla_op.getLoc(), result_types, args, b);
|
typename = std::enable_if_t<
|
||||||
|
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
|
||||||
|
std::is_same<typename xla_hlo::HloToLhloOp<LhloOpTy>,
|
||||||
|
std::false_type>::value>>
|
||||||
|
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
|
||||||
|
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
|
||||||
|
args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation for HLO ops except xla_hlo::CompareOp.
|
||||||
|
template <typename XlaOpTy, typename LhloOpTy = xla_hlo::HloToLhloOp<XlaOpTy>,
|
||||||
|
typename = std::enable_if_t<
|
||||||
|
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
|
||||||
|
!std::is_same<LhloOpTy, std::false_type>::value>>
|
||||||
|
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args, OpBuilder* b, int i = 0) {
|
||||||
|
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
|
||||||
|
args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation for xla_lhlo::CompareOp.
|
||||||
|
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
|
||||||
|
LhloOpTy, xla_lhlo::CompareOp>::value>>
|
||||||
|
static Value map(xla_lhlo::CompareOp op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
|
auto comparison_direction = op.comparison_direction();
|
||||||
|
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||||
|
op.getLoc(), comparison_direction, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation for xla_hlo::CompareOp.
|
||||||
|
template <typename HloOpTy, typename = std::enable_if_t<std::is_same<
|
||||||
|
HloOpTy, xla_hlo::CompareOp>::value>>
|
||||||
|
static Value map(xla_hlo::CompareOp op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
|
auto comparison_direction = op.comparison_direction();
|
||||||
|
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||||
|
op.getLoc(), comparison_direction, result_types, args, b);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename XlaOpTy>
|
||||||
|
inline Value MapXlaOpToStdScalarOp(XlaOpTy xla_op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
|
return XlaOpToStdScalarOp::map<XlaOpTy>(xla_op, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla_lhlo
|
} // namespace xla_lhlo
|
||||||
|
@ -149,8 +149,8 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
|||||||
rewriter.setInsertionPointToEnd(block);
|
rewriter.setInsertionPointToEnd(block);
|
||||||
// TODO(ravishankarm) : For now use the method in xla_lhlo namespace. That
|
// TODO(ravishankarm) : For now use the method in xla_lhlo namespace. That
|
||||||
// method needs to be moved out of there.
|
// method needs to be moved out of there.
|
||||||
Value opResult = xla_lhlo::MapXlaOpToStdScalarOp<OpTy>(
|
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<OpTy>(
|
||||||
llvm::cast<OpTy>(op), bodyResultTypes, bodyArgs, &rewriter);
|
op, bodyResultTypes, bodyArgs, &rewriter);
|
||||||
if (!opResult) {
|
if (!opResult) {
|
||||||
return ConversionPattern::matchFailure();
|
return ConversionPattern::matchFailure();
|
||||||
}
|
}
|
||||||
@ -180,9 +180,9 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
|||||||
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
|
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
|
||||||
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
|
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
|
||||||
// TODO(ravishankarm) : Move this method out of xla_lhlo namespace.
|
// TODO(ravishankarm) : Move this method out of xla_lhlo namespace.
|
||||||
Value opResult = xla_lhlo::MapXlaOpToStdScalarOp<LhloOp>(
|
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOp>(
|
||||||
llvm::cast<LhloOp>(lhlo_op), argType.getElementType(),
|
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
|
||||||
llvm::ArrayRef<Value>{lhs, rhs}, &rewriter);
|
&rewriter);
|
||||||
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
|
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
|
||||||
rewriter.eraseOp(lhlo_op);
|
rewriter.eraseOp(lhlo_op);
|
||||||
return ConversionPattern::matchSuccess();
|
return ConversionPattern::matchSuccess();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user