[XLA][MLIR] Reduce code bloat for LHLO->STD and HLO->STD patterns.
PiperOrigin-RevId: 298840878 Change-Id: I781008f01b5c8e478d75ba282db9aa78da546ea1
This commit is contained in:
		
							parent
							
								
									8a53e358fc
								
							
						
					
					
						commit
						4aaabc836e
					
				@ -133,16 +133,25 @@ cc_library(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
cc_library(
 | 
					cc_library(
 | 
				
			||||||
    name = "map_xla_to_scalar_op",
 | 
					    name = "map_xla_to_scalar_op",
 | 
				
			||||||
    srcs = [],
 | 
					 | 
				
			||||||
    hdrs = ["transforms/map_xla_to_scalar_op.h"],
 | 
					    hdrs = ["transforms/map_xla_to_scalar_op.h"],
 | 
				
			||||||
    deps = [
 | 
					    deps = [
 | 
				
			||||||
        ":hlo",
 | 
					        ":hlo",
 | 
				
			||||||
        ":lhlo",
 | 
					        ":lhlo",
 | 
				
			||||||
 | 
					        ":map_hlo_to_lhlo_op",
 | 
				
			||||||
        "@llvm-project//llvm:support",
 | 
					        "@llvm-project//llvm:support",
 | 
				
			||||||
        "@llvm-project//mlir:StandardOps",
 | 
					        "@llvm-project//mlir:StandardOps",
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cc_library(
 | 
				
			||||||
 | 
					    name = "map_hlo_to_lhlo_op",
 | 
				
			||||||
 | 
					    hdrs = ["transforms/map_hlo_to_lhlo_op.h"],
 | 
				
			||||||
 | 
					    deps = [
 | 
				
			||||||
 | 
					        ":hlo",
 | 
				
			||||||
 | 
					        ":lhlo",
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cc_library(
 | 
					cc_library(
 | 
				
			||||||
    name = "hlo_shape_derivation",
 | 
					    name = "hlo_shape_derivation",
 | 
				
			||||||
    srcs = [],
 | 
					    srcs = [],
 | 
				
			||||||
@ -234,6 +243,7 @@ cc_library(
 | 
				
			|||||||
        ":hlo",
 | 
					        ":hlo",
 | 
				
			||||||
        ":hlo_shape_derivation",
 | 
					        ":hlo_shape_derivation",
 | 
				
			||||||
        ":lhlo",
 | 
					        ":lhlo",
 | 
				
			||||||
 | 
					        ":map_hlo_to_lhlo_op",
 | 
				
			||||||
        "@com_google_absl//absl/memory",
 | 
					        "@com_google_absl//absl/memory",
 | 
				
			||||||
        "@llvm-project//mlir:IR",
 | 
					        "@llvm-project//mlir:IR",
 | 
				
			||||||
        "@llvm-project//mlir:Pass",
 | 
					        "@llvm-project//mlir:Pass",
 | 
				
			||||||
 | 
				
			|||||||
@ -31,6 +31,7 @@ limitations under the License.
 | 
				
			|||||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
 | 
					#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
 | 
				
			||||||
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
 | 
					#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
 | 
				
			||||||
#include "tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h"
 | 
					#include "tensorflow/compiler/mlir/xla/transforms/hlo_shape_derivation.h"
 | 
				
			||||||
 | 
					#include "tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h"
 | 
				
			||||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
 | 
					#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
 | 
				
			||||||
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
 | 
					#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -117,7 +118,7 @@ Value InsertAllocAndDealloc(Location loc, Value result,
 | 
				
			|||||||
  return alloc;
 | 
					  return alloc;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename HloOpTy, typename LhloOpTy>
 | 
					template <typename HloOpTy>
 | 
				
			||||||
class HloToLhloOpConverter : public ConversionPattern {
 | 
					class HloToLhloOpConverter : public ConversionPattern {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  explicit HloToLhloOpConverter(MLIRContext* context)
 | 
					  explicit HloToLhloOpConverter(MLIRContext* context)
 | 
				
			||||||
@ -147,14 +148,14 @@ class HloToLhloOpConverter : public ConversionPattern {
 | 
				
			|||||||
            op->getLoc(), result.value(), shape_value, &rewriter));
 | 
					            op->getLoc(), result.value(), shape_value, &rewriter));
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    rewriter.create<LhloOpTy>(op->getLoc(), llvm::None, buffer_args,
 | 
					    rewriter.create<xla_hlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
 | 
				
			||||||
                              op->getAttrs());
 | 
					                                                   buffer_args, op->getAttrs());
 | 
				
			||||||
    rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
 | 
					    rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
 | 
				
			||||||
    return matchSuccess();
 | 
					    return matchSuccess();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct HloToLHloDynamicBroadcastInDimOpConverter
 | 
					struct HloToLhloDynamicBroadcastInDimOpConverter
 | 
				
			||||||
    : public OpConversionPattern<xla_hlo::DynamicBroadcastInDimOp> {
 | 
					    : public OpConversionPattern<xla_hlo::DynamicBroadcastInDimOp> {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  using OpConversionPattern::OpConversionPattern;
 | 
					  using OpConversionPattern::OpConversionPattern;
 | 
				
			||||||
@ -178,7 +179,7 @@ struct HloToLHloDynamicBroadcastInDimOpConverter
 | 
				
			|||||||
  }
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct HloToLHloReduceOpConverter
 | 
					struct HloToLhloReduceOpConverter
 | 
				
			||||||
    : public OpConversionPattern<xla_hlo::ReduceOp> {
 | 
					    : public OpConversionPattern<xla_hlo::ReduceOp> {
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  using OpConversionPattern::OpConversionPattern;
 | 
					  using OpConversionPattern::OpConversionPattern;
 | 
				
			||||||
@ -438,36 +439,35 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
 | 
				
			|||||||
                                        OwningRewritePatternList* patterns) {
 | 
					                                        OwningRewritePatternList* patterns) {
 | 
				
			||||||
  // clang-format off
 | 
					  // clang-format off
 | 
				
			||||||
  patterns->insert<
 | 
					  patterns->insert<
 | 
				
			||||||
      HloToLHloDynamicBroadcastInDimOpConverter,
 | 
					      HloToLhloDynamicBroadcastInDimOpConverter,
 | 
				
			||||||
      HloToLhloFuncOpConverter,
 | 
					      HloToLhloFuncOpConverter,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::AbsOp, xla_lhlo::AbsOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::AbsOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::AddOp, xla_lhlo::AddOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::AddOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::AndOp, xla_lhlo::AndOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::AndOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::BroadcastInDimOp,
 | 
					      HloToLhloOpConverter<xla_hlo::BroadcastInDimOp>,
 | 
				
			||||||
                           xla_lhlo::BroadcastInDimOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::CeilOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::CeilOp, xla_lhlo::CeilOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::CompareOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::CompareOp, xla_lhlo::CompareOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::ConstOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::ConstOp, xla_lhlo::ConstOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::ConvertOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::ConvertOp, xla_lhlo::ConvertOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::CopyOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::CopyOp, xla_lhlo::CopyOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::CosOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::CosOp, xla_lhlo::CosOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::DivOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::DivOp, xla_lhlo::DivOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::ExpOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::ExpOp, xla_lhlo::ExpOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::IotaOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::IotaOp, xla_lhlo::IotaOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::LogOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::LogOp, xla_lhlo::LogOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::MaxOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::MaxOp, xla_lhlo::MaxOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::MinOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::MinOp, xla_lhlo::MinOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::MulOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::MulOp, xla_lhlo::MulOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::NegOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::NegOp, xla_lhlo::NegOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::RemOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::RemOp, xla_lhlo::RemOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::SelectOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::SelectOp, xla_lhlo::SelectOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::SignOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::SignOp, xla_lhlo::SignOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::SubOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::SubOp, xla_lhlo::SubOp>,
 | 
					      HloToLhloOpConverter<xla_hlo::TanhOp>,
 | 
				
			||||||
      HloToLhloOpConverter<xla_hlo::TanhOp, xla_lhlo::TanhOp>,
 | 
					      HloToLhloReduceOpConverter,
 | 
				
			||||||
      HloToLHloReduceOpConverter,
 | 
					 | 
				
			||||||
      StdToLhloReturnOpConverter,
 | 
					 | 
				
			||||||
      HloToLhloTensorLoadOpConverter,
 | 
					      HloToLhloTensorLoadOpConverter,
 | 
				
			||||||
      HloToLhloTensorStoreOpConverter
 | 
					      HloToLhloTensorStoreOpConverter,
 | 
				
			||||||
 | 
					      StdToLhloReturnOpConverter
 | 
				
			||||||
  >(context);
 | 
					  >(context);
 | 
				
			||||||
  // clang-format on
 | 
					  // clang-format on
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -31,11 +31,11 @@ namespace mlir {
 | 
				
			|||||||
namespace xla_lhlo {
 | 
					namespace xla_lhlo {
 | 
				
			||||||
namespace {
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename LhloOp>
 | 
					template <typename LhloOpTy>
 | 
				
			||||||
struct BinaryOpConverter : public OpRewritePattern<LhloOp> {
 | 
					struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
 | 
				
			||||||
  using OpRewritePattern<LhloOp>::OpRewritePattern;
 | 
					  using OpRewritePattern<LhloOpTy>::OpRewritePattern;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  PatternMatchResult matchAndRewrite(LhloOp op,
 | 
					  PatternMatchResult matchAndRewrite(LhloOpTy op,
 | 
				
			||||||
                                     PatternRewriter& rewriter) const override {
 | 
					                                     PatternRewriter& rewriter) const override {
 | 
				
			||||||
    const auto& lhs = op.lhs();
 | 
					    const auto& lhs = op.lhs();
 | 
				
			||||||
    const auto& rhs = op.rhs();
 | 
					    const auto& rhs = op.rhs();
 | 
				
			||||||
@ -56,8 +56,8 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOp> {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
    auto l = rewriter.create<LoadOp>(loc, lhs, induction_vars);
 | 
					    auto l = rewriter.create<LoadOp>(loc, lhs, induction_vars);
 | 
				
			||||||
    auto r = rewriter.create<LoadOp>(loc, rhs, induction_vars);
 | 
					    auto r = rewriter.create<LoadOp>(loc, rhs, induction_vars);
 | 
				
			||||||
    Value opResult = MapXlaOpToStdScalarOp<LhloOp>(
 | 
					    Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
 | 
				
			||||||
        llvm::cast<LhloOp>(op), element_type, {l, r}, &rewriter);
 | 
					        op, element_type, {l, r}, &rewriter);
 | 
				
			||||||
    if (opResult == nullptr) {
 | 
					    if (opResult == nullptr) {
 | 
				
			||||||
      return this->matchFailure();
 | 
					      return this->matchFailure();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										70
									
								
								tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,70 @@
 | 
				
			|||||||
 | 
					/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					limitations under the License.
 | 
				
			||||||
 | 
					==============================================================================*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
 | 
				
			||||||
 | 
					#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <type_traits>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
 | 
				
			||||||
 | 
					#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace mlir {
 | 
				
			||||||
 | 
					namespace xla_hlo {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename HloOpTy>
 | 
				
			||||||
 | 
					struct HloToLhloOpImpl {
 | 
				
			||||||
 | 
					  using Type = std::false_type;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					template <typename HloOpTy>
 | 
				
			||||||
 | 
					using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#define MAP_HLO_TO_LHLO(OpName)             \
 | 
				
			||||||
 | 
					  template <>                               \
 | 
				
			||||||
 | 
					  struct HloToLhloOpImpl<xla_hlo::OpName> { \
 | 
				
			||||||
 | 
					    using Type = xla_lhlo::OpName;          \
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(AbsOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(AddOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(AndOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(BroadcastInDimOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(CeilOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(ConstOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(CompareOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(ConvertOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(CopyOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(CosOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(DivOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(ExpOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(IotaOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(LogOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(MaxOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(MinOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(MulOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(NegOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(ReduceOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(RemOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(SelectOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(SignOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(SubOp);
 | 
				
			||||||
 | 
					MAP_HLO_TO_LHLO(TanhOp);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#undef MAP_HLO_TO_LHLO
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}  // namespace xla_hlo
 | 
				
			||||||
 | 
					}  // namespace mlir
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#endif  // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
 | 
				
			||||||
@ -21,81 +21,63 @@ limitations under the License.
 | 
				
			|||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 | 
					#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 | 
				
			||||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
 | 
					#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
 | 
				
			||||||
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
 | 
					#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
 | 
				
			||||||
 | 
					#include "tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mlir {
 | 
					namespace mlir {
 | 
				
			||||||
namespace xla_lhlo {
 | 
					namespace xla_lhlo {
 | 
				
			||||||
 | 
					namespace impl {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename LHLO_BinaryOp>
 | 
					// A struct to map LhloBinaryOpTy type to the corresponding floating-point and
 | 
				
			||||||
struct ScalarOp;
 | 
					// integer scalar operation types.
 | 
				
			||||||
 | 
					template <typename LhloBinaryOpTy>
 | 
				
			||||||
 | 
					struct LhloToScalarOp;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
struct ScalarOp<xla_lhlo::AddOp> {
 | 
					struct LhloToScalarOp<xla_lhlo::AddOp> {
 | 
				
			||||||
  using FOp = ::mlir::AddFOp;
 | 
					  using FOp = ::mlir::AddFOp;
 | 
				
			||||||
  using IOp = ::mlir::AddIOp;
 | 
					  using IOp = ::mlir::AddIOp;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
struct ScalarOp<xla_hlo::AddOp> {
 | 
					struct LhloToScalarOp<xla_lhlo::CompareOp> {
 | 
				
			||||||
  using FOp = ::mlir::AddFOp;
 | 
					 | 
				
			||||||
  using IOp = ::mlir::AddIOp;
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
struct ScalarOp<xla_lhlo::CompareOp> {
 | 
					 | 
				
			||||||
  using FOp = ::mlir::CmpFOp;
 | 
					  using FOp = ::mlir::CmpFOp;
 | 
				
			||||||
  using IOp = ::mlir::CmpIOp;
 | 
					  using IOp = ::mlir::CmpIOp;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
struct ScalarOp<xla_hlo::CompareOp> {
 | 
					struct LhloToScalarOp<xla_lhlo::DivOp> {
 | 
				
			||||||
  using FOp = ::mlir::CmpFOp;
 | 
					 | 
				
			||||||
  using IOp = ::mlir::CmpIOp;
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
struct ScalarOp<xla_lhlo::DivOp> {
 | 
					 | 
				
			||||||
  using FOp = ::mlir::DivFOp;
 | 
					  using FOp = ::mlir::DivFOp;
 | 
				
			||||||
  using IOp = ::mlir::SignedDivIOp;
 | 
					  using IOp = ::mlir::SignedDivIOp;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
struct ScalarOp<xla_hlo::DivOp> {
 | 
					struct LhloToScalarOp<xla_lhlo::MulOp> {
 | 
				
			||||||
  using FOp = ::mlir::DivFOp;
 | 
					 | 
				
			||||||
  using IOp = ::mlir::SignedDivIOp;
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
struct ScalarOp<xla_lhlo::MulOp> {
 | 
					 | 
				
			||||||
  using FOp = ::mlir::MulFOp;
 | 
					  using FOp = ::mlir::MulFOp;
 | 
				
			||||||
  using IOp = ::mlir::MulIOp;
 | 
					  using IOp = ::mlir::MulIOp;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
struct ScalarOp<xla_hlo::MulOp> {
 | 
					struct LhloToScalarOp<xla_lhlo::RemOp> {
 | 
				
			||||||
  using FOp = ::mlir::MulFOp;
 | 
					 | 
				
			||||||
  using IOp = ::mlir::MulIOp;
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
struct ScalarOp<xla_lhlo::RemOp> {
 | 
					 | 
				
			||||||
  using FOp = ::mlir::RemFOp;
 | 
					  using FOp = ::mlir::RemFOp;
 | 
				
			||||||
  using IOp = ::mlir::SignedRemIOp;
 | 
					  using IOp = ::mlir::SignedRemIOp;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
struct ScalarOp<xla_hlo::RemOp> {
 | 
					struct LhloToScalarOp<xla_lhlo::SubOp> {
 | 
				
			||||||
  using FOp = ::mlir::RemFOp;
 | 
					 | 
				
			||||||
  using IOp = ::mlir::SignedRemIOp;
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
struct ScalarOp<xla_lhlo::SubOp> {
 | 
					 | 
				
			||||||
  using FOp = ::mlir::SubFOp;
 | 
					 | 
				
			||||||
  using IOp = ::mlir::SubIOp;
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
struct ScalarOp<xla_hlo::SubOp> {
 | 
					 | 
				
			||||||
  using FOp = ::mlir::SubFOp;
 | 
					  using FOp = ::mlir::SubFOp;
 | 
				
			||||||
  using IOp = ::mlir::SubIOp;
 | 
					  using IOp = ::mlir::SubIOp;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename XLA_BinaryOp>
 | 
					template <typename LhloBinaryOpTy>
 | 
				
			||||||
using ScalarFOp = typename ScalarOp<XLA_BinaryOp>::FOp;
 | 
					struct ScalarOp {
 | 
				
			||||||
template <typename XLA_BinaryOp>
 | 
					  using FOp = typename LhloToScalarOp<LhloBinaryOpTy>::FOp;
 | 
				
			||||||
using ScalarIOp = typename ScalarOp<XLA_BinaryOp>::IOp;
 | 
					  using IOp = typename LhloToScalarOp<LhloBinaryOpTy>::IOp;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Alias for the map from LHLO binary op type to STD floating-point op type.
 | 
				
			||||||
 | 
					template <typename LhloOp>
 | 
				
			||||||
 | 
					using ScalarFOp = typename ScalarOp<LhloOp>::FOp;
 | 
				
			||||||
 | 
					// Alias for the map from LHLO binary op type to STD integer op type.
 | 
				
			||||||
 | 
					template <typename LhloOp>
 | 
				
			||||||
 | 
					using ScalarIOp = typename ScalarOp<LhloOp>::IOp;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename... Args>
 | 
					template <typename... Args>
 | 
				
			||||||
struct MapXlaOpToStdScalarOpImpl {
 | 
					struct MapLhloOpToStdScalarOpImpl {
 | 
				
			||||||
  Value operator()(Location loc, ArrayRef<Type> result_types,
 | 
					  Value operator()(Location loc, ArrayRef<Type> result_types,
 | 
				
			||||||
                   ArrayRef<Value> args, OpBuilder* b) {
 | 
					                   ArrayRef<Value> args, OpBuilder* b) {
 | 
				
			||||||
    return nullptr;
 | 
					    return nullptr;
 | 
				
			||||||
@ -103,7 +85,7 @@ struct MapXlaOpToStdScalarOpImpl {
 | 
				
			|||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename StdScalarOp>
 | 
					template <typename StdScalarOp>
 | 
				
			||||||
struct MapXlaOpToStdScalarOpImpl<StdScalarOp> {
 | 
					struct MapLhloOpToStdScalarOpImpl<StdScalarOp> {
 | 
				
			||||||
  Value operator()(Location loc, ArrayRef<Type> result_types,
 | 
					  Value operator()(Location loc, ArrayRef<Type> result_types,
 | 
				
			||||||
                   ArrayRef<Value> args, OpBuilder* b) {
 | 
					                   ArrayRef<Value> args, OpBuilder* b) {
 | 
				
			||||||
    return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
 | 
					    return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
 | 
				
			||||||
@ -111,7 +93,7 @@ struct MapXlaOpToStdScalarOpImpl<StdScalarOp> {
 | 
				
			|||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename SupportedType, typename StdScalarOp, typename... Args>
 | 
					template <typename SupportedType, typename StdScalarOp, typename... Args>
 | 
				
			||||||
struct MapXlaOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
 | 
					struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
 | 
				
			||||||
  Value operator()(Location loc, ArrayRef<Type> result_types,
 | 
					  Value operator()(Location loc, ArrayRef<Type> result_types,
 | 
				
			||||||
                   ArrayRef<Value> args, OpBuilder* b) {
 | 
					                   ArrayRef<Value> args, OpBuilder* b) {
 | 
				
			||||||
    Type element_type = args.front().getType();
 | 
					    Type element_type = args.front().getType();
 | 
				
			||||||
@ -119,52 +101,34 @@ struct MapXlaOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
 | 
				
			|||||||
      return b->template create<StdScalarOp>(loc, result_types, args,
 | 
					      return b->template create<StdScalarOp>(loc, result_types, args,
 | 
				
			||||||
                                             mlir::None);
 | 
					                                             mlir::None);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    return MapXlaOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b);
 | 
					    return MapLhloOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename XlaOp>
 | 
					// Inserts the computation that corresponds to the body of the loop for lowered
 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp(XlaOp xla_op, ArrayRef<Type> result_types,
 | 
					// LHLO unary/binary op. Returns the value for the result.
 | 
				
			||||||
                                   ArrayRef<Value> args, OpBuilder* b) {
 | 
					template <typename LhloOpTy>
 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<IntegerType, ScalarIOp<XlaOp>, FloatType,
 | 
					inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
 | 
				
			||||||
                                   ScalarFOp<XlaOp>>{}(xla_op.getLoc(),
 | 
					                                    ArrayRef<Value> args, OpBuilder* b) {
 | 
				
			||||||
                                                       result_types, args, b);
 | 
					  return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType,
 | 
				
			||||||
}
 | 
					                                    ScalarFOp<LhloOpTy>>{}(loc, result_types,
 | 
				
			||||||
 | 
					                                                           args, b);
 | 
				
			||||||
// TODO(ravishankarm): Find a way to reduce code-bloat in HLO and LHLO
 | 
					 | 
				
			||||||
// specialization.
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::AbsOp>(xla_lhlo::AbsOp xla_op,
 | 
					 | 
				
			||||||
                                                    ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
                                                    ArrayRef<Value> args,
 | 
					 | 
				
			||||||
                                                    OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
 | 
					 | 
				
			||||||
      xla_op.getLoc(), result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::AbsOp>(xla_hlo::AbsOp xla_op,
 | 
					 | 
				
			||||||
                                                   ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
                                                   ArrayRef<Value> args,
 | 
					 | 
				
			||||||
                                                   OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
 | 
					 | 
				
			||||||
      xla_op.getLoc(), result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::AndOp>(xla_lhlo::AndOp xla_op,
 | 
					inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
 | 
				
			||||||
                                                    ArrayRef<Type> result_types,
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
                                                    ArrayRef<Value> args,
 | 
					    OpBuilder* b) {
 | 
				
			||||||
                                                    OpBuilder* b) {
 | 
					  return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
 | 
					      loc, result_types, args, b);
 | 
				
			||||||
      xla_op.getLoc(), result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::AndOp>(xla_hlo::AndOp xla_op,
 | 
					inline Value MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
 | 
				
			||||||
                                                   ArrayRef<Type> result_types,
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
                                                   ArrayRef<Value> args,
 | 
					    OpBuilder* b) {
 | 
				
			||||||
                                                   OpBuilder* b) {
 | 
					  return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
 | 
					      loc, result_types, args, b);
 | 
				
			||||||
      xla_op.getLoc(), result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename PredicateType>
 | 
					template <typename PredicateType>
 | 
				
			||||||
@ -200,7 +164,8 @@ inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename XLACompareOpTy>
 | 
					template <typename XLACompareOpTy>
 | 
				
			||||||
inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op,
 | 
					inline Value MapXlaCompareOpToStdScalarOp(Location loc,
 | 
				
			||||||
 | 
					                                          StringRef comparison_direction,
 | 
				
			||||||
                                          ArrayRef<Type> result_types,
 | 
					                                          ArrayRef<Type> result_types,
 | 
				
			||||||
                                          ArrayRef<Value> args, OpBuilder* b) {
 | 
					                                          ArrayRef<Value> args, OpBuilder* b) {
 | 
				
			||||||
  const auto& lhs = args[0];
 | 
					  const auto& lhs = args[0];
 | 
				
			||||||
@ -208,101 +173,60 @@ inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op,
 | 
				
			|||||||
  Type element_type = lhs.getType();
 | 
					  Type element_type = lhs.getType();
 | 
				
			||||||
  if (element_type.isSignlessInteger()) {
 | 
					  if (element_type.isSignlessInteger()) {
 | 
				
			||||||
    Optional<CmpIPredicate> predicate =
 | 
					    Optional<CmpIPredicate> predicate =
 | 
				
			||||||
        getCmpPredicate<CmpIPredicate>(xla_op.comparison_direction());
 | 
					        getCmpPredicate<CmpIPredicate>(comparison_direction);
 | 
				
			||||||
    assert(predicate.hasValue() && "expected valid comparison direction");
 | 
					    assert(predicate.hasValue() && "expected valid comparison direction");
 | 
				
			||||||
    return b->create<ScalarIOp<XLACompareOpTy>>(xla_op.getLoc(),
 | 
					    return b->create<ScalarIOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
 | 
				
			||||||
                                                predicate.getValue(), lhs, rhs);
 | 
					                                                rhs);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  if (element_type.isa<FloatType>()) {
 | 
					  if (element_type.isa<FloatType>()) {
 | 
				
			||||||
    Optional<CmpFPredicate> predicate =
 | 
					    Optional<CmpFPredicate> predicate =
 | 
				
			||||||
        getCmpPredicate<CmpFPredicate>(xla_op.comparison_direction());
 | 
					        getCmpPredicate<CmpFPredicate>(comparison_direction);
 | 
				
			||||||
    assert(predicate.hasValue() && "expected valid comparison direction");
 | 
					    assert(predicate.hasValue() && "expected valid comparison direction");
 | 
				
			||||||
    return b->create<ScalarFOp<XLACompareOpTy>>(xla_op.getLoc(),
 | 
					    return b->create<ScalarFOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
 | 
				
			||||||
                                                predicate.getValue(), lhs, rhs);
 | 
					                                                rhs);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  return nullptr;
 | 
					  return nullptr;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CompareOp>(
 | 
					 | 
				
			||||||
    xla_lhlo::CompareOp xla_op, ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
    ArrayRef<Value> args, OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(xla_op, result_types,
 | 
					 | 
				
			||||||
                                                           args, b);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::CompareOp>(
 | 
					 | 
				
			||||||
    xla_hlo::CompareOp xla_op, ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
    ArrayRef<Value> args, OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaCompareOpToStdScalarOp<xla_hlo::CompareOp>(xla_op, result_types,
 | 
					 | 
				
			||||||
                                                          args, b);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CopyOp>(
 | 
					inline Value MapLhloOpToStdScalarOp<xla_lhlo::CopyOp>(
 | 
				
			||||||
    xla_lhlo::CopyOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
    OpBuilder* b) {
 | 
					    OpBuilder* b) {
 | 
				
			||||||
  return args.front();
 | 
					  return args.front();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::CopyOp>(xla_hlo::CopyOp xla_op,
 | 
					 | 
				
			||||||
                                                    ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
                                                    ArrayRef<Value> args,
 | 
					 | 
				
			||||||
                                                    OpBuilder* b) {
 | 
					 | 
				
			||||||
  return args.front();
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::ExpOp>(xla_lhlo::ExpOp xla_op,
 | 
					inline Value MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
 | 
				
			||||||
                                                    ArrayRef<Type> result_types,
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
                                                    ArrayRef<Value> args,
 | 
					 | 
				
			||||||
                                                    OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
 | 
					 | 
				
			||||||
      xla_op.getLoc(), result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::ExpOp>(xla_hlo::ExpOp xla_op,
 | 
					 | 
				
			||||||
                                                   ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
                                                   ArrayRef<Value> args,
 | 
					 | 
				
			||||||
                                                   OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
 | 
					 | 
				
			||||||
      xla_op.getLoc(), result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CeilOp>(
 | 
					 | 
				
			||||||
    xla_lhlo::CeilOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
					 | 
				
			||||||
    OpBuilder* b) {
 | 
					    OpBuilder* b) {
 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
 | 
					  return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
 | 
				
			||||||
      xla_op.getLoc(), result_types, args, b);
 | 
					      loc, result_types, args, b);
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::CeilOp>(xla_hlo::CeilOp xla_op,
 | 
					 | 
				
			||||||
                                                    ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
                                                    ArrayRef<Value> args,
 | 
					 | 
				
			||||||
                                                    OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
 | 
					 | 
				
			||||||
      xla_op.getLoc(), result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
 | 
					inline Value MapLhloOpToStdScalarOp<xla_lhlo::CeilOp>(
 | 
				
			||||||
    xla_lhlo::ConvertOp xla_op, ArrayRef<Type> result_types,
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
    ArrayRef<Value> args, OpBuilder* b) {
 | 
					    OpBuilder* b) {
 | 
				
			||||||
 | 
					  return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
 | 
				
			||||||
 | 
					      loc, result_types, args, b);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
 | 
				
			||||||
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
 | 
					    OpBuilder* b) {
 | 
				
			||||||
  Type sourceType = args.front().getType();
 | 
					  Type sourceType = args.front().getType();
 | 
				
			||||||
  Type targetType = result_types.front();
 | 
					  Type targetType = result_types.front();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
 | 
					  if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
 | 
				
			||||||
    return b->create<mlir::SIToFPOp>(xla_op.getLoc(), result_types, args,
 | 
					    return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None);
 | 
				
			||||||
                                     mlir::None);
 | 
					 | 
				
			||||||
  } else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) {
 | 
					  } else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) {
 | 
				
			||||||
    FloatType src = sourceType.cast<FloatType>();
 | 
					    FloatType src = sourceType.cast<FloatType>();
 | 
				
			||||||
    FloatType res = targetType.cast<FloatType>();
 | 
					    FloatType res = targetType.cast<FloatType>();
 | 
				
			||||||
    if (src.getWidth() > res.getWidth()) {
 | 
					    if (src.getWidth() > res.getWidth()) {
 | 
				
			||||||
      return b->create<mlir::FPTruncOp>(xla_op.getLoc(), result_types, args,
 | 
					      return b->create<mlir::FPTruncOp>(loc, result_types, args, mlir::None);
 | 
				
			||||||
                                        mlir::None);
 | 
					 | 
				
			||||||
    } else if (src.getWidth() < res.getWidth()) {
 | 
					    } else if (src.getWidth() < res.getWidth()) {
 | 
				
			||||||
      return b->create<mlir::FPExtOp>(xla_op.getLoc(), result_types, args,
 | 
					      return b->create<mlir::FPExtOp>(loc, result_types, args, mlir::None);
 | 
				
			||||||
                                      mlir::None);
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    // No conversion is needed for the same width floats
 | 
					    // No conversion is needed for the same width floats
 | 
				
			||||||
    return args.front();
 | 
					    return args.front();
 | 
				
			||||||
@ -311,10 +235,9 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
 | 
				
			|||||||
    IntegerType src = sourceType.cast<IntegerType>();
 | 
					    IntegerType src = sourceType.cast<IntegerType>();
 | 
				
			||||||
    IntegerType res = targetType.cast<IntegerType>();
 | 
					    IntegerType res = targetType.cast<IntegerType>();
 | 
				
			||||||
    if (src.getWidth() > res.getWidth()) {
 | 
					    if (src.getWidth() > res.getWidth()) {
 | 
				
			||||||
      return b->create<mlir::TruncateIOp>(xla_op.getLoc(), result_types, args,
 | 
					      return b->create<mlir::TruncateIOp>(loc, result_types, args, mlir::None);
 | 
				
			||||||
                                          mlir::None);
 | 
					 | 
				
			||||||
    } else if (src.getWidth() < res.getWidth()) {
 | 
					    } else if (src.getWidth() < res.getWidth()) {
 | 
				
			||||||
      return b->create<mlir::ZeroExtendIOp>(xla_op.getLoc(), result_types, args,
 | 
					      return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
 | 
				
			||||||
                                            mlir::None);
 | 
					                                            mlir::None);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    // No conversion is needed for the same width integers
 | 
					    // No conversion is needed for the same width integers
 | 
				
			||||||
@ -322,35 +245,25 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::ConvertOp>(
 | 
				
			|||||||
  }
 | 
					  }
 | 
				
			||||||
  // TODO(dfki-ehna): Add other primitive type conversions
 | 
					  // TODO(dfki-ehna): Add other primitive type conversions
 | 
				
			||||||
  // if (mlir::FpToSiOp::areCastCompatible(sourceType, targetType)) {
 | 
					  // if (mlir::FpToSiOp::areCastCompatible(sourceType, targetType)) {
 | 
				
			||||||
  //   return b.create<mlir::FpToSiOp>(xla_op.getLoc(), result_types,
 | 
					  //   return b.create<mlir::FpToSiOp>(loc, result_types,
 | 
				
			||||||
  //   args,mlir::None);
 | 
					  //   args,mlir::None);
 | 
				
			||||||
  // }
 | 
					  // }
 | 
				
			||||||
 | 
					 | 
				
			||||||
  return nullptr;
 | 
					  return nullptr;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CosOp>(xla_lhlo::CosOp xla_op,
 | 
					inline Value MapLhloOpToStdScalarOp<xla_lhlo::CosOp>(
 | 
				
			||||||
                                                    ArrayRef<Type> result_types,
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
                                                    ArrayRef<Value> args,
 | 
					    OpBuilder* b) {
 | 
				
			||||||
                                                    OpBuilder* b) {
 | 
					  return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
 | 
					      loc, result_types, args, b);
 | 
				
			||||||
      xla_op.getLoc(), result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::CosOp>(xla_hlo::CosOp xla_op,
 | 
					 | 
				
			||||||
                                                   ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
                                                   ArrayRef<Value> args,
 | 
					 | 
				
			||||||
                                                   OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
 | 
					 | 
				
			||||||
      xla_op.getLoc(), result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/// Implements the conversion of XLA op to scalar op (to use within region of a
 | 
					/// Implements the conversion of XLA op to scalar op (to use within region of a
 | 
				
			||||||
/// linalg.generic op) for compare-select style operations like min/max.
 | 
					/// linalg.generic op) for compare-select style operations like min/max.
 | 
				
			||||||
template <typename... Args>
 | 
					template <typename... Args>
 | 
				
			||||||
struct MapXlaCompareSelectOpToStdScalarOp {
 | 
					struct XlaCompareSelectOpToStdScalarOp {
 | 
				
			||||||
  Value operator()(Location loc, StringRef comparison_direction,
 | 
					  static Value map(Location loc, StringRef comparison_direction,
 | 
				
			||||||
                   ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
					                   ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
                   OpBuilder* b) {
 | 
					                   OpBuilder* b) {
 | 
				
			||||||
    return nullptr;
 | 
					    return nullptr;
 | 
				
			||||||
@ -361,9 +274,9 @@ struct MapXlaCompareSelectOpToStdScalarOp {
 | 
				
			|||||||
/// dialect with a given predicate based on the element type of the operand.
 | 
					/// dialect with a given predicate based on the element type of the operand.
 | 
				
			||||||
template <typename SupportedType, typename StdCompareOp, typename Predicate,
 | 
					template <typename SupportedType, typename StdCompareOp, typename Predicate,
 | 
				
			||||||
          typename... Args>
 | 
					          typename... Args>
 | 
				
			||||||
struct MapXlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp,
 | 
					struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
 | 
				
			||||||
                                          Predicate, Args...> {
 | 
					                                       Args...> {
 | 
				
			||||||
  Value operator()(Location loc, StringRef comparison_direction,
 | 
					  static Value map(Location loc, StringRef comparison_direction,
 | 
				
			||||||
                   ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
					                   ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
                   OpBuilder* b) {
 | 
					                   OpBuilder* b) {
 | 
				
			||||||
    Type element_type = args.front().getType();
 | 
					    Type element_type = args.front().getType();
 | 
				
			||||||
@ -374,132 +287,130 @@ struct MapXlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp,
 | 
				
			|||||||
                                                  args[0], args[1]);
 | 
					                                                  args[0], args[1]);
 | 
				
			||||||
      return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
 | 
					      return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    return MapXlaCompareSelectOpToStdScalarOp<Args...>{}(
 | 
					    return XlaCompareSelectOpToStdScalarOp<Args...>::map(
 | 
				
			||||||
        loc, comparison_direction, result_types, args, b);
 | 
					        loc, comparison_direction, result_types, args, b);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::LogOp>(xla_lhlo::LogOp xla_op,
 | 
					inline Value MapLhloOpToStdScalarOp<xla_lhlo::LogOp>(
 | 
				
			||||||
                                                    ArrayRef<Type> result_types,
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
                                                    ArrayRef<Value> args,
 | 
					 | 
				
			||||||
                                                    OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
 | 
					 | 
				
			||||||
      xla_op.getLoc(), result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::LogOp>(xla_hlo::LogOp xla_op,
 | 
					 | 
				
			||||||
                                                   ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
                                                   ArrayRef<Value> args,
 | 
					 | 
				
			||||||
                                                   OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
 | 
					 | 
				
			||||||
      xla_op.getLoc(), result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::MaxOp>(xla_lhlo::MaxOp xla_op,
 | 
					 | 
				
			||||||
                                                    ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
                                                    ArrayRef<Value> args,
 | 
					 | 
				
			||||||
                                                    OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaCompareSelectOpToStdScalarOp<
 | 
					 | 
				
			||||||
      IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
 | 
					 | 
				
			||||||
      ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "GT",
 | 
					 | 
				
			||||||
                                                       result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::MaxOp>(xla_hlo::MaxOp xla_op,
 | 
					 | 
				
			||||||
                                                   ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
                                                   ArrayRef<Value> args,
 | 
					 | 
				
			||||||
                                                   OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaCompareSelectOpToStdScalarOp<
 | 
					 | 
				
			||||||
      IntegerType, ScalarIOp<xla_hlo::CompareOp>, CmpIPredicate, FloatType,
 | 
					 | 
				
			||||||
      ScalarFOp<xla_hlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "GT",
 | 
					 | 
				
			||||||
                                                      result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::MinOp>(xla_lhlo::MinOp xla_op,
 | 
					 | 
				
			||||||
                                                    ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
                                                    ArrayRef<Value> args,
 | 
					 | 
				
			||||||
                                                    OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaCompareSelectOpToStdScalarOp<
 | 
					 | 
				
			||||||
      IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
 | 
					 | 
				
			||||||
      ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "LT",
 | 
					 | 
				
			||||||
                                                       result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::MinOp>(xla_hlo::MinOp xla_op,
 | 
					 | 
				
			||||||
                                                   ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
                                                   ArrayRef<Value> args,
 | 
					 | 
				
			||||||
                                                   OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaCompareSelectOpToStdScalarOp<
 | 
					 | 
				
			||||||
      IntegerType, ScalarIOp<xla_hlo::CompareOp>, CmpIPredicate, FloatType,
 | 
					 | 
				
			||||||
      ScalarFOp<xla_hlo::CompareOp>, CmpFPredicate>{}(xla_op.getLoc(), "LT",
 | 
					 | 
				
			||||||
                                                      result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::NegOp>(xla_lhlo::NegOp xla_op,
 | 
					 | 
				
			||||||
                                                    ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
                                                    ArrayRef<Value> args,
 | 
					 | 
				
			||||||
                                                    OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
 | 
					 | 
				
			||||||
      xla_op.getLoc(), result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::NegOp>(xla_hlo::NegOp xla_op,
 | 
					 | 
				
			||||||
                                                   ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
                                                   ArrayRef<Value> args,
 | 
					 | 
				
			||||||
                                                   OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
 | 
					 | 
				
			||||||
      xla_op.getLoc(), result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::SelectOp>(
 | 
					 | 
				
			||||||
    xla_lhlo::SelectOp xla_op, ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
    ArrayRef<Value> args, OpBuilder* b) {
 | 
					 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(),
 | 
					 | 
				
			||||||
                                                       result_types, args, b);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
template <>
 | 
					 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::SelectOp>(
 | 
					 | 
				
			||||||
    xla_hlo::SelectOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
					 | 
				
			||||||
    OpBuilder* b) {
 | 
					    OpBuilder* b) {
 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(),
 | 
					  return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
 | 
				
			||||||
                                                       result_types, args, b);
 | 
					      loc, result_types, args, b);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::SignOp>(
 | 
					inline Value MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
 | 
				
			||||||
    xla_lhlo::SignOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
 | 
					    OpBuilder* b) {
 | 
				
			||||||
 | 
					  return XlaCompareSelectOpToStdScalarOp<
 | 
				
			||||||
 | 
					      IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
 | 
				
			||||||
 | 
					      ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
 | 
				
			||||||
 | 
					                                                          result_types, args,
 | 
				
			||||||
 | 
					                                                          b);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					inline Value MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
 | 
				
			||||||
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
 | 
					    OpBuilder* b) {
 | 
				
			||||||
 | 
					  return XlaCompareSelectOpToStdScalarOp<
 | 
				
			||||||
 | 
					      IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
 | 
				
			||||||
 | 
					      ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
 | 
				
			||||||
 | 
					                                                          result_types, args,
 | 
				
			||||||
 | 
					                                                          b);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>(
 | 
				
			||||||
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
 | 
					    OpBuilder* b) {
 | 
				
			||||||
 | 
					  return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
 | 
				
			||||||
 | 
					      loc, result_types, args, b);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
 | 
				
			||||||
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
 | 
					    OpBuilder* b) {
 | 
				
			||||||
 | 
					  return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
 | 
				
			||||||
 | 
					                                                        b);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
 | 
				
			||||||
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
    OpBuilder* b) {
 | 
					    OpBuilder* b) {
 | 
				
			||||||
  Type element_type = args.front().getType();
 | 
					  Type element_type = args.front().getType();
 | 
				
			||||||
  if (element_type.isa<FloatType>()) {
 | 
					  if (element_type.isa<FloatType>()) {
 | 
				
			||||||
    FloatType float_type = element_type.cast<FloatType>();
 | 
					    FloatType float_type = element_type.cast<FloatType>();
 | 
				
			||||||
    APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0);
 | 
					    APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0);
 | 
				
			||||||
    Value one = b->create<mlir::ConstantFloatOp>(xla_op.getLoc(), const_value,
 | 
					    Value one = b->create<mlir::ConstantFloatOp>(loc, const_value, float_type);
 | 
				
			||||||
                                                 float_type);
 | 
					    return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
 | 
				
			||||||
    return b->create<::mlir::CopySignOp>(xla_op.getLoc(), result_types, one,
 | 
					 | 
				
			||||||
                                         args[0]);
 | 
					 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  return nullptr;
 | 
					  return nullptr;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::TanhOp>(
 | 
					inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
 | 
				
			||||||
    xla_lhlo::TanhOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
					    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
				
			||||||
    OpBuilder* b) {
 | 
					    OpBuilder* b) {
 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
 | 
					  return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
 | 
				
			||||||
      xla_op.getLoc(), result_types, args, b);
 | 
					      loc, result_types, args, b);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
template <>
 | 
					
 | 
				
			||||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::TanhOp>(xla_hlo::TanhOp xla_op,
 | 
					}  // namespace impl
 | 
				
			||||||
                                                    ArrayRef<Type> result_types,
 | 
					
 | 
				
			||||||
                                                    ArrayRef<Value> args,
 | 
					struct XlaOpToStdScalarOp {
 | 
				
			||||||
                                                    OpBuilder* b) {
 | 
					  // Implementation for LHLO ops except xla_lhlo::CompareOp.
 | 
				
			||||||
  return MapXlaOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
 | 
					  template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
 | 
				
			||||||
      xla_op.getLoc(), result_types, args, b);
 | 
					            typename = std::enable_if_t<
 | 
				
			||||||
 | 
					                !std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
 | 
				
			||||||
 | 
					                std::is_same<typename xla_hlo::HloToLhloOp<LhloOpTy>,
 | 
				
			||||||
 | 
					                             std::false_type>::value>>
 | 
				
			||||||
 | 
					  static Value map(XlaOpTy op, ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                   ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
 | 
				
			||||||
 | 
					    return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
 | 
				
			||||||
 | 
					                                                  args, b);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Implementation for HLO ops except xla_hlo::CompareOp.
 | 
				
			||||||
 | 
					  template <typename XlaOpTy, typename LhloOpTy = xla_hlo::HloToLhloOp<XlaOpTy>,
 | 
				
			||||||
 | 
					            typename = std::enable_if_t<
 | 
				
			||||||
 | 
					                !std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
 | 
				
			||||||
 | 
					                !std::is_same<LhloOpTy, std::false_type>::value>>
 | 
				
			||||||
 | 
					  static Value map(XlaOpTy op, ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                   ArrayRef<Value> args, OpBuilder* b, int i = 0) {
 | 
				
			||||||
 | 
					    return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
 | 
				
			||||||
 | 
					                                                  args, b);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Implementation for xla_lhlo::CompareOp.
 | 
				
			||||||
 | 
					  template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
 | 
				
			||||||
 | 
					                                   LhloOpTy, xla_lhlo::CompareOp>::value>>
 | 
				
			||||||
 | 
					  static Value map(xla_lhlo::CompareOp op, ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                   ArrayRef<Value> args, OpBuilder* b) {
 | 
				
			||||||
 | 
					    auto comparison_direction = op.comparison_direction();
 | 
				
			||||||
 | 
					    return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
 | 
				
			||||||
 | 
					        op.getLoc(), comparison_direction, result_types, args, b);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Implementation for xla_hlo::CompareOp.
 | 
				
			||||||
 | 
					  template <typename HloOpTy, typename = std::enable_if_t<std::is_same<
 | 
				
			||||||
 | 
					                                  HloOpTy, xla_hlo::CompareOp>::value>>
 | 
				
			||||||
 | 
					  static Value map(xla_hlo::CompareOp op, ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                   ArrayRef<Value> args, OpBuilder* b) {
 | 
				
			||||||
 | 
					    auto comparison_direction = op.comparison_direction();
 | 
				
			||||||
 | 
					    return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
 | 
				
			||||||
 | 
					        op.getLoc(), comparison_direction, result_types, args, b);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename XlaOpTy>
 | 
				
			||||||
 | 
					inline Value MapXlaOpToStdScalarOp(XlaOpTy xla_op, ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                                   ArrayRef<Value> args, OpBuilder* b) {
 | 
				
			||||||
 | 
					  return XlaOpToStdScalarOp::map<XlaOpTy>(xla_op, result_types, args, b);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace xla_lhlo
 | 
					}  // namespace xla_lhlo
 | 
				
			||||||
 | 
				
			|||||||
@ -149,8 +149,8 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
 | 
				
			|||||||
    rewriter.setInsertionPointToEnd(block);
 | 
					    rewriter.setInsertionPointToEnd(block);
 | 
				
			||||||
    // TODO(ravishankarm) : For now use the method in xla_lhlo namespace. That
 | 
					    // TODO(ravishankarm) : For now use the method in xla_lhlo namespace. That
 | 
				
			||||||
    // method needs to be moved out of there.
 | 
					    // method needs to be moved out of there.
 | 
				
			||||||
    Value opResult = xla_lhlo::MapXlaOpToStdScalarOp<OpTy>(
 | 
					    Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<OpTy>(
 | 
				
			||||||
        llvm::cast<OpTy>(op), bodyResultTypes, bodyArgs, &rewriter);
 | 
					        op, bodyResultTypes, bodyArgs, &rewriter);
 | 
				
			||||||
    if (!opResult) {
 | 
					    if (!opResult) {
 | 
				
			||||||
      return ConversionPattern::matchFailure();
 | 
					      return ConversionPattern::matchFailure();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@ -180,9 +180,9 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
 | 
				
			|||||||
    auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
 | 
					    auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
 | 
				
			||||||
    auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
 | 
					    auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
 | 
				
			||||||
    // TODO(ravishankarm) : Move this method out of xla_lhlo namespace.
 | 
					    // TODO(ravishankarm) : Move this method out of xla_lhlo namespace.
 | 
				
			||||||
    Value opResult = xla_lhlo::MapXlaOpToStdScalarOp<LhloOp>(
 | 
					    Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOp>(
 | 
				
			||||||
        llvm::cast<LhloOp>(lhlo_op), argType.getElementType(),
 | 
					        lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
 | 
				
			||||||
        llvm::ArrayRef<Value>{lhs, rhs}, &rewriter);
 | 
					        &rewriter);
 | 
				
			||||||
    rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
 | 
					    rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
 | 
				
			||||||
    rewriter.eraseOp(lhlo_op);
 | 
					    rewriter.eraseOp(lhlo_op);
 | 
				
			||||||
    return ConversionPattern::matchSuccess();
 | 
					    return ConversionPattern::matchSuccess();
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user