Add a verify method to TF.Slice.

Add test cases.

PiperOrigin-RevId: 274837040
Change-Id: I14e697471df1ec6c93b734fe9efdbea39703e1ee
This commit is contained in:
Bixia Zheng 2019-10-15 10:38:06 -07:00 committed by TensorFlower Gardener
parent 7af994f509
commit 82c9b5abe1
3 changed files with 117 additions and 0 deletions

View File

@ -3852,6 +3852,10 @@ whose values are extracted from 'input' starting at the offsets in
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>;
let verifier = [{
return Verify(*this);
}];
}
def TF_SnapshotOp : TF_Op<"Snapshot", [NoSideEffect, SameOperandsAndResultType]> {

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include <algorithm>
#include <cstdint>
#include <functional>
#include <numeric>
#include <string>
@ -1212,6 +1213,70 @@ static LogicalResult Verify(ShapeNOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// SliceOp
//===----------------------------------------------------------------------===//
// Verifies that,
//
// - operands begin and size are 1D with the same number of elements.
// - if the input is a ranked tensor, the rank of the input equals the number
// of elements in operands begin and size.
// - if begin are constants, 0 <= begin[i] < input_ty.getShape()[i]
//
static LogicalResult Verify(SliceOp op) {
RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin());
if (begin_ty && begin_ty.getRank() != 1) {
return op.emitOpError() << "requires begin operand to be 1D tensor";
}
RankedTensorType size_ty = GetRankedTensorTypeForOperand(op.size());
if (size_ty && size_ty.getRank() != 1) {
return op.emitOpError() << "requires size operand to be 1D tensor";
}
if (!begin_ty || !size_ty || !begin_ty.hasStaticShape() ||
!size_ty.hasStaticShape())
return success();
if (begin_ty.getNumElements() != size_ty.getNumElements()) {
return op.emitOpError() << "requires begin and size operands to have the"
" same number of elements";
}
auto input_ty = op.input()->getType().dyn_cast<RankedTensorType>();
if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) {
return op.emitOpError() << "requires number of elements in begin and size"
"are equal to input rank";
}
DenseIntElementsAttr begin_indices;
if (matchPattern(op.begin(), m_Constant(&begin_indices))) {
DenseIntElementsAttr slice_sizes;
bool constant_slice_sizes =
matchPattern(op.size(), m_Constant(&slice_sizes));
int dim = 0;
for (APInt raw_begin_index : begin_indices.getValues<APInt>()) {
int64_t begin_index = raw_begin_index.getSExtValue();
int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1;
int64_t slice_size = constant_slice_sizes
? slice_sizes.getValue<APInt>(dim).getSExtValue()
: 0;
if (slice_size == -1 && input_size != -1) {
slice_size = input_size - begin_index;
}
if (begin_index < 0 ||
(input_size != -1 && begin_index + slice_size > input_size)) {
return op.emitOpError()
<< "requires 0 <= begin[i] <= begin[i] + size[i] <= Di";
}
++dim;
}
}
return success();
}
//===----------------------------------------------------------------------===//
// SoftmaxOp
//===----------------------------------------------------------------------===//

View File

@ -1311,6 +1311,54 @@ func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>, %axis: tensor<i32
// -----
// Valid slice operation.
func @testSlice(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
%sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
%0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
return %0 : tensor<1x4xi32>
}
// -----
func @testSlice_begin_2d(%arg0: tensor<4xi32>, %begins: tensor<2x2xi64>) -> tensor<3xi32> {
%sizes = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
// expected-error @+1 {{requires begin operand to be 1D tensor}}
%0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<2x2xi64>, tensor<1xi64>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
func @testSlice_size_two_much_elements(%arg0: tensor<4xi32>) -> tensor<3xi32> {
%begins = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
%sizes = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> (tensor<2xi64>)
// expected-error @+1 {{requires begin and size operands to have the same number of elements}}
%0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<2xi64>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
func @testSlice_begin_negative(%arg0: tensor<4xi32>) -> tensor<2xi32> {
%begins = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
%sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>)
// expected-error @+1 {{requires 0 <= begin[i] <= begin[i] + size[i] <= Di}}
%0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}
// -----
func @testSlice_begin_out_of_bound(%arg0: tensor<4xi32>) -> tensor<2xi32> {
%begins = "tf.Const"() {value = dense<[4]> : tensor<1xi64>} : () -> (tensor<1xi64>)
%sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>)
// expected-error @+1 {{requires 0 <= begin[i] <= begin[i] + size[i] <= Di}}
%0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}
// -----
// Valid StridedSlice operation.
func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>) -> tensor<?x?xf32> {
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<?x?xf32>