Canonicalize FusedBatchNorm op to FusedBatchNormV3
PiperOrigin-RevId: 337394368 Change-Id: I7bd7f0513815e1b27584a1b6cba7c447a9d9c9a2
This commit is contained in:
parent
4ad4b488c9
commit
cb83c8469e
@ -62,40 +62,6 @@ func @Conv2dNCHW(%arg0: tensor<256x3x32x32xf32>, %arg1: tensor<3x3x3x16xf32>) ->
|
||||
// LAYOUT: "tfl.conv_2d"
|
||||
}
|
||||
|
||||
|
||||
func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) {
|
||||
^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
|
||||
// OK
|
||||
%0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
// Unsupported training
|
||||
%1:5 = "tf.FusedBatchNorm"( %0#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
// Use other output
|
||||
%2:5 = "tf.FusedBatchNorm"( %1#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
|
||||
return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
|
||||
|
||||
// CHECK-LABEL: fusedBatchNorm
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03>
|
||||
// variance + epsilon
|
||||
// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]])
|
||||
// rsqrt(variance + epsilon)
|
||||
// CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD1]])
|
||||
// scale * rsqrt(variance + epsilon)
|
||||
// CHECK: %[[MUL1:.*]] = "tf.Mul"(%[[ARG1:.*]], %[[RSQRT]])
|
||||
// x * scale * rsqrt(variance + epsilon)
|
||||
// CHECK: %[[MUL2:.*]] = "tf.Mul"(%[[ARG0:.*]], %[[MUL1]])
|
||||
// mean * scale * rsqrt(variance + epsilon)
|
||||
// CHECK: %[[MUL3:.*]] = "tf.Mul"(%[[ARG3:.*]], %[[MUL1]])
|
||||
// offset - mean * scale * rsqrt(variance + epsilon)
|
||||
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG2:.*]], %[[MUL3]])
|
||||
// x * scale * rsqrt(variance + epsilon) +
|
||||
// offset - mean * scale * rsqrt(variance + epsilon)
|
||||
// CHECK: %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]])
|
||||
|
||||
// CHECK: %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
|
||||
// CHECK: "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
|
||||
}
|
||||
|
||||
func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) {
|
||||
^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
|
||||
// OK
|
||||
|
@ -740,31 +740,6 @@ struct ConvertTFBroadcastTo : public RewritePattern {
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
|
||||
explicit ConvertFusedBatchNorm(MLIRContext *context)
|
||||
: OpRewritePattern<TF::FusedBatchNormOp>(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto new_result_types =
|
||||
llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
|
||||
// reserve_space_3
|
||||
new_result_types.push_back(
|
||||
UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
|
||||
|
||||
OperationState new_state(tf_fused_batch_norm_op.getLoc(),
|
||||
TF::FusedBatchNormV3Op::getOperationName(),
|
||||
tf_fused_batch_norm_op.getOperands(),
|
||||
new_result_types,
|
||||
tf_fused_batch_norm_op.getAttrs());
|
||||
Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state);
|
||||
|
||||
rewriter.replaceOp(tf_fused_batch_norm_op,
|
||||
tf_fused_batch_norm_op_v3->getResults().drop_back());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// The below pattern is equivalent to the DRR rule below
|
||||
// The checks are dependent on generated values, so we can't add
|
||||
// the checks on intermediate values, ideally we should find equivalent
|
||||
@ -1202,7 +1177,6 @@ void PrepareTFPass::runOnFunction() {
|
||||
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>, FusedBatchNormV3Pat,
|
||||
ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(ctx);
|
||||
|
||||
patterns.insert<ConvertFusedBatchNorm>(ctx);
|
||||
TFL::populateWithGenerated(ctx, patterns);
|
||||
// TODO(karimnosseir): Split to separate pass probably after
|
||||
// deciding on long term plan for this optimization.
|
||||
|
@ -3942,6 +3942,8 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
@ -2336,6 +2336,41 @@ void NonMaxSuppressionV3Op::getCanonicalizationPatterns(
|
||||
results.insert<NMSV3ToNMSV4Op>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FusedBatchNormOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
class ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
|
||||
using OpRewritePattern<FusedBatchNormOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto new_result_types =
|
||||
llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
|
||||
// reserve_space_3
|
||||
new_result_types.push_back(
|
||||
UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
|
||||
|
||||
OperationState new_state(tf_fused_batch_norm_op.getLoc(),
|
||||
TF::FusedBatchNormV3Op::getOperationName(),
|
||||
tf_fused_batch_norm_op.getOperands(),
|
||||
new_result_types,
|
||||
tf_fused_batch_norm_op.getAttrs());
|
||||
Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state);
|
||||
|
||||
rewriter.replaceOp(tf_fused_batch_norm_op,
|
||||
tf_fused_batch_norm_op_v3->getResults().drop_back());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace.
|
||||
|
||||
void FusedBatchNormOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<ConvertFusedBatchNorm>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// UnpackOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1284,3 +1284,10 @@ func @testNMSV3ToNMSV4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tens
|
||||
%0 = "tf.NonMaxSuppressionV3"(%arg0, %arg1, %max_size, %arg2, %arg3): (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>)
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testFusedBatchNormToBatchNormV3
|
||||
func @testFusedBatchNormToBatchNormV3(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
|
||||
// CHECK: "tf.FusedBatchNormV3"
|
||||
%0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4): (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> )
|
||||
return %0#0 : tensor<8x8x8x8xf32>
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user