diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 7beccaf6ad7..68617ee0241 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -4035,6 +4035,10 @@ receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ + return Verify(*this); + }]; } def TF_SubOp : TF_Op<"Sub", [Broadcastable, NoSideEffect]>, diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 8d28ec26507..826389b09aa 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -21,11 +21,13 @@ limitations under the License. #include #include +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/Dialect/Traits.h" // TF:local_config_mlir @@ -1029,6 +1031,66 @@ void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +//===----------------------------------------------------------------------===// +// StridedSliceOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// +// - begin, end and strides operands are 1D and they have the same number of +// elements. Here, the number of elements should be less than 32 to support +// 32-bit mask attributes. +// - None of the strides values are zero. +// +static LogicalResult Verify(StridedSliceOp op) { + // Expected size for operands begin, end and strides vector operands. + int64_t expected_size = -1; + + for (Value *val : llvm::drop_begin(op.getOperands(), 1)) { + auto operand_ty = val->getType().dyn_cast(); + if (!operand_ty || !operand_ty.hasStaticShape()) { + // TensorFlow constant ops may have non-static shape because the shape is + // not propagated during constant folding. If the defining op for this + // operand is a constant op, use the constant op's attribute to get the + // actual shape. + DenseIntElementsAttr attr; + if (!matchPattern(val, m_Constant(&attr))) continue; + operand_ty = attr.getType(); + } + + if (operand_ty.getRank() != 1) + return op.emitOpError() + << "requires begin, end and strides to be 1D tensors"; + + int64_t length = operand_ty.getDimSize(0); + if (length == -1) continue; + + if (expected_size == -1) { + // This op uses 32-bit masks. + if (length >= 32) + return op.emitOpError( + "requires begin, end and strides operands with less than 32 " + "elements"); + + expected_size = length; + } else if (length != expected_size) { + return op.emitOpError() << "requires begin, end and strides to have the " + "same number of elements"; + } + } + + // If strides are constants, verify that none of the element is zero. + DenseIntElementsAttr strides; + if (matchPattern(op.strides(), m_Constant(&strides))) { + if (llvm::is_contained(strides.getValues(), 0)) + return op.emitOpError("requires non-zero strides"); + } + + // TODO(hinsu): Validate attributes. + + return success(); +} + //===----------------------------------------------------------------------===// // TensorListReserveOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index b702f5fe88c..e207e3d3562 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1173,3 +1173,56 @@ func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>, %axis: tensor, tensor<4x8xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } + +// ----- + +// Valid StridedSlice operation. +func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>) -> tensor { + %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + return %0 : tensor +} + +// ----- + +func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor, %end: tensor, %strides: tensor) -> tensor { + // expected-error @+1 {{requires begin, end and strides to be 1D tensors}} + %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<32xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>) -> tensor { + // expected-error @+1 {{with less than 32 elements}} + %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<32xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor + return %0 : tensor +} + +// ----- + +func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor, %end: tensor<3xi64>, %strides: tensor<2xi64>) -> tensor { + // expected-error @+1 {{to have the same number of elements}} + %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor, tensor<3xi64>, tensor<2xi64>) -> tensor + return %0 : tensor +} + +// ----- + +func @testStridedSlice(%input: tensor<4x8xf32>) -> tensor { + %begin = "tf.Const"() { value = dense<[0, 0]> : tensor<2xi64> } : () -> tensor + %end = "tf.Const"() { value = dense<[5, 10]> : tensor<2xi64> } : () -> tensor + %strides = "tf.Const"() { value = dense<[2, 3, 4]> : tensor<3xi64> } : () -> tensor + + // expected-error @+1 {{to have the same number of elements}} + %1 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor, tensor, tensor) -> tensor +} + +// ----- + +func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<2xi32>, %end: tensor<2xi32>) -> tensor { + %strides = "tf.Const"() { value = dense<[2, 0]> : tensor<2xi32> } : () -> tensor<2xi32> + + // expected-error @+1 {{requires non-zero strides}} + %1 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor + return %1 : tensor +}