diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 4e7c08945d4..186c8631e56 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -62,40 +62,6 @@ func @Conv2dNCHW(%arg0: tensor<256x3x32x32xf32>, %arg1: tensor<3x3x3x16xf32>) -> // LAYOUT: "tfl.conv_2d" } - -func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) { -^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>): - // OK - %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - // Unsupported training - %1:5 = "tf.FusedBatchNorm"( %0#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - // Use other output - %2:5 = "tf.FusedBatchNorm"( %1#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) - - return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32> - -// CHECK-LABEL: fusedBatchNorm -// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03> -// variance + epsilon -// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]]) -// rsqrt(variance + epsilon) -// CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD1]]) -// scale * rsqrt(variance + epsilon) -// CHECK: %[[MUL1:.*]] = "tf.Mul"(%[[ARG1:.*]], %[[RSQRT]]) -// x * scale * rsqrt(variance + epsilon) -// CHECK: %[[MUL2:.*]] = "tf.Mul"(%[[ARG0:.*]], %[[MUL1]]) -// mean * scale * rsqrt(variance + epsilon) -// CHECK: %[[MUL3:.*]] = "tf.Mul"(%[[ARG3:.*]], %[[MUL1]]) -// offset - mean * scale * rsqrt(variance + epsilon) -// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG2:.*]], %[[MUL3]]) -// x * scale * rsqrt(variance + epsilon) + -// offset - mean * scale * rsqrt(variance + epsilon) -// CHECK: %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]]) - -// CHECK: %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) -// CHECK: "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) -} - func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) { ^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>): // OK diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index ecca3d38deb..c4f30c22be3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -740,31 +740,6 @@ struct ConvertTFBroadcastTo : public RewritePattern { } }; -struct ConvertFusedBatchNorm : public OpRewritePattern { - explicit ConvertFusedBatchNorm(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op, - PatternRewriter &rewriter) const override { - auto new_result_types = - llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes()); - // reserve_space_3 - new_result_types.push_back( - UnrankedTensorType::get(FloatType::getF32(rewriter.getContext()))); - - OperationState new_state(tf_fused_batch_norm_op.getLoc(), - TF::FusedBatchNormV3Op::getOperationName(), - tf_fused_batch_norm_op.getOperands(), - new_result_types, - tf_fused_batch_norm_op.getAttrs()); - Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state); - - rewriter.replaceOp(tf_fused_batch_norm_op, - tf_fused_batch_norm_op_v3->getResults().drop_back()); - return success(); - } -}; - // The below pattern is equivalent to the DRR rule below // The checks are dependent on generated values, so we can't add // the checks on intermediate values, ideally we should find equivalent @@ -1202,7 +1177,6 @@ void PrepareTFPass::runOnFunction() { patterns.insert, FusedBatchNormV3Pat, ConvertTFDilatedConvOp>(ctx); - patterns.insert(ctx); TFL::populateWithGenerated(ctx, patterns); // TODO(karimnosseir): Split to separate pass probably after // deciding on long term plan for this optimization. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index aa087ec8c57..aa1b7bb81a9 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -3942,6 +3942,8 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + let hasCanonicalizer = 1; + let verifier = [{ return Verify(*this); }]; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index d7a24271062..e9ccbed53db 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -2336,6 +2336,41 @@ void NonMaxSuppressionV3Op::getCanonicalizationPatterns( results.insert(context); } +//===----------------------------------------------------------------------===// +// FusedBatchNormOp +//===----------------------------------------------------------------------===// + +namespace { + +class ConvertFusedBatchNorm : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op, + PatternRewriter &rewriter) const override { + auto new_result_types = + llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes()); + // reserve_space_3 + new_result_types.push_back( + UnrankedTensorType::get(FloatType::getF32(rewriter.getContext()))); + + OperationState new_state(tf_fused_batch_norm_op.getLoc(), + TF::FusedBatchNormV3Op::getOperationName(), + tf_fused_batch_norm_op.getOperands(), + new_result_types, + tf_fused_batch_norm_op.getAttrs()); + Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state); + + rewriter.replaceOp(tf_fused_batch_norm_op, + tf_fused_batch_norm_op_v3->getResults().drop_back()); + return success(); + } +}; +} // namespace. + +void FusedBatchNormOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // UnpackOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index ea9820d42e0..e77dd365abf 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -1284,3 +1284,10 @@ func @testNMSV3ToNMSV4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tens %0 = "tf.NonMaxSuppressionV3"(%arg0, %arg1, %max_size, %arg2, %arg3): (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor) -> (tensor<2xi32>) return %0 : tensor<2xi32> } + +// CHECK-LABEL: testFusedBatchNormToBatchNormV3 +func @testFusedBatchNormToBatchNormV3(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: "tf.FusedBatchNormV3" + %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4): (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> ) + return %0#0 : tensor<8x8x8x8xf32> +}