Remove pattern for FusedBatchNormV3 and add the generated C++ version with added extra conditions about broadcastability.
PiperOrigin-RevId: 336802159 Change-Id: Ib0ce51f6df8f9eba4d3e4d8dce67df8d82a1734a
This commit is contained in:
parent
66c99931b6
commit
78ba66c122
@ -666,4 +666,11 @@ func @xla_gather_to_slice(%arg0 : tensor<1x9x104x768xf32>) -> tensor<*xf32> {
|
||||
// CHECK: return %[[V0]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: DontMatchFusedBatchNormV3
|
||||
func @DontMatchFusedBatchNormV3(%arg0 :tensor<?x576x1x1xf32>, %arg1 : tensor<576xf32>, %arg2 : tensor<576xf32>, %arg3 : tensor<576xf32>,%arg4 : tensor<576xf32>) -> (tensor<?x576x1x1xf32>) {
|
||||
%result:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {data_format = "NHWC", device = "", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor<?x576x1x1xf32>, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>) -> (tensor<?x576x1x1xf32>, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>, tensor<576xf32>, tensor<*xf32>)
|
||||
return %result : tensor<?x576x1x1xf32>
|
||||
// CHECK: "tf.FusedBatchNormV3"
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -40,42 +40,6 @@ def : Pat<
|
||||
(TF_MulOp $t, (TF_MulOp:$mul (TF_RsqrtOp (TF_AddOp $v, (TF_ConstOp $variance_epsilon))), $gamma)),
|
||||
(TF_SubOp $beta, (TF_MulOp $m, $mul)))>;
|
||||
|
||||
// Converts tf.FusedBatchNormV3 into a sequence of more primitive arithmetic
|
||||
// operations. Specifically, performs the following calculation:
|
||||
//
|
||||
// (x - mean) * scale / sqrt(variance + epsilon) + offset
|
||||
//
|
||||
// Let multiplier = scale / sqrt(variance + epsilon),
|
||||
// to compute
|
||||
// (x - mean) * scale / sqrt(variance + epsilon) + offset,
|
||||
// is then to compute
|
||||
// (x * multiplier) + (offset - mean * multiplier).
|
||||
|
||||
def : Pattern<
|
||||
(TF_FusedBatchNormV3Op:$root
|
||||
$x, $scale, $offset, $mean, $variance,
|
||||
F32Attr:$epsilon, $exponential_avg_factor,
|
||||
$data_format, FalseBoolAttr:$is_training),
|
||||
[(TF_AddOp
|
||||
(TF_MulOp
|
||||
$x,
|
||||
(TF_MulOp:$multiplier
|
||||
$scale,
|
||||
(TF_RsqrtOp
|
||||
(TF_AddOp $variance,
|
||||
(TF_ConstOp $epsilon))))),
|
||||
(TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
|
||||
// We already guaranteed that the last five results have no use so it does
|
||||
// not matter what value we provide here for replacement.
|
||||
/*batch_mean=*/(replaceWithValue $x),
|
||||
/*batch_variance=*/(replaceWithValue $x),
|
||||
/*reserve_space_1=*/(replaceWithValue $x),
|
||||
/*reserve_space_2=*/(replaceWithValue $x),
|
||||
/*reserve_space_3=*/(replaceWithValue $x)],
|
||||
[(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
|
||||
(HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
|
||||
(HasNoUseOf:$root__5)]>;
|
||||
|
||||
class TFi32<int v> : ConstantAttr<I32ElementsAttr, !cast<string>(v)>;
|
||||
|
||||
// Matmul without transpose on b to matmul with explicit transpose op and
|
||||
|
||||
@ -765,6 +765,278 @@ struct ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// The below pattern is equivalent to the DRR rule below
|
||||
// The checks are dependent on generated values, so we can't add
|
||||
// the checks on intermediate values, ideally we should find equivalent
|
||||
// checks that guarantees the resultant ops are valid.
|
||||
// The extra conditions are the broadcasting conditions.
|
||||
//
|
||||
// The pattern lower FusedBatchNormV3 to arithmetic ops.
|
||||
// Specifically, performs the following calculation:
|
||||
//
|
||||
// (x - mean) * scale / sqrt(variance + epsilon) + offset
|
||||
//
|
||||
// Let multiplier = scale / sqrt(variance + epsilon),
|
||||
// to compute
|
||||
// (x - mean) * scale / sqrt(variance + epsilon) + offset,
|
||||
// is then to compute
|
||||
// (x * multiplier) + (offset - mean * multiplier).
|
||||
//
|
||||
// def : Pattern<
|
||||
// (TF_FusedBatchNormV3Op:$root
|
||||
// $x, $scale, $offset, $mean, $variance,
|
||||
// F32Attr:$epsilon, $exponential_avg_factor,
|
||||
// $data_format, FalseBoolAttr:$is_training),
|
||||
// [(TF_AddOp
|
||||
// (TF_MulOp
|
||||
// $x,
|
||||
// (TF_MulOp:$multiplier
|
||||
// $scale,
|
||||
// (TF_RsqrtOp
|
||||
// (TF_AddOp $variance,
|
||||
// (TF_ConstOp $epsilon))))),
|
||||
// (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))),
|
||||
// // We already guaranteed that the last five results have no use so it does
|
||||
// // not matter what value we provide here for replacement.
|
||||
// /*batch_mean=*/(replaceWithValue $x),
|
||||
// /*batch_variance=*/(replaceWithValue $x),
|
||||
// /*reserve_space_1=*/(replaceWithValue $x),
|
||||
// /*reserve_space_2=*/(replaceWithValue $x),
|
||||
// /*reserve_space_3=*/(replaceWithValue $x)],
|
||||
// [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
|
||||
// (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
|
||||
// (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
|
||||
|
||||
struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
|
||||
explicit FusedBatchNormV3Pat(::mlir::MLIRContext *context)
|
||||
: ::mlir::RewritePattern(
|
||||
"tf.FusedBatchNormV3",
|
||||
{"tf.Add", "tf.Const", "tf.Mul", "tf.Rsqrt", "tf.Sub"}, 1,
|
||||
context) {}
|
||||
|
||||
::mlir::LogicalResult matchAndRewrite(
|
||||
::mlir::Operation *fused_batch_norm,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
// Variables for capturing values and attributes used for creating ops
|
||||
Operation::operand_range mean(fused_batch_norm->getOperands());
|
||||
::mlir::FloatAttr exponential_avg_factor;
|
||||
::mlir::StringAttr data_format;
|
||||
::mlir::TF::FusedBatchNormV3Op root;
|
||||
Operation::operand_range offset(fused_batch_norm->getOperands());
|
||||
Operation::operand_range x(fused_batch_norm->getOperands());
|
||||
Operation::operand_range scale(fused_batch_norm->getOperands());
|
||||
Operation::operand_range variance(fused_batch_norm->getOperands());
|
||||
::mlir::FloatAttr epsilon;
|
||||
::mlir::BoolAttr is_training;
|
||||
|
||||
// Match
|
||||
auto fused_batch_norm_op =
|
||||
dyn_cast_or_null<::mlir::TF::FusedBatchNormV3Op>(fused_batch_norm);
|
||||
root = fused_batch_norm_op;
|
||||
x = fused_batch_norm_op.getODSOperands(0);
|
||||
scale = fused_batch_norm_op.getODSOperands(1);
|
||||
offset = fused_batch_norm_op.getODSOperands(2);
|
||||
mean = fused_batch_norm_op.getODSOperands(3);
|
||||
variance = fused_batch_norm_op.getODSOperands(4);
|
||||
{
|
||||
epsilon = fused_batch_norm_op.getAttrOfType<::mlir::FloatAttr>("epsilon");
|
||||
if (!epsilon)
|
||||
epsilon = rewriter.getFloatAttr(rewriter.getF32Type(), 0.0001f);
|
||||
|
||||
if (!(((epsilon.isa<::mlir::FloatAttr>())) &&
|
||||
((epsilon.cast<::mlir::FloatAttr>().getType().isF32())))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
|
||||
diag << "op 'tf.FusedBatchNormV3' attribute 'epsilon' failed to "
|
||||
"satisfy constraint: 32-bit float attribute";
|
||||
});
|
||||
}
|
||||
}
|
||||
{
|
||||
exponential_avg_factor =
|
||||
fused_batch_norm_op.getAttrOfType<::mlir::FloatAttr>(
|
||||
"exponential_avg_factor");
|
||||
if (!exponential_avg_factor)
|
||||
exponential_avg_factor =
|
||||
rewriter.getFloatAttr(rewriter.getF32Type(), 1.0f);
|
||||
}
|
||||
{
|
||||
data_format =
|
||||
fused_batch_norm_op.getAttrOfType<::mlir::StringAttr>("data_format");
|
||||
if (!data_format) data_format = rewriter.getStringAttr("NHWC");
|
||||
}
|
||||
{
|
||||
is_training =
|
||||
fused_batch_norm_op.getAttrOfType<::mlir::BoolAttr>("is_training");
|
||||
if (!is_training) is_training = rewriter.getBoolAttr(true);
|
||||
|
||||
if (!((!is_training.getValue()))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
|
||||
diag << "op 'tf.FusedBatchNormV3' attribute 'is_training' failed "
|
||||
"to "
|
||||
"satisfy constraint: FalseBoolAttr";
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (!(((*root.getODSResults(1).begin()).use_empty()))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
|
||||
diag << "entities '' failed to satisfy constraint: has no use";
|
||||
});
|
||||
}
|
||||
|
||||
if (!(((*root.getODSResults(2).begin()).use_empty()))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
|
||||
diag << "entities '' failed to satisfy constraint: has no use";
|
||||
});
|
||||
}
|
||||
|
||||
if (!(((*root.getODSResults(3).begin()).use_empty()))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
|
||||
diag << "entities '' failed to satisfy constraint: has no use";
|
||||
});
|
||||
}
|
||||
|
||||
if (!(((*root.getODSResults(4).begin()).use_empty()))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
|
||||
diag << "entities '' failed to satisfy constraint: has no use";
|
||||
});
|
||||
}
|
||||
|
||||
if (!(((*root.getODSResults(5).begin()).use_empty()))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
|
||||
diag << "entities '' failed to satisfy constraint: has no use";
|
||||
});
|
||||
}
|
||||
// Rewrite
|
||||
auto odsLoc = rewriter.getFusedLoc({fused_batch_norm->getLoc()});
|
||||
::llvm::SmallVector<::mlir::Value, 4> replace_values;
|
||||
::mlir::TF::ConstOp epsilon_const_op;
|
||||
{
|
||||
epsilon_const_op =
|
||||
rewriter.create<::mlir::TF::ConstOp>(odsLoc,
|
||||
/*value=*/epsilon);
|
||||
}
|
||||
::mlir::TF::AddOp add_op_1;
|
||||
{
|
||||
::mlir::Value tblgen_value_0 = (*variance.begin());
|
||||
::mlir::Value tblgen_value_1 =
|
||||
(*epsilon_const_op.getODSResults(0).begin());
|
||||
add_op_1 = rewriter.create<::mlir::TF::AddOp>(odsLoc,
|
||||
/*x=*/tblgen_value_0,
|
||||
/*y=*/tblgen_value_1);
|
||||
// We need to make sure the Add operands are broadcastable.
|
||||
if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(add_op_1)
|
||||
.value == LogicalResult::Failure) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
::mlir::TF::RsqrtOp rsqrt_op;
|
||||
{
|
||||
::mlir::SmallVector<::mlir::Value, 4> tblgen_values;
|
||||
::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs;
|
||||
tblgen_values.push_back((*add_op_1.getODSResults(0).begin()));
|
||||
rsqrt_op = rewriter.create<::mlir::TF::RsqrtOp>(odsLoc, tblgen_values,
|
||||
tblgen_attrs);
|
||||
}
|
||||
::mlir::TF::MulOp multiplier;
|
||||
{
|
||||
::mlir::Value tblgen_value_0 = (*scale.begin());
|
||||
::mlir::Value tblgen_value_1 = (*rsqrt_op.getODSResults(0).begin());
|
||||
// We need to make sure the Add operands are broadcastable.
|
||||
multiplier = rewriter.create<::mlir::TF::MulOp>(odsLoc,
|
||||
/*x=*/tblgen_value_0,
|
||||
/*y=*/tblgen_value_1);
|
||||
if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(multiplier)
|
||||
.value == LogicalResult::Failure) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
::mlir::TF::MulOp mul_op_1;
|
||||
{
|
||||
::mlir::Value tblgen_value_0 = (*x.begin());
|
||||
::mlir::Value tblgen_value_1 = (*multiplier.getODSResults(0).begin());
|
||||
mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
|
||||
/*x=*/tblgen_value_0,
|
||||
/*y=*/tblgen_value_1);
|
||||
// We need to make sure the Mul operands are broadcastable.
|
||||
if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(mul_op_1)
|
||||
.value == LogicalResult::Failure) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
::mlir::TF::MulOp mul_op_2;
|
||||
{
|
||||
::mlir::Value tblgen_value_0 = (*mean.begin());
|
||||
::mlir::Value tblgen_value_1 = (*multiplier.getODSResults(0).begin());
|
||||
mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
|
||||
/*x=*/tblgen_value_0,
|
||||
/*y=*/tblgen_value_1);
|
||||
if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(mul_op_2)
|
||||
.value == LogicalResult::Failure) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
::mlir::TF::SubOp sub_op;
|
||||
{
|
||||
::mlir::Value tblgen_value_0 = (*offset.begin());
|
||||
::mlir::Value tblgen_value_1 = (*mul_op_2.getODSResults(0).begin());
|
||||
sub_op = rewriter.create<::mlir::TF::SubOp>(odsLoc,
|
||||
/*x=*/tblgen_value_0,
|
||||
/*y=*/tblgen_value_1);
|
||||
if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(sub_op).value ==
|
||||
LogicalResult::Failure) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
::mlir::TF::AddOp add_op_2;
|
||||
{
|
||||
::mlir::SmallVector<::mlir::Value, 4> tblgen_values;
|
||||
::mlir::SmallVector<::mlir::NamedAttribute, 4> tblgen_attrs;
|
||||
tblgen_values.push_back((*mul_op_1.getODSResults(0).begin()));
|
||||
tblgen_values.push_back((*sub_op.getODSResults(0).begin()));
|
||||
::mlir::SmallVector<::mlir::Type, 4> tblgen_types;
|
||||
for (auto v : fused_batch_norm_op.getODSResults(0)) {
|
||||
tblgen_types.push_back(v.getType());
|
||||
}
|
||||
add_op_2 = rewriter.create<::mlir::TF::AddOp>(
|
||||
odsLoc, tblgen_types, tblgen_values, tblgen_attrs);
|
||||
// We need to make sure the Add operands are broadcastable.
|
||||
if (mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(add_op_2)
|
||||
.value == LogicalResult::Failure) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
for (auto v :
|
||||
::llvm::SmallVector<::mlir::Value, 4>{add_op_2.getODSResults(0)}) {
|
||||
replace_values.push_back(v);
|
||||
}
|
||||
for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
|
||||
replace_values.push_back(v);
|
||||
}
|
||||
for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
|
||||
replace_values.push_back(v);
|
||||
}
|
||||
for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
|
||||
replace_values.push_back(v);
|
||||
}
|
||||
for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
|
||||
replace_values.push_back(v);
|
||||
}
|
||||
for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{x}) {
|
||||
replace_values.push_back(v);
|
||||
}
|
||||
rewriter.replaceOp(fused_batch_norm, replace_values);
|
||||
return success();
|
||||
};
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_prepare_tf.inc"
|
||||
|
||||
// Returns success if all the operations in the `op`'s regions including `op`
|
||||
@ -927,7 +1199,7 @@ void PrepareTFPass::runOnFunction() {
|
||||
// This pattern will try to identify and optimize for dilated convolution.
|
||||
// e.g. Patterns like "SpaceToBatchND -> Conv2D -> BatchToSpaceND" will be
|
||||
// replaced with a single Conv op with dilation parameter.
|
||||
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>,
|
||||
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>, FusedBatchNormV3Pat,
|
||||
ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(ctx);
|
||||
|
||||
patterns.insert<ConvertFusedBatchNorm>(ctx);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user