diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
index 45df154818d..c3dd7f5a398 100644
--- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
+++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
@@ -402,6 +402,23 @@ static LogicalResult Verify(PackOp op) {
   if (op.getOperation()->getNumOperands() != op.values_count())
     return op.emitOpError("input count should match 'values_count' attribute");
 
+  Value *operand0 = op.getOperand(0);
+  auto input_type = operand0->getType().cast<ShapedType>();
+
+  // Check axis bounds.
+  int64_t axis_value = op.axis().getSExtValue();
+  if (abs(axis_value) > input_type.getRank())
+    return op.emitOpError("op attribute 'axis' is out of bounds, got ")
+           << axis_value;
+
+  // Make sure all inputs have the same shape and element type.
+  // TODO(rahulsp): Simplify once b/135032064 is fixed.
+  for (Value *operand : op.getOperands()) {
+    auto other_type = operand->getType().cast<ShapedType>();
+    if (input_type != other_type)
+      return op.emitOpError("operands should be of the same type");
+  }
+
   return success();
 }
 
diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir
index eada37df7d6..fe6dc486822 100644
--- a/tensorflow/compiler/mlir/lite/tests/ops.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir
@@ -825,6 +825,22 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
 
 // -----
 
+func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> {
+  // CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32}
+  %0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32>
+  return %0 : tensor<1x4x2xi32>
+}
+
+// -----
+
+func @packNegInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<2x1x4xi32> {
+  // CHECK: "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32}
+  %0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<2x1x4xi32>
+  return %0 : tensor<2x1x4xi32>
+}
+
+// -----
+
 func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
   // expected-error @+1 {{input count should match 'values_count' attribute}}
   %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 1 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
@@ -833,6 +849,22 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
 
 // -----
 
+func @pack(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
+  // expected-error @+1 {{operands should be of the same type}}
+  %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1xi32>, tensor<2xi32>) -> tensor<2x2xi32>
+  return %0 : tensor<2x2xi32>
+}
+
+// -----
+
+func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
+  // expected-error @+1 {{op attribute 'axis' is out of bounds, got 3}}
+  %0 = "tfl.pack"(%arg0, %arg1) {axis = 3 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
+  return %0 : tensor<2x2xi32>
+}
+
+// -----
+
 func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
   // CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32}
   %0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)