Check for axis bounds and that all operands have the same shape and element type in PackOp verifier.
PiperOrigin-RevId: 264921464
This commit is contained in:
parent
36d59de3b6
commit
78299d8dbf
@ -402,6 +402,23 @@ static LogicalResult Verify(PackOp op) {
|
||||
if (op.getOperation()->getNumOperands() != op.values_count())
|
||||
return op.emitOpError("input count should match 'values_count' attribute");
|
||||
|
||||
Value *operand0 = op.getOperand(0);
|
||||
auto input_type = operand0->getType().cast<ShapedType>();
|
||||
|
||||
// Check axis bounds.
|
||||
int64_t axis_value = op.axis().getSExtValue();
|
||||
if (abs(axis_value) > input_type.getRank())
|
||||
return op.emitOpError("op attribute 'axis' is out of bounds, got ")
|
||||
<< axis_value;
|
||||
|
||||
// Make sure all inputs have the same shape and element type.
|
||||
// TODO(rahulsp): Simplify once b/135032064 is fixed.
|
||||
for (Value *operand : op.getOperands()) {
|
||||
auto other_type = operand->getType().cast<ShapedType>();
|
||||
if (input_type != other_type)
|
||||
return op.emitOpError("operands should be of the same type");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -825,6 +825,22 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> {
|
||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32>
|
||||
return %0 : tensor<1x4x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @packNegInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<2x1x4xi32> {
|
||||
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<2x1x4xi32>
|
||||
return %0 : tensor<2x1x4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
// expected-error @+1 {{input count should match 'values_count' attribute}}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 1 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
@ -833,6 +849,22 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
func @pack(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
// expected-error @+1 {{operands should be of the same type}}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
|
||||
// expected-error @+1 {{op attribute 'axis' is out of bounds, got 3}}
|
||||
%0 = "tfl.pack"(%arg0, %arg1) {axis = 3 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
|
||||
// CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32}
|
||||
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
|
||||
|
Loading…
Reference in New Issue
Block a user