Use pack instead of concat
Pack works for scalars, while concat requires non-scalars. PiperOrigin-RevId: 338013269 Change-Id: I0815ac6a45806a14a69ee5f0961480f0d69dd708
This commit is contained in:
parent
72c19e8880
commit
495a9f4ac2
tensorflow/compiler/mlir/tensorflow
@ -515,11 +515,9 @@ func @addN_variant(%arg0: tensor<!tf.variant<tensor<2xf32>>>, %arg1: tensor<!tf.
|
||||
|
||||
// CHECK-LABEL: func @DynamicStitch_simple
|
||||
func @DynamicStitch_simple(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
// CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]])
|
||||
// CHECK-DAG: %[[ITEMS_0:.*]] = "tf.ExpandDims"(%[[ITEMS]]#0, %[[AXIS]])
|
||||
// CHECK: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[ITEMS_0]], %[[AXIS]]) : (tensor<1x2xf32>, tensor<1x2xf32>, tensor<i64>) -> tensor<2x2xf32>
|
||||
// CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<2xf32>, tensor<2xf32>, tensor<i64>) -> tensor<2x2xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
%indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
@ -535,12 +533,7 @@ func @DynamicStitch_scalar_matrix_indices(%arg0: tensor<2xf32>, %arg1: tensor<2x
|
||||
// CHECK-DAG: %[[INP1:.*]] = "tf.Reshape"(%arg1, %[[SHAPE]]) : (tensor<2x2x2xf32>, tensor<2xi64>) -> tensor<4x2xf32>
|
||||
// CHECK-DAG: %[[ITEMS1:.*]]:4 = "tf.Unpack"(%[[INP1]]) {axis = 0 : i64} : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>)
|
||||
// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK-DAG: %[[ITEMS1_3:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#3, %[[AXIS]])
|
||||
// CHECK-DAG: %[[ITEMS1_2:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#2, %[[AXIS]])
|
||||
// CHECK-DAG: %[[ITEMS1_1:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#1, %[[AXIS]])
|
||||
// CHECK-DAG: %[[ITEMS1_0:.*]] = "tf.ExpandDims"(%[[ITEMS1]]#0, %[[AXIS]])
|
||||
// CHECK-DAG: %[[ITEMS0_0:.*]] = "tf.ExpandDims"(%[[ITEMS0]], %[[AXIS]])
|
||||
// CHECK-DAG: "tf.ConcatV2"(%[[ITEMS1_3]], %[[ITEMS1_2]], %[[ITEMS1_1]], %[[ITEMS1_0]], %[[ITEMS0_0]], %[[AXIS]]) : (tensor<1x2xf32>, tensor<1x2xf32>, tensor<1x2xf32>, tensor<1x2xf32>, tensor<1x2xf32>, tensor<i64>) -> tensor<5x2xf32>
|
||||
// CHECK-DAG: %6 = "tf.ConcatV2"(%[[ITEMS1]]#3, %[[ITEMS1]]#2, %[[ITEMS1]]#1, %[[ITEMS1]]#0, %[[ITEMS0]], %[[AXIS]]) : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<i64>) -> tensor<5x2xf32>
|
||||
|
||||
%indices0 = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
|
||||
%indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
||||
@ -562,9 +555,7 @@ func @DynamicStitch_uint8(%arg0: tensor<2x2xui8>) -> tensor<2x2xui8> {
|
||||
func @DynamicStitch_scalar_item(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK-DAG: %[[ITEMS]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2xf32>) -> (tensor<f32>, tensor<f32>)
|
||||
// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]])
|
||||
// CHECK-DAG: %[[ITEMS_0:.*]] = "tf.ExpandDims"(%[[ITEMS]]#0, %[[AXIS]])
|
||||
// CHECK-DAG: %[[RESULT]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[ITEMS_0]], %[[AXIS]]) : (tensor<1xf32>, tensor<1xf32>, tensor<i64>) -> tensor<2xf32>
|
||||
// CHECK-DAG: %[[RESULT]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<f32>, tensor<f32>, tensor<i64>) -> tensor<2xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
%indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
@ -576,9 +567,7 @@ func @DynamicStitch_scalar_item(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
func @DynamicStitch_matrix_item(%arg0: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> {
|
||||
// CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>)
|
||||
// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]])
|
||||
// CHECK-DAG: %[[ITEMS_0:.*]] = "tf.ExpandDims"(%[[ITEMS]]#0, %[[AXIS]])
|
||||
// CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[ITEMS_0]], %[[AXIS]]) : (tensor<1x2x2xf32>, tensor<1x2x2xf32>, tensor<i64>) -> tensor<2x2x2xf32>
|
||||
// CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<i64>) -> tensor<2x2x2xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
%indices = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
@ -597,8 +586,7 @@ func @DynamicStitch_dynamic(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tenso
|
||||
func @DynamicStitch_duplicates(%arg0: tensor<2x2xf32>) -> tensor<1x2xf32> {
|
||||
// CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>)
|
||||
// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK-DAG: %[[ITEMS_1:.*]] = "tf.ExpandDims"(%[[ITEMS]]#1, %[[AXIS]])
|
||||
// CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS_1]], %[[AXIS]]) : (tensor<1x2xf32>, tensor<i64>) -> tensor<1x2xf32>
|
||||
// CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[AXIS]]) : (tensor<2xf32>, tensor<i64>) -> tensor<1x2xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
||||
%indices = "tf.Const"() {value = dense<[0, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
|
@ -215,7 +215,7 @@ class LowerAddNOp : public RewritePattern {
|
||||
};
|
||||
|
||||
// Lowers DynamicStitch op with constant indices and with static input and
|
||||
// output shapes using Reshape, UnPack and Pack op.
|
||||
// output shapes using Reshape, UnPack and ConcatV2 op.
|
||||
//
|
||||
// %indices0 = "tf.Const"() {value = dense<4> : tensor<i32>}
|
||||
// %indices1 = "tf.Const"() {value = dense<[[3, 2], [1, 0]]> :
|
||||
@ -237,7 +237,7 @@ class LowerAddNOp : public RewritePattern {
|
||||
// : (tensor<4x2xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
|
||||
// tensor<2xf32>)
|
||||
// %axis = "tf.Const"() {value = dense<0> : tensor<i64>}
|
||||
// %0 = "tf.Pack"(items1#3, items1#2, items1#1, items1#0, %items0, %axis)
|
||||
// %0 = "tf.ConcatV2"(items1#3, items1#2, items1#1, items1#0, %items0, %axis)
|
||||
// : (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>,
|
||||
// tensor<2xf32>, tensor<i64>) -> tensor<5x2xf32>
|
||||
//
|
||||
@ -303,7 +303,8 @@ class LowerDynamicStitchOp : public OpRewritePattern<TF::DynamicStitchOp> {
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<PackOp>(op, op.getType(), values);
|
||||
auto axis = rewriter.create<ConstOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
rewriter.replaceOpWithNewOp<ConcatV2Op>(op, op.getType(), values, axis);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user