Fix fp16 FusedBatchNorm CPU crash if batch dimension is 1.
The GPU kernel outputs NaNs for the variance in this case, which is also incorrect, but better than crashing. PiperOrigin-RevId: 317331280 Change-Id: Iea4e5a3337625796c50244e51d7ccb4b89f4c3e4
This commit is contained in:
parent
ef1cabc7a8
commit
e0780ef031
tensorflow
@ -57,7 +57,8 @@ struct ReduceOuterDimensions {
|
||||
if (1 == outer_dim) {
|
||||
// Nothing to do but passing input to output.
|
||||
output->template flat<OutputT>() =
|
||||
input.template flat<OutputT>().reshape(output_dims);
|
||||
input.template flat<InputT>().template cast<OutputT>().reshape(
|
||||
output_dims);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -226,7 +227,8 @@ struct ReduceMiddleDimensions {
|
||||
if ((1 == inner_dim * outer_dim)) {
|
||||
// Nothing to do.
|
||||
output->template flat<OutputT>() =
|
||||
input.template flat<OutputT>().reshape(output_dims);
|
||||
input.template flat<InputT>().template cast<OutputT>().reshape(
|
||||
output_dims);
|
||||
return;
|
||||
} else if (1 == inner_dim) {
|
||||
// Equivalent to ReduceOuterDimensions.
|
||||
|
@ -375,9 +375,10 @@ class BatchNormalizationTest(test.TestCase):
|
||||
self.assertLess(err_grad_x_2, err_tolerance)
|
||||
self.assertLess(err_grad_scale, err_tolerance)
|
||||
|
||||
def _runtests(self, x_shape, is_training, gradient_test=False):
|
||||
def _runtests(self, x_shape, is_training, gradient_test=False,
|
||||
cpu_only=False):
|
||||
use_gpu_vals = [False]
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
if test.is_gpu_available(cuda_only=True) and not cpu_only:
|
||||
use_gpu_vals += [True]
|
||||
factors = [1.0, 0.6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
@ -438,6 +439,11 @@ class BatchNormalizationTest(test.TestCase):
|
||||
x_shape = [0, 131, 127, 6]
|
||||
self._runtests(x_shape, False)
|
||||
|
||||
def testInferenceShape6(self):
|
||||
x_shape = [1, 1, 1, 1]
|
||||
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
||||
self._runtests(x_shape, False, cpu_only=True)
|
||||
|
||||
def testTrainingShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
self._runtests(x_shape, True)
|
||||
@ -459,6 +465,11 @@ class BatchNormalizationTest(test.TestCase):
|
||||
x_shape = [0, 131, 127, 6]
|
||||
self._runtests(x_shape, True)
|
||||
|
||||
def testTrainingShape6(self):
|
||||
x_shape = [1, 1, 1, 1]
|
||||
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
||||
self._runtests(x_shape, True, cpu_only=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradInferenceShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
@ -485,6 +496,13 @@ class BatchNormalizationTest(test.TestCase):
|
||||
x_shape = [0, 7, 11, 4]
|
||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradInferenceShape6(self):
|
||||
x_shape = [1, 1, 1, 1]
|
||||
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
||||
self._runtests(x_shape, is_training=False, gradient_test=True,
|
||||
cpu_only=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradTrainingShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
@ -511,6 +529,12 @@ class BatchNormalizationTest(test.TestCase):
|
||||
x_shape = [0, 7, 11, 4]
|
||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBatchNormGradTrainingShape6(self):
|
||||
x_shape = [1, 1, 1, 1]
|
||||
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
||||
self._runtests(x_shape, is_training=True, gradient_test=True, cpu_only=True)
|
||||
|
||||
def _testBatchNormGradGrad(self, config):
|
||||
shape = config['shape']
|
||||
err_tolerance = config['err_tolerance']
|
||||
|
Loading…
Reference in New Issue
Block a user