diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 45df154818d..c3dd7f5a398 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -402,6 +402,23 @@ static LogicalResult Verify(PackOp op) { if (op.getOperation()->getNumOperands() != op.values_count()) return op.emitOpError("input count should match 'values_count' attribute"); + Value *operand0 = op.getOperand(0); + auto input_type = operand0->getType().cast<ShapedType>(); + + // Check axis bounds. + int64_t axis_value = op.axis().getSExtValue(); + if (abs(axis_value) > input_type.getRank()) + return op.emitOpError("op attribute 'axis' is out of bounds, got ") + << axis_value; + + // Make sure all inputs have the same shape and element type. + // TODO(rahulsp): Simplify once b/135032064 is fixed. + for (Value *operand : op.getOperands()) { + auto other_type = operand->getType().cast<ShapedType>(); + if (input_type != other_type) + return op.emitOpError("operands should be of the same type"); + } + return success(); } diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index eada37df7d6..fe6dc486822 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -825,6 +825,22 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { // ----- +func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> { + // CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} + %0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32> + return %0 : tensor<1x4x2xi32> +} + +// ----- + +func @packNegInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<2x1x4xi32> { + // CHECK: "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} + %0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<2x1x4xi32> + return %0 : tensor<2x1x4xi32> +} + +// ----- + func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { // expected-error @+1 {{input count should match 'values_count' attribute}} %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 1 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> @@ -833,6 +849,22 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { // ----- +func @pack(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { + // expected-error @+1 {{operands should be of the same type}} + %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1xi32>, tensor<2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + +// ----- + +func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { + // expected-error @+1 {{op attribute 'axis' is out of bounds, got 3}} + %0 = "tfl.pack"(%arg0, %arg1) {axis = 3 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + +// ----- + func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { // CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} %0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)