diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
index 5230f05f25c..b1787546d67 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
@@ -515,11 +515,9 @@ func @addN_variant(%arg0: tensor<!tf.variant<tensor<2xf32>>>, %arg1: tensor<!tf.
 
 // CHECK-LABEL: func @DynamicStitch_simple
 func @DynamicStitch_simple(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
-  // CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
-  // CHECK: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
-  // CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]])
-  // CHECK-DAG: %[[ITEMS_0:.*]] = "tf.ExpandDims"(%[[ITEMS]]#0, %[[AXIS]])
-  // CHECK: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[ITEMS_0]], %[[AXIS]]) : (tensor<1x2xf32>, tensor<1x2xf32>, tensor<i64>) -> tensor<2x2xf32>
+  // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
+  // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
+  // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<2xf32>, tensor<2xf32>, tensor<i64>) -> tensor<2x2xf32>
   // CHECK: return %[[RESULT]]
 
   %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
@@ -535,12 +533,7 @@ func @DynamicStitch_scalar_matrix_indices(%arg0: tensor<2xf32>, %arg1: tensor<2x
   // CHECK-DAG: %[[INP1:.*]] = "tf.Reshape"(%arg1, %[[SHAPE]]) : (tensor<2x2x2xf32>, tensor<2xi64>) -> tensor<4x2xf32>
   // CHECK-DAG: %[[ITEMS1:.*]]:4 = "tf.Unpack"(%[[INP1]]) {axis = 0 : i64} : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>)
   // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
-  // CHECK-DAG: %[[ITEMS1_3:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#3, %[[AXIS]])
-  // CHECK-DAG: %[[ITEMS1_2:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#2, %[[AXIS]])
-  // CHECK-DAG: %[[ITEMS1_1:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#1, %[[AXIS]])
-  // CHECK-DAG: %[[ITEMS1_0:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#0, %[[AXIS]])
-  // CHECK-DAG: %[[ITEMS0_0:.*]] = "tf.ExpandDims"(%[[ITEMS0]], %[[AXIS]])
-  // CHECK-DAG: "tf.ConcatV2"(%[[ITEMS1_3]], %[[ITEMS1_2]], %[[ITEMS1_1]], %[[ITEMS1_0]], %[[ITEMS0_0]], %[[AXIS]]) : (tensor<1x2xf32>, tensor<1x2xf32>, tensor<1x2xf32>, tensor<1x2xf32>, tensor<1x2xf32>, tensor<i64>) -> tensor<5x2xf32>
+  // CHECK-DAG: %6 = "tf.ConcatV2"(%[[ITEMS1]]#3, %[[ITEMS1]]#2, %[[ITEMS1]]#1, %[[ITEMS1]]#0, %[[ITEMS0]], %[[AXIS]]) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<i64>) -> tensor<5x2xf32>
 
   %indices0 = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
   %indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
@@ -562,9 +555,7 @@ func @DynamicStitch_uint8(%arg0: tensor<2x2xui8>) -> tensor<2x2xui8> {
 func @DynamicStitch_scalar_item(%arg0: tensor<2xf32>) -> tensor<2xf32> {
   // CHECK-DAG: %[[ITEMS]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2xf32>) -> (tensor<f32>, tensor<f32>)
   // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
-  // CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]])
-  // CHECK-DAG: %[[ITEMS_0:.*]] = "tf.ExpandDims"(%[[ITEMS]]#0, %[[AXIS]])
-  // CHECK-DAG: %[[RESULT]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[ITEMS_0]], %[[AXIS]]) : (tensor<1xf32>, tensor<1xf32>, tensor<i64>) -> tensor<2xf32>
+  // CHECK-DAG: %[[RESULT]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<f32>, tensor<f32>, tensor<i64>) -> tensor<2xf32>
   // CHECK: return %[[RESULT]]
 
   %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
@@ -576,9 +567,7 @@ func @DynamicStitch_scalar_item(%arg0: tensor<2xf32>) -> tensor<2xf32> {
 func @DynamicStitch_matrix_item(%arg0: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> {
   // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>)
   // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
-  // CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]])
-  // CHECK-DAG: %[[ITEMS_0:.*]] = "tf.ExpandDims"(%[[ITEMS]]#0, %[[AXIS]])
-  // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[ITEMS_0]], %[[AXIS]]) : (tensor<1x2x2xf32>, tensor<1x2x2xf32>, tensor<i64>) -> tensor<2x2x2xf32>
+  // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<i64>) -> tensor<2x2x2xf32>
   // CHECK: return %[[RESULT]]
 
   %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
@@ -597,8 +586,7 @@ func @DynamicStitch_dynamic(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tenso
 func @DynamicStitch_duplicates(%arg0: tensor<2x2xf32>) -> tensor<1x2xf32> {
   // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
   // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
-  // CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]])
-  // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[AXIS]]) : (tensor<1x2xf32>, tensor<i64>) -> tensor<1x2xf32>
+  // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[AXIS]]) : (tensor<2xf32>, tensor<i64>) -> tensor<1x2xf32>
   // CHECK: return %[[RESULT]]
 
   %indices = "tf.Const"() {value = dense<[0, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc
index c93679ab7da..a462f967bef 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc
@@ -215,7 +215,7 @@ class LowerAddNOp : public RewritePattern {
 };
 
 // Lowers DynamicStitch op with constant indices and with static input and
-// output shapes using Reshape, UnPack and Pack op.
+// output shapes using Reshape, UnPack and ConcatV2 op.
 //
 //   %indices0 = "tf.Const"() {value = dense<4> : tensor<i32>}
 //   %indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> :
@@ -237,7 +237,7 @@ class LowerAddNOp : public RewritePattern {
 //     : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
 //     tensor<2xf32>)
 //   %axis = "tf.Const"() {value = dense<0> : tensor<i64>}
-//   %0 = "tf.Pack"(items1#3, items1#2, items1#1, items1#0, %items0, %axis)
+//   %0 = "tf.ConcatV2"(items1#3, items1#2, items1#1, items1#0, %items0, %axis)
 //     : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
 //        tensor<2xf32>, tensor<i64>) -> tensor<5x2xf32>
 //
@@ -303,7 +303,8 @@ class LowerDynamicStitchOp : public OpRewritePattern<TF::DynamicStitchOp> {
       }
     }
 
-    rewriter.replaceOpWithNewOp<PackOp>(op, op.getType(), values);
+    auto axis = rewriter.create<ConstOp>(loc, rewriter.getI64IntegerAttr(0));
+    rewriter.replaceOpWithNewOp<ConcatV2Op>(op, op.getType(), values, axis);
     return success();
   }
 };