From e7a8ed51ee3c2010272a9a5718bed7fc356f29fc Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 24 Jun 2019 09:58:47 -0700 Subject: [PATCH] Use custom FusedBatchNorm kernel only if there is non-empty side input or activation PiperOrigin-RevId: 254778511 --- .../core/kernels/fused_batch_norm_op.cc | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index d2d6c20259e..aeae8f53403 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -523,23 +523,28 @@ struct FusedBatchNorm { // In inference mode we use custom CUDA kernel, because cuDNN does not // support side input and activations for inference. - if (!is_training) { + const bool has_side_input = side_input.dim_size(0) != 0; + const bool has_activation = + activation_mode != FusedBatchNormActivationMode::kIdentity; + + if (!is_training && (has_side_input || has_activation)) { FusedBatchNormInferenceFunctor inference_functor; - if (side_input.dim_size(0) == 0) { + + if (has_side_input) { + inference_functor(context, tensor_format, x.tensor(), + scale.vec(), offset.vec(), + estimated_mean.vec(), estimated_variance.vec(), + side_input.tensor(), epsilon, activation_mode, + y->tensor()); + } else { typename TTypes::ConstTensor empty_tensor(nullptr, 0, 0, 0, 0); inference_functor(context, tensor_format, x.tensor(), scale.vec(), offset.vec(), estimated_mean.vec(), estimated_variance.vec(), empty_tensor, epsilon, activation_mode, y->tensor()); - - } else { - inference_functor(context, tensor_format, x.tensor(), - scale.vec(), offset.vec(), - estimated_mean.vec(), estimated_variance.vec(), - side_input.tensor(), epsilon, activation_mode, - y->tensor()); } + return; }