Fix stddev 0 bug in Normalization layer

PiperOrigin-RevId: 322836694
Change-Id: I22a669e19f369cba271e56b63a08ca2763a6eab8
This commit is contained in:
Francois Chollet 2020-07-23 12:06:49 -07:00 committed by TensorFlower Gardener
parent c4a2703957
commit d8dcead440
2 changed files with 12 additions and 1 deletions

View File

@ -156,7 +156,8 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
# broadcasts the data correctly.
mean = array_ops.reshape(self.mean, self._broadcast_shape)
variance = array_ops.reshape(self.variance, self._broadcast_shape)
return (inputs - mean) / math_ops.sqrt(variance)
return ((inputs - mean) /
math_ops.maximum(math_ops.sqrt(variance), K.epsilon()))
def compute_output_shape(self, input_shape):
return input_shape

View File

@ -97,6 +97,16 @@ def _get_layer_computation_test_cases():
np.float32),
"testcase_name":
"3d_multiple_axis"
}, {
"adapt_data":
np.zeros((3, 4)),
"axis": -1,
"test_data":
np.zeros((3, 4)),
"expected":
np.zeros((3, 4)),
"testcase_name":
"zero_variance"
})
crossed_test_cases = []