Check for axis bounds and that all operands have the same shape and element type in PackOp verifier.

PiperOrigin-RevId: 264921464
This commit is contained in:
A. Unique TensorFlower 2019-08-22 14:35:33 -07:00 committed by TensorFlower Gardener
parent 36d59de3b6
commit 78299d8dbf
2 changed files with 49 additions and 0 deletions

View File

@ -402,6 +402,23 @@ static LogicalResult Verify(PackOp op) {
if (op.getOperation()->getNumOperands() != op.values_count())
return op.emitOpError("input count should match 'values_count' attribute");
Value *operand0 = op.getOperand(0);
auto input_type = operand0->getType().cast<ShapedType>();
// Check axis bounds.
int64_t axis_value = op.axis().getSExtValue();
if (abs(axis_value) > input_type.getRank())
return op.emitOpError("op attribute 'axis' is out of bounds, got ")
<< axis_value;
// Make sure all inputs have the same shape and element type.
// TODO(rahulsp): Simplify once b/135032064 is fixed.
for (Value *operand : op.getOperands()) {
auto other_type = operand->getType().cast<ShapedType>();
if (input_type != other_type)
return op.emitOpError("operands should be of the same type");
}
return success();
}

View File

@ -825,6 +825,22 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// -----
func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32>
return %0 : tensor<1x4x2xi32>
}
// -----
func @packNegInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<2x1x4xi32> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<2x1x4xi32>
return %0 : tensor<2x1x4xi32>
}
// -----
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// expected-error @+1 {{input count should match 'values_count' attribute}}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 1 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
@ -833,6 +849,22 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// -----
func @pack(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// expected-error @+1 {{operands should be of the same type}}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// expected-error @+1 {{op attribute 'axis' is out of bounds, got 3}}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 3 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32}
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)