Make axis handling for Normalization more robust.

PiperOrigin-RevId: 316898233
Change-Id: I6888216ed21c4d2a482772fb2a314160750185b6
This commit is contained in:
Mark Daoust 2020-06-17 09:00:01 -07:00 committed by TensorFlower Gardener
parent 78ecbb0481
commit fc296acdc1
2 changed files with 86 additions and 24 deletions

View File

@ -53,10 +53,13 @@ class Normalization(CombinerPreprocessingLayer):
Attributes:
axis: Integer or tuple of integers, the axis or axes that should be
normalized (typically the features axis). We will normalize each element
in the specified axis. If set to 'None', the layer will perform scalar
normalization (diving the input by a single scalar value). 0 (the batch
axis) is not allowed.
"kept". These axes are not be summed over when calculating the
normalization statistics. By default the last axis, the `features` axis
is kept and any `space` or `time` axes are summed. Each element in the
the axes that are kept is normalized independently. If `axis` is set to
'None', the layer will perform scalar normalization (diving the input
by a single scalar value). The `batch` axis, 0, is always summed over
(`axis=0` is not allowed).
Examples:
@ -78,10 +81,18 @@ class Normalization(CombinerPreprocessingLayer):
# time, the dtype value will change to reflect it.
dtype = dtype or K.floatx()
# Standardize `axis` to a tuple.
if axis is None:
axis = ()
elif isinstance(axis, int):
axis = (axis,)
else:
axis = tuple(axis)
super(Normalization, self).__init__(
combiner=_NormalizingCombiner(axis), dtype=dtype, **kwargs)
if axis == 0:
if 0 in axis:
raise ValueError('The argument \'axis\' may not be 0.')
self.axis = axis
@ -90,18 +101,27 @@ class Normalization(CombinerPreprocessingLayer):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if len(input_shape) == 1:
input_shape = input_shape + [1]
ndim = len(input_shape)
# Sort `self.axis` to avoid transposing `mean_and_var_shape`.
# Negative axes are not sortable until you know the number of dimensions.
original_axis = self.axis
self.axis = tuple(sorted(self.axis,
key=lambda a: a if a >= 0 else ndim + a))
if any(a < 1-ndim for a in self.axis) or any(a >= ndim for a in self.axis):
raise ValueError('All `axis` values must be in '
'the range [1-ndim, ndim-1].\n'
'Got:\n'
' ndim: {}\n'
' axis: {}'.format(ndim, original_axis))
self._broadcast_shape = [1 for _ in range(len(input_shape))]
if isinstance(self.axis, (tuple, list)):
mean_and_var_shape = []
for i in self.axis:
mean_and_var_shape.append(input_shape[i])
self._broadcast_shape[i] = input_shape[i]
else:
if self.axis is None:
mean_and_var_shape = ()
else:
mean_and_var_shape = input_shape[self.axis]
self._broadcast_shape[self.axis] = input_shape[self.axis]
mean_and_var_shape = []
for i in self.axis:
mean_and_var_shape.append(input_shape[i])
self._broadcast_shape[i] = input_shape[i]
# count is not used in this class's call() method, but is used to re-create
# the accumulator during multiple calls to 'adapt'.
@ -179,11 +199,13 @@ class _NormalizingCombiner(Combiner):
if values.ndim == 1:
values = np.expand_dims(values, 1)
# `np.delete` ignores negative indexes, so use a mask to delete items.
axis_mask = np.ones([values.ndim], dtype=bool)
axis_mask[np.array(self.axis, dtype=np.int32)] = False
# This is the shape of all reduced axes (not specified in 'axis').
if self.axis is None:
reduction_counts = values.shape
else:
reduction_counts = np.delete(values.shape, self.axis)
reduction_counts = np.array(values.shape)[axis_mask]
# We get the number of elements that will be reduced by multiplying all
# values of 'shape' corresponding to the reduced axes.
count = np.prod(reduction_counts, dtype=np.int64)
@ -191,10 +213,7 @@ class _NormalizingCombiner(Combiner):
# We want to reduce across dimensions except those specified in 'axis'
# when using np.mean or np.variance; create the tuple of axes to reduce
# over here.
if self.axis is None:
reduction_axes = None
else:
reduction_axes = tuple(np.delete(range(values.ndim), self.axis))
reduction_axes = tuple(np.arange(values.ndim)[axis_mask])
mean = np.mean(values, axis=reduction_axes, dtype=np.float64)
variance = np.var(values, axis=reduction_axes, dtype=np.float64)

View File

@ -275,6 +275,49 @@ class NormalizationTest(keras_parameterized.TestCase,
if context.executing_eagerly():
self.assertAllClose(output.numpy(), [[-1], [1], [-1], [1]])
@parameterized.parameters(
{"axis": 0},
{"axis": (-1, 0)},
)
def test_zeros_fail_init(self, axis):
cls = get_layer_class()
with self.assertRaisesRegex(ValueError,
"The argument 'axis' may not be 0."):
cls(axis=axis)
@parameterized.parameters(
# Out of bounds
{"axis": 3},
{"axis": -3},
# In a tuple
{"axis": (1, 3)},
{"axis": (1, -3)},
)
def test_bad_axis_fail_build(self, axis):
cls = get_layer_class()
layer = cls(axis=axis)
with self.assertRaisesRegex(ValueError,
r"in the range \[1-ndim, ndim-1\]."):
layer.build([None, 2, 3])
@parameterized.parameters(
# Results should be identical no matter how the axes are specified (3d).
{"axis": (1, 2)},
{"axis": (2, 1)},
{"axis": (1, -1)},
{"axis": (-1, 1)},
)
def test_axis_permutations(self, axis):
cls = get_layer_class()
layer = cls(axis=axis)
# data.shape = [2, 2, 3]
data = np.array([[[0., 1., 2.], [0., 2., 6.]],
[[2., 3., 4.], [3., 6., 10.]]])
expect = np.array([[[-1., -1., -1.], [-1., -1., -1.]],
[[1., 1., 1.], [1., 1., 1.]]])
layer.adapt(data)
self.assertAllClose(expect, layer(data))
if __name__ == "__main__":
test.main()