Check for axis bounds and that all operands have the same shape and element type in PackOp verifier.
PiperOrigin-RevId: 264921464
This commit is contained in:
parent
36d59de3b6
commit
78299d8dbf
@ -402,6 +402,23 @@ static LogicalResult Verify(PackOp op) {
|
|||||||
if (op.getOperation()->getNumOperands() != op.values_count())
|
if (op.getOperation()->getNumOperands() != op.values_count())
|
||||||
return op.emitOpError("input count should match 'values_count' attribute");
|
return op.emitOpError("input count should match 'values_count' attribute");
|
||||||
|
|
||||||
|
Value *operand0 = op.getOperand(0);
|
||||||
|
auto input_type = operand0->getType().cast<ShapedType>();
|
||||||
|
|
||||||
|
// Check axis bounds.
|
||||||
|
int64_t axis_value = op.axis().getSExtValue();
|
||||||
|
if (abs(axis_value) > input_type.getRank())
|
||||||
|
return op.emitOpError("op attribute 'axis' is out of bounds, got ")
|
||||||
|
<< axis_value;
|
||||||
|
|
||||||
|
// Make sure all inputs have the same shape and element type.
|
||||||
|
// TODO(rahulsp): Simplify once b/135032064 is fixed.
|
||||||
|
for (Value *operand : op.getOperands()) {
|
||||||
|
auto other_type = operand->getType().cast<ShapedType>();
|
||||||
|
if (input_type != other_type)
|
||||||
|
return op.emitOpError("operands should be of the same type");
|
||||||
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -825,6 +825,22 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> {
|
||||||
|
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32}
|
||||||
|
%0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32>
|
||||||
|
return %0 : tensor<1x4x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @packNegInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<2x1x4xi32> {
|
||||||
|
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32}
|
||||||
|
%0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<2x1x4xi32>
|
||||||
|
return %0 : tensor<2x1x4xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||||
// expected-error @+1 {{input count should match 'values_count' attribute}}
|
// expected-error @+1 {{input count should match 'values_count' attribute}}
|
||||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 1 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 1 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||||
@ -833,6 +849,22 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @pack(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||||
|
// expected-error @+1 {{operands should be of the same type}}
|
||||||
|
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||||
|
return %0 : tensor<2x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||||
|
// expected-error @+1 {{op attribute 'axis' is out of bounds, got 3}}
|
||||||
|
%0 = "tfl.pack"(%arg0, %arg1) {axis = 3 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||||
|
return %0 : tensor<2x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
||||||
// CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32}
|
// CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32}
|
||||||
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user