Add initial TensorFlow StridedSlice op verification

Added a TODO to verify the mask attributes.

PiperOrigin-RevId: 270359381
This commit is contained in:
Smit Hinsu 2019-09-20 15:09:15 -07:00 committed by TensorFlower Gardener
parent d2247c815f
commit 436a8b6698
3 changed files with 119 additions and 0 deletions

View File

@ -4035,6 +4035,10 @@ receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>;
let verifier = [{
return Verify(*this);
}];
}
def TF_SubOp : TF_Op<"Sub", [Broadcastable, NoSideEffect]>,

View File

@ -21,11 +21,13 @@ limitations under the License.
#include <string>
#include <type_traits>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
@ -1029,6 +1031,66 @@ void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<SubOfNeg>(context);
}
//===----------------------------------------------------------------------===//
// StridedSliceOp
//===----------------------------------------------------------------------===//
// Verifies that,
//
// - begin, end and strides operands are 1D and they have the same number of
// elements. Here, the number of elements should be less than 32 to support
// 32-bit mask attributes.
// - None of the strides values are zero.
//
static LogicalResult Verify(StridedSliceOp op) {
// Expected size for operands begin, end and strides vector operands.
int64_t expected_size = -1;
for (Value *val : llvm::drop_begin(op.getOperands(), 1)) {
auto operand_ty = val->getType().dyn_cast<ShapedType>();
if (!operand_ty || !operand_ty.hasStaticShape()) {
// TensorFlow constant ops may have non-static shape because the shape is
// not propagated during constant folding. If the defining op for this
// operand is a constant op, use the constant op's attribute to get the
// actual shape.
DenseIntElementsAttr attr;
if (!matchPattern(val, m_Constant(&attr))) continue;
operand_ty = attr.getType();
}
if (operand_ty.getRank() != 1)
return op.emitOpError()
<< "requires begin, end and strides to be 1D tensors";
int64_t length = operand_ty.getDimSize(0);
if (length == -1) continue;
if (expected_size == -1) {
// This op uses 32-bit masks.
if (length >= 32)
return op.emitOpError(
"requires begin, end and strides operands with less than 32 "
"elements");
expected_size = length;
} else if (length != expected_size) {
return op.emitOpError() << "requires begin, end and strides to have the "
"same number of elements";
}
}
// If strides are constants, verify that none of the element is zero.
DenseIntElementsAttr strides;
if (matchPattern(op.strides(), m_Constant(&strides))) {
if (llvm::is_contained(strides.getValues<APInt>(), 0))
return op.emitOpError("requires non-zero strides");
}
// TODO(hinsu): Validate attributes.
return success();
}
//===----------------------------------------------------------------------===//
// TensorListReserveOp
//===----------------------------------------------------------------------===//

View File

@ -1173,3 +1173,56 @@ func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>, %axis: tensor<i32
%0 = "tf.Pack"(%arg0, %arg1) {axis = 3 : i64, N = 2: i64} : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// Valid StridedSlice operation.
func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>) -> tensor<?x?xf32> {
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<i64>, %end: tensor<i64>, %strides: tensor<i64>) -> tensor<?x?xf32> {
// expected-error @+1 {{requires begin, end and strides to be 1D tensors}}
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<32xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>) -> tensor<?x?xf32> {
// expected-error @+1 {{with less than 32 elements}}
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<32xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<?xi64>, %end: tensor<3xi64>, %strides: tensor<2xi64>) -> tensor<?x?xf32> {
// expected-error @+1 {{to have the same number of elements}}
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<?xi64>, tensor<3xi64>, tensor<2xi64>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
func @testStridedSlice(%input: tensor<4x8xf32>) -> tensor<?x?xf32> {
%begin = "tf.Const"() { value = dense<[0, 0]> : tensor<2xi64> } : () -> tensor<?xi64>
%end = "tf.Const"() { value = dense<[5, 10]> : tensor<2xi64> } : () -> tensor<?xi64>
%strides = "tf.Const"() { value = dense<[2, 3, 4]> : tensor<3xi64> } : () -> tensor<?xi64>
// expected-error @+1 {{to have the same number of elements}}
%1 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<?xi64>, tensor<?xi64>, tensor<?xi64>) -> tensor<?x?xf32>
}
// -----
func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<2xi32>, %end: tensor<2xi32>) -> tensor<?x?xf32> {
%strides = "tf.Const"() { value = dense<[2, 0]> : tensor<2xi32> } : () -> tensor<2xi32>
// expected-error @+1 {{requires non-zero strides}}
%1 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}