Make axis handling for Normalization more robust.
PiperOrigin-RevId: 316898233 Change-Id: I6888216ed21c4d2a482772fb2a314160750185b6
This commit is contained in:
parent
78ecbb0481
commit
fc296acdc1
@ -53,10 +53,13 @@ class Normalization(CombinerPreprocessingLayer):
|
||||
|
||||
Attributes:
|
||||
axis: Integer or tuple of integers, the axis or axes that should be
|
||||
normalized (typically the features axis). We will normalize each element
|
||||
in the specified axis. If set to 'None', the layer will perform scalar
|
||||
normalization (diving the input by a single scalar value). 0 (the batch
|
||||
axis) is not allowed.
|
||||
"kept". These axes are not be summed over when calculating the
|
||||
normalization statistics. By default the last axis, the `features` axis
|
||||
is kept and any `space` or `time` axes are summed. Each element in the
|
||||
the axes that are kept is normalized independently. If `axis` is set to
|
||||
'None', the layer will perform scalar normalization (diving the input
|
||||
by a single scalar value). The `batch` axis, 0, is always summed over
|
||||
(`axis=0` is not allowed).
|
||||
|
||||
Examples:
|
||||
|
||||
@ -78,10 +81,18 @@ class Normalization(CombinerPreprocessingLayer):
|
||||
# time, the dtype value will change to reflect it.
|
||||
dtype = dtype or K.floatx()
|
||||
|
||||
# Standardize `axis` to a tuple.
|
||||
if axis is None:
|
||||
axis = ()
|
||||
elif isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
else:
|
||||
axis = tuple(axis)
|
||||
|
||||
super(Normalization, self).__init__(
|
||||
combiner=_NormalizingCombiner(axis), dtype=dtype, **kwargs)
|
||||
|
||||
if axis == 0:
|
||||
if 0 in axis:
|
||||
raise ValueError('The argument \'axis\' may not be 0.')
|
||||
|
||||
self.axis = axis
|
||||
@ -90,18 +101,27 @@ class Normalization(CombinerPreprocessingLayer):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
if len(input_shape) == 1:
|
||||
input_shape = input_shape + [1]
|
||||
|
||||
ndim = len(input_shape)
|
||||
|
||||
# Sort `self.axis` to avoid transposing `mean_and_var_shape`.
|
||||
# Negative axes are not sortable until you know the number of dimensions.
|
||||
original_axis = self.axis
|
||||
self.axis = tuple(sorted(self.axis,
|
||||
key=lambda a: a if a >= 0 else ndim + a))
|
||||
|
||||
if any(a < 1-ndim for a in self.axis) or any(a >= ndim for a in self.axis):
|
||||
raise ValueError('All `axis` values must be in '
|
||||
'the range [1-ndim, ndim-1].\n'
|
||||
'Got:\n'
|
||||
' ndim: {}\n'
|
||||
' axis: {}'.format(ndim, original_axis))
|
||||
|
||||
self._broadcast_shape = [1 for _ in range(len(input_shape))]
|
||||
if isinstance(self.axis, (tuple, list)):
|
||||
mean_and_var_shape = []
|
||||
for i in self.axis:
|
||||
mean_and_var_shape.append(input_shape[i])
|
||||
self._broadcast_shape[i] = input_shape[i]
|
||||
else:
|
||||
if self.axis is None:
|
||||
mean_and_var_shape = ()
|
||||
else:
|
||||
mean_and_var_shape = input_shape[self.axis]
|
||||
self._broadcast_shape[self.axis] = input_shape[self.axis]
|
||||
mean_and_var_shape = []
|
||||
for i in self.axis:
|
||||
mean_and_var_shape.append(input_shape[i])
|
||||
self._broadcast_shape[i] = input_shape[i]
|
||||
|
||||
# count is not used in this class's call() method, but is used to re-create
|
||||
# the accumulator during multiple calls to 'adapt'.
|
||||
@ -179,11 +199,13 @@ class _NormalizingCombiner(Combiner):
|
||||
if values.ndim == 1:
|
||||
values = np.expand_dims(values, 1)
|
||||
|
||||
# `np.delete` ignores negative indexes, so use a mask to delete items.
|
||||
axis_mask = np.ones([values.ndim], dtype=bool)
|
||||
axis_mask[np.array(self.axis, dtype=np.int32)] = False
|
||||
|
||||
# This is the shape of all reduced axes (not specified in 'axis').
|
||||
if self.axis is None:
|
||||
reduction_counts = values.shape
|
||||
else:
|
||||
reduction_counts = np.delete(values.shape, self.axis)
|
||||
|
||||
reduction_counts = np.array(values.shape)[axis_mask]
|
||||
# We get the number of elements that will be reduced by multiplying all
|
||||
# values of 'shape' corresponding to the reduced axes.
|
||||
count = np.prod(reduction_counts, dtype=np.int64)
|
||||
@ -191,10 +213,7 @@ class _NormalizingCombiner(Combiner):
|
||||
# We want to reduce across dimensions except those specified in 'axis'
|
||||
# when using np.mean or np.variance; create the tuple of axes to reduce
|
||||
# over here.
|
||||
if self.axis is None:
|
||||
reduction_axes = None
|
||||
else:
|
||||
reduction_axes = tuple(np.delete(range(values.ndim), self.axis))
|
||||
reduction_axes = tuple(np.arange(values.ndim)[axis_mask])
|
||||
|
||||
mean = np.mean(values, axis=reduction_axes, dtype=np.float64)
|
||||
variance = np.var(values, axis=reduction_axes, dtype=np.float64)
|
||||
|
@ -275,6 +275,49 @@ class NormalizationTest(keras_parameterized.TestCase,
|
||||
if context.executing_eagerly():
|
||||
self.assertAllClose(output.numpy(), [[-1], [1], [-1], [1]])
|
||||
|
||||
@parameterized.parameters(
|
||||
{"axis": 0},
|
||||
{"axis": (-1, 0)},
|
||||
)
|
||||
def test_zeros_fail_init(self, axis):
|
||||
cls = get_layer_class()
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"The argument 'axis' may not be 0."):
|
||||
cls(axis=axis)
|
||||
|
||||
@parameterized.parameters(
|
||||
# Out of bounds
|
||||
{"axis": 3},
|
||||
{"axis": -3},
|
||||
# In a tuple
|
||||
{"axis": (1, 3)},
|
||||
{"axis": (1, -3)},
|
||||
)
|
||||
def test_bad_axis_fail_build(self, axis):
|
||||
cls = get_layer_class()
|
||||
layer = cls(axis=axis)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"in the range \[1-ndim, ndim-1\]."):
|
||||
layer.build([None, 2, 3])
|
||||
|
||||
@parameterized.parameters(
|
||||
# Results should be identical no matter how the axes are specified (3d).
|
||||
{"axis": (1, 2)},
|
||||
{"axis": (2, 1)},
|
||||
{"axis": (1, -1)},
|
||||
{"axis": (-1, 1)},
|
||||
)
|
||||
def test_axis_permutations(self, axis):
|
||||
cls = get_layer_class()
|
||||
layer = cls(axis=axis)
|
||||
# data.shape = [2, 2, 3]
|
||||
data = np.array([[[0., 1., 2.], [0., 2., 6.]],
|
||||
[[2., 3., 4.], [3., 6., 10.]]])
|
||||
expect = np.array([[[-1., -1., -1.], [-1., -1., -1.]],
|
||||
[[1., 1., 1.], [1., 1., 1.]]])
|
||||
layer.adapt(data)
|
||||
self.assertAllClose(expect, layer(data))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user