Match type of unused outputs of tf.FusedBatchNormV3 during legalization.
During inference - `mean`, `variance`, and reserved outputs of `tf.FusedBatchNormV3` are not used and so we forward any value to those outputs. But it should match the type, so we forward input `mean` and `variance` to these outputs which are of the same type. PiperOrigin-RevId: 307814606 Change-Id: I51331b566138b872634d564dcfe7538e624c5662
This commit is contained in:
parent
3a3dad2f8b
commit
1fb6cec4a6
@ -27,14 +27,20 @@ func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf3
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fusedBatchNormV3_noTraining_mixedPrecision
|
||||
func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) {
|
||||
// CHECK: %[[RESULT0:.*]] = "xla_hlo.convert"(%arg0) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
|
||||
// CHECK: %[[RESULT1:.*]] = "xla_hlo.batch_norm_inference"(%[[RESULT0]], %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32>
|
||||
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
|
||||
// CHECK-NEXT: "xla_hlo.convert"(%[[RESULT1]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
|
||||
return %0#0 : tensor<8x8x8x8xbf16>
|
||||
// CHECK-SAME: ([[X:%.*]]: tensor<8x8x8x8xbf16>, [[SCALE:%.*]]: tensor<8xf32>, [[OFFSET:%.*]]: tensor<8xf32>, [[MEAN:%.*]]: tensor<8xf32>, [[VARIANCE:%.*]]: tensor<8xf32>)
|
||||
func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) {
|
||||
// CHECK: [[CONVERT_X:%.*]] = "xla_hlo.convert"([[X]]) : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32>
|
||||
// CHECK: [[Y:%.*]] = "xla_hlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
%0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>)
|
||||
// CHECK: [[Y_CONVERT:%.*]] = "xla_hlo.convert"([[Y]]) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16>
|
||||
// CHECK: [[DUMMY:%.*]] = xla_hlo.constant {value = dense<[]> : tensor<0xf32>} : tensor<*xf32>
|
||||
// CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY]]
|
||||
return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
// CHECK-LABEL: fusedBatchNormV3_training
|
||||
func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
|
||||
// CHECK: %[[RESULT0:.*]] = "xla_hlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>
|
||||
|
||||
@ -1421,11 +1421,26 @@ class ConvertFusedBatchNormV3Op
|
||||
|
||||
// The mean, variance, and reserved space outputs of the batch norm op are
|
||||
// not used for inference. It doesn't matter what values we provide for
|
||||
// the last 5 results.
|
||||
rewriter.replaceOp(
|
||||
op, {/*y=*/y_out, /*batch_mean=*/op.x(),
|
||||
/*batch_variance=*/op.x(), /*reserve_space_1=*/op.x(),
|
||||
/*reserve_space_2=*/op.x(), /*reserve_space_3=*/op.x()});
|
||||
// the last 5 results as long as they are of the same type. Forward
|
||||
// input mean and variance to output mean, variance, reserved_space_1 and
|
||||
// reserver_space_2. Create a constant tensor to forward to last
|
||||
// reserve_space_3 output.
|
||||
auto reserve_space_3_type = op.getResult(5).getType().cast<TensorType>();
|
||||
int num_elements = reserve_space_3_type.hasStaticShape()
|
||||
? reserve_space_3_type.getNumElements()
|
||||
: 0;
|
||||
auto const_attr_type = RankedTensorType::get(
|
||||
{num_elements}, getElementTypeOrSelf(reserve_space_3_type));
|
||||
auto dummy_const = rewriter.create<ConstOp>(
|
||||
op.getLoc(), reserve_space_3_type,
|
||||
DenseFPElementsAttr::get(const_attr_type,
|
||||
std::vector<float>(num_elements, 0)));
|
||||
rewriter.replaceOp(op, {/*y=*/y_out,
|
||||
/*batch_mean=*/op.mean(),
|
||||
/*batch_variance=*/op.variance(),
|
||||
/*reserve_space_1=*/op.mean(),
|
||||
/*reserve_space_2=*/op.variance(),
|
||||
/*reserve_space_3=*/dummy_const.getResult()});
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user