Fix fp16 FusedBatchNorm CPU crash if batch dimension is 1.

The GPU kernel outputs NaNs for the variance in this case, which is also incorrect, but better than crashing.

PiperOrigin-RevId: 317331280
Change-Id: Iea4e5a3337625796c50244e51d7ccb4b89f4c3e4
This commit is contained in:
Reed Wanderman-Milne 2020-06-19 10:21:14 -07:00 committed by TensorFlower Gardener
parent ef1cabc7a8
commit e0780ef031
2 changed files with 30 additions and 4 deletions
tensorflow

View File

@ -57,7 +57,8 @@ struct ReduceOuterDimensions {
if (1 == outer_dim) {
// Nothing to do but passing input to output.
output->template flat<OutputT>() =
input.template flat<OutputT>().reshape(output_dims);
input.template flat<InputT>().template cast<OutputT>().reshape(
output_dims);
return;
}
@ -226,7 +227,8 @@ struct ReduceMiddleDimensions {
if ((1 == inner_dim * outer_dim)) {
// Nothing to do.
output->template flat<OutputT>() =
input.template flat<OutputT>().reshape(output_dims);
input.template flat<InputT>().template cast<OutputT>().reshape(
output_dims);
return;
} else if (1 == inner_dim) {
// Equivalent to ReduceOuterDimensions.

View File

@ -375,9 +375,10 @@ class BatchNormalizationTest(test.TestCase):
self.assertLess(err_grad_x_2, err_tolerance)
self.assertLess(err_grad_scale, err_tolerance)
def _runtests(self, x_shape, is_training, gradient_test=False):
def _runtests(self, x_shape, is_training, gradient_test=False,
cpu_only=False):
use_gpu_vals = [False]
if test.is_gpu_available(cuda_only=True):
if test.is_gpu_available(cuda_only=True) and not cpu_only:
use_gpu_vals += [True]
factors = [1.0, 0.6]
for dtype in [np.float16, np.float32]:
@ -438,6 +439,11 @@ class BatchNormalizationTest(test.TestCase):
x_shape = [0, 131, 127, 6]
self._runtests(x_shape, False)
def testInferenceShape6(self):
x_shape = [1, 1, 1, 1]
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
self._runtests(x_shape, False, cpu_only=True)
def testTrainingShape1(self):
x_shape = [1, 1, 6, 1]
self._runtests(x_shape, True)
@ -459,6 +465,11 @@ class BatchNormalizationTest(test.TestCase):
x_shape = [0, 131, 127, 6]
self._runtests(x_shape, True)
def testTrainingShape6(self):
x_shape = [1, 1, 1, 1]
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
self._runtests(x_shape, True, cpu_only=True)
@test_util.run_deprecated_v1
def testBatchNormGradInferenceShape1(self):
x_shape = [1, 1, 6, 1]
@ -485,6 +496,13 @@ class BatchNormalizationTest(test.TestCase):
x_shape = [0, 7, 11, 4]
self._runtests(x_shape, is_training=False, gradient_test=True)
@test_util.run_deprecated_v1
def testBatchNormGradInferenceShape6(self):
x_shape = [1, 1, 1, 1]
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
self._runtests(x_shape, is_training=False, gradient_test=True,
cpu_only=True)
@test_util.run_deprecated_v1
def testBatchNormGradTrainingShape1(self):
x_shape = [1, 1, 6, 1]
@ -511,6 +529,12 @@ class BatchNormalizationTest(test.TestCase):
x_shape = [0, 7, 11, 4]
self._runtests(x_shape, is_training=True, gradient_test=True)
@test_util.run_deprecated_v1
def testBatchNormGradTrainingShape6(self):
x_shape = [1, 1, 1, 1]
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
self._runtests(x_shape, is_training=True, gradient_test=True, cpu_only=True)
def _testBatchNormGradGrad(self, config):
shape = config['shape']
err_tolerance = config['err_tolerance']