diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index 5230f05f25c..b1787546d67 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -515,11 +515,9 @@ func @addN_variant(%arg0: tensor<!tf.variant<tensor<2xf32>>>, %arg1: tensor<!tf. // CHECK-LABEL: func @DynamicStitch_simple func @DynamicStitch_simple(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64> - // CHECK: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>) - // CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]]) - // CHECK-DAG: %[[ITEMS_0:.*]] = "tf.ExpandDims"(%[[ITEMS]]#0, %[[AXIS]]) - // CHECK: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[ITEMS_0]], %[[AXIS]]) : (tensor<1x2xf32>, tensor<1x2xf32>, tensor<i64>) -> tensor<2x2xf32> + // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>) + // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64> + // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<2xf32>, tensor<2xf32>, tensor<i64>) -> tensor<2x2xf32> // CHECK: return %[[RESULT]] %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> @@ -535,12 +533,7 @@ func @DynamicStitch_scalar_matrix_indices(%arg0: tensor<2xf32>, %arg1: tensor<2x // CHECK-DAG: %[[INP1:.*]] = "tf.Reshape"(%arg1, %[[SHAPE]]) : (tensor<2x2x2xf32>, tensor<2xi64>) -> tensor<4x2xf32> // CHECK-DAG: %[[ITEMS1:.*]]:4 = "tf.Unpack"(%[[INP1]]) {axis = 0 : i64} : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64> - // CHECK-DAG: %[[ITEMS1_3:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#3, %[[AXIS]]) - // CHECK-DAG: %[[ITEMS1_2:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#2, %[[AXIS]]) - // CHECK-DAG: %[[ITEMS1_1:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#1, %[[AXIS]]) - // CHECK-DAG: %[[ITEMS1_0:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#0, %[[AXIS]]) - // CHECK-DAG: %[[ITEMS0_0:.*]] = "tf.ExpandDims"(%[[ITEMS0]], %[[AXIS]]) - // CHECK-DAG: "tf.ConcatV2"(%[[ITEMS1_3]], %[[ITEMS1_2]], %[[ITEMS1_1]], %[[ITEMS1_0]], %[[ITEMS0_0]], %[[AXIS]]) : (tensor<1x2xf32>, tensor<1x2xf32>, tensor<1x2xf32>, tensor<1x2xf32>, tensor<1x2xf32>, tensor<i64>) -> tensor<5x2xf32> + // CHECK-DAG: %6 = "tf.ConcatV2"(%[[ITEMS1]]#3, %[[ITEMS1]]#2, %[[ITEMS1]]#1, %[[ITEMS1]]#0, %[[ITEMS0]], %[[AXIS]]) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<i64>) -> tensor<5x2xf32> %indices0 = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32> %indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> @@ -562,9 +555,7 @@ func @DynamicStitch_uint8(%arg0: tensor<2x2xui8>) -> tensor<2x2xui8> { func @DynamicStitch_scalar_item(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-DAG: %[[ITEMS]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2xf32>) -> (tensor<f32>, tensor<f32>) // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64> - // CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]]) - // CHECK-DAG: %[[ITEMS_0:.*]] = "tf.ExpandDims"(%[[ITEMS]]#0, %[[AXIS]]) - // CHECK-DAG: %[[RESULT]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[ITEMS_0]], %[[AXIS]]) : (tensor<1xf32>, tensor<1xf32>, tensor<i64>) -> tensor<2xf32> + // CHECK-DAG: %[[RESULT]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<f32>, tensor<f32>, tensor<i64>) -> tensor<2xf32> // CHECK: return %[[RESULT]] %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> @@ -576,9 +567,7 @@ func @DynamicStitch_scalar_item(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @DynamicStitch_matrix_item(%arg0: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>) // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64> - // CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]]) - // CHECK-DAG: %[[ITEMS_0:.*]] = "tf.ExpandDims"(%[[ITEMS]]#0, %[[AXIS]]) - // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[ITEMS_0]], %[[AXIS]]) : (tensor<1x2x2xf32>, tensor<1x2x2xf32>, tensor<i64>) -> tensor<2x2x2xf32> + // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<i64>) -> tensor<2x2x2xf32> // CHECK: return %[[RESULT]] %indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> @@ -597,8 +586,7 @@ func @DynamicStitch_dynamic(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tenso func @DynamicStitch_duplicates(%arg0: tensor<2x2xf32>) -> tensor<1x2xf32> { // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>) // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64> - // CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]]) - // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[AXIS]]) : (tensor<1x2xf32>, tensor<i64>) -> tensor<1x2xf32> + // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[AXIS]]) : (tensor<2xf32>, tensor<i64>) -> tensor<1x2xf32> // CHECK: return %[[RESULT]] %indices = "tf.Const"() {value = dense<[0, 0]> : tensor<2xi32>} : () -> tensor<2xi32> diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index c93679ab7da..a462f967bef 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -215,7 +215,7 @@ class LowerAddNOp : public RewritePattern { }; // Lowers DynamicStitch op with constant indices and with static input and -// output shapes using Reshape, UnPack and Pack op. +// output shapes using Reshape, UnPack and ConcatV2 op. // // %indices0 = "tf.Const"() {value = dense<4> : tensor<i32>} // %indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> : @@ -237,7 +237,7 @@ class LowerAddNOp : public RewritePattern { // : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, // tensor<2xf32>) // %axis = "tf.Const"() {value = dense<0> : tensor<i64>} -// %0 = "tf.Pack"(items1#3, items1#2, items1#1, items1#0, %items0, %axis) +// %0 = "tf.ConcatV2"(items1#3, items1#2, items1#1, items1#0, %items0, %axis) // : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, // tensor<2xf32>, tensor<i64>) -> tensor<5x2xf32> // @@ -303,7 +303,8 @@ class LowerDynamicStitchOp : public OpRewritePattern<TF::DynamicStitchOp> { } } - rewriter.replaceOpWithNewOp<PackOp>(op, op.getType(), values); + auto axis = rewriter.create<ConstOp>(loc, rewriter.getI64IntegerAttr(0)); + rewriter.replaceOpWithNewOp<ConcatV2Op>(op, op.getType(), values, axis); return success(); } };