Match type of unused outputs of tf.FusedBatchNormV3 during legalization.

During inference - `mean`, `variance`, and reserved outputs of `tf.FusedBatchNormV3` are not used and so we forward any value to those outputs. But it should match the type, so we forward input `mean` and `variance` to these outputs which are of the same type.

PiperOrigin-RevId: 307814606
Change-Id: I51331b566138b872634d564dcfe7538e624c5662
This commit is contained in:
Prakalp Srivastava 2020-04-22 07:43:50 -07:00 committed by TensorFlower Gardener
parent 3a3dad2f8b
commit 1fb6cec4a6
2 changed files with 32 additions and 11 deletions

View File

@ -27,14 +27,20 @@ func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf3
}
// CHECK-LABEL: fusedBatchNormV3_noTraining_mixedPrecision
func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
// CHECK: %[[RESULT0:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
// CHECK: %[[RESULT1:.*]] = "xla_hlo.batch_norm_inference"(%[[RESULT0]], %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
// CHECK-NEXT: "xla_hlo.convert"(%[[RESULT1]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
return %0#0 : tensor<8x8x8x8xbf16>
// CHECK-SAME: ([[X:%.*]]: tensor<8x8x8x8xbf16>, [[SCALE:%.*]]: tensor<8xf32>, [[OFFSET:%.*]]: tensor<8xf32>, [[MEAN:%.*]]: tensor<8xf32>, [[VARIANCE:%.*]]: tensor<8xf32>)
func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) {
// CHECK: [[CONVERT_X:%.*]] = "xla_hlo.convert"([[X]]) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
// CHECK: [[Y:%.*]] = "xla_hlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>)
// CHECK: [[Y_CONVERT:%.*]] = "xla_hlo.convert"([[Y]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
// CHECK: [[DUMMY:%.*]] = xla_hlo.constant {value = dense<[]> : tensor<0xf32>} : tensor<*xf32>
// CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY]]
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>
}
// CHECK-LABEL: fusedBatchNormV3_training
func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
// CHECK: %[[RESULT0:.*]] = "xla_hlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>

View File

@ -1421,11 +1421,26 @@ class ConvertFusedBatchNormV3Op
// The mean, variance, and reserved space outputs of the batch norm op are
// not used for inference. It doesn't matter what values we provide for
// the last 5 results.
rewriter.replaceOp(
op, {/*y=*/y_out, /*batch_mean=*/op.x(),
/*batch_variance=*/op.x(), /*reserve_space_1=*/op.x(),
/*reserve_space_2=*/op.x(), /*reserve_space_3=*/op.x()});
// the last 5 results as long as they are of the same type. Forward
// input mean and variance to output mean, variance, reserved_space_1 and
// reserver_space_2. Create a constant tensor to forward to last
// reserve_space_3 output.
auto reserve_space_3_type = op.getResult(5).getType().cast<TensorType>();
int num_elements = reserve_space_3_type.hasStaticShape()
? reserve_space_3_type.getNumElements()
: 0;
auto const_attr_type = RankedTensorType::get(
{num_elements}, getElementTypeOrSelf(reserve_space_3_type));
auto dummy_const = rewriter.create<ConstOp>(
op.getLoc(), reserve_space_3_type,
DenseFPElementsAttr::get(const_attr_type,
std::vector<float>(num_elements, 0)));
rewriter.replaceOp(op, {/*y=*/y_out,
/*batch_mean=*/op.mean(),
/*batch_variance=*/op.variance(),
/*reserve_space_1=*/op.mean(),
/*reserve_space_2=*/op.variance(),
/*reserve_space_3=*/dummy_const.getResult()});
}
return success();
}