Make axis handling for Normalization more robust.
PiperOrigin-RevId: 316898233 Change-Id: I6888216ed21c4d2a482772fb2a314160750185b6
This commit is contained in:
parent
78ecbb0481
commit
fc296acdc1
@ -53,10 +53,13 @@ class Normalization(CombinerPreprocessingLayer):
|
|||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
axis: Integer or tuple of integers, the axis or axes that should be
|
axis: Integer or tuple of integers, the axis or axes that should be
|
||||||
normalized (typically the features axis). We will normalize each element
|
"kept". These axes are not be summed over when calculating the
|
||||||
in the specified axis. If set to 'None', the layer will perform scalar
|
normalization statistics. By default the last axis, the `features` axis
|
||||||
normalization (diving the input by a single scalar value). 0 (the batch
|
is kept and any `space` or `time` axes are summed. Each element in the
|
||||||
axis) is not allowed.
|
the axes that are kept is normalized independently. If `axis` is set to
|
||||||
|
'None', the layer will perform scalar normalization (diving the input
|
||||||
|
by a single scalar value). The `batch` axis, 0, is always summed over
|
||||||
|
(`axis=0` is not allowed).
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
@ -78,10 +81,18 @@ class Normalization(CombinerPreprocessingLayer):
|
|||||||
# time, the dtype value will change to reflect it.
|
# time, the dtype value will change to reflect it.
|
||||||
dtype = dtype or K.floatx()
|
dtype = dtype or K.floatx()
|
||||||
|
|
||||||
|
# Standardize `axis` to a tuple.
|
||||||
|
if axis is None:
|
||||||
|
axis = ()
|
||||||
|
elif isinstance(axis, int):
|
||||||
|
axis = (axis,)
|
||||||
|
else:
|
||||||
|
axis = tuple(axis)
|
||||||
|
|
||||||
super(Normalization, self).__init__(
|
super(Normalization, self).__init__(
|
||||||
combiner=_NormalizingCombiner(axis), dtype=dtype, **kwargs)
|
combiner=_NormalizingCombiner(axis), dtype=dtype, **kwargs)
|
||||||
|
|
||||||
if axis == 0:
|
if 0 in axis:
|
||||||
raise ValueError('The argument \'axis\' may not be 0.')
|
raise ValueError('The argument \'axis\' may not be 0.')
|
||||||
|
|
||||||
self.axis = axis
|
self.axis = axis
|
||||||
@ -90,18 +101,27 @@ class Normalization(CombinerPreprocessingLayer):
|
|||||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||||
if len(input_shape) == 1:
|
if len(input_shape) == 1:
|
||||||
input_shape = input_shape + [1]
|
input_shape = input_shape + [1]
|
||||||
|
|
||||||
|
ndim = len(input_shape)
|
||||||
|
|
||||||
|
# Sort `self.axis` to avoid transposing `mean_and_var_shape`.
|
||||||
|
# Negative axes are not sortable until you know the number of dimensions.
|
||||||
|
original_axis = self.axis
|
||||||
|
self.axis = tuple(sorted(self.axis,
|
||||||
|
key=lambda a: a if a >= 0 else ndim + a))
|
||||||
|
|
||||||
|
if any(a < 1-ndim for a in self.axis) or any(a >= ndim for a in self.axis):
|
||||||
|
raise ValueError('All `axis` values must be in '
|
||||||
|
'the range [1-ndim, ndim-1].\n'
|
||||||
|
'Got:\n'
|
||||||
|
' ndim: {}\n'
|
||||||
|
' axis: {}'.format(ndim, original_axis))
|
||||||
|
|
||||||
self._broadcast_shape = [1 for _ in range(len(input_shape))]
|
self._broadcast_shape = [1 for _ in range(len(input_shape))]
|
||||||
if isinstance(self.axis, (tuple, list)):
|
mean_and_var_shape = []
|
||||||
mean_and_var_shape = []
|
for i in self.axis:
|
||||||
for i in self.axis:
|
mean_and_var_shape.append(input_shape[i])
|
||||||
mean_and_var_shape.append(input_shape[i])
|
self._broadcast_shape[i] = input_shape[i]
|
||||||
self._broadcast_shape[i] = input_shape[i]
|
|
||||||
else:
|
|
||||||
if self.axis is None:
|
|
||||||
mean_and_var_shape = ()
|
|
||||||
else:
|
|
||||||
mean_and_var_shape = input_shape[self.axis]
|
|
||||||
self._broadcast_shape[self.axis] = input_shape[self.axis]
|
|
||||||
|
|
||||||
# count is not used in this class's call() method, but is used to re-create
|
# count is not used in this class's call() method, but is used to re-create
|
||||||
# the accumulator during multiple calls to 'adapt'.
|
# the accumulator during multiple calls to 'adapt'.
|
||||||
@ -179,11 +199,13 @@ class _NormalizingCombiner(Combiner):
|
|||||||
if values.ndim == 1:
|
if values.ndim == 1:
|
||||||
values = np.expand_dims(values, 1)
|
values = np.expand_dims(values, 1)
|
||||||
|
|
||||||
|
# `np.delete` ignores negative indexes, so use a mask to delete items.
|
||||||
|
axis_mask = np.ones([values.ndim], dtype=bool)
|
||||||
|
axis_mask[np.array(self.axis, dtype=np.int32)] = False
|
||||||
|
|
||||||
# This is the shape of all reduced axes (not specified in 'axis').
|
# This is the shape of all reduced axes (not specified in 'axis').
|
||||||
if self.axis is None:
|
|
||||||
reduction_counts = values.shape
|
reduction_counts = np.array(values.shape)[axis_mask]
|
||||||
else:
|
|
||||||
reduction_counts = np.delete(values.shape, self.axis)
|
|
||||||
# We get the number of elements that will be reduced by multiplying all
|
# We get the number of elements that will be reduced by multiplying all
|
||||||
# values of 'shape' corresponding to the reduced axes.
|
# values of 'shape' corresponding to the reduced axes.
|
||||||
count = np.prod(reduction_counts, dtype=np.int64)
|
count = np.prod(reduction_counts, dtype=np.int64)
|
||||||
@ -191,10 +213,7 @@ class _NormalizingCombiner(Combiner):
|
|||||||
# We want to reduce across dimensions except those specified in 'axis'
|
# We want to reduce across dimensions except those specified in 'axis'
|
||||||
# when using np.mean or np.variance; create the tuple of axes to reduce
|
# when using np.mean or np.variance; create the tuple of axes to reduce
|
||||||
# over here.
|
# over here.
|
||||||
if self.axis is None:
|
reduction_axes = tuple(np.arange(values.ndim)[axis_mask])
|
||||||
reduction_axes = None
|
|
||||||
else:
|
|
||||||
reduction_axes = tuple(np.delete(range(values.ndim), self.axis))
|
|
||||||
|
|
||||||
mean = np.mean(values, axis=reduction_axes, dtype=np.float64)
|
mean = np.mean(values, axis=reduction_axes, dtype=np.float64)
|
||||||
variance = np.var(values, axis=reduction_axes, dtype=np.float64)
|
variance = np.var(values, axis=reduction_axes, dtype=np.float64)
|
||||||
|
@ -275,6 +275,49 @@ class NormalizationTest(keras_parameterized.TestCase,
|
|||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
self.assertAllClose(output.numpy(), [[-1], [1], [-1], [1]])
|
self.assertAllClose(output.numpy(), [[-1], [1], [-1], [1]])
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
{"axis": 0},
|
||||||
|
{"axis": (-1, 0)},
|
||||||
|
)
|
||||||
|
def test_zeros_fail_init(self, axis):
|
||||||
|
cls = get_layer_class()
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
"The argument 'axis' may not be 0."):
|
||||||
|
cls(axis=axis)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
# Out of bounds
|
||||||
|
{"axis": 3},
|
||||||
|
{"axis": -3},
|
||||||
|
# In a tuple
|
||||||
|
{"axis": (1, 3)},
|
||||||
|
{"axis": (1, -3)},
|
||||||
|
)
|
||||||
|
def test_bad_axis_fail_build(self, axis):
|
||||||
|
cls = get_layer_class()
|
||||||
|
layer = cls(axis=axis)
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r"in the range \[1-ndim, ndim-1\]."):
|
||||||
|
layer.build([None, 2, 3])
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
# Results should be identical no matter how the axes are specified (3d).
|
||||||
|
{"axis": (1, 2)},
|
||||||
|
{"axis": (2, 1)},
|
||||||
|
{"axis": (1, -1)},
|
||||||
|
{"axis": (-1, 1)},
|
||||||
|
)
|
||||||
|
def test_axis_permutations(self, axis):
|
||||||
|
cls = get_layer_class()
|
||||||
|
layer = cls(axis=axis)
|
||||||
|
# data.shape = [2, 2, 3]
|
||||||
|
data = np.array([[[0., 1., 2.], [0., 2., 6.]],
|
||||||
|
[[2., 3., 4.], [3., 6., 10.]]])
|
||||||
|
expect = np.array([[[-1., -1., -1.], [-1., -1., -1.]],
|
||||||
|
[[1., 1., 1.], [1., 1., 1.]]])
|
||||||
|
layer.adapt(data)
|
||||||
|
self.assertAllClose(expect, layer(data))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user