From b5359acda5115674d2a1e81943e9c7d2c2581a99 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Mon, 26 Oct 2020 12:57:48 -0700 Subject: [PATCH] Add compare_type optional attribute to CompareOp in HLO dialects If unspecified, `compare_type` is FLOAT for float element types, SIGNED for signed element types and UNSIGNED for unsigned element types. compare_type can be TOTALORDER for float element types. - Added import and export support the attribute. - Restricted legalization from HLO to TF to the default compare types. - Updated existing usage of the CompareOp PiperOrigin-RevId: 339099219 Change-Id: I47c94603c497dc225373f638974d20eb28f56b76 --- .../mlir-hlo/Dialect/mhlo/IR/chlo_ops.td | 11 ++-- .../mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 13 +++-- .../mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td | 21 +++++++- .../mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td | 3 +- .../mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc | 5 +- .../mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc | 6 ++- .../mhlo/transforms/chlo_legalize_to_hlo.cc | 6 +-- .../chlo_legalize_to_hlo_patterns.td | 6 ++- .../tests/mhlo_infer_shape_type_methods.mlir | 2 +- .../mlir/tensorflow/tests/legalize_hlo.mlir | 14 +++++ .../transforms/legalize_hlo_patterns.td | 22 ++++++-- tensorflow/compiler/mlir/xla/BUILD | 1 + .../mlir/xla/hlo_function_importer.cc | 19 +++++-- .../compiler/mlir/xla/hlo_function_importer.h | 6 ++- .../compiler/mlir/xla/ir/mlir_hlo_builder.cc | 6 ++- .../compiler/mlir/xla/ir/mlir_hlo_builder.h | 3 +- .../compiler/mlir/xla/mlir_hlo_to_hlo.cc | 22 ++++++++ .../xla/tests/legalize-tf-with-tf2xla.mlir | 2 +- .../mlir/xla/tests/translate/export.mlir | 8 +-- .../mlir/xla/tests/translate/import.hlotxt | 6 +-- .../xla/transforms/legalize_tf_patterns.td | 51 ++++++++++++------- tensorflow/compiler/xla/client/xla_builder.h | 5 +- 22 files changed, 178 insertions(+), 60 deletions(-) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 13d5f02368b..a6f530876d1 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -427,7 +427,10 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp< string summary = "Compare operator (with optional broadcasting)"; string description = [{ - Compares `lhs` and `rhs` elementwise according to `comparison_direction`. + Compares `lhs` and `rhs` elementwise according to `comparison_direction` + and `compare_type`. If unspecified, `compare_type` is FLOAT for float element + types, SIGNED for signed element types and UNSIGNED for unsigned element + types. See https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. @@ -437,13 +440,15 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp< HLO_Tensor:$lhs, HLO_Tensor:$rhs, OptionalAttr:$broadcast_dimensions, - HLO_ComparisonDirectionAttr:$comparison_direction + HLO_ComparisonDirectionAttr:$comparison_direction, + OptionalAttr:$compare_type ); let results = (outs HLO_PredTensor); let builders = [OpBuilder< "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " - "DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction" + "DenseIntElementsAttr broadcast_dimensions, " + "StringAttr comparison_direction, StringAttr compare_type = {}" >]; } diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 33c13aaca29..0b52ed2075e 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -680,16 +680,19 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands, let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, - HLO_ComparisonDirectionAttr:$comparison_direction + HLO_ComparisonDirectionAttr:$comparison_direction, + OptionalAttr:$compare_type ); let results = (outs HLO_PredTensor); let hasFolder = 1; - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " - "StringAttr comparison_direction" - >]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " + "StringAttr comparison_direction, StringAttr compare_type = {}">, + ]; + + let hasCustomHLOConverter = 1; } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index da8c921a47b..c30adde80a1 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -749,11 +749,30 @@ def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection", HLO_COMPARISON_DIRECTION_LT ]>; +def HLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"StringAttr()">; +def HLO_COMPARISON_TYPE_FLOAT : StrEnumAttrCase<"FLOAT">; +def HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : StrEnumAttrCase<"TOTALORDER">; +def HLO_COMPARISON_TYPE_SIGNED : StrEnumAttrCase<"SIGNED">; +def HLO_COMPARISON_TYPE_UNSIGNED : StrEnumAttrCase<"UNSIGNED">; + +def HLO_ComparisonTypeAttr : StrEnumAttr<"ComparisonType", + "Which comparison type to use.", + [ + HLO_COMPARISON_TYPE_FLOAT, + HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER, + HLO_COMPARISON_TYPE_SIGNED, + HLO_COMPARISON_TYPE_UNSIGNED + ]>; + + class BASE_HLO_CompareOp { string summary = "Comparison operator"; string description = [{ - Compares `lhs` and `rhs` elementwise according to `comparison_direction`. + Compares `lhs` and `rhs` elementwise according to `comparison_direction` + and `compare_type`. If unspecified, `compare_type` is FLOAT for float element + types, SIGNED for signed element types and UNSIGNED for unsigned element + types. See https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 28e51351c7e..3873724826e 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -284,7 +284,8 @@ def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp { Arg:$rhs, Arg:$out, OptionalAttr:$broadcast_dimensions, - HLO_ComparisonDirectionAttr:$comparison_direction + HLO_ComparisonDirectionAttr:$comparison_direction, + OptionalAttr:$compare_type ); } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc index 99b22a75a14..7ea42c6f806 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -190,11 +190,12 @@ LogicalResult BroadcastComplexOp::reifyReturnTypeShapes( void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, Value rhs, DenseIntElementsAttr broadcast_dimensions, - StringAttr comparison_direction) { + StringAttr comparison_direction, + StringAttr compare_type) { auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(), builder.getI1Type(), broadcast_dimensions); build(builder, result, new_type, lhs, rhs, broadcast_dimensions, - comparison_direction); + comparison_direction, compare_type); } LogicalResult BroadcastCompareOp::inferReturnTypeComponents( diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc index 17255224c9a..fe5198e903f 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -2611,10 +2611,12 @@ void UnaryEinsumOp::getCanonicalizationPatterns( //===----------------------------------------------------------------------===// void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, - Value rhs, StringAttr comparison_direction) { + Value rhs, StringAttr comparison_direction, + StringAttr compare_type) { auto new_type = UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type()); - build(builder, result, new_type, lhs, rhs, comparison_direction); + build(builder, result, new_type, lhs, rhs, comparison_direction, + compare_type); } LogicalResult CompareOp::inferReturnTypeComponents( diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index 42d6d70b524..f261d6adb23 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -505,9 +505,9 @@ struct HloCompareAdaptor { static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type, Value broadcasted_lhs, Value broadcasted_rhs, OpBuilder &builder) { - return builder.create(from_op.getLoc(), result_type, - broadcasted_lhs, broadcasted_rhs, - from_op.comparison_direction()); + return builder.create( + from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs, + from_op.comparison_direction(), from_op.compare_typeAttr()); } }; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td index a48abb6190c..2ad07eed773 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td @@ -31,7 +31,8 @@ def : Pat<(HLOClient_AcosOp $input), (HLO_CompareOp $input, (HLO_ConstantLike<"-1"> $input), - HLO_COMPARISON_DIRECTION_NE + HLO_COMPARISON_DIRECTION_NE, + (HLO_DEFAULT_COMPARISON_TYPE) ), (HLO_MulOp (HLO_ConstantLike<"2"> $input), @@ -67,7 +68,8 @@ def : Pat<(HLOClient_SinhOp $input), (HLO_CompareOp (HLO_AbsOp $input), (HLO_ConstantLike<"1"> $input), - HLO_COMPARISON_DIRECTION_LT + HLO_COMPARISON_DIRECTION_LT, + (HLO_DEFAULT_COMPARISON_TYPE) ), (HLO_DivOp (HLO_SubOp diff --git a/tensorflow/compiler/mlir/hlo/tests/mhlo_infer_shape_type_methods.mlir b/tensorflow/compiler/mlir/hlo/tests/mhlo_infer_shape_type_methods.mlir index d626f520824..8829e4c7328 100644 --- a/tensorflow/compiler/mlir/hlo/tests/mhlo_infer_shape_type_methods.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/mhlo_infer_shape_type_methods.mlir @@ -28,7 +28,7 @@ func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xi64> { // CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64 // CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64> // CHECK: return %[[SHAPE]] : tensor<2xi64> - %0 = "mhlo.compare"(%a, %b) { comparison_direction = "NE" } + %0 = "mhlo.compare"(%a, %b) {comparison_direction = "NE"} : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1> %1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor<2x?xi1>) -> tensor<2xi64> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index cc923070077..bc2229d944b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -533,6 +533,13 @@ func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor return %0 : tensor } +// CHECK-LABEL: func @equal_unsupported_compare_type +func @equal_unsupported_compare_type(%arg0: tensor<1xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xi1> { + // CHECK: chlo.broadcast_compare + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, compare_type = "TOTALORDER", comparison_direction = "EQ"} : (tensor<1xf32>, tensor<1x2xf32>) -> tensor<1x2xi1> + return %0 : tensor<1x2xi1> +} + // CHECK-LABEL: func @notequal( // CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> { @@ -599,6 +606,13 @@ func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor< return %0 : tensor<1x2xi1> } +// CHECK-LABEL: func @greater_unsupported_compare_type +func @greater_unsupported_compare_type(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xi1> { + // CHECK: mhlo.compare + %0 = "mhlo.compare"(%arg0, %arg1) {compare_type = "TOTALORDER", comparison_direction = "GT"} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + return %0 : tensor<2xi1> +} + // CHECK-LABEL: func @greater_equal( // CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td index 17d7f00369d..8035d58857d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo_patterns.td @@ -157,6 +157,16 @@ def : Pat<(HLO_SelectOp $cond, $t, $e), (TF_SelectOp $cond, $t, $e)>; def : Pat<(HLO_ConcatenateOp $inputs, $dim), (TF_ConcatV2Op $inputs, (TF_ConstOp $dim))>; +class HasCompareType : + CPred<"$_self.cast().getValue() == \"" # value # "\"">; + +// Attribute value should be such that it matches the comparison used by +// TensorFlow, if the attribute is present. +def IsTFCompareType : AttrConstraint< + Or<[CPred<"!$_self">, HasCompareType<"FLOAT">, HasCompareType<"SIGNED">, + HasCompareType<"UNSIGNED">]>, + "compare type supported by TensorFlow">; + //===----------------------------------------------------------------------===// // Compare op patterns. // Note that these are legalized from chlo.broadcast_* ops, since those are @@ -166,18 +176,22 @@ def : Pat<(HLO_ConcatenateOp $inputs, $dim), foreach p = [[TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ], [TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in { - def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue), + def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, p[1], IsTFCompareType:$type), + (p[0] $l, $r, ConstBoolAttrTrue), [(AreBroadcastCompatible $l, $r)]>; - def : Pat<(HLO_CompareOp $l, $r, p[1]), (p[0] $l, $r, ConstBoolAttrTrue)>; + def : Pat<(HLO_CompareOp $l, $r, p[1], IsTFCompareType:$type), + (p[0] $l, $r, ConstBoolAttrTrue)>; } foreach pair = [[TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE], [TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT], [TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE], [TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in { - def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r), + def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, pair[1], IsTFCompareType:$type), + (pair[0] $l, $r), [(AreBroadcastCompatible $l, $r)]>; - def : Pat<(HLO_CompareOp $l, $r, pair[1]), (pair[0] $l, $r)>; + def : Pat<(HLO_CompareOp $l, $r, pair[1], IsTFCompareType:$type), + (pair[0] $l, $r)>; } def ConvertDotOp : NativeCodeCall<"ConvertDotOp($_builder, " diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 32dd1e202ee..1e664ff40f4 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -300,6 +300,7 @@ cc_library( ":hlo_utils", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index d682b6cb44b..81facea3857 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/xla/attribute_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" +#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -305,7 +306,12 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( MakeAndReturn(CustomCallOp); } case HloOpcode::kCompare: { - attributes.push_back(ConvertComparisonDirection(instruction)); + auto compare = Cast(instruction); + attributes.push_back(ConvertComparisonDirection(compare->direction())); + auto default_type = Comparison::DefaultComparisonType( + compare->operand(0)->shape().element_type()); + if (compare->type() != default_type) + attributes.push_back(ConvertComparisonType(compare->type())); MakeAndReturn(CompareOp); } case HloOpcode::kCholesky: { @@ -855,11 +861,16 @@ StatusOr HloFunctionImporter::GetMlirValue(HloInstruction* instruction) { } mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection( - HloInstruction* instruction) { + ComparisonDirection direction) { return builder_->getNamedAttr( "comparison_direction", - builder_->getStringAttr( - ComparisonDirectionToString(instruction->comparison_direction()))); + builder_->getStringAttr(ComparisonDirectionToString(direction))); +} + +mlir::NamedAttribute HloFunctionImporter::ConvertComparisonType( + Comparison::Type type) { + return builder_->getNamedAttr( + "compare_type", builder_->getStringAttr(ComparisonTypeToString(type))); } mlir::DenseIntElementsAttr HloFunctionImporter::ConvertDimensions( diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index 4a75b079d76..ee0372ad2b2 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -118,7 +119,10 @@ class HloFunctionImporter { // Converts an XLA ComparisonDirection to the corresponding MLIR attribute. mlir::NamedAttribute ConvertComparisonDirection( - xla::HloInstruction* instruction); + ComparisonDirection direction); + + // Converts an XLA Comparison::Type to the corresponding MLIR attribute. + mlir::NamedAttribute ConvertComparisonType(Comparison::Type type); // Converts the dimensions of an HLO instruction into an MLIR attribute. mlir::DenseIntElementsAttr ConvertDimensions( diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index daea2d9b8f6..4d045179f21 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -385,12 +385,14 @@ StatusOr MlirHloBuilder::AddInstruction( StatusOr MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, - ComparisonDirection direction) { + ComparisonDirection direction, + Comparison::Type type) { TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( shape, builder_)); auto op = builder_.create( loc_, ty, GetValue(lhs), GetValue(rhs), - builder_.getStringAttr(ComparisonDirectionToString(direction))); + builder_.getStringAttr(ComparisonDirectionToString(direction)), + builder_.getStringAttr(ComparisonTypeToString(type))); return MakeXlaOp(op.getResult()); } diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index 59b4bc7b1e0..cbdc6f48fdc 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -205,7 +205,8 @@ class MlirHloBuilder : public XlaBuilder { absl::Span operands) override; StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, - ComparisonDirection direction) override; + ComparisonDirection direction, + Comparison::Type type) override; XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, XlaOp lhs, XlaOp rhs) override; diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 875f521f520..0e904c153bb 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -731,6 +731,28 @@ LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) { return success(); } +// Specialize CompareOp export to set broadcast_dimensions argument. +mlir::LogicalResult ExportXlaOp(mlir::mhlo::CompareOp op, + OpLoweringContext ctx) { + auto& value_map = *ctx.values; + xla::XlaOp lhs, rhs; + if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure(); + if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure(); + auto dir = Convert_comparison_direction(op.comparison_direction()); + auto type_attr = op.compare_typeAttr(); + + xla::XlaOp xla_result; + if (type_attr) { + auto type = + xla::StringToComparisonType(type_attr.getValue().str()).ValueOrDie(); + xla_result = xla::Compare(lhs, rhs, /*broadcast_dimensions=*/{}, dir, type); + } else { + xla_result = xla::Compare(lhs, rhs, dir); + } + value_map[op] = xla_result; + return mlir::success(); +} + LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) { return failure(); } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index a21a78cf7f4..ee32407c24f 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -126,7 +126,7 @@ func @constant(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @greater func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {compare_type = "SIGNED", comparison_direction = "GT"} %0 = "tf.Greater"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> return %0: tensor<2xi1> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index b857b2963f9..9a9ac78cc70 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -750,7 +750,7 @@ func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> te %0 = mhlo.constant dense<0.000000e+00> : tensor %1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ( { ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors - %2 = "mhlo.compare"(%arg3, %arg4) {comparison_direction = "GE"} : (tensor, tensor) -> tensor + %2 = "mhlo.compare"(%arg3, %arg4) {compare_type = "TOTALORDER", comparison_direction = "GE"} : (tensor, tensor) -> tensor "mhlo.return"(%2) : (tensor) -> () }, { ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors @@ -764,7 +764,7 @@ func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> te } // CHECK: %[[SELECT_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> pred[] { -// CHECK: ROOT %[[RESULT:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GE +// CHECK: ROOT %[[RESULT:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GE, type=TOTALORDER // CHECK: %[[SCATTER_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] { // CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]]) @@ -950,7 +950,7 @@ func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): - %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + %7 = "mhlo.compare"(%arg0, %arg1) {compare_type = "FLOAT", comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return @@ -969,7 +969,7 @@ func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { func @main(%input0: tensor<16x16xf32>) { %0 = "mhlo.sort"(%input0) ( { ^bb0(%arg0: tensor, %arg1: tensor): - %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + %7 = "mhlo.compare"(%arg0, %arg1) {compare_type = "FLOAT", comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>) -> (tensor<16x16xf32>) return diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index cce49b16c6c..b1a54af2c6e 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -144,10 +144,10 @@ add { %Arg_2.3 = f32[3] parameter(2) // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> - %compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ + %compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ, type=FLOAT - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> - %compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {compare_type = "TOTALORDER", comparison_direction = "LE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + %compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE, type=TOTALORDER // Requires broadcast of compatible tensors. // CHECK-NEXT: "mhlo.compare"(%arg0, %arg2) {comparison_direction = "GT"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index d8baef14e62..11af809ffb7 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -143,10 +143,13 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), (HLO_SelectOp (HLOClient_BroadcastCompareOp (HLOClient_BroadcastCompareOp $l, (HLO_ConstOp (GetScalarOfType<0> $l)), - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), + (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT, + (HLO_DEFAULT_COMPARISON_TYPE)), (HLOClient_BroadcastCompareOp $r, (HLO_ConstOp (GetScalarOfType<0> $r)), - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), - (BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ), + (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT, + (HLO_DEFAULT_COMPARISON_TYPE)), + (BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ, + (HLO_DEFAULT_COMPARISON_TYPE)), (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)), (HLOClient_BroadcastDivOp (HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l), @@ -170,14 +173,18 @@ def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r), (HLOClient_BroadcastCompareOp (HLOClient_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), (HLO_ConstOp:$l_zeros (GetScalarOfType<0> $l)), - (BinBroadcastDimensions $l, $rem), HLO_COMPARISON_DIRECTION_NE), + (BinBroadcastDimensions $l, $rem), HLO_COMPARISON_DIRECTION_NE, + (HLO_DEFAULT_COMPARISON_TYPE)), (HLOClient_BroadcastCompareOp (HLOClient_BroadcastCompareOp:$r_cmp $r, (HLO_ConstOp:$r_zeros (GetScalarOfType<0> $r)), - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), + (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT, + (HLO_DEFAULT_COMPARISON_TYPE)), (HLOClient_BroadcastCompareOp:$rem_cmp $rem, $r_zeros, - (BinBroadcastDimensions $rem, $r_zeros), HLO_COMPARISON_DIRECTION_LT), - (BinBroadcastDimensions $r_cmp, $rem_cmp), HLO_COMPARISON_DIRECTION_NE), + (BinBroadcastDimensions $rem, $r_zeros), HLO_COMPARISON_DIRECTION_LT, + (HLO_DEFAULT_COMPARISON_TYPE)), + (BinBroadcastDimensions $r_cmp, $rem_cmp), HLO_COMPARISON_DIRECTION_NE, + (HLO_DEFAULT_COMPARISON_TYPE)), (NullDenseIntElementsAttr)), (HLOClient_BroadcastAddOp $r, $rem, (BinBroadcastDimensions $r, $rem)), $rem)>; @@ -204,7 +211,8 @@ foreach fromToBinPair = [[TF_LogicalAndOp, HLOClient_BroadcastAndOp], class DirectComparePat : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r), (HLOClient_BroadcastCompareOp - $l, $r, (BinBroadcastDimensions $l, $r), direction)>; + $l, $r, (BinBroadcastDimensions $l, $r), direction, + (HLO_DEFAULT_COMPARISON_TYPE))>; def : DirectComparePat; def : DirectComparePat; @@ -215,7 +223,8 @@ class EqualityPat : Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r, TrueBoolAttr:$incompatible_shape_error), (HLOClient_BroadcastCompareOp - $l, $r, (BinBroadcastDimensions $l, $r), direction), + $l, $r, (BinBroadcastDimensions $l, $r), direction, + (HLO_DEFAULT_COMPARISON_TYPE)), [(AreBroadcastCompatible $l, $r)]>; def : EqualityPat; @@ -404,14 +413,15 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyStaticShapeTensor:$input, $num_lower, (HLO_SelectOp:$num_lower_or_m (HLO_CompareOp $num_lower, (HLO_ConstOp:$zero (ConstantSplat<"0"> $num_lower)), - HLO_COMPARISON_DIRECTION_LT + HLO_COMPARISON_DIRECTION_LT, (HLO_DEFAULT_COMPARISON_TYPE) ), $m_dim, $num_lower ), (HLO_SelectOp:$num_upper_or_n (HLO_CompareOp - $num_upper, $zero, HLO_COMPARISON_DIRECTION_LT + $num_upper, $zero, HLO_COMPARISON_DIRECTION_LT, + (HLO_DEFAULT_COMPARISON_TYPE) ), $n_dim, $num_upper @@ -424,10 +434,12 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyStaticShapeTensor:$input, $num_lower, (createIotaOp<"1"> $op, $input, $num_lower), (createIotaOp<"0"> $op, $input, $num_lower) ), - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE + (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE, + (HLO_DEFAULT_COMPARISON_TYPE) ), (HLOClient_BroadcastCompareOp $offset, $num_upper_or_n, - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE + (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE, + (HLO_DEFAULT_COMPARISON_TYPE) ) ), $input, @@ -452,7 +464,7 @@ def : Pat<(TF_EluOp AnyRankedTensor:$features), $features, (HLO_ConstOp:$zero (GetScalarOfType<0> $features)), (BinBroadcastDimensions $zero, $features), - HLO_COMPARISON_DIRECTION_GT), + HLO_COMPARISON_DIRECTION_GT, (HLO_DEFAULT_COMPARISON_TYPE)), $features, (HLO_Expm1Op $features))>; @@ -462,7 +474,7 @@ def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$featur $features, (HLO_ConstOp:$zero (GetScalarOfType<0> $features)), (BinBroadcastDimensions $zero, $features), - HLO_COMPARISON_DIRECTION_GT), + HLO_COMPARISON_DIRECTION_GT, (HLO_DEFAULT_COMPARISON_TYPE)), $gradients, (HLO_MulOp $gradients, @@ -507,7 +519,8 @@ def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$featu (HLO_SelectOp (HLOClient_BroadcastCompareOp $features, (HLO_ConstOp (GetScalarOfType<0> $features)), - (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_GT), + (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_GT, + (HLO_DEFAULT_COMPARISON_TYPE)), $gradients, (HLO_ConstOp (ConstantSplat<"0"> $gradients)))>; //===----------------------------------------------------------------------===// @@ -679,7 +692,8 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), $features, (HLO_NegOp $threshold), (NullDenseIntElementsAttr), - HLO_COMPARISON_DIRECTION_GT + HLO_COMPARISON_DIRECTION_GT, + (HLO_DEFAULT_COMPARISON_TYPE) ), $features, (HLO_SelectOp @@ -687,7 +701,8 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), $features, $threshold, (NullDenseIntElementsAttr), - HLO_COMPARISON_DIRECTION_LT + HLO_COMPARISON_DIRECTION_LT, + (HLO_DEFAULT_COMPARISON_TYPE) ), $features_exp, (HLO_Log1pOp $features_exp) diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 05efc038082..cd4252f4f55 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -852,9 +852,10 @@ class XlaBuilder { absl::optional direction = absl::nullopt, absl::optional type = absl::nullopt); + StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, + ComparisonDirection direction); + // Internal helper method for binary op compare without broadcast dimensions. - virtual StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, - ComparisonDirection direction); virtual StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, ComparisonDirection direction, Comparison::Type type);