Fix fp16 FusedBatchNorm CPU crash if batch dimension is 1.
The GPU kernel outputs NaNs for the variance in this case, which is also incorrect, but better than crashing. PiperOrigin-RevId: 317331280 Change-Id: Iea4e5a3337625796c50244e51d7ccb4b89f4c3e4
This commit is contained in:
parent
ef1cabc7a8
commit
e0780ef031
@ -57,7 +57,8 @@ struct ReduceOuterDimensions {
|
|||||||
if (1 == outer_dim) {
|
if (1 == outer_dim) {
|
||||||
// Nothing to do but passing input to output.
|
// Nothing to do but passing input to output.
|
||||||
output->template flat<OutputT>() =
|
output->template flat<OutputT>() =
|
||||||
input.template flat<OutputT>().reshape(output_dims);
|
input.template flat<InputT>().template cast<OutputT>().reshape(
|
||||||
|
output_dims);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -226,7 +227,8 @@ struct ReduceMiddleDimensions {
|
|||||||
if ((1 == inner_dim * outer_dim)) {
|
if ((1 == inner_dim * outer_dim)) {
|
||||||
// Nothing to do.
|
// Nothing to do.
|
||||||
output->template flat<OutputT>() =
|
output->template flat<OutputT>() =
|
||||||
input.template flat<OutputT>().reshape(output_dims);
|
input.template flat<InputT>().template cast<OutputT>().reshape(
|
||||||
|
output_dims);
|
||||||
return;
|
return;
|
||||||
} else if (1 == inner_dim) {
|
} else if (1 == inner_dim) {
|
||||||
// Equivalent to ReduceOuterDimensions.
|
// Equivalent to ReduceOuterDimensions.
|
||||||
|
@ -375,9 +375,10 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
self.assertLess(err_grad_x_2, err_tolerance)
|
self.assertLess(err_grad_x_2, err_tolerance)
|
||||||
self.assertLess(err_grad_scale, err_tolerance)
|
self.assertLess(err_grad_scale, err_tolerance)
|
||||||
|
|
||||||
def _runtests(self, x_shape, is_training, gradient_test=False):
|
def _runtests(self, x_shape, is_training, gradient_test=False,
|
||||||
|
cpu_only=False):
|
||||||
use_gpu_vals = [False]
|
use_gpu_vals = [False]
|
||||||
if test.is_gpu_available(cuda_only=True):
|
if test.is_gpu_available(cuda_only=True) and not cpu_only:
|
||||||
use_gpu_vals += [True]
|
use_gpu_vals += [True]
|
||||||
factors = [1.0, 0.6]
|
factors = [1.0, 0.6]
|
||||||
for dtype in [np.float16, np.float32]:
|
for dtype in [np.float16, np.float32]:
|
||||||
@ -438,6 +439,11 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
x_shape = [0, 131, 127, 6]
|
x_shape = [0, 131, 127, 6]
|
||||||
self._runtests(x_shape, False)
|
self._runtests(x_shape, False)
|
||||||
|
|
||||||
|
def testInferenceShape6(self):
|
||||||
|
x_shape = [1, 1, 1, 1]
|
||||||
|
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
||||||
|
self._runtests(x_shape, False, cpu_only=True)
|
||||||
|
|
||||||
def testTrainingShape1(self):
|
def testTrainingShape1(self):
|
||||||
x_shape = [1, 1, 6, 1]
|
x_shape = [1, 1, 6, 1]
|
||||||
self._runtests(x_shape, True)
|
self._runtests(x_shape, True)
|
||||||
@ -459,6 +465,11 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
x_shape = [0, 131, 127, 6]
|
x_shape = [0, 131, 127, 6]
|
||||||
self._runtests(x_shape, True)
|
self._runtests(x_shape, True)
|
||||||
|
|
||||||
|
def testTrainingShape6(self):
|
||||||
|
x_shape = [1, 1, 1, 1]
|
||||||
|
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
||||||
|
self._runtests(x_shape, True, cpu_only=True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testBatchNormGradInferenceShape1(self):
|
def testBatchNormGradInferenceShape1(self):
|
||||||
x_shape = [1, 1, 6, 1]
|
x_shape = [1, 1, 6, 1]
|
||||||
@ -485,6 +496,13 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
x_shape = [0, 7, 11, 4]
|
x_shape = [0, 7, 11, 4]
|
||||||
self._runtests(x_shape, is_training=False, gradient_test=True)
|
self._runtests(x_shape, is_training=False, gradient_test=True)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testBatchNormGradInferenceShape6(self):
|
||||||
|
x_shape = [1, 1, 1, 1]
|
||||||
|
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
||||||
|
self._runtests(x_shape, is_training=False, gradient_test=True,
|
||||||
|
cpu_only=True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testBatchNormGradTrainingShape1(self):
|
def testBatchNormGradTrainingShape1(self):
|
||||||
x_shape = [1, 1, 6, 1]
|
x_shape = [1, 1, 6, 1]
|
||||||
@ -511,6 +529,12 @@ class BatchNormalizationTest(test.TestCase):
|
|||||||
x_shape = [0, 7, 11, 4]
|
x_shape = [0, 7, 11, 4]
|
||||||
self._runtests(x_shape, is_training=True, gradient_test=True)
|
self._runtests(x_shape, is_training=True, gradient_test=True)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testBatchNormGradTrainingShape6(self):
|
||||||
|
x_shape = [1, 1, 1, 1]
|
||||||
|
# GPU kernel doesn't properly handle case where non-channel dimensions are 1
|
||||||
|
self._runtests(x_shape, is_training=True, gradient_test=True, cpu_only=True)
|
||||||
|
|
||||||
def _testBatchNormGradGrad(self, config):
|
def _testBatchNormGradGrad(self, config):
|
||||||
shape = config['shape']
|
shape = config['shape']
|
||||||
err_tolerance = config['err_tolerance']
|
err_tolerance = config['err_tolerance']
|
||||||
|
Loading…
x
Reference in New Issue
Block a user