Folders for mhlo.compare
Constant evaluation of compare for the case where inputs are either the same variable or the values are constant. PiperOrigin-RevId: 333342328 Change-Id: I089e3f73797f5a9ecc2cbded00d07488c9b6da92
This commit is contained in:
parent
ee8caa1798
commit
736aea1228
@ -689,6 +689,8 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands,
|
||||
);
|
||||
let results = (outs HLO_PredTensor);
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
|
||||
"StringAttr comparison_direction"
|
||||
|
@ -2501,8 +2501,108 @@ LogicalResult CompareOp::reifyReturnTypeShapes(
|
||||
&reifiedReturnShapes);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct less : std::less<T> {};
|
||||
|
||||
template <>
|
||||
struct less<APInt> {
|
||||
bool operator()(const APInt& a, const APInt& b) const { return a.slt(b); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct less_equal : std::less_equal<T> {};
|
||||
|
||||
template <>
|
||||
struct less_equal<APInt> {
|
||||
bool operator()(const APInt& a, const APInt& b) const { return a.sle(b); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct greater : std::greater<T> {};
|
||||
|
||||
template <>
|
||||
struct greater<APInt> {
|
||||
bool operator()(const APInt& a, const APInt& b) const { return a.sgt(b); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct greater_equal : std::greater_equal<T> {};
|
||||
|
||||
template <>
|
||||
struct greater_equal<APInt> {
|
||||
bool operator()(const APInt& a, const APInt& b) const { return a.sge(b); }
|
||||
};
|
||||
|
||||
template <typename Op, typename ElementType, typename SrcType, typename Convert>
|
||||
static Attribute CompareFolder(CompareOp op, ArrayRef<Attribute> attrs) {
|
||||
if (!attrs[0] || !attrs[1]) return {};
|
||||
|
||||
DenseElementsAttr lhs = attrs[0].dyn_cast<DenseElementsAttr>();
|
||||
DenseElementsAttr rhs = attrs[1].dyn_cast<DenseElementsAttr>();
|
||||
if (!lhs || !rhs) return {};
|
||||
|
||||
ShapedType operand_type =
|
||||
op.getOperand(0).getType().template cast<ShapedType>();
|
||||
if (!operand_type.hasStaticShape()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
if (!operand_type.getElementType().isa<ElementType>()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
SmallVector<bool, 6> values;
|
||||
values.reserve(lhs.getNumElements());
|
||||
for (const auto zip :
|
||||
llvm::zip(lhs.getValues<SrcType>(), rhs.getValues<SrcType>())) {
|
||||
values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip)));
|
||||
}
|
||||
|
||||
auto result_ty = op.getType().cast<ShapedType>();
|
||||
return DenseElementsAttr::get(result_ty, values);
|
||||
}
|
||||
|
||||
OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto result_ty = getType().cast<ShapedType>();
|
||||
if (!result_ty.hasStaticShape()) return {};
|
||||
|
||||
auto direction = comparison_direction();
|
||||
if (lhs() == rhs()) {
|
||||
if (direction == "LE" || direction == "EQ" || direction == "GE") {
|
||||
return DenseIntElementsAttr::get(result_ty, {true});
|
||||
}
|
||||
|
||||
return DenseIntElementsAttr::get(result_ty, {false});
|
||||
}
|
||||
|
||||
if (!operands[0] || !operands[1]) {
|
||||
return {};
|
||||
}
|
||||
|
||||
#define COMPARE_FOLDER(Op, comparison, Func) \
|
||||
if (direction == comparison) { \
|
||||
if (auto folded = CompareFolder<Op, FloatType, APFloat, Func<APFloat>>( \
|
||||
*this, operands)) \
|
||||
return folded; \
|
||||
if (auto folded = CompareFolder<Op, IntegerType, APInt, Func<APInt>>( \
|
||||
*this, operands)) \
|
||||
return folded; \
|
||||
}
|
||||
|
||||
COMPARE_FOLDER(CompareOp, "EQ", std::equal_to);
|
||||
COMPARE_FOLDER(CompareOp, "NE", std::not_equal_to);
|
||||
COMPARE_FOLDER(CompareOp, "LT", less);
|
||||
COMPARE_FOLDER(CompareOp, "LE", less_equal);
|
||||
COMPARE_FOLDER(CompareOp, "GT", greater);
|
||||
COMPARE_FOLDER(CompareOp, "GE", greater_equal);
|
||||
#undef COMPARE_FOLDER
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
|
||||
|
||||
|
@ -583,6 +583,262 @@ func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
return %arg0 : tensor<i64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_same_eq
|
||||
func @fold_compare_same_eq(%arg0: tensor<i64>) -> tensor<i1> {
|
||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_same_le
|
||||
func @fold_compare_same_le(%arg0: tensor<i64>) -> tensor<i1> {
|
||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_same_ge
|
||||
func @fold_compare_same_ge(%arg0: tensor<i64>) -> tensor<i1> {
|
||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
// CHECK-LABEL: fold_compare_same_ne
|
||||
func @fold_compare_same_ne(%arg0: tensor<i64>) -> tensor<i1> {
|
||||
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_same_lt
|
||||
func @fold_compare_same_lt(%arg0: tensor<i64>) -> tensor<i1> {
|
||||
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_same_gt
|
||||
func @fold_compare_same_gt(%arg0: tensor<i64>) -> tensor<i1> {
|
||||
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_false_eq
|
||||
func @fold_compare_false_eq() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = mhlo.constant dense<1> : tensor<i32>
|
||||
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
// CHECK-LABEL: fold_compare_true_eq
|
||||
func @fold_compare_true_eq() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||
%1 = mhlo.constant dense<1> : tensor<i32>
|
||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_false_eq_float
|
||||
func @fold_compare_false_eq_float() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<0.> : tensor<f32>
|
||||
%1 = mhlo.constant dense<1.> : tensor<f32>
|
||||
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_true_eq_float
|
||||
func @fold_compare_true_eq_float() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<1.> : tensor<f32>
|
||||
%1 = mhlo.constant dense<1.> : tensor<f32>
|
||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_false_ne
|
||||
func @fold_compare_false_ne() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||
%1 = mhlo.constant dense<1> : tensor<i32>
|
||||
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_true_ne
|
||||
func @fold_compare_true_ne() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||
%1 = mhlo.constant dense<0> : tensor<i32>
|
||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_false_ne_float
|
||||
func @fold_compare_false_ne_float() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<1.> : tensor<f32>
|
||||
%1 = mhlo.constant dense<1.> : tensor<f32>
|
||||
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_true_ne_float
|
||||
func @fold_compare_true_ne_float() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<0.> : tensor<f32>
|
||||
%1 = mhlo.constant dense<1.> : tensor<f32>
|
||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_false_lt
|
||||
func @fold_compare_false_lt() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||
%1 = mhlo.constant dense<1> : tensor<i32>
|
||||
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_true_lt
|
||||
func @fold_compare_true_lt() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = mhlo.constant dense<1> : tensor<i32>
|
||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_false_lt_float
|
||||
func @fold_compare_false_lt_float() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<1.> : tensor<f32>
|
||||
%1 = mhlo.constant dense<1.> : tensor<f32>
|
||||
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_true_lt_float
|
||||
func @fold_compare_true_lt_float() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<0.> : tensor<f32>
|
||||
%1 = mhlo.constant dense<1.> : tensor<f32>
|
||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_false_le
|
||||
func @fold_compare_false_le() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||
%1 = mhlo.constant dense<0> : tensor<i32>
|
||||
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_true_le
|
||||
func @fold_compare_true_le() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||
%1 = mhlo.constant dense<1> : tensor<i32>
|
||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_false_le_float
|
||||
func @fold_compare_false_le_float() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<1.> : tensor<f32>
|
||||
%1 = mhlo.constant dense<0.> : tensor<f32>
|
||||
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_true_le_float
|
||||
func @fold_compare_true_le_float() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<1.> : tensor<f32>
|
||||
%1 = mhlo.constant dense<1.> : tensor<f32>
|
||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_false_gt
|
||||
func @fold_compare_false_gt() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = mhlo.constant dense<0> : tensor<i32>
|
||||
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_true_gt
|
||||
func @fold_compare_true_gt() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||
%1 = mhlo.constant dense<0> : tensor<i32>
|
||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_false_gt_float
|
||||
func @fold_compare_false_gt_float() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<0.> : tensor<f32>
|
||||
%1 = mhlo.constant dense<0.> : tensor<f32>
|
||||
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_true_gt_float
|
||||
func @fold_compare_true_gt_float() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<1.> : tensor<f32>
|
||||
%1 = mhlo.constant dense<0.> : tensor<f32>
|
||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_false_ge
|
||||
func @fold_compare_false_ge() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = mhlo.constant dense<1> : tensor<i32>
|
||||
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_true_ge
|
||||
func @fold_compare_true_ge() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = mhlo.constant dense<0> : tensor<i32>
|
||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_false_ge_float
|
||||
func @fold_compare_false_ge_float() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<0.> : tensor<f32>
|
||||
%1 = mhlo.constant dense<1.> : tensor<f32>
|
||||
// CHECK: %0 = mhlo.constant dense<false> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fold_compare_true_ge_float
|
||||
func @fold_compare_true_ge_float() -> tensor<i1> {
|
||||
%0 = mhlo.constant dense<0.> : tensor<f32>
|
||||
%1 = mhlo.constant dense<0.> : tensor<f32>
|
||||
// CHECK: %0 = mhlo.constant dense<true> : tensor<i1>
|
||||
%2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
return %2 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: unpack_repack_same_tuple
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tuple<tensor<i32>, !mhlo.token, tensor<f32>>)
|
||||
func @unpack_repack_same_tuple(%arg0: tuple<tensor<i32>, !mhlo.token, tensor<f32>>) -> tuple<tensor<i32>, !mhlo.token, tensor<f32>> {
|
||||
|
@ -51,38 +51,38 @@ func @unary_ops_float(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
|
||||
func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
|
||||
// CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32>
|
||||
%1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32>
|
||||
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32>
|
||||
%3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32>
|
||||
%4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32>
|
||||
%5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-LABEL: func @compare_int
|
||||
func @compare_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
|
||||
// CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg1 : tensor<4xi32>
|
||||
%1 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg1 : tensor<4xi32>
|
||||
%2 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg1 : tensor<4xi32>
|
||||
%3 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg1 : tensor<4xi32>
|
||||
%4 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg1 : tensor<4xi32>
|
||||
%5 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
||||
return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @compare_float
|
||||
func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
|
||||
// CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32>
|
||||
%1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32>
|
||||
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32>
|
||||
%3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32>
|
||||
%4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32>
|
||||
%5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
func @compare_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
|
||||
// CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg1 : tensor<4xf32>
|
||||
%0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %1 = cmpf "une", %arg0, %arg1 : tensor<4xf32>
|
||||
%1 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg1 : tensor<4xf32>
|
||||
%2 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg1 : tensor<4xf32>
|
||||
%3 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg1 : tensor<4xf32>
|
||||
%4 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg1 : tensor<4xf32>
|
||||
%5 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
||||
}
|
||||
|
||||
|
@ -479,12 +479,13 @@ func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> te
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @equal(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK: %[[VAL_1:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_0]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
// CHECK: return %[[VAL_1]] : tensor<2xi1>
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK: %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
// CHECK: return %[[VAL_2]] : tensor<2xi1>
|
||||
// CHECK: }
|
||||
func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
func @equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -533,12 +534,13 @@ func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @notequal(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK: %[[VAL_1:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_0]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
// CHECK: return %[[VAL_1]] : tensor<2xi1>
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK: %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
// CHECK: return %[[VAL_2]] : tensor<2xi1>
|
||||
// CHECK: }
|
||||
func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -576,12 +578,13 @@ func @notequal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: ten
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @greater(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK: %[[VAL_1:.*]] = "tf.Greater"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
// CHECK: return %[[VAL_1]] : tensor<2xi1>
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK: %[[VAL_2:.*]] = "tf.Greater"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
// CHECK: return %[[VAL_2]] : tensor<2xi1>
|
||||
// CHECK: }
|
||||
func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -597,12 +600,13 @@ func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @greater_equal(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK: %[[VAL_1:.*]] = "tf.GreaterEqual"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
// CHECK: return %[[VAL_1]] : tensor<2xi1>
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK: %[[VAL_2:.*]] = "tf.GreaterEqual"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
// CHECK: return %[[VAL_2]] : tensor<2xi1>
|
||||
// CHECK: }
|
||||
func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
func @greater_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -618,12 +622,13 @@ func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> t
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @less(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK: %[[VAL_1:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
// CHECK: return %[[VAL_1]] : tensor<2xi1>
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK: %[[VAL_2:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
// CHECK: return %[[VAL_2]] : tensor<2xi1>
|
||||
// CHECK: }
|
||||
func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
func @less(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -639,12 +644,13 @@ func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @less_equal(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK: %[[VAL_1:.*]] = "tf.LessEqual"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
// CHECK: return %[[VAL_1]] : tensor<2xi1>
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK: %[[VAL_2:.*]] = "tf.LessEqual"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
// CHECK: return %[[VAL_2]] : tensor<2xi1>
|
||||
// CHECK: }
|
||||
func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
func @less_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -1363,20 +1369,21 @@ func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sign(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
|
||||
// CHECK: %[[VAL_1:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_0]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
|
||||
// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_0]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
|
||||
// CHECK: %[[VAL_4:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = "tf.Sign"(%[[VAL_0]]) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = "tf.Select"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = "tf.Select"(%[[VAL_1]], %[[VAL_2]], %[[VAL_6]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
// CHECK: return %[[VAL_7]] : tensor<1x2x3x4xf32>
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4xf32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
|
||||
// CHECK: %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
|
||||
// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
|
||||
// CHECK: %[[VAL_5:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = "tf.Sign"(%[[VAL_0]]) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = "tf.Select"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = "tf.Select"(%[[VAL_2]], %[[VAL_3]], %[[VAL_7]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
// CHECK: return %[[VAL_8]] : tensor<1x2x3x4xf32>
|
||||
// CHECK: }
|
||||
func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
|
||||
func @sign(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
|
||||
%0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
|
||||
%1 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
|
||||
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
|
||||
%2 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
|
||||
%3 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
|
||||
%4 = "mhlo.sign"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
%5 = "mhlo.select"(%2, %3, %4) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
|
@ -197,9 +197,9 @@ func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @equal
|
||||
func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"}
|
||||
%0 = "tf.Equal"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
func @equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"}
|
||||
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -265,9 +265,9 @@ func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @notequal
|
||||
func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"}
|
||||
%0 = "tf.NotEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"}
|
||||
%0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -278,9 +278,9 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @greater
|
||||
func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"}
|
||||
%0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"}
|
||||
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -310,29 +310,29 @@ func @greater_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @greater_uranked
|
||||
func @greater_uranked(%arg0: tensor<*xi32>) -> tensor<*xi1> {
|
||||
func @greater_uranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> {
|
||||
// CHECK: "tf.Greater"
|
||||
%0 = "tf.Greater"(%arg0, %arg0) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1>
|
||||
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1>
|
||||
return %0: tensor<*xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @greater_equal
|
||||
func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"}
|
||||
%0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
func @greater_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GE"}
|
||||
%0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @less
|
||||
func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"}
|
||||
%0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
func @less(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LT"}
|
||||
%0 = "tf.Less"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @less_equal
|
||||
func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"}
|
||||
%0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
func @less_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"}
|
||||
%0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
@ -125,9 +125,9 @@ func @constant(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @greater
|
||||
func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"}
|
||||
%0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"}
|
||||
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0: tensor<2xi1>
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user