diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 351e8bdae0e..d545c2a1c04 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -689,6 +689,8 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands, ); let results = (outs HLO_PredTensor); + let hasFolder = 1; + let builders = [OpBuilder< "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " "StringAttr comparison_direction" diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc index 6711a916896..1ebec669ea5 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -2501,8 +2501,108 @@ LogicalResult CompareOp::reifyReturnTypeShapes( &reifiedReturnShapes); } +template +struct less : std::less {}; + +template <> +struct less { + bool operator()(const APInt& a, const APInt& b) const { return a.slt(b); } +}; + +template +struct less_equal : std::less_equal {}; + +template <> +struct less_equal { + bool operator()(const APInt& a, const APInt& b) const { return a.sle(b); } +}; + +template +struct greater : std::greater {}; + +template <> +struct greater { + bool operator()(const APInt& a, const APInt& b) const { return a.sgt(b); } +}; + +template +struct greater_equal : std::greater_equal {}; + +template <> +struct greater_equal { + bool operator()(const APInt& a, const APInt& b) const { return a.sge(b); } +}; + +template +static Attribute CompareFolder(CompareOp op, ArrayRef attrs) { + if (!attrs[0] || !attrs[1]) return {}; + + DenseElementsAttr lhs = attrs[0].dyn_cast(); + DenseElementsAttr rhs = attrs[1].dyn_cast(); + if (!lhs || !rhs) return {}; + + ShapedType operand_type = + op.getOperand(0).getType().template cast(); + if (!operand_type.hasStaticShape()) { + return {}; + } + + if (!operand_type.getElementType().isa()) { + return {}; + } + + SmallVector values; + values.reserve(lhs.getNumElements()); + for (const auto zip : + llvm::zip(lhs.getValues(), rhs.getValues())) { + values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip))); + } + + auto result_ty = op.getType().cast(); + return DenseElementsAttr::get(result_ty, values); +} + +OpFoldResult CompareOp::fold(ArrayRef operands) { + auto result_ty = getType().cast(); + if (!result_ty.hasStaticShape()) return {}; + + auto direction = comparison_direction(); + if (lhs() == rhs()) { + if (direction == "LE" || direction == "EQ" || direction == "GE") { + return DenseIntElementsAttr::get(result_ty, {true}); + } + + return DenseIntElementsAttr::get(result_ty, {false}); + } + + if (!operands[0] || !operands[1]) { + return {}; + } + +#define COMPARE_FOLDER(Op, comparison, Func) \ + if (direction == comparison) { \ + if (auto folded = CompareFolder>( \ + *this, operands)) \ + return folded; \ + if (auto folded = CompareFolder>( \ + *this, operands)) \ + return folded; \ + } + + COMPARE_FOLDER(CompareOp, "EQ", std::equal_to); + COMPARE_FOLDER(CompareOp, "NE", std::not_equal_to); + COMPARE_FOLDER(CompareOp, "LT", less); + COMPARE_FOLDER(CompareOp, "LE", less_equal); + COMPARE_FOLDER(CompareOp, "GT", greater); + COMPARE_FOLDER(CompareOp, "GE", greater_equal); +#undef COMPARE_FOLDER + + return {}; +} + } // namespace mhlo } // namespace mlir + #define GET_OP_CLASSES #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" diff --git a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir index 5da43d5f113..6d8814545d4 100644 --- a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir @@ -583,6 +583,262 @@ func @dce_while_without_side_effect(%arg0: tensor) -> tensor { return %arg0 : tensor } +// CHECK-LABEL: fold_compare_same_eq +func @fold_compare_same_eq(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: fold_compare_same_le +func @fold_compare_same_le(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: fold_compare_same_ge +func @fold_compare_same_ge(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK-LABEL: fold_compare_same_ne +func @fold_compare_same_ne(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: fold_compare_same_lt +func @fold_compare_same_lt(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: fold_compare_same_gt +func @fold_compare_same_gt(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: fold_compare_false_eq +func @fold_compare_false_eq() -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + return %2 : tensor +} +// CHECK-LABEL: fold_compare_true_eq +func @fold_compare_true_eq() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_eq_float +func @fold_compare_false_eq_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_eq_float +func @fold_compare_true_eq_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_ne +func @fold_compare_false_ne() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_ne +func @fold_compare_true_ne() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<0> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_ne_float +func @fold_compare_false_ne_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_ne_float +func @fold_compare_true_ne_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_lt +func @fold_compare_false_lt() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_lt +func @fold_compare_true_lt() -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_lt_float +func @fold_compare_false_lt_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_lt_float +func @fold_compare_true_lt_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_le +func @fold_compare_false_le() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<0> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_le +func @fold_compare_true_le() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_le_float +func @fold_compare_false_le_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<0.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_le_float +func @fold_compare_true_le_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_gt +func @fold_compare_false_gt() -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<0> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_gt +func @fold_compare_true_gt() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<0> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_gt_float +func @fold_compare_false_gt_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<0.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_gt_float +func @fold_compare_true_gt_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<0.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_ge +func @fold_compare_false_ge() -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_ge +func @fold_compare_true_ge() -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<0> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_ge_float +func @fold_compare_false_ge_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_ge_float +func @fold_compare_true_ge_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<0.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + return %2 : tensor +} + // CHECK-LABEL: unpack_repack_same_tuple // CHECK-SAME: ([[ARG0:%.*]]: tuple, !mhlo.token, tensor>) func @unpack_repack_same_tuple(%arg0: tuple, !mhlo.token, tensor>) -> tuple, !mhlo.token, tensor> { diff --git a/tensorflow/compiler/mlir/hlo/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/hlo/tests/legalize-to-std.mlir index abe4e872b73..404be85e05e 100644 --- a/tensorflow/compiler/mlir/hlo/tests/legalize-to-std.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/legalize-to-std.mlir @@ -51,38 +51,38 @@ func @unary_ops_float(%arg0: tensor<4xf32>) -> tensor<4xf32> { return %0 : tensor<4xf32> } -// CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { -func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { - // CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32> - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - // CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32> - %1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - // CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32> - %2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - // CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32> - %3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - // CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32> - %4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> - // CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32> - %5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> +// CHECK-LABEL: func @compare_int +func @compare_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { + // CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg1 : tensor<4xi32> + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + // CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg1 : tensor<4xi32> + %1 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + // CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg1 : tensor<4xi32> + %2 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + // CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg1 : tensor<4xi32> + %3 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + // CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg1 : tensor<4xi32> + %4 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + // CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg1 : tensor<4xi32> + %5 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> } // CHECK-LABEL: func @compare_float -func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { - // CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32> - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - // CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32> - %1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - // CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32> - %2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - // CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32> - %3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - // CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32> - %4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - // CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32> - %5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> +func @compare_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { + // CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg1 : tensor<4xf32> + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK-NEXT: %1 = cmpf "une", %arg0, %arg1 : tensor<4xf32> + %1 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg1 : tensor<4xf32> + %2 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg1 : tensor<4xf32> + %3 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg1 : tensor<4xf32> + %4 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg1 : tensor<4xf32> + %5 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 9864cffee7c..cc923070077 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -479,12 +479,13 @@ func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> te } // CHECK-LABEL: func @equal( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: %[[VAL_1:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_0]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return %[[VAL_1]] : tensor<2xi1> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.Equal"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_2]] : tensor<2xi1> // CHECK: } -func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } @@ -533,12 +534,13 @@ func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor } // CHECK-LABEL: func @notequal( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: %[[VAL_1:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_0]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return %[[VAL_1]] : tensor<2xi1> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_2]] : tensor<2xi1> // CHECK: } -func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } @@ -576,12 +578,13 @@ func @notequal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: ten } // CHECK-LABEL: func @greater( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: %[[VAL_1:.*]] = "tf.Greater"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return %[[VAL_1]] : tensor<2xi1> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.Greater"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_2]] : tensor<2xi1> // CHECK: } -func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } @@ -597,12 +600,13 @@ func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor< } // CHECK-LABEL: func @greater_equal( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: %[[VAL_1:.*]] = "tf.GreaterEqual"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return %[[VAL_1]] : tensor<2xi1> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.GreaterEqual"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_2]] : tensor<2xi1> // CHECK: } -func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @greater_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } @@ -618,12 +622,13 @@ func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> t } // CHECK-LABEL: func @less( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: %[[VAL_1:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return %[[VAL_1]] : tensor<2xi1> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.Less"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_2]] : tensor<2xi1> // CHECK: } -func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @less(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } @@ -639,12 +644,13 @@ func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2 } // CHECK-LABEL: func @less_equal( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>) -> tensor<2xi1> { -// CHECK: %[[VAL_1:.*]] = "tf.LessEqual"(%[[VAL_0]], %[[VAL_0]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> -// CHECK: return %[[VAL_1]] : tensor<2xi1> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> { +// CHECK: %[[VAL_2:.*]] = "tf.LessEqual"(%[[VAL_0]], %[[VAL_1]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +// CHECK: return %[[VAL_2]] : tensor<2xi1> // CHECK: } -func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @less_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0 : tensor<2xi1> } @@ -1363,20 +1369,21 @@ func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> { } // CHECK-LABEL: func @sign( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { -// CHECK: %[[VAL_1:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_0]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> -// CHECK: %[[VAL_2:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32> -// CHECK: %[[VAL_3:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_0]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> -// CHECK: %[[VAL_4:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32> -// CHECK: %[[VAL_5:.*]] = "tf.Sign"(%[[VAL_0]]) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> -// CHECK: %[[VAL_6:.*]] = "tf.Select"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> -// CHECK: %[[VAL_7:.*]] = "tf.Select"(%[[VAL_1]], %[[VAL_2]], %[[VAL_6]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> -// CHECK: return %[[VAL_7]] : tensor<1x2x3x4xf32> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { +// CHECK: %[[VAL_2:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> +// CHECK: %[[VAL_3:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32> +// CHECK: %[[VAL_4:.*]] = "tf.NotEqual"(%[[VAL_0]], %[[VAL_1]]) {incompatible_shape_error = true} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> +// CHECK: %[[VAL_5:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1x2x3x4xf32>} : () -> tensor<1x2x3x4xf32> +// CHECK: %[[VAL_6:.*]] = "tf.Sign"(%[[VAL_0]]) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> +// CHECK: %[[VAL_7:.*]] = "tf.Select"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> +// CHECK: %[[VAL_8:.*]] = "tf.Select"(%[[VAL_2]], %[[VAL_3]], %[[VAL_7]]) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> +// CHECK: return %[[VAL_8]] : tensor<1x2x3x4xf32> // CHECK: } -func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> +func @sign(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> %1 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> - %2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> + %2 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> %3 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> %4 = "mhlo.sign"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> %5 = "mhlo.select"(%2, %3, %4) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir index 887fdea5a21..7f37dbb0479 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir @@ -197,9 +197,9 @@ func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { //===----------------------------------------------------------------------===// // CHECK-LABEL: func @equal -func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} - %0 = "tf.Equal"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} + %0 = "tf.Equal"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } @@ -265,9 +265,9 @@ func @equal_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> } // CHECK-LABEL: func @notequal -func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} - %0 = "tf.NotEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "NE"} + %0 = "tf.NotEqual"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } @@ -278,9 +278,9 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { //===----------------------------------------------------------------------===// // CHECK-LABEL: func @greater -func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + // CHECK: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } @@ -310,29 +310,29 @@ func @greater_dynamic(%arg0: tensor, %arg1: tensor) -> tensor) -> tensor<*xi1> { +func @greater_uranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi1> { // CHECK: "tf.Greater" - %0 = "tf.Greater"(%arg0, %arg0) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> + %0 = "tf.Greater"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> return %0: tensor<*xi1> } // CHECK-LABEL: func @greater_equal -func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} - %0 = "tf.GreaterEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @greater_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GE"} + %0 = "tf.GreaterEqual"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } // CHECK-LABEL: func @less -func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} - %0 = "tf.Less"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @less(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LT"} + %0 = "tf.Less"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } // CHECK-LABEL: func @less_equal -func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} - %0 = "tf.LessEqual"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @less_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} + %0 = "tf.LessEqual"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index 8c8d99940de..a21a78cf7f4 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -125,9 +125,9 @@ func @constant(%arg0: tensor<2xf32>) -> tensor<2xf32> { } // CHECK-LABEL: func @greater -func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} - %0 = "tf.Greater"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> +func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} + %0 = "tf.Greater"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> }