Replace BiasAdd op with AddV2 op in BiasAdd and Mul op rewrite pattern
This triggered a test failure after adding BiasAdd verifier to have 1D shape constraint on bias operand. PiperOrigin-RevId: 273226065
This commit is contained in:
parent
168b3b88c3
commit
e0a97f34c2
@ -13,7 +13,7 @@ func @convbiasaddmul(%arg: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf32>
|
||||
// CHECK-NEXT: %[[cst:.*]] = "tf.Const{{.*}} dense<8.000000e+00> : tensor<3x3x3x16xf32>
|
||||
// CHECK-NEXT: %[[cst_0:.*]] = "tf.Const{{.*}} dense<1.200000e+01> : tensor<16xf32>
|
||||
// CHECK-NEXT: %[[conv:.*]] = "tf.Conv2D"(%arg0, %[[cst]])
|
||||
// CHECK-NEXT: %[[bias:.*]] = "tf.BiasAdd"(%[[conv]], %[[cst_0]])
|
||||
// CHECK-NEXT: %[[bias:.*]] = "tf.AddV2"(%[[conv]], %[[cst_0]])
|
||||
// CHECK-NEXT: return %[[bias]] : tensor<256x30x30x16xf32>
|
||||
}
|
||||
|
||||
|
@ -46,27 +46,30 @@ def FuseMulAndConv2D : Pat<(TF_MulOp (TF_Conv2DOp $input,
|
||||
// This rule does the following pattern match and rewrite:
|
||||
//
|
||||
// input bias input value bias value
|
||||
// | / => \ / \ /
|
||||
// \ / => \ / \ /
|
||||
// BiasAdd value Mul Mul
|
||||
// \ / \ /
|
||||
// Mul BiasAdd
|
||||
// Mul AddV2
|
||||
// This is to enable the FuseMulAndConv2D pattern.
|
||||
// Here, root of the result is AddV2 instead of BiasAdd because the value may
|
||||
// not have rank one and therefore the second operand may not have rank one
|
||||
// that is required by the BiasAdd. BiasAdd with 'NHWC' data format equivalent
|
||||
// to AddV2 op.
|
||||
def PassthroughMulAndBiasAdd :
|
||||
Pat<(TF_MulOp
|
||||
(TF_BiasAddOp $input,
|
||||
(ConstantOp F32ElementsAttr:$bias), IsDataFormatNHWC:$same_format),
|
||||
(ConstantOp F32ElementsAttr:$bias), IsDataFormatNHWC:$format),
|
||||
(ConstantOp F32ElementsAttr:$value)),
|
||||
(TF_BiasAddOp
|
||||
(TF_AddV2Op
|
||||
(TF_MulOp $input, (ConstantOp $value)),
|
||||
(TF_MulOp (ConstantOp $bias), (ConstantOp $value)),
|
||||
$same_format),
|
||||
(TF_MulOp (ConstantOp $bias), (ConstantOp $value))),
|
||||
[(DefinedByConv2D $input)]>;
|
||||
|
||||
|
||||
// This rule does the following pattern match and rewrite:
|
||||
//
|
||||
// input bias input value bias value
|
||||
// | / => \ / \ /
|
||||
// \ / => \ / \ /
|
||||
// AddV2 value Mul Mul
|
||||
// \ / \ /
|
||||
// Mul AddV2
|
||||
|
Loading…
x
Reference in New Issue
Block a user