diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index d53bafff638..fd24b7284c1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -7436,9 +7436,15 @@ select(condition, t, e) ==> [[1, 2], ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; + + let hasCanonicalizer = 1; + + let verifier = [{ + return Verify(*this); + }]; } -def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect]> { +def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect, ResultsBroadcastableShape]> { let summary = ""; let description = [{ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 7fcc82f6757..1b6dbfe3e1a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -251,6 +251,39 @@ static LogicalResult VerifyTypesCompatibility( return success(); } +// This is a helper for the Select to SelectV2 canonicalization. The `data` rank +// refers to the rank of `t`/`e` (these two inputs have equal rank; this is +// checked in the verifier). +// +// In most cases, the predicate for Select can be used directly as the predicate +// for SelectV2. However, there is one case that varies, which is when the +// predicate is a tensor and the data is multidimensional. In this case, Select +// op semantics dictate that the predicate tensor length must match the size of +// the first data dimension. This varies from normal broadcasting semantics +// (which are used in SelectV2), so we must reshape the tensor in this case to +// be compatible. +static Value ReshapeSelectPredIfNecessary(OpBuilder *builder, Location loc, + Value cond, int data_rank) { + auto cond_tensor = cond.getType().cast(); + // Reshape is only needed in the case that the cond rank is 1 (i.e. it is + // a vector) AND t/e rank is > 1. + if (cond_tensor.getRank() != 1 || data_rank <= 1) { + // No reshape necessary. Leave cond as it is. + return cond; + } + + // This is the case where a reshape is needed. We want to construct the + // shape [x,1,...1], where x is the value in the pred tensor and the + // length of the shape is equal to data_rank. + SmallVector shape(data_rank, 1); + shape[0] = cond_tensor.getShape().front(); + auto new_shape_type = + RankedTensorType::get({data_rank}, builder->getIntegerType(64)); + auto shape_attr = DenseIntElementsAttr::get(new_shape_type, shape); + auto new_shape = builder->create(loc, shape_attr); + return builder->create(loc, cond, new_shape); +} + //===----------------------------------------------------------------------===// // Helper functions detect device capabilities from RuntimeDevices. //===----------------------------------------------------------------------===// @@ -2550,6 +2583,81 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor, return unranked(); } +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +// Verifies a few extra requirements on SelectOp: +// (1) `then` and `else` must have same shape +// (2) At least one of the following must be true: +// (a) `cond` has the same rank as `then` and `else` +// (b) `cond` is a scalar +// (c) `cond` is a vector AND `then` and `else` are non-scalar with their +// first dimension equal to `cond`. +static LogicalResult Verify(SelectOp op) { + auto then_tensor = op.t().getType().cast(); + auto else_tensor = op.e().getType().cast(); + // Check (1). + if (!AreCastCompatible({then_tensor, else_tensor})) + return op.emitOpError() << "requires t and e have compatible shapes"; + + // Get data rank (if exists). + int data_rank; + // If data is unranked or data_rank is 0, this will remain -2. Otherwise + // refers to first dimension of then and/or else. + int data_first_dim = -2; + bool then_has_rank = then_tensor.hasRank(); + bool else_has_rank = else_tensor.hasRank(); + if (then_has_rank && else_has_rank) { + data_rank = then_tensor.getRank(); + if (then_tensor.getRank() > 0) + data_first_dim = then_tensor.getShape().front(); + if (else_tensor.getRank() > 0) + data_first_dim = std::max( + static_cast(else_tensor.getShape().front()), data_first_dim); + } else if (then_has_rank) { + data_rank = then_tensor.getRank(); + if (then_tensor.getRank() > 0) + data_first_dim = then_tensor.getShape().front(); + } else if (else_has_rank) { + data_rank = else_tensor.getRank(); + if (else_tensor.getRank() > 0) + data_first_dim = else_tensor.getShape().front(); + } else { + // Neither has a rank. + return success(); + } + + auto cond_tensor = op.condition().getType().dyn_cast(); + if (!cond_tensor) return success(); + auto cond_rank = cond_tensor.getRank(); + // Check (2a) and (2b). + if (cond_rank == 0 || cond_rank == data_rank) return success(); + // Check (2c). + if (cond_rank == 1) { + auto cond_shape = cond_tensor.getShape().front(); + if (data_rank == 0) { + return op.emitOpError() + << "requires that t and e are nonscalar when pred is a vector"; + } + // We know `data` tensor has a rank of at least 1. + if (data_first_dim != -1 && cond_shape != -1 && + data_first_dim != cond_shape) { + return op.emitOpError() << "requires that, when pred is a vector, the " + "shape matches the first dimension of t and e"; + } + return success(); + } + // None of (2a,b,c) were true; fail. + return op.emitOpError() << "requires that pred is a scalar OR has the same " + "rank as t and e OR is a vector"; +} + //===----------------------------------------------------------------------===// // SelectV2Op //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index e05894dc266..20f4dd79715 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -258,6 +258,59 @@ func @testDoubleReciprocal(%arg0: tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi // CHECK: return %arg0 } +// CHECK-LABEL: testSelectScalarPred +func @testSelectScalarPred(%arg0: tensor, %arg1: tensor<4x2xf16>, %arg2: tensor<4x2xf16>) -> tensor<4x2xf16> { + // CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16> + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16> + return %0: tensor<4x2xf16> +} + +// CHECK-LABEL: testSelectVectorPred +func @testSelectVectorPred(%arg0: tensor<2xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK-NEXT: %[[SHAPE:.*]] = "tf.Const" + // CHECK-NEXT: %[[PRED:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2xi1>, tensor<2xi64>) -> tensor<2x1xi1> + // CHECK-NEXT: "tf.SelectV2"(%[[PRED]], %arg1, %arg2) : (tensor<2x1xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// CHECK-LABEL: testSelectAllSameShape +func @testSelectAllSameShape(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// If we don't have guarantees on input shapes, we can't support canonicalizing +// to SelectV2. Test these cases. +// CHECK-LABEL: testSelectInvalid +func @testSelectInvalid(%arg0: tensor, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// CHECK-LABEL: testSelectInvalidUnranked +func @testSelectInvalidUnranked(%arg0: tensor<6x7xi1>, %arg1: tensor<*xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<6x7xi1>, tensor<*xf16>, tensor<*xf16>) -> tensor<*xf16> + return %0: tensor<*xf16> +} + +// CHECK-LABEL: testSelectThenUnranked +func @testSelectThenUnranked(%arg0: tensor<3xi1>, %arg1: tensor<*xf16>, %arg2: tensor<3x2xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<*xf16>, tensor<3x2xf16>) -> tensor<*xf16> + return %0: tensor<*xf16> +} + +// CHECK-LABEL: testSelectElseUnranked +func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<*xf16>) -> tensor<*xf16> + return %0: tensor<*xf16> +} + // CHECK-LABEL: testLogicalNotOfEqual func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> { %0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 3560fec7b7d..82e60a08e2e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1007,6 +1007,116 @@ func @pcall_func_2(%arg0: tensor, %arg1: tensor) -> tensor { // ----- +//===--------------------------------------------------------------------===// +// tf.Select +//===--------------------------------------------------------------------===// + +// Test valid tf.Select +// CHECK-LABEL: func @testSelect +func @testSelect(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<3x2xf16>) -> tensor<3x2xf16> { + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<3x2xf16>) -> tensor<3x2xf16> + return %0: tensor<3x2xf16> +} + +// ----- + +func @testInvalidSelect(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // expected-error @+1 {{requires that, when pred is a vector, the shape matches the first dimension of t and e}} + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + +// ----- + +// Test invalid tf.Select - broadcasting then/else parameters is not supported +func @selectBroadcastThen(%arg0: tensor, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { + // expected-error @+1 {{requires t and e have compatible shapes}} + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// ----- + +func @invalidSelect(%arg0: tensor<2xi1>, %arg1: tensor, %arg2: tensor) -> tensor<2xi32> { + // expected-error @+1 {{requires that t and e are nonscalar when pred is a vector}} + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor, tensor) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// ----- + +func @invalidSelect(%arg0: tensor<1x8xi1>, %arg1: tensor<1x8x8xi32>, %arg2: tensor<1x8x8xi32>) -> tensor<1x8x8xi32> { + // expected-error @+1 {{requires that pred is a scalar OR has the same rank as t and e OR is a vector}} + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<1x8xi1>, tensor<1x8x8xi32>, tensor<1x8x8xi32>) -> tensor<1x8x8xi32> + return %0: tensor<1x8x8xi32> +} + +// ----- + +//===--------------------------------------------------------------------===// +// tf.SelectV2 +//===--------------------------------------------------------------------===// + +// Test valid tf.SelectV2 +// CHfaECK-LABEL: func @selectV2BroadcastThen +func @selectV2BroadcastThen(%arg0: tensor, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// ----- + +// Test valid tf.SelectV2 +// CHECK-LABEL: func @selectV2BroadcastElse +func @selectV2BroadcastElse(%arg0: tensor, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// ----- + +// Test valid tf.SelectV2 +// CHECK-LABEL: func @selectV2BroadcastPred +func @selectV2BroadcastPred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32> + return %0: tensor<2x8x8xi32> +} + +// ----- + +// CHECK-LABEL: func @selectV2BroadcastAll +func @selectV2BroadcastAll(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32> + return %0: tensor<8x8x8xi32> +} + +// ----- + +// CHECK-LABEL: func @selectV2DynamicRanked +func @selectV2DynamicRanked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32> + return %0: tensor<2x?x8xi32> +} + +// ----- + +// CHECK-LABEL: func @selectV2Unranked +func @selectV2Unranked(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<*xi32>) -> tensor<*xi32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<*xi32>) -> tensor<*xi32> + return %0: tensor<*xi32> +} + +// ----- + +// Test invalid tf.SelectV2: this is an invalid broadcast for the predicate +func @testInvalidSelectV2(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<3x2xf16>) -> tensor<3x2xf16> { + // expected-error @+1 {{operands don't have broadcast-compatible shapes}} + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<3x2xf16>) -> tensor<3x2xf16> + return %0: tensor<3x2xf16> +} + +// ----- + //===--------------------------------------------------------------------===// // tf.Softmax //===--------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index ccc3e83a2a2..cf09f8d64fb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -152,6 +152,23 @@ def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)), def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)), (replaceWithValue $arg)>; +//===----------------------------------------------------------------------===// +// Select op patterns. +//===----------------------------------------------------------------------===// + +def ReshapeSelectPredIfNecessary : NativeCodeCall< + "ReshapeSelectPredIfNecessary(&($_builder), $0.getOwner()->getLoc(), $1, " + "$2.getType().cast().getRank())">; + +// Select supports tensor `condition` where the shape is equal to the first +// dimension of t and e. SelectV2 op supports normal broadcasting, so in these +// cases the condition needs to be reshaped. +def SelectToSelectV2 : Pat< + (TF_SelectOp:$op StaticShapeTensorOf<[AnyType]>:$cond, + StaticShapeTensorOf<[AnyType]>:$t, + StaticShapeTensorOf<[AnyType]>:$e), + (TF_SelectV2Op (ReshapeSelectPredIfNecessary $op, $cond, $t), $t, $e)>; + //===----------------------------------------------------------------------===// // Square op patterns. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index bfa96413e7c..2288e0fefc4 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1320,27 +1320,6 @@ func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tens // Select op legalizations. //===----------------------------------------------------------------------===// -// CHECK-LABEL: func @select -func @select(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> -} - -// CHECK-LABEL: func @select_float -func @select_float(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - return %0: tensor<2xf32> -} - -// CHECK-LABEL: func @select_multidimensional -func @select_multidimensional(%arg0: tensor<3x2xi1>, %arg1: tensor<3x2xi32>, %arg2: tensor<3x2xi32>) -> tensor<3x2xi32> { - // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) - %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> - return %0: tensor<3x2xi32> -} - // CHECK-LABEL: func @selectv2 func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { // CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2) @@ -1379,6 +1358,14 @@ func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %ar return %0: tensor<2x8x8xi32> } +// CHECK-LABEL: func @selectv2_broadcast_tensor_pred +func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> { + // CHECK: %[[BROADCAST:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi1>) -> tensor<2x3xi1> + // CHECK: "xla_hlo.select"(%[[BROADCAST]], %arg1, %arg2) + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16> + return %0: tensor<2x3xf16> +} + // CHECK-LABEL: func @selectv2_broadcast_all func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> { // CHECK-DAG: %[[BROADCAST_0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1> diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 33c92ee65d5..19fc42714b0 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -521,18 +521,6 @@ def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1, &$_builder)">; def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (TF_ConstOp $axis)), (HLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; -//===----------------------------------------------------------------------===// -// Ternary op patterns. -//===----------------------------------------------------------------------===// - -def BothTypesMatch : Constraint, - "types must be equal">; - -def : Pat<(TF_SelectOp $cond, $t, $e), (HLO_SelectOp $cond, $t, $e), - // TODO(jpienaar): This restriction is to avoid creating a currently - // unsupported HLO select. - [(BothTypesMatch $t, $e)]>; - //===----------------------------------------------------------------------===// // Unary op patterns. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index a1bb64eb88d..7bbfecff403 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -77,7 +77,6 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): np.int32(2), expected=np.array([1, 3, 5], dtype=np.int32)) - @test_util.disable_mlir_bridge('TODO(b/155949336)') def testSelect(self): for dtype in self.numeric_types: self._testTernary(