Replace BiasAdd op with AddV2 op in BiasAdd and Mul op rewrite pattern

This triggered a test failure after adding BiasAdd verifier to have 1D shape constraint on bias operand.

PiperOrigin-RevId: 273226065
This commit is contained in:
Smit Hinsu 2019-10-06 23:31:24 -07:00 committed by TensorFlower Gardener
parent 168b3b88c3
commit e0a97f34c2
2 changed files with 11 additions and 8 deletions

View File

@ -13,7 +13,7 @@ func @convbiasaddmul(%arg: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf32>
// CHECK-NEXT: %[[cst:.*]] = "tf.Const{{.*}} dense<8.000000e+00> : tensor<3x3x3x16xf32>
// CHECK-NEXT: %[[cst_0:.*]] = "tf.Const{{.*}} dense<1.200000e+01> : tensor<16xf32>
// CHECK-NEXT: %[[conv:.*]] = "tf.Conv2D"(%arg0, %[[cst]])
// CHECK-NEXT: %[[bias:.*]] = "tf.BiasAdd"(%[[conv]], %[[cst_0]])
// CHECK-NEXT: %[[bias:.*]] = "tf.AddV2"(%[[conv]], %[[cst_0]])
// CHECK-NEXT: return %[[bias]] : tensor<256x30x30x16xf32>
}

View File

@ -46,27 +46,30 @@ def FuseMulAndConv2D : Pat<(TF_MulOp (TF_Conv2DOp $input,
// This rule does the following pattern match and rewrite:
//
// input bias input value bias value
// | / => \ / \ /
// \ / => \ / \ /
// BiasAdd value Mul Mul
// \ / \ /
// Mul BiasAdd
// Mul AddV2
// This is to enable the FuseMulAndConv2D pattern.
// Here, root of the result is AddV2 instead of BiasAdd because the value may
// not have rank one and therefore the second operand may not have rank one
// that is required by the BiasAdd. BiasAdd with 'NHWC' data format equivalent
// to AddV2 op.
def PassthroughMulAndBiasAdd :
Pat<(TF_MulOp
(TF_BiasAddOp $input,
(ConstantOp F32ElementsAttr:$bias), IsDataFormatNHWC:$same_format),
(ConstantOp F32ElementsAttr:$bias), IsDataFormatNHWC:$format),
(ConstantOp F32ElementsAttr:$value)),
(TF_BiasAddOp
(TF_AddV2Op
(TF_MulOp $input, (ConstantOp $value)),
(TF_MulOp (ConstantOp $bias), (ConstantOp $value)),
$same_format),
(TF_MulOp (ConstantOp $bias), (ConstantOp $value))),
[(DefinedByConv2D $input)]>;
// This rule does the following pattern match and rewrite:
//
// input bias input value bias value
// | / => \ / \ /
// \ / => \ / \ /
// AddV2 value Mul Mul
// \ / \ /
// Mul AddV2