diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 26a55c61989..1ae789f5468 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1021,24 +1021,6 @@ func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> t // CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> } -func @concat2Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x2xi32> { - %0 = "tf.Const"() { value = dense<1> : tensor } : () -> tensor - %1 = "tf.Concat"(%0, %arg0, %arg1) : (tensor, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> - return %1 : tensor<2x2xi32> - -// CHECK-LABEL: concat2Tensors -// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> -} - -func @concat3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { - %0 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor - %1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) : (tensor, tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> - return %1 : tensor<2x3xi32> - -// CHECK-LABEL: concat3Tensors -// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> -} - func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { %0 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor %1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor) -> tensor<2x3xi32> diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index bfcbc190638..953f24c3b45 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -123,7 +123,6 @@ bool HasSameStaticShapes(Operation* op) { // operands are properly supported in declarative rewrite rule specification. DECL_CONVERT_OP(Assert); -DECL_CONVERT_OP(Concat); DECL_CONVERT_OP(ConcatV2); DECL_CONVERT_OP(MatMul); DECL_CONVERT_OP(MatrixDiagV2); @@ -184,25 +183,6 @@ LogicalResult ConvertTFRandomUniformOp::matchAndRewrite( return failure(); } -LogicalResult ConvertTFConcatOp::matchAndRewrite( - Operation* op, PatternRewriter& rewriter) const { - auto tf_concat_op = cast(op); - - auto values = tf_concat_op.values(); - auto output_type = tf_concat_op.output().getType(); - // Extract axis attribute from constant concat_dims tensor - ElementsAttr axis; - if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis))) - return failure(); - - StringAttr fused_activation_function = - StringAttr::get("NONE", rewriter.getContext()); - rewriter.replaceOpWithNewOp( - op, output_type, values, mlir::TFL::ExtractSingleElementAsInteger(axis), - fused_activation_function); - return success(); -} - // Converts any IntegerAttr to an IntegerAttr of an i32 type. // The value won't change in the new attribute, but if the value is out of // the bound of i32, the function returns a failure. @@ -756,12 +736,12 @@ void LegalizeTF::runOnFunction() { // Add the generated patterns to the list. populateWithGenerated(context, &patterns); - patterns.insert(context); + patterns + .insert(context); // Ophint python converter converted tf node pattern. patterns.insert { let verifier = [{ return Verify(*this); }]; + + let hasCanonicalizer = 1; } def TF_ConcatOffsetOp : TF_Op<"ConcatOffset", [NoSideEffect]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index a96853b31e9..512011c3a0f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -873,6 +873,11 @@ static LogicalResult Verify(OpT op) { /*mask_one_dim=*/true, op.getOperation()); } +void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ConcatOffsetOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 61b9ed64035..a127168157f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -133,6 +133,16 @@ func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x3 // CHECK: return %arg0 } +// CHECK-LABEL: testConcatCanonicalization +func @testConcatCanonicalization(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x2xi32> { + // CHECK: %[[AXIS:.*]] = "tf.Const" + %0 = "tf.Const"() { value = dense<1> : tensor } : () -> tensor + + // CHECK: "tf.ConcatV2"(%arg0, %arg1, %[[AXIS]]) + %1 = "tf.Concat"(%0, %arg0, %arg1) : (tensor, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> + return %1 : tensor<2x2xi32> +} + // CHECK-LABEL: testLogOfSoftmax func @testLogOfSoftmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index 43661902138..041d1c6cdaf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -83,6 +83,13 @@ def BitcastSameType : Pat<(TF_BitcastOp:$res $arg), (replaceWithValue $arg), def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)), (TF_BitcastOp $arg)>; +//===----------------------------------------------------------------------===// +// Concat op patterns. +//===----------------------------------------------------------------------===// + +def ConvertToConcatV2 : Pat<(TF_ConcatOp $axis, $inputs), + (TF_ConcatV2Op $inputs, $axis)>; + //===----------------------------------------------------------------------===// // Conj op patterns. //===----------------------------------------------------------------------===//