Use custom FusedBatchNorm kernel only if there is non-empty side input or activation
PiperOrigin-RevId: 254778511
This commit is contained in:
parent
5b9ebedc98
commit
e7a8ed51ee
@ -523,23 +523,28 @@ struct FusedBatchNorm<GPUDevice, T, U> {
|
||||
|
||||
// In inference mode we use custom CUDA kernel, because cuDNN does not
|
||||
// support side input and activations for inference.
|
||||
if (!is_training) {
|
||||
const bool has_side_input = side_input.dim_size(0) != 0;
|
||||
const bool has_activation =
|
||||
activation_mode != FusedBatchNormActivationMode::kIdentity;
|
||||
|
||||
if (!is_training && (has_side_input || has_activation)) {
|
||||
FusedBatchNormInferenceFunctor<GPUDevice, T, U> inference_functor;
|
||||
if (side_input.dim_size(0) == 0) {
|
||||
|
||||
if (has_side_input) {
|
||||
inference_functor(context, tensor_format, x.tensor<T, 4>(),
|
||||
scale.vec<U>(), offset.vec<U>(),
|
||||
estimated_mean.vec<U>(), estimated_variance.vec<U>(),
|
||||
side_input.tensor<T, 4>(), epsilon, activation_mode,
|
||||
y->tensor<T, 4>());
|
||||
} else {
|
||||
typename TTypes<T, 4>::ConstTensor empty_tensor(nullptr, 0, 0, 0, 0);
|
||||
inference_functor(context, tensor_format, x.tensor<T, 4>(),
|
||||
scale.vec<U>(), offset.vec<U>(),
|
||||
estimated_mean.vec<U>(), estimated_variance.vec<U>(),
|
||||
empty_tensor, epsilon, activation_mode,
|
||||
y->tensor<T, 4>());
|
||||
|
||||
} else {
|
||||
inference_functor(context, tensor_format, x.tensor<T, 4>(),
|
||||
scale.vec<U>(), offset.vec<U>(),
|
||||
estimated_mean.vec<U>(), estimated_variance.vec<U>(),
|
||||
side_input.tensor<T, 4>(), epsilon, activation_mode,
|
||||
y->tensor<T, 4>());
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user