From e0a97f34c2858b7f0c7bf987501b42cfaca2275e Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Sun, 6 Oct 2019 23:31:24 -0700 Subject: [PATCH] Replace BiasAdd op with AddV2 op in BiasAdd and Mul op rewrite pattern This triggered a test failure after adding BiasAdd verifier to have 1D shape constraint on bias operand. PiperOrigin-RevId: 273226065 --- .../mlir/tensorflow/tests/optimize.mlir | 2 +- .../mlir/tensorflow/transforms/optimize.td | 17 ++++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/optimize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/optimize.mlir index ad43a3502ab..fc286fbd640 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/optimize.mlir @@ -13,7 +13,7 @@ func @convbiasaddmul(%arg: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf32> // CHECK-NEXT: %[[cst:.*]] = "tf.Const{{.*}} dense<8.000000e+00> : tensor<3x3x3x16xf32> // CHECK-NEXT: %[[cst_0:.*]] = "tf.Const{{.*}} dense<1.200000e+01> : tensor<16xf32> // CHECK-NEXT: %[[conv:.*]] = "tf.Conv2D"(%arg0, %[[cst]]) -// CHECK-NEXT: %[[bias:.*]] = "tf.BiasAdd"(%[[conv]], %[[cst_0]]) +// CHECK-NEXT: %[[bias:.*]] = "tf.AddV2"(%[[conv]], %[[cst_0]]) // CHECK-NEXT: return %[[bias]] : tensor<256x30x30x16xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td index 49793f43cf3..55038b13ab4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td @@ -46,27 +46,30 @@ def FuseMulAndConv2D : Pat<(TF_MulOp (TF_Conv2DOp $input, // This rule does the following pattern match and rewrite: // // input bias input value bias value -// | / => \ / \ / +// \ / => \ / \ / // BiasAdd value Mul Mul // \ / \ / -// Mul BiasAdd +// Mul AddV2 // This is to enable the FuseMulAndConv2D pattern. +// Here, root of the result is AddV2 instead of BiasAdd because the value may +// not have rank one and therefore the second operand may not have rank one +// that is required by the BiasAdd. BiasAdd with 'NHWC' data format equivalent +// to AddV2 op. def PassthroughMulAndBiasAdd : Pat<(TF_MulOp (TF_BiasAddOp $input, - (ConstantOp F32ElementsAttr:$bias), IsDataFormatNHWC:$same_format), + (ConstantOp F32ElementsAttr:$bias), IsDataFormatNHWC:$format), (ConstantOp F32ElementsAttr:$value)), - (TF_BiasAddOp + (TF_AddV2Op (TF_MulOp $input, (ConstantOp $value)), - (TF_MulOp (ConstantOp $bias), (ConstantOp $value)), - $same_format), + (TF_MulOp (ConstantOp $bias), (ConstantOp $value))), [(DefinedByConv2D $input)]>; // This rule does the following pattern match and rewrite: // // input bias input value bias value -// | / => \ / \ / +// \ / => \ / \ / // AddV2 value Mul Mul // \ / \ / // Mul AddV2