Support training when lowering FusedBatchNormV3 op
If is_training equals True, the given mean and variance are not used. They are calculated from input values. PiperOrigin-RevId: 357843515 Change-Id: Ic23c5a30ec367eda8321a51921f2e4ca23cb0d28
This commit is contained in:
parent
4d71a27952
commit
b1f7822c63
@ -66,15 +66,19 @@ func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor
|
||||
^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
|
||||
// OK
|
||||
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
// Unsupported training
|
||||
%1:6 = "tf.FusedBatchNormV3"( %0#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
// Training with non-broadcastable shape
|
||||
%cst = constant dense<0.0> : tensor<4xf32>
|
||||
%1:6 = "tf.FusedBatchNormV3"( %0#0, %cst, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<4xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
// Inference with non-broadcastable shape
|
||||
%2:6 = "tf.FusedBatchNormV3"( %1#0, %cst, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<4xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
// Use other output
|
||||
%2:6 = "tf.FusedBatchNormV3"( %1#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
%3:6 = "tf.FusedBatchNormV3"( %2#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
|
||||
return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
|
||||
return %3, %3#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
|
||||
|
||||
// CHECK-LABEL: fusedBatchNormV3
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03>
|
||||
// CHECK: %[[CONSTANT1:.*]] = constant dense<0.000000e+00> : tensor<4xf32>
|
||||
// variance + epsilon
|
||||
// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]])
|
||||
// rsqrt(variance + epsilon)
|
||||
@ -90,11 +94,12 @@ func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor
|
||||
// x * scale * rsqrt(variance + epsilon) +
|
||||
// offset - mean * scale * rsqrt(variance + epsilon)
|
||||
// CHECK: %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]])
|
||||
|
||||
// CHECK: %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
|
||||
// CHECK: "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
|
||||
// CHECK: %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[ADD2]], %[[CONSTANT1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
|
||||
// CHECK: %[[BATCHNORM1_b:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[CONSTANT1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
|
||||
// CHECK: "tf.FusedBatchNormV3"(%[[BATCHNORM1_b]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
|
||||
}
|
||||
|
||||
|
||||
func @batchNormWithGlobalNormalization(
|
||||
%t:tensor<1x10x10x3xf32>, %m:tensor<3xf32>, %v:tensor<3xf32>, %beta:tensor<3xf32>, %gamma:tensor<3xf32>) -> (tensor<1x10x10x3xf32>) {
|
||||
%0 = "tf.BatchNormWithGlobalNormalization"(%t, %m, %v, %beta, %gamma) {T = "tfdtype$DT_FLOAT", variance_epsilon = 0.001 : f32, scale_after_normalization = false} : (tensor<1x10x10x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<1x10x10x3xf32>)
|
||||
@ -779,4 +784,25 @@ func @strided_slice_unranked_input(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.StridedSlice"
|
||||
}
|
||||
|
||||
func @fused_batch_norm_v3_training(%arg0 : tensor<1x1x6x2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>, %arg4 : tensor<2xf32>) -> tensor<1x1x6x2xf32> {
|
||||
%0, %1, %2, %3, %4, %5 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {data_format = "NHWC", epsilon = 1.000000e-03 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = true} : (tensor<1x1x6x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<1x1x6x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<*xf32>)
|
||||
return %0 : tensor<1x1x6x2xf32>
|
||||
// CHECK-LABEL: fused_batch_norm_v3_training
|
||||
// CHECK: %[[CST:.*]] = constant dense<[0, 1, 2]> : tensor<3xi64>
|
||||
// CHECK: %[[CST0:.*]] = constant dense<0.166666672> : tensor<1xf32>
|
||||
// CHECK: %[[CST1:.*]] = constant dense<1.000000e-03> : tensor<f32>
|
||||
// CHECK: %[[SUM:.*]] = "tf.Sum"(%arg0, %[[CST]]) {keep_dims = false} : (tensor<1x1x6x2xf32>, tensor<3xi64>) -> tensor<2xf32>
|
||||
// CHECK: %[[MUL:.*]] = "tf.Mul"(%[[SUM]], %[[CST0]]) : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32>
|
||||
// CHECK: %[[SQ:.*]] = "tf.SquaredDifference"(%arg0, %[[MUL]]) : (tensor<1x1x6x2xf32>, tensor<2xf32>) -> tensor<1x1x6x2xf32>
|
||||
// CHECK: %[[SUM0:.*]] = "tf.Sum"(%[[SQ]], %[[CST]]) {keep_dims = false} : (tensor<1x1x6x2xf32>, tensor<3xi64>) -> tensor<2xf32>
|
||||
// CHECK: %[[MUL0:.*]] = "tf.Mul"(%[[SUM0]], %[[CST0]]) : (tensor<2xf32>, tensor<1xf32>) -> tensor<2xf32>
|
||||
// CHECK: %[[ADD:.*]] = "tf.Add"(%[[MUL0]], %[[CST1]]) : (tensor<2xf32>, tensor<f32>) -> tensor<2xf32>
|
||||
// CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD]]) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: %[[MUL1:.*]] = "tf.Mul"(%arg1, %[[RSQRT]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: %[[MUL2:.*]] = "tf.Mul"(%arg0, %[[MUL1]]) : (tensor<1x1x6x2xf32>, tensor<2xf32>) -> tensor<1x1x6x2xf32>
|
||||
// CHECK: %[[MUL3:.*]] = "tf.Mul"(%[[MUL]], %[[MUL1]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: %[[SUB:.*]] = "tf.Sub"(%arg2, %[[MUL3]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: %[[ADD0:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]]) : (tensor<1x1x6x2xf32>, tensor<2xf32>) -> tensor<1x1x6x2xf32>
|
||||
// CHECK: return %[[ADD0]] : tensor<1x1x6x2xf32>
|
||||
}
|
||||
}
|
||||
|
@ -926,6 +926,55 @@ struct ConvertTFBroadcastTo : public RewritePattern {
|
||||
// [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
|
||||
// (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
|
||||
// (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
|
||||
//
|
||||
// When is_training is set to true, the given variance and mean are not used.
|
||||
// In above calculation, they are replaced by new values. These new mean and
|
||||
// variance are calculated as following:
|
||||
// rest_size = shape(x)[0] * shape(x)[1] * shape(x)[2]
|
||||
// new_mean = sum(x, axis=[0, 1, 2]) / rest_size
|
||||
// new_variance = sum(squared_difference(x, new_mean), axis=[0, 1, 2])
|
||||
// / rest_size
|
||||
//
|
||||
// The DDR rule for the is_training equals true case is as following:
|
||||
// def : Pattern<
|
||||
// (TF_FusedBatchNormV3Op:$root
|
||||
// $x, $scale, $offset, $mean, $variance,
|
||||
// F32Attr:$epsilon, $exponential_avg_factor,
|
||||
// $data_format, FalseBoolAttr:$is_training),
|
||||
// [(TF_AddOp
|
||||
// (TF_MulOp
|
||||
// $x,
|
||||
// (TF_MulOp:$multiplier
|
||||
// $scale,
|
||||
// (TF_RsqrtOp
|
||||
// (TF_AddOp
|
||||
// (TF_DivOp:$new_variance
|
||||
// (TF_SumOp
|
||||
// (TF_SquaredDifferenceOp $x, $new_mean),
|
||||
// (TF_ConstOp [0,1,2])),
|
||||
// $rest_size),
|
||||
// (TF_ConstOp $epsilon))))),
|
||||
// (TF_SubOp
|
||||
// $offset,
|
||||
// (TF_MulOp
|
||||
// (TF_DivOp:$new_mean
|
||||
// (TF_SumOp $x, (TF_ConstOp [0,1,2])),
|
||||
// (TF_ProdOp:$rest_size
|
||||
// (TF_SliceOp
|
||||
// (TF_ShapeOp $x),
|
||||
// (TF_ConstOp 0),
|
||||
// (TF_ConstOp 3)))),
|
||||
// $multiplier))),
|
||||
// // We already guaranteed that the last five results have no use so it does
|
||||
// // not matter what value we provide here for replacement.
|
||||
// /*batch_mean=*/(replaceWithValue $x),
|
||||
// /*batch_variance=*/(replaceWithValue $x),
|
||||
// /*reserve_space_1=*/(replaceWithValue $x),
|
||||
// /*reserve_space_2=*/(replaceWithValue $x),
|
||||
// /*reserve_space_3=*/(replaceWithValue $x)],
|
||||
// [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2),
|
||||
// (HasNoUseOf:$root__3), (HasNoUseOf:$root__4),
|
||||
// (HasNoUseOf:$root__5), (AreBroadcastableTypes $multiplier, $x)]>;
|
||||
|
||||
struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
|
||||
explicit FusedBatchNormV3Pat(::mlir::MLIRContext *context)
|
||||
@ -940,7 +989,6 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
|
||||
// Variables for capturing values and attributes used for creating ops
|
||||
Operation::operand_range mean(fused_batch_norm->getOperands());
|
||||
::mlir::FloatAttr exponential_avg_factor;
|
||||
::mlir::StringAttr data_format;
|
||||
::mlir::TF::FusedBatchNormV3Op root;
|
||||
Operation::operand_range offset(fused_batch_norm->getOperands());
|
||||
Operation::operand_range x(fused_batch_norm->getOperands());
|
||||
@ -959,6 +1007,9 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
|
||||
mean = fused_batch_norm_op.getODSOperands(3);
|
||||
variance = fused_batch_norm_op.getODSOperands(4);
|
||||
|
||||
::mlir::Value mean_value = (*mean.begin());
|
||||
::mlir::Value variance_value = (*variance.begin());
|
||||
|
||||
if (!TFTypeIsFloat32Tensor(fused_batch_norm_op.x())) return failure();
|
||||
|
||||
{
|
||||
@ -984,25 +1035,9 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
|
||||
exponential_avg_factor =
|
||||
rewriter.getFloatAttr(rewriter.getF32Type(), 1.0f);
|
||||
}
|
||||
{
|
||||
data_format =
|
||||
fused_batch_norm_op->getAttrOfType<::mlir::StringAttr>("data_format");
|
||||
if (!data_format) data_format = rewriter.getStringAttr("NHWC");
|
||||
}
|
||||
{
|
||||
is_training =
|
||||
fused_batch_norm_op->getAttrOfType<::mlir::BoolAttr>("is_training");
|
||||
if (!is_training) is_training = rewriter.getBoolAttr(true);
|
||||
|
||||
if (!((!is_training.getValue()))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
|
||||
diag << "op 'tf.FusedBatchNormV3' attribute 'is_training' failed "
|
||||
"to "
|
||||
"satisfy constraint: FalseBoolAttr";
|
||||
});
|
||||
}
|
||||
}
|
||||
if (!TFDataFormatIsNHWC(fused_batch_norm_op) &&
|
||||
!TFDataFormatIsNDHWC(fused_batch_norm_op))
|
||||
return failure();
|
||||
|
||||
if (!(((*root.getODSResults(1).begin()).use_empty()))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
@ -1038,8 +1073,140 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
|
||||
diag << "entities '' failed to satisfy constraint: has no use";
|
||||
});
|
||||
}
|
||||
// Rewrite
|
||||
|
||||
is_training =
|
||||
fused_batch_norm_op->getAttrOfType<::mlir::BoolAttr>("is_training");
|
||||
auto odsLoc = rewriter.getFusedLoc({fused_batch_norm->getLoc()});
|
||||
|
||||
// We need to make sure input and output shapes are compatible.
|
||||
{
|
||||
int64_t last_dim = -1;
|
||||
auto is_last_dim_compatible = [](const Value &v, int64_t &last_dim) {
|
||||
auto v_type = v.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
if (!v_type) return true;
|
||||
int64_t v_last_dim = v_type.getDimSize(v_type.getRank() - 1);
|
||||
if (v_last_dim == -1) return true;
|
||||
if (last_dim != -1 && v_last_dim != last_dim) return false;
|
||||
last_dim = v_last_dim;
|
||||
return true;
|
||||
};
|
||||
|
||||
if (!is_last_dim_compatible(*x.begin(), last_dim) ||
|
||||
!is_last_dim_compatible(*scale.begin(), last_dim) ||
|
||||
!is_last_dim_compatible(*offset.begin(), last_dim)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
|
||||
diag << "Shapes of scale and offset should be 1D and "
|
||||
"compatible with x";
|
||||
});
|
||||
}
|
||||
|
||||
if (!is_training.getValue()) {
|
||||
if (!is_last_dim_compatible(mean_value, last_dim) ||
|
||||
!is_last_dim_compatible(variance_value, last_dim)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
|
||||
diag << "Shapes of mean and variance should be 1D and "
|
||||
"compatible with x";
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Check if output shape and input shape are compatible.
|
||||
auto x_type = (*x.begin()).getType();
|
||||
auto y_type = (*root.getODSResults(0).begin()).getType();
|
||||
if (!OpTrait::util::getBroadcastedType(x_type, y_type)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
|
||||
diag << "Shapes of x and the first output should be compatible";
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// For training, mean and variance is calculated from input values.
|
||||
if (is_training.getValue()) {
|
||||
auto input_type = fused_batch_norm_op.x()
|
||||
.getType()
|
||||
.dyn_cast_or_null<RankedTensorType>();
|
||||
if (!input_type || input_type.getRank() != 4 ||
|
||||
!input_type.hasStaticShape()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
fused_batch_norm_op, [&](::mlir::Diagnostic &diag) {
|
||||
diag << "op 'tf.FusedBatchNormV3' that has 'is_training' equals "
|
||||
"True is only supported with static input shape";
|
||||
});
|
||||
}
|
||||
|
||||
::mlir::TF::ConstOp reduce_dim_op;
|
||||
{
|
||||
auto reduce_dim_type =
|
||||
::mlir::RankedTensorType::get({3}, rewriter.getIntegerType(64));
|
||||
::mlir::SmallVector<int64_t, 3> reduce_dim_values = {0, 1, 2};
|
||||
reduce_dim_op = rewriter.create<TF::ConstOp>(
|
||||
odsLoc, ::mlir::DenseIntElementsAttr::get(reduce_dim_type,
|
||||
reduce_dim_values));
|
||||
}
|
||||
|
||||
::mlir::TF::ConstOp rest_size_inv_op;
|
||||
{
|
||||
int64_t rest_size = input_type.getDimSize(0) *
|
||||
input_type.getDimSize(1) * input_type.getDimSize(2);
|
||||
auto rest_size_inv_type =
|
||||
::mlir::RankedTensorType::get({1}, rewriter.getF32Type());
|
||||
auto rest_size_inv_attr = ::mlir::DenseFPElementsAttr::get(
|
||||
rest_size_inv_type, {1.0f / rest_size});
|
||||
rest_size_inv_op =
|
||||
rewriter.create<::mlir::TF::ConstOp>(odsLoc, rest_size_inv_attr);
|
||||
}
|
||||
|
||||
::mlir::TF::SumOp sum_op_1;
|
||||
{
|
||||
::mlir::Value x_value = (*x.begin());
|
||||
sum_op_1 = rewriter.create<TF::SumOp>(
|
||||
odsLoc, x_value, reduce_dim_op,
|
||||
/*keep_dims=*/rewriter.getBoolAttr(false));
|
||||
}
|
||||
|
||||
::mlir::TF::MulOp mul_op_1;
|
||||
{
|
||||
::mlir::Value tblgen_value_0 = (*sum_op_1.getODSResults(0).begin());
|
||||
::mlir::Value tblgen_value_1 =
|
||||
(*rest_size_inv_op.getODSResults(0).begin());
|
||||
mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc, tblgen_value_0,
|
||||
tblgen_value_1);
|
||||
}
|
||||
|
||||
::mlir::TF::SquaredDifferenceOp square_diff_op;
|
||||
{
|
||||
::mlir::Value tblgen_value_0 = (*x.begin());
|
||||
::mlir::Value tblgen_value_1 = (*mul_op_1.getODSResults(0).begin());
|
||||
// If x has shape of [b, h, w, c], the result of mul_op_1 will have
|
||||
// shape of [c]. Therefore, their shapes are always compatible.
|
||||
square_diff_op = rewriter.create<::mlir::TF::SquaredDifferenceOp>(
|
||||
odsLoc, tblgen_value_0, tblgen_value_1);
|
||||
}
|
||||
|
||||
::mlir::TF::SumOp sum_op_2;
|
||||
{
|
||||
::mlir::Value input_value = (*square_diff_op.getODSResults(0).begin());
|
||||
sum_op_2 = rewriter.create<TF::SumOp>(
|
||||
odsLoc, input_value, reduce_dim_op,
|
||||
/*keep_dims=*/rewriter.getBoolAttr(false));
|
||||
}
|
||||
|
||||
::mlir::TF::MulOp mul_op_2;
|
||||
{
|
||||
::mlir::Value tblgen_value_0 = (*sum_op_2.getODSResults(0).begin());
|
||||
::mlir::Value tblgen_value_1 =
|
||||
(*rest_size_inv_op.getODSResults(0).begin());
|
||||
mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc, tblgen_value_0,
|
||||
tblgen_value_1);
|
||||
}
|
||||
|
||||
mean_value = (*mul_op_1.getODSResults(0).begin());
|
||||
variance_value = (*mul_op_2.getODSResults(0).begin());
|
||||
} // End is_training equals true if.
|
||||
|
||||
::llvm::SmallVector<::mlir::Value, 4> replace_values;
|
||||
::mlir::TF::ConstOp epsilon_const_op;
|
||||
{
|
||||
@ -1049,17 +1216,12 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
|
||||
}
|
||||
::mlir::TF::AddOp add_op_1;
|
||||
{
|
||||
::mlir::Value tblgen_value_0 = (*variance.begin());
|
||||
::mlir::Value tblgen_value_1 =
|
||||
::mlir::Value epsilon_value =
|
||||
(*epsilon_const_op.getODSResults(0).begin());
|
||||
// Multiplying with a constant, no need to check broadcastibility.
|
||||
add_op_1 = rewriter.create<::mlir::TF::AddOp>(odsLoc,
|
||||
/*x=*/tblgen_value_0,
|
||||
/*y=*/tblgen_value_1);
|
||||
// We need to make sure the Add operands are broadcastable.
|
||||
if (mlir::failed(mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
|
||||
add_op_1))) {
|
||||
return failure();
|
||||
}
|
||||
/*x=*/variance_value,
|
||||
/*y=*/epsilon_value);
|
||||
}
|
||||
::mlir::TF::RsqrtOp rsqrt_op;
|
||||
{
|
||||
@ -1073,14 +1235,9 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
|
||||
{
|
||||
::mlir::Value tblgen_value_0 = (*scale.begin());
|
||||
::mlir::Value tblgen_value_1 = (*rsqrt_op.getODSResults(0).begin());
|
||||
// We need to make sure the Add operands are broadcastable.
|
||||
multiplier = rewriter.create<::mlir::TF::MulOp>(odsLoc,
|
||||
/*x=*/tblgen_value_0,
|
||||
/*y=*/tblgen_value_1);
|
||||
if (mlir::failed(mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
|
||||
multiplier))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
::mlir::TF::MulOp mul_op_1;
|
||||
{
|
||||
@ -1089,23 +1246,13 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
|
||||
mul_op_1 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
|
||||
/*x=*/tblgen_value_0,
|
||||
/*y=*/tblgen_value_1);
|
||||
// We need to make sure the Mul operands are broadcastable.
|
||||
if (mlir::failed(mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
|
||||
mul_op_1))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
::mlir::TF::MulOp mul_op_2;
|
||||
{
|
||||
::mlir::Value tblgen_value_0 = (*mean.begin());
|
||||
::mlir::Value tblgen_value_1 = (*multiplier.getODSResults(0).begin());
|
||||
::mlir::Value multiplier_value = (*multiplier.getODSResults(0).begin());
|
||||
mul_op_2 = rewriter.create<::mlir::TF::MulOp>(odsLoc,
|
||||
/*x=*/tblgen_value_0,
|
||||
/*y=*/tblgen_value_1);
|
||||
if (mlir::failed(mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
|
||||
mul_op_2))) {
|
||||
return failure();
|
||||
}
|
||||
/*x=*/mean_value,
|
||||
/*y=*/multiplier_value);
|
||||
}
|
||||
::mlir::TF::SubOp sub_op;
|
||||
{
|
||||
@ -1114,10 +1261,6 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
|
||||
sub_op = rewriter.create<::mlir::TF::SubOp>(odsLoc,
|
||||
/*x=*/tblgen_value_0,
|
||||
/*y=*/tblgen_value_1);
|
||||
if (failed(
|
||||
mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(sub_op))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
::mlir::TF::AddOp add_op_2;
|
||||
{
|
||||
@ -1131,11 +1274,6 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern {
|
||||
}
|
||||
add_op_2 = rewriter.create<::mlir::TF::AddOp>(
|
||||
odsLoc, tblgen_types, tblgen_values, tblgen_attrs);
|
||||
// We need to make sure the Add operands are broadcastable.
|
||||
if (mlir::failed(mlir::OpTrait::impl::verifyCompatibleOperandBroadcast(
|
||||
add_op_2))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
for (auto v :
|
||||
::llvm::SmallVector<::mlir::Value, 4>{add_op_2.getODSResults(0)}) {
|
||||
|
@ -31,8 +31,18 @@ def make_fused_batch_norm_tests(options):
|
||||
"dtype": [tf.float32],
|
||||
"input_shape": [[1, 1, 6, 2]],
|
||||
"epsilon": [0.001, 0.1],
|
||||
"is_training": [False],
|
||||
}]
|
||||
|
||||
# Training support in MLIR converter.
|
||||
if options.use_experimental_converter:
|
||||
test_parameters = test_parameters + [{
|
||||
"dtype": [tf.float32],
|
||||
"input_shape": [[1, 1, 6, 2]],
|
||||
"epsilon": [0.001, 0.1],
|
||||
"is_training": [True],
|
||||
}]
|
||||
|
||||
def build_graph(parameters):
|
||||
"""Build the testing graph for fused batch normalization."""
|
||||
input_shape = parameters["input_shape"]
|
||||
@ -43,7 +53,8 @@ def make_fused_batch_norm_tests(options):
|
||||
mean = create_tensor_data(parameters["dtype"], scale_shape)
|
||||
variance = create_tensor_data(parameters["dtype"], scale_shape)
|
||||
|
||||
x = create_tensor_data(parameters["dtype"], parameters["input_shape"])
|
||||
x = tf.compat.v1.placeholder(
|
||||
dtype=parameters["dtype"], name="x", shape=parameters["input_shape"])
|
||||
[x_norm, _, _] = tf.compat.v1.nn.fused_batch_norm(
|
||||
x,
|
||||
scale,
|
||||
@ -52,19 +63,22 @@ def make_fused_batch_norm_tests(options):
|
||||
variance,
|
||||
parameters["epsilon"],
|
||||
data_format="NHWC",
|
||||
is_training=False)
|
||||
is_training=parameters["is_training"])
|
||||
|
||||
input_tensor = tf.compat.v1.placeholder(
|
||||
dtype=parameters["dtype"],
|
||||
name="input",
|
||||
shape=parameters["input_shape"])
|
||||
out = tf.add(input_tensor, x_norm)
|
||||
return [input_tensor], [out]
|
||||
return [x, input_tensor], [out]
|
||||
|
||||
def build_inputs(parameters, sess, inputs, outputs):
|
||||
input_value = create_tensor_data(parameters["dtype"],
|
||||
parameters["input_shape"])
|
||||
return [input_value], sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, [input_value])))
|
||||
input_values = [
|
||||
create_tensor_data(parameters["dtype"], parameters["input_shape"]),
|
||||
create_tensor_data(parameters["dtype"], parameters["input_shape"])
|
||||
]
|
||||
|
||||
return input_values, sess.run(
|
||||
outputs, feed_dict=dict(zip(inputs, input_values)))
|
||||
|
||||
make_zip_of_tests(options, test_parameters, build_graph, build_inputs)
|
||||
|
Loading…
Reference in New Issue
Block a user