Fix stddev 0 bug in Normalization layer
PiperOrigin-RevId: 322836694 Change-Id: I22a669e19f369cba271e56b63a08ca2763a6eab8
This commit is contained in:
parent
c4a2703957
commit
d8dcead440
@ -156,7 +156,8 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
# broadcasts the data correctly.
|
||||
mean = array_ops.reshape(self.mean, self._broadcast_shape)
|
||||
variance = array_ops.reshape(self.variance, self._broadcast_shape)
|
||||
return (inputs - mean) / math_ops.sqrt(variance)
|
||||
return ((inputs - mean) /
|
||||
math_ops.maximum(math_ops.sqrt(variance), K.epsilon()))
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
@ -97,6 +97,16 @@ def _get_layer_computation_test_cases():
|
||||
np.float32),
|
||||
"testcase_name":
|
||||
"3d_multiple_axis"
|
||||
}, {
|
||||
"adapt_data":
|
||||
np.zeros((3, 4)),
|
||||
"axis": -1,
|
||||
"test_data":
|
||||
np.zeros((3, 4)),
|
||||
"expected":
|
||||
np.zeros((3, 4)),
|
||||
"testcase_name":
|
||||
"zero_variance"
|
||||
})
|
||||
|
||||
crossed_test_cases = []
|
||||
|
Loading…
Reference in New Issue
Block a user