Add a verify method to TF.Slice.
Add test cases. PiperOrigin-RevId: 274837040 Change-Id: I14e697471df1ec6c93b734fe9efdbea39703e1ee
This commit is contained in:
parent
7af994f509
commit
82c9b5abe1
@ -3852,6 +3852,10 @@ whose values are extracted from 'input' starting at the offsets in
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_SnapshotOp : TF_Op<"Snapshot", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
@ -1212,6 +1213,70 @@ static LogicalResult Verify(ShapeNOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Verifies that,
|
||||
//
|
||||
// - operands begin and size are 1D with the same number of elements.
|
||||
// - if the input is a ranked tensor, the rank of the input equals the number
|
||||
// of elements in operands begin and size.
|
||||
// - if begin are constants, 0 <= begin[i] < input_ty.getShape()[i]
|
||||
//
|
||||
static LogicalResult Verify(SliceOp op) {
|
||||
RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin());
|
||||
if (begin_ty && begin_ty.getRank() != 1) {
|
||||
return op.emitOpError() << "requires begin operand to be 1D tensor";
|
||||
}
|
||||
|
||||
RankedTensorType size_ty = GetRankedTensorTypeForOperand(op.size());
|
||||
if (size_ty && size_ty.getRank() != 1) {
|
||||
return op.emitOpError() << "requires size operand to be 1D tensor";
|
||||
}
|
||||
|
||||
if (!begin_ty || !size_ty || !begin_ty.hasStaticShape() ||
|
||||
!size_ty.hasStaticShape())
|
||||
return success();
|
||||
|
||||
if (begin_ty.getNumElements() != size_ty.getNumElements()) {
|
||||
return op.emitOpError() << "requires begin and size operands to have the"
|
||||
" same number of elements";
|
||||
}
|
||||
|
||||
auto input_ty = op.input()->getType().dyn_cast<RankedTensorType>();
|
||||
if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) {
|
||||
return op.emitOpError() << "requires number of elements in begin and size"
|
||||
"are equal to input rank";
|
||||
}
|
||||
|
||||
DenseIntElementsAttr begin_indices;
|
||||
if (matchPattern(op.begin(), m_Constant(&begin_indices))) {
|
||||
DenseIntElementsAttr slice_sizes;
|
||||
bool constant_slice_sizes =
|
||||
matchPattern(op.size(), m_Constant(&slice_sizes));
|
||||
int dim = 0;
|
||||
for (APInt raw_begin_index : begin_indices.getValues<APInt>()) {
|
||||
int64_t begin_index = raw_begin_index.getSExtValue();
|
||||
int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1;
|
||||
int64_t slice_size = constant_slice_sizes
|
||||
? slice_sizes.getValue<APInt>(dim).getSExtValue()
|
||||
: 0;
|
||||
if (slice_size == -1 && input_size != -1) {
|
||||
slice_size = input_size - begin_index;
|
||||
}
|
||||
if (begin_index < 0 ||
|
||||
(input_size != -1 && begin_index + slice_size > input_size)) {
|
||||
return op.emitOpError()
|
||||
<< "requires 0 <= begin[i] <= begin[i] + size[i] <= Di";
|
||||
}
|
||||
++dim;
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SoftmaxOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1311,6 +1311,54 @@ func @testPack(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>, %axis: tensor<i32
|
||||
|
||||
// -----
|
||||
|
||||
// Valid slice operation.
|
||||
func @testSlice(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> {
|
||||
%sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>)
|
||||
%0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32>
|
||||
return %0 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSlice_begin_2d(%arg0: tensor<4xi32>, %begins: tensor<2x2xi64>) -> tensor<3xi32> {
|
||||
%sizes = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
||||
// expected-error @+1 {{requires begin operand to be 1D tensor}}
|
||||
%0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<2x2xi64>, tensor<1xi64>) -> tensor<3xi32>
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSlice_size_two_much_elements(%arg0: tensor<4xi32>) -> tensor<3xi32> {
|
||||
%begins = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
||||
%sizes = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> (tensor<2xi64>)
|
||||
// expected-error @+1 {{requires begin and size operands to have the same number of elements}}
|
||||
%0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<2xi64>) -> tensor<3xi32>
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSlice_begin_negative(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
||||
%begins = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
||||
%sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
||||
// expected-error @+1 {{requires 0 <= begin[i] <= begin[i] + size[i] <= Di}}
|
||||
%0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testSlice_begin_out_of_bound(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
||||
%begins = "tf.Const"() {value = dense<[4]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
||||
%sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>)
|
||||
// expected-error @+1 {{requires 0 <= begin[i] <= begin[i] + size[i] <= Di}}
|
||||
%0 = "tf.Slice"(%arg0, %begins, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Valid StridedSlice operation.
|
||||
func @testStridedSlice(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>, %strides: tensor<2xi64>) -> tensor<?x?xf32> {
|
||||
%0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<?x?xf32>
|
||||
|
Loading…
Reference in New Issue
Block a user