Add compare_type optional attribute to CompareOp in HLO dialects

If unspecified, `compare_type` is FLOAT for float element types, SIGNED for signed element types and UNSIGNED for unsigned element types. compare_type can be TOTALORDER for float element types.

- Added import and export support the attribute.
- Restricted legalization from HLO to TF to the default compare types.
- Updated existing usage of the CompareOp

PiperOrigin-RevId: 339099219
Change-Id: I47c94603c497dc225373f638974d20eb28f56b76
This commit is contained in:
Smit Hinsu 2020-10-26 12:57:48 -07:00 committed by TensorFlower Gardener
parent f9d8d8a02f
commit b5359acda5
22 changed files with 178 additions and 60 deletions

View File

@ -427,7 +427,10 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
string summary = "Compare operator (with optional broadcasting)";
string description = [{
Compares `lhs` and `rhs` elementwise according to `comparison_direction`.
Compares `lhs` and `rhs` elementwise according to `comparison_direction`
and `compare_type`. If unspecified, `compare_type` is FLOAT for float element
types, SIGNED for signed element types and UNSIGNED for unsigned element
types.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
@ -437,13 +440,15 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
HLO_Tensor:$lhs,
HLO_Tensor:$rhs,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
HLO_ComparisonDirectionAttr:$comparison_direction
HLO_ComparisonDirectionAttr:$comparison_direction,
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
);
let results = (outs HLO_PredTensor);
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
"DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction"
"DenseIntElementsAttr broadcast_dimensions, "
"StringAttr comparison_direction, StringAttr compare_type = {}"
>];
}

View File

@ -680,16 +680,19 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands,
let arguments = (ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs,
HLO_ComparisonDirectionAttr:$comparison_direction
HLO_ComparisonDirectionAttr:$comparison_direction,
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
);
let results = (outs HLO_PredTensor);
let hasFolder = 1;
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
"StringAttr comparison_direction"
>];
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
"StringAttr comparison_direction, StringAttr compare_type = {}">,
];
let hasCustomHLOConverter = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -749,11 +749,30 @@ def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection",
HLO_COMPARISON_DIRECTION_LT
]>;
def HLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"StringAttr()">;
def HLO_COMPARISON_TYPE_FLOAT : StrEnumAttrCase<"FLOAT">;
def HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : StrEnumAttrCase<"TOTALORDER">;
def HLO_COMPARISON_TYPE_SIGNED : StrEnumAttrCase<"SIGNED">;
def HLO_COMPARISON_TYPE_UNSIGNED : StrEnumAttrCase<"UNSIGNED">;
def HLO_ComparisonTypeAttr : StrEnumAttr<"ComparisonType",
"Which comparison type to use.",
[
HLO_COMPARISON_TYPE_FLOAT,
HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER,
HLO_COMPARISON_TYPE_SIGNED,
HLO_COMPARISON_TYPE_UNSIGNED
]>;
class BASE_HLO_CompareOp {
string summary = "Comparison operator";
string description = [{
Compares `lhs` and `rhs` elementwise according to `comparison_direction`.
Compares `lhs` and `rhs` elementwise according to `comparison_direction`
and `compare_type`. If unspecified, `compare_type` is FLOAT for float element
types, SIGNED for signed element types and UNSIGNED for unsigned element
types.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.

View File

@ -284,7 +284,8 @@ def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_PredBuffer, "", [MemWrite]>:$out,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
HLO_ComparisonDirectionAttr:$comparison_direction
HLO_ComparisonDirectionAttr:$comparison_direction,
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
);
}

View File

@ -190,11 +190,12 @@ LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
Value lhs, Value rhs,
DenseIntElementsAttr broadcast_dimensions,
StringAttr comparison_direction) {
StringAttr comparison_direction,
StringAttr compare_type) {
auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(),
builder.getI1Type(), broadcast_dimensions);
build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
comparison_direction);
comparison_direction, compare_type);
}
LogicalResult BroadcastCompareOp::inferReturnTypeComponents(

View File

@ -2611,10 +2611,12 @@ void UnaryEinsumOp::getCanonicalizationPatterns(
//===----------------------------------------------------------------------===//
void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
Value rhs, StringAttr comparison_direction) {
Value rhs, StringAttr comparison_direction,
StringAttr compare_type) {
auto new_type =
UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type());
build(builder, result, new_type, lhs, rhs, comparison_direction);
build(builder, result, new_type, lhs, rhs, comparison_direction,
compare_type);
}
LogicalResult CompareOp::inferReturnTypeComponents(

View File

@ -505,9 +505,9 @@ struct HloCompareAdaptor {
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<mhlo::CompareOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs,
from_op.comparison_direction());
return builder.create<mhlo::CompareOp>(
from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs,
from_op.comparison_direction(), from_op.compare_typeAttr());
}
};

View File

@ -31,7 +31,8 @@ def : Pat<(HLOClient_AcosOp $input),
(HLO_CompareOp
$input,
(HLO_ConstantLike<"-1"> $input),
HLO_COMPARISON_DIRECTION_NE
HLO_COMPARISON_DIRECTION_NE,
(HLO_DEFAULT_COMPARISON_TYPE)
),
(HLO_MulOp
(HLO_ConstantLike<"2"> $input),
@ -67,7 +68,8 @@ def : Pat<(HLOClient_SinhOp $input),
(HLO_CompareOp
(HLO_AbsOp $input),
(HLO_ConstantLike<"1"> $input),
HLO_COMPARISON_DIRECTION_LT
HLO_COMPARISON_DIRECTION_LT,
(HLO_DEFAULT_COMPARISON_TYPE)
),
(HLO_DivOp
(HLO_SubOp

View File

@ -28,7 +28,7 @@ func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xi64> {
// CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
// CHECK: return %[[SHAPE]] : tensor<2xi64>
%0 = "mhlo.compare"(%a, %b) { comparison_direction = "NE" }
%0 = "mhlo.compare"(%a, %b) {comparison_direction = "NE"}
: (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1>
%1 = "mhlo_test.reify_return_type_shapes"(%0)
: (tensor<2x?xi1>) -> tensor<2xi64>

View File

@ -533,6 +533,13 @@ func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor
return %0 : tensor<?xi1>
}
// CHECK-LABEL: func @equal_unsupported_compare_type
func @equal_unsupported_compare_type(%arg0: tensor<1xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xi1> {
// CHECK: chlo.broadcast_compare
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, compare_type = "TOTALORDER", comparison_direction = "EQ"} : (tensor<1xf32>, tensor<1x2xf32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
// CHECK-LABEL: func @notequal(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> {
@ -599,6 +606,13 @@ func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<
return %0 : tensor<1x2xi1>
}
// CHECK-LABEL: func @greater_unsupported_compare_type
func @greater_unsupported_compare_type(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xi1> {
// CHECK: mhlo.compare
%0 = "mhlo.compare"(%arg0, %arg1) {compare_type = "TOTALORDER", comparison_direction = "GT"} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1>
return %0 : tensor<2xi1>
}
// CHECK-LABEL: func @greater_equal(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2xi32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<2xi32>) -> tensor<2xi1> {

View File

@ -157,6 +157,16 @@ def : Pat<(HLO_SelectOp $cond, $t, $e), (TF_SelectOp $cond, $t, $e)>;
def : Pat<(HLO_ConcatenateOp $inputs, $dim),
(TF_ConcatV2Op $inputs, (TF_ConstOp $dim))>;
class HasCompareType<string value> :
CPred<"$_self.cast<StringAttr>().getValue() == \"" # value # "\"">;
// Attribute value should be such that it matches the comparison used by
// TensorFlow, if the attribute is present.
def IsTFCompareType : AttrConstraint<
Or<[CPred<"!$_self">, HasCompareType<"FLOAT">, HasCompareType<"SIGNED">,
HasCompareType<"UNSIGNED">]>,
"compare type supported by TensorFlow">;
//===----------------------------------------------------------------------===//
// Compare op patterns.
// Note that these are legalized from chlo.broadcast_* ops, since those are
@ -166,18 +176,22 @@ def : Pat<(HLO_ConcatenateOp $inputs, $dim),
foreach p = [[TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ],
[TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in {
def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue),
def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, p[1], IsTFCompareType:$type),
(p[0] $l, $r, ConstBoolAttrTrue),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_CompareOp $l, $r, p[1]), (p[0] $l, $r, ConstBoolAttrTrue)>;
def : Pat<(HLO_CompareOp $l, $r, p[1], IsTFCompareType:$type),
(p[0] $l, $r, ConstBoolAttrTrue)>;
}
foreach pair = [[TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE],
[TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT],
[TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE],
[TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in {
def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r),
def : Pat<(HLOClient_BroadcastCompareOp $l, $r, $_, pair[1], IsTFCompareType:$type),
(pair[0] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_CompareOp $l, $r, pair[1]), (pair[0] $l, $r)>;
def : Pat<(HLO_CompareOp $l, $r, pair[1], IsTFCompareType:$type),
(pair[0] $l, $r)>;
}
def ConvertDotOp : NativeCodeCall<"ConvertDotOp($_builder, "

View File

@ -300,6 +300,7 @@ cc_library(
":hlo_utils",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@ -305,7 +306,12 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
MakeAndReturn(CustomCallOp);
}
case HloOpcode::kCompare: {
attributes.push_back(ConvertComparisonDirection(instruction));
auto compare = Cast<HloCompareInstruction>(instruction);
attributes.push_back(ConvertComparisonDirection(compare->direction()));
auto default_type = Comparison::DefaultComparisonType(
compare->operand(0)->shape().element_type());
if (compare->type() != default_type)
attributes.push_back(ConvertComparisonType(compare->type()));
MakeAndReturn(CompareOp);
}
case HloOpcode::kCholesky: {
@ -855,11 +861,16 @@ StatusOr<Value> HloFunctionImporter::GetMlirValue(HloInstruction* instruction) {
}
mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection(
HloInstruction* instruction) {
ComparisonDirection direction) {
return builder_->getNamedAttr(
"comparison_direction",
builder_->getStringAttr(
ComparisonDirectionToString(instruction->comparison_direction())));
builder_->getStringAttr(ComparisonDirectionToString(direction)));
}
mlir::NamedAttribute HloFunctionImporter::ConvertComparisonType(
Comparison::Type type) {
return builder_->getNamedAttr(
"compare_type", builder_->getStringAttr(ComparisonTypeToString(type)));
}
mlir::DenseIntElementsAttr HloFunctionImporter::ConvertDimensions(

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -118,7 +119,10 @@ class HloFunctionImporter {
// Converts an XLA ComparisonDirection to the corresponding MLIR attribute.
mlir::NamedAttribute ConvertComparisonDirection(
xla::HloInstruction* instruction);
ComparisonDirection direction);
// Converts an XLA Comparison::Type to the corresponding MLIR attribute.
mlir::NamedAttribute ConvertComparisonType(Comparison::Type type);
// Converts the dimensions of an HLO instruction into an MLIR attribute.
mlir::DenseIntElementsAttr ConvertDimensions(

View File

@ -385,12 +385,14 @@ StatusOr<XlaOp> MlirHloBuilder::AddInstruction(
StatusOr<XlaOp> MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs,
XlaOp rhs,
ComparisonDirection direction) {
ComparisonDirection direction,
Comparison::Type type) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::mhlo::CompareOp>(
loc_, ty, GetValue(lhs), GetValue(rhs),
builder_.getStringAttr(ComparisonDirectionToString(direction)));
builder_.getStringAttr(ComparisonDirectionToString(direction)),
builder_.getStringAttr(ComparisonTypeToString(type)));
return MakeXlaOp(op.getResult());
}

View File

@ -205,7 +205,8 @@ class MlirHloBuilder : public XlaBuilder {
absl::Span<const XlaOp> operands) override;
StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
ComparisonDirection direction) override;
ComparisonDirection direction,
Comparison::Type type) override;
XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, XlaOp lhs,
XlaOp rhs) override;

View File

@ -731,6 +731,28 @@ LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) {
return success();
}
// Specialize CompareOp export to set broadcast_dimensions argument.
mlir::LogicalResult ExportXlaOp(mlir::mhlo::CompareOp op,
OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::XlaOp lhs, rhs;
if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
auto dir = Convert_comparison_direction(op.comparison_direction());
auto type_attr = op.compare_typeAttr();
xla::XlaOp xla_result;
if (type_attr) {
auto type =
xla::StringToComparisonType(type_attr.getValue().str()).ValueOrDie();
xla_result = xla::Compare(lhs, rhs, /*broadcast_dimensions=*/{}, dir, type);
} else {
xla_result = xla::Compare(lhs, rhs, dir);
}
value_map[op] = xla_result;
return mlir::success();
}
LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) {
return failure();
}

View File

@ -126,7 +126,7 @@ func @constant(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-LABEL: func @greater
func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> {
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"}
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {compare_type = "SIGNED", comparison_direction = "GT"}
%0 = "tf.Greater"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0: tensor<2xi1>
}

View File

@ -750,7 +750,7 @@ func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> te
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ( {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
%2 = "mhlo.compare"(%arg3, %arg4) {comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%2 = "mhlo.compare"(%arg3, %arg4) {compare_type = "TOTALORDER", comparison_direction = "GE"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
@ -764,7 +764,7 @@ func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) -> te
}
// CHECK: %[[SELECT_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> pred[] {
// CHECK: ROOT %[[RESULT:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GE
// CHECK: ROOT %[[RESULT:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GE, type=TOTALORDER
// CHECK: %[[SCATTER_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] {
// CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]])
@ -950,7 +950,7 @@ func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
%0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%7 = "mhlo.compare"(%arg0, %arg1) {compare_type = "FLOAT", comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return
@ -969,7 +969,7 @@ func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
func @main(%input0: tensor<16x16xf32>) {
%0 = "mhlo.sort"(%input0) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%7 = "mhlo.compare"(%arg0, %arg1) {compare_type = "FLOAT", comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>) -> (tensor<16x16xf32>)
return

View File

@ -144,10 +144,10 @@ add {
%Arg_2.3 = f32[3] parameter(2)
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
%compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ
%compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ, type=FLOAT
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
%compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {compare_type = "TOTALORDER", comparison_direction = "LE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
%compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE, type=TOTALORDER
// Requires broadcast of compatible tensors.
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg2) {comparison_direction = "GT"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>

View File

@ -143,10 +143,13 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r),
(HLO_SelectOp
(HLOClient_BroadcastCompareOp
(HLOClient_BroadcastCompareOp $l, (HLO_ConstOp (GetScalarOfType<0> $l)),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT,
(HLO_DEFAULT_COMPARISON_TYPE)),
(HLOClient_BroadcastCompareOp $r, (HLO_ConstOp (GetScalarOfType<0> $r)),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT),
(BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT,
(HLO_DEFAULT_COMPARISON_TYPE)),
(BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ,
(HLO_DEFAULT_COMPARISON_TYPE)),
(HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)),
(HLOClient_BroadcastDivOp
(HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l),
@ -170,14 +173,18 @@ def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r),
(HLOClient_BroadcastCompareOp
(HLOClient_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)),
(HLO_ConstOp:$l_zeros (GetScalarOfType<0> $l)),
(BinBroadcastDimensions $l, $rem), HLO_COMPARISON_DIRECTION_NE),
(BinBroadcastDimensions $l, $rem), HLO_COMPARISON_DIRECTION_NE,
(HLO_DEFAULT_COMPARISON_TYPE)),
(HLOClient_BroadcastCompareOp
(HLOClient_BroadcastCompareOp:$r_cmp $r,
(HLO_ConstOp:$r_zeros (GetScalarOfType<0> $r)),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT,
(HLO_DEFAULT_COMPARISON_TYPE)),
(HLOClient_BroadcastCompareOp:$rem_cmp $rem, $r_zeros,
(BinBroadcastDimensions $rem, $r_zeros), HLO_COMPARISON_DIRECTION_LT),
(BinBroadcastDimensions $r_cmp, $rem_cmp), HLO_COMPARISON_DIRECTION_NE),
(BinBroadcastDimensions $rem, $r_zeros), HLO_COMPARISON_DIRECTION_LT,
(HLO_DEFAULT_COMPARISON_TYPE)),
(BinBroadcastDimensions $r_cmp, $rem_cmp), HLO_COMPARISON_DIRECTION_NE,
(HLO_DEFAULT_COMPARISON_TYPE)),
(NullDenseIntElementsAttr)),
(HLOClient_BroadcastAddOp $r,
$rem, (BinBroadcastDimensions $r, $rem)), $rem)>;
@ -204,7 +211,8 @@ foreach fromToBinPair = [[TF_LogicalAndOp, HLOClient_BroadcastAndOp],
class DirectComparePat<Op FromOp, StrEnumAttrCase direction>
: Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r),
(HLOClient_BroadcastCompareOp
$l, $r, (BinBroadcastDimensions $l, $r), direction)>;
$l, $r, (BinBroadcastDimensions $l, $r), direction,
(HLO_DEFAULT_COMPARISON_TYPE))>;
def : DirectComparePat<TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT>;
def : DirectComparePat<TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE>;
@ -215,7 +223,8 @@ class EqualityPat<Op FromOp, StrEnumAttrCase direction>
: Pat<(FromOp AnyRankedTensor:$l, AnyRankedTensor:$r,
TrueBoolAttr:$incompatible_shape_error),
(HLOClient_BroadcastCompareOp
$l, $r, (BinBroadcastDimensions $l, $r), direction),
$l, $r, (BinBroadcastDimensions $l, $r), direction,
(HLO_DEFAULT_COMPARISON_TYPE)),
[(AreBroadcastCompatible $l, $r)]>;
def : EqualityPat<TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ>;
@ -404,14 +413,15 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyStaticShapeTensor:$input, $num_lower,
(HLO_SelectOp:$num_lower_or_m
(HLO_CompareOp
$num_lower, (HLO_ConstOp:$zero (ConstantSplat<"0"> $num_lower)),
HLO_COMPARISON_DIRECTION_LT
HLO_COMPARISON_DIRECTION_LT, (HLO_DEFAULT_COMPARISON_TYPE)
),
$m_dim,
$num_lower
),
(HLO_SelectOp:$num_upper_or_n
(HLO_CompareOp
$num_upper, $zero, HLO_COMPARISON_DIRECTION_LT
$num_upper, $zero, HLO_COMPARISON_DIRECTION_LT,
(HLO_DEFAULT_COMPARISON_TYPE)
),
$n_dim,
$num_upper
@ -424,10 +434,12 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyStaticShapeTensor:$input, $num_lower,
(createIotaOp<"1"> $op, $input, $num_lower),
(createIotaOp<"0"> $op, $input, $num_lower)
),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE,
(HLO_DEFAULT_COMPARISON_TYPE)
),
(HLOClient_BroadcastCompareOp $offset, $num_upper_or_n,
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE,
(HLO_DEFAULT_COMPARISON_TYPE)
)
),
$input,
@ -452,7 +464,7 @@ def : Pat<(TF_EluOp AnyRankedTensor:$features),
$features,
(HLO_ConstOp:$zero (GetScalarOfType<0> $features)),
(BinBroadcastDimensions $zero, $features),
HLO_COMPARISON_DIRECTION_GT),
HLO_COMPARISON_DIRECTION_GT, (HLO_DEFAULT_COMPARISON_TYPE)),
$features,
(HLO_Expm1Op $features))>;
@ -462,7 +474,7 @@ def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$featur
$features,
(HLO_ConstOp:$zero (GetScalarOfType<0> $features)),
(BinBroadcastDimensions $zero, $features),
HLO_COMPARISON_DIRECTION_GT),
HLO_COMPARISON_DIRECTION_GT, (HLO_DEFAULT_COMPARISON_TYPE)),
$gradients,
(HLO_MulOp
$gradients,
@ -507,7 +519,8 @@ def : Pat<(TF_ReluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$featu
(HLO_SelectOp
(HLOClient_BroadcastCompareOp $features,
(HLO_ConstOp (GetScalarOfType<0> $features)),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_GT),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_GT,
(HLO_DEFAULT_COMPARISON_TYPE)),
$gradients, (HLO_ConstOp (ConstantSplat<"0"> $gradients)))>;
//===----------------------------------------------------------------------===//
@ -679,7 +692,8 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features),
$features,
(HLO_NegOp $threshold),
(NullDenseIntElementsAttr),
HLO_COMPARISON_DIRECTION_GT
HLO_COMPARISON_DIRECTION_GT,
(HLO_DEFAULT_COMPARISON_TYPE)
),
$features,
(HLO_SelectOp
@ -687,7 +701,8 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features),
$features,
$threshold,
(NullDenseIntElementsAttr),
HLO_COMPARISON_DIRECTION_LT
HLO_COMPARISON_DIRECTION_LT,
(HLO_DEFAULT_COMPARISON_TYPE)
),
$features_exp,
(HLO_Log1pOp $features_exp)

View File

@ -852,9 +852,10 @@ class XlaBuilder {
absl::optional<ComparisonDirection> direction = absl::nullopt,
absl::optional<Comparison::Type> type = absl::nullopt);
StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
ComparisonDirection direction);
// Internal helper method for binary op compare without broadcast dimensions.
virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
ComparisonDirection direction);
virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
ComparisonDirection direction,
Comparison::Type type);