Canonicalize tf.Select to tf.SelectV2.

The ops are mostly equivalent, except that Select has stricter requirements and does not support broadcasting, whereas SelectV2 does.

There is one special case to be considered in this canonicalization, which is when the predicate is a tensor and the data is multidimensional. In this case, Select op semantics dictate that the predicate tensor length must match the size of the first data dimension. This varies from normal broadcasting semantics, which are used in SelectV2, so we must reshape the tensor in this case to be compatible.

This also adds verifiers and tests for the Select and SelectV2 ops in the MLIR TF dialect.

PiperOrigin-RevId: 312362580
Change-Id: I43f326ad330c92ce279b25cecf5a2cf46714ce3f
This commit is contained in:
Lucy Fox 2020-05-19 15:12:23 -07:00 committed by TensorFlower Gardener
parent af2263101b
commit ca53894d61
8 changed files with 303 additions and 35 deletions

View File

@ -7436,9 +7436,15 @@ select(condition, t, e) ==> [[1, 2],
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
let hasCanonicalizer = 1;
let verifier = [{
return Verify(*this);
}];
}
def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect]> {
def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect, ResultsBroadcastableShape]> {
let summary = "";
let description = [{

View File

@ -251,6 +251,39 @@ static LogicalResult VerifyTypesCompatibility(
return success();
}
// This is a helper for the Select to SelectV2 canonicalization. The `data` rank
// refers to the rank of `t`/`e` (these two inputs have equal rank; this is
// checked in the verifier).
//
// In most cases, the predicate for Select can be used directly as the predicate
// for SelectV2. However, there is one case that varies, which is when the
// predicate is a tensor and the data is multidimensional. In this case, Select
// op semantics dictate that the predicate tensor length must match the size of
// the first data dimension. This varies from normal broadcasting semantics
// (which are used in SelectV2), so we must reshape the tensor in this case to
// be compatible.
static Value ReshapeSelectPredIfNecessary(OpBuilder *builder, Location loc,
Value cond, int data_rank) {
auto cond_tensor = cond.getType().cast<RankedTensorType>();
// Reshape is only needed in the case that the cond rank is 1 (i.e. it is
// a vector) AND t/e rank is > 1.
if (cond_tensor.getRank() != 1 || data_rank <= 1) {
// No reshape necessary. Leave cond as it is.
return cond;
}
// This is the case where a reshape is needed. We want to construct the
// shape [x,1,...1], where x is the value in the pred tensor and the
// length of the shape is equal to data_rank.
SmallVector<int64_t, 8> shape(data_rank, 1);
shape[0] = cond_tensor.getShape().front();
auto new_shape_type =
RankedTensorType::get({data_rank}, builder->getIntegerType(64));
auto shape_attr = DenseIntElementsAttr::get(new_shape_type, shape);
auto new_shape = builder->create<ConstOp>(loc, shape_attr);
return builder->create<ReshapeOp>(loc, cond, new_shape);
}
//===----------------------------------------------------------------------===//
// Helper functions detect device capabilities from RuntimeDevices.
//===----------------------------------------------------------------------===//
@ -2550,6 +2583,81 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
return unranked();
}
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<SelectToSelectV2>(context);
}
// Verifies a few extra requirements on SelectOp:
// (1) `then` and `else` must have same shape
// (2) At least one of the following must be true:
// (a) `cond` has the same rank as `then` and `else`
// (b) `cond` is a scalar
// (c) `cond` is a vector AND `then` and `else` are non-scalar with their
// first dimension equal to `cond`.
static LogicalResult Verify(SelectOp op) {
auto then_tensor = op.t().getType().cast<TensorType>();
auto else_tensor = op.e().getType().cast<TensorType>();
// Check (1).
if (!AreCastCompatible({then_tensor, else_tensor}))
return op.emitOpError() << "requires t and e have compatible shapes";
// Get data rank (if exists).
int data_rank;
// If data is unranked or data_rank is 0, this will remain -2. Otherwise
// refers to first dimension of then and/or else.
int data_first_dim = -2;
bool then_has_rank = then_tensor.hasRank();
bool else_has_rank = else_tensor.hasRank();
if (then_has_rank && else_has_rank) {
data_rank = then_tensor.getRank();
if (then_tensor.getRank() > 0)
data_first_dim = then_tensor.getShape().front();
if (else_tensor.getRank() > 0)
data_first_dim = std::max(
static_cast<int>(else_tensor.getShape().front()), data_first_dim);
} else if (then_has_rank) {
data_rank = then_tensor.getRank();
if (then_tensor.getRank() > 0)
data_first_dim = then_tensor.getShape().front();
} else if (else_has_rank) {
data_rank = else_tensor.getRank();
if (else_tensor.getRank() > 0)
data_first_dim = else_tensor.getShape().front();
} else {
// Neither has a rank.
return success();
}
auto cond_tensor = op.condition().getType().dyn_cast<RankedTensorType>();
if (!cond_tensor) return success();
auto cond_rank = cond_tensor.getRank();
// Check (2a) and (2b).
if (cond_rank == 0 || cond_rank == data_rank) return success();
// Check (2c).
if (cond_rank == 1) {
auto cond_shape = cond_tensor.getShape().front();
if (data_rank == 0) {
return op.emitOpError()
<< "requires that t and e are nonscalar when pred is a vector";
}
// We know `data` tensor has a rank of at least 1.
if (data_first_dim != -1 && cond_shape != -1 &&
data_first_dim != cond_shape) {
return op.emitOpError() << "requires that, when pred is a vector, the "
"shape matches the first dimension of t and e";
}
return success();
}
// None of (2a,b,c) were true; fail.
return op.emitOpError() << "requires that pred is a scalar OR has the same "
"rank as t and e OR is a vector";
}
//===----------------------------------------------------------------------===//
// SelectV2Op
//===----------------------------------------------------------------------===//

View File

@ -258,6 +258,59 @@ func @testDoubleReciprocal(%arg0: tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi
// CHECK: return %arg0
}
// CHECK-LABEL: testSelectScalarPred
func @testSelectScalarPred(%arg0: tensor<i1>, %arg1: tensor<4x2xf16>, %arg2: tensor<4x2xf16>) -> tensor<4x2xf16> {
// CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16>
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16>
return %0: tensor<4x2xf16>
}
// CHECK-LABEL: testSelectVectorPred
func @testSelectVectorPred(%arg0: tensor<2xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
// CHECK-NEXT: %[[SHAPE:.*]] = "tf.Const"
// CHECK-NEXT: %[[PRED:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2xi1>, tensor<2xi64>) -> tensor<2x1xi1>
// CHECK-NEXT: "tf.SelectV2"(%[[PRED]], %arg1, %arg2) : (tensor<2x1xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
return %0: tensor<2x3xf16>
}
// CHECK-LABEL: testSelectAllSameShape
func @testSelectAllSameShape(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
// CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
return %0: tensor<2x3xf16>
}
// If we don't have guarantees on input shapes, we can't support canonicalizing
// to SelectV2. Test these cases.
// CHECK-LABEL: testSelectInvalid
func @testSelectInvalid(%arg0: tensor<?xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
// CHECK-NEXT: tf.Select
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<?xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
return %0: tensor<2x3xf16>
}
// CHECK-LABEL: testSelectInvalidUnranked
func @testSelectInvalidUnranked(%arg0: tensor<6x7xi1>, %arg1: tensor<*xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> {
// CHECK-NEXT: tf.Select
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<6x7xi1>, tensor<*xf16>, tensor<*xf16>) -> tensor<*xf16>
return %0: tensor<*xf16>
}
// CHECK-LABEL: testSelectThenUnranked
func @testSelectThenUnranked(%arg0: tensor<3xi1>, %arg1: tensor<*xf16>, %arg2: tensor<3x2xf16>) -> tensor<*xf16> {
// CHECK-NEXT: tf.Select
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<*xf16>, tensor<3x2xf16>) -> tensor<*xf16>
return %0: tensor<*xf16>
}
// CHECK-LABEL: testSelectElseUnranked
func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> {
// CHECK-NEXT: tf.Select
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<*xf16>) -> tensor<*xf16>
return %0: tensor<*xf16>
}
// CHECK-LABEL: testLogicalNotOfEqual
func @testLogicalNotOfEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xi1> {
%0 = "tf.Equal"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xi1>

View File

@ -1007,6 +1007,116 @@ func @pcall_func_2(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
// -----
//===--------------------------------------------------------------------===//
// tf.Select
//===--------------------------------------------------------------------===//
// Test valid tf.Select
// CHECK-LABEL: func @testSelect
func @testSelect(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<3x2xf16>) -> tensor<3x2xf16> {
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<3x2xf16>) -> tensor<3x2xf16>
return %0: tensor<3x2xf16>
}
// -----
func @testInvalidSelect(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
// expected-error @+1 {{requires that, when pred is a vector, the shape matches the first dimension of t and e}}
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
return %0: tensor<2x3xf16>
}
// -----
// Test invalid tf.Select - broadcasting then/else parameters is not supported
func @selectBroadcastThen(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
// expected-error @+1 {{requires t and e have compatible shapes}}
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
return %0: tensor<2x8x8xi32>
}
// -----
func @invalidSelect(%arg0: tensor<2xi1>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<2xi32> {
// expected-error @+1 {{requires that t and e are nonscalar when pred is a vector}}
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<i32>, tensor<i32>) -> tensor<2xi32>
return %0: tensor<2xi32>
}
// -----
func @invalidSelect(%arg0: tensor<1x8xi1>, %arg1: tensor<1x8x8xi32>, %arg2: tensor<1x8x8xi32>) -> tensor<1x8x8xi32> {
// expected-error @+1 {{requires that pred is a scalar OR has the same rank as t and e OR is a vector}}
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<1x8xi1>, tensor<1x8x8xi32>, tensor<1x8x8xi32>) -> tensor<1x8x8xi32>
return %0: tensor<1x8x8xi32>
}
// -----
//===--------------------------------------------------------------------===//
// tf.SelectV2
//===--------------------------------------------------------------------===//
// Test valid tf.SelectV2
// CHfaECK-LABEL: func @selectV2BroadcastThen
func @selectV2BroadcastThen(%arg0: tensor<i1>, %arg1: tensor<8x1xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<8x1xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
return %0: tensor<2x8x8xi32>
}
// -----
// Test valid tf.SelectV2
// CHECK-LABEL: func @selectV2BroadcastElse
func @selectV2BroadcastElse(%arg0: tensor<i1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<8x1xi32>) -> tensor<2x8x8xi32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x8x8xi32>, tensor<8x1xi32>) -> tensor<2x8x8xi32>
return %0: tensor<2x8x8xi32>
}
// -----
// Test valid tf.SelectV2
// CHECK-LABEL: func @selectV2BroadcastPred
func @selectV2BroadcastPred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x8x8xi32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<2x8x8xi32>) -> tensor<2x8x8xi32>
return %0: tensor<2x8x8xi32>
}
// -----
// CHECK-LABEL: func @selectV2BroadcastAll
func @selectV2BroadcastAll(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8x1x1xi1>, tensor<1x8x1xi32>, tensor<1x1x8xi32>) -> tensor<8x8x8xi32>
return %0: tensor<8x8x8xi32>
}
// -----
// CHECK-LABEL: func @selectV2DynamicRanked
func @selectV2DynamicRanked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32>
return %0: tensor<2x?x8xi32>
}
// -----
// CHECK-LABEL: func @selectV2Unranked
func @selectV2Unranked(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %arg2: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x8x8xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0: tensor<*xi32>
}
// -----
// Test invalid tf.SelectV2: this is an invalid broadcast for the predicate
func @testInvalidSelectV2(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<3x2xf16>) -> tensor<3x2xf16> {
// expected-error @+1 {{operands don't have broadcast-compatible shapes}}
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<3x2xf16>) -> tensor<3x2xf16>
return %0: tensor<3x2xf16>
}
// -----
//===--------------------------------------------------------------------===//
// tf.Softmax
//===--------------------------------------------------------------------===//

View File

@ -152,6 +152,23 @@ def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)),
def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)),
(replaceWithValue $arg)>;
//===----------------------------------------------------------------------===//
// Select op patterns.
//===----------------------------------------------------------------------===//
def ReshapeSelectPredIfNecessary : NativeCodeCall<
"ReshapeSelectPredIfNecessary(&($_builder), $0.getOwner()->getLoc(), $1, "
"$2.getType().cast<RankedTensorType>().getRank())">;
// Select supports tensor `condition` where the shape is equal to the first
// dimension of t and e. SelectV2 op supports normal broadcasting, so in these
// cases the condition needs to be reshaped.
def SelectToSelectV2 : Pat<
(TF_SelectOp:$op StaticShapeTensorOf<[AnyType]>:$cond,
StaticShapeTensorOf<[AnyType]>:$t,
StaticShapeTensorOf<[AnyType]>:$e),
(TF_SelectV2Op (ReshapeSelectPredIfNecessary $op, $cond, $t), $t, $e)>;
//===----------------------------------------------------------------------===//
// Square op patterns.
//===----------------------------------------------------------------------===//

View File

@ -1320,27 +1320,6 @@ func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor<?x?xf32>) -> tens
// Select op legalizations.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @select
func @select(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2)
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0: tensor<2xi32>
}
// CHECK-LABEL: func @select_float
func @select_float(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2)
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %0: tensor<2xf32>
}
// CHECK-LABEL: func @select_multidimensional
func @select_multidimensional(%arg0: tensor<3x2xi1>, %arg1: tensor<3x2xi32>, %arg2: tensor<3x2xi32>) -> tensor<3x2xi32> {
// CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2)
%0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
return %0: tensor<3x2xi32>
}
// CHECK-LABEL: func @selectv2
func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
// CHECK-NEXT: "xla_hlo.select"(%arg0, %arg1, %arg2)
@ -1379,6 +1358,14 @@ func @selectv2_broadcast_pred(%arg0: tensor<1xi1>, %arg1: tensor<2x8x8xi32>, %ar
return %0: tensor<2x8x8xi32>
}
// CHECK-LABEL: func @selectv2_broadcast_tensor_pred
func @selectv2_broadcast_tensor_pred(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf16>, %arg2: tensor<2x3xf16>) -> tensor<2x3xf16> {
// CHECK: %[[BROADCAST:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi1>) -> tensor<2x3xi1>
// CHECK: "xla_hlo.select"(%[[BROADCAST]], %arg1, %arg2)
%0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf16>, tensor<2x3xf16>) -> tensor<2x3xf16>
return %0: tensor<2x3xf16>
}
// CHECK-LABEL: func @selectv2_broadcast_all
func @selectv2_broadcast_all(%arg0: tensor<8x1x1xi1>, %arg1: tensor<1x8x1xi32>, %arg2: tensor<1x1x8xi32>) -> tensor<8x8x8xi32> {
// CHECK-DAG: %[[BROADCAST_0:.*]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x1xi1>) -> tensor<8x8x8xi1>

View File

@ -521,18 +521,6 @@ def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1, &$_builder)">;
def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (TF_ConstOp $axis)),
(HLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>;
//===----------------------------------------------------------------------===//
// Ternary op patterns.
//===----------------------------------------------------------------------===//
def BothTypesMatch : Constraint<CPred<"$0.getType() == $1.getType()">,
"types must be equal">;
def : Pat<(TF_SelectOp $cond, $t, $e), (HLO_SelectOp $cond, $t, $e),
// TODO(jpienaar): This restriction is to avoid creating a currently
// unsupported HLO select.
[(BothTypesMatch $t, $e)]>;
//===----------------------------------------------------------------------===//
// Unary op patterns.
//===----------------------------------------------------------------------===//

View File

@ -77,7 +77,6 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
np.int32(2),
expected=np.array([1, 3, 5], dtype=np.int32))
@test_util.disable_mlir_bridge('TODO(b/155949336)')
def testSelect(self):
for dtype in self.numeric_types:
self._testTernary(