Support training when lowering FusedBatchNormV3 op

If is_training equals True, the given mean and variance are not used.
They are calculated from input values.

PiperOrigin-RevId: 357843515
Change-Id: Ic23c5a30ec367eda8321a51921f2e4ca23cb0d28
This commit is contained in:
Thai Nguyen 2021-02-16 17:46:52 -08:00 committed by TensorFlower Gardener
parent 4d71a27952
commit b1f7822c63
3 changed files with 249 additions and 71 deletions

View File

@ -66,15 +66,19 @@ func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor
^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
// OK
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
// Unsupported training
%1:6 = "tf.FusedBatchNormV3"( %0#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
// Training with non-broadcastable shape
%cst = constant dense<0.0> : tensor<4xf32>
%1:6 = "tf.FusedBatchNormV3"( %0#0, %cst, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<4xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
// Inference with non-broadcastable shape
%2:6 = "tf.FusedBatchNormV3"( %1#0, %cst, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<4xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
// Use other output
%2:6 = "tf.FusedBatchNormV3"( %1#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
%3:6 = "tf.FusedBatchNormV3"( %2#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
return %3, %3#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
// CHECK-LABEL: fusedBatchNormV3
// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03>
// CHECK: %[[CONSTANT1:.*]] = constant dense<0.000000e+00> : tensor<4xf32>
// variance + epsilon
// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]])
// rsqrt(variance + epsilon)
@ -90,11 +94,12 @@ func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor
// x * scale * rsqrt(variance + epsilon) +
// offset - mean * scale * rsqrt(variance + epsilon)
// CHECK: %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]])
// CHECK: %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
// CHECK: "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
// CHECK: %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[ADD2]], %[[CONSTANT1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
// CHECK: %[[BATCHNORM1_b:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[CONSTANT1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
// CHECK: "tf.FusedBatchNormV3"(%[[BATCHNORM1_b]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
}
func @batchNormWithGlobalNormalization(
%t:tensor<1x10x10x3xf32>, %m:tensor<3xf32>, %v:tensor<3xf32>, %beta:tensor<3xf32>, %gamma:tensor<3xf32>) -> (tensor<1x10x10x3xf32>) {
%0 = "tf.BatchNormWithGlobalNormalization"(%t, %m, %v, %beta, %gamma) {T = "tfdtype$DT_FLOAT", variance_epsilon = 0.001 : f32, scale_after_normalization = false} : (tensor<1x10x10x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<1x10x10x3xf32>)
@ -779,4 +784,25 @@ func @strided_slice_unranked_input(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.StridedSlice"
}
func @fused_batch_norm_v3_training(%arg0 : tensor<1x1x6x2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>) -> tensor<1x1x6x2xf32> {
%0, %1, %2, %3, %4, %5 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {data_format = "NHWC", epsilon = 1.000000e-03 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = true} : (tensor<1x1x6x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<1x1x6x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<*xf32>)
return %0 : tensor<1x1x6x2xf32>
// CHECK-LABEL: fused_batch_norm_v3_training
// CHECK: %[[CST:.*]] = constant dense<[0, 1, 2]> : tensor<3xi64>
// CHECK: %[[CST0:.*]] = constant dense<0.166666672> : tensor<1xf32>
// CHECK: %[[CST1:.*]] = constant dense<1.000000e-03> : tensor<f32>
// CHECK: %[[SUM:.*]] = "tf.Sum"(%arg0, %[[CST]]) {keep_dims = false} : (tensor<1x1x6x2xf32>, tensor<3xi64>) -> tensor<2xf32>
// CHECK: %[[MUL:.*]] = "tf.Mul"(%[[SUM]], %[[CST0]]) : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32>
// CHECK: %[[SQ:.*]] = "tf.SquaredDifference"(%arg0, %[[MUL]]) : (tensor<1x1x6x2xf32>, tensor<2xf32>) -> tensor<1x1x6x2xf32>
// CHECK: %[[SUM0:.*]] = "tf.Sum"(%[[SQ]], %[[CST]]) {keep_dims = false} : (tensor<1x1x6x2xf32>, tensor<3xi64>) -> tensor<2xf32>
// CHECK: %[[MUL0:.*]] = "tf.Mul"(%[[SUM0]], %[[CST0]]) : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32>
// CHECK: %[[ADD:.*]] = "tf.Add"(%[[MUL0]], %[[CST1]]) : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
// CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: %[[MUL1:.*]] = "tf.Mul"(%arg1, %[[RSQRT]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
// CHECK: %[[MUL2:.*]] = "tf.Mul"(%arg0, %[[MUL1]]) : (tensor<1x1x6x2xf32>, tensor<2xf32>) -> tensor<1x1x6x2xf32>
// CHECK: %[[MUL3:.*]] = "tf.Mul"(%[[MUL]], %[[MUL1]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
// CHECK: %[[SUB:.*]] = "tf.Sub"(%arg2, %[[MUL3]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
// CHECK: %[[ADD0:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]]) : (tensor<1x1x6x2xf32>, tensor<2xf32>) -> tensor<1x1x6x2xf32>
// CHECK: return %[[ADD0]] : tensor<1x1x6x2xf32>
}
}

View File

@ -926,6 +926,55 @@ struct ConvertTFBroadcastTo : public RewritePattern {
// [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
// (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
// (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
//
// When is_training is set to true, the given variance and mean are not used.
// In above calculation, they are replaced by new values. These new mean and
// variance are calculated as following:
// rest_size = shape(x)[0] * shape(x)[1] * shape(x)[2]
// new_mean = sum(x, axis=[0, 1, 2]) / rest_size
// new_variance = sum(squared_difference(x, new_mean), axis=[0, 1, 2])
// / rest_size
//
// The DDR rule for the is_training equals true case is as following:
// def : Pattern<
// (TF_FusedBatchNormV3Op:$root
// $x, $scale, $offset, $mean, $variance,
// F32Attr:$epsilon, $exponential_avg_factor,
// $data_format, FalseBoolAttr:$is_training),
// [(TF_AddOp
// (TF_MulOp
// $x,
// (TF_MulOp:$multiplier
// $scale,
// (TF_RsqrtOp
// (TF_AddOp
// (TF_DivOp:$new_variance
// (TF_SumOp
// (TF_SquaredDifferenceOp $x, $new_mean),
// (TF_ConstOp [0,1,2])),
// $rest_size),
// (TF_ConstOp $epsilon))))),
// (TF_SubOp
// $offset,
// (TF_MulOp
// (TF_DivOp:$new_mean
// (TF_SumOp $x, (TF_ConstOp [0,1,2])),
// (TF_ProdOp:$rest_size
// (TF_SliceOp
// (TF_ShapeOp $x),
// (TF_ConstOp 0),
// (TF_ConstOp 3)))),
// $multiplier))),
// // We already guaranteed that the last five results have no use so it does
// // not matter what value we provide here for replacement.
// /*batch_mean=*/(replaceWithValue $x),
// /*batch_variance=*/(replaceWithValue $x),
// /*reserve_space_1=*/(replaceWithValue $x),
// /*reserve_space_2=*/(replaceWithValue $x),
// /*reserve_space_3=*/(replaceWithValue $x)],
// [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
// (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
// (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
explicit FusedBatchNormV3Pat(::mlir::MLIRContext *context)
@ -940,7 +989,6 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
// Variables for capturing values and attributes used for creating ops
Operation::operand_range mean(fused_batch_norm->getOperands());
::mlir::FloatAttr exponential_avg_factor;
::mlir::StringAttr data_format;
::mlir::TF::FusedBatchNormV3Op root;
Operation::operand_range offset(fused_batch_norm->getOperands());
Operation::operand_range x(fused_batch_norm->getOperands());
@ -959,6 +1007,9 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
mean = fused_batch_norm_op.getODSOperands(3);
variance = fused_batch_norm_op.getODSOperands(4);
::mlir::Value mean_value = (*mean.begin());
::mlir::Value variance_value = (*variance.begin());
if (!TFTypeIsFloat32Tensor(fused_batch_norm_op.x())) return failure();
{
@ -984,25 +1035,9 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
exponential_avg_factor =
rewriter.getFloatAttr(rewriter.getF32Type(), 1.0f);
}
{
data_format =
fused_batch_norm_op->getAttrOfType<::mlir::StringAttr>("data_format");
if (!data_format) data_format = rewriter.getStringAttr("NHWC");
}
{
is_training =
fused_batch_norm_op->getAttrOfType<::mlir::BoolAttr>("is_training");
if (!is_training) is_training = rewriter.getBoolAttr(true);
if (!((!is_training.getValue()))) {
return rewriter.notifyMatchFailure(
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
diag << "op 'tf.FusedBatchNormV3' attribute 'is_training' failed "
"to "
"satisfy constraint: FalseBoolAttr";
});
}
}
if (!TFDataFormatIsNHWC(fused_batch_norm_op) &&
!TFDataFormatIsNDHWC(fused_batch_norm_op))
return failure();
if (!(((*root.getODSResults(1).begin()).use_empty()))) {
return rewriter.notifyMatchFailure(
@ -1038,8 +1073,140 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
diag << "entities '' failed to satisfy constraint: has no use";
});
}
// Rewrite
is_training =
fused_batch_norm_op->getAttrOfType<::mlir::BoolAttr>("is_training");
auto odsLoc = rewriter.getFusedLoc({fused_batch_norm->getLoc()});
// We need to make sure input and output shapes are compatible.
{
int64_t last_dim = -1;
auto is_last_dim_compatible = [](const Value &v, int64_t &last_dim) {
auto v_type = v.getType().dyn_cast_or_null<RankedTensorType>();
if (!v_type) return true;
int64_t v_last_dim = v_type.getDimSize(v_type.getRank() - 1);
if (v_last_dim == -1) return true;
if (last_dim != -1 && v_last_dim != last_dim) return false;
last_dim = v_last_dim;
return true;
};
if (!is_last_dim_compatible(*x.begin(), last_dim) ||
!is_last_dim_compatible(*scale.begin(), last_dim) ||
!is_last_dim_compatible(*offset.begin(), last_dim)) {
return rewriter.notifyMatchFailure(
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
diag << "Shapes of scale and offset should be 1D and "
"compatible with x";
});
}
if (!is_training.getValue()) {
if (!is_last_dim_compatible(mean_value, last_dim) ||
!is_last_dim_compatible(variance_value, last_dim)) {
return rewriter.notifyMatchFailure(
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
diag << "Shapes of mean and variance should be 1D and "
"compatible with x";
});
}
}
// Check if output shape and input shape are compatible.
auto x_type = (*x.begin()).getType();
auto y_type = (*root.getODSResults(0).begin()).getType();
if (!OpTrait::util::getBroadcastedType(x_type, y_type)) {
return rewriter.notifyMatchFailure(
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
diag << "Shapes of x and the first output should be compatible";
});
}
}
// For training, mean and variance is calculated from input values.
if (is_training.getValue()) {
auto input_type = fused_batch_norm_op.x()
.getType()
.dyn_cast_or_null<RankedTensorType>();
if (!input_type || input_type.getRank() != 4 ||
!input_type.hasStaticShape()) {
return rewriter.notifyMatchFailure(
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
diag << "op 'tf.FusedBatchNormV3' that has 'is_training' equals "
"True is only supported with static input shape";
});
}
::mlir::TF::ConstOp reduce_dim_op;
{
auto reduce_dim_type =
::mlir::RankedTensorType::get({3}, rewriter.getIntegerType(64));
::mlir::SmallVector<int64_t, 3> reduce_dim_values = {0, 1, 2};
reduce_dim_op = rewriter.create<TF::ConstOp>(
odsLoc, ::mlir::DenseIntElementsAttr::get(reduce_dim_type,
reduce_dim_values));
}
::mlir::TF::ConstOp rest_size_inv_op;
{
int64_t rest_size = input_type.getDimSize(0) *
input_type.getDimSize(1) * input_type.getDimSize(2);
auto rest_size_inv_type =
::mlir::RankedTensorType::get({1}, rewriter.getF32Type());
auto rest_size_inv_attr = ::mlir::DenseFPElementsAttr::get(
rest_size_inv_type, {1.0f / rest_size});
rest_size_inv_op =
rewriter.create<::mlir::TF::ConstOp>(odsLoc, rest_size_inv_attr);
}
::mlir::TF::SumOp sum_op_1;
{
::mlir::Value x_value = (*x.begin());
sum_op_1 = rewriter.create<TF::SumOp>(
odsLoc, x_value, reduce_dim_op,
/*keep_dims=*/rewriter.getBoolAttr(false));
}
::mlir::TF::MulOp mul_op_1;
{
::mlir::Value tblgen_value_0 = (*sum_op_1.getODSResults(0).begin());
::mlir::Value tblgen_value_1 =
(*rest_size_inv_op.getODSResults(0).begin());
mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc, tblgen_value_0,
tblgen_value_1);
}
::mlir::TF::SquaredDifferenceOp square_diff_op;
{
::mlir::Value tblgen_value_0 = (*x.begin());
::mlir::Value tblgen_value_1 = (*mul_op_1.getODSResults(0).begin());
// If x has shape of [b, h, w, c], the result of mul_op_1 will have
// shape of [c]. Therefore, their shapes are always compatible.
square_diff_op = rewriter.create<::mlir::TF::SquaredDifferenceOp>(
odsLoc, tblgen_value_0, tblgen_value_1);
}
::mlir::TF::SumOp sum_op_2;
{
::mlir::Value input_value = (*square_diff_op.getODSResults(0).begin());
sum_op_2 = rewriter.create<TF::SumOp>(
odsLoc, input_value, reduce_dim_op,
/*keep_dims=*/rewriter.getBoolAttr(false));
}
::mlir::TF::MulOp mul_op_2;
{
::mlir::Value tblgen_value_0 = (*sum_op_2.getODSResults(0).begin());
::mlir::Value tblgen_value_1 =
(*rest_size_inv_op.getODSResults(0).begin());
mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc, tblgen_value_0,
tblgen_value_1);
}
mean_value = (*mul_op_1.getODSResults(0).begin());
variance_value = (*mul_op_2.getODSResults(0).begin());
} // End is_training equals true if.
::llvm::SmallVector<::mlir::Value, 4> replace_values;
::mlir::TF::ConstOp epsilon_const_op;
{
@ -1049,17 +1216,12 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
}
::mlir::TF::AddOp add_op_1;
{
::mlir::Value tblgen_value_0 = (*variance.begin());
::mlir::Value tblgen_value_1 =
::mlir::Value epsilon_value =
(*epsilon_const_op.getODSResults(0).begin());
// Multiplying with a constant, no need to check broadcastibility.
add_op_1 = rewriter.create<::mlir::TF::AddOp>(odsLoc,
/*x=*/tblgen_value_0,
/*y=*/tblgen_value_1);
// We need to make sure the Add operands are broadcastable.
if (mlir::failed(mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
add_op_1))) {
return failure();
}
/*x=*/variance_value,
/*y=*/epsilon_value);
}
::mlir::TF::RsqrtOp rsqrt_op;
{
@ -1073,14 +1235,9 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
{
::mlir::Value tblgen_value_0 = (*scale.begin());
::mlir::Value tblgen_value_1 = (*rsqrt_op.getODSResults(0).begin());
// We need to make sure the Add operands are broadcastable.
multiplier = rewriter.create<::mlir::TF::MulOp>(odsLoc,
/*x=*/tblgen_value_0,
/*y=*/tblgen_value_1);
if (mlir::failed(mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
multiplier))) {
return failure();
}
}
::mlir::TF::MulOp mul_op_1;
{
@ -1089,23 +1246,13 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
/*x=*/tblgen_value_0,
/*y=*/tblgen_value_1);
// We need to make sure the Mul operands are broadcastable.
if (mlir::failed(mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
mul_op_1))) {
return failure();
}
}
::mlir::TF::MulOp mul_op_2;
{
::mlir::Value tblgen_value_0 = (*mean.begin());
::mlir::Value tblgen_value_1 = (*multiplier.getODSResults(0).begin());
::mlir::Value multiplier_value = (*multiplier.getODSResults(0).begin());
mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
/*x=*/tblgen_value_0,
/*y=*/tblgen_value_1);
if (mlir::failed(mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
mul_op_2))) {
return failure();
}
/*x=*/mean_value,
/*y=*/multiplier_value);
}
::mlir::TF::SubOp sub_op;
{
@ -1114,10 +1261,6 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
sub_op = rewriter.create<::mlir::TF::SubOp>(odsLoc,
/*x=*/tblgen_value_0,
/*y=*/tblgen_value_1);
if (failed(
mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(sub_op))) {
return failure();
}
}
::mlir::TF::AddOp add_op_2;
{
@ -1131,11 +1274,6 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
}
add_op_2 = rewriter.create<::mlir::TF::AddOp>(
odsLoc, tblgen_types, tblgen_values, tblgen_attrs);
// We need to make sure the Add operands are broadcastable.
if (mlir::failed(mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
add_op_2))) {
return failure();
}
}
for (auto v :
::llvm::SmallVector<::mlir::Value, 4>{add_op_2.getODSResults(0)}) {

View File

@ -31,8 +31,18 @@ def make_fused_batch_norm_tests(options):
"dtype": [tf.float32],
"input_shape": [[1, 1, 6, 2]],
"epsilon": [0.001, 0.1],
"is_training": [False],
}]
# Training support in MLIR converter.
if options.use_experimental_converter:
test_parameters = test_parameters + [{
"dtype": [tf.float32],
"input_shape": [[1, 1, 6, 2]],
"epsilon": [0.001, 0.1],
"is_training": [True],
}]
def build_graph(parameters):
"""Build the testing graph for fused batch normalization."""
input_shape = parameters["input_shape"]
@ -43,7 +53,8 @@ def make_fused_batch_norm_tests(options):
mean = create_tensor_data(parameters["dtype"], scale_shape)
variance = create_tensor_data(parameters["dtype"], scale_shape)
x = create_tensor_data(parameters["dtype"], parameters["input_shape"])
x = tf.compat.v1.placeholder(
dtype=parameters["dtype"], name="x", shape=parameters["input_shape"])
[x_norm, _, _] = tf.compat.v1.nn.fused_batch_norm(
x,
scale,
@ -52,19 +63,22 @@ def make_fused_batch_norm_tests(options):
variance,
parameters["epsilon"],
data_format="NHWC",
is_training=False)
is_training=parameters["is_training"])
input_tensor = tf.compat.v1.placeholder(
dtype=parameters["dtype"],
name="input",
shape=parameters["input_shape"])
out = tf.add(input_tensor, x_norm)
return [input_tensor], [out]
return [x, input_tensor], [out]
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(parameters["dtype"],
parameters["input_shape"])
return [input_value], sess.run(
outputs, feed_dict=dict(zip(inputs, [input_value])))
input_values = [
create_tensor_data(parameters["dtype"], parameters["input_shape"]),
create_tensor_data(parameters["dtype"], parameters["input_shape"])
]
return input_values, sess.run(
outputs, feed_dict=dict(zip(inputs, input_values)))
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)