Add initial TensorFlow StridedSlice op verification
Added a TODO to verify the mask attributes. PiperOrigin-RevId: 270359381
This commit is contained in:
parent
d2247c815f
commit
436a8b6698
@ -4035,6 +4035,10 @@ receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_SubOp : TF_Op<"Sub", [Broadcastable, NoSideEffect]>,
|
||||
|
@ -21,11 +21,13 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/Traits.h" // TF:local_config_mlir
|
||||
@ -1029,6 +1031,66 @@ void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
results.insert<SubOfNeg>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StridedSliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Verifies that,
|
||||
//
|
||||
// - begin, end and strides operands are 1D and they have the same number of
|
||||
// elements. Here, the number of elements should be less than 32 to support
|
||||
// 32-bit mask attributes.
|
||||
// - None of the strides values are zero.
|
||||
//
|
||||
static LogicalResult Verify(StridedSliceOp op) {
|
||||
// Expected size for operands begin, end and strides vector operands.
|
||||
int64_t expected_size = -1;
|
||||
|
||||
for (Value *val : llvm::drop_begin(op.getOperands(), 1)) {
|
||||
auto operand_ty = val->getType().dyn_cast<ShapedType>();
|
||||
if (!operand_ty || !operand_ty.hasStaticShape()) {
|
||||
// TensorFlow constant ops may have non-static shape because the shape is
|
||||
// not propagated during constant folding. If the defining op for this
|
||||
// operand is a constant op, use the constant op's attribute to get the
|
||||
// actual shape.
|
||||
DenseIntElementsAttr attr;
|
||||
if (!matchPattern(val, m_Constant(&attr))) continue;
|
||||
operand_ty = attr.getType();
|
||||
}
|
||||
|
||||
if (operand_ty.getRank() != 1)
|
||||
return op.emitOpError()
|
||||
<< "requires begin, end and strides to be 1D tensors";
|
||||
|
||||
int64_t length = operand_ty.getDimSize(0);
|
||||
if (length == -1) continue;
|
||||
|
||||
if (expected_size == -1) {
|
||||
// This op uses 32-bit masks.
|
||||
if (length >= 32)
|
||||
return op.emitOpError(
|
||||
"requires begin, end and strides operands with less than 32 "
|
||||
"elements");
|
||||
|
||||
expected_size = length;
|
||||
} else if (length != expected_size) {
|
||||
return op.emitOpError() << "requires begin, end and strides to have the "
|
||||
"same number of elements";
|
||||
}
|
||||
}
|
||||
|
||||
// If strides are constants, verify that none of the element is zero.
|
||||
DenseIntElementsAttr strides;
|
||||
if (matchPattern(op.strides(), m_Constant(&strides))) {
|
||||
if (llvm::is_contained(strides.getValues<APInt>(), 0))
|
||||
return op.emitOpError("requires non-zero strides");
|
||||
}
|
||||
|
||||
// TODO(hinsu): Validate attributes.
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorListReserveOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1173,3 +1173,56 @@ func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>, %axis: tensor<i32
|
||||
%0 = "tf.Pack"(%arg0, %arg1) {axis = 3 : i64, N = 2: i64} : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Valid StridedSlice operation.
|
||||
func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>) -> tensor<?x?xf32> {
|
||||
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<i64>, %end: tensor<i64>, %strides: tensor<i64>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{requires begin, end and strides to be 1D tensors}}
|
||||
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<32xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{with less than 32 elements}}
|
||||
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<32xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<?xi64>, %end: tensor<3xi64>, %strides: tensor<2xi64>) -> tensor<?x?xf32> {
|
||||
// expected-error @+1 {{to have the same number of elements}}
|
||||
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<?xi64>, tensor<3xi64>, tensor<2xi64>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testStridedSlice(%input: tensor<4x8xf32>) -> tensor<?x?xf32> {
|
||||
%begin = "tf.Const"() { value = dense<[0, 0]> : tensor<2xi64> } : () -> tensor<?xi64>
|
||||
%end = "tf.Const"() { value = dense<[5, 10]> : tensor<2xi64> } : () -> tensor<?xi64>
|
||||
%strides = "tf.Const"() { value = dense<[2, 3, 4]> : tensor<3xi64> } : () -> tensor<?xi64>
|
||||
|
||||
// expected-error @+1 {{to have the same number of elements}}
|
||||
%1 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<?xi64>, tensor<?xi64>, tensor<?xi64>) -> tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<2xi32>, %end: tensor<2xi32>) -> tensor<?x?xf32> {
|
||||
%strides = "tf.Const"() { value = dense<[2, 0]> : tensor<2xi32> } : () -> tensor<2xi32>
|
||||
|
||||
// expected-error @+1 {{requires non-zero strides}}
|
||||
%1 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user