Use custom FusedBatchNorm kernel only if there is non-empty side input or activation

PiperOrigin-RevId: 254778511
This commit is contained in:
Eugene Zhulenev 2019-06-24 09:58:47 -07:00 committed by TensorFlower Gardener
parent 5b9ebedc98
commit e7a8ed51ee

View File

@ -523,23 +523,28 @@ struct FusedBatchNorm<GPUDevice, T, U> {
// In inference mode we use custom CUDA kernel, because cuDNN does not
// support side input and activations for inference.
if (!is_training) {
const bool has_side_input = side_input.dim_size(0) != 0;
const bool has_activation =
activation_mode != FusedBatchNormActivationMode::kIdentity;
if (!is_training && (has_side_input || has_activation)) {
FusedBatchNormInferenceFunctor<GPUDevice, T, U> inference_functor;
if (side_input.dim_size(0) == 0) {
if (has_side_input) {
inference_functor(context, tensor_format, x.tensor<T, 4>(),
scale.vec<U>(), offset.vec<U>(),
estimated_mean.vec<U>(), estimated_variance.vec<U>(),
side_input.tensor<T, 4>(), epsilon, activation_mode,
y->tensor<T, 4>());
} else {
typename TTypes<T, 4>::ConstTensor empty_tensor(nullptr, 0, 0, 0, 0);
inference_functor(context, tensor_format, x.tensor<T, 4>(),
scale.vec<U>(), offset.vec<U>(),
estimated_mean.vec<U>(), estimated_variance.vec<U>(),
empty_tensor, epsilon, activation_mode,
y->tensor<T, 4>());
} else {
inference_functor(context, tensor_format, x.tensor<T, 4>(),
scale.vec<U>(), offset.vec<U>(),
estimated_mean.vec<U>(), estimated_variance.vec<U>(),
side_input.tensor<T, 4>(), epsilon, activation_mode,
y->tensor<T, 4>());
}
return;
}