Fix stddev 0 bug in Normalization layer
PiperOrigin-RevId: 322836694 Change-Id: I22a669e19f369cba271e56b63a08ca2763a6eab8
This commit is contained in:
parent
c4a2703957
commit
d8dcead440
@ -156,7 +156,8 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
|
|||||||
# broadcasts the data correctly.
|
# broadcasts the data correctly.
|
||||||
mean = array_ops.reshape(self.mean, self._broadcast_shape)
|
mean = array_ops.reshape(self.mean, self._broadcast_shape)
|
||||||
variance = array_ops.reshape(self.variance, self._broadcast_shape)
|
variance = array_ops.reshape(self.variance, self._broadcast_shape)
|
||||||
return (inputs - mean) / math_ops.sqrt(variance)
|
return ((inputs - mean) /
|
||||||
|
math_ops.maximum(math_ops.sqrt(variance), K.epsilon()))
|
||||||
|
|
||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
return input_shape
|
return input_shape
|
||||||
|
@ -97,6 +97,16 @@ def _get_layer_computation_test_cases():
|
|||||||
np.float32),
|
np.float32),
|
||||||
"testcase_name":
|
"testcase_name":
|
||||||
"3d_multiple_axis"
|
"3d_multiple_axis"
|
||||||
|
}, {
|
||||||
|
"adapt_data":
|
||||||
|
np.zeros((3, 4)),
|
||||||
|
"axis": -1,
|
||||||
|
"test_data":
|
||||||
|
np.zeros((3, 4)),
|
||||||
|
"expected":
|
||||||
|
np.zeros((3, 4)),
|
||||||
|
"testcase_name":
|
||||||
|
"zero_variance"
|
||||||
})
|
})
|
||||||
|
|
||||||
crossed_test_cases = []
|
crossed_test_cases = []
|
||||||
|
Loading…
Reference in New Issue
Block a user